diff --git a/docs/config-reference.md b/docs/config-reference.md index f18a09bcc..22d4dbde4 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -26,6 +26,14 @@ Schema export command: | `provider_api` | unset | Optional API mode for `custom:` providers: `openai-chat-completions` or `openai-responses` | | `default_model` | `anthropic/claude-sonnet-4-6` | model routed through selected provider | | `default_temperature` | `0.7` | model temperature | +| `model_support_vision` | unset (`None`) | Vision support override for active provider/model | + +Notes: + +- `model_support_vision = true` forces vision support on (e.g. Ollama running `llava`). +- `model_support_vision = false` forces vision support off. +- Unset keeps the provider's built-in default. +- Environment override: `ZEROCLAW_MODEL_SUPPORT_VISION` or `MODEL_SUPPORT_VISION` (values: `true`/`false`/`1`/`0`/`yes`/`no`/`on`/`off`). ## `[observability]` diff --git a/docs/i18n/vi/config-reference.md b/docs/i18n/vi/config-reference.md index e49c8489e..f3bb0ca22 100644 --- a/docs/i18n/vi/config-reference.md +++ b/docs/i18n/vi/config-reference.md @@ -25,6 +25,14 @@ Lệnh xuất schema: | `default_provider` | `openrouter` | ID hoặc bí danh provider | | `default_model` | `anthropic/claude-sonnet-4-6` | Model định tuyến qua provider đã chọn | | `default_temperature` | `0.7` | Nhiệt độ model | +| `model_support_vision` | chưa đặt (`None`) | Ghi đè hỗ trợ vision cho provider/model đang dùng | + +Lưu ý: + +- `model_support_vision = true` bật vision (ví dụ Ollama chạy `llava`). +- `model_support_vision = false` tắt vision. +- Để trống giữ mặc định của provider. +- Biến môi trường: `ZEROCLAW_MODEL_SUPPORT_VISION` hoặc `MODEL_SUPPORT_VISION` (giá trị: `true`/`false`/`1`/`0`/`yes`/`no`/`on`/`off`). ## `[observability]` diff --git a/docs/i18n/vi/providers-reference.md b/docs/i18n/vi/providers-reference.md index 00ac11584..1d1454bbd 100644 --- a/docs/i18n/vi/providers-reference.md +++ b/docs/i18n/vi/providers-reference.md @@ -94,6 +94,25 @@ Hành vi: - `true`: gửi `think: true`. - Không đặt: bỏ qua `think` và giữ nguyên mặc định của Ollama/model. +### Ghi đè Vision cho Ollama + +Một số model Ollama hỗ trợ vision (ví dụ `llava`, `llama3.2-vision`) trong khi các model khác thì không. +Vì ZeroClaw không thể tự động phát hiện, bạn có thể ghi đè trong `config.toml`: + +```toml +default_provider = "ollama" +default_model = "llava" +model_support_vision = true +``` + +Hành vi: + +- `true`: bật xử lý hình ảnh đính kèm trong vòng lặp agent. +- `false`: tắt vision ngay cả khi provider báo hỗ trợ. +- Không đặt: dùng mặc định của provider. + +Biến môi trường: `ZEROCLAW_MODEL_SUPPORT_VISION=true` + ### Ghi chú về Kimi Code - Provider ID: `kimi-code` diff --git a/docs/providers-reference.md b/docs/providers-reference.md index 7f156558e..b2990501c 100644 --- a/docs/providers-reference.md +++ b/docs/providers-reference.md @@ -152,6 +152,25 @@ Behavior: - `true`: sends `think: true`. - Unset: omits `think` and keeps Ollama/model defaults. +### Ollama Vision Override + +Some Ollama models support vision (e.g. `llava`, `llama3.2-vision`) while others do not. +Since ZeroClaw cannot auto-detect this, you can override it in `config.toml`: + +```toml +default_provider = "ollama" +default_model = "llava" +model_support_vision = true +``` + +Behavior: + +- `true`: enables image attachment processing in the agent loop. +- `false`: disables vision even if the provider reports support. +- Unset: uses the provider's built-in default. + +Environment override: `ZEROCLAW_MODEL_SUPPORT_VISION=true` + ### Kimi Code Notes - Provider ID: `kimi-code` diff --git a/docs/vi/config-reference.md b/docs/vi/config-reference.md index e49c8489e..f3bb0ca22 100644 --- a/docs/vi/config-reference.md +++ b/docs/vi/config-reference.md @@ -25,6 +25,14 @@ Lệnh xuất schema: | `default_provider` | `openrouter` | ID hoặc bí danh provider | | `default_model` | `anthropic/claude-sonnet-4-6` | Model định tuyến qua provider đã chọn | | `default_temperature` | `0.7` | Nhiệt độ model | +| `model_support_vision` | chưa đặt (`None`) | Ghi đè hỗ trợ vision cho provider/model đang dùng | + +Lưu ý: + +- `model_support_vision = true` bật vision (ví dụ Ollama chạy `llava`). +- `model_support_vision = false` tắt vision. +- Để trống giữ mặc định của provider. +- Biến môi trường: `ZEROCLAW_MODEL_SUPPORT_VISION` hoặc `MODEL_SUPPORT_VISION` (giá trị: `true`/`false`/`1`/`0`/`yes`/`no`/`on`/`off`). ## `[observability]` diff --git a/docs/vi/providers-reference.md b/docs/vi/providers-reference.md index 00ac11584..1d1454bbd 100644 --- a/docs/vi/providers-reference.md +++ b/docs/vi/providers-reference.md @@ -94,6 +94,25 @@ Hành vi: - `true`: gửi `think: true`. - Không đặt: bỏ qua `think` và giữ nguyên mặc định của Ollama/model. +### Ghi đè Vision cho Ollama + +Một số model Ollama hỗ trợ vision (ví dụ `llava`, `llama3.2-vision`) trong khi các model khác thì không. +Vì ZeroClaw không thể tự động phát hiện, bạn có thể ghi đè trong `config.toml`: + +```toml +default_provider = "ollama" +default_model = "llava" +model_support_vision = true +``` + +Hành vi: + +- `true`: bật xử lý hình ảnh đính kèm trong vòng lặp agent. +- `false`: tắt vision ngay cả khi provider báo hỗ trợ. +- Không đặt: dùng mặc định của provider. + +Biến môi trường: `ZEROCLAW_MODEL_SUPPORT_VISION=true` + ### Ghi chú về Kimi Code - Provider ID: `kimi-code` diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index c5f751c30..a79fe8e04 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1137,6 +1137,7 @@ pub async fn run( reasoning_enabled: config.runtime.reasoning_enabled, custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), max_tokens_override: None, + model_support_vision: config.model_support_vision, }; let provider: Box = providers::create_routed_provider_with_options( @@ -1598,6 +1599,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { reasoning_enabled: config.runtime.reasoning_enabled, custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), max_tokens_override: None, + model_support_vision: config.model_support_vision, }; let provider: Box = providers::create_routed_provider_with_options( provider_name, diff --git a/src/channels/mod.rs b/src/channels/mod.rs index e9c4a5de4..5791ce3d7 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -413,6 +413,17 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { - Keep normal text outside markers and never wrap markers in code fences.\n\ - Use tool results silently: answer the latest user message directly, and do not narrate delayed/internal tool execution bookkeeping.", ), + "whatsapp" => Some( + "When responding on WhatsApp:\n\ + - Use *bold* for emphasis (WhatsApp uses single asterisks).\n\ + - Be concise. No markdown headers (## etc.) — they don't render.\n\ + - No markdown tables — use bullet lists instead.\n\ + - For sending images, documents, videos, or audio files use markers: [IMAGE:], [DOCUMENT:], [VIDEO:], [AUDIO:]\n\ + - The path MUST be an absolute filesystem path to a local file (e.g. [IMAGE:/home/nicolas/.zeroclaw/workspace/images/chart.png]).\n\ + - Keep normal text outside markers and never wrap markers in code fences.\n\ + - You can combine text and media in one response — text is sent first, then each attachment.\n\ + - Use tool results silently: answer the latest user message directly, and do not narrate delayed/internal tool execution bookkeeping.", + ), _ => None, } } @@ -3047,6 +3058,7 @@ pub async fn start_channels(config: Config) -> Result<()> { reasoning_enabled: config.runtime.reasoning_enabled, custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), max_tokens_override: None, + model_support_vision: config.model_support_vision, }; let provider: Arc = Arc::from( create_resilient_provider_nonblocking( diff --git a/src/channels/whatsapp_web.rs b/src/channels/whatsapp_web.rs index a16ba4338..a794740e0 100644 --- a/src/channels/whatsapp_web.rs +++ b/src/channels/whatsapp_web.rs @@ -34,6 +34,127 @@ use parking_lot::Mutex; use std::sync::Arc; use tokio::select; +// ── Media attachment support ────────────────────────────────────────── + +/// Supported WhatsApp media attachment kinds. +#[cfg(feature = "whatsapp-web")] +#[derive(Debug, Clone, Copy)] +enum WaAttachmentKind { + Image, + Document, + Video, + Audio, +} + +#[cfg(feature = "whatsapp-web")] +impl WaAttachmentKind { + /// Parse from the marker prefix (case-insensitive). + fn from_marker(s: &str) -> Option { + match s.to_ascii_uppercase().as_str() { + "IMAGE" => Some(Self::Image), + "DOCUMENT" => Some(Self::Document), + "VIDEO" => Some(Self::Video), + "AUDIO" => Some(Self::Audio), + _ => None, + } + } + + /// Map to the wa-rs `MediaType` used for upload encryption. + fn media_type(self) -> wa_rs_core::download::MediaType { + match self { + Self::Image => wa_rs_core::download::MediaType::Image, + Self::Document => wa_rs_core::download::MediaType::Document, + Self::Video => wa_rs_core::download::MediaType::Video, + Self::Audio => wa_rs_core::download::MediaType::Audio, + } + } +} + +/// A parsed media attachment from `[KIND:path]` markers in the response text. +#[cfg(feature = "whatsapp-web")] +#[derive(Debug, Clone)] +struct WaAttachment { + kind: WaAttachmentKind, + target: String, +} + +/// Parse `[IMAGE:/path]`, `[DOCUMENT:/path]`, etc. markers out of a message. +/// Returns the cleaned text (markers removed) and a vec of attachments. +#[cfg(feature = "whatsapp-web")] +fn parse_wa_attachment_markers(message: &str) -> (String, Vec) { + let mut cleaned = String::with_capacity(message.len()); + let mut attachments = Vec::new(); + let mut cursor = 0; + + while cursor < message.len() { + let Some(open_rel) = message[cursor..].find('[') else { + cleaned.push_str(&message[cursor..]); + break; + }; + + let open = cursor + open_rel; + cleaned.push_str(&message[cursor..open]); + + let Some(close_rel) = message[open..].find(']') else { + cleaned.push_str(&message[open..]); + break; + }; + + let close = open + close_rel; + let marker = &message[open + 1..close]; + + let parsed = marker.split_once(':').and_then(|(kind, target)| { + let kind = WaAttachmentKind::from_marker(kind)?; + let target = target.trim(); + if target.is_empty() { + return None; + } + Some(WaAttachment { + kind, + target: target.to_string(), + }) + }); + + if let Some(attachment) = parsed { + attachments.push(attachment); + } else { + // Not a valid media marker — keep the original text. + cleaned.push_str(&message[open..=close]); + } + + cursor = close + 1; + } + + (cleaned.trim().to_string(), attachments) +} + +/// Infer MIME type from file extension. +#[cfg(feature = "whatsapp-web")] +fn mime_from_path(path: &std::path::Path) -> &'static str { + match path + .extension() + .and_then(|e| e.to_str()) + .unwrap_or("") + .to_ascii_lowercase() + .as_str() + { + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + "gif" => "image/gif", + "webp" => "image/webp", + "mp4" => "video/mp4", + "mov" => "video/quicktime", + "mp3" => "audio/mpeg", + "ogg" | "opus" => "audio/ogg", + "pdf" => "application/pdf", + "doc" => "application/msword", + "docx" => "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "xls" => "application/vnd.ms-excel", + "xlsx" => "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + _ => "application/octet-stream", + } +} + /// WhatsApp Web channel using wa-rs with custom rusqlite storage /// /// # Status: Functional Implementation @@ -233,6 +354,108 @@ impl WhatsAppWebChannel { Ok(wa_rs_binary::jid::Jid::pn(digits)) } + + /// Upload a file to WhatsApp media servers and send it as the appropriate message type. + #[cfg(feature = "whatsapp-web")] + async fn send_media_attachment( + &self, + client: &Arc, + to: &wa_rs_binary::jid::Jid, + attachment: &WaAttachment, + ) -> Result<()> { + use std::path::Path; + + let path = Path::new(&attachment.target); + if !path.exists() { + anyhow::bail!("Media file not found: {}", attachment.target); + } + + let data = tokio::fs::read(path).await?; + let file_len = data.len() as u64; + let mimetype = mime_from_path(path).to_string(); + + tracing::info!( + "WhatsApp Web: uploading {:?} ({} bytes, {})", + attachment.kind, + file_len, + mimetype + ); + + let upload = client.upload(data, attachment.kind.media_type()).await?; + + let outgoing = match attachment.kind { + WaAttachmentKind::Image => wa_rs_proto::whatsapp::Message { + image_message: Some(Box::new(wa_rs_proto::whatsapp::message::ImageMessage { + url: Some(upload.url), + direct_path: Some(upload.direct_path), + media_key: Some(upload.media_key), + file_enc_sha256: Some(upload.file_enc_sha256), + file_sha256: Some(upload.file_sha256), + file_length: Some(upload.file_length), + mimetype: Some(mimetype), + ..Default::default() + })), + ..Default::default() + }, + WaAttachmentKind::Document => { + let file_name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("file") + .to_string(); + wa_rs_proto::whatsapp::Message { + document_message: Some(Box::new( + wa_rs_proto::whatsapp::message::DocumentMessage { + url: Some(upload.url), + direct_path: Some(upload.direct_path), + media_key: Some(upload.media_key), + file_enc_sha256: Some(upload.file_enc_sha256), + file_sha256: Some(upload.file_sha256), + file_length: Some(upload.file_length), + mimetype: Some(mimetype), + file_name: Some(file_name), + ..Default::default() + }, + )), + ..Default::default() + } + } + WaAttachmentKind::Video => wa_rs_proto::whatsapp::Message { + video_message: Some(Box::new(wa_rs_proto::whatsapp::message::VideoMessage { + url: Some(upload.url), + direct_path: Some(upload.direct_path), + media_key: Some(upload.media_key), + file_enc_sha256: Some(upload.file_enc_sha256), + file_sha256: Some(upload.file_sha256), + file_length: Some(upload.file_length), + mimetype: Some(mimetype), + ..Default::default() + })), + ..Default::default() + }, + WaAttachmentKind::Audio => wa_rs_proto::whatsapp::Message { + audio_message: Some(Box::new(wa_rs_proto::whatsapp::message::AudioMessage { + url: Some(upload.url), + direct_path: Some(upload.direct_path), + media_key: Some(upload.media_key), + file_enc_sha256: Some(upload.file_enc_sha256), + file_sha256: Some(upload.file_sha256), + file_length: Some(upload.file_length), + mimetype: Some(mimetype), + ..Default::default() + })), + ..Default::default() + }, + }; + + let msg_id = client.send_message(to.clone(), outgoing).await?; + tracing::info!( + "WhatsApp Web: sent {:?} media (id: {})", + attachment.kind, + msg_id + ); + Ok(()) + } } #[cfg(feature = "whatsapp-web")] @@ -261,17 +484,59 @@ impl Channel for WhatsAppWebChannel { } let to = self.recipient_to_jid(&message.recipient)?; - let outgoing = wa_rs_proto::whatsapp::Message { - conversation: Some(message.content.clone()), - ..Default::default() - }; - let message_id = client.send_message(to, outgoing).await?; - tracing::debug!( - "WhatsApp Web: sent message to {} (id: {})", - message.recipient, - message_id - ); + // Parse media attachment markers from the response text. + let (text_without_markers, attachments) = parse_wa_attachment_markers(&message.content); + + // Send any text portion first. + if !text_without_markers.is_empty() { + let text_msg = wa_rs_proto::whatsapp::Message { + conversation: Some(text_without_markers.clone()), + ..Default::default() + }; + let msg_id = client.send_message(to.clone(), text_msg).await?; + tracing::debug!( + "WhatsApp Web: sent text to {} (id: {})", + message.recipient, + msg_id + ); + } + + // Send each media attachment. + for attachment in &attachments { + if let Err(e) = self.send_media_attachment(&client, &to, attachment).await { + tracing::error!( + "WhatsApp Web: failed to send {:?} attachment {}: {}", + attachment.kind, + attachment.target, + e + ); + // Fall back to sending the path as text so the user knows something went wrong. + let fallback = wa_rs_proto::whatsapp::Message { + conversation: Some(format!("[Failed to send media: {}]", attachment.target)), + ..Default::default() + }; + let _ = client.send_message(to.clone(), fallback).await; + } + } + + // If there were no markers and no text (shouldn't happen), send original content. + if attachments.is_empty() + && text_without_markers.is_empty() + && !message.content.trim().is_empty() + { + let outgoing = wa_rs_proto::whatsapp::Message { + conversation: Some(message.content.clone()), + ..Default::default() + }; + let message_id = client.send_message(to, outgoing).await?; + tracing::debug!( + "WhatsApp Web: sent message to {} (id: {})", + message.recipient, + message_id + ); + } + Ok(()) } @@ -720,4 +985,44 @@ mod tests { let ch = make_channel(); assert!(!ch.health_check().await); } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn parse_wa_markers_image() { + let msg = "Here is the timeline [IMAGE:/tmp/chart.png]"; + let (text, attachments) = parse_wa_attachment_markers(msg); + assert_eq!(text, "Here is the timeline"); + assert_eq!(attachments.len(), 1); + assert_eq!(attachments[0].target, "/tmp/chart.png"); + assert!(matches!(attachments[0].kind, WaAttachmentKind::Image)); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn parse_wa_markers_multiple() { + let msg = "Text [IMAGE:/a.png] more [DOCUMENT:/b.pdf]"; + let (text, attachments) = parse_wa_attachment_markers(msg); + assert_eq!(text, "Text more"); + assert_eq!(attachments.len(), 2); + assert!(matches!(attachments[0].kind, WaAttachmentKind::Image)); + assert!(matches!(attachments[1].kind, WaAttachmentKind::Document)); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn parse_wa_markers_no_markers() { + let msg = "Just regular text"; + let (text, attachments) = parse_wa_attachment_markers(msg); + assert_eq!(text, "Just regular text"); + assert!(attachments.is_empty()); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn parse_wa_markers_unknown_kind_preserved() { + let msg = "Check [UNKNOWN:/foo] out"; + let (text, attachments) = parse_wa_attachment_markers(msg); + assert_eq!(text, "Check [UNKNOWN:/foo] out"); + assert!(attachments.is_empty()); + } } diff --git a/src/config/schema.rs b/src/config/schema.rs index 20ff63837..02df1002f 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -252,6 +252,13 @@ pub struct Config { /// Inter-process agent communication (`[agents_ipc]`). #[serde(default)] pub agents_ipc: AgentsIpcConfig, + + /// Vision support override for the active provider/model. + /// - `None` (default): use provider's built-in default + /// - `Some(true)`: force vision support on (e.g. Ollama running llava) + /// - `Some(false)`: force vision support off + #[serde(default)] + pub model_support_vision: Option, } /// Named provider profile definition compatible with Codex app-server style config. @@ -3888,6 +3895,7 @@ impl Default for Config { query_classification: QueryClassificationConfig::default(), transcription: TranscriptionConfig::default(), agents_ipc: AgentsIpcConfig::default(), + model_support_vision: None, } } } @@ -4808,6 +4816,18 @@ impl Config { } } + // Vision support override: ZEROCLAW_MODEL_SUPPORT_VISION or MODEL_SUPPORT_VISION + if let Ok(flag) = std::env::var("ZEROCLAW_MODEL_SUPPORT_VISION") + .or_else(|_| std::env::var("MODEL_SUPPORT_VISION")) + { + let normalized = flag.trim().to_ascii_lowercase(); + match normalized.as_str() { + "1" | "true" | "yes" | "on" => self.model_support_vision = Some(true), + "0" | "false" | "no" | "off" => self.model_support_vision = Some(false), + _ => {} + } + } + // Web search enabled: ZEROCLAW_WEB_SEARCH_ENABLED or WEB_SEARCH_ENABLED if let Ok(enabled) = std::env::var("ZEROCLAW_WEB_SEARCH_ENABLED") .or_else(|_| std::env::var("WEB_SEARCH_ENABLED")) @@ -5437,6 +5457,7 @@ default_temperature = 0.7 hardware: HardwareConfig::default(), transcription: TranscriptionConfig::default(), agents_ipc: AgentsIpcConfig::default(), + model_support_vision: None, }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -5528,6 +5549,24 @@ reasoning_enabled = false assert_eq!(parsed.runtime.reasoning_enabled, Some(false)); } + #[test] + async fn model_support_vision_deserializes() { + let raw = r#" +default_temperature = 0.7 +model_support_vision = true +"#; + + let parsed: Config = toml::from_str(raw).unwrap(); + assert_eq!(parsed.model_support_vision, Some(true)); + + // Default (omitted) should be None + let raw_no_vision = r#" +default_temperature = 0.7 +"#; + let parsed2: Config = toml::from_str(raw_no_vision).unwrap(); + assert_eq!(parsed2.model_support_vision, None); + } + #[test] async fn agent_config_defaults() { let cfg = AgentConfig::default(); @@ -5622,6 +5661,7 @@ tool_dispatcher = "xml" hardware: HardwareConfig::default(), transcription: TranscriptionConfig::default(), agents_ipc: AgentsIpcConfig::default(), + model_support_vision: None, }; config.save().await.unwrap(); @@ -7394,6 +7434,28 @@ default_model = "legacy-model" std::env::remove_var("ZEROCLAW_REASONING_ENABLED"); } + #[test] + async fn env_override_model_support_vision() { + let _env_guard = env_override_lock().await; + let mut config = Config::default(); + assert_eq!(config.model_support_vision, None); + + std::env::set_var("ZEROCLAW_MODEL_SUPPORT_VISION", "true"); + config.apply_env_overrides(); + assert_eq!(config.model_support_vision, Some(true)); + + std::env::set_var("ZEROCLAW_MODEL_SUPPORT_VISION", "false"); + config.apply_env_overrides(); + assert_eq!(config.model_support_vision, Some(false)); + + std::env::set_var("ZEROCLAW_MODEL_SUPPORT_VISION", "maybe"); + config.model_support_vision = Some(true); + config.apply_env_overrides(); + assert_eq!(config.model_support_vision, Some(true)); + + std::env::remove_var("ZEROCLAW_MODEL_SUPPORT_VISION"); + } + #[test] async fn env_override_invalid_port_ignored() { let _env_guard = env_override_lock().await; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 11f95aa71..009e9b01d 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -365,6 +365,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { reasoning_enabled: config.runtime.reasoning_enabled, custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), max_tokens_override: None, + model_support_vision: config.model_support_vision, }, )?); let model = config diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 8191a3c3f..58bf306bd 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -176,6 +176,7 @@ pub async fn run_wizard(force: bool) -> Result { query_classification: crate::config::QueryClassificationConfig::default(), transcription: crate::config::TranscriptionConfig::default(), agents_ipc: crate::config::AgentsIpcConfig::default(), + model_support_vision: None, }; println!( @@ -530,6 +531,7 @@ async fn run_quick_setup_with_home( query_classification: crate::config::QueryClassificationConfig::default(), transcription: crate::config::TranscriptionConfig::default(), agents_ipc: crate::config::AgentsIpcConfig::default(), + model_support_vision: None, }; config.save().await?; diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index ef94fb185..ed3d60d85 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,6 +1,6 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, TokenUsage, ToolCall as ProviderToolCall, + Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, }; use crate::tools::ToolSpec; use async_trait::async_trait; @@ -468,13 +468,6 @@ impl AnthropicProvider { #[async_trait] impl Provider for AnthropicProvider { - fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities { - crate::providers::traits::ProviderCapabilities { - vision: true, - native_tool_calling: true, - } - } - async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -566,6 +559,13 @@ impl Provider for AnthropicProvider { true } + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: true, + } + } + async fn chat_with_tools( &self, messages: &[ChatMessage], diff --git a/src/providers/mod.rs b/src/providers/mod.rs index b5eb6f2d2..2d86b81a6 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -678,6 +678,7 @@ pub struct ProviderRuntimeOptions { pub reasoning_enabled: Option, pub custom_provider_api_mode: Option, pub max_tokens_override: Option, + pub model_support_vision: Option, } impl Default for ProviderRuntimeOptions { @@ -690,6 +691,7 @@ impl Default for ProviderRuntimeOptions { reasoning_enabled: None, custom_provider_api_mode: None, max_tokens_override: None, + model_support_vision: None, } } } @@ -1355,7 +1357,8 @@ pub fn create_resilient_provider_with_options( reliability.provider_backoff_ms, ) .with_api_keys(reliability.api_keys.clone()) - .with_model_fallbacks(reliability.model_fallbacks.clone()); + .with_model_fallbacks(reliability.model_fallbacks.clone()) + .with_vision_override(options.model_support_vision); Ok(Box::new(reliable)) } @@ -1427,8 +1430,7 @@ pub fn create_routed_provider_with_options( .then_some(api_url) .flatten(); - let mut route_options = options.clone(); - route_options.max_tokens_override = route.max_tokens; + let route_options = options.clone(); match create_resilient_provider_with_options( &route.provider, @@ -1458,11 +1460,24 @@ pub fn create_routed_provider_with_options( } } - Ok(Box::new(router::RouterProvider::new( - providers, - routes, - default_model.to_string(), - ))) + // Build route table + let routes: Vec<(String, router::Route)> = model_routes + .iter() + .map(|r| { + ( + r.hint.clone(), + router::Route { + provider_name: r.provider.clone(), + model: r.model.clone(), + }, + ) + }) + .collect(); + + Ok(Box::new( + router::RouterProvider::new(providers, routes, default_model.to_string()) + .with_vision_override(options.model_support_vision), + )) } /// Information about a supported provider for display purposes. diff --git a/src/providers/openai_codex.rs b/src/providers/openai_codex.rs index 41e5dcfb4..3036bf9ff 100644 --- a/src/providers/openai_codex.rs +++ b/src/providers/openai_codex.rs @@ -1034,6 +1034,7 @@ data: [DONE] reasoning_enabled: None, custom_provider_api_mode: None, max_tokens_override: None, + model_support_vision: None, }; let provider = OpenAiCodexProvider::new(&options, None).expect("provider should initialize"); diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index a81877f0f..88ae1e76c 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -233,6 +233,8 @@ pub struct ReliableProvider { model_fallbacks: HashMap>, /// Provider-scoped model remaps: provider_name → [model_1, model_2, ...] provider_model_fallbacks: HashMap>, + /// Vision support override from config (`None` = defer to provider). + vision_override: Option, } impl ReliableProvider { @@ -249,6 +251,7 @@ impl ReliableProvider { key_index: AtomicUsize::new(0), model_fallbacks: HashMap::new(), provider_model_fallbacks: HashMap::new(), + vision_override: None, } } @@ -279,6 +282,12 @@ impl ReliableProvider { self } + /// Set vision support override from runtime config. + pub fn with_vision_override(mut self, vision_override: Option) -> Self { + self.vision_override = vision_override; + self + } + /// Build the list of models to try: [original, fallback1, fallback2, ...] fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> { let mut chain = vec![model]; @@ -605,9 +614,11 @@ impl Provider for ReliableProvider { } fn supports_vision(&self) -> bool { - self.providers - .iter() - .any(|(_, provider)| provider.supports_vision()) + self.vision_override.unwrap_or_else(|| { + self.providers + .iter() + .any(|(_, provider)| provider.supports_vision()) + }) } async fn chat_with_tools( @@ -2105,4 +2116,68 @@ mod tests { assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } + + #[test] + fn vision_override_forces_true() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "", + }) as Box, + )], + 1, + 100, + ) + .with_vision_override(Some(true)); + + // MockProvider default capabilities → vision: false + // Override should force true + assert!(provider.supports_vision()); + } + + #[test] + fn vision_override_forces_false() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "", + }) as Box, + )], + 1, + 100, + ) + .with_vision_override(Some(false)); + + assert!(!provider.supports_vision()); + } + + #[test] + fn vision_override_none_defers_to_provider() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "", + }) as Box, + )], + 1, + 100, + ); + // No override set → should defer to provider default (false) + assert!(!provider.supports_vision()); + } } diff --git a/src/providers/router.rs b/src/providers/router.rs index b12bd5205..fd5c6a46f 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -23,6 +23,8 @@ pub struct RouterProvider { providers: Vec<(String, Box)>, default_index: usize, default_model: String, + /// Vision support override from config (`None` = defer to providers). + vision_override: Option, } impl RouterProvider { @@ -66,9 +68,16 @@ impl RouterProvider { providers, default_index: 0, default_model, + vision_override: None, } } + /// Set vision support override from runtime config. + pub fn with_vision_override(mut self, vision_override: Option) -> Self { + self.vision_override = vision_override; + self + } + /// Resolve a model parameter to a (provider, actual_model) pair. /// /// If the model starts with "hint:", look up the hint in the route table. @@ -159,9 +168,11 @@ impl Provider for RouterProvider { } fn supports_vision(&self) -> bool { - self.providers - .iter() - .any(|(_, provider)| provider.supports_vision()) + self.vision_override.unwrap_or_else(|| { + self.providers + .iter() + .any(|(_, provider)| provider.supports_vision()) + }) } async fn warmup(&self) -> anyhow::Result<()> { diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 49301c086..591431e50 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -364,6 +364,7 @@ pub fn all_tools_with_runtime( .provider_api .map(|mode| mode.as_compatible_mode()), max_tokens_override: None, + model_support_vision: root_config.model_support_vision, }, ) .with_parent_tools(parent_tools) diff --git a/tests/openai_codex_vision_e2e.rs b/tests/openai_codex_vision_e2e.rs index af1ed7813..68dc10232 100644 --- a/tests/openai_codex_vision_e2e.rs +++ b/tests/openai_codex_vision_e2e.rs @@ -153,6 +153,7 @@ async fn openai_codex_second_vision_support() -> Result<()> { reasoning_enabled: None, custom_provider_api_mode: None, max_tokens_override: None, + model_support_vision: None, }; let provider = zeroclaw::providers::create_provider_with_options("openai-codex", None, &opts)?;