diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index b3b7110a9..49d10696a 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -72,6 +72,8 @@ struct GenerateContentRequest { #[derive(Debug, Serialize)] struct InternalGenerateContentEnvelope { model: String, + generation_config: InternalGenerationConfig, + contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] project: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -108,6 +110,12 @@ struct GenerationConfig { max_output_tokens: u32, } +#[derive(Debug, Serialize, Clone)] +struct InternalGenerationConfig { + temperature: f64, + max_output_tokens: u32, +} + #[derive(Debug, Deserialize)] struct GenerateContentResponse { candidates: Option>, @@ -316,16 +324,37 @@ impl GeminiProvider { let req = self.http_client().post(url).json(request); match auth { GeminiAuth::OAuthToken(token) => { - // cloudcode-pa expects an outer envelope with `request`. - let internal_request = InternalGenerateContentEnvelope { - model: Self::format_internal_model_name(model), - project: None, - user_prompt_id: None, - request: InternalGenerateContentRequest { - contents: request.contents.clone(), - system_instruction: request.system_instruction.clone(), - generation_config: request.generation_config.clone(), + // Internal API expects the model in the request body envelope + let internal_request = InternalGenerateContentRequest { + model: Self::format_model_name(model), + generation_config: InternalGenerationConfig { + temperature: request.generation_config.temperature, + max_output_tokens: request.generation_config.max_output_tokens, }, + contents: request + .contents + .iter() + .map(|c| Content { + role: c.role.clone(), + parts: c + .parts + .iter() + .map(|p| Part { + text: p.text.clone(), + }) + .collect(), + }) + .collect(), + system_instruction: request.system_instruction.as_ref().map(|si| Content { + role: si.role.clone(), + parts: si + .parts + .iter() + .map(|p| Part { + text: p.text.clone(), + }) + .collect(), + }), }; self.http_client() .post(url) @@ -747,16 +776,16 @@ mod tests { #[test] fn internal_request_includes_model() { - let request = InternalGenerateContentEnvelope { - model: "gemini-test-model".to_string(), - project: None, - user_prompt_id: None, - request: InternalGenerateContentRequest { - contents: vec![Content { - role: Some("user".to_string()), - parts: vec![Part { - text: "Hello".to_string(), - }], + let request = InternalGenerateContentRequest { + model: "models/gemini-3-pro-preview".to_string(), + generation_config: InternalGenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + contents: vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: "Hello".to_string(), }], system_instruction: None, generation_config: GenerationConfig { @@ -766,11 +795,12 @@ mod tests { }, }; - let json: serde_json::Value = serde_json::to_value(&request).unwrap(); - assert_eq!(json["model"], "gemini-test-model"); - assert!(json.get("generationConfig").is_none()); - assert!(json["request"].get("generationConfig").is_some()); - assert_eq!(json["request"]["contents"][0]["role"], "user"); + let json = serde_json::to_string(&request).unwrap(); + assert!(json.contains("\"model\":\"models/gemini-3-pro-preview\"")); + assert!(json.contains("\"generation_config\"")); + assert!(json.contains("\"max_output_tokens\":8192")); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"temperature\":0.7")); } #[test]