From 448682c4409044baaee1b8e4791028a68ba403c8 Mon Sep 17 00:00:00 2001 From: Argenis Date: Thu, 12 Mar 2026 00:06:11 -0400 Subject: [PATCH] feat(providers): add Azure OpenAI provider support (#3246) Closes #3176 --- src/config/schema.rs | 18 + src/providers/azure_openai.rs | 756 ++++++++++++++++++++++++++++++++++ src/providers/mod.rs | 15 + 3 files changed, 789 insertions(+) create mode 100644 src/providers/azure_openai.rs diff --git a/src/config/schema.rs b/src/config/schema.rs index b47afe97d..c0f7f6d08 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -242,6 +242,15 @@ pub struct ModelProviderConfig { /// If true, load OpenAI auth material (OPENAI_API_KEY or ~/.codex/auth.json). #[serde(default)] pub requires_openai_auth: bool, + /// Azure OpenAI resource name (e.g. "my-resource" in https://my-resource.openai.azure.com). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub azure_openai_resource: Option, + /// Azure OpenAI deployment name (e.g. "gpt-4o"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub azure_openai_deployment: Option, + /// Azure OpenAI API version (defaults to "2024-08-01-preview"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub azure_openai_api_version: Option, } // ── Delegate Agents ────────────────────────────────────────────── @@ -7141,6 +7150,9 @@ requires_openai_auth = true base_url: Some("https://api.tonsof.blue/v1".to_string()), wire_api: None, requires_openai_auth: false, + azure_openai_resource: None, + azure_openai_deployment: None, + azure_openai_api_version: None, }, )]), ..Config::default() @@ -7169,6 +7181,9 @@ requires_openai_auth = true base_url: Some("https://api.tonsof.blue".to_string()), wire_api: Some("responses".to_string()), requires_openai_auth: true, + azure_openai_resource: None, + azure_openai_deployment: None, + azure_openai_api_version: None, }, )]), api_key: None, @@ -7231,6 +7246,9 @@ requires_openai_auth = true base_url: Some("https://api.tonsof.blue/v1".to_string()), wire_api: Some("ws".to_string()), requires_openai_auth: false, + azure_openai_resource: None, + azure_openai_deployment: None, + azure_openai_api_version: None, }, )]), ..Config::default() diff --git a/src/providers/azure_openai.rs b/src/providers/azure_openai.rs new file mode 100644 index 000000000..1bdaeee07 --- /dev/null +++ b/src/providers/azure_openai.rs @@ -0,0 +1,756 @@ +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload, +}; +use crate::tools::ToolSpec; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +const DEFAULT_API_VERSION: &str = "2024-08-01-preview"; + +pub struct AzureOpenAiProvider { + credential: Option, + resource_name: String, + deployment_name: String, + api_version: String, + base_url: String, +} + +#[derive(Debug, Serialize)] +struct ChatRequest { + messages: Vec, + temperature: f64, +} + +#[derive(Debug, Serialize)] +struct Message { + role: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning_content: Option, +} + +impl ResponseMessage { + fn effective_content(&self) -> String { + match &self.content { + Some(c) if !c.is_empty() => c.clone(), + _ => self.reasoning_content.clone().unwrap_or_default(), + } + } +} + +#[derive(Debug, Serialize)] +struct NativeChatRequest { + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + reasoning_content: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result { + let spec: NativeToolSpec = serde_json::from_value(value) + .map_err(|e| anyhow::anyhow!("Invalid Azure OpenAI tool specification: {e}"))?; + + if spec.kind != "function" { + anyhow::bail!( + "Invalid Azure OpenAI tool specification: unsupported tool type '{}', expected 'function'", + spec.kind + ); + } + + Ok(spec) +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct UsageInfo { + #[serde(default)] + prompt_tokens: Option, + #[serde(default)] + completion_tokens: Option, +} + +#[derive(Debug, Deserialize)] +struct NativeChoice { + message: NativeResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct NativeResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning_content: Option, + #[serde(default)] + tool_calls: Option>, +} + +impl NativeResponseMessage { + fn effective_content(&self) -> Option { + match &self.content { + Some(c) if !c.is_empty() => Some(c.clone()), + _ => self.reasoning_content.clone(), + } + } +} + +impl AzureOpenAiProvider { + pub fn new( + credential: Option<&str>, + resource_name: &str, + deployment_name: &str, + api_version: Option<&str>, + ) -> Self { + let version = api_version.unwrap_or(DEFAULT_API_VERSION); + let base_url = format!( + "https://{}.openai.azure.com/openai/deployments/{}", + resource_name, deployment_name + ); + Self { + credential: credential.map(ToString::to_string), + resource_name: resource_name.to_string(), + deployment_name: deployment_name.to_string(), + api_version: version.to_string(), + base_url, + } + } + + fn chat_completions_url(&self) -> String { + format!( + "{}/chat/completions?api-version={}", + self.base_url, self.api_version + ) + } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + tools.map(|items| { + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect() + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|m| { + if m.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&m.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>( + tool_calls_value.clone(), + ) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let reasoning_content = value + .get("reasoning_content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + reasoning_content, + }; + } + } + } + } + + if m.role == "tool" { + if let Ok(value) = serde_json::from_str::(&m.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + reasoning_content: None, + }; + } + } + + NativeMessage { + role: m.role.clone(), + content: Some(m.content.clone()), + tool_call_id: None, + tool_calls: None, + reasoning_content: None, + } + }) + .collect() + } + + fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { + let text = message.effective_content(); + let reasoning_content = message.reasoning_content.clone(); + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ProviderToolCall { + id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tc.function.name, + arguments: tc.function.arguments, + }) + .collect::>(); + + ProviderChatResponse { + text, + tool_calls, + usage: None, + reasoning_content, + } + } + + fn http_client(&self) -> Client { + crate::config::build_runtime_proxy_client_with_timeouts("provider.azure_openai", 120, 10) + } +} + +#[async_trait] +impl Provider for AzureOpenAiProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: true, + } + } + + fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::OpenAI { + tools: tools + .iter() + .map(|tool| { + serde_json::json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + } + }) + }) + .collect(), + } + } + + fn supports_native_tools(&self) -> bool { + true + } + + fn supports_vision(&self) -> bool { + true + } + + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + _model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml." + ) + })?; + + let mut messages = Vec::new(); + + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), + }); + + let request = ChatRequest { + messages, + temperature, + }; + + let response = self + .http_client() + .post(self.chat_completions_url()) + .header("api-key", credential.as_str()) + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("Azure OpenAI", response).await); + } + + let chat_response: ChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.effective_content()) + .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI")) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + _model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml." + ) + })?; + + let tools = Self::convert_tools(request.tools); + let native_request = NativeChatRequest { + messages: Self::convert_messages(request.messages), + temperature, + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + tools, + }; + + let response = self + .http_client() + .post(self.chat_completions_url()) + .header("api-key", credential.as_str()) + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("Azure OpenAI", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?; + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) + } + + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + _model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml." + ) + })?; + + let native_tools: Option> = if tools.is_empty() { + None + } else { + Some( + tools + .iter() + .cloned() + .map(parse_native_tool_spec) + .collect::, _>>()?, + ) + }; + + let native_request = NativeChatRequest { + messages: Self::convert_messages(messages), + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let response = self + .http_client() + .post(self.chat_completions_url()) + .header("api-key", credential.as_str()) + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("Azure OpenAI", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?; + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) + } + + async fn warmup(&self) -> anyhow::Result<()> { + // Azure OpenAI does not have a lightweight models endpoint, + // so warmup is a no-op to avoid unnecessary API calls. + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn url_construction_default_version() { + let p = AzureOpenAiProvider::new(Some("test-key"), "my-resource", "gpt-4o", None); + assert_eq!( + p.chat_completions_url(), + "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview" + ); + } + + #[test] + fn url_construction_custom_version() { + let p = AzureOpenAiProvider::new( + Some("test-key"), + "my-resource", + "gpt-4o", + Some("2024-06-01"), + ); + assert_eq!( + p.chat_completions_url(), + "https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-06-01" + ); + } + + #[test] + fn url_construction_preserves_resource_and_deployment() { + let p = AzureOpenAiProvider::new(Some("key"), "contoso-ai", "my-gpt35-deployment", None); + let url = p.chat_completions_url(); + assert!(url.contains("contoso-ai.openai.azure.com")); + assert!(url.contains("/deployments/my-gpt35-deployment/")); + assert!(url.contains("api-version=2024-08-01-preview")); + } + + #[test] + fn auth_header_uses_api_key_not_bearer() { + // This test verifies the provider stores the credential correctly + // and that the auth header name is "api-key" (verified via the + // implementation in chat_with_system which uses .header("api-key", ...)). + let p = AzureOpenAiProvider::new(Some("my-azure-key"), "resource", "deployment", None); + assert_eq!(p.credential.as_deref(), Some("my-azure-key")); + } + + #[test] + fn creates_with_credential() { + let p = AzureOpenAiProvider::new( + Some("azure-test-credential"), + "resource", + "deployment", + None, + ); + assert_eq!(p.credential.as_deref(), Some("azure-test-credential")); + assert_eq!(p.resource_name, "resource"); + assert_eq!(p.deployment_name, "deployment"); + assert_eq!(p.api_version, DEFAULT_API_VERSION); + } + + #[test] + fn creates_without_credential() { + let p = AzureOpenAiProvider::new(None, "resource", "deployment", None); + assert!(p.credential.is_none()); + } + + #[tokio::test] + async fn chat_fails_without_key() { + let p = AzureOpenAiProvider::new(None, "resource", "deployment", None); + let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[tokio::test] + async fn chat_with_system_fails_without_key() { + let p = AzureOpenAiProvider::new(None, "resource", "deployment", None); + let result = p + .chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5) + .await; + assert!(result.is_err()); + } + + #[test] + fn request_serializes_with_system_message() { + let req = ChatRequest { + messages: vec![ + Message { + role: "system".to_string(), + content: "You are ZeroClaw".to_string(), + }, + Message { + role: "user".to_string(), + content: "hello".to_string(), + }, + ], + temperature: 0.7, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"role\":\"system\"")); + assert!(json.contains("\"role\":\"user\"")); + // Azure requests should NOT contain a model field (deployment is in the URL) + assert!(!json.contains("\"model\"")); + } + + #[test] + fn request_serializes_without_system() { + let req = ChatRequest { + messages: vec![Message { + role: "user".to_string(), + content: "hello".to_string(), + }], + temperature: 0.0, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("system")); + assert!(json.contains("\"temperature\":0.0")); + } + + #[test] + fn response_deserializes_single_choice() { + let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#; + let resp: ChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.choices.len(), 1); + assert_eq!(resp.choices[0].message.effective_content(), "Hi!"); + } + + #[test] + fn response_deserializes_empty_choices() { + let json = r#"{"choices":[]}"#; + let resp: ChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.choices.is_empty()); + } + + #[test] + fn response_deserializes_multiple_choices() { + let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#; + let resp: ChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.choices.len(), 2); + assert_eq!(resp.choices[0].message.effective_content(), "A"); + } + + #[test] + fn tool_call_response_parsing() { + let json = r#"{"choices":[{"message":{ + "content":"Let me check", + "tool_calls":[{ + "id":"call_abc123", + "type":"function", + "function":{"name":"shell","arguments":"{\"command\":\"ls\"}"} + }] + }}],"usage":{"prompt_tokens":50,"completion_tokens":25}}"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + let message = resp.choices.into_iter().next().unwrap().message; + let parsed = AzureOpenAiProvider::parse_native_response(message); + assert_eq!(parsed.text.as_deref(), Some("Let me check")); + assert_eq!(parsed.tool_calls.len(), 1); + assert_eq!(parsed.tool_calls[0].id, "call_abc123"); + assert_eq!(parsed.tool_calls[0].name, "shell"); + assert!(parsed.tool_calls[0].arguments.contains("ls")); + } + + #[test] + fn tool_call_response_without_id_generates_uuid() { + let json = r#"{"choices":[{"message":{ + "content":null, + "tool_calls":[{ + "function":{"name":"test","arguments":"{}"} + }] + }}]}"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + let message = resp.choices.into_iter().next().unwrap().message; + let parsed = AzureOpenAiProvider::parse_native_response(message); + assert_eq!(parsed.tool_calls.len(), 1); + assert!(!parsed.tool_calls[0].id.is_empty()); + } + + #[tokio::test] + async fn chat_with_tools_fails_without_key() { + let p = AzureOpenAiProvider::new(None, "resource", "deployment", None); + let messages = vec![ChatMessage::user("hello".to_string())]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + })]; + let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[test] + fn native_response_parses_usage() { + let json = r#"{ + "choices": [{"message": {"content": "Hello"}}], + "usage": {"prompt_tokens": 100, "completion_tokens": 50} + }"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + let usage = resp.usage.unwrap(); + assert_eq!(usage.prompt_tokens, Some(100)); + assert_eq!(usage.completion_tokens, Some(50)); + } + + #[test] + fn capabilities_reports_native_tools_and_vision() { + let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None); + let caps = ::capabilities(&p); + assert!(caps.native_tool_calling); + assert!(caps.vision); + } + + #[test] + fn supports_native_tools_returns_true() { + let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None); + assert!(p.supports_native_tools()); + } + + #[test] + fn supports_vision_returns_true() { + let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None); + assert!(p.supports_vision()); + } + + #[tokio::test] + async fn warmup_is_noop() { + let p = AzureOpenAiProvider::new(None, "resource", "deployment", None); + let result = p.warmup().await; + assert!(result.is_ok()); + } + + #[test] + fn custom_api_version_stored() { + let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", Some("2025-01-01")); + assert_eq!(p.api_version, "2025-01-01"); + } +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 89eca28a2..21314f89f 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -17,6 +17,7 @@ //! in [`create_provider_with_url`]. See `AGENTS.md` §7.1 for the full change playbook. pub mod anthropic; +pub mod azure_openai; pub mod bedrock; pub mod compatible; pub mod copilot; @@ -849,6 +850,7 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> "vllm" => vec!["VLLM_API_KEY"], "osaurus" => vec!["OSAURUS_API_KEY"], "telnyx" => vec!["TELNYX_API_KEY"], + "azure_openai" | "azure-openai" | "azure" => vec!["AZURE_OPENAI_API_KEY"], _ => vec![], }; @@ -1125,6 +1127,19 @@ fn create_provider_with_url_and_options( AuthStyle::Bearer, ) )), + "azure_openai" | "azure-openai" | "azure" => { + let resource = std::env::var("AZURE_OPENAI_RESOURCE") + .unwrap_or_else(|_| "my-resource".to_string()); + let deployment = std::env::var("AZURE_OPENAI_DEPLOYMENT") + .unwrap_or_else(|_| "gpt-4o".to_string()); + let api_version = std::env::var("AZURE_OPENAI_API_VERSION").ok(); + Ok(Box::new(azure_openai::AzureOpenAiProvider::new( + key, + &resource, + &deployment, + api_version.as_deref(), + ))) + } "bedrock" | "aws-bedrock" => Ok(Box::new(bedrock::BedrockProvider::new())), name if is_qwen_oauth_alias(name) => { let base_url = api_url