From 40672ee81cdd38d9f668330df4a8def4e0bf49ef Mon Sep 17 00:00:00 2001 From: Argenis Date: Sat, 21 Mar 2026 07:15:36 -0400 Subject: [PATCH] feat(transcription): add LocalWhisperProvider for self-hosted STT (TDD) (#4141) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Self-hosted Whisper-compatible STT provider that POSTs audio to a configurable HTTP endpoint (e.g. faster-whisper over WireGuard). Audio never leaves the platform perimeter. Implemented via red/green TDD cycles: Wave 1 — config schema: LocalWhisperConfig struct, local_whisper field on TranscriptionConfig + Default impl, re-export in config/mod.rs Wave 2 — from_config validation: url non-empty, url parseable, bearer_token non-empty, max_audio_bytes > 0, timeout_secs > 0; returns Result Wave 3 — manager integration: registration with ? propagation (not if let Ok — credentials come directly from config, no env-var fallback; present section with bad values is a hard error, not a silent skip) Wave 4 — transcribe(): resolve_audio_format() extracted from validate_audio() so LocalWhisperProvider can resolve MIME without the 25 MB cloud cap; size check + format resolution before HTTP send Wave 5 — HTTP mock tests: success response, bearer auth header, 503 error 33 tests (20 baseline + 13 new), all passing. Clippy clean. Co-authored-by: Nim G --- src/channels/transcription.rs | 407 ++++++++++++++++++++++++++++++++-- src/config/mod.rs | 26 +-- src/config/schema.rs | 61 ++++- 3 files changed, 461 insertions(+), 33 deletions(-) diff --git a/src/channels/transcription.rs b/src/channels/transcription.rs index 6b22aef53..fb26d2204 100644 --- a/src/channels/transcription.rs +++ b/src/channels/transcription.rs @@ -80,17 +80,10 @@ fn resolve_transcription_api_key(config: &TranscriptionConfig) -> Result ); } -/// Validate audio data and resolve MIME type from file name. +/// Resolve MIME type and normalize filename from extension. /// -/// Returns `(normalized_filename, mime_type)` on success. -fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> { - if audio_data.len() > MAX_AUDIO_BYTES { - bail!( - "Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})", - audio_data.len() - ); - } - +/// No size check — callers enforce their own limits. +fn resolve_audio_format(file_name: &str) -> Result<(String, &'static str)> { let normalized_name = normalize_audio_filename(file_name); let extension = normalized_name .rsplit_once('.') @@ -98,13 +91,26 @@ fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'stati .unwrap_or(""); let mime = mime_for_audio(extension).ok_or_else(|| { anyhow::anyhow!( - "Unsupported audio format '.{extension}' — accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm" + "Unsupported audio format '.{extension}' — \ + accepted: flac, mp3, mp4, mpeg, mpga, m4a, ogg, opus, wav, webm" ) })?; - Ok((normalized_name, mime)) } +/// Validate audio data and resolve MIME type from file name. +/// +/// Enforces the 25 MB cloud API cap. Returns `(normalized_filename, mime_type)` on success. +fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> { + if audio_data.len() > MAX_AUDIO_BYTES { + bail!( + "Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})", + audio_data.len() + ); + } + resolve_audio_format(file_name) +} + // ── TranscriptionProvider trait ───────────────────────────────── /// Trait for speech-to-text provider implementations. @@ -586,21 +592,120 @@ impl TranscriptionProvider for GoogleSttProvider { } } +// ── LocalWhisperProvider ──────────────────────────────────────── + +/// Self-hosted faster-whisper-compatible STT provider. +/// +/// POSTs audio as `multipart/form-data` (field name `file`) to a configurable +/// HTTP endpoint (e.g. faster-whisper on GEX44 over WireGuard). The endpoint +/// must return `{"text": "..."}`. No cloud API key required. Size limit is +/// configurable — not constrained by the 25 MB cloud API cap. +pub struct LocalWhisperProvider { + url: String, + bearer_token: String, + max_audio_bytes: usize, + timeout_secs: u64, +} + +impl LocalWhisperProvider { + /// Build from config. Fails if `url` or `bearer_token` is empty, if `url` + /// is not a valid HTTP/HTTPS URL (scheme must be `http` or `https`), if + /// `max_audio_bytes` is zero, or if `timeout_secs` is zero. + pub fn from_config(config: &crate::config::LocalWhisperConfig) -> Result { + let url = config.url.trim().to_string(); + anyhow::ensure!(!url.is_empty(), "local_whisper: `url` must not be empty"); + let parsed = url + .parse::() + .with_context(|| format!("local_whisper: invalid `url`: {url:?}"))?; + anyhow::ensure!( + matches!(parsed.scheme(), "http" | "https"), + "local_whisper: `url` must use http or https scheme, got {:?}", + parsed.scheme() + ); + + let bearer_token = config.bearer_token.trim().to_string(); + anyhow::ensure!( + !bearer_token.is_empty(), + "local_whisper: `bearer_token` must not be empty" + ); + + anyhow::ensure!( + config.max_audio_bytes > 0, + "local_whisper: `max_audio_bytes` must be greater than zero" + ); + + anyhow::ensure!( + config.timeout_secs > 0, + "local_whisper: `timeout_secs` must be greater than zero" + ); + + Ok(Self { + url, + bearer_token, + max_audio_bytes: config.max_audio_bytes, + timeout_secs: config.timeout_secs, + }) + } +} + +#[async_trait] +impl TranscriptionProvider for LocalWhisperProvider { + fn name(&self) -> &str { + "local_whisper" + } + + async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result { + if audio_data.len() > self.max_audio_bytes { + bail!( + "Audio file too large ({} bytes, local_whisper max {})", + audio_data.len(), + self.max_audio_bytes + ); + } + + let (normalized_name, mime) = resolve_audio_format(file_name)?; + + let client = crate::config::build_runtime_proxy_client("transcription.local_whisper"); + + // to_vec() clones the buffer for the multipart payload; peak memory per + // call is ~2× max_audio_bytes. TODO: replace with streaming upload once + // reqwest supports body streaming in multipart parts. + let file_part = Part::bytes(audio_data.to_vec()) + .file_name(normalized_name) + .mime_str(mime)?; + + let resp = client + .post(&self.url) + .bearer_auth(&self.bearer_token) + .multipart(Form::new().part("file", file_part)) + .timeout(std::time::Duration::from_secs(self.timeout_secs)) + .send() + .await + .context("Failed to send audio to local Whisper endpoint")?; + + parse_whisper_response(resp).await + } +} + // ── Shared response parsing ───────────────────────────────────── -/// Parse a standard Whisper-compatible JSON response (`{ "text": "..." }`). +/// Parse a faster-whisper-compatible JSON response (`{ "text": "..." }`). +/// +/// Checks HTTP status before attempting JSON parsing so that non-JSON error +/// bodies (plain text, HTML, empty 5xx) produce a readable status error +/// rather than a confusing "Failed to parse transcription response". async fn parse_whisper_response(resp: reqwest::Response) -> Result { let status = resp.status(); + if !status.is_success() { + let body = resp.text().await.unwrap_or_default(); + bail!("Transcription API error ({}): {}", status, body.trim()); + } + let body: serde_json::Value = resp .json() .await .context("Failed to parse transcription response")?; - if !status.is_success() { - let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error"); - bail!("Transcription API error ({}): {}", status, error_msg); - } - let text = body["text"] .as_str() .context("Transcription response missing 'text' field")? @@ -657,6 +762,17 @@ impl TranscriptionManager { } } + if let Some(ref local_cfg) = config.local_whisper { + match LocalWhisperProvider::from_config(local_cfg) { + Ok(p) => { + providers.insert("local_whisper".to_string(), Box::new(p)); + } + Err(e) => { + tracing::warn!("local_whisper config invalid, provider skipped: {e}"); + } + } + } + let default_provider = config.default_provider.clone(); if config.enabled && !providers.contains_key(&default_provider) { @@ -1036,5 +1152,260 @@ mod tests { assert!(config.deepgram.is_none()); assert!(config.assemblyai.is_none()); assert!(config.google.is_none()); + assert!(config.local_whisper.is_none()); + } + + // ── LocalWhisperProvider tests (TDD — added below as red/green cycles) ── + + fn local_whisper_config(url: &str) -> crate::config::LocalWhisperConfig { + crate::config::LocalWhisperConfig { + url: url.to_string(), + bearer_token: "test-token".to_string(), + max_audio_bytes: 10 * 1024 * 1024, + timeout_secs: 30, + } + } + + #[test] + fn local_whisper_rejects_empty_url() { + let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + cfg.url = String::new(); + let err = LocalWhisperProvider::from_config(&cfg).err().unwrap(); + assert!( + err.to_string().contains("`url` must not be empty"), + "got: {err}" + ); + } + + #[test] + fn local_whisper_rejects_invalid_url() { + let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + cfg.url = "not-a-url".to_string(); + let err = LocalWhisperProvider::from_config(&cfg).err().unwrap(); + assert!(err.to_string().contains("invalid `url`"), "got: {err}"); + } + + #[test] + fn local_whisper_rejects_non_http_url() { + let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + cfg.url = "ftp://10.10.0.1:8001/v1/transcribe".to_string(); + let err = LocalWhisperProvider::from_config(&cfg).err().unwrap(); + assert!(err.to_string().contains("http or https"), "got: {err}"); + } + + #[test] + fn local_whisper_rejects_empty_bearer_token() { + let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + cfg.bearer_token = String::new(); + let err = LocalWhisperProvider::from_config(&cfg).err().unwrap(); + assert!( + err.to_string().contains("`bearer_token` must not be empty"), + "got: {err}" + ); + } + + #[test] + fn local_whisper_rejects_zero_max_audio_bytes() { + let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + cfg.max_audio_bytes = 0; + let err = LocalWhisperProvider::from_config(&cfg).err().unwrap(); + assert!( + err.to_string() + .contains("`max_audio_bytes` must be greater than zero"), + "got: {err}" + ); + } + + #[test] + fn local_whisper_rejects_zero_timeout() { + let mut cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + cfg.timeout_secs = 0; + let err = LocalWhisperProvider::from_config(&cfg).err().unwrap(); + assert!( + err.to_string() + .contains("`timeout_secs` must be greater than zero"), + "got: {err}" + ); + } + + #[test] + fn local_whisper_registered_when_config_present() { + let mut config = TranscriptionConfig::default(); + config.local_whisper = Some(local_whisper_config("http://127.0.0.1:9999/v1/transcribe")); + config.default_provider = "local_whisper".to_string(); + + let manager = TranscriptionManager::new(&config).unwrap(); + assert!( + manager.available_providers().contains(&"local_whisper"), + "expected local_whisper in {:?}", + manager.available_providers() + ); + } + + #[test] + fn local_whisper_misconfigured_section_fails_manager_construction() { + // A misconfigured local_whisper section logs a warning and skips + // registration. When local_whisper is also the default_provider and + // transcription is enabled, the safety net in TranscriptionManager + // surfaces the error: "not configured". + let mut config = TranscriptionConfig::default(); + let mut bad_cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + bad_cfg.bearer_token = String::new(); + config.local_whisper = Some(bad_cfg); + config.enabled = true; + config.default_provider = "local_whisper".to_string(); + + let err = TranscriptionManager::new(&config).err().unwrap(); + assert!( + err.to_string().contains("not configured"), + "expected 'not configured' from manager safety net, got: {err}" + ); + } + + #[test] + fn validate_audio_still_enforces_25mb_cap() { + // Regression: extracting resolve_audio_format() must not weaken validate_audio(). + let at_limit = vec![0u8; MAX_AUDIO_BYTES]; + assert!(validate_audio(&at_limit, "test.ogg").is_ok()); + let over_limit = vec![0u8; MAX_AUDIO_BYTES + 1]; + let err = validate_audio(&over_limit, "test.ogg").unwrap_err(); + assert!(err.to_string().contains("too large")); + } + + #[tokio::test] + async fn local_whisper_rejects_oversized_audio() { + let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + let provider = LocalWhisperProvider::from_config(&cfg).unwrap(); + let big = vec![0u8; cfg.max_audio_bytes + 1]; + let err = provider.transcribe(&big, "voice.ogg").await.unwrap_err(); + assert!(err.to_string().contains("too large"), "got: {err}"); + } + + #[tokio::test] + async fn local_whisper_rejects_unsupported_format() { + let cfg = local_whisper_config("http://127.0.0.1:9999/v1/transcribe"); + let provider = LocalWhisperProvider::from_config(&cfg).unwrap(); + let data = vec![0u8; 100]; + let err = provider.transcribe(&data, "voice.aiff").await.unwrap_err(); + assert!( + err.to_string().contains("Unsupported audio format"), + "got: {err}" + ); + } + + // ── LocalWhisperProvider HTTP mock tests ──────────────────── + + #[tokio::test] + async fn local_whisper_returns_text_from_response() { + use wiremock::matchers::{header_exists, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/transcribe")) + .and(header_exists("authorization")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"text": "hello world"})), + ) + .mount(&server) + .await; + + let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri())); + let provider = LocalWhisperProvider::from_config(&cfg).unwrap(); + + let result = provider + .transcribe(b"fake-audio", "voice.ogg") + .await + .unwrap(); + assert_eq!(result, "hello world"); + } + + #[tokio::test] + async fn local_whisper_sends_bearer_auth_header() { + use wiremock::matchers::{header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/transcribe")) + .and(header("authorization", "Bearer test-token")) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"text": "auth ok"})), + ) + .mount(&server) + .await; + + let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri())); + let provider = LocalWhisperProvider::from_config(&cfg).unwrap(); + + let result = provider + .transcribe(b"fake-audio", "voice.ogg") + .await + .unwrap(); + assert_eq!(result, "auth ok"); + } + + #[tokio::test] + async fn local_whisper_propagates_http_error() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/transcribe")) + .respond_with( + ResponseTemplate::new(503).set_body_json( + serde_json::json!({"error": {"message": "service unavailable"}}), + ), + ) + .mount(&server) + .await; + + let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri())); + let provider = LocalWhisperProvider::from_config(&cfg).unwrap(); + + let err = provider + .transcribe(b"fake-audio", "voice.ogg") + .await + .unwrap_err(); + assert!( + err.to_string().contains("503") || err.to_string().contains("service unavailable"), + "expected HTTP error, got: {err}" + ); + } + + #[tokio::test] + async fn local_whisper_propagates_non_json_http_error() { + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/transcribe")) + .respond_with( + ResponseTemplate::new(502) + .set_body_string("Bad Gateway") + .insert_header("content-type", "text/plain"), + ) + .mount(&server) + .await; + + let cfg = local_whisper_config(&format!("{}/v1/transcribe", server.uri())); + let provider = LocalWhisperProvider::from_config(&cfg).unwrap(); + + let err = provider + .transcribe(b"fake-audio", "voice.ogg") + .await + .unwrap_err(); + assert!(err.to_string().contains("502"), "got: {err}"); + assert!( + err.to_string().contains("Bad Gateway"), + "expected plain-text body in error, got: {err}" + ); } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 8dc50cb4e..551cd5066 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -16,19 +16,19 @@ pub use schema::{ HeartbeatConfig, HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig, LarkConfig, LinkedInConfig, - LinkedInContentConfig, LinkedInImageConfig, MatrixConfig, McpConfig, McpServerConfig, - McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig, - NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig, - OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, - PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig, - ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, - RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, - StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, - SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, - TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig, - WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, - DEFAULT_GWS_SERVICES, + LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig, MatrixConfig, McpConfig, + McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, + MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, + ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, + OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, + ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, + ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, + SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig, SkillsConfig, + SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, + StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig, + TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig, + TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, + WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES, }; pub fn name_and_presence(channel: Option<&T>) -> (&'static str, bool) { diff --git a/src/config/schema.rs b/src/config/schema.rs index 9c9abc668..7a05bf8cf 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -784,6 +784,9 @@ pub struct TranscriptionConfig { /// Google Cloud Speech-to-Text provider configuration. #[serde(default)] pub google: Option, + /// Local/self-hosted Whisper-compatible STT provider. + #[serde(default)] + pub local_whisper: Option, } impl Default for TranscriptionConfig { @@ -801,6 +804,7 @@ impl Default for TranscriptionConfig { deepgram: None, assemblyai: None, google: None, + local_whisper: None, } } } @@ -1169,6 +1173,35 @@ pub struct GoogleSttConfig { pub language_code: String, } +/// Local/self-hosted Whisper-compatible STT endpoint (`[transcription.local_whisper]`). +/// +/// Audio is sent over WireGuard; never leaves the platform perimeter. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct LocalWhisperConfig { + /// HTTP or HTTPS endpoint URL, e.g. `"http://10.10.0.1:8001/v1/transcribe"`. + pub url: String, + /// Bearer token for endpoint authentication. + pub bearer_token: String, + /// Maximum audio file size in bytes accepted by this endpoint. + /// Defaults to 25 MB — matching the cloud API cap for a safe out-of-the-box + /// experience. Self-hosted endpoints can accept much larger files; raise this + /// as needed, but note that each transcription call clones the audio buffer + /// into a multipart payload, so peak memory per request is ~2× this value. + #[serde(default = "default_local_whisper_max_audio_bytes")] + pub max_audio_bytes: usize, + /// Request timeout in seconds. Defaults to 300 (large files on local GPU). + #[serde(default = "default_local_whisper_timeout_secs")] + pub timeout_secs: u64, +} + +fn default_local_whisper_max_audio_bytes() -> usize { + 25 * 1024 * 1024 +} + +fn default_local_whisper_timeout_secs() -> u64 { + 300 +} + /// Agent orchestration configuration (`[agent]` section). #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentConfig { @@ -8081,10 +8114,10 @@ impl Config { { let dp = self.transcription.default_provider.trim(); match dp { - "groq" | "openai" | "deepgram" | "assemblyai" | "google" => {} + "groq" | "openai" | "deepgram" | "assemblyai" | "google" | "local_whisper" => {} other => { anyhow::bail!( - "transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google (got '{other}')" + "transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google, local_whisper (got '{other}')" ); } } @@ -12423,6 +12456,30 @@ require_otp_to_resume = true assert!(err.to_string().contains("gated_domains")); } + #[test] + async fn validate_accepts_local_whisper_as_transcription_default_provider() { + let mut config = Config::default(); + config.transcription.default_provider = "local_whisper".to_string(); + + config.validate().expect( + "local_whisper must be accepted by the transcription.default_provider allowlist", + ); + } + + #[test] + async fn validate_rejects_unknown_transcription_default_provider() { + let mut config = Config::default(); + config.transcription.default_provider = "unknown_stt".to_string(); + + let err = config + .validate() + .expect_err("expected validation to reject unknown transcription provider"); + assert!( + err.to_string().contains("transcription.default_provider"), + "got: {err}" + ); + } + #[tokio::test] async fn channel_secret_telegram_bot_token_roundtrip() { let dir = std::env::temp_dir().join(format!(