fix(multimodal): optimize image markers for prompt budget

This commit is contained in:
argenis de la rosa 2026-02-26 22:43:26 -05:00 committed by Argenis
parent 34852919da
commit cd26886f15

View File

@ -2,9 +2,12 @@ use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig};
use crate::providers::ChatMessage;
use base64::{engine::general_purpose::STANDARD, Engine as _};
use reqwest::Client;
use std::io::Cursor;
use std::path::Path;
const IMAGE_MARKER_PREFIX: &str = "[IMAGE:";
const OPTIMIZED_IMAGE_MAX_DIMENSION: u32 = 512;
const OPTIMIZED_IMAGE_TARGET_BYTES: usize = 256 * 1024;
const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[
"image/png",
"image/jpeg",
@ -198,7 +201,7 @@ async fn normalize_image_reference(
remote_client: &Client,
) -> anyhow::Result<String> {
if source.starts_with("data:") {
return normalize_data_uri(source, max_bytes);
return normalize_data_uri(source, max_bytes).await;
}
if source.starts_with("http://") || source.starts_with("https://") {
@ -215,7 +218,7 @@ async fn normalize_image_reference(
normalize_local_image(source, max_bytes).await
}
fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
async fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
let Some(comma_idx) = source.find(',') else {
return Err(MultimodalError::InvalidMarker {
input: source.to_string(),
@ -252,9 +255,14 @@ fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String>
reason: format!("invalid base64 payload: {error}"),
})?;
validate_size(source, decoded.len(), max_bytes)?;
let (optimized_bytes, optimized_mime) =
optimize_image_for_prompt(source, decoded, &mime).await?;
validate_size(source, optimized_bytes.len(), max_bytes)?;
Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded)))
Ok(format!(
"data:{optimized_mime};base64,{}",
STANDARD.encode(optimized_bytes)
))
}
async fn normalize_remote_image(
@ -307,8 +315,14 @@ async fn normalize_remote_image(
})?;
validate_mime(source, &mime)?;
let (optimized_bytes, optimized_mime) =
optimize_image_for_prompt(source, bytes.to_vec(), &mime).await?;
validate_size(source, optimized_bytes.len(), max_bytes)?;
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
Ok(format!(
"data:{optimized_mime};base64,{}",
STANDARD.encode(optimized_bytes)
))
}
async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result<String> {
@ -350,8 +364,78 @@ async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result
})?;
validate_mime(source, &mime)?;
let (optimized_bytes, optimized_mime) = optimize_image_for_prompt(source, bytes, &mime).await?;
validate_size(source, optimized_bytes.len(), max_bytes)?;
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
Ok(format!(
"data:{optimized_mime};base64,{}",
STANDARD.encode(optimized_bytes)
))
}
async fn optimize_image_for_prompt(
source: &str,
bytes: Vec<u8>,
mime: &str,
) -> anyhow::Result<(Vec<u8>, String)> {
validate_mime(source, mime)?;
let source_owned = source.to_string();
let mime_owned = mime.to_string();
tokio::task::spawn_blocking(move || {
optimize_image_for_prompt_blocking(source_owned, bytes, mime_owned)
})
.await
.map_err(|error| MultimodalError::InvalidMarker {
input: source.to_string(),
reason: format!("failed to optimize image payload: {error}"),
})?
}
fn optimize_image_for_prompt_blocking(
source: String,
bytes: Vec<u8>,
mime: String,
) -> anyhow::Result<(Vec<u8>, String)> {
let decoded = match image::load_from_memory(&bytes) {
Ok(decoded) => decoded,
Err(_) => return Ok((bytes, mime)),
};
let resized = if decoded.width() > OPTIMIZED_IMAGE_MAX_DIMENSION
|| decoded.height() > OPTIMIZED_IMAGE_MAX_DIMENSION
{
decoded.thumbnail(OPTIMIZED_IMAGE_MAX_DIMENSION, OPTIMIZED_IMAGE_MAX_DIMENSION)
} else {
decoded
};
let mut best_jpeg = Vec::new();
for quality in [85_u8, 70_u8, 55_u8, 40_u8] {
let mut encoded = Vec::new();
{
let mut cursor = Cursor::new(&mut encoded);
let mut encoder =
image::codecs::jpeg::JpegEncoder::new_with_quality(&mut cursor, quality);
encoder
.encode_image(&resized)
.map_err(|error| MultimodalError::InvalidMarker {
input: source.clone(),
reason: format!("failed to encode optimized image: {error}"),
})?;
}
best_jpeg = encoded;
if best_jpeg.len() <= OPTIMIZED_IMAGE_TARGET_BYTES {
return Ok((best_jpeg, "image/jpeg".to_string()));
}
}
if best_jpeg.len() < bytes.len() {
return Ok((best_jpeg, "image/jpeg".to_string()));
}
Ok((bytes, mime))
}
fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> {
@ -560,6 +644,43 @@ mod tests {
.contains("multimodal image size limit exceeded"));
}
#[tokio::test]
async fn normalize_data_uri_downscales_large_images_for_prompt_budget() {
let mut image = image::RgbImage::new(1800, 1200);
for (x, y, pixel) in image.enumerate_pixels_mut() {
*pixel = image::Rgb([(x % 251) as u8, (y % 241) as u8, ((x + y) % 239) as u8]);
}
let mut png_bytes = Vec::new();
image::DynamicImage::ImageRgb8(image)
.write_to(
&mut std::io::Cursor::new(&mut png_bytes),
image::ImageFormat::Png,
)
.unwrap();
let original_size = png_bytes.len();
let source = format!("data:image/png;base64,{}", STANDARD.encode(&png_bytes));
let optimized = normalize_data_uri(&source, 5 * 1024 * 1024)
.await
.expect("data uri should normalize");
assert!(optimized.starts_with("data:image/jpeg;base64,"));
let payload = optimized
.split_once(',')
.map(|(_, payload)| payload)
.expect("optimized data URI payload");
let optimized_bytes = STANDARD.decode(payload).expect("base64 decode");
assert!(
optimized_bytes.len() < original_size,
"optimized bytes should be smaller than original PNG payload"
);
let optimized_image = image::load_from_memory(&optimized_bytes).expect("decode optimized");
assert!(optimized_image.width() <= OPTIMIZED_IMAGE_MAX_DIMENSION);
assert!(optimized_image.height() <= OPTIMIZED_IMAGE_MAX_DIMENSION);
}
#[test]
fn extract_ollama_image_payload_supports_data_uris() {
let payload = extract_ollama_image_payload("data:image/png;base64,abcd==")