From 3a8e7d6edf13a37ae2b7181b48612e4767a13a26 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 2 Mar 2026 13:44:33 -0500 Subject: [PATCH] feat(providers): support custom auth_header for custom endpoints --- docs/config-reference.md | 2 + src/agent/loop_.rs | 2 + src/channels/mod.rs | 1 + src/config/schema.rs | 163 +++++++++++++++++++++++++++++++ src/gateway/mod.rs | 1 + src/providers/compatible.rs | 90 +++++++++++++---- src/providers/mod.rs | 80 ++++++++++++++- src/providers/openai_codex.rs | 1 + src/tools/mod.rs | 1 + tests/openai_codex_vision_e2e.rs | 1 + 10 files changed, 324 insertions(+), 18 deletions(-) diff --git a/docs/config-reference.md b/docs/config-reference.md index ab301f078..2a62f5e48 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -46,6 +46,7 @@ Use named profiles to map a logical provider id to a provider name/base URL and |---|---|---| | `name` | unset | Optional provider id override (for example `openai`, `openai-codex`) | | `base_url` | unset | Optional OpenAI-compatible endpoint URL | +| `auth_header` | unset | Optional auth header for `custom:` endpoints (for example `api-key` for Azure OpenAI) | | `wire_api` | unset | Optional protocol mode: `responses` or `chat_completions` | | `model` | unset | Optional profile-scoped default model | | `api_key` | unset | Optional profile-scoped API key (used when top-level `api_key` is empty) | @@ -55,6 +56,7 @@ Notes: - If both top-level `api_key` and profile `api_key` are present, top-level `api_key` wins. - If top-level `default_model` is still the global OpenRouter default, profile `model` is used as an automatic compatibility override. +- `auth_header` is only applied when the resolved provider is `custom:` and the profile `base_url` matches that custom URL. - Secrets encryption applies to profile API keys when `secrets.encrypt = true`. Example: diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index d8bcf901c..b46f589be 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -2718,6 +2718,7 @@ pub async fn run( reasoning_enabled: config.runtime.reasoning_enabled, reasoning_level: config.effective_provider_reasoning_level(), custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), + custom_provider_auth_header: config.effective_custom_provider_auth_header(), max_tokens_override: None, model_support_vision: config.model_support_vision, }; @@ -3423,6 +3424,7 @@ pub async fn process_message_with_session( reasoning_enabled: config.runtime.reasoning_enabled, reasoning_level: config.effective_provider_reasoning_level(), custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), + custom_provider_auth_header: config.effective_custom_provider_auth_header(), max_tokens_override: None, model_support_vision: config.model_support_vision, }; diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 6483c9ff8..aa99ab1f2 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -5421,6 +5421,7 @@ pub async fn start_channels(config: Config) -> Result<()> { reasoning_enabled: config.runtime.reasoning_enabled, reasoning_level: config.effective_provider_reasoning_level(), custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), + custom_provider_auth_header: config.effective_custom_provider_auth_header(), max_tokens_override: None, model_support_vision: config.model_support_vision, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index ea792a8f9..3c1f88437 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -403,6 +403,10 @@ pub struct ModelProviderConfig { /// Optional base URL for OpenAI-compatible endpoints. #[serde(default)] pub base_url: Option, + /// Optional custom authentication header for `custom:` providers + /// (for example `api-key` for Azure OpenAI). + #[serde(default)] + pub auth_header: Option, /// Provider protocol variant ("responses" or "chat_completions"). #[serde(default)] pub wire_api: Option, @@ -7734,6 +7738,10 @@ impl Config { } } + fn urls_match_ignoring_trailing_slash(lhs: &str, rhs: &str) -> bool { + lhs.trim().trim_end_matches('/') == rhs.trim().trim_end_matches('/') + } + /// Resolve provider reasoning level with backward-compatible runtime alias. /// /// Priority: @@ -7787,6 +7795,53 @@ impl Config { Self::normalize_provider_transport(self.provider.transport.as_deref(), "provider.transport") } + /// Resolve custom provider auth header from a matching `[model_providers.*]` profile. + /// + /// This is used when `default_provider = "custom:"` and a profile with the + /// same `base_url` declares `auth_header` (for example `api-key` for Azure OpenAI). + pub fn effective_custom_provider_auth_header(&self) -> Option { + let custom_provider_url = self + .default_provider + .as_deref() + .map(str::trim) + .and_then(|provider| provider.strip_prefix("custom:")) + .map(str::trim) + .filter(|value| !value.is_empty())?; + + let mut profile_keys = self.model_providers.keys().collect::>(); + profile_keys.sort_unstable(); + + for profile_key in profile_keys { + let Some(profile) = self.model_providers.get(profile_key) else { + continue; + }; + + let Some(header) = profile + .auth_header + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + continue; + }; + + let Some(base_url) = profile + .base_url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + continue; + }; + + if Self::urls_match_ignoring_trailing_slash(custom_provider_url, base_url) { + return Some(header.to_string()); + } + } + + None + } + fn lookup_model_provider_profile( &self, provider_name: &str, @@ -8560,6 +8615,18 @@ impl Config { ); } } + + if let Some(auth_header) = profile.auth_header.as_deref().map(str::trim) { + if !auth_header.is_empty() { + reqwest::header::HeaderName::from_bytes(auth_header.as_bytes()).with_context( + || { + format!( + "model_providers.{profile_name}.auth_header is invalid; expected a valid HTTP header name" + ) + }, + )?; + } + } } // Ollama cloud-routing safety checks @@ -12552,6 +12619,7 @@ provider_api = "not-a-real-mode" ModelProviderConfig { name: Some("sub2api".to_string()), base_url: Some("https://api.tonsof.blue/v1".to_string()), + auth_header: None, wire_api: None, default_model: None, api_key: None, @@ -12572,6 +12640,73 @@ provider_api = "not-a-real-mode" ); } + #[test] + async fn model_provider_profile_surfaces_custom_auth_header_for_matching_custom_provider() { + let _env_guard = env_override_lock().await; + let mut config = Config { + default_provider: Some("azure".to_string()), + model_providers: HashMap::from([( + "azure".to_string(), + ModelProviderConfig { + name: Some("azure".to_string()), + base_url: Some( + "https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01" + .to_string(), + ), + auth_header: Some("api-key".to_string()), + wire_api: None, + default_model: None, + api_key: None, + requires_openai_auth: false, + }, + )]), + ..Config::default() + }; + + config.apply_env_overrides(); + assert_eq!( + config.default_provider.as_deref(), + Some( + "custom:https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01" + ) + ); + assert_eq!( + config.effective_custom_provider_auth_header().as_deref(), + Some("api-key") + ); + } + + #[test] + async fn model_provider_profile_custom_auth_header_requires_url_match() { + let _env_guard = env_override_lock().await; + let mut config = Config { + default_provider: Some("azure".to_string()), + model_providers: HashMap::from([( + "azure".to_string(), + ModelProviderConfig { + name: Some("azure".to_string()), + base_url: Some( + "https://resource.openai.azure.com/openai/deployments/other-model/chat/completions?api-version=2024-02-01" + .to_string(), + ), + auth_header: Some("api-key".to_string()), + wire_api: None, + default_model: None, + api_key: None, + requires_openai_auth: false, + }, + )]), + ..Config::default() + }; + + config.apply_env_overrides(); + config.default_provider = Some( + "custom:https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01" + .to_string(), + ); + assert!(config.effective_custom_provider_auth_header().is_none()); + } + #[test] async fn model_provider_profile_responses_uses_openai_codex_and_openai_key() { let _env_guard = env_override_lock().await; @@ -12582,6 +12717,7 @@ provider_api = "not-a-real-mode" ModelProviderConfig { name: Some("sub2api".to_string()), base_url: Some("https://api.tonsof.blue".to_string()), + auth_header: None, wire_api: Some("responses".to_string()), default_model: None, api_key: None, @@ -12646,6 +12782,7 @@ provider_api = "not-a-real-mode" ModelProviderConfig { name: Some("sub2api".to_string()), base_url: Some("https://api.tonsof.blue/v1".to_string()), + auth_header: None, wire_api: Some("ws".to_string()), default_model: None, api_key: None, @@ -12661,6 +12798,30 @@ provider_api = "not-a-real-mode" .contains("wire_api must be one of: responses, chat_completions")); } + #[test] + async fn validate_rejects_invalid_model_provider_auth_header() { + let _env_guard = env_override_lock().await; + let config = Config { + default_provider: Some("sub2api".to_string()), + model_providers: HashMap::from([( + "sub2api".to_string(), + ModelProviderConfig { + name: Some("sub2api".to_string()), + base_url: Some("https://api.tonsof.blue/v1".to_string()), + auth_header: Some("not a header".to_string()), + wire_api: None, + default_model: None, + api_key: None, + requires_openai_auth: false, + }, + )]), + ..Config::default() + }; + + let error = config.validate().expect_err("expected validation failure"); + assert!(error.to_string().contains("auth_header is invalid")); + } + #[test] async fn model_provider_profile_uses_profile_api_key_when_global_is_missing() { let _env_guard = env_override_lock().await; @@ -12672,6 +12833,7 @@ provider_api = "not-a-real-mode" ModelProviderConfig { name: Some("sub2api".to_string()), base_url: Some("https://api.tonsof.blue/v1".to_string()), + auth_header: None, wire_api: None, default_model: None, api_key: Some("profile-api-key".to_string()), @@ -12696,6 +12858,7 @@ provider_api = "not-a-real-mode" ModelProviderConfig { name: Some("sub2api".to_string()), base_url: Some("https://api.tonsof.blue/v1".to_string()), + auth_header: None, wire_api: None, default_model: Some("qwen-max".to_string()), api_key: None, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 726f9eb74..25e7da738 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -417,6 +417,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { reasoning_enabled: config.runtime.reasoning_enabled, reasoning_level: config.effective_provider_reasoning_level(), custom_provider_api_mode: config.provider_api.map(|mode| mode.as_compatible_mode()), + custom_provider_auth_header: config.effective_custom_provider_auth_header(), max_tokens_override: None, model_support_vision: config.model_support_vision, }, diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 6e16e4c89..342dd4d4e 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -317,22 +317,26 @@ impl OpenAiCompatibleProvider { /// This allows custom providers with non-standard endpoints (e.g., VolcEngine ARK uses /// `/api/coding/v3/chat/completions` instead of `/v1/chat/completions`). fn chat_completions_url(&self) -> String { - let has_full_endpoint = reqwest::Url::parse(&self.base_url) - .map(|url| { - url.path() - .trim_end_matches('/') - .ends_with("/chat/completions") - }) - .unwrap_or_else(|_| { - self.base_url - .trim_end_matches('/') - .ends_with("/chat/completions") - }); + if let Ok(mut url) = reqwest::Url::parse(&self.base_url) { + let path = url.path().trim_end_matches('/').to_string(); + if path.ends_with("/chat/completions") { + return url.to_string(); + } - if has_full_endpoint { - self.base_url.clone() + let target_path = if path.is_empty() || path == "/" { + "/chat/completions".to_string() + } else { + format!("{path}/chat/completions") + }; + url.set_path(&target_path); + return url.to_string(); + } + + let normalized = self.base_url.trim_end_matches('/'); + if normalized.ends_with("/chat/completions") { + normalized.to_string() } else { - format!("{}/chat/completions", self.base_url) + format!("{normalized}/chat/completions") } } @@ -355,19 +359,32 @@ impl OpenAiCompatibleProvider { /// Build the full URL for responses API, detecting if base_url already includes the path. fn responses_url(&self) -> String { + if let Ok(mut url) = reqwest::Url::parse(&self.base_url) { + let path = url.path().trim_end_matches('/').to_string(); + let target_path = if path.ends_with("/responses") { + return url.to_string(); + } else if let Some(prefix) = path.strip_suffix("/chat/completions") { + format!("{prefix}/responses") + } else if !path.is_empty() && path != "/" { + format!("{path}/responses") + } else { + "/v1/responses".to_string() + }; + + url.set_path(&target_path); + return url.to_string(); + } + if self.path_ends_with("/responses") { return self.base_url.clone(); } let normalized_base = self.base_url.trim_end_matches('/'); - // If chat endpoint is explicitly configured, derive sibling responses endpoint. if let Some(prefix) = normalized_base.strip_suffix("/chat/completions") { return format!("{prefix}/responses"); } - // If an explicit API path already exists (e.g. /v1, /openai, /api/coding/v3), - // append responses directly to avoid duplicate /v1 segments. if self.has_explicit_api_path() { format!("{normalized_base}/responses") } else { @@ -3318,6 +3335,32 @@ mod tests { ); } + #[test] + fn chat_completions_url_preserves_query_params_for_full_endpoint() { + let p = make_provider( + "custom", + "https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01" + ); + } + + #[test] + fn chat_completions_url_appends_path_before_existing_query_params() { + let p = make_provider( + "custom", + "https://resource.openai.azure.com/openai/deployments/my-model?api-version=2024-02-01", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01" + ); + } + #[test] fn chat_completions_url_requires_exact_suffix_match() { let p = make_provider( @@ -3365,6 +3408,19 @@ mod tests { ); } + #[test] + fn responses_url_preserves_query_params_from_chat_endpoint() { + let p = make_provider( + "custom", + "https://resource.openai.azure.com/openai/deployments/my-model/chat/completions?api-version=2024-02-01", + None, + ); + assert_eq!( + p.responses_url(), + "https://resource.openai.azure.com/openai/deployments/my-model/responses?api-version=2024-02-01" + ); + } + #[test] fn responses_url_derives_from_chat_endpoint() { let p = make_provider( diff --git a/src/providers/mod.rs b/src/providers/mod.rs index cf863c7ec..ebb681ae1 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -742,6 +742,7 @@ pub struct ProviderRuntimeOptions { pub reasoning_enabled: Option, pub reasoning_level: Option, pub custom_provider_api_mode: Option, + pub custom_provider_auth_header: Option, pub max_tokens_override: Option, pub model_support_vision: Option, } @@ -757,6 +758,7 @@ impl Default for ProviderRuntimeOptions { reasoning_enabled: None, reasoning_level: None, custom_provider_api_mode: None, + custom_provider_auth_header: None, max_tokens_override: None, model_support_vision: None, } @@ -1098,6 +1100,36 @@ fn parse_custom_provider_url( } } +fn resolve_custom_provider_auth_style(options: &ProviderRuntimeOptions) -> AuthStyle { + let Some(header) = options + .custom_provider_auth_header + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + return AuthStyle::Bearer; + }; + + if header.eq_ignore_ascii_case("authorization") { + return AuthStyle::Bearer; + } + + if header.eq_ignore_ascii_case("x-api-key") { + return AuthStyle::XApiKey; + } + + match reqwest::header::HeaderName::from_bytes(header.as_bytes()) { + Ok(_) => AuthStyle::Custom(header.to_string()), + Err(error) => { + tracing::warn!( + header = %header, + "Ignoring invalid custom provider auth header and falling back to Bearer: {error}" + ); + AuthStyle::Bearer + } + } +} + /// Factory: create the right provider from config (without custom URL) pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { create_provider_with_options(name, api_key, &ProviderRuntimeOptions::default()) @@ -1488,11 +1520,12 @@ fn create_provider_with_url_and_options( let api_mode = options .custom_provider_api_mode .unwrap_or(CompatibleApiMode::OpenAiChatCompletions); + let auth_style = resolve_custom_provider_auth_style(options); Ok(Box::new(OpenAiCompatibleProvider::new_custom_with_mode( "Custom", &base_url, key, - AuthStyle::Bearer, + auth_style, true, api_mode, options.max_tokens_override, @@ -2852,6 +2885,51 @@ mod tests { assert!(p.is_ok()); } + #[test] + fn custom_provider_auth_style_defaults_to_bearer() { + let options = ProviderRuntimeOptions::default(); + assert!(matches!( + resolve_custom_provider_auth_style(&options), + AuthStyle::Bearer + )); + } + + #[test] + fn custom_provider_auth_style_maps_x_api_key() { + let options = ProviderRuntimeOptions { + custom_provider_auth_header: Some("x-api-key".to_string()), + ..ProviderRuntimeOptions::default() + }; + assert!(matches!( + resolve_custom_provider_auth_style(&options), + AuthStyle::XApiKey + )); + } + + #[test] + fn custom_provider_auth_style_maps_custom_header() { + let options = ProviderRuntimeOptions { + custom_provider_auth_header: Some("api-key".to_string()), + ..ProviderRuntimeOptions::default() + }; + assert!(matches!( + resolve_custom_provider_auth_style(&options), + AuthStyle::Custom(header) if header == "api-key" + )); + } + + #[test] + fn custom_provider_auth_style_invalid_header_falls_back_to_bearer() { + let options = ProviderRuntimeOptions { + custom_provider_auth_header: Some("not a header".to_string()), + ..ProviderRuntimeOptions::default() + }; + assert!(matches!( + resolve_custom_provider_auth_style(&options), + AuthStyle::Bearer + )); + } + // ── Anthropic-compatible custom endpoints ───────────────── #[test] diff --git a/src/providers/openai_codex.rs b/src/providers/openai_codex.rs index aeafd20af..02e384548 100644 --- a/src/providers/openai_codex.rs +++ b/src/providers/openai_codex.rs @@ -1601,6 +1601,7 @@ data: [DONE] reasoning_enabled: None, reasoning_level: None, custom_provider_api_mode: None, + custom_provider_auth_header: None, max_tokens_override: None, model_support_vision: None, }; diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 8864182a4..c1658544d 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -644,6 +644,7 @@ pub fn all_tools_with_runtime( custom_provider_api_mode: root_config .provider_api .map(|mode| mode.as_compatible_mode()), + custom_provider_auth_header: root_config.effective_custom_provider_auth_header(), max_tokens_override: None, model_support_vision: root_config.model_support_vision, }; diff --git a/tests/openai_codex_vision_e2e.rs b/tests/openai_codex_vision_e2e.rs index a4a7875fa..bfe47678c 100644 --- a/tests/openai_codex_vision_e2e.rs +++ b/tests/openai_codex_vision_e2e.rs @@ -154,6 +154,7 @@ async fn openai_codex_second_vision_support() -> Result<()> { reasoning_enabled: None, reasoning_level: None, custom_provider_api_mode: None, + custom_provider_auth_header: None, max_tokens_override: None, model_support_vision: None, };