diff --git a/src/multimodal.rs b/src/multimodal.rs index 7182df7a8..50722dc6b 100644 --- a/src/multimodal.rs +++ b/src/multimodal.rs @@ -2,9 +2,12 @@ use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig}; use crate::providers::ChatMessage; use base64::{engine::general_purpose::STANDARD, Engine as _}; use reqwest::Client; +use std::io::Cursor; use std::path::Path; const IMAGE_MARKER_PREFIX: &str = "[IMAGE:"; +const OPTIMIZED_IMAGE_MAX_DIMENSION: u32 = 512; +const OPTIMIZED_IMAGE_TARGET_BYTES: usize = 256 * 1024; const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[ "image/png", "image/jpeg", @@ -198,7 +201,7 @@ async fn normalize_image_reference( remote_client: &Client, ) -> anyhow::Result { if source.starts_with("data:") { - return normalize_data_uri(source, max_bytes); + return normalize_data_uri(source, max_bytes).await; } if source.starts_with("http://") || source.starts_with("https://") { @@ -215,7 +218,7 @@ async fn normalize_image_reference( normalize_local_image(source, max_bytes).await } -fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result { +async fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result { let Some(comma_idx) = source.find(',') else { return Err(MultimodalError::InvalidMarker { input: source.to_string(), @@ -252,9 +255,14 @@ fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result reason: format!("invalid base64 payload: {error}"), })?; - validate_size(source, decoded.len(), max_bytes)?; + let (optimized_bytes, optimized_mime) = + optimize_image_for_prompt(source, decoded, &mime).await?; + validate_size(source, optimized_bytes.len(), max_bytes)?; - Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded))) + Ok(format!( + "data:{optimized_mime};base64,{}", + STANDARD.encode(optimized_bytes) + )) } async fn normalize_remote_image( @@ -307,8 +315,14 @@ async fn normalize_remote_image( })?; validate_mime(source, &mime)?; + let (optimized_bytes, optimized_mime) = + optimize_image_for_prompt(source, bytes.to_vec(), &mime).await?; + validate_size(source, optimized_bytes.len(), max_bytes)?; - Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes))) + Ok(format!( + "data:{optimized_mime};base64,{}", + STANDARD.encode(optimized_bytes) + )) } async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result { @@ -350,8 +364,78 @@ async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result })?; validate_mime(source, &mime)?; + let (optimized_bytes, optimized_mime) = optimize_image_for_prompt(source, bytes, &mime).await?; + validate_size(source, optimized_bytes.len(), max_bytes)?; - Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes))) + Ok(format!( + "data:{optimized_mime};base64,{}", + STANDARD.encode(optimized_bytes) + )) +} + +async fn optimize_image_for_prompt( + source: &str, + bytes: Vec, + mime: &str, +) -> anyhow::Result<(Vec, String)> { + validate_mime(source, mime)?; + + let source_owned = source.to_string(); + let mime_owned = mime.to_string(); + tokio::task::spawn_blocking(move || { + optimize_image_for_prompt_blocking(source_owned, bytes, mime_owned) + }) + .await + .map_err(|error| MultimodalError::InvalidMarker { + input: source.to_string(), + reason: format!("failed to optimize image payload: {error}"), + })? +} + +fn optimize_image_for_prompt_blocking( + source: String, + bytes: Vec, + mime: String, +) -> anyhow::Result<(Vec, String)> { + let decoded = match image::load_from_memory(&bytes) { + Ok(decoded) => decoded, + Err(_) => return Ok((bytes, mime)), + }; + + let resized = if decoded.width() > OPTIMIZED_IMAGE_MAX_DIMENSION + || decoded.height() > OPTIMIZED_IMAGE_MAX_DIMENSION + { + decoded.thumbnail(OPTIMIZED_IMAGE_MAX_DIMENSION, OPTIMIZED_IMAGE_MAX_DIMENSION) + } else { + decoded + }; + + let mut best_jpeg = Vec::new(); + for quality in [85_u8, 70_u8, 55_u8, 40_u8] { + let mut encoded = Vec::new(); + { + let mut cursor = Cursor::new(&mut encoded); + let mut encoder = + image::codecs::jpeg::JpegEncoder::new_with_quality(&mut cursor, quality); + encoder + .encode_image(&resized) + .map_err(|error| MultimodalError::InvalidMarker { + input: source.clone(), + reason: format!("failed to encode optimized image: {error}"), + })?; + } + + best_jpeg = encoded; + if best_jpeg.len() <= OPTIMIZED_IMAGE_TARGET_BYTES { + return Ok((best_jpeg, "image/jpeg".to_string())); + } + } + + if best_jpeg.len() < bytes.len() { + return Ok((best_jpeg, "image/jpeg".to_string())); + } + + Ok((bytes, mime)) } fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> { @@ -560,6 +644,43 @@ mod tests { .contains("multimodal image size limit exceeded")); } + #[tokio::test] + async fn normalize_data_uri_downscales_large_images_for_prompt_budget() { + let mut image = image::RgbImage::new(1800, 1200); + for (x, y, pixel) in image.enumerate_pixels_mut() { + *pixel = image::Rgb([(x % 251) as u8, (y % 241) as u8, ((x + y) % 239) as u8]); + } + + let mut png_bytes = Vec::new(); + image::DynamicImage::ImageRgb8(image) + .write_to( + &mut std::io::Cursor::new(&mut png_bytes), + image::ImageFormat::Png, + ) + .unwrap(); + let original_size = png_bytes.len(); + + let source = format!("data:image/png;base64,{}", STANDARD.encode(&png_bytes)); + let optimized = normalize_data_uri(&source, 5 * 1024 * 1024) + .await + .expect("data uri should normalize"); + assert!(optimized.starts_with("data:image/jpeg;base64,")); + + let payload = optimized + .split_once(',') + .map(|(_, payload)| payload) + .expect("optimized data URI payload"); + let optimized_bytes = STANDARD.decode(payload).expect("base64 decode"); + assert!( + optimized_bytes.len() < original_size, + "optimized bytes should be smaller than original PNG payload" + ); + + let optimized_image = image::load_from_memory(&optimized_bytes).expect("decode optimized"); + assert!(optimized_image.width() <= OPTIMIZED_IMAGE_MAX_DIMENSION); + assert!(optimized_image.height() <= OPTIMIZED_IMAGE_MAX_DIMENSION); + } + #[test] fn extract_ollama_image_payload_supports_data_uris() { let payload = extract_ollama_image_payload("data:image/png;base64,abcd==")