zeroclaw/src/auth/gemini_oauth.rs

600 lines
20 KiB
Rust

//! Google/Gemini OAuth2 authentication flow.
//!
//! Supports:
//! - Authorization code flow with PKCE (loopback redirect)
//! - Device code flow for headless environments
//!
//! Uses the same client credentials as Gemini CLI for compatibility.
use crate::auth::oauth_common::{parse_query_params, url_decode, url_encode};
use crate::auth::profiles::TokenSet;
use anyhow::{Context, Result};
use base64::Engine;
use chrono::Utc;
use reqwest::Client;
use serde::Deserialize;
use std::collections::BTreeMap;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
// Re-export for external use (used by main.rs)
#[allow(unused_imports)]
pub use crate::auth::oauth_common::{generate_pkce_state, PkceState};
/// Get Gemini OAuth client ID from environment.
/// Required: set GEMINI_OAUTH_CLIENT_ID environment variable.
pub fn gemini_oauth_client_id() -> Option<String> {
std::env::var("GEMINI_OAUTH_CLIENT_ID")
.ok()
.filter(|s| !s.is_empty())
}
/// Get Gemini OAuth client secret from environment.
/// Required: set GEMINI_OAUTH_CLIENT_SECRET environment variable.
pub fn gemini_oauth_client_secret() -> Option<String> {
std::env::var("GEMINI_OAUTH_CLIENT_SECRET")
.ok()
.filter(|s| !s.is_empty())
}
/// Get required OAuth credentials or return error.
fn get_oauth_credentials() -> Result<(String, String)> {
let client_id = gemini_oauth_client_id().ok_or_else(|| {
anyhow::anyhow!("GEMINI_OAUTH_CLIENT_ID environment variable is required")
})?;
let client_secret = gemini_oauth_client_secret().ok_or_else(|| {
anyhow::anyhow!("GEMINI_OAUTH_CLIENT_SECRET environment variable is required")
})?;
Ok((client_id, client_secret))
}
pub const GOOGLE_OAUTH_AUTHORIZE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
pub const GOOGLE_OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
pub const GOOGLE_OAUTH_DEVICE_CODE_URL: &str = "https://oauth2.googleapis.com/device/code";
pub const GEMINI_OAUTH_REDIRECT_URI: &str = "http://localhost:1456/auth/callback";
/// Scopes required for Gemini API access.
pub const GEMINI_OAUTH_SCOPES: &str =
"openid profile email https://www.googleapis.com/auth/cloud-platform";
#[derive(Debug, Clone)]
pub struct DeviceCodeStart {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_in: u64,
pub interval: u64,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
id_token: Option<String>,
#[serde(default)]
expires_in: Option<i64>,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
scope: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_url: String,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
interval: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
#[serde(default)]
error_description: Option<String>,
}
pub fn build_authorize_url(pkce: &PkceState) -> Result<String> {
let (client_id, _) = get_oauth_credentials()?;
let mut params = BTreeMap::new();
params.insert("response_type", "code");
params.insert("client_id", client_id.as_str());
params.insert("redirect_uri", GEMINI_OAUTH_REDIRECT_URI);
params.insert("scope", GEMINI_OAUTH_SCOPES);
params.insert("code_challenge", pkce.code_challenge.as_str());
params.insert("code_challenge_method", "S256");
params.insert("state", pkce.state.as_str());
params.insert("access_type", "offline");
params.insert("prompt", "consent");
let mut encoded: Vec<String> = Vec::with_capacity(params.len());
for (k, v) in params {
encoded.push(format!("{}={}", url_encode(k), url_encode(v)));
}
Ok(format!(
"{}?{}",
GOOGLE_OAUTH_AUTHORIZE_URL,
encoded.join("&")
))
}
pub async fn exchange_code_for_tokens(
client: &Client,
code: &str,
pkce: &PkceState,
) -> Result<TokenSet> {
let (client_id, client_secret) = get_oauth_credentials()?;
let form = [
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", GEMINI_OAUTH_REDIRECT_URI),
("client_id", client_id.as_str()),
("client_secret", client_secret.as_str()),
("code_verifier", &pkce.code_verifier),
];
let response = client
.post(GOOGLE_OAUTH_TOKEN_URL)
.form(&form)
.send()
.await
.context("Failed to send token exchange request")?;
let status = response.status();
let body = response
.text()
.await
.context("Failed to read token response body")?;
if !status.is_success() {
if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
anyhow::bail!(
"Google OAuth error: {} - {}",
err.error,
err.error_description.unwrap_or_default()
);
}
anyhow::bail!("Google OAuth token exchange failed ({}): {}", status, body);
}
let token_response: TokenResponse =
serde_json::from_str(&body).context("Failed to parse token response")?;
let expires_at = token_response
.expires_in
.map(|secs| Utc::now() + chrono::Duration::seconds(secs));
Ok(TokenSet {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
id_token: token_response.id_token,
expires_at,
token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
scope: token_response.scope,
})
}
pub async fn refresh_access_token(client: &Client, refresh_token: &str) -> Result<TokenSet> {
let (client_id, client_secret) = get_oauth_credentials()?;
let form = [
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", client_id.as_str()),
("client_secret", client_secret.as_str()),
];
let response = client
.post(GOOGLE_OAUTH_TOKEN_URL)
.form(&form)
.send()
.await
.context("Failed to send refresh token request")?;
let status = response.status();
let body = response
.text()
.await
.context("Failed to read refresh response body")?;
if !status.is_success() {
if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
anyhow::bail!(
"Google OAuth refresh error: {} - {}",
err.error,
err.error_description.unwrap_or_default()
);
}
anyhow::bail!("Google OAuth refresh failed ({}): {}", status, body);
}
let token_response: TokenResponse =
serde_json::from_str(&body).context("Failed to parse refresh response")?;
let expires_at = token_response
.expires_in
.map(|secs| Utc::now() + chrono::Duration::seconds(secs));
Ok(TokenSet {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
id_token: token_response.id_token,
expires_at,
token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
scope: token_response.scope,
})
}
pub async fn start_device_code_flow(client: &Client) -> Result<DeviceCodeStart> {
let (client_id, _) = get_oauth_credentials()?;
let form = [
("client_id", client_id.as_str()),
("scope", GEMINI_OAUTH_SCOPES),
];
let response = client
.post(GOOGLE_OAUTH_DEVICE_CODE_URL)
.form(&form)
.send()
.await
.context("Failed to start device code flow")?;
let status = response.status();
let body = response
.text()
.await
.context("Failed to read device code response")?;
if !status.is_success() {
if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
anyhow::bail!(
"Google device code error: {} - {}",
err.error,
err.error_description.unwrap_or_default()
);
}
anyhow::bail!("Google device code request failed ({}): {}", status, body);
}
let device_response: DeviceCodeResponse =
serde_json::from_str(&body).context("Failed to parse device code response")?;
let user_code = device_response.user_code;
let verification_url = device_response.verification_url;
Ok(DeviceCodeStart {
device_code: device_response.device_code,
verification_uri_complete: Some(format!("{}?user_code={}", &verification_url, &user_code)),
user_code,
verification_uri: verification_url,
expires_in: device_response.expires_in.unwrap_or(1800),
interval: device_response.interval.unwrap_or(5),
})
}
pub async fn poll_device_code_tokens(
client: &Client,
device: &DeviceCodeStart,
) -> Result<TokenSet> {
let (client_id, client_secret) = get_oauth_credentials()?;
let deadline = std::time::Instant::now() + Duration::from_secs(device.expires_in);
let interval = Duration::from_secs(device.interval.max(5));
loop {
if std::time::Instant::now() > deadline {
anyhow::bail!("Device code expired before authorization was completed");
}
tokio::time::sleep(interval).await;
let form = [
("client_id", client_id.as_str()),
("client_secret", client_secret.as_str()),
("device_code", device.device_code.as_str()),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
];
let response = client
.post(GOOGLE_OAUTH_TOKEN_URL)
.form(&form)
.send()
.await
.context("Failed to poll device code")?;
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.is_success() {
let token_response: TokenResponse =
serde_json::from_str(&body).context("Failed to parse token response")?;
let expires_at = token_response
.expires_in
.map(|secs| Utc::now() + chrono::Duration::seconds(secs));
return Ok(TokenSet {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
id_token: token_response.id_token,
expires_at,
token_type: token_response.token_type.or_else(|| Some("Bearer".into())),
scope: token_response.scope,
});
}
if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&body) {
match err.error.as_str() {
"authorization_pending" => {}
"slow_down" => {
tokio::time::sleep(Duration::from_secs(5)).await;
}
"access_denied" => {
anyhow::bail!("User denied authorization");
}
"expired_token" => {
anyhow::bail!("Device code expired");
}
_ => {
anyhow::bail!(
"Google OAuth error: {} - {}",
err.error,
err.error_description.unwrap_or_default()
);
}
}
}
}
}
/// Receive OAuth code via loopback callback OR manual stdin input.
///
/// If the callback server can't receive the redirect (e.g., remote/headless environment),
/// the user can paste the full callback URL or just the code.
pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result<String> {
// Try to bind to the callback port
let listener = match TcpListener::bind("127.0.0.1:1456").await {
Ok(l) => l,
Err(e) => {
eprintln!("Could not bind to localhost:1456: {e}");
eprintln!("Falling back to manual input.");
return receive_code_from_stdin(expected_state).await;
}
};
println!("Waiting for callback at http://localhost:1456/auth/callback ...");
println!("(Or paste the full callback URL / authorization code here if running remotely)");
// Race between: callback arriving OR stdin input
tokio::select! {
accept_result = async {
tokio::time::timeout(timeout, listener.accept()).await
} => {
match accept_result {
Ok(Ok((mut stream, _))) => {
let mut buffer = vec![0u8; 4096];
let n = stream
.read(&mut buffer)
.await
.context("Failed to read from callback connection")?;
let request = String::from_utf8_lossy(&buffer[..n]);
let (code, state) = parse_callback_request(&request)?;
if state != expected_state {
let response = "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html\r\n\r\n\
<html><body><h1>State mismatch</h1><p>Please try again.</p></body></html>";
let _ = stream.write_all(response.as_bytes()).await;
anyhow::bail!("OAuth state mismatch");
}
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\
<html><body><h1>Success!</h1><p>You can close this window and return to the terminal.</p></body></html>";
let _ = stream.write_all(response.as_bytes()).await;
Ok(code)
}
Ok(Err(e)) => Err(anyhow::anyhow!("Failed to accept connection: {e}")),
Err(_) => {
eprintln!("\nCallback timeout. Falling back to manual input.");
receive_code_from_stdin(expected_state).await
}
}
}
stdin_result = receive_code_from_stdin(expected_state) => {
stdin_result
}
}
}
/// Read authorization code from stdin (supports full URL or raw code).
async fn receive_code_from_stdin(expected_state: &str) -> Result<String> {
use std::io::{self, BufRead};
let expected = expected_state.to_string();
let input = tokio::task::spawn_blocking(move || {
let stdin = io::stdin();
let mut line = String::new();
stdin.lock().read_line(&mut line).ok();
let trimmed = line.trim().to_string();
if trimmed.is_empty() {
return Err(anyhow::anyhow!("No input received"));
}
parse_code_from_redirect(&trimmed, Some(&expected))
})
.await
.context("Failed to read from stdin")??;
Ok(input)
}
fn parse_callback_request(request: &str) -> Result<(String, String)> {
let first_line = request.lines().next().unwrap_or("");
let path = first_line
.split_whitespace()
.nth(1)
.unwrap_or("")
.to_string();
let query_start = path.find('?').map(|i| i + 1).unwrap_or(path.len());
let query = &path[query_start..];
let mut code = None;
let mut state = None;
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
match key {
"code" => code = Some(url_decode(value)),
"state" => state = Some(url_decode(value)),
_ => {}
}
}
}
let code = code.ok_or_else(|| anyhow::anyhow!("No 'code' parameter in callback"))?;
let state = state.ok_or_else(|| anyhow::anyhow!("No 'state' parameter in callback"))?;
Ok((code, state))
}
pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
let trimmed = input.trim();
if trimmed.is_empty() {
anyhow::bail!("No OAuth code provided");
}
// Extract query string
let query = if let Some((_, right)) = trimmed.split_once('?') {
right
} else {
trimmed
};
let params = parse_query_params(query);
// If we have code param, extract it
if let Some(code) = params.get("code") {
// Validate state if expected
if let Some(expected) = expected_state {
if let Some(actual) = params.get("state") {
if actual != expected {
anyhow::bail!("OAuth state mismatch: expected {expected}, got {actual}");
}
}
}
return Ok(code.clone());
}
// Otherwise, assume it's the raw code (if long enough and no spaces)
if trimmed.len() > 10 && !trimmed.contains(' ') && !trimmed.contains('&') {
return Ok(trimmed.to_string());
}
anyhow::bail!("Could not parse OAuth code from input")
}
/// Extract account email from Google ID token.
pub fn extract_account_email_from_id_token(id_token: &str) -> Option<String> {
let parts: Vec<&str> = id_token.split('.').collect();
if parts.len() != 3 {
return None;
}
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.ok()?;
#[derive(Deserialize)]
struct IdTokenPayload {
email: Option<String>,
}
let payload: IdTokenPayload = serde_json::from_slice(&payload).ok()?;
payload.email
}
#[cfg(test)]
mod tests {
use super::*;
struct EnvVarRestore {
key: &'static str,
original: Option<String>,
}
impl EnvVarRestore {
fn set(key: &'static str, value: &str) -> Self {
let original = std::env::var(key).ok();
std::env::set_var(key, value);
Self { key, original }
}
}
impl Drop for EnvVarRestore {
fn drop(&mut self) {
if let Some(ref original) = self.original {
std::env::set_var(self.key, original);
} else {
std::env::remove_var(self.key);
}
}
}
#[test]
fn pkce_generates_valid_state() {
let pkce = generate_pkce_state();
assert!(!pkce.code_verifier.is_empty());
assert!(!pkce.code_challenge.is_empty());
assert!(!pkce.state.is_empty());
}
#[test]
fn authorize_url_contains_required_params() {
// Isolate environment changes so this test cannot leak into other test modules.
let _client_id_guard = EnvVarRestore::set("GEMINI_OAUTH_CLIENT_ID", "test-client-id");
let _client_secret_guard =
EnvVarRestore::set("GEMINI_OAUTH_CLIENT_SECRET", "test-client-secret");
let pkce = generate_pkce_state();
let url = build_authorize_url(&pkce).expect("Failed to build authorize URL");
assert!(url.contains("accounts.google.com"));
assert!(url.contains("client_id="));
assert!(url.contains("redirect_uri="));
assert!(url.contains("code_challenge="));
assert!(url.contains("access_type=offline"));
}
#[test]
fn parse_code_from_url() {
let url = "http://localhost:1456/auth/callback?code=4/0test&state=xyz";
let code = parse_code_from_redirect(url, Some("xyz")).unwrap();
assert_eq!(code, "4/0test");
}
#[test]
fn parse_code_from_raw() {
let raw = "4/0AcvDMrC1234567890abcdef";
let code = parse_code_from_redirect(raw, None).unwrap();
assert_eq!(code, raw);
}
#[test]
fn extract_email_from_id_token() {
// Minimal test JWT with email claim
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256"}"#);
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(r#"{"email":"test@example.com"}"#);
let token = format!("{}.{}.signature", header, payload);
let email = extract_account_email_from_id_token(&token);
assert_eq!(email, Some("test@example.com".to_string()));
}
}