diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 000c4a079..0cd209657 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -4,7 +4,8 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, StreamChunk, StreamError, StreamOptions, StreamResult, ToolCall as ProviderToolCall, + Provider, StreamChunk, StreamError, StreamOptions, StreamResult, TokenUsage, + ToolCall as ProviderToolCall, }; use async_trait::async_trait; use futures_util::{stream, StreamExt}; @@ -272,6 +273,16 @@ struct Message { #[derive(Debug, Deserialize)] struct ApiChatResponse { choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct UsageInfo { + #[serde(default)] + prompt_tokens: Option, + #[serde(default)] + completion_tokens: Option, } #[derive(Debug, Deserialize)] @@ -1270,6 +1281,10 @@ impl Provider for OpenAiCompatibleProvider { let body = response.text().await?; let chat_response = parse_chat_response_body(&self.name, &body)?; + let usage = chat_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let choice = chat_response .choices .into_iter() @@ -1297,7 +1312,7 @@ impl Provider for OpenAiCompatibleProvider { Ok(ProviderChatResponse { text, tool_calls, - usage: None, + usage, }) } @@ -1401,6 +1416,10 @@ impl Provider for OpenAiCompatibleProvider { } let native_response: ApiChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let message = native_response .choices .into_iter() @@ -1408,7 +1427,9 @@ impl Provider for OpenAiCompatibleProvider { .map(|choice| choice.message) .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; - Ok(Self::parse_native_response(message)) + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) } fn supports_native_tools(&self) -> bool { @@ -2460,4 +2481,23 @@ mod tests { let result = parse_sse_line(line).unwrap(); assert_eq!(result, None); } + + #[test] + fn api_response_parses_usage() { + let json = r#"{ + "choices": [{"message": {"content": "Hello"}}], + "usage": {"prompt_tokens": 150, "completion_tokens": 60} + }"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let usage = resp.usage.unwrap(); + assert_eq!(usage.prompt_tokens, Some(150)); + assert_eq!(usage.completion_tokens, Some(60)); + } + + #[test] + fn api_response_parses_without_usage() { + let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.usage.is_none()); + } } diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs index bdb350966..ce768ab3a 100644 --- a/src/providers/copilot.rs +++ b/src/providers/copilot.rs @@ -13,7 +13,7 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ToolCall as ProviderToolCall, + Provider, TokenUsage, ToolCall as ProviderToolCall, }; use crate::tools::ToolSpec; use async_trait::async_trait; @@ -134,6 +134,16 @@ struct NativeFunctionCall { #[derive(Debug, Deserialize)] struct ApiChatResponse { choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct UsageInfo { + #[serde(default)] + prompt_tokens: Option, + #[serde(default)] + completion_tokens: Option, } #[derive(Debug, Deserialize)] @@ -340,6 +350,10 @@ impl CopilotProvider { } let api_response: ApiChatResponse = response.json().await?; + let usage = api_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let choice = api_response .choices .into_iter() @@ -363,7 +377,7 @@ impl CopilotProvider { Ok(ProviderChatResponse { text: choice.message.content, tool_calls, - usage: None, + usage, }) } @@ -701,4 +715,23 @@ mod tests { let provider = CopilotProvider::new(None); assert!(provider.supports_native_tools()); } + + #[test] + fn api_response_parses_usage() { + let json = r#"{ + "choices": [{"message": {"content": "Hello"}}], + "usage": {"prompt_tokens": 200, "completion_tokens": 80} + }"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let usage = resp.usage.unwrap(); + assert_eq!(usage.prompt_tokens, Some(200)); + assert_eq!(usage.completion_tokens, Some(80)); + } + + #[test] + fn api_response_parses_without_usage() { + let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.usage.is_none()); + } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index c69536966..02400d067 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,6 +1,6 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ToolCall as ProviderToolCall, + Provider, TokenUsage, ToolCall as ProviderToolCall, }; use crate::tools::ToolSpec; use async_trait::async_trait; @@ -121,6 +121,16 @@ struct NativeFunctionCall { #[derive(Debug, Deserialize)] struct NativeChatResponse { choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct UsageInfo { + #[serde(default)] + prompt_tokens: Option, + #[serde(default)] + completion_tokens: Option, } #[derive(Debug, Deserialize)] @@ -359,13 +369,19 @@ impl Provider for OpenAiProvider { } let native_response: NativeChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let message = native_response .choices .into_iter() .next() .map(|c| c.message) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; - Ok(Self::parse_native_response(message)) + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) } fn supports_native_tools(&self) -> bool { @@ -416,13 +432,19 @@ impl Provider for OpenAiProvider { } let native_response: NativeChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let message = native_response .choices .into_iter() .next() .map(|c| c.message) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; - Ok(Self::parse_native_response(message)) + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) } async fn warmup(&self) -> anyhow::Result<()> { @@ -678,4 +700,23 @@ mod tests { assert_eq!(spec.kind, "function"); assert_eq!(spec.function.name, "shell"); } + + #[test] + fn native_response_parses_usage() { + let json = r#"{ + "choices": [{"message": {"content": "Hello"}}], + "usage": {"prompt_tokens": 100, "completion_tokens": 50} + }"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + let usage = resp.usage.unwrap(); + assert_eq!(usage.prompt_tokens, Some(100)); + assert_eq!(usage.completion_tokens, Some(50)); + } + + #[test] + fn native_response_parses_without_usage() { + let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.usage.is_none()); + } } diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 0a2347a29..8d1030c18 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -1,6 +1,6 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ToolCall as ProviderToolCall, + Provider, TokenUsage, ToolCall as ProviderToolCall, }; use crate::tools::ToolSpec; use async_trait::async_trait; @@ -93,6 +93,16 @@ struct NativeFunctionCall { #[derive(Debug, Deserialize)] struct NativeChatResponse { choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct UsageInfo { + #[serde(default)] + prompt_tokens: Option, + #[serde(default)] + completion_tokens: Option, } #[derive(Debug, Deserialize)] @@ -388,13 +398,19 @@ impl Provider for OpenRouterProvider { } let native_response: NativeChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let message = native_response .choices .into_iter() .next() .map(|c| c.message) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; - Ok(Self::parse_native_response(message)) + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) } fn supports_native_tools(&self) -> bool { @@ -476,13 +492,19 @@ impl Provider for OpenRouterProvider { } let native_response: NativeChatResponse = response.json().await?; + let usage = native_response.usage.map(|u| TokenUsage { + input_tokens: u.prompt_tokens, + output_tokens: u.completion_tokens, + }); let message = native_response .choices .into_iter() .next() .map(|c| c.message) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; - Ok(Self::parse_native_response(message)) + let mut result = Self::parse_native_response(message); + result.usage = usage; + Ok(result) } } @@ -749,4 +771,23 @@ mod tests { assert_eq!(converted[0].content.as_deref(), Some("done")); assert!(converted[0].tool_calls.is_none()); } + + #[test] + fn native_response_parses_usage() { + let json = r#"{ + "choices": [{"message": {"content": "Hello"}}], + "usage": {"prompt_tokens": 42, "completion_tokens": 15} + }"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + let usage = resp.usage.unwrap(); + assert_eq!(usage.prompt_tokens, Some(42)); + assert_eq!(usage.completion_tokens, Some(15)); + } + + #[test] + fn native_response_parses_without_usage() { + let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#; + let resp: NativeChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.usage.is_none()); + } }