feat(providers): add Azure OpenAI provider support (#3246)
Closes #3176
This commit is contained in:
parent
3fe3fe23b1
commit
448682c440
@ -242,6 +242,15 @@ pub struct ModelProviderConfig {
|
||||
/// If true, load OpenAI auth material (OPENAI_API_KEY or ~/.codex/auth.json).
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
/// Azure OpenAI resource name (e.g. "my-resource" in https://my-resource.openai.azure.com).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub azure_openai_resource: Option<String>,
|
||||
/// Azure OpenAI deployment name (e.g. "gpt-4o").
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub azure_openai_deployment: Option<String>,
|
||||
/// Azure OpenAI API version (defaults to "2024-08-01-preview").
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub azure_openai_api_version: Option<String>,
|
||||
}
|
||||
|
||||
// ── Delegate Agents ──────────────────────────────────────────────
|
||||
@ -7141,6 +7150,9 @@ requires_openai_auth = true
|
||||
base_url: Some("https://api.tonsof.blue/v1".to_string()),
|
||||
wire_api: None,
|
||||
requires_openai_auth: false,
|
||||
azure_openai_resource: None,
|
||||
azure_openai_deployment: None,
|
||||
azure_openai_api_version: None,
|
||||
},
|
||||
)]),
|
||||
..Config::default()
|
||||
@ -7169,6 +7181,9 @@ requires_openai_auth = true
|
||||
base_url: Some("https://api.tonsof.blue".to_string()),
|
||||
wire_api: Some("responses".to_string()),
|
||||
requires_openai_auth: true,
|
||||
azure_openai_resource: None,
|
||||
azure_openai_deployment: None,
|
||||
azure_openai_api_version: None,
|
||||
},
|
||||
)]),
|
||||
api_key: None,
|
||||
@ -7231,6 +7246,9 @@ requires_openai_auth = true
|
||||
base_url: Some("https://api.tonsof.blue/v1".to_string()),
|
||||
wire_api: Some("ws".to_string()),
|
||||
requires_openai_auth: false,
|
||||
azure_openai_resource: None,
|
||||
azure_openai_deployment: None,
|
||||
azure_openai_api_version: None,
|
||||
},
|
||||
)]),
|
||||
..Config::default()
|
||||
|
||||
756
src/providers/azure_openai.rs
Normal file
756
src/providers/azure_openai.rs
Normal file
@ -0,0 +1,756 @@
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_API_VERSION: &str = "2024-08-01-preview";
|
||||
|
||||
pub struct AzureOpenAiProvider {
|
||||
credential: Option<String>,
|
||||
resource_name: String,
|
||||
deployment_name: String,
|
||||
api_version: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
impl ResponseMessage {
|
||||
fn effective_content(&self) -> String {
|
||||
match &self.content {
|
||||
Some(c) if !c.is_empty() => c.clone(),
|
||||
_ => self.reasoning_content.clone().unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeChatRequest {
|
||||
messages: Vec<NativeMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result<NativeToolSpec> {
|
||||
let spec: NativeToolSpec = serde_json::from_value(value)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid Azure OpenAI tool specification: {e}"))?;
|
||||
|
||||
if spec.kind != "function" {
|
||||
anyhow::bail!(
|
||||
"Invalid Azure OpenAI tool specification: unsupported tool type '{}', expected 'function'",
|
||||
spec.kind
|
||||
);
|
||||
}
|
||||
|
||||
Ok(spec)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChatResponse {
|
||||
choices: Vec<NativeChoice>,
|
||||
#[serde(default)]
|
||||
usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsageInfo {
|
||||
#[serde(default)]
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChoice {
|
||||
message: NativeResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
impl NativeResponseMessage {
|
||||
fn effective_content(&self) -> Option<String> {
|
||||
match &self.content {
|
||||
Some(c) if !c.is_empty() => Some(c.clone()),
|
||||
_ => self.reasoning_content.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AzureOpenAiProvider {
|
||||
pub fn new(
|
||||
credential: Option<&str>,
|
||||
resource_name: &str,
|
||||
deployment_name: &str,
|
||||
api_version: Option<&str>,
|
||||
) -> Self {
|
||||
let version = api_version.unwrap_or(DEFAULT_API_VERSION);
|
||||
let base_url = format!(
|
||||
"https://{}.openai.azure.com/openai/deployments/{}",
|
||||
resource_name, deployment_name
|
||||
);
|
||||
Self {
|
||||
credential: credential.map(ToString::to_string),
|
||||
resource_name: resource_name.to_string(),
|
||||
deployment_name: deployment_name.to_string(),
|
||||
api_version: version.to_string(),
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn chat_completions_url(&self) -> String {
|
||||
format!(
|
||||
"{}/chat/completions?api-version={}",
|
||||
self.base_url, self.api_version
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
if m.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||
tool_calls_value.clone(),
|
||||
)
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| NativeToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let reasoning_content = value
|
||||
.get("reasoning_content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
reasoning_content,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
NativeMessage {
|
||||
role: m.role.clone(),
|
||||
content: Some(m.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
|
||||
let text = message.effective_content();
|
||||
let reasoning_content = message.reasoning_content.clone();
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tc| ProviderToolCall {
|
||||
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ProviderChatResponse {
|
||||
text,
|
||||
tool_calls,
|
||||
usage: None,
|
||||
reasoning_content,
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("provider.azure_openai", 120, 10)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for AzureOpenAiProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::OpenAI {
|
||||
tools: tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
_model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(self.chat_completions_url())
|
||||
.header("api-key", credential.as_str())
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Azure OpenAI", response).await);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.effective_content())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
_model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let tools = Self::convert_tools(request.tools);
|
||||
let native_request = NativeChatRequest {
|
||||
messages: Self::convert_messages(request.messages),
|
||||
temperature,
|
||||
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(self.chat_completions_url())
|
||||
.header("api-key", credential.as_str())
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Azure OpenAI", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
|
||||
let mut result = Self::parse_native_response(message);
|
||||
result.usage = usage;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn chat_with_tools(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[serde_json::Value],
|
||||
_model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tools
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(parse_native_tool_spec)
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
};
|
||||
|
||||
let native_request = NativeChatRequest {
|
||||
messages: Self::convert_messages(messages),
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(self.chat_completions_url())
|
||||
.header("api-key", credential.as_str())
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Azure OpenAI", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
|
||||
let mut result = Self::parse_native_response(message);
|
||||
result.usage = usage;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Azure OpenAI does not have a lightweight models endpoint,
|
||||
// so warmup is a no-op to avoid unnecessary API calls.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn url_construction_default_version() {
|
||||
let p = AzureOpenAiProvider::new(Some("test-key"), "my-resource", "gpt-4o", None);
|
||||
assert_eq!(
|
||||
p.chat_completions_url(),
|
||||
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn url_construction_custom_version() {
|
||||
let p = AzureOpenAiProvider::new(
|
||||
Some("test-key"),
|
||||
"my-resource",
|
||||
"gpt-4o",
|
||||
Some("2024-06-01"),
|
||||
);
|
||||
assert_eq!(
|
||||
p.chat_completions_url(),
|
||||
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-06-01"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn url_construction_preserves_resource_and_deployment() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "contoso-ai", "my-gpt35-deployment", None);
|
||||
let url = p.chat_completions_url();
|
||||
assert!(url.contains("contoso-ai.openai.azure.com"));
|
||||
assert!(url.contains("/deployments/my-gpt35-deployment/"));
|
||||
assert!(url.contains("api-version=2024-08-01-preview"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_header_uses_api_key_not_bearer() {
|
||||
// This test verifies the provider stores the credential correctly
|
||||
// and that the auth header name is "api-key" (verified via the
|
||||
// implementation in chat_with_system which uses .header("api-key", ...)).
|
||||
let p = AzureOpenAiProvider::new(Some("my-azure-key"), "resource", "deployment", None);
|
||||
assert_eq!(p.credential.as_deref(), Some("my-azure-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_credential() {
|
||||
let p = AzureOpenAiProvider::new(
|
||||
Some("azure-test-credential"),
|
||||
"resource",
|
||||
"deployment",
|
||||
None,
|
||||
);
|
||||
assert_eq!(p.credential.as_deref(), Some("azure-test-credential"));
|
||||
assert_eq!(p.resource_name, "resource");
|
||||
assert_eq!(p.deployment_name, "deployment");
|
||||
assert_eq!(p.api_version, DEFAULT_API_VERSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_credential() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_key() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_fails_without_key() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let result = p
|
||||
.chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system_message() {
|
||||
let req = ChatRequest {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
temperature: 0.7,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
// Azure requests should NOT contain a model field (deployment is in the URL)
|
||||
assert!(!json.contains("\"model\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
}],
|
||||
temperature: 0.0,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_single_choice() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices.len(), 1);
|
||||
assert_eq!(resp.choices[0].message.effective_content(), "Hi!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_empty_choices() {
|
||||
let json = r#"{"choices":[]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.choices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_multiple_choices() {
|
||||
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices.len(), 2);
|
||||
assert_eq!(resp.choices[0].message.effective_content(), "A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_response_parsing() {
|
||||
let json = r#"{"choices":[{"message":{
|
||||
"content":"Let me check",
|
||||
"tool_calls":[{
|
||||
"id":"call_abc123",
|
||||
"type":"function",
|
||||
"function":{"name":"shell","arguments":"{\"command\":\"ls\"}"}
|
||||
}]
|
||||
}}],"usage":{"prompt_tokens":50,"completion_tokens":25}}"#;
|
||||
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
let message = resp.choices.into_iter().next().unwrap().message;
|
||||
let parsed = AzureOpenAiProvider::parse_native_response(message);
|
||||
assert_eq!(parsed.text.as_deref(), Some("Let me check"));
|
||||
assert_eq!(parsed.tool_calls.len(), 1);
|
||||
assert_eq!(parsed.tool_calls[0].id, "call_abc123");
|
||||
assert_eq!(parsed.tool_calls[0].name, "shell");
|
||||
assert!(parsed.tool_calls[0].arguments.contains("ls"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_response_without_id_generates_uuid() {
|
||||
let json = r#"{"choices":[{"message":{
|
||||
"content":null,
|
||||
"tool_calls":[{
|
||||
"function":{"name":"test","arguments":"{}"}
|
||||
}]
|
||||
}}]}"#;
|
||||
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
let message = resp.choices.into_iter().next().unwrap().message;
|
||||
let parsed = AzureOpenAiProvider::parse_native_response(message);
|
||||
assert_eq!(parsed.tool_calls.len(), 1);
|
||||
assert!(!parsed.tool_calls[0].id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_tools_fails_without_key() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let messages = vec![ChatMessage::user("hello".to_string())];
|
||||
let tools = vec![serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"description": "Run a shell command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
})];
|
||||
let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_response_parses_usage() {
|
||||
let json = r#"{
|
||||
"choices": [{"message": {"content": "Hello"}}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 50}
|
||||
}"#;
|
||||
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
let usage = resp.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, Some(100));
|
||||
assert_eq!(usage.completion_tokens, Some(50));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_reports_native_tools_and_vision() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
|
||||
let caps = <AzureOpenAiProvider as Provider>::capabilities(&p);
|
||||
assert!(caps.native_tool_calling);
|
||||
assert!(caps.vision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools_returns_true() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
|
||||
assert!(p.supports_native_tools());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_vision_returns_true() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
|
||||
assert!(p.supports_vision());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_is_noop() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let result = p.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_api_version_stored() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", Some("2025-01-01"));
|
||||
assert_eq!(p.api_version, "2025-01-01");
|
||||
}
|
||||
}
|
||||
@ -17,6 +17,7 @@
|
||||
//! in [`create_provider_with_url`]. See `AGENTS.md` §7.1 for the full change playbook.
|
||||
|
||||
pub mod anthropic;
|
||||
pub mod azure_openai;
|
||||
pub mod bedrock;
|
||||
pub mod compatible;
|
||||
pub mod copilot;
|
||||
@ -849,6 +850,7 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) ->
|
||||
"vllm" => vec!["VLLM_API_KEY"],
|
||||
"osaurus" => vec!["OSAURUS_API_KEY"],
|
||||
"telnyx" => vec!["TELNYX_API_KEY"],
|
||||
"azure_openai" | "azure-openai" | "azure" => vec!["AZURE_OPENAI_API_KEY"],
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
@ -1125,6 +1127,19 @@ fn create_provider_with_url_and_options(
|
||||
AuthStyle::Bearer,
|
||||
)
|
||||
)),
|
||||
"azure_openai" | "azure-openai" | "azure" => {
|
||||
let resource = std::env::var("AZURE_OPENAI_RESOURCE")
|
||||
.unwrap_or_else(|_| "my-resource".to_string());
|
||||
let deployment = std::env::var("AZURE_OPENAI_DEPLOYMENT")
|
||||
.unwrap_or_else(|_| "gpt-4o".to_string());
|
||||
let api_version = std::env::var("AZURE_OPENAI_API_VERSION").ok();
|
||||
Ok(Box::new(azure_openai::AzureOpenAiProvider::new(
|
||||
key,
|
||||
&resource,
|
||||
&deployment,
|
||||
api_version.as_deref(),
|
||||
)))
|
||||
}
|
||||
"bedrock" | "aws-bedrock" => Ok(Box::new(bedrock::BedrockProvider::new())),
|
||||
name if is_qwen_oauth_alias(name) => {
|
||||
let base_url = api_url
|
||||
|
||||
Loading…
Reference in New Issue
Block a user