2274 lines
80 KiB
Rust
2274 lines
80 KiB
Rust
//! Google Gemini provider with support for:
|
|
//! - Direct API key (`GEMINI_API_KEY` env var or config)
|
|
//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication)
|
|
//! - ZeroClaw auth-profiles OAuth tokens
|
|
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
|
|
|
|
use crate::auth::AuthService;
|
|
use crate::multimodal;
|
|
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, TokenUsage};
|
|
use async_trait::async_trait;
|
|
use base64::Engine;
|
|
use directories::UserDirs;
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
|
|
/// Gemini provider supporting multiple authentication methods.
|
|
pub struct GeminiProvider {
|
|
auth: Option<GeminiAuth>,
|
|
oauth_project: Arc<tokio::sync::Mutex<Option<String>>>,
|
|
oauth_cred_paths: Vec<PathBuf>,
|
|
oauth_index: Arc<tokio::sync::Mutex<usize>>,
|
|
/// AuthService for managed profiles (auth-profiles.json).
|
|
auth_service: Option<AuthService>,
|
|
/// Override profile name for managed auth.
|
|
auth_profile_override: Option<String>,
|
|
}
|
|
|
|
/// Mutable OAuth token state — supports runtime refresh for long-lived processes.
|
|
struct OAuthTokenState {
|
|
access_token: String,
|
|
refresh_token: Option<String>,
|
|
client_id: Option<String>,
|
|
client_secret: Option<String>,
|
|
/// Expiry as unix millis. `None` means unknown (treat as potentially expired).
|
|
expiry_millis: Option<i64>,
|
|
}
|
|
|
|
/// Resolved credential — the variant determines both the HTTP auth method
|
|
/// and the diagnostic label returned by `auth_source()`.
|
|
enum GeminiAuth {
|
|
/// Explicit API key from config: sent as `?key=` query parameter.
|
|
ExplicitKey(String),
|
|
/// API key from `GEMINI_API_KEY` env var: sent as `?key=`.
|
|
EnvGeminiKey(String),
|
|
/// API key from `GOOGLE_API_KEY` env var: sent as `?key=`.
|
|
EnvGoogleKey(String),
|
|
/// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`.
|
|
/// Wrapped in a Mutex to allow runtime token refresh.
|
|
OAuthToken(Arc<tokio::sync::Mutex<OAuthTokenState>>),
|
|
/// OAuth token managed by AuthService (auth-profiles.json).
|
|
/// Token refresh is handled by AuthService, not here.
|
|
ManagedOAuth,
|
|
}
|
|
|
|
impl GeminiAuth {
|
|
/// Whether this credential is an API key (sent as `?key=` query param).
|
|
fn is_api_key(&self) -> bool {
|
|
matches!(
|
|
self,
|
|
GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_)
|
|
)
|
|
}
|
|
|
|
/// Whether this credential is an OAuth token (CLI or managed).
|
|
fn is_oauth(&self) -> bool {
|
|
matches!(self, GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth)
|
|
}
|
|
|
|
/// The raw credential string (for API key variants only).
|
|
fn api_key_credential(&self) -> &str {
|
|
match self {
|
|
GeminiAuth::ExplicitKey(s)
|
|
| GeminiAuth::EnvGeminiKey(s)
|
|
| GeminiAuth::EnvGoogleKey(s) => s,
|
|
GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => "",
|
|
}
|
|
}
|
|
}
|
|
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
// API REQUEST/RESPONSE TYPES
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
struct GenerateContentRequest {
|
|
contents: Vec<Content>,
|
|
#[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
|
|
system_instruction: Option<Content>,
|
|
#[serde(rename = "generationConfig")]
|
|
generation_config: GenerationConfig,
|
|
}
|
|
|
|
/// Request envelope for the internal cloudcode-pa API.
|
|
/// OAuth tokens from Gemini CLI are scoped for this endpoint.
|
|
///
|
|
/// The internal API expects a nested structure:
|
|
/// ```json
|
|
/// {
|
|
/// "model": "models/gemini-...",
|
|
/// "project": "...",
|
|
/// "request": {
|
|
/// "contents": [...],
|
|
/// "systemInstruction": {...},
|
|
/// "generationConfig": {...}
|
|
/// }
|
|
/// }
|
|
/// ```
|
|
/// Ref: gemini-cli `packages/core/src/code_assist/converter.ts`
|
|
#[derive(Debug, Serialize)]
|
|
struct InternalGenerateContentEnvelope {
|
|
model: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
project: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
user_prompt_id: Option<String>,
|
|
request: InternalGenerateContentRequest,
|
|
}
|
|
|
|
/// Nested request payload for cloudcode-pa's code assist APIs.
|
|
#[derive(Debug, Serialize)]
|
|
struct InternalGenerateContentRequest {
|
|
contents: Vec<Content>,
|
|
#[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
|
|
system_instruction: Option<Content>,
|
|
#[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
|
|
generation_config: Option<GenerationConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
struct Content {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
role: Option<String>,
|
|
parts: Vec<Part>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
#[serde(untagged)]
|
|
enum Part {
|
|
Text {
|
|
text: String,
|
|
},
|
|
InlineData {
|
|
#[serde(rename = "inlineData")]
|
|
inline_data: InlineDataPart,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
struct InlineDataPart {
|
|
#[serde(rename = "mimeType")]
|
|
mime_type: String,
|
|
data: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Clone)]
|
|
struct GenerationConfig {
|
|
temperature: f64,
|
|
#[serde(rename = "maxOutputTokens")]
|
|
max_output_tokens: u32,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GenerateContentResponse {
|
|
candidates: Option<Vec<Candidate>>,
|
|
error: Option<ApiError>,
|
|
#[serde(default)]
|
|
response: Option<Box<GenerateContentResponse>>,
|
|
#[serde(default, rename = "usageMetadata")]
|
|
usage_metadata: Option<GeminiUsageMetadata>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GeminiUsageMetadata {
|
|
#[serde(default, rename = "promptTokenCount")]
|
|
prompt_token_count: Option<u64>,
|
|
#[serde(default, rename = "candidatesTokenCount")]
|
|
candidates_token_count: Option<u64>,
|
|
}
|
|
|
|
/// Response envelope for the internal cloudcode-pa API.
|
|
/// The internal API nests the standard response under a `response` field.
|
|
#[derive(Debug, Deserialize)]
|
|
struct InternalGenerateContentResponse {
|
|
response: GenerateContentResponse,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct Candidate {
|
|
#[serde(default)]
|
|
content: Option<CandidateContent>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct CandidateContent {
|
|
parts: Vec<ResponsePart>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ResponsePart {
|
|
#[serde(default)]
|
|
text: Option<String>,
|
|
/// Thinking models (e.g. gemini-3-pro-preview) mark reasoning parts with `thought: true`.
|
|
#[serde(default)]
|
|
thought: bool,
|
|
}
|
|
|
|
impl CandidateContent {
|
|
/// Extract effective text, skipping thinking/signature parts.
|
|
///
|
|
/// Gemini thinking models (e.g. gemini-3-pro-preview) return parts like:
|
|
/// - `{"thought": true, "text": "reasoning..."}` — internal reasoning
|
|
/// - `{"text": "actual answer"}` — the real response
|
|
/// - `{"thoughtSignature": "..."}` — opaque signature (no text field)
|
|
///
|
|
/// Returns the non-thinking text, falling back to thinking text only when
|
|
/// no non-thinking content is available.
|
|
fn effective_text(self) -> Option<String> {
|
|
let mut answer_parts: Vec<String> = Vec::new();
|
|
let mut first_thinking: Option<String> = None;
|
|
|
|
for part in self.parts {
|
|
if let Some(text) = part.text {
|
|
if text.is_empty() {
|
|
continue;
|
|
}
|
|
if !part.thought {
|
|
answer_parts.push(text);
|
|
} else if first_thinking.is_none() {
|
|
first_thinking = Some(text);
|
|
}
|
|
}
|
|
}
|
|
|
|
if answer_parts.is_empty() {
|
|
first_thinking
|
|
} else {
|
|
Some(answer_parts.join(""))
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ApiError {
|
|
message: String,
|
|
}
|
|
|
|
impl GenerateContentResponse {
|
|
/// cloudcode-pa wraps the actual response under `response`.
|
|
fn into_effective_response(self) -> Self {
|
|
match self {
|
|
Self {
|
|
response: Some(inner),
|
|
..
|
|
} => *inner,
|
|
other => other,
|
|
}
|
|
}
|
|
}
|
|
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
// GEMINI CLI TOKEN STRUCTURES
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
/// OAuth token stored by Gemini CLI in `~/.gemini/oauth_creds.json`
|
|
#[derive(Debug, Deserialize)]
|
|
struct GeminiCliOAuthCreds {
|
|
access_token: Option<String>,
|
|
#[serde(alias = "idToken")]
|
|
id_token: Option<String>,
|
|
refresh_token: Option<String>,
|
|
#[serde(alias = "clientId")]
|
|
client_id: Option<String>,
|
|
#[serde(alias = "clientSecret")]
|
|
client_secret: Option<String>,
|
|
/// Unix milliseconds expiry (used by newer Gemini CLI versions).
|
|
#[serde(alias = "expiryDate")]
|
|
expiry_date: Option<i64>,
|
|
/// RFC 3339 expiry string (used by older Gemini CLI versions).
|
|
expiry: Option<String>,
|
|
}
|
|
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
// GEMINI CLI OAUTH CONSTANTS
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
/// Google OAuth token endpoint.
|
|
const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
|
|
|
|
/// Internal API endpoint used by Gemini CLI for OAuth users.
|
|
/// See: https://github.com/google-gemini/gemini-cli/issues/19200
|
|
const CLOUDCODE_PA_ENDPOINT: &str = "https://cloudcode-pa.googleapis.com/v1internal";
|
|
|
|
/// loadCodeAssist endpoint for resolving the project ID.
|
|
const LOAD_CODE_ASSIST_ENDPOINT: &str =
|
|
"https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist";
|
|
|
|
/// Public API endpoint for API key users.
|
|
const PUBLIC_API_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta";
|
|
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
// TOKEN REFRESH
|
|
// ══════════════════════════════════════════════════════════════════════════════
|
|
|
|
/// Result of a successful token refresh.
|
|
struct RefreshedToken {
|
|
access_token: String,
|
|
/// Expiry as unix millis (computed from `expires_in` seconds in the response).
|
|
expiry_millis: Option<i64>,
|
|
}
|
|
|
|
/// Refresh an expired Gemini CLI OAuth token using the refresh_token grant.
|
|
///
|
|
/// Client credentials are optional and can be sourced from:
|
|
/// - `oauth_creds.json` if present
|
|
/// - `GEMINI_OAUTH_CLIENT_ID` / `GEMINI_OAUTH_CLIENT_SECRET` env vars
|
|
fn refresh_gemini_cli_token(
|
|
refresh_token: &str,
|
|
client_id: Option<&str>,
|
|
client_secret: Option<&str>,
|
|
) -> anyhow::Result<RefreshedToken> {
|
|
let client = reqwest::blocking::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(15))
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.build()
|
|
.unwrap_or_else(|_| reqwest::blocking::Client::new());
|
|
|
|
let form = build_oauth_refresh_form(refresh_token, client_id, client_secret);
|
|
|
|
let response = client
|
|
.post(GOOGLE_TOKEN_ENDPOINT)
|
|
.header("Content-Type", "application/x-www-form-urlencoded")
|
|
.header("Accept", "application/json")
|
|
.form(&form)
|
|
.send()
|
|
.map_err(|error| anyhow::anyhow!("Gemini CLI OAuth refresh request failed: {error}"))?;
|
|
|
|
let status = response.status();
|
|
let body = response
|
|
.text()
|
|
.unwrap_or_else(|_| "<failed to read response body>".to_string());
|
|
|
|
if !status.is_success() {
|
|
let sanitized = super::sanitize_api_error(&body);
|
|
anyhow::bail!("Gemini CLI OAuth refresh failed (HTTP {status}): {sanitized}");
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct TokenResponse {
|
|
access_token: Option<String>,
|
|
expires_in: Option<i64>,
|
|
}
|
|
|
|
let parsed: TokenResponse = serde_json::from_str(&body)
|
|
.map_err(|_| anyhow::anyhow!("Gemini CLI OAuth refresh response is not valid JSON"))?;
|
|
|
|
let access_token = parsed
|
|
.access_token
|
|
.filter(|t| !t.trim().is_empty())
|
|
.ok_or_else(|| anyhow::anyhow!("Gemini CLI OAuth refresh response missing access_token"))?;
|
|
|
|
let expiry_millis = parsed.expires_in.and_then(|secs| {
|
|
let now_millis = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.ok()
|
|
.and_then(|d| i64::try_from(d.as_millis()).ok())?;
|
|
now_millis.checked_add(secs.checked_mul(1000)?)
|
|
});
|
|
|
|
Ok(RefreshedToken {
|
|
access_token,
|
|
expiry_millis,
|
|
})
|
|
}
|
|
|
|
fn build_oauth_refresh_form(
|
|
refresh_token: &str,
|
|
client_id: Option<&str>,
|
|
client_secret: Option<&str>,
|
|
) -> Vec<(&'static str, String)> {
|
|
let mut form = vec![
|
|
("grant_type", "refresh_token".to_string()),
|
|
("refresh_token", refresh_token.to_string()),
|
|
];
|
|
if let Some(id) = client_id.and_then(GeminiProvider::normalize_non_empty) {
|
|
form.push(("client_id", id));
|
|
}
|
|
if let Some(secret) = client_secret.and_then(GeminiProvider::normalize_non_empty) {
|
|
form.push(("client_secret", secret));
|
|
}
|
|
form
|
|
}
|
|
|
|
fn extract_client_id_from_id_token(id_token: &str) -> Option<String> {
|
|
let payload = id_token.split('.').nth(1)?;
|
|
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.decode(payload)
|
|
.or_else(|_| base64::engine::general_purpose::URL_SAFE.decode(payload))
|
|
.ok()?;
|
|
|
|
#[derive(Deserialize)]
|
|
struct IdTokenClaims {
|
|
aud: Option<String>,
|
|
azp: Option<String>,
|
|
}
|
|
|
|
let claims: IdTokenClaims = serde_json::from_slice(&decoded).ok()?;
|
|
claims
|
|
.aud
|
|
.as_deref()
|
|
.and_then(GeminiProvider::normalize_non_empty)
|
|
.or_else(|| {
|
|
claims
|
|
.azp
|
|
.as_deref()
|
|
.and_then(GeminiProvider::normalize_non_empty)
|
|
})
|
|
}
|
|
|
|
/// Async version of token refresh for use during runtime (inside tokio context).
|
|
async fn refresh_gemini_cli_token_async(
|
|
refresh_token: &str,
|
|
client_id: Option<&str>,
|
|
client_secret: Option<&str>,
|
|
) -> anyhow::Result<RefreshedToken> {
|
|
let refresh_token = refresh_token.to_string();
|
|
let client_id = client_id.map(str::to_string);
|
|
let client_secret = client_secret.map(str::to_string);
|
|
tokio::task::spawn_blocking(move || {
|
|
refresh_gemini_cli_token(
|
|
&refresh_token,
|
|
client_id.as_deref(),
|
|
client_secret.as_deref(),
|
|
)
|
|
})
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("Token refresh task panicked: {e}"))?
|
|
}
|
|
|
|
impl GeminiProvider {
|
|
/// Create a new Gemini provider.
|
|
///
|
|
/// Authentication priority:
|
|
/// 1. Explicit API key passed in
|
|
/// 2. `GEMINI_API_KEY` environment variable
|
|
/// 3. `GOOGLE_API_KEY` environment variable
|
|
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
|
|
pub fn new(api_key: Option<&str>) -> Self {
|
|
let oauth_cred_paths = Self::discover_oauth_cred_paths();
|
|
let resolved_auth = api_key
|
|
.and_then(Self::normalize_non_empty)
|
|
.map(GeminiAuth::ExplicitKey)
|
|
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
|
|
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
|
|
.or_else(|| {
|
|
Self::try_load_gemini_cli_token(oauth_cred_paths.first())
|
|
.map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))))
|
|
});
|
|
|
|
Self {
|
|
auth: resolved_auth,
|
|
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
|
oauth_cred_paths,
|
|
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
|
|
auth_service: None,
|
|
auth_profile_override: None,
|
|
}
|
|
}
|
|
|
|
/// Create a new Gemini provider with managed OAuth from auth-profiles.json.
|
|
///
|
|
/// Authentication priority:
|
|
/// 1. Explicit API key passed in
|
|
/// 2. `GEMINI_API_KEY` environment variable
|
|
/// 3. `GOOGLE_API_KEY` environment variable
|
|
/// 4. Managed OAuth from auth-profiles.json (if auth_service provided)
|
|
/// 5. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
|
|
pub fn new_with_auth(
|
|
api_key: Option<&str>,
|
|
auth_service: AuthService,
|
|
profile_override: Option<String>,
|
|
) -> Self {
|
|
let oauth_cred_paths = Self::discover_oauth_cred_paths();
|
|
|
|
// First check API keys
|
|
let resolved_auth = api_key
|
|
.and_then(Self::normalize_non_empty)
|
|
.map(GeminiAuth::ExplicitKey)
|
|
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
|
|
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey));
|
|
|
|
// If no API key, we'll use managed OAuth (checked at runtime)
|
|
// or fall back to CLI OAuth
|
|
let (auth, use_managed) = if resolved_auth.is_some() {
|
|
(resolved_auth, false)
|
|
} else {
|
|
// Check if we have a managed profile - this is a blocking check
|
|
// but we need to know at construction time
|
|
let has_managed = std::thread::scope(|s| {
|
|
s.spawn(|| {
|
|
let rt = tokio::runtime::Builder::new_current_thread()
|
|
.enable_all()
|
|
.build()
|
|
.ok()?;
|
|
rt.block_on(async {
|
|
auth_service
|
|
.get_gemini_profile(profile_override.as_deref())
|
|
.await
|
|
.ok()
|
|
.flatten()
|
|
})
|
|
})
|
|
.join()
|
|
.ok()
|
|
.flatten()
|
|
.is_some()
|
|
});
|
|
|
|
if has_managed {
|
|
(Some(GeminiAuth::ManagedOAuth), true)
|
|
} else {
|
|
// Fall back to CLI OAuth
|
|
let cli_auth = Self::try_load_gemini_cli_token(oauth_cred_paths.first())
|
|
.map(|state| GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(state))));
|
|
(cli_auth, false)
|
|
}
|
|
};
|
|
|
|
Self {
|
|
auth,
|
|
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
|
oauth_cred_paths,
|
|
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
|
|
auth_service: if use_managed {
|
|
Some(auth_service)
|
|
} else {
|
|
None
|
|
},
|
|
auth_profile_override: profile_override,
|
|
}
|
|
}
|
|
|
|
fn normalize_non_empty(value: &str) -> Option<String> {
|
|
let trimmed = value.trim();
|
|
if trimmed.is_empty() {
|
|
None
|
|
} else {
|
|
Some(trimmed.to_string())
|
|
}
|
|
}
|
|
|
|
fn load_non_empty_env(name: &str) -> Option<String> {
|
|
std::env::var(name)
|
|
.ok()
|
|
.and_then(|value| Self::normalize_non_empty(&value))
|
|
}
|
|
|
|
fn load_gemini_cli_creds(creds_path: &PathBuf) -> Option<GeminiCliOAuthCreds> {
|
|
if !creds_path.exists() {
|
|
return None;
|
|
}
|
|
let content = std::fs::read_to_string(creds_path).ok()?;
|
|
serde_json::from_str(&content).ok()
|
|
}
|
|
|
|
/// Discover all OAuth credential files from known Gemini CLI installations.
|
|
///
|
|
/// Looks in `~/.gemini/oauth_creds.json` (default) plus any
|
|
/// `~/.gemini-*-home/.gemini/oauth_creds.json` siblings.
|
|
fn discover_oauth_cred_paths() -> Vec<PathBuf> {
|
|
let home = match UserDirs::new() {
|
|
Some(u) => u.home_dir().to_path_buf(),
|
|
None => return Vec::new(),
|
|
};
|
|
|
|
let mut paths = Vec::new();
|
|
|
|
let primary = home.join(".gemini").join("oauth_creds.json");
|
|
if primary.exists() {
|
|
paths.push(primary);
|
|
}
|
|
|
|
if let Ok(entries) = std::fs::read_dir(&home) {
|
|
let mut extras: Vec<PathBuf> = entries
|
|
.filter_map(|e| e.ok())
|
|
.filter_map(|e| {
|
|
let name = e.file_name().to_string_lossy().to_string();
|
|
if name.starts_with(".gemini-") && name.ends_with("-home") {
|
|
let path = e.path().join(".gemini").join("oauth_creds.json");
|
|
if path.exists() {
|
|
return Some(path);
|
|
}
|
|
}
|
|
None
|
|
})
|
|
.collect();
|
|
extras.sort();
|
|
paths.extend(extras);
|
|
}
|
|
|
|
paths
|
|
}
|
|
|
|
/// Try to load OAuth credentials from Gemini CLI's cached credentials.
|
|
/// Location: `~/.gemini/oauth_creds.json`
|
|
///
|
|
/// Returns the full `OAuthTokenState` so the provider can refresh at runtime.
|
|
fn try_load_gemini_cli_token(path: Option<&PathBuf>) -> Option<OAuthTokenState> {
|
|
let creds = Self::load_gemini_cli_creds(path?)?;
|
|
|
|
// Determine expiry in millis: prefer expiry_date over expiry (RFC 3339)
|
|
let expiry_millis = creds.expiry_date.or_else(|| {
|
|
creds.expiry.as_deref().and_then(|expiry| {
|
|
chrono::DateTime::parse_from_rfc3339(expiry)
|
|
.ok()
|
|
.map(|dt| dt.timestamp_millis())
|
|
})
|
|
});
|
|
|
|
let access_token = creds
|
|
.access_token
|
|
.and_then(|token| Self::normalize_non_empty(&token))?;
|
|
|
|
let id_token_client_id = creds
|
|
.id_token
|
|
.as_deref()
|
|
.and_then(extract_client_id_from_id_token);
|
|
|
|
let client_id = Self::load_non_empty_env("GEMINI_OAUTH_CLIENT_ID")
|
|
.or_else(|| {
|
|
creds
|
|
.client_id
|
|
.as_deref()
|
|
.and_then(Self::normalize_non_empty)
|
|
})
|
|
.or(id_token_client_id);
|
|
let client_secret = Self::load_non_empty_env("GEMINI_OAUTH_CLIENT_SECRET").or_else(|| {
|
|
creds
|
|
.client_secret
|
|
.as_deref()
|
|
.and_then(Self::normalize_non_empty)
|
|
});
|
|
|
|
Some(OAuthTokenState {
|
|
access_token,
|
|
refresh_token: creds.refresh_token,
|
|
client_id,
|
|
client_secret,
|
|
expiry_millis,
|
|
})
|
|
}
|
|
|
|
/// Get the Gemini CLI config directory (~/.gemini)
|
|
fn gemini_cli_dir() -> Option<PathBuf> {
|
|
UserDirs::new().map(|u| u.home_dir().join(".gemini"))
|
|
}
|
|
|
|
/// Check if Gemini CLI is configured and has valid credentials
|
|
pub fn has_cli_credentials() -> bool {
|
|
Self::discover_oauth_cred_paths().iter().any(|path| {
|
|
Self::load_gemini_cli_creds(path)
|
|
.and_then(|creds| {
|
|
creds
|
|
.access_token
|
|
.as_deref()
|
|
.and_then(Self::normalize_non_empty)
|
|
})
|
|
.is_some()
|
|
})
|
|
}
|
|
|
|
/// Check if any Gemini authentication is available
|
|
pub fn has_any_auth() -> bool {
|
|
Self::load_non_empty_env("GEMINI_API_KEY").is_some()
|
|
|| Self::load_non_empty_env("GOOGLE_API_KEY").is_some()
|
|
|| Self::has_cli_credentials()
|
|
}
|
|
|
|
/// Get authentication source description for diagnostics.
|
|
/// Uses the stored enum variant — no env var re-reading at call time.
|
|
pub fn auth_source(&self) -> &'static str {
|
|
match self.auth.as_ref() {
|
|
Some(GeminiAuth::ExplicitKey(_)) => "config",
|
|
Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var",
|
|
Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var",
|
|
Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
|
|
Some(GeminiAuth::ManagedOAuth) => "auth-profiles",
|
|
None => "none",
|
|
}
|
|
}
|
|
|
|
/// Get a valid OAuth access token, refreshing if expired.
|
|
/// Adds a 60-second buffer before actual expiry to avoid edge-case failures.
|
|
async fn get_valid_oauth_token(
|
|
state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
|
|
) -> anyhow::Result<String> {
|
|
let mut guard = state.lock().await;
|
|
|
|
let now_millis = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.ok()
|
|
.and_then(|d| i64::try_from(d.as_millis()).ok())
|
|
.unwrap_or(i64::MAX);
|
|
|
|
// Refresh if expiry is unknown, already expired, or within 60s of expiry.
|
|
let needs_refresh = guard
|
|
.expiry_millis
|
|
.map_or(true, |exp| exp <= now_millis.saturating_add(60_000));
|
|
|
|
if needs_refresh {
|
|
if let Some(ref refresh_token) = guard.refresh_token {
|
|
let refreshed = refresh_gemini_cli_token_async(
|
|
refresh_token,
|
|
guard.client_id.as_deref(),
|
|
guard.client_secret.as_deref(),
|
|
)
|
|
.await?;
|
|
tracing::info!("Gemini CLI OAuth token refreshed successfully (runtime)");
|
|
guard.access_token = refreshed.access_token;
|
|
guard.expiry_millis = refreshed.expiry_millis;
|
|
} else {
|
|
anyhow::bail!(
|
|
"Gemini CLI OAuth token expired and no refresh_token available — re-run `gemini` to authenticate"
|
|
);
|
|
}
|
|
}
|
|
|
|
Ok(guard.access_token.clone())
|
|
}
|
|
|
|
/// Rotate to the next available OAuth credentials file and swap state.
|
|
/// Returns `true` when rotation succeeded.
|
|
async fn rotate_oauth_credential(
|
|
&self,
|
|
state: &Arc<tokio::sync::Mutex<OAuthTokenState>>,
|
|
) -> bool {
|
|
if self.oauth_cred_paths.len() <= 1 {
|
|
return false;
|
|
}
|
|
|
|
let mut idx = self.oauth_index.lock().await;
|
|
let start = *idx;
|
|
|
|
loop {
|
|
let next = (*idx + 1) % self.oauth_cred_paths.len();
|
|
*idx = next;
|
|
|
|
if next == start {
|
|
return false;
|
|
}
|
|
|
|
if let Some(next_state) =
|
|
Self::try_load_gemini_cli_token(self.oauth_cred_paths.get(next))
|
|
{
|
|
{
|
|
let mut guard = state.lock().await;
|
|
*guard = next_state;
|
|
}
|
|
{
|
|
let mut cached_project = self.oauth_project.lock().await;
|
|
*cached_project = None;
|
|
}
|
|
tracing::warn!(
|
|
"Gemini OAuth: rotated credential to {}",
|
|
self.oauth_cred_paths[next].display()
|
|
);
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn format_model_name(model: &str) -> String {
|
|
if model.starts_with("models/") {
|
|
model.to_string()
|
|
} else {
|
|
format!("models/{model}")
|
|
}
|
|
}
|
|
|
|
fn format_internal_model_name(model: &str) -> String {
|
|
model.strip_prefix("models/").unwrap_or(model).to_string()
|
|
}
|
|
|
|
/// Build the API URL based on auth type.
|
|
///
|
|
/// - API key users → public `generativelanguage.googleapis.com/v1beta`
|
|
/// - OAuth users → internal `cloudcode-pa.googleapis.com/v1internal`
|
|
///
|
|
/// The Gemini CLI OAuth tokens are scoped for the internal Code Assist API,
|
|
/// not the public API. Sending them to the public endpoint results in
|
|
/// "400 Bad Request: API key not valid" errors.
|
|
/// See: https://github.com/google-gemini/gemini-cli/issues/19200
|
|
fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
|
|
match auth {
|
|
GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
|
|
// OAuth tokens are scoped for the internal Code Assist API.
|
|
// The model is passed in the request body, not the URL path.
|
|
format!("{CLOUDCODE_PA_ENDPOINT}:generateContent")
|
|
}
|
|
_ => {
|
|
let model_name = Self::format_model_name(model);
|
|
let base_url = format!("{PUBLIC_API_ENDPOINT}/{model_name}:generateContent");
|
|
|
|
if auth.is_api_key() {
|
|
format!("{base_url}?key={}", auth.api_key_credential())
|
|
} else {
|
|
base_url
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn http_client(&self) -> Client {
|
|
crate::config::build_runtime_proxy_client_with_timeouts("provider.gemini", 120, 10)
|
|
}
|
|
|
|
/// Resolve the GCP project ID for OAuth by calling the loadCodeAssist endpoint.
|
|
/// Caches the result for subsequent calls.
|
|
async fn resolve_oauth_project(&self, token: &str) -> anyhow::Result<String> {
|
|
let project_seed = Self::load_non_empty_env("GOOGLE_CLOUD_PROJECT")
|
|
.or_else(|| Self::load_non_empty_env("GOOGLE_CLOUD_PROJECT_ID"));
|
|
let project_seed_for_request = project_seed.clone();
|
|
let duet_project_for_request = project_seed.clone();
|
|
|
|
// Check cache first
|
|
{
|
|
let cached = self.oauth_project.lock().await;
|
|
if let Some(ref project) = *cached {
|
|
return Ok(project.clone());
|
|
}
|
|
}
|
|
|
|
// Call loadCodeAssist
|
|
let client = self.http_client();
|
|
let response = client
|
|
.post(LOAD_CODE_ASSIST_ENDPOINT)
|
|
.bearer_auth(token)
|
|
.json(&serde_json::json!({
|
|
"cloudaicompanionProject": project_seed_for_request,
|
|
"metadata": {
|
|
"ideType": "GEMINI_CLI",
|
|
"platform": "PLATFORM_UNSPECIFIED",
|
|
"pluginType": "GEMINI",
|
|
"duetProject": duet_project_for_request,
|
|
}
|
|
}))
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
if let Some(seed) = project_seed {
|
|
tracing::warn!(
|
|
"loadCodeAssist failed (HTTP {status}); using GOOGLE_CLOUD_PROJECT fallback"
|
|
);
|
|
return Ok(seed);
|
|
}
|
|
let sanitized = super::sanitize_api_error(&body);
|
|
anyhow::bail!("loadCodeAssist failed (HTTP {status}): {sanitized}");
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct LoadCodeAssistResponse {
|
|
#[serde(rename = "cloudaicompanionProject")]
|
|
cloudaicompanion_project: Option<String>,
|
|
}
|
|
|
|
let result: LoadCodeAssistResponse = response.json().await?;
|
|
let project = result
|
|
.cloudaicompanion_project
|
|
.filter(|p| !p.trim().is_empty())
|
|
.or(project_seed)
|
|
.ok_or_else(|| anyhow::anyhow!("loadCodeAssist response missing project context"))?;
|
|
|
|
// Cache for future calls
|
|
{
|
|
let mut cached = self.oauth_project.lock().await;
|
|
*cached = Some(project.clone());
|
|
}
|
|
|
|
Ok(project)
|
|
}
|
|
|
|
/// Build the HTTP request for generateContent.
|
|
///
|
|
/// For OAuth, pass the resolved `oauth_token` and `project`.
|
|
/// For API key, both are `None`.
|
|
fn build_generate_content_request(
|
|
&self,
|
|
auth: &GeminiAuth,
|
|
url: &str,
|
|
request: &GenerateContentRequest,
|
|
model: &str,
|
|
include_generation_config: bool,
|
|
project: Option<&str>,
|
|
oauth_token: Option<&str>,
|
|
) -> reqwest::RequestBuilder {
|
|
let req = self.http_client().post(url).json(request);
|
|
match auth {
|
|
GeminiAuth::OAuthToken(_) | GeminiAuth::ManagedOAuth => {
|
|
let token = oauth_token.unwrap_or_default();
|
|
// Internal Code Assist API uses a wrapped payload shape:
|
|
// { model, project?, user_prompt_id?, request: { contents, systemInstruction?, generationConfig } }
|
|
let internal_request = InternalGenerateContentEnvelope {
|
|
model: Self::format_internal_model_name(model),
|
|
project: project.map(|value| value.to_string()),
|
|
user_prompt_id: Some(uuid::Uuid::new_v4().to_string()),
|
|
request: InternalGenerateContentRequest {
|
|
contents: request.contents.clone(),
|
|
system_instruction: request.system_instruction.clone(),
|
|
generation_config: if include_generation_config {
|
|
Some(request.generation_config.clone())
|
|
} else {
|
|
None
|
|
},
|
|
},
|
|
};
|
|
self.http_client()
|
|
.post(url)
|
|
.json(&internal_request)
|
|
.bearer_auth(token)
|
|
}
|
|
_ => req,
|
|
}
|
|
}
|
|
|
|
fn should_retry_oauth_without_generation_config(
|
|
status: reqwest::StatusCode,
|
|
error_text: &str,
|
|
) -> bool {
|
|
if status != reqwest::StatusCode::BAD_REQUEST {
|
|
return false;
|
|
}
|
|
|
|
error_text.contains("Unknown name \"generationConfig\"")
|
|
|| error_text.contains("Unknown name 'generationConfig'")
|
|
|| error_text.contains(r#"Unknown name \"generationConfig\""#)
|
|
}
|
|
|
|
fn should_rotate_oauth_on_error(status: reqwest::StatusCode, error_text: &str) -> bool {
|
|
status == reqwest::StatusCode::TOO_MANY_REQUESTS
|
|
|| status == reqwest::StatusCode::SERVICE_UNAVAILABLE
|
|
|| status.is_server_error()
|
|
|| error_text.contains("RESOURCE_EXHAUSTED")
|
|
}
|
|
|
|
fn parse_inline_image_marker(image_ref: &str) -> Option<InlineDataPart> {
|
|
let rest = image_ref.strip_prefix("data:")?;
|
|
let semi_index = rest.find(';')?;
|
|
let mime_type = rest[..semi_index].trim();
|
|
if mime_type.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let payload = rest[semi_index + 1..].strip_prefix("base64,")?.trim();
|
|
if payload.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
Some(InlineDataPart {
|
|
mime_type: mime_type.to_string(),
|
|
data: payload.to_string(),
|
|
})
|
|
}
|
|
|
|
fn build_user_parts(content: &str) -> Vec<Part> {
|
|
let (cleaned_text, image_refs) = multimodal::parse_image_markers(content);
|
|
if image_refs.is_empty() {
|
|
return vec![Part::Text {
|
|
text: content.to_string(),
|
|
}];
|
|
}
|
|
|
|
let mut parts: Vec<Part> = Vec::with_capacity(image_refs.len() + 1);
|
|
if !cleaned_text.is_empty() {
|
|
parts.push(Part::Text { text: cleaned_text });
|
|
}
|
|
|
|
for image_ref in image_refs {
|
|
if let Some(inline_data) = Self::parse_inline_image_marker(&image_ref) {
|
|
parts.push(Part::InlineData { inline_data });
|
|
} else {
|
|
parts.push(Part::Text {
|
|
text: format!("[IMAGE:{image_ref}]"),
|
|
});
|
|
}
|
|
}
|
|
|
|
if parts.is_empty() {
|
|
vec![Part::Text {
|
|
text: String::new(),
|
|
}]
|
|
} else {
|
|
parts
|
|
}
|
|
}
|
|
}
|
|
|
|
impl GeminiProvider {
|
|
async fn send_generate_content(
|
|
&self,
|
|
contents: Vec<Content>,
|
|
system_instruction: Option<Content>,
|
|
model: &str,
|
|
temperature: f64,
|
|
) -> anyhow::Result<(String, Option<TokenUsage>)> {
|
|
let auth = self.auth.as_ref().ok_or_else(|| {
|
|
anyhow::anyhow!(
|
|
"Gemini API key not found. Options:\n\
|
|
1. Set GEMINI_API_KEY env var\n\
|
|
2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
|
|
3. Run `zeroclaw auth login --provider gemini`\n\
|
|
4. Get an API key from https://aistudio.google.com/app/apikey\n\
|
|
5. Run `zeroclaw onboard` to configure"
|
|
)
|
|
})?;
|
|
|
|
let oauth_state = match auth {
|
|
GeminiAuth::OAuthToken(state) => Some(state.clone()),
|
|
_ => None,
|
|
};
|
|
|
|
// For OAuth: get a valid (potentially refreshed) token and resolve project
|
|
let (mut oauth_token, mut project) = match auth {
|
|
GeminiAuth::OAuthToken(state) => {
|
|
let token = Self::get_valid_oauth_token(state).await?;
|
|
let proj = self.resolve_oauth_project(&token).await?;
|
|
(Some(token), Some(proj))
|
|
}
|
|
GeminiAuth::ManagedOAuth => {
|
|
let auth_service = self
|
|
.auth_service
|
|
.as_ref()
|
|
.ok_or_else(|| anyhow::anyhow!("ManagedOAuth requires auth_service"))?;
|
|
let token = auth_service
|
|
.get_valid_gemini_access_token(self.auth_profile_override.as_deref())
|
|
.await?
|
|
.ok_or_else(|| {
|
|
anyhow::anyhow!(
|
|
"Gemini auth profile not found. Run `zeroclaw auth login --provider gemini`."
|
|
)
|
|
})?;
|
|
let proj = self.resolve_oauth_project(&token).await?;
|
|
(Some(token), Some(proj))
|
|
}
|
|
_ => (None, None),
|
|
};
|
|
|
|
let request = GenerateContentRequest {
|
|
contents,
|
|
system_instruction,
|
|
generation_config: GenerationConfig {
|
|
temperature,
|
|
max_output_tokens: 8192,
|
|
},
|
|
};
|
|
|
|
let url = Self::build_generate_content_url(model, auth);
|
|
|
|
let mut response = self
|
|
.build_generate_content_request(
|
|
auth,
|
|
&url,
|
|
&request,
|
|
model,
|
|
true,
|
|
project.as_deref(),
|
|
oauth_token.as_deref(),
|
|
)
|
|
.send()
|
|
.await?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
|
|
if auth.is_oauth() && Self::should_rotate_oauth_on_error(status, &error_text) {
|
|
// For CLI OAuth: rotate credentials
|
|
// For ManagedOAuth: AuthService handles refresh, just retry
|
|
let can_retry = match auth {
|
|
GeminiAuth::OAuthToken(_) => {
|
|
if let Some(state) = oauth_state.as_ref() {
|
|
self.rotate_oauth_credential(state).await
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
GeminiAuth::ManagedOAuth => true, // AuthService refreshes automatically
|
|
_ => false,
|
|
};
|
|
|
|
if can_retry {
|
|
// Re-fetch token (may be refreshed)
|
|
let (new_token, new_project) = match auth {
|
|
GeminiAuth::OAuthToken(state) => {
|
|
let token = Self::get_valid_oauth_token(state).await?;
|
|
let proj = self.resolve_oauth_project(&token).await?;
|
|
(token, proj)
|
|
}
|
|
GeminiAuth::ManagedOAuth => {
|
|
let auth_service = self.auth_service.as_ref().unwrap();
|
|
let token = auth_service
|
|
.get_valid_gemini_access_token(
|
|
self.auth_profile_override.as_deref(),
|
|
)
|
|
.await?
|
|
.ok_or_else(|| anyhow::anyhow!("Gemini auth profile not found"))?;
|
|
let proj = self.resolve_oauth_project(&token).await?;
|
|
(token, proj)
|
|
}
|
|
_ => unreachable!(),
|
|
};
|
|
oauth_token = Some(new_token);
|
|
project = Some(new_project);
|
|
response = self
|
|
.build_generate_content_request(
|
|
auth,
|
|
&url,
|
|
&request,
|
|
model,
|
|
true,
|
|
project.as_deref(),
|
|
oauth_token.as_deref(),
|
|
)
|
|
.send()
|
|
.await?;
|
|
} else {
|
|
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
|
}
|
|
} else if auth.is_oauth()
|
|
&& Self::should_retry_oauth_without_generation_config(status, &error_text)
|
|
{
|
|
tracing::warn!(
|
|
"Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
|
|
);
|
|
response = self
|
|
.build_generate_content_request(
|
|
auth,
|
|
&url,
|
|
&request,
|
|
model,
|
|
false,
|
|
project.as_deref(),
|
|
oauth_token.as_deref(),
|
|
)
|
|
.send()
|
|
.await?;
|
|
} else {
|
|
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
|
}
|
|
}
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
if auth.is_oauth()
|
|
&& Self::should_retry_oauth_without_generation_config(status, &error_text)
|
|
{
|
|
tracing::warn!(
|
|
"Gemini OAuth internal endpoint rejected generationConfig; retrying without generationConfig"
|
|
);
|
|
response = self
|
|
.build_generate_content_request(
|
|
auth,
|
|
&url,
|
|
&request,
|
|
model,
|
|
false,
|
|
project.as_deref(),
|
|
oauth_token.as_deref(),
|
|
)
|
|
.send()
|
|
.await?;
|
|
} else {
|
|
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
|
}
|
|
}
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
|
}
|
|
|
|
let result: GenerateContentResponse = response.json().await?;
|
|
if let Some(err) = &result.error {
|
|
anyhow::bail!("Gemini API error: {}", err.message);
|
|
}
|
|
let result = result.into_effective_response();
|
|
if let Some(err) = result.error {
|
|
anyhow::bail!("Gemini API error: {}", err.message);
|
|
}
|
|
|
|
let usage = result.usage_metadata.map(|u| TokenUsage {
|
|
input_tokens: u.prompt_token_count,
|
|
output_tokens: u.candidates_token_count,
|
|
});
|
|
|
|
let text = result
|
|
.candidates
|
|
.and_then(|c| c.into_iter().next())
|
|
.and_then(|c| c.content)
|
|
.and_then(|c| c.effective_text())
|
|
.ok_or_else(|| anyhow::anyhow!("No response from Gemini"))?;
|
|
|
|
Ok((text, usage))
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Provider for GeminiProvider {
|
|
async fn chat_with_system(
|
|
&self,
|
|
system_prompt: Option<&str>,
|
|
message: &str,
|
|
model: &str,
|
|
temperature: f64,
|
|
) -> anyhow::Result<String> {
|
|
let system_instruction = system_prompt.map(|sys| Content {
|
|
role: None,
|
|
parts: vec![Part::Text {
|
|
text: sys.to_string(),
|
|
}],
|
|
});
|
|
|
|
let contents = vec![Content {
|
|
role: Some("user".to_string()),
|
|
parts: Self::build_user_parts(message),
|
|
}];
|
|
|
|
let (text, _usage) = self
|
|
.send_generate_content(contents, system_instruction, model, temperature)
|
|
.await?;
|
|
Ok(text)
|
|
}
|
|
|
|
async fn chat_with_history(
|
|
&self,
|
|
messages: &[ChatMessage],
|
|
model: &str,
|
|
temperature: f64,
|
|
) -> anyhow::Result<String> {
|
|
let mut system_parts: Vec<&str> = Vec::new();
|
|
let mut contents: Vec<Content> = Vec::new();
|
|
|
|
for msg in messages {
|
|
match msg.role.as_str() {
|
|
"system" => {
|
|
system_parts.push(&msg.content);
|
|
}
|
|
"user" => {
|
|
contents.push(Content {
|
|
role: Some("user".to_string()),
|
|
parts: Self::build_user_parts(&msg.content),
|
|
});
|
|
}
|
|
"assistant" => {
|
|
// Gemini API uses "model" role instead of "assistant"
|
|
contents.push(Content {
|
|
role: Some("model".to_string()),
|
|
parts: vec![Part::Text {
|
|
text: msg.content.clone(),
|
|
}],
|
|
});
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let system_instruction = if system_parts.is_empty() {
|
|
None
|
|
} else {
|
|
Some(Content {
|
|
role: None,
|
|
parts: vec![Part::Text {
|
|
text: system_parts.join("\n\n"),
|
|
}],
|
|
})
|
|
};
|
|
|
|
let (text, _usage) = self
|
|
.send_generate_content(contents, system_instruction, model, temperature)
|
|
.await?;
|
|
Ok(text)
|
|
}
|
|
|
|
async fn chat(
|
|
&self,
|
|
request: crate::providers::traits::ChatRequest<'_>,
|
|
model: &str,
|
|
temperature: f64,
|
|
) -> anyhow::Result<ChatResponse> {
|
|
let mut system_parts: Vec<&str> = Vec::new();
|
|
let mut contents: Vec<Content> = Vec::new();
|
|
|
|
for msg in request.messages {
|
|
match msg.role.as_str() {
|
|
"system" => system_parts.push(&msg.content),
|
|
"user" => contents.push(Content {
|
|
role: Some("user".to_string()),
|
|
parts: Self::build_user_parts(&msg.content),
|
|
}),
|
|
"assistant" => contents.push(Content {
|
|
role: Some("model".to_string()),
|
|
parts: vec![Part::Text {
|
|
text: msg.content.clone(),
|
|
}],
|
|
}),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let system_instruction = if system_parts.is_empty() {
|
|
None
|
|
} else {
|
|
Some(Content {
|
|
role: None,
|
|
parts: vec![Part::Text {
|
|
text: system_parts.join("\n\n"),
|
|
}],
|
|
})
|
|
};
|
|
|
|
let (text, usage) = self
|
|
.send_generate_content(contents, system_instruction, model, temperature)
|
|
.await?;
|
|
|
|
Ok(ChatResponse {
|
|
text: Some(text),
|
|
tool_calls: Vec::new(),
|
|
usage,
|
|
reasoning_content: None,
|
|
quota_metadata: None,
|
|
})
|
|
}
|
|
|
|
async fn warmup(&self) -> anyhow::Result<()> {
|
|
if let Some(auth) = self.auth.as_ref() {
|
|
match auth {
|
|
GeminiAuth::ManagedOAuth => {
|
|
// For ManagedOAuth, verify and refresh the token if needed.
|
|
// This ensures fallback works even if tokens expired during daemon uptime.
|
|
let auth_service = self
|
|
.auth_service
|
|
.as_ref()
|
|
.ok_or_else(|| anyhow::anyhow!("ManagedOAuth requires auth_service"))?;
|
|
|
|
let _token = auth_service
|
|
.get_valid_gemini_access_token(self.auth_profile_override.as_deref())
|
|
.await?
|
|
.ok_or_else(|| {
|
|
anyhow::anyhow!(
|
|
"Gemini auth profile not found or expired. Run: zeroclaw auth login --provider gemini"
|
|
)
|
|
})?;
|
|
|
|
// Token refresh happens in get_valid_gemini_access_token().
|
|
// We don't call resolve_oauth_project() here to keep warmup fast.
|
|
// OAuth project will be resolved lazily on first real request.
|
|
}
|
|
GeminiAuth::OAuthToken(_) => {
|
|
// CLI OAuth — cloudcode-pa does not expose a lightweight model-list probe.
|
|
// Token will be validated on first real request.
|
|
}
|
|
_ => {
|
|
// API key path — verify with public API models endpoint.
|
|
let url = if auth.is_api_key() {
|
|
format!(
|
|
"https://generativelanguage.googleapis.com/v1beta/models?key={}",
|
|
auth.api_key_credential()
|
|
)
|
|
} else {
|
|
"https://generativelanguage.googleapis.com/v1beta/models".to_string()
|
|
};
|
|
|
|
self.http_client()
|
|
.get(&url)
|
|
.send()
|
|
.await?
|
|
.error_for_status()?;
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use reqwest::{header::AUTHORIZATION, StatusCode};
|
|
|
|
/// Helper to create a test OAuth auth variant.
|
|
fn test_oauth_auth(token: &str) -> GeminiAuth {
|
|
GeminiAuth::OAuthToken(Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
|
|
access_token: token.to_string(),
|
|
refresh_token: None,
|
|
client_id: None,
|
|
client_secret: None,
|
|
expiry_millis: None,
|
|
})))
|
|
}
|
|
|
|
fn test_provider(auth: Option<GeminiAuth>) -> GeminiProvider {
|
|
GeminiProvider {
|
|
auth,
|
|
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
|
oauth_cred_paths: Vec::new(),
|
|
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
|
|
auth_service: None,
|
|
auth_profile_override: None,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn normalize_non_empty_trims_and_filters() {
|
|
assert_eq!(
|
|
GeminiProvider::normalize_non_empty(" value "),
|
|
Some("value".into())
|
|
);
|
|
assert_eq!(GeminiProvider::normalize_non_empty(""), None);
|
|
assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None);
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_refresh_form_uses_provided_client_credentials() {
|
|
let form = build_oauth_refresh_form("refresh-token", Some("client-id"), Some("secret"));
|
|
let map: std::collections::HashMap<_, _> = form.into_iter().collect();
|
|
assert_eq!(map.get("grant_type"), Some(&"refresh_token".to_string()));
|
|
assert_eq!(map.get("refresh_token"), Some(&"refresh-token".to_string()));
|
|
assert_eq!(map.get("client_id"), Some(&"client-id".to_string()));
|
|
assert_eq!(map.get("client_secret"), Some(&"secret".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_refresh_form_omits_client_credentials_when_missing() {
|
|
let form = build_oauth_refresh_form("refresh-token", None, None);
|
|
let map: std::collections::HashMap<_, _> = form.into_iter().collect();
|
|
assert!(!map.contains_key("client_id"));
|
|
assert!(!map.contains_key("client_secret"));
|
|
}
|
|
|
|
#[test]
|
|
fn extract_client_id_from_id_token_prefers_aud_claim() {
|
|
let payload = serde_json::json!({
|
|
"aud": "aud-client-id",
|
|
"azp": "azp-client-id"
|
|
});
|
|
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.encode(serde_json::to_vec(&payload).unwrap());
|
|
let token = format!("header.{payload_b64}.sig");
|
|
|
|
assert_eq!(
|
|
extract_client_id_from_id_token(&token),
|
|
Some("aud-client-id".to_string())
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_client_id_from_id_token_uses_azp_when_aud_missing() {
|
|
let payload = serde_json::json!({
|
|
"azp": "azp-client-id"
|
|
});
|
|
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.encode(serde_json::to_vec(&payload).unwrap());
|
|
let token = format!("header.{payload_b64}.sig");
|
|
|
|
assert_eq!(
|
|
extract_client_id_from_id_token(&token),
|
|
Some("azp-client-id".to_string())
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn extract_client_id_from_id_token_returns_none_for_invalid_tokens() {
|
|
assert_eq!(extract_client_id_from_id_token("invalid"), None);
|
|
assert_eq!(extract_client_id_from_id_token("a.b.c"), None);
|
|
}
|
|
|
|
#[test]
|
|
fn try_load_cli_token_derives_client_id_from_id_token_when_missing() {
|
|
let payload = serde_json::json!({ "aud": "derived-client-id" });
|
|
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.encode(serde_json::to_vec(&payload).unwrap());
|
|
let id_token = format!("header.{payload_b64}.sig");
|
|
|
|
let file = tempfile::NamedTempFile::new().unwrap();
|
|
let json = format!(
|
|
r#"{{
|
|
"access_token": "ya29.test-access",
|
|
"refresh_token": "1//test-refresh",
|
|
"id_token": "{id_token}"
|
|
}}"#
|
|
);
|
|
std::fs::write(file.path(), json).unwrap();
|
|
|
|
let path = file.path().to_path_buf();
|
|
let state = GeminiProvider::try_load_gemini_cli_token(Some(&path)).unwrap();
|
|
assert_eq!(state.client_id.as_deref(), Some("derived-client-id"));
|
|
assert_eq!(state.client_secret, None);
|
|
}
|
|
|
|
#[test]
|
|
fn provider_creates_without_key() {
|
|
let provider = GeminiProvider::new(None);
|
|
// May pick up env vars; just verify it doesn't panic
|
|
let _ = provider.auth_source();
|
|
}
|
|
|
|
#[test]
|
|
fn provider_creates_with_key() {
|
|
let provider = GeminiProvider::new(Some("test-api-key"));
|
|
assert!(matches!(
|
|
provider.auth,
|
|
Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn provider_rejects_empty_key() {
|
|
let provider = GeminiProvider::new(Some(""));
|
|
assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_))));
|
|
}
|
|
|
|
#[test]
|
|
fn gemini_cli_dir_returns_path() {
|
|
let dir = GeminiProvider::gemini_cli_dir();
|
|
// Should return Some on systems with home dir
|
|
if UserDirs::new().is_some() {
|
|
assert!(dir.is_some());
|
|
assert!(dir.unwrap().ends_with(".gemini"));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn auth_source_explicit_key() {
|
|
let provider = test_provider(Some(GeminiAuth::ExplicitKey("key".into())));
|
|
assert_eq!(provider.auth_source(), "config");
|
|
}
|
|
|
|
#[test]
|
|
fn auth_source_none_without_credentials() {
|
|
let provider = test_provider(None);
|
|
assert_eq!(provider.auth_source(), "none");
|
|
}
|
|
|
|
#[test]
|
|
fn auth_source_oauth() {
|
|
let provider = test_provider(Some(test_oauth_auth("ya29.mock")));
|
|
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
|
|
}
|
|
|
|
#[test]
|
|
fn model_name_formatting() {
|
|
assert_eq!(
|
|
GeminiProvider::format_model_name("gemini-2.0-flash"),
|
|
"models/gemini-2.0-flash"
|
|
);
|
|
assert_eq!(
|
|
GeminiProvider::format_model_name("models/gemini-1.5-pro"),
|
|
"models/gemini-1.5-pro"
|
|
);
|
|
assert_eq!(
|
|
GeminiProvider::format_internal_model_name("models/gemini-2.5-flash"),
|
|
"gemini-2.5-flash"
|
|
);
|
|
assert_eq!(
|
|
GeminiProvider::format_internal_model_name("gemini-2.5-flash"),
|
|
"gemini-2.5-flash"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn api_key_url_includes_key_query_param() {
|
|
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
|
assert!(url.contains(":generateContent?key=api-key-123"));
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_url_uses_internal_endpoint() {
|
|
let auth = test_oauth_auth("ya29.test-token");
|
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
|
assert!(url.starts_with("https://cloudcode-pa.googleapis.com/v1internal"));
|
|
assert!(url.ends_with(":generateContent"));
|
|
assert!(!url.contains("generativelanguage.googleapis.com"));
|
|
assert!(!url.contains("?key="));
|
|
}
|
|
|
|
#[test]
|
|
fn api_key_url_uses_public_endpoint() {
|
|
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
|
assert!(url.contains("generativelanguage.googleapis.com/v1beta"));
|
|
assert!(url.contains("models/gemini-2.0-flash"));
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_request_uses_bearer_auth_header() {
|
|
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
|
|
let auth = test_oauth_auth("ya29.mock-token");
|
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
|
let body = GenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".into()),
|
|
parts: vec![Part::Text {
|
|
text: "hello".into(),
|
|
}],
|
|
}],
|
|
system_instruction: None,
|
|
generation_config: GenerationConfig {
|
|
temperature: 0.7,
|
|
max_output_tokens: 8192,
|
|
},
|
|
};
|
|
|
|
let request = provider
|
|
.build_generate_content_request(
|
|
&auth,
|
|
&url,
|
|
&body,
|
|
"gemini-2.0-flash",
|
|
true,
|
|
Some("test-project"),
|
|
Some("ya29.mock-token"),
|
|
)
|
|
.build()
|
|
.unwrap();
|
|
|
|
assert_eq!(
|
|
request
|
|
.headers()
|
|
.get(AUTHORIZATION)
|
|
.and_then(|h| h.to_str().ok()),
|
|
Some("Bearer ya29.mock-token")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_request_wraps_payload_in_request_envelope() {
|
|
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
|
|
let auth = test_oauth_auth("ya29.mock-token");
|
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
|
let body = GenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".into()),
|
|
parts: vec![Part::Text {
|
|
text: "hello".into(),
|
|
}],
|
|
}],
|
|
system_instruction: None,
|
|
generation_config: GenerationConfig {
|
|
temperature: 0.7,
|
|
max_output_tokens: 8192,
|
|
},
|
|
};
|
|
|
|
let request = provider
|
|
.build_generate_content_request(
|
|
&auth,
|
|
&url,
|
|
&body,
|
|
"models/gemini-2.0-flash",
|
|
true,
|
|
Some("test-project"),
|
|
Some("ya29.mock-token"),
|
|
)
|
|
.build()
|
|
.unwrap();
|
|
|
|
let payload = request
|
|
.body()
|
|
.and_then(|b| b.as_bytes())
|
|
.expect("json request body should be bytes");
|
|
let json: serde_json::Value = serde_json::from_slice(payload).unwrap();
|
|
|
|
assert_eq!(json["model"], "gemini-2.0-flash");
|
|
assert!(json.get("generationConfig").is_none());
|
|
assert!(json.get("request").is_some());
|
|
assert!(json["request"].get("generationConfig").is_some());
|
|
}
|
|
|
|
#[test]
|
|
fn api_key_request_does_not_set_bearer_header() {
|
|
let provider = test_provider(Some(GeminiAuth::ExplicitKey("api-key-123".into())));
|
|
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
|
let body = GenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".into()),
|
|
parts: vec![Part::Text {
|
|
text: "hello".into(),
|
|
}],
|
|
}],
|
|
system_instruction: None,
|
|
generation_config: GenerationConfig {
|
|
temperature: 0.7,
|
|
max_output_tokens: 8192,
|
|
},
|
|
};
|
|
|
|
let request = provider
|
|
.build_generate_content_request(
|
|
&auth,
|
|
&url,
|
|
&body,
|
|
"gemini-2.0-flash",
|
|
true,
|
|
None,
|
|
None,
|
|
)
|
|
.build()
|
|
.unwrap();
|
|
|
|
assert!(request.headers().get(AUTHORIZATION).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn request_serialization() {
|
|
let request = GenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".to_string()),
|
|
parts: vec![Part::Text {
|
|
text: "Hello".to_string(),
|
|
}],
|
|
}],
|
|
system_instruction: Some(Content {
|
|
role: None,
|
|
parts: vec![Part::Text {
|
|
text: "You are helpful".to_string(),
|
|
}],
|
|
}),
|
|
generation_config: GenerationConfig {
|
|
temperature: 0.7,
|
|
max_output_tokens: 8192,
|
|
},
|
|
};
|
|
|
|
let json = serde_json::to_string(&request).unwrap();
|
|
assert!(json.contains("\"role\":\"user\""));
|
|
assert!(json.contains("\"text\":\"Hello\""));
|
|
assert!(json.contains("\"systemInstruction\""));
|
|
assert!(!json.contains("\"system_instruction\""));
|
|
assert!(json.contains("\"temperature\":0.7"));
|
|
assert!(json.contains("\"maxOutputTokens\":8192"));
|
|
}
|
|
|
|
#[test]
|
|
fn build_user_parts_text_only_is_backward_compatible() {
|
|
let content = "Plain text message without image markers.";
|
|
let parts = GeminiProvider::build_user_parts(content);
|
|
assert_eq!(parts.len(), 1);
|
|
match &parts[0] {
|
|
Part::Text { text } => assert_eq!(text, content),
|
|
Part::InlineData { .. } => panic!("text-only message must stay text-only"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn build_user_parts_single_image() {
|
|
let parts = GeminiProvider::build_user_parts(
|
|
"Describe this image [IMAGE:data:image/png;base64,aGVsbG8=]",
|
|
);
|
|
assert_eq!(parts.len(), 2);
|
|
match &parts[0] {
|
|
Part::Text { text } => assert_eq!(text, "Describe this image"),
|
|
Part::InlineData { .. } => panic!("first part should be text"),
|
|
}
|
|
match &parts[1] {
|
|
Part::InlineData { inline_data } => {
|
|
assert_eq!(inline_data.mime_type, "image/png");
|
|
assert_eq!(inline_data.data, "aGVsbG8=");
|
|
}
|
|
Part::Text { .. } => panic!("second part should be inline image data"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn build_user_parts_multiple_images() {
|
|
let parts = GeminiProvider::build_user_parts(
|
|
"Compare [IMAGE:data:image/png;base64,aQ==] and [IMAGE:data:image/jpeg;base64,ag==]",
|
|
);
|
|
assert_eq!(parts.len(), 3);
|
|
assert!(matches!(parts[0], Part::Text { .. }));
|
|
assert!(matches!(parts[1], Part::InlineData { .. }));
|
|
assert!(matches!(parts[2], Part::InlineData { .. }));
|
|
}
|
|
|
|
#[test]
|
|
fn build_user_parts_image_only() {
|
|
let parts = GeminiProvider::build_user_parts("[IMAGE:data:image/webp;base64,YWJjZA==]");
|
|
assert_eq!(parts.len(), 1);
|
|
match &parts[0] {
|
|
Part::InlineData { inline_data } => {
|
|
assert_eq!(inline_data.mime_type, "image/webp");
|
|
assert_eq!(inline_data.data, "YWJjZA==");
|
|
}
|
|
Part::Text { .. } => panic!("image-only message should create inline image part"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn build_user_parts_fallback_for_non_data_uri_markers() {
|
|
let parts = GeminiProvider::build_user_parts("Inspect [IMAGE:https://example.com/img.png]");
|
|
assert_eq!(parts.len(), 2);
|
|
match &parts[0] {
|
|
Part::Text { text } => assert_eq!(text, "Inspect"),
|
|
Part::InlineData { .. } => panic!("first part should be text"),
|
|
}
|
|
match &parts[1] {
|
|
Part::Text { text } => assert_eq!(text, "[IMAGE:https://example.com/img.png]"),
|
|
Part::InlineData { .. } => panic!("invalid markers should fall back to text"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn internal_request_includes_model() {
|
|
let request = InternalGenerateContentEnvelope {
|
|
model: "gemini-3-pro-preview".to_string(),
|
|
project: Some("test-project".to_string()),
|
|
user_prompt_id: Some("prompt-123".to_string()),
|
|
request: InternalGenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".to_string()),
|
|
parts: vec![Part::Text {
|
|
text: "Hello".to_string(),
|
|
}],
|
|
}],
|
|
system_instruction: None,
|
|
generation_config: Some(GenerationConfig {
|
|
temperature: 0.7,
|
|
max_output_tokens: 8192,
|
|
}),
|
|
},
|
|
};
|
|
|
|
let json = serde_json::to_string(&request).unwrap();
|
|
assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
|
|
assert!(json.contains("\"request\""));
|
|
assert!(json.contains("\"generationConfig\""));
|
|
assert!(json.contains("\"maxOutputTokens\":8192"));
|
|
assert!(json.contains("\"user_prompt_id\":\"prompt-123\""));
|
|
assert!(json.contains("\"project\":\"test-project\""));
|
|
assert!(json.contains("\"role\":\"user\""));
|
|
assert!(json.contains("\"temperature\":0.7"));
|
|
}
|
|
|
|
#[test]
|
|
fn internal_request_omits_generation_config_when_none() {
|
|
let request = InternalGenerateContentEnvelope {
|
|
model: "gemini-3-pro-preview".to_string(),
|
|
project: Some("test-project".to_string()),
|
|
user_prompt_id: None,
|
|
request: InternalGenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".to_string()),
|
|
parts: vec![Part::Text {
|
|
text: "Hello".to_string(),
|
|
}],
|
|
}],
|
|
system_instruction: None,
|
|
generation_config: None,
|
|
},
|
|
};
|
|
|
|
let json = serde_json::to_string(&request).unwrap();
|
|
assert!(!json.contains("generationConfig"));
|
|
assert!(json.contains("\"model\":\"gemini-3-pro-preview\""));
|
|
}
|
|
|
|
#[test]
|
|
fn internal_request_includes_project() {
|
|
let request = InternalGenerateContentEnvelope {
|
|
model: "gemini-2.5-flash".to_string(),
|
|
project: Some("my-gcp-project-id".to_string()),
|
|
user_prompt_id: None,
|
|
request: InternalGenerateContentRequest {
|
|
contents: vec![Content {
|
|
role: Some("user".to_string()),
|
|
parts: vec![Part::Text {
|
|
text: "Hello".to_string(),
|
|
}],
|
|
}],
|
|
system_instruction: None,
|
|
generation_config: None,
|
|
},
|
|
};
|
|
|
|
let json = serde_json::to_string(&request).unwrap();
|
|
assert!(json.contains("\"project\":\"my-gcp-project-id\""));
|
|
}
|
|
|
|
#[test]
|
|
fn internal_response_deserialize_nested() {
|
|
let json = r#"{
|
|
"response": {
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{"text": "Hello from internal API!"}]
|
|
}
|
|
}]
|
|
}
|
|
}"#;
|
|
|
|
let internal: InternalGenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let text = internal
|
|
.response
|
|
.candidates
|
|
.unwrap()
|
|
.into_iter()
|
|
.next()
|
|
.unwrap()
|
|
.content
|
|
.unwrap()
|
|
.parts
|
|
.into_iter()
|
|
.next()
|
|
.unwrap()
|
|
.text;
|
|
assert_eq!(text, Some("Hello from internal API!".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn creds_deserialize_with_expiry_date() {
|
|
let json = r#"{
|
|
"access_token": "ya29.test-token",
|
|
"refresh_token": "1//test-refresh",
|
|
"expiry_date": 4102444800000
|
|
}"#;
|
|
|
|
let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
|
|
assert_eq!(creds.access_token.as_deref(), Some("ya29.test-token"));
|
|
assert_eq!(creds.refresh_token.as_deref(), Some("1//test-refresh"));
|
|
assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
|
|
assert!(creds.expiry.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn creds_deserialize_accepts_camel_case_fields() {
|
|
let json = r#"{
|
|
"access_token": "ya29.test-token",
|
|
"idToken": "header.payload.sig",
|
|
"refresh_token": "1//test-refresh",
|
|
"clientId": "test-client-id",
|
|
"clientSecret": "test-client-secret",
|
|
"expiryDate": 4102444800000
|
|
}"#;
|
|
|
|
let creds: GeminiCliOAuthCreds = serde_json::from_str(json).unwrap();
|
|
assert_eq!(creds.id_token.as_deref(), Some("header.payload.sig"));
|
|
assert_eq!(creds.client_id.as_deref(), Some("test-client-id"));
|
|
assert_eq!(creds.client_secret.as_deref(), Some("test-client-secret"));
|
|
assert_eq!(creds.expiry_date, Some(4_102_444_800_000));
|
|
}
|
|
|
|
#[test]
|
|
fn oauth_retry_detection_for_generation_config_rejection() {
|
|
// Bare quotes (e.g. pre-parsed error string)
|
|
let err =
|
|
"Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field.";
|
|
assert!(
|
|
GeminiProvider::should_retry_oauth_without_generation_config(
|
|
StatusCode::BAD_REQUEST,
|
|
err
|
|
)
|
|
);
|
|
// JSON-escaped quotes (raw response body from Google API)
|
|
let err_json = r#"Invalid JSON payload received. Unknown name \"generationConfig\": Cannot find field."#;
|
|
assert!(
|
|
GeminiProvider::should_retry_oauth_without_generation_config(
|
|
StatusCode::BAD_REQUEST,
|
|
err_json
|
|
)
|
|
);
|
|
assert!(
|
|
!GeminiProvider::should_retry_oauth_without_generation_config(
|
|
StatusCode::UNAUTHORIZED,
|
|
err
|
|
)
|
|
);
|
|
assert!(
|
|
!GeminiProvider::should_retry_oauth_without_generation_config(
|
|
StatusCode::BAD_REQUEST,
|
|
"something else"
|
|
)
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn response_deserialization() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{"text": "Hello there!"}]
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
assert!(response.candidates.is_some());
|
|
let text = response
|
|
.candidates
|
|
.unwrap()
|
|
.into_iter()
|
|
.next()
|
|
.unwrap()
|
|
.content
|
|
.unwrap()
|
|
.parts
|
|
.into_iter()
|
|
.next()
|
|
.unwrap()
|
|
.text;
|
|
assert_eq!(text, Some("Hello there!".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn error_response_deserialization() {
|
|
let json = r#"{
|
|
"error": {
|
|
"message": "Invalid API key"
|
|
}
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
assert!(response.error.is_some());
|
|
assert_eq!(response.error.unwrap().message, "Invalid API key");
|
|
}
|
|
|
|
#[test]
|
|
fn internal_response_deserialization() {
|
|
let json = r#"{
|
|
"response": {
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{"text": "Hello from internal"}]
|
|
}
|
|
}]
|
|
}
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let text = response
|
|
.into_effective_response()
|
|
.candidates
|
|
.unwrap()
|
|
.into_iter()
|
|
.next()
|
|
.unwrap()
|
|
.content
|
|
.unwrap()
|
|
.parts
|
|
.into_iter()
|
|
.next()
|
|
.unwrap()
|
|
.text;
|
|
assert_eq!(text, Some("Hello from internal".to_string()));
|
|
}
|
|
|
|
// ── Thinking model response tests ──────────────────────────────────────
|
|
|
|
#[test]
|
|
fn thinking_response_extracts_non_thinking_text() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [
|
|
{"thought": true, "text": "Let me think about this..."},
|
|
{"text": "The answer is 42."},
|
|
{"thoughtSignature": "c2lnbmF0dXJl"}
|
|
]
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let candidate = response.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, Some("The answer is 42.".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn non_thinking_response_unaffected() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [{"text": "Hello there!"}]
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let candidate = response.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, Some("Hello there!".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn thinking_only_response_falls_back_to_thinking_text() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [
|
|
{"thought": true, "text": "I need more context..."},
|
|
{"thoughtSignature": "c2lnbmF0dXJl"}
|
|
]
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let candidate = response.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, Some("I need more context...".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn empty_parts_returns_none() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": []
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let candidate = response.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, None);
|
|
}
|
|
|
|
#[test]
|
|
fn multiple_text_parts_concatenated() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [
|
|
{"text": "Part one. "},
|
|
{"text": "Part two."}
|
|
]
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let candidate = response.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, Some("Part one. Part two.".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn thought_signature_only_parts_skipped() {
|
|
let json = r#"{
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [
|
|
{"thoughtSignature": "c2lnbmF0dXJl"}
|
|
]
|
|
}
|
|
}]
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let candidate = response.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, None);
|
|
}
|
|
|
|
#[test]
|
|
fn internal_response_thinking_model() {
|
|
let json = r#"{
|
|
"response": {
|
|
"candidates": [{
|
|
"content": {
|
|
"parts": [
|
|
{"thought": true, "text": "reasoning..."},
|
|
{"text": "final answer"}
|
|
]
|
|
}
|
|
}]
|
|
}
|
|
}"#;
|
|
|
|
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let effective = response.into_effective_response();
|
|
let candidate = effective.candidates.unwrap().into_iter().next().unwrap();
|
|
let text = candidate.content.unwrap().effective_text();
|
|
assert_eq!(text, Some("final answer".to_string()));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn warmup_without_key_is_noop() {
|
|
let provider = test_provider(None);
|
|
let result = provider.warmup().await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn warmup_oauth_is_noop() {
|
|
let provider = test_provider(Some(test_oauth_auth("ya29.mock-token")));
|
|
let result = provider.warmup().await;
|
|
assert!(result.is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn discover_oauth_cred_paths_does_not_panic() {
|
|
let _paths = GeminiProvider::discover_oauth_cred_paths();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn rotate_oauth_without_alternatives_returns_false() {
|
|
let state = Arc::new(tokio::sync::Mutex::new(OAuthTokenState {
|
|
access_token: "ya29.mock".to_string(),
|
|
refresh_token: None,
|
|
client_id: None,
|
|
client_secret: None,
|
|
expiry_millis: None,
|
|
}));
|
|
let provider = test_provider(Some(GeminiAuth::OAuthToken(state.clone())));
|
|
assert!(!provider.rotate_oauth_credential(&state).await);
|
|
}
|
|
|
|
#[test]
|
|
fn response_parses_usage_metadata() {
|
|
let json = r#"{
|
|
"candidates": [{"content": {"parts": [{"text": "Hello"}]}}],
|
|
"usageMetadata": {"promptTokenCount": 120, "candidatesTokenCount": 40}
|
|
}"#;
|
|
let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
let usage = resp.usage_metadata.unwrap();
|
|
assert_eq!(usage.prompt_token_count, Some(120));
|
|
assert_eq!(usage.candidates_token_count, Some(40));
|
|
}
|
|
|
|
#[test]
|
|
fn response_parses_without_usage_metadata() {
|
|
let json = r#"{"candidates": [{"content": {"parts": [{"text": "Hello"}]}}]}"#;
|
|
let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
|
assert!(resp.usage_metadata.is_none());
|
|
}
|
|
|
|
/// Validates that warmup() for ManagedOAuth requires auth_service.
|
|
#[tokio::test]
|
|
async fn warmup_managed_oauth_requires_auth_service() {
|
|
let provider = GeminiProvider {
|
|
auth: Some(GeminiAuth::ManagedOAuth),
|
|
oauth_project: Arc::new(tokio::sync::Mutex::new(None)),
|
|
oauth_cred_paths: Vec::new(),
|
|
oauth_index: Arc::new(tokio::sync::Mutex::new(0)),
|
|
auth_service: None, // Missing auth_service
|
|
auth_profile_override: None,
|
|
};
|
|
|
|
let result = provider.warmup().await;
|
|
assert!(result.is_err());
|
|
assert!(result
|
|
.unwrap_err()
|
|
.to_string()
|
|
.contains("ManagedOAuth requires auth_service"));
|
|
}
|
|
|
|
/// Validates that warmup() for CLI OAuth skips validation (existing behavior).
|
|
#[tokio::test]
|
|
async fn warmup_cli_oauth_skips_validation() {
|
|
let provider = test_provider(Some(test_oauth_auth("fake_token")));
|
|
let result = provider.warmup().await;
|
|
// Should succeed without making HTTP requests
|
|
assert!(result.is_ok());
|
|
}
|
|
}
|