zeroclaw/src/providers/traits.rs

241 lines
6.7 KiB
Rust

use crate::tools::ToolSpec;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// A single message in a conversation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
pub fn tool(content: impl Into<String>) -> Self {
Self {
role: "tool".into(),
content: content.into(),
}
}
}
/// A tool call requested by the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
/// An LLM response that may contain text, tool calls, or both.
#[derive(Debug, Clone)]
pub struct ChatResponse {
/// Text content of the response (may be empty if only tool calls).
pub text: Option<String>,
/// Tool calls requested by the LLM.
pub tool_calls: Vec<ToolCall>,
}
impl ChatResponse {
/// True when the LLM wants to invoke at least one tool.
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
/// Convenience: return text content or empty string.
pub fn text_or_empty(&self) -> &str {
self.text.as_deref().unwrap_or("")
}
}
/// Request payload for provider chat calls.
#[derive(Debug, Clone, Copy)]
pub struct ChatRequest<'a> {
pub messages: &'a [ChatMessage],
pub tools: Option<&'a [ToolSpec]>,
}
/// A tool result to feed back to the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMessage {
pub tool_call_id: String,
pub content: String,
}
/// A message in a multi-turn conversation, including tool interactions.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ConversationMessage {
/// Regular chat message (system, user, assistant).
Chat(ChatMessage),
/// Tool calls from the assistant (stored for history fidelity).
AssistantToolCalls {
text: Option<String>,
tool_calls: Vec<ToolCall>,
},
/// Results of tool executions, fed back to the LLM.
ToolResults(Vec<ToolResultMessage>),
}
#[async_trait]
pub trait Provider: Send + Sync {
/// Simple one-shot chat (single user message, no explicit system prompt).
///
/// This is the preferred API for non-agentic direct interactions.
async fn simple_chat(
&self,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.chat_with_system(None, message, model, temperature).await
}
/// One-shot chat with optional system prompt.
///
/// Kept for compatibility and advanced one-shot prompting.
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String>;
/// Multi-turn conversation. Default implementation extracts the last user
/// message and delegates to `chat_with_system`.
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.as_str());
let last_user = messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or("");
self.chat_with_system(system, last_user, model, temperature)
.await
}
/// Structured chat API for agent loop callers.
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
})
}
/// Whether provider supports native tool calls over API.
fn supports_native_tools(&self) -> bool {
false
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
/// Default implementation is a no-op; providers with HTTP clients should override.
async fn warmup(&self) -> anyhow::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
assert_eq!(sys.role, "system");
assert_eq!(sys.content, "Be helpful");
let user = ChatMessage::user("Hello");
assert_eq!(user.role, "user");
let asst = ChatMessage::assistant("Hi there");
assert_eq!(asst.role, "assistant");
let tool = ChatMessage::tool("{}");
assert_eq!(tool.role, "tool");
}
#[test]
fn chat_response_helpers() {
let empty = ChatResponse {
text: None,
tool_calls: vec![],
};
assert!(!empty.has_tool_calls());
assert_eq!(empty.text_or_empty(), "");
let with_tools = ChatResponse {
text: Some("Let me check".into()),
tool_calls: vec![ToolCall {
id: "1".into(),
name: "shell".into(),
arguments: "{}".into(),
}],
};
assert!(with_tools.has_tool_calls());
assert_eq!(with_tools.text_or_empty(), "Let me check");
}
#[test]
fn tool_call_serialization() {
let tc = ToolCall {
id: "call_123".into(),
name: "file_read".into(),
arguments: r#"{"path":"test.txt"}"#.into(),
};
let json = serde_json::to_string(&tc).unwrap();
assert!(json.contains("call_123"));
assert!(json.contains("file_read"));
}
#[test]
fn conversation_message_variants() {
let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
let json = serde_json::to_string(&chat).unwrap();
assert!(json.contains("\"type\":\"Chat\""));
let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
tool_call_id: "1".into(),
content: "done".into(),
}]);
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\""));
}
}