From 244e68b5fef519e1ace6e09dffaab2c66f7c0b5d Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Mon, 2 Mar 2026 13:42:30 -0500 Subject: [PATCH] feat(reliability): support per-fallback API keys for custom endpoints --- src/config/schema.rs | 59 ++++++++++++++++ src/providers/mod.rs | 26 +++++-- tests/reliability_fallback_api_keys.rs | 95 ++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 6 deletions(-) create mode 100644 tests/reliability_fallback_api_keys.rs diff --git a/src/config/schema.rs b/src/config/schema.rs index ea792a8f9..f48bfcb87 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -4009,6 +4009,10 @@ pub struct ReliabilityConfig { /// Fallback provider chain (e.g. `["anthropic", "openai"]`). #[serde(default)] pub fallback_providers: Vec, + /// Optional per-fallback provider API keys keyed by fallback entry name. + /// This allows distinct credentials for multiple `custom:` endpoints. + #[serde(default)] + pub fallback_api_keys: std::collections::HashMap, /// Additional API keys for round-robin rotation on rate-limit (429) errors. /// The primary `api_key` is always tried first; these are extras. #[serde(default)] @@ -4064,6 +4068,7 @@ impl Default for ReliabilityConfig { provider_retries: default_provider_retries(), provider_backoff_ms: default_provider_backoff_ms(), fallback_providers: Vec::new(), + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: default_channel_backoff_secs(), @@ -6875,6 +6880,21 @@ fn decrypt_vec_secrets( Ok(()) } +fn decrypt_map_secrets( + store: &crate::security::SecretStore, + values: &mut std::collections::HashMap, + field_name: &str, +) -> Result<()> { + for (key, value) in values.iter_mut() { + if crate::security::SecretStore::is_encrypted(value) { + *value = store + .decrypt(value) + .with_context(|| format!("Failed to decrypt {field_name}.{key}"))?; + } + } + Ok(()) +} + fn encrypt_optional_secret( store: &crate::security::SecretStore, value: &mut Option, @@ -6920,6 +6940,21 @@ fn encrypt_vec_secrets( Ok(()) } +fn encrypt_map_secrets( + store: &crate::security::SecretStore, + values: &mut std::collections::HashMap, + field_name: &str, +) -> Result<()> { + for (key, value) in values.iter_mut() { + if !crate::security::SecretStore::is_encrypted(value) { + *value = store + .encrypt(value) + .with_context(|| format!("Failed to encrypt {field_name}.{key}"))?; + } + } + Ok(()) +} + fn decrypt_channel_secrets( store: &crate::security::SecretStore, channels: &mut ChannelsConfig, @@ -7645,6 +7680,11 @@ impl Config { &mut config.reliability.api_keys, "config.reliability.api_keys", )?; + decrypt_map_secrets( + &store, + &mut config.reliability.fallback_api_keys, + "config.reliability.fallback_api_keys", + )?; decrypt_vec_secrets( &store, &mut config.gateway.paired_tokens, @@ -9368,6 +9408,11 @@ impl Config { &mut config_to_save.reliability.api_keys, "config.reliability.api_keys", )?; + encrypt_map_secrets( + &store, + &mut config_to_save.reliability.fallback_api_keys, + "config.reliability.fallback_api_keys", + )?; encrypt_vec_secrets( &store, &mut config_to_save.gateway.paired_tokens, @@ -10652,6 +10697,10 @@ denied_tools = ["shell"] config.web_search.jina_api_key = Some("jina-credential".into()); config.storage.provider.config.db_url = Some("postgres://user:pw@host/db".into()); config.reliability.api_keys = vec!["backup-credential".into()]; + config.reliability.fallback_api_keys.insert( + "custom:https://api-a.example.com/v1".into(), + "fallback-a-credential".into(), + ); config.gateway.paired_tokens = vec!["zc_0123456789abcdef".into()]; config.channels_config.telegram = Some(TelegramConfig { bot_token: "telegram-credential".into(), @@ -10786,6 +10835,16 @@ denied_tools = ["shell"] let reliability_key = &stored.reliability.api_keys[0]; assert!(crate::security::SecretStore::is_encrypted(reliability_key)); assert_eq!(store.decrypt(reliability_key).unwrap(), "backup-credential"); + let fallback_key = stored + .reliability + .fallback_api_keys + .get("custom:https://api-a.example.com/v1") + .expect("fallback key should exist"); + assert!(crate::security::SecretStore::is_encrypted(fallback_key)); + assert_eq!( + store.decrypt(fallback_key).unwrap(), + "fallback-a-credential" + ); let paired_token = &stored.gateway.paired_tokens[0]; assert!(crate::security::SecretStore::is_encrypted(paired_token)); diff --git a/src/providers/mod.rs b/src/providers/mod.rs index cf863c7ec..d316805d3 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1586,15 +1586,22 @@ pub fn create_resilient_provider_with_options( let (provider_name, profile_override) = parse_provider_profile(fallback); - // Each fallback provider resolves its own credential via provider- - // specific env vars (e.g. DEEPSEEK_API_KEY for "deepseek") instead - // of inheriting the primary provider's key. Passing `None` lets - // `resolve_provider_credential` check the correct env var for the - // fallback provider name. + // Fallback providers can use explicit per-entry API keys from + // `reliability.fallback_api_keys` (keyed by full fallback entry), or + // fall back to provider-name keys for compatibility. + // + // If no explicit map entry exists, pass `None` so + // `resolve_provider_credential` can resolve provider-specific env vars. // // When a profile override is present (e.g. "openai-codex:second"), // propagate it through `auth_profile_override` so the provider // picks up the correct OAuth credential set. + let fallback_api_key = reliability + .fallback_api_keys + .get(fallback) + .or_else(|| reliability.fallback_api_keys.get(provider_name)) + .map(String::as_str); + let fallback_options = match profile_override { Some(profile) => { let mut opts = options.clone(); @@ -1604,7 +1611,7 @@ pub fn create_resilient_provider_with_options( None => options.clone(), }; - match create_provider_with_options(provider_name, None, &fallback_options) { + match create_provider_with_options(provider_name, fallback_api_key, &fallback_options) { Ok(provider) => providers.push((fallback.clone(), provider)), Err(_error) => { tracing::warn!( @@ -2962,6 +2969,7 @@ providers = ["demo-plugin-provider"] "openai".into(), "openai".into(), ], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, @@ -3001,6 +3009,7 @@ providers = ["demo-plugin-provider"] provider_retries: 1, provider_backoff_ms: 100, fallback_providers: vec!["lmstudio".into(), "ollama".into()], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, @@ -3023,6 +3032,7 @@ providers = ["demo-plugin-provider"] provider_retries: 1, provider_backoff_ms: 100, fallback_providers: vec!["custom:http://host.docker.internal:1234/v1".into()], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, @@ -3049,6 +3059,7 @@ providers = ["demo-plugin-provider"] "nonexistent-provider".into(), "lmstudio".into(), ], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, @@ -3081,6 +3092,7 @@ providers = ["demo-plugin-provider"] provider_retries: 1, provider_backoff_ms: 100, fallback_providers: vec!["osaurus".into(), "lmstudio".into()], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, @@ -3615,6 +3627,7 @@ providers = ["demo-plugin-provider"] provider_retries: 1, provider_backoff_ms: 100, fallback_providers: vec!["openai-codex:second".into()], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, @@ -3644,6 +3657,7 @@ providers = ["demo-plugin-provider"] "lmstudio".into(), "nonexistent-provider".into(), ], + fallback_api_keys: std::collections::HashMap::new(), api_keys: Vec::new(), model_fallbacks: std::collections::HashMap::new(), channel_initial_backoff_secs: 2, diff --git a/tests/reliability_fallback_api_keys.rs b/tests/reliability_fallback_api_keys.rs new file mode 100644 index 000000000..a9e9551aa --- /dev/null +++ b/tests/reliability_fallback_api_keys.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use wiremock::matchers::{header, method, path}; +use wiremock::{Mock, MockServer, ResponseTemplate}; +use zeroclaw::config::ReliabilityConfig; +use zeroclaw::providers::create_resilient_provider; + +#[tokio::test] +async fn fallback_api_keys_support_multiple_custom_endpoints() { + let primary_server = MockServer::start().await; + let fallback_server_one = MockServer::start().await; + let fallback_server_two = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(500) + .set_body_json(serde_json::json!({ "error": "primary unavailable" })), + ) + .expect(1) + .mount(&primary_server) + .await; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer fallback-key-1")) + .respond_with( + ResponseTemplate::new(500) + .set_body_json(serde_json::json!({ "error": "fallback one unavailable" })), + ) + .expect(1) + .mount(&fallback_server_one) + .await; + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .and(header("authorization", "Bearer fallback-key-2")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "id": "chatcmpl-1", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "response-from-fallback-two" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2 + } + }))) + .expect(1) + .mount(&fallback_server_two) + .await; + + let primary_provider = format!("custom:{}/v1", primary_server.uri()); + let fallback_provider_one = format!("custom:{}/v1", fallback_server_one.uri()); + let fallback_provider_two = format!("custom:{}/v1", fallback_server_two.uri()); + + let mut fallback_api_keys = HashMap::new(); + fallback_api_keys.insert(fallback_provider_one.clone(), "fallback-key-1".to_string()); + fallback_api_keys.insert(fallback_provider_two.clone(), "fallback-key-2".to_string()); + + let reliability = ReliabilityConfig { + provider_retries: 0, + provider_backoff_ms: 0, + fallback_providers: vec![fallback_provider_one.clone(), fallback_provider_two.clone()], + fallback_api_keys, + api_keys: Vec::new(), + model_fallbacks: HashMap::new(), + channel_initial_backoff_secs: 2, + channel_max_backoff_secs: 60, + scheduler_poll_secs: 15, + scheduler_retries: 2, + }; + + let provider = + create_resilient_provider(&primary_provider, Some("primary-key"), None, &reliability) + .expect("resilient provider should initialize"); + + let reply = provider + .chat_with_system(None, "hello", "gpt-4o-mini", 0.0) + .await + .expect("fallback chain should return final response"); + + assert_eq!(reply, "response-from-fallback-two"); + + fallback_server_one.verify().await; + fallback_server_two.verify().await; +}