fix(multimodal): optimize image markers for prompt budget
This commit is contained in:
parent
34852919da
commit
cd26886f15
@ -2,9 +2,12 @@ use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig};
|
||||
use crate::providers::ChatMessage;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
use reqwest::Client;
|
||||
use std::io::Cursor;
|
||||
use std::path::Path;
|
||||
|
||||
const IMAGE_MARKER_PREFIX: &str = "[IMAGE:";
|
||||
const OPTIMIZED_IMAGE_MAX_DIMENSION: u32 = 512;
|
||||
const OPTIMIZED_IMAGE_TARGET_BYTES: usize = 256 * 1024;
|
||||
const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
@ -198,7 +201,7 @@ async fn normalize_image_reference(
|
||||
remote_client: &Client,
|
||||
) -> anyhow::Result<String> {
|
||||
if source.starts_with("data:") {
|
||||
return normalize_data_uri(source, max_bytes);
|
||||
return normalize_data_uri(source, max_bytes).await;
|
||||
}
|
||||
|
||||
if source.starts_with("http://") || source.starts_with("https://") {
|
||||
@ -215,7 +218,7 @@ async fn normalize_image_reference(
|
||||
normalize_local_image(source, max_bytes).await
|
||||
}
|
||||
|
||||
fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
||||
async fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
||||
let Some(comma_idx) = source.find(',') else {
|
||||
return Err(MultimodalError::InvalidMarker {
|
||||
input: source.to_string(),
|
||||
@ -252,9 +255,14 @@ fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String>
|
||||
reason: format!("invalid base64 payload: {error}"),
|
||||
})?;
|
||||
|
||||
validate_size(source, decoded.len(), max_bytes)?;
|
||||
let (optimized_bytes, optimized_mime) =
|
||||
optimize_image_for_prompt(source, decoded, &mime).await?;
|
||||
validate_size(source, optimized_bytes.len(), max_bytes)?;
|
||||
|
||||
Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded)))
|
||||
Ok(format!(
|
||||
"data:{optimized_mime};base64,{}",
|
||||
STANDARD.encode(optimized_bytes)
|
||||
))
|
||||
}
|
||||
|
||||
async fn normalize_remote_image(
|
||||
@ -307,8 +315,14 @@ async fn normalize_remote_image(
|
||||
})?;
|
||||
|
||||
validate_mime(source, &mime)?;
|
||||
let (optimized_bytes, optimized_mime) =
|
||||
optimize_image_for_prompt(source, bytes.to_vec(), &mime).await?;
|
||||
validate_size(source, optimized_bytes.len(), max_bytes)?;
|
||||
|
||||
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
|
||||
Ok(format!(
|
||||
"data:{optimized_mime};base64,{}",
|
||||
STANDARD.encode(optimized_bytes)
|
||||
))
|
||||
}
|
||||
|
||||
async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
||||
@ -350,8 +364,78 @@ async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result
|
||||
})?;
|
||||
|
||||
validate_mime(source, &mime)?;
|
||||
let (optimized_bytes, optimized_mime) = optimize_image_for_prompt(source, bytes, &mime).await?;
|
||||
validate_size(source, optimized_bytes.len(), max_bytes)?;
|
||||
|
||||
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
|
||||
Ok(format!(
|
||||
"data:{optimized_mime};base64,{}",
|
||||
STANDARD.encode(optimized_bytes)
|
||||
))
|
||||
}
|
||||
|
||||
async fn optimize_image_for_prompt(
|
||||
source: &str,
|
||||
bytes: Vec<u8>,
|
||||
mime: &str,
|
||||
) -> anyhow::Result<(Vec<u8>, String)> {
|
||||
validate_mime(source, mime)?;
|
||||
|
||||
let source_owned = source.to_string();
|
||||
let mime_owned = mime.to_string();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
optimize_image_for_prompt_blocking(source_owned, bytes, mime_owned)
|
||||
})
|
||||
.await
|
||||
.map_err(|error| MultimodalError::InvalidMarker {
|
||||
input: source.to_string(),
|
||||
reason: format!("failed to optimize image payload: {error}"),
|
||||
})?
|
||||
}
|
||||
|
||||
fn optimize_image_for_prompt_blocking(
|
||||
source: String,
|
||||
bytes: Vec<u8>,
|
||||
mime: String,
|
||||
) -> anyhow::Result<(Vec<u8>, String)> {
|
||||
let decoded = match image::load_from_memory(&bytes) {
|
||||
Ok(decoded) => decoded,
|
||||
Err(_) => return Ok((bytes, mime)),
|
||||
};
|
||||
|
||||
let resized = if decoded.width() > OPTIMIZED_IMAGE_MAX_DIMENSION
|
||||
|| decoded.height() > OPTIMIZED_IMAGE_MAX_DIMENSION
|
||||
{
|
||||
decoded.thumbnail(OPTIMIZED_IMAGE_MAX_DIMENSION, OPTIMIZED_IMAGE_MAX_DIMENSION)
|
||||
} else {
|
||||
decoded
|
||||
};
|
||||
|
||||
let mut best_jpeg = Vec::new();
|
||||
for quality in [85_u8, 70_u8, 55_u8, 40_u8] {
|
||||
let mut encoded = Vec::new();
|
||||
{
|
||||
let mut cursor = Cursor::new(&mut encoded);
|
||||
let mut encoder =
|
||||
image::codecs::jpeg::JpegEncoder::new_with_quality(&mut cursor, quality);
|
||||
encoder
|
||||
.encode_image(&resized)
|
||||
.map_err(|error| MultimodalError::InvalidMarker {
|
||||
input: source.clone(),
|
||||
reason: format!("failed to encode optimized image: {error}"),
|
||||
})?;
|
||||
}
|
||||
|
||||
best_jpeg = encoded;
|
||||
if best_jpeg.len() <= OPTIMIZED_IMAGE_TARGET_BYTES {
|
||||
return Ok((best_jpeg, "image/jpeg".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if best_jpeg.len() < bytes.len() {
|
||||
return Ok((best_jpeg, "image/jpeg".to_string()));
|
||||
}
|
||||
|
||||
Ok((bytes, mime))
|
||||
}
|
||||
|
||||
fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> {
|
||||
@ -560,6 +644,43 @@ mod tests {
|
||||
.contains("multimodal image size limit exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn normalize_data_uri_downscales_large_images_for_prompt_budget() {
|
||||
let mut image = image::RgbImage::new(1800, 1200);
|
||||
for (x, y, pixel) in image.enumerate_pixels_mut() {
|
||||
*pixel = image::Rgb([(x % 251) as u8, (y % 241) as u8, ((x + y) % 239) as u8]);
|
||||
}
|
||||
|
||||
let mut png_bytes = Vec::new();
|
||||
image::DynamicImage::ImageRgb8(image)
|
||||
.write_to(
|
||||
&mut std::io::Cursor::new(&mut png_bytes),
|
||||
image::ImageFormat::Png,
|
||||
)
|
||||
.unwrap();
|
||||
let original_size = png_bytes.len();
|
||||
|
||||
let source = format!("data:image/png;base64,{}", STANDARD.encode(&png_bytes));
|
||||
let optimized = normalize_data_uri(&source, 5 * 1024 * 1024)
|
||||
.await
|
||||
.expect("data uri should normalize");
|
||||
assert!(optimized.starts_with("data:image/jpeg;base64,"));
|
||||
|
||||
let payload = optimized
|
||||
.split_once(',')
|
||||
.map(|(_, payload)| payload)
|
||||
.expect("optimized data URI payload");
|
||||
let optimized_bytes = STANDARD.decode(payload).expect("base64 decode");
|
||||
assert!(
|
||||
optimized_bytes.len() < original_size,
|
||||
"optimized bytes should be smaller than original PNG payload"
|
||||
);
|
||||
|
||||
let optimized_image = image::load_from_memory(&optimized_bytes).expect("decode optimized");
|
||||
assert!(optimized_image.width() <= OPTIMIZED_IMAGE_MAX_DIMENSION);
|
||||
assert!(optimized_image.height() <= OPTIMIZED_IMAGE_MAX_DIMENSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ollama_image_payload_supports_data_uris() {
|
||||
let payload = extract_ollama_image_payload("data:image/png;base64,abcd==")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user