From 3089eb57a033b83b23a1491365fd6d9ed0fcb4c0 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Wed, 4 Mar 2026 00:33:48 -0500 Subject: [PATCH] fix(discord): transcribe inbound audio attachments --- src/channels/discord.rs | 389 +++++++++++++++++++++++++++++++++++++++- src/channels/mod.rs | 1 + 2 files changed, 384 insertions(+), 6 deletions(-) diff --git a/src/channels/discord.rs b/src/channels/discord.rs index 689e3d9d7..f3674255c 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -1,4 +1,5 @@ use super::traits::{Channel, ChannelMessage, SendMessage}; +use crate::config::TranscriptionConfig; use anyhow::Context; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; @@ -18,6 +19,7 @@ pub struct DiscordChannel { listen_to_bots: bool, mention_only: bool, group_reply_allowed_sender_ids: Vec, + transcription: Option, workspace_dir: Option, typing_handles: Mutex>>, } @@ -37,6 +39,7 @@ impl DiscordChannel { listen_to_bots, mention_only, group_reply_allowed_sender_ids: Vec::new(), + transcription: None, workspace_dir: None, typing_handles: Mutex::new(HashMap::new()), } @@ -48,6 +51,13 @@ impl DiscordChannel { self } + /// Configure voice/audio transcription. + pub fn with_transcription(mut self, config: TranscriptionConfig) -> Self { + if config.enabled { + self.transcription = Some(config); + } + self + } /// Configure workspace directory used for validating local attachment paths. pub fn with_workspace_dir(mut self, dir: PathBuf) -> Self { self.workspace_dir = Some(dir); @@ -132,11 +142,16 @@ fn normalize_group_reply_allowed_sender_ids(sender_ids: Vec) -> Vec]` markers. For +/// `application/octet-stream` or missing MIME types, image-like filename/url +/// extensions are also treated as images. +/// `audio/*` attachments are transcribed when `[transcription].enabled = true`. +/// `text/*` MIME types are fetched and inlined. Other types are skipped. +/// Fetch errors are logged as warnings. async fn process_attachments( attachments: &[serde_json::Value], client: &reqwest::Client, + transcription: Option<&TranscriptionConfig>, ) -> String { let mut parts: Vec = Vec::new(); for att in attachments { @@ -152,7 +167,63 @@ async fn process_attachments( tracing::warn!(name, "discord: attachment has no url, skipping"); continue; }; - if ct.starts_with("text/") { + if is_image_attachment(ct, name, url) { + parts.push(format!("[IMAGE:{url}]")); + } else if is_audio_attachment(ct, name, url) { + let Some(config) = transcription else { + tracing::debug!( + name, + content_type = ct, + "discord: skipping audio attachment because transcription is disabled" + ); + continue; + }; + + if let Some(duration_secs) = parse_attachment_duration_secs(att) { + if duration_secs > config.max_duration_secs { + tracing::warn!( + name, + duration_secs, + max_duration_secs = config.max_duration_secs, + "discord: skipping audio attachment that exceeds transcription duration limit" + ); + continue; + } + } + + let audio_data = match client.get(url).send().await { + Ok(resp) if resp.status().is_success() => match resp.bytes().await { + Ok(bytes) => bytes.to_vec(), + Err(error) => { + tracing::warn!(name, error = %error, "discord: failed to read audio attachment body"); + continue; + } + }, + Ok(resp) => { + tracing::warn!(name, status = %resp.status(), "discord audio attachment fetch failed"); + continue; + } + Err(error) => { + tracing::warn!(name, error = %error, "discord audio attachment fetch error"); + continue; + } + }; + + let file_name = infer_audio_filename(name, url, ct); + match super::transcription::transcribe_audio(audio_data, &file_name, config).await { + Ok(transcript) => { + let transcript = transcript.trim(); + if transcript.is_empty() { + tracing::info!(name, "discord: transcription returned empty text"); + } else { + parts.push(format!("[Voice:{file_name}] {transcript}")); + } + } + Err(error) => { + tracing::warn!(name, error = %error, "discord: audio transcription failed"); + } + } + } else if ct.starts_with("text/") { match client.get(url).send().await { Ok(resp) if resp.status().is_success() => { if let Ok(text) = resp.text().await { @@ -177,6 +248,132 @@ async fn process_attachments( parts.join("\n---\n") } +fn normalize_content_type(content_type: &str) -> String { + content_type + .split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase() +} + +fn is_image_attachment(content_type: &str, filename: &str, url: &str) -> bool { + let normalized_content_type = normalize_content_type(content_type); + + if !normalized_content_type.is_empty() { + if normalized_content_type.starts_with("image/") { + return true; + } + // Trust explicit non-image MIME to avoid false positives from filename extensions. + if normalized_content_type != "application/octet-stream" { + return false; + } + } + + has_image_extension(filename) || has_image_extension(url) +} + +fn is_audio_attachment(content_type: &str, filename: &str, url: &str) -> bool { + let normalized_content_type = normalize_content_type(content_type); + + if !normalized_content_type.is_empty() { + if normalized_content_type.starts_with("audio/") { + return true; + } + // Trust explicit non-audio MIME to avoid false positives from filename extensions. + if normalized_content_type != "application/octet-stream" { + return false; + } + } + + has_audio_extension(filename) || has_audio_extension(url) +} + +fn parse_attachment_duration_secs(attachment: &serde_json::Value) -> Option { + let raw = attachment + .get("duration_secs") + .and_then(|value| value.as_f64().or_else(|| value.as_u64().map(|v| v as f64)))?; + if !raw.is_finite() || raw.is_sign_negative() { + return None; + } + Some(raw.ceil() as u64) +} + +fn extension_from_media_path(value: &str) -> Option { + let base = value.split('?').next().unwrap_or(value); + let base = base.split('#').next().unwrap_or(base); + Path::new(base) + .extension() + .and_then(|ext| ext.to_str()) + .map(|ext| ext.to_ascii_lowercase()) +} + +fn is_supported_audio_extension(extension: &str) -> bool { + matches!( + extension, + "flac" | "mp3" | "mpeg" | "mpga" | "mp4" | "m4a" | "ogg" | "oga" | "opus" | "wav" | "webm" + ) +} + +fn has_audio_extension(value: &str) -> bool { + matches!( + extension_from_media_path(value).as_deref(), + Some(ext) if is_supported_audio_extension(ext) + ) +} + +fn audio_extension_from_content_type(content_type: &str) -> Option<&'static str> { + match normalize_content_type(content_type).as_str() { + "audio/flac" | "audio/x-flac" => Some("flac"), + "audio/mpeg" => Some("mp3"), + "audio/mpga" => Some("mpga"), + "audio/mp4" | "audio/x-m4a" | "audio/m4a" => Some("m4a"), + "audio/ogg" | "application/ogg" => Some("ogg"), + "audio/opus" => Some("opus"), + "audio/wav" | "audio/x-wav" | "audio/wave" => Some("wav"), + "audio/webm" => Some("webm"), + _ => None, + } +} + +fn infer_audio_filename(filename: &str, url: &str, content_type: &str) -> String { + let trimmed_name = filename.trim(); + if !trimmed_name.is_empty() && has_audio_extension(trimmed_name) { + return trimmed_name.to_string(); + } + + if let Some(ext) = + extension_from_media_path(url).filter(|ext| is_supported_audio_extension(ext)) + { + return format!("audio.{ext}"); + } + + if let Some(ext) = audio_extension_from_content_type(content_type) { + return format!("audio.{ext}"); + } + + "audio.ogg".to_string() +} + +fn has_image_extension(value: &str) -> bool { + matches!( + extension_from_media_path(value).as_deref(), + Some( + "png" + | "jpg" + | "jpeg" + | "gif" + | "webp" + | "bmp" + | "tif" + | "tiff" + | "svg" + | "avif" + | "heic" + | "heif" + ) + ) +} #[derive(Debug, Clone, PartialEq, Eq)] enum DiscordAttachmentKind { Image, @@ -807,7 +1004,8 @@ impl Channel for DiscordChannel { .and_then(|a| a.as_array()) .cloned() .unwrap_or_default(); - process_attachments(&atts, &self.http_client()).await + process_attachments(&atts, &self.http_client(), self.transcription.as_ref()) + .await }; let final_content = if attachment_text.is_empty() { clean_content @@ -984,6 +1182,8 @@ impl Channel for DiscordChannel { #[cfg(test)] mod tests { use super::*; + use axum::{routing::get, routing::post, Json, Router}; + use serde_json::json as json_value; #[test] fn discord_channel_name() { @@ -1542,7 +1742,7 @@ mod tests { #[tokio::test] async fn process_attachments_empty_list_returns_empty() { let client = reqwest::Client::new(); - let result = process_attachments(&[], &client).await; + let result = process_attachments(&[], &client, None).await; assert!(result.is_empty()); } @@ -1554,10 +1754,170 @@ mod tests { "filename": "doc.pdf", "content_type": "application/pdf" })]; - let result = process_attachments(&attachments, &client).await; + let result = process_attachments(&attachments, &client, None).await; assert!(result.is_empty()); } + #[tokio::test] + async fn process_attachments_emits_image_marker_for_image_content_type() { + let client = reqwest::Client::new(); + let attachments = vec![serde_json::json!({ + "url": "https://cdn.discordapp.com/attachments/123/456/photo.png", + "filename": "photo.png", + "content_type": "image/png" + })]; + let result = process_attachments(&attachments, &client, None).await; + assert_eq!( + result, + "[IMAGE:https://cdn.discordapp.com/attachments/123/456/photo.png]" + ); + } + + #[tokio::test] + async fn process_attachments_emits_multiple_image_markers() { + let client = reqwest::Client::new(); + let attachments = vec![ + serde_json::json!({ + "url": "https://cdn.discordapp.com/attachments/123/456/one.jpg", + "filename": "one.jpg", + "content_type": "image/jpeg" + }), + serde_json::json!({ + "url": "https://cdn.discordapp.com/attachments/123/456/two.webp", + "filename": "two.webp", + "content_type": "image/webp" + }), + ]; + let result = process_attachments(&attachments, &client, None).await; + assert_eq!( + result, + "[IMAGE:https://cdn.discordapp.com/attachments/123/456/one.jpg]\n---\n[IMAGE:https://cdn.discordapp.com/attachments/123/456/two.webp]" + ); + } + + #[tokio::test] + async fn process_attachments_emits_image_marker_from_filename_without_content_type() { + let client = reqwest::Client::new(); + let attachments = vec![serde_json::json!({ + "url": "https://cdn.discordapp.com/attachments/123/456/photo.jpeg?size=1024", + "filename": "photo.jpeg" + })]; + let result = process_attachments(&attachments, &client, None).await; + assert_eq!( + result, + "[IMAGE:https://cdn.discordapp.com/attachments/123/456/photo.jpeg?size=1024]" + ); + } + + #[tokio::test] + #[ignore = "requires local loopback TCP bind"] + async fn process_attachments_transcribes_audio_when_enabled() { + async fn audio_handler() -> ([(String, String); 1], Vec) { + ( + [( + "content-type".to_string(), + "audio/ogg; codecs=opus".to_string(), + )], + vec![1_u8, 2, 3, 4, 5, 6], + ) + } + + async fn transcribe_handler() -> Json { + Json(json_value!({ "text": "hello from discord audio" })) + } + + let app = Router::new() + .route("/audio.ogg", get(audio_handler)) + .route("/transcribe", post(transcribe_handler)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test server"); + let addr = listener.local_addr().expect("local addr"); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + + let mut transcription = TranscriptionConfig::default(); + transcription.enabled = true; + transcription.api_url = format!("http://{addr}/transcribe"); + transcription.model = "whisper-test".to_string(); + + let client = reqwest::Client::new(); + let attachments = vec![serde_json::json!({ + "url": format!("http://{addr}/audio.ogg"), + "filename": "voice.ogg", + "content_type": "audio/ogg", + "duration_secs": 4 + })]; + + let result = process_attachments(&attachments, &client, Some(&transcription)).await; + assert_eq!(result, "[Voice:voice.ogg] hello from discord audio"); + } + + #[tokio::test] + async fn process_attachments_skips_audio_when_duration_exceeds_limit() { + let mut transcription = TranscriptionConfig::default(); + transcription.enabled = true; + transcription.api_url = "http://127.0.0.1:1/transcribe".to_string(); + transcription.max_duration_secs = 5; + + let client = reqwest::Client::new(); + let attachments = vec![serde_json::json!({ + "url": "http://127.0.0.1:1/audio.ogg", + "filename": "voice.ogg", + "content_type": "audio/ogg", + "duration_secs": 120 + })]; + + let result = process_attachments(&attachments, &client, Some(&transcription)).await; + assert!(result.is_empty()); + } + + #[test] + fn is_image_attachment_prefers_non_image_content_type_over_extension() { + assert!(!is_image_attachment( + "text/plain", + "photo.png", + "https://cdn.discordapp.com/attachments/123/456/photo.png" + )); + } + + #[test] + fn is_image_attachment_allows_octet_stream_extension_fallback() { + assert!(is_image_attachment( + "application/octet-stream", + "photo.png", + "https://cdn.discordapp.com/attachments/123/456/photo.png" + )); + } + + #[test] + fn is_audio_attachment_prefers_non_audio_content_type_over_extension() { + assert!(!is_audio_attachment( + "text/plain", + "voice.ogg", + "https://cdn.discordapp.com/attachments/123/456/voice.ogg" + )); + } + + #[test] + fn is_audio_attachment_allows_octet_stream_extension_fallback() { + assert!(is_audio_attachment( + "application/octet-stream", + "voice.ogg", + "https://cdn.discordapp.com/attachments/123/456/voice.ogg" + )); + } + + #[test] + fn infer_audio_filename_uses_content_type_when_name_lacks_extension() { + let file_name = infer_audio_filename( + "voice_upload", + "https://cdn.discordapp.com/attachments/123/456/blob", + "audio/ogg; codecs=opus", + ); + assert_eq!(file_name, "audio.ogg"); + } #[test] fn parse_attachment_markers_extracts_supported_markers() { let input = "Report\n[IMAGE:https://example.com/a.png]\n[DOCUMENT:/tmp/a.pdf]"; @@ -1632,6 +1992,23 @@ mod tests { ); } + #[test] + fn with_transcription_sets_config_when_enabled() { + let mut tc = TranscriptionConfig::default(); + tc.enabled = true; + let channel = + DiscordChannel::new("fake".into(), None, vec![], false, false).with_transcription(tc); + assert!(channel.transcription.is_some()); + } + + #[test] + fn with_transcription_skips_when_disabled() { + let tc = TranscriptionConfig::default(); + let channel = + DiscordChannel::new("fake".into(), None, vec![], false, false).with_transcription(tc); + assert!(channel.transcription.is_none()); + } + #[test] fn resolve_local_attachment_path_blocks_workspace_escape() { let temp = tempfile::tempdir().expect("tempdir"); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index b115bc177..fb343e34a 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -4194,6 +4194,7 @@ fn collect_configured_channels( dc.effective_group_reply_mode().requires_mention(), ) .with_group_reply_allowed_senders(dc.group_reply_allowed_sender_ids()) + .with_transcription(config.transcription.clone()) .with_workspace_dir(config.workspace_dir.clone()), ), });