Merge remote-tracking branch 'origin/master'

This commit is contained in:
argenis de la rosa 2026-03-13 09:38:00 -04:00
commit 7e0570abd6
20 changed files with 2212 additions and 67 deletions

View File

@ -48,8 +48,8 @@ schemars = "1.2"
tracing = { version = "0.1", default-features = false }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
# Observability - Prometheus metrics
prometheus = { version = "0.14", default-features = false }
# Observability - Prometheus metrics (optional; requires AtomicU64, unavailable on 32-bit)
prometheus = { version = "0.14", default-features = false, optional = true }
# Base64 encoding (screenshots, image data)
base64 = "0.22"
@ -205,6 +205,8 @@ sandbox-landlock = ["dep:landlock"]
sandbox-bubblewrap = []
# Backward-compatible alias for older invocations
landlock = ["sandbox-landlock"]
# Prometheus metrics observer (requires 64-bit atomics; disable on 32-bit targets)
metrics = ["dep:prometheus"]
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
probe = ["dep:probe-rs"]
# rag-pdf = PDF ingestion for datasheet RAG

110
build.rs
View File

@ -1,6 +1,110 @@
use std::path::Path;
use std::process::Command;
fn main() {
let dir = std::path::Path::new("web/dist");
if !dir.exists() {
std::fs::create_dir_all(dir).expect("failed to create web/dist/");
let dist_dir = Path::new("web/dist");
let web_dir = Path::new("web");
// Tell Cargo to re-run this script when web source files change.
println!("cargo:rerun-if-changed=web/src");
println!("cargo:rerun-if-changed=web/index.html");
println!("cargo:rerun-if-changed=web/package.json");
println!("cargo:rerun-if-changed=web/vite.config.ts");
// Attempt to build the web frontend if npm is available and web/dist is
// missing or stale. The build is best-effort: when Node.js is not
// installed (e.g. CI containers, cross-compilation, minimal dev setups)
// we fall back to the existing stub/empty dist directory so the Rust
// build still succeeds.
let needs_build = !dist_dir.join("index.html").exists();
if needs_build && web_dir.join("package.json").exists() {
if let Ok(npm) = which_npm() {
eprintln!("cargo:warning=Building web frontend (web/dist is missing or stale)...");
// npm ci / npm install
let install_status = Command::new(&npm)
.args(["ci", "--ignore-scripts"])
.current_dir(web_dir)
.status();
match install_status {
Ok(s) if s.success() => {}
Ok(s) => {
// Fall back to `npm install` if `npm ci` fails (no lockfile, etc.)
eprintln!("cargo:warning=npm ci exited with {s}, trying npm install...");
let fallback = Command::new(&npm)
.args(["install"])
.current_dir(web_dir)
.status();
if !matches!(fallback, Ok(s) if s.success()) {
eprintln!("cargo:warning=npm install failed — skipping web build");
ensure_dist_dir(dist_dir);
return;
}
}
Err(e) => {
eprintln!("cargo:warning=Could not run npm: {e} — skipping web build");
ensure_dist_dir(dist_dir);
return;
}
}
// npm run build
let build_status = Command::new(&npm)
.args(["run", "build"])
.current_dir(web_dir)
.status();
match build_status {
Ok(s) if s.success() => {
eprintln!("cargo:warning=Web frontend built successfully.");
}
Ok(s) => {
eprintln!(
"cargo:warning=npm run build exited with {s} — web dashboard may be unavailable"
);
}
Err(e) => {
eprintln!(
"cargo:warning=Could not run npm build: {e} — web dashboard may be unavailable"
);
}
}
}
}
ensure_dist_dir(dist_dir);
}
/// Ensure the dist directory exists so `rust-embed` does not fail at compile
/// time even when the web frontend is not built.
fn ensure_dist_dir(dist_dir: &Path) {
if !dist_dir.exists() {
std::fs::create_dir_all(dist_dir).expect("failed to create web/dist/");
}
}
/// Locate the `npm` binary on the system PATH.
fn which_npm() -> Result<String, ()> {
let cmd = if cfg!(target_os = "windows") {
"where"
} else {
"which"
};
Command::new(cmd)
.arg("npm")
.output()
.ok()
.and_then(|output| {
if output.status.success() {
String::from_utf8(output.stdout)
.ok()
.map(|s| s.lines().next().unwrap_or("npm").trim().to_string())
} else {
None
}
})
.ok_or(())
}

View File

@ -211,8 +211,35 @@ should_attempt_prebuilt_for_resources() {
return 1
}
resolve_asset_url() {
local asset_name="$1"
local api_url="https://api.github.com/repos/zeroclaw-labs/zeroclaw/releases"
local releases_json download_url
# Fetch up to 10 recent releases (includes prereleases) and find the first
# one that contains the requested asset.
releases_json="$(curl -fsSL "${api_url}?per_page=10" 2>/dev/null || true)"
if [[ -z "$releases_json" ]]; then
return 1
fi
# Parse with simple grep/sed — avoids jq dependency.
download_url="$(printf '%s\n' "$releases_json" \
| tr ',' '\n' \
| grep '"browser_download_url"' \
| sed 's/.*"browser_download_url"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/' \
| grep "/${asset_name}\$" \
| head -n 1)"
if [[ -z "$download_url" ]]; then
return 1
fi
echo "$download_url"
}
install_prebuilt_binary() {
local target archive_url temp_dir archive_path extracted_bin install_dir
local target archive_url temp_dir archive_path extracted_bin install_dir asset_name
if ! have_cmd curl; then
warn "curl is required for pre-built binary installation."
@ -229,9 +256,17 @@ install_prebuilt_binary() {
return 1
fi
archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-${target}.tar.gz"
asset_name="zeroclaw-${target}.tar.gz"
# Try the GitHub API first to find the newest release (including prereleases)
# that actually contains the asset, then fall back to /releases/latest/.
archive_url="$(resolve_asset_url "$asset_name" || true)"
if [[ -z "$archive_url" ]]; then
archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/${asset_name}"
fi
temp_dir="$(mktemp -d -t zeroclaw-prebuilt-XXXXXX)"
archive_path="$temp_dir/zeroclaw-${target}.tar.gz"
archive_path="$temp_dir/${asset_name}"
info "Attempting pre-built binary install for target: $target"
if ! curl -fsSL "$archive_url" -o "$archive_path"; then

View File

@ -1,6 +1,10 @@
use crate::channels::traits::{Channel, ChannelMessage, SendMessage};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(target_has_atomic = "64"))]
use std::sync::atomic::AtomicU32;
#[cfg(target_has_atomic = "64")]
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::{mpsc, Mutex};
@ -13,7 +17,10 @@ use tokio_rustls::rustls;
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
/// Monotonic counter to ensure unique message IDs under burst traffic.
#[cfg(target_has_atomic = "64")]
static MSG_SEQ: AtomicU64 = AtomicU64::new(0);
#[cfg(not(target_has_atomic = "64"))]
static MSG_SEQ: AtomicU32 = AtomicU32::new(0);
/// IRC over TLS channel.
///

View File

@ -61,20 +61,33 @@ impl LinqChannel {
/// Parse an incoming webhook payload from Linq and extract messages.
///
/// Linq webhook envelope:
/// Supports two webhook formats:
///
/// **New format (webhook_version 2026-02-03):**
/// ```json
/// {
/// "api_version": "v3",
/// "webhook_version": "2026-02-03",
/// "event_type": "message.received",
/// "data": {
/// "id": "msg-...",
/// "direction": "inbound",
/// "sender_handle": { "handle": "+1...", "is_me": false },
/// "chat": { "id": "chat-..." },
/// "parts": [{ "type": "text", "value": "..." }]
/// }
/// }
/// ```
///
/// **Legacy format (webhook_version 2025-01-01):**
/// ```json
/// {
/// "api_version": "v3",
/// "event_type": "message.received",
/// "event_id": "...",
/// "created_at": "...",
/// "trace_id": "...",
/// "data": {
/// "chat_id": "...",
/// "from": "+1...",
/// "recipient_phone": "+1...",
/// "is_from_me": false,
/// "service": "iMessage",
/// "message": {
/// "id": "...",
/// "parts": [{ "type": "text", "value": "..." }]
@ -99,18 +112,44 @@ impl LinqChannel {
return messages;
};
// Detect format: new format has `sender_handle`, legacy has `from`.
let is_new_format = data.get("sender_handle").is_some();
// Skip messages sent by the bot itself
if data
.get("is_from_me")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
let is_from_me = if is_new_format {
// New format: data.sender_handle.is_me or data.direction == "outbound"
data.get("sender_handle")
.and_then(|sh| sh.get("is_me"))
.and_then(|v| v.as_bool())
.unwrap_or(false)
|| data
.get("direction")
.and_then(|d| d.as_str())
.is_some_and(|d| d == "outbound")
} else {
// Legacy format: data.is_from_me
data.get("is_from_me")
.and_then(|v| v.as_bool())
.unwrap_or(false)
};
if is_from_me {
tracing::debug!("Linq: skipping is_from_me message");
return messages;
}
// Get sender phone number
let Some(from) = data.get("from").and_then(|f| f.as_str()) else {
let from = if is_new_format {
// New format: data.sender_handle.handle
data.get("sender_handle")
.and_then(|sh| sh.get("handle"))
.and_then(|h| h.as_str())
} else {
// Legacy format: data.from
data.get("from").and_then(|f| f.as_str())
};
let Some(from) = from else {
return messages;
};
@ -132,18 +171,33 @@ impl LinqChannel {
}
// Get chat_id for reply routing
let chat_id = data
.get("chat_id")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
// Extract text from message parts
let Some(message) = data.get("message") else {
return messages;
let chat_id = if is_new_format {
// New format: data.chat.id
data.get("chat")
.and_then(|c| c.get("id"))
.and_then(|id| id.as_str())
.unwrap_or("")
.to_string()
} else {
// Legacy format: data.chat_id
data.get("chat_id")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
};
let Some(parts) = message.get("parts").and_then(|p| p.as_array()) else {
// Extract message parts
let parts = if is_new_format {
// New format: data.parts (directly on data)
data.get("parts").and_then(|p| p.as_array())
} else {
// Legacy format: data.message.parts
data.get("message")
.and_then(|m| m.get("parts"))
.and_then(|p| p.as_array())
};
let Some(parts) = parts else {
return messages;
};
@ -790,4 +844,217 @@ mod tests {
let ch = make_channel();
assert_eq!(ch.phone_number(), "+15551234567");
}
// ---- New format (2026-02-03) tests ----
#[test]
fn linq_parse_new_format_text_message() {
let ch = make_channel();
let payload = serde_json::json!({
"api_version": "v3",
"webhook_version": "2026-02-03",
"event_type": "message.received",
"event_id": "evt-123",
"created_at": "2026-03-01T12:00:00Z",
"trace_id": "trace-456",
"data": {
"id": "msg-abc",
"direction": "inbound",
"sender_handle": {
"handle": "+1234567890",
"is_me": false
},
"chat": { "id": "chat-789" },
"service": "iMessage",
"parts": [{
"type": "text",
"value": "Hello from new format!"
}]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].sender, "+1234567890");
assert_eq!(msgs[0].content, "Hello from new format!");
assert_eq!(msgs[0].channel, "linq");
assert_eq!(msgs[0].reply_target, "chat-789");
}
#[test]
fn linq_parse_new_format_skip_is_me() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "outbound",
"sender_handle": {
"handle": "+15551234567",
"is_me": true
},
"chat": { "id": "chat-789" },
"parts": [{ "type": "text", "value": "My own message" }]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert!(
msgs.is_empty(),
"is_me messages should be skipped in new format"
);
}
#[test]
fn linq_parse_new_format_skip_outbound_direction() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "outbound",
"sender_handle": {
"handle": "+15551234567",
"is_me": false
},
"chat": { "id": "chat-789" },
"parts": [{ "type": "text", "value": "Outbound" }]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert!(msgs.is_empty(), "outbound direction should be skipped");
}
#[test]
fn linq_parse_new_format_unauthorized_sender() {
let ch = make_channel();
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "inbound",
"sender_handle": {
"handle": "+9999999999",
"is_me": false
},
"chat": { "id": "chat-789" },
"parts": [{ "type": "text", "value": "Spam" }]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert!(
msgs.is_empty(),
"Unauthorized senders should be filtered in new format"
);
}
#[test]
fn linq_parse_new_format_media_image() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "inbound",
"sender_handle": {
"handle": "+1234567890",
"is_me": false
},
"chat": { "id": "chat-789" },
"parts": [{
"type": "media",
"url": "https://example.com/photo.png",
"mime_type": "image/png"
}]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "[IMAGE:https://example.com/photo.png]");
}
#[test]
fn linq_parse_new_format_multiple_parts() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "inbound",
"sender_handle": {
"handle": "+1234567890",
"is_me": false
},
"chat": { "id": "chat-789" },
"parts": [
{ "type": "text", "value": "Check this out" },
{ "type": "media", "url": "https://example.com/img.jpg", "mime_type": "image/jpeg" }
]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(
msgs[0].content,
"Check this out\n[IMAGE:https://example.com/img.jpg]"
);
}
#[test]
fn linq_parse_new_format_fallback_reply_target_when_no_chat() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "inbound",
"sender_handle": {
"handle": "+1234567890",
"is_me": false
},
"parts": [{ "type": "text", "value": "Hi" }]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].reply_target, "+1234567890");
}
#[test]
fn linq_parse_new_format_normalizes_phone() {
let ch = LinqChannel::new(
"tok".into(),
"+15551234567".into(),
vec!["+1234567890".into()],
);
let payload = serde_json::json!({
"event_type": "message.received",
"webhook_version": "2026-02-03",
"data": {
"id": "msg-abc",
"direction": "inbound",
"sender_handle": {
"handle": "1234567890",
"is_me": false
},
"chat": { "id": "chat-789" },
"parts": [{ "type": "text", "value": "Hi" }]
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].sender, "+1234567890");
}
}

View File

@ -89,7 +89,11 @@ use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
#[cfg(not(target_has_atomic = "64"))]
use std::sync::atomic::AtomicU32;
#[cfg(target_has_atomic = "64")]
use std::sync::atomic::AtomicU64;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant, SystemTime};
use tokio_util::sync::CancellationToken;
@ -2305,7 +2309,10 @@ async fn run_message_dispatch_loop(
String,
InFlightSenderTaskState,
>::new()));
#[cfg(target_has_atomic = "64")]
let task_sequence = Arc::new(AtomicU64::new(1));
#[cfg(not(target_has_atomic = "64"))]
let task_sequence = Arc::new(AtomicU32::new(1));
while let Some(msg) = rx.recv().await {
let permit = match Arc::clone(&semaphore).acquire_owned().await {
@ -2323,7 +2330,7 @@ async fn run_message_dispatch_loop(
let sender_scope_key = interruption_scope_key(&msg);
let cancellation_token = CancellationToken::new();
let completion = Arc::new(InFlightTaskCompletion::new());
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed) as u64;
if interrupt_enabled {
let previous = {
@ -3364,7 +3371,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
};
// Build system prompt from workspace identity files + skills
let workspace = config.workspace_dir.clone();
let tools_registry = Arc::new(tools::all_tools_with_runtime(
let mut built_tools = tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
@ -3378,7 +3385,44 @@ pub async fn start_channels(config: Config) -> Result<()> {
&config.agents,
config.api_key.as_deref(),
&config,
));
);
// Wire MCP tools into the registry before freezing — non-fatal.
if config.mcp.enabled && !config.mcp.servers.is_empty() {
tracing::info!(
"Initializing MCP client — {} server(s) configured",
config.mcp.servers.len()
);
match crate::tools::mcp_client::McpRegistry::connect_all(&config.mcp.servers).await {
Ok(registry) => {
let registry = std::sync::Arc::new(registry);
let names = registry.tool_names();
let mut registered = 0usize;
for name in names {
if let Some(def) = registry.get_tool_def(&name).await {
let wrapper = crate::tools::mcp_tool::McpToolWrapper::new(
name,
def,
std::sync::Arc::clone(&registry),
);
built_tools.push(Box::new(wrapper));
registered += 1;
}
}
tracing::info!(
"MCP: {} tool(s) registered from {} server(s)",
registered,
registry.server_count()
);
}
Err(e) => {
// Non-fatal — daemon continues with the tools registered above.
tracing::error!("MCP registry failed to initialize: {e:#}");
}
}
}
let tools_registry = Arc::new(built_tools);
let skills = crate::skills::load_skills_with_config(&workspace, &config);

View File

@ -10,12 +10,13 @@ pub use schema::{
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig,
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig,
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope,
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, TelegramConfig, TranscriptionConfig, TtsConfig,
TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
};

View File

@ -232,6 +232,10 @@ pub struct Config {
/// Text-to-Speech configuration (`[tts]`).
#[serde(default)]
pub tts: TtsConfig,
/// External MCP server connections (`[mcp]`).
#[serde(default, alias = "mcpServers")]
pub mcp: McpConfig,
}
/// Named provider profile definition compatible with Codex app-server style config.
@ -455,6 +459,60 @@ impl Default for TranscriptionConfig {
}
}
// ── MCP ─────────────────────────────────────────────────────────
/// Transport type for MCP server connections.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum McpTransport {
/// Spawn a local process and communicate over stdin/stdout.
#[default]
Stdio,
/// Connect via HTTP POST.
Http,
/// Connect via HTTP + Server-Sent Events.
Sse,
}
/// Configuration for a single external MCP server.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
pub struct McpServerConfig {
/// Display name used as a tool prefix (`<server>__<tool>`).
pub name: String,
/// Transport type (default: stdio).
#[serde(default)]
pub transport: McpTransport,
/// URL for HTTP/SSE transports.
#[serde(default)]
pub url: Option<String>,
/// Executable to spawn for stdio transport.
#[serde(default)]
pub command: String,
/// Command arguments for stdio transport.
#[serde(default)]
pub args: Vec<String>,
/// Optional environment variables for stdio transport.
#[serde(default)]
pub env: HashMap<String, String>,
/// Optional HTTP headers for HTTP/SSE transports.
#[serde(default)]
pub headers: HashMap<String, String>,
/// Optional per-call timeout in seconds (hard capped in validation).
#[serde(default)]
pub tool_timeout_secs: Option<u64>,
}
/// External MCP client configuration (`[mcp]` section).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
pub struct McpConfig {
/// Enable MCP tool loading.
#[serde(default)]
pub enabled: bool,
/// Configured MCP servers.
#[serde(default, alias = "mcpServers")]
pub servers: Vec<McpServerConfig>,
}
// ── TTS (Text-to-Speech) ─────────────────────────────────────────
fn default_tts_provider() -> String {
@ -1634,6 +1692,65 @@ fn service_selector_matches(selector: &str, service_key: &str) -> bool {
false
}
const MCP_MAX_TOOL_TIMEOUT_SECS: u64 = 600;
fn validate_mcp_config(config: &McpConfig) -> Result<()> {
let mut seen_names = std::collections::HashSet::new();
for (i, server) in config.servers.iter().enumerate() {
let name = server.name.trim();
if name.is_empty() {
anyhow::bail!("mcp.servers[{i}].name must not be empty");
}
if !seen_names.insert(name.to_ascii_lowercase()) {
anyhow::bail!("mcp.servers contains duplicate name: {name}");
}
if let Some(timeout) = server.tool_timeout_secs {
if timeout == 0 {
anyhow::bail!("mcp.servers[{i}].tool_timeout_secs must be greater than 0");
}
if timeout > MCP_MAX_TOOL_TIMEOUT_SECS {
anyhow::bail!(
"mcp.servers[{i}].tool_timeout_secs exceeds max {MCP_MAX_TOOL_TIMEOUT_SECS}"
);
}
}
match server.transport {
McpTransport::Stdio => {
if server.command.trim().is_empty() {
anyhow::bail!(
"mcp.servers[{i}] with transport=stdio requires non-empty command"
);
}
}
McpTransport::Http | McpTransport::Sse => {
let url = server
.url
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
anyhow::anyhow!(
"mcp.servers[{i}] with transport={} requires url",
match server.transport {
McpTransport::Http => "http",
McpTransport::Sse => "sse",
McpTransport::Stdio => "stdio",
}
)
})?;
let parsed = reqwest::Url::parse(url)
.with_context(|| format!("mcp.servers[{i}].url is not a valid URL"))?;
if !matches!(parsed.scheme(), "http" | "https") {
anyhow::bail!("mcp.servers[{i}].url must use http/https");
}
}
}
}
Ok(())
}
fn validate_proxy_url(field: &str, url: &str) -> Result<()> {
let parsed = reqwest::Url::parse(url)
.with_context(|| format!("Invalid {field} URL: '{url}' is not a valid URL"))?;
@ -3885,6 +4002,7 @@ impl Default for Config {
query_classification: QueryClassificationConfig::default(),
transcription: TranscriptionConfig::default(),
tts: TtsConfig::default(),
mcp: McpConfig::default(),
}
}
}
@ -4844,6 +4962,11 @@ impl Config {
}
}
// MCP
if self.mcp.enabled {
validate_mcp_config(&self.mcp)?;
}
// Proxy (delegate to existing validation)
self.proxy.validate()?;
@ -5850,6 +5973,7 @@ default_temperature = 0.7
hardware: HardwareConfig::default(),
transcription: TranscriptionConfig::default(),
tts: TtsConfig::default(),
mcp: McpConfig::default(),
};
let toml_str = toml::to_string_pretty(&config).unwrap();
@ -6046,6 +6170,7 @@ tool_dispatcher = "xml"
hardware: HardwareConfig::default(),
transcription: TranscriptionConfig::default(),
tts: TtsConfig::default(),
mcp: McpConfig::default(),
};
config.save().await.unwrap();

View File

@ -748,15 +748,25 @@ const PROMETHEUS_CONTENT_TYPE: &str = "text/plain; version=0.0.4; charset=utf-8"
/// GET /metrics — Prometheus text exposition format
async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
let body = if let Some(prom) = state
.observer
.as_ref()
.as_any()
.downcast_ref::<crate::observability::PrometheusObserver>()
{
prom.encode()
} else {
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
let body = {
#[cfg(feature = "metrics")]
{
if let Some(prom) = state
.observer
.as_ref()
.as_any()
.downcast_ref::<crate::observability::PrometheusObserver>()
{
prom.encode()
} else {
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
}
}
#[cfg(not(feature = "metrics"))]
{
let _ = &state;
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
}
};
(
@ -1739,6 +1749,7 @@ mod tests {
assert!(text.contains("Prometheus backend not enabled"));
}
#[cfg(feature = "metrics")]
#[tokio::test]
async fn metrics_endpoint_renders_prometheus_output() {
let prom = Arc::new(crate::observability::PrometheusObserver::new());

View File

@ -3,6 +3,7 @@ pub mod multi;
pub mod noop;
#[cfg(feature = "observability-otel")]
pub mod otel;
#[cfg(feature = "metrics")]
pub mod prometheus;
pub mod runtime_trace;
pub mod traits;
@ -15,6 +16,7 @@ pub use self::multi::MultiObserver;
pub use noop::NoopObserver;
#[cfg(feature = "observability-otel")]
pub use otel::OtelObserver;
#[cfg(feature = "metrics")]
pub use prometheus::PrometheusObserver;
pub use traits::{Observer, ObserverEvent};
#[allow(unused_imports)]
@ -26,7 +28,19 @@ use crate::config::ObservabilityConfig;
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
match config.backend.as_str() {
"log" => Box::new(LogObserver::new()),
"prometheus" => Box::new(PrometheusObserver::new()),
"prometheus" => {
#[cfg(feature = "metrics")]
{
Box::new(PrometheusObserver::new())
}
#[cfg(not(feature = "metrics"))]
{
tracing::warn!(
"Prometheus backend requested but this build was compiled without `metrics`; falling back to noop."
);
Box::new(NoopObserver)
}
}
"otel" | "opentelemetry" | "otlp" => {
#[cfg(feature = "observability-otel")]
match OtelObserver::new(
@ -104,7 +118,12 @@ mod tests {
backend: "prometheus".into(),
..ObservabilityConfig::default()
};
assert_eq!(create_observer(&cfg).name(), "prometheus");
let expected = if cfg!(feature = "metrics") {
"prometheus"
} else {
"noop"
};
assert_eq!(create_observer(&cfg).name(), expected);
}
#[test]

View File

@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
query_classification: crate::config::QueryClassificationConfig::default(),
transcription: crate::config::TranscriptionConfig::default(),
tts: crate::config::TtsConfig::default(),
mcp: crate::config::McpConfig::default(),
};
println!(
@ -526,6 +527,7 @@ async fn run_quick_setup_with_home(
query_classification: crate::config::QueryClassificationConfig::default(),
transcription: crate::config::TranscriptionConfig::default(),
tts: crate::config::TtsConfig::default(),
mcp: crate::config::McpConfig::default(),
};
config.save().await?;

View File

@ -27,7 +27,7 @@ struct ChatRequest {
tools: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Clone, Serialize)]
struct Message {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
@ -40,14 +40,14 @@ struct Message {
tool_name: Option<String>,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Clone, Serialize)]
struct OutgoingToolCall {
#[serde(rename = "type")]
kind: String,
function: OutgoingFunction,
}
#[derive(Debug, Serialize)]
#[derive(Debug, Clone, Serialize)]
struct OutgoingFunction {
name: String,
arguments: serde_json::Value,
@ -258,13 +258,31 @@ impl OllamaProvider {
model: &str,
temperature: f64,
tools: Option<&[serde_json::Value]>,
) -> ChatRequest {
self.build_chat_request_with_think(
messages,
model,
temperature,
tools,
self.reasoning_enabled,
)
}
/// Build a chat request with an explicit `think` value.
fn build_chat_request_with_think(
&self,
messages: Vec<Message>,
model: &str,
temperature: f64,
tools: Option<&[serde_json::Value]>,
think: Option<bool>,
) -> ChatRequest {
ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
think: self.reasoning_enabled,
think,
tools: tools.map(|t| t.to_vec()),
}
}
@ -396,17 +414,18 @@ impl OllamaProvider {
.collect()
}
/// Send a request to Ollama and get the parsed response.
/// Pass `tools` to enable native function-calling for models that support it.
async fn send_request(
/// Send a single HTTP request to Ollama and parse the response.
async fn send_request_inner(
&self,
messages: Vec<Message>,
messages: &[Message],
model: &str,
temperature: f64,
should_auth: bool,
tools: Option<&[serde_json::Value]>,
think: Option<bool>,
) -> anyhow::Result<ApiChatResponse> {
let request = self.build_chat_request(messages, model, temperature, tools);
let request =
self.build_chat_request_with_think(messages.to_vec(), model, temperature, tools, think);
let url = format!("{}/api/chat", self.base_url);
@ -466,6 +485,59 @@ impl OllamaProvider {
Ok(chat_response)
}
/// Send a request to Ollama and get the parsed response.
/// Pass `tools` to enable native function-calling for models that support it.
///
/// When `reasoning_enabled` (`think`) is set to `true`, the first request
/// includes `think: true`. If that request fails (the model may not support
/// the `think` parameter), we automatically retry once with `think` omitted
/// so the call succeeds instead of entering an infinite retry loop.
async fn send_request(
&self,
messages: Vec<Message>,
model: &str,
temperature: f64,
should_auth: bool,
tools: Option<&[serde_json::Value]>,
) -> anyhow::Result<ApiChatResponse> {
let result = self
.send_request_inner(
&messages,
model,
temperature,
should_auth,
tools,
self.reasoning_enabled,
)
.await;
match result {
Ok(resp) => Ok(resp),
Err(first_err) if self.reasoning_enabled == Some(true) => {
tracing::warn!(
model = model,
error = %first_err,
"Ollama request failed with think=true; retrying without reasoning \
(model may not support it)"
);
// Retry with think omitted from the request entirely.
self.send_request_inner(&messages, model, temperature, should_auth, tools, None)
.await
.map_err(|retry_err| {
// Both attempts failed — return the original error for clarity.
tracing::error!(
model = model,
original_error = %first_err,
retry_error = %retry_err,
"Ollama request also failed without think; returning original error"
);
first_err
})
}
Err(e) => Err(e),
}
}
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
///
/// Handles quirky model behavior where tool calls are wrapped:

357
src/tools/mcp_client.rs Normal file
View File

@ -0,0 +1,357 @@
//! MCP (Model Context Protocol) client — connects to external tool servers.
//!
//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use serde_json::json;
use tokio::sync::Mutex;
use tokio::time::{timeout, Duration};
use crate::config::schema::McpServerConfig;
use crate::tools::mcp_protocol::{
JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION,
};
use crate::tools::mcp_transport::{create_transport, McpTransportConn};
/// Timeout for receiving a response from an MCP server during init/list.
/// Prevents a hung server from blocking the daemon indefinitely.
const RECV_TIMEOUT_SECS: u64 = 30;
/// Default timeout for tool calls (seconds) when not configured per-server.
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
// ── Internal server state ──────────────────────────────────────────────────
struct McpServerInner {
config: McpServerConfig,
transport: Box<dyn McpTransportConn>,
next_id: AtomicU64,
tools: Vec<McpToolDef>,
}
// ── McpServer ──────────────────────────────────────────────────────────────
/// A live connection to one MCP server (any transport).
#[derive(Clone)]
pub struct McpServer {
inner: Arc<Mutex<McpServerInner>>,
}
impl McpServer {
/// Connect to the server, perform the initialize handshake, and fetch the tool list.
pub async fn connect(config: McpServerConfig) -> Result<Self> {
// Create transport based on config
let mut transport = create_transport(&config).with_context(|| {
format!(
"failed to create transport for MCP server `{}`",
config.name
)
})?;
// Initialize handshake
let id = 1u64;
let init_req = JsonRpcRequest::new(
id,
"initialize",
json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {
"name": "zeroclaw",
"version": env!("CARGO_PKG_VERSION")
}
}),
);
let init_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
transport.send_and_recv(&init_req),
)
.await
.with_context(|| {
format!(
"MCP server `{}` timed out after {}s waiting for initialize response",
config.name, RECV_TIMEOUT_SECS
)
})??;
if init_resp.error.is_some() {
bail!(
"MCP server `{}` rejected initialize: {:?}",
config.name,
init_resp.error
);
}
// Notify server that client is initialized (no response expected for notifications)
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
// Best effort - ignore errors for notifications
let _ = transport.send_and_recv(&notif).await;
// Fetch available tools
let id = 2u64;
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
let list_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
transport.send_and_recv(&list_req),
)
.await
.with_context(|| {
format!(
"MCP server `{}` timed out after {}s waiting for tools/list response",
config.name, RECV_TIMEOUT_SECS
)
})??;
let result = list_resp
.result
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
let tool_list: McpToolsListResult = serde_json::from_value(result)
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
let tool_count = tool_list.tools.len();
let inner = McpServerInner {
config,
transport,
next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
tools: tool_list.tools,
};
tracing::info!(
"MCP server `{}` connected — {} tool(s) available",
inner.config.name,
tool_count
);
Ok(Self {
inner: Arc::new(Mutex::new(inner)),
})
}
/// Tools advertised by this server.
pub async fn tools(&self) -> Vec<McpToolDef> {
self.inner.lock().await.tools.clone()
}
/// Server display name.
#[allow(dead_code)]
pub async fn name(&self) -> String {
self.inner.lock().await.config.name.clone()
}
/// Call a tool on this server. Returns the raw JSON result.
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
let mut inner = self.inner.lock().await;
let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
let req = JsonRpcRequest::new(
id,
"tools/call",
json!({ "name": tool_name, "arguments": arguments }),
);
// Use per-server tool timeout if configured, otherwise default.
// Cap at MAX_TOOL_TIMEOUT_SECS for safety.
let tool_timeout = inner
.config
.tool_timeout_secs
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
.min(MAX_TOOL_TIMEOUT_SECS);
let resp = timeout(
Duration::from_secs(tool_timeout),
inner.transport.send_and_recv(&req),
)
.await
.map_err(|_| {
anyhow!(
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
inner.config.name,
tool_timeout
)
})?
.with_context(|| {
format!(
"MCP server `{}` error during tool call `{tool_name}`",
inner.config.name
)
})?;
if let Some(err) = resp.error {
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
}
Ok(resp.result.unwrap_or(serde_json::Value::Null))
}
}
// ── McpRegistry ───────────────────────────────────────────────────────────
/// Registry of all connected MCP servers, with a flat tool index.
pub struct McpRegistry {
servers: Vec<McpServer>,
/// prefixed_name -> (server_index, original_tool_name)
tool_index: HashMap<String, (usize, String)>,
}
impl McpRegistry {
/// Connect to all configured servers. Non-fatal: failures are logged and skipped.
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
let mut servers = Vec::new();
let mut tool_index = HashMap::new();
for config in configs {
match McpServer::connect(config.clone()).await {
Ok(server) => {
let server_idx = servers.len();
// Collect tools while holding the lock once, then release
let tools = server.tools().await;
for tool in &tools {
// Prefix prevents name collisions across servers
let prefixed = format!("{}__{}", config.name, tool.name);
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
}
servers.push(server);
}
// Non-fatal — log and continue with remaining servers
Err(e) => {
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
}
}
}
Ok(Self {
servers,
tool_index,
})
}
/// All prefixed tool names across all connected servers.
pub fn tool_names(&self) -> Vec<String> {
self.tool_index.keys().cloned().collect()
}
/// Tool definition for a given prefixed name (cloned).
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
let inner = self.servers[*server_idx].inner.lock().await;
inner
.tools
.iter()
.find(|t| &t.name == original_name)
.cloned()
}
/// Execute a tool by prefixed name.
pub async fn call_tool(
&self,
prefixed_name: &str,
arguments: serde_json::Value,
) -> Result<String> {
let (server_idx, original_name) = self
.tool_index
.get(prefixed_name)
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
let result = self.servers[*server_idx]
.call_tool(original_name, arguments)
.await?;
serde_json::to_string_pretty(&result)
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
}
pub fn is_empty(&self) -> bool {
self.servers.is_empty()
}
pub fn server_count(&self) -> usize {
self.servers.len()
}
pub fn tool_count(&self) -> usize {
self.tool_index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::schema::McpTransport;
#[test]
fn tool_name_prefix_format() {
let prefixed = format!("{}__{}", "filesystem", "read_file");
assert_eq!(prefixed, "filesystem__read_file");
}
#[tokio::test]
async fn connect_nonexistent_command_fails_cleanly() {
// A command that doesn't exist should fail at spawn, not panic.
let config = McpServerConfig {
name: "nonexistent".to_string(),
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
args: vec![],
env: HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: HashMap::default(),
};
let result = McpServer::connect(config).await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("failed to create transport"), "got: {msg}");
}
#[tokio::test]
async fn connect_all_nonfatal_on_single_failure() {
// If one server config is bad, connect_all should succeed (with 0 servers).
let configs = vec![McpServerConfig {
name: "bad".to_string(),
command: "/usr/bin/does_not_exist_zc_test".to_string(),
args: vec![],
env: HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: HashMap::default(),
}];
let registry = McpRegistry::connect_all(&configs)
.await
.expect("connect_all should not fail");
assert!(registry.is_empty());
assert_eq!(registry.tool_count(), 0);
}
#[test]
fn http_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Http,
..Default::default()
};
let result = create_transport(&config);
assert!(result.is_err());
}
#[test]
fn sse_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Sse,
..Default::default()
};
let result = create_transport(&config);
assert!(result.is_err());
}
}

130
src/tools/mcp_protocol.rs Normal file
View File

@ -0,0 +1,130 @@
//! MCP (Model Context Protocol) JSON-RPC 2.0 protocol types.
//! Protocol version: 2024-11-05
//! Adapted from ops-mcp-server/src/protocol.rs for client use.
//! Both Serialize and Deserialize are derived — the client both sends (Serialize)
//! and receives (Deserialize) JSON-RPC messages.
use serde::{Deserialize, Serialize};
pub const JSONRPC_VERSION: &str = "2.0";
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
// Standard JSON-RPC 2.0 error codes
#[allow(dead_code)]
pub const PARSE_ERROR: i32 = -32700;
#[allow(dead_code)]
pub const INVALID_REQUEST: i32 = -32600;
#[allow(dead_code)]
pub const METHOD_NOT_FOUND: i32 = -32601;
#[allow(dead_code)]
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
/// Outbound JSON-RPC request (client -> MCP server).
/// Used for both method calls (with id) and notifications (id = None).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<serde_json::Value>,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
impl JsonRpcRequest {
/// Create a method call request with a numeric id.
pub fn new(id: u64, method: impl Into<String>, params: serde_json::Value) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
id: Some(serde_json::Value::Number(id.into())),
method: method.into(),
params: Some(params),
}
}
/// Create a notification — no id, no response expected from server.
pub fn notification(method: impl Into<String>, params: serde_json::Value) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
id: None,
method: method.into(),
params: Some(params),
}
}
}
/// Inbound JSON-RPC response (MCP server -> client).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
/// JSON-RPC error object embedded in a response.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
/// A tool advertised by an MCP server (from `tools/list` response).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolDef {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(rename = "inputSchema")]
pub input_schema: serde_json::Value,
}
/// Expected shape of the `tools/list` result payload.
#[derive(Debug, Deserialize)]
pub struct McpToolsListResult {
pub tools: Vec<McpToolDef>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_serializes_with_id() {
let req = JsonRpcRequest::new(1, "tools/list", serde_json::json!({}));
let s = serde_json::to_string(&req).unwrap();
assert!(s.contains("\"id\":1"));
assert!(s.contains("\"method\":\"tools/list\""));
assert!(s.contains("\"jsonrpc\":\"2.0\""));
}
#[test]
fn notification_omits_id() {
let notif =
JsonRpcRequest::notification("notifications/initialized", serde_json::json!({}));
let s = serde_json::to_string(&notif).unwrap();
assert!(!s.contains("\"id\""));
}
#[test]
fn response_deserializes() {
let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
assert!(resp.result.is_some());
assert!(resp.error.is_none());
}
#[test]
fn tool_def_deserializes_input_schema() {
let json = r#"{"name":"read_file","description":"Read a file","inputSchema":{"type":"object","properties":{"path":{"type":"string"}}}}"#;
let def: McpToolDef = serde_json::from_str(json).unwrap();
assert_eq!(def.name, "read_file");
assert!(def.input_schema.is_object());
}
}

68
src/tools/mcp_tool.rs Normal file
View File

@ -0,0 +1,68 @@
//! Wraps a discovered MCP tool as a zeroclaw [`Tool`] so it is dispatched
//! through the existing tool registry and agent loop without modification.
use std::sync::Arc;
use async_trait::async_trait;
use crate::tools::mcp_client::McpRegistry;
use crate::tools::mcp_protocol::McpToolDef;
use crate::tools::traits::{Tool, ToolResult};
/// A zeroclaw [`Tool`] backed by an MCP server tool.
///
/// The `prefixed_name` (e.g. `filesystem__read_file`) is what the agent loop
/// sees. The registry knows how to route it to the correct server.
pub struct McpToolWrapper {
/// Prefixed name: `<server_name>__<tool_name>`.
prefixed_name: String,
/// Description extracted from the MCP tool definition. Stored as an owned
/// String so that `description()` can return `&str` with self's lifetime.
description: String,
/// JSON schema for the tool's input parameters.
input_schema: serde_json::Value,
/// Shared registry — used to dispatch actual tool calls.
registry: Arc<McpRegistry>,
}
impl McpToolWrapper {
pub fn new(prefixed_name: String, def: McpToolDef, registry: Arc<McpRegistry>) -> Self {
let description = def.description.unwrap_or_else(|| "MCP tool".to_string());
Self {
prefixed_name,
description,
input_schema: def.input_schema,
registry,
}
}
}
#[async_trait]
impl Tool for McpToolWrapper {
fn name(&self) -> &str {
&self.prefixed_name
}
fn description(&self) -> &str {
&self.description
}
fn parameters_schema(&self) -> serde_json::Value {
self.input_schema.clone()
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
match self.registry.call_tool(&self.prefixed_name, args).await {
Ok(output) => Ok(ToolResult {
success: true,
output,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
}

868
src/tools/mcp_transport.rs Normal file
View File

@ -0,0 +1,868 @@
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
use std::borrow::Cow;
use anyhow::{anyhow, bail, Context, Result};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::{oneshot, Mutex, Notify};
use tokio::time::{timeout, Duration};
use tokio_stream::StreamExt;
use crate::config::schema::{McpServerConfig, McpTransport};
use crate::tools::mcp_protocol::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR};
/// Maximum bytes for a single JSON-RPC response.
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
/// Timeout for init/list operations.
const RECV_TIMEOUT_SECS: u64 = 30;
// ── Transport Trait ──────────────────────────────────────────────────────
/// Abstract transport for MCP communication.
#[async_trait::async_trait]
pub trait McpTransportConn: Send + Sync {
/// Send a JSON-RPC request and receive the response.
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
/// Close the connection.
async fn close(&mut self) -> Result<()>;
}
// ── Stdio Transport ──────────────────────────────────────────────────────
/// Stdio-based transport (spawn local process).
pub struct StdioTransport {
_child: Child,
stdin: tokio::process::ChildStdin,
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
}
impl StdioTransport {
pub fn new(config: &McpServerConfig) -> Result<Self> {
let mut child = Command::new(&config.command)
.args(&config.args)
.envs(&config.env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit())
.kill_on_drop(true)
.spawn()
.with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
let stdout_lines = BufReader::new(stdout).lines();
Ok(Self {
_child: child,
stdin,
stdout_lines,
})
}
async fn send_raw(&mut self, line: &str) -> Result<()> {
self.stdin
.write_all(line.as_bytes())
.await
.context("failed to write to MCP server stdin")?;
self.stdin
.write_all(b"\n")
.await
.context("failed to write newline to MCP server stdin")?;
self.stdin.flush().await.context("failed to flush stdin")?;
Ok(())
}
async fn recv_raw(&mut self) -> Result<String> {
let line = self
.stdout_lines
.next_line()
.await?
.ok_or_else(|| anyhow!("MCP server closed stdout"))?;
if line.len() > MAX_LINE_BYTES {
bail!("MCP response too large: {} bytes", line.len());
}
Ok(line)
}
}
#[async_trait::async_trait]
impl McpTransportConn for StdioTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let line = serde_json::to_string(request)?;
self.send_raw(&line).await?;
if request.id.is_none() {
return Ok(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: None,
});
}
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
.await
.context("timeout waiting for MCP response")??;
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
Ok(resp)
}
async fn close(&mut self) -> Result<()> {
let _ = self.stdin.shutdown().await;
Ok(())
}
}
// ── HTTP Transport ───────────────────────────────────────────────────────
/// HTTP-based transport (POST requests).
pub struct HttpTransport {
url: String,
client: reqwest::Client,
headers: std::collections::HashMap<String, String>,
}
impl HttpTransport {
pub fn new(config: &McpServerConfig) -> Result<Self> {
let url = config
.url
.as_ref()
.ok_or_else(|| anyhow!("URL required for HTTP transport"))?
.clone();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.build()
.context("failed to build HTTP client")?;
Ok(Self {
url,
client,
headers: config.headers.clone(),
})
}
}
#[async_trait::async_trait]
impl McpTransportConn for HttpTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
let body = serde_json::to_string(request)?;
let mut req = self.client.post(&self.url).body(body);
for (key, value) in &self.headers {
req = req.header(key, value);
}
let resp = req
.send()
.await
.context("HTTP request to MCP server failed")?;
if !resp.status().is_success() {
bail!("MCP server returned HTTP {}", resp.status());
}
if request.id.is_none() {
return Ok(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: None,
});
}
let resp_text = resp.text().await.context("failed to read HTTP response")?;
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
Ok(mcp_resp)
}
async fn close(&mut self) -> Result<()> {
Ok(())
}
}
// ── SSE Transport ─────────────────────────────────────────────────────────
/// SSE-based transport (HTTP POST for requests, SSE for responses).
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum SseStreamState {
Unknown,
Connected,
Unsupported,
}
pub struct SseTransport {
sse_url: String,
server_name: String,
client: reqwest::Client,
headers: std::collections::HashMap<String, String>,
stream_state: SseStreamState,
shared: std::sync::Arc<Mutex<SseSharedState>>,
notify: std::sync::Arc<Notify>,
shutdown_tx: Option<oneshot::Sender<()>>,
reader_task: Option<tokio::task::JoinHandle<()>>,
}
impl SseTransport {
pub fn new(config: &McpServerConfig) -> Result<Self> {
let sse_url = config
.url
.as_ref()
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
.clone();
let client = reqwest::Client::builder()
.build()
.context("failed to build HTTP client")?;
Ok(Self {
sse_url,
server_name: config.name.clone(),
client,
headers: config.headers.clone(),
stream_state: SseStreamState::Unknown,
shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
notify: std::sync::Arc::new(Notify::new()),
shutdown_tx: None,
reader_task: None,
})
}
async fn ensure_connected(&mut self) -> Result<()> {
if self.stream_state == SseStreamState::Unsupported {
return Ok(());
}
if let Some(task) = &self.reader_task {
if !task.is_finished() {
self.stream_state = SseStreamState::Connected;
return Ok(());
}
}
let mut req = self
.client
.get(&self.sse_url)
.header("Accept", "text/event-stream")
.header("Cache-Control", "no-cache");
for (key, value) in &self.headers {
req = req.header(key, value);
}
let resp = req.send().await.context("SSE GET to MCP server failed")?;
if resp.status() == reqwest::StatusCode::NOT_FOUND
|| resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
{
self.stream_state = SseStreamState::Unsupported;
return Ok(());
}
if !resp.status().is_success() {
return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
}
let is_event_stream = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
if !is_event_stream {
self.stream_state = SseStreamState::Unsupported;
return Ok(());
}
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
self.shutdown_tx = Some(shutdown_tx);
let shared = self.shared.clone();
let notify = self.notify.clone();
let sse_url = self.sse_url.clone();
let server_name = self.server_name.clone();
self.reader_task = Some(tokio::spawn(async move {
let stream = resp
.bytes_stream()
.map(|item| item.map_err(std::io::Error::other));
let reader = tokio_util::io::StreamReader::new(stream);
let mut lines = BufReader::new(reader).lines();
let mut cur_event: Option<String> = None;
let mut cur_id: Option<String> = None;
let mut cur_data: Vec<String> = Vec::new();
loop {
tokio::select! {
_ = &mut shutdown_rx => {
break;
}
line = lines.next_line() => {
let Ok(line_opt) = line else { break; };
let Some(mut line) = line_opt else { break; };
if line.ends_with('\r') {
line.pop();
}
if line.is_empty() {
if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
continue;
}
let event = cur_event.take();
let data = cur_data.join("\n");
cur_data.clear();
let id = cur_id.take();
handle_sse_event(
&server_name,
&sse_url,
&shared,
&notify,
event.as_deref(),
id.as_deref(),
data,
)
.await;
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix("event:") {
cur_event = Some(rest.trim().to_string());
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
cur_data.push(rest.to_string());
continue;
}
if let Some(rest) = line.strip_prefix("id:") {
cur_id = Some(rest.trim().to_string());
}
}
}
}
let pending = {
let mut guard = shared.lock().await;
std::mem::take(&mut guard.pending)
};
for (_, tx) in pending {
let _ = tx.send(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: Some(JsonRpcError {
code: INTERNAL_ERROR,
message: "SSE connection closed".to_string(),
data: None,
}),
});
}
}));
self.stream_state = SseStreamState::Connected;
Ok(())
}
async fn get_message_url(&self) -> Result<(String, bool)> {
let guard = self.shared.lock().await;
if let Some(url) = &guard.message_url {
return Ok((url.clone(), guard.message_url_from_endpoint));
}
drop(guard);
let derived = derive_message_url(&self.sse_url, "messages")
.or_else(|| derive_message_url(&self.sse_url, "message"))
.ok_or_else(|| anyhow!("invalid SSE URL"))?;
let mut guard = self.shared.lock().await;
if guard.message_url.is_none() {
guard.message_url = Some(derived.clone());
guard.message_url_from_endpoint = false;
}
Ok((derived, false))
}
}
#[derive(Default)]
struct SseSharedState {
message_url: Option<String>,
message_url_from_endpoint: bool,
pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
}
fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
let url = reqwest::Url::parse(sse_url).ok()?;
let mut segments: Vec<&str> = url.path_segments()?.collect();
if segments.is_empty() {
return None;
}
if segments.last().copied() == Some("sse") {
segments.pop();
segments.push(message_path);
let mut new_url = url.clone();
new_url.set_path(&format!("/{}", segments.join("/")));
return Some(new_url.to_string());
}
let mut new_url = url.clone();
let mut path = url.path().trim_end_matches('/').to_string();
path.push('/');
path.push_str(message_path);
new_url.set_path(&path);
Some(new_url.to_string())
}
async fn handle_sse_event(
server_name: &str,
sse_url: &str,
shared: &std::sync::Arc<Mutex<SseSharedState>>,
notify: &std::sync::Arc<Notify>,
event: Option<&str>,
_id: Option<&str>,
data: String,
) {
let event = event.unwrap_or("message");
let trimmed = data.trim();
if trimmed.is_empty() {
return;
}
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
let mut guard = shared.lock().await;
guard.message_url = Some(url);
guard.message_url_from_endpoint = true;
drop(guard);
notify.notify_waiters();
}
return;
}
if !event.eq_ignore_ascii_case("message") {
return;
}
let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
return;
};
let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
let _ = serde_json::from_value::<JsonRpcRequest>(value);
return;
};
let Some(id_val) = resp.id.clone() else {
return;
};
let id = match id_val.as_u64() {
Some(v) => v,
None => return,
};
let tx = {
let mut guard = shared.lock().await;
guard.pending.remove(&id)
};
if let Some(tx) = tx {
let _ = tx.send(resp);
} else {
tracing::debug!(
"MCP SSE `{}` received response for unknown id {}",
server_name,
id
);
}
}
fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
if data.starts_with('{') {
let v: serde_json::Value = serde_json::from_str(data).ok()?;
let endpoint = v.get("endpoint")?.as_str()?;
return parse_endpoint_from_data(sse_url, endpoint);
}
if data.starts_with("http://") || data.starts_with("https://") {
return Some(data.to_string());
}
let base = reqwest::Url::parse(sse_url).ok()?;
base.join(data).ok().map(|u| u.to_string())
}
fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
let text = resp_text.trim_start_matches('\u{feff}');
let mut current_data_lines: Vec<&str> = Vec::new();
let mut last_event_data_lines: Vec<&str> = Vec::new();
for raw_line in text.lines() {
let line = raw_line.trim_end_matches('\r').trim_start();
if line.is_empty() {
if !current_data_lines.is_empty() {
last_event_data_lines = std::mem::take(&mut current_data_lines);
}
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
current_data_lines.push(rest);
}
}
if !current_data_lines.is_empty() {
last_event_data_lines = current_data_lines;
}
if last_event_data_lines.is_empty() {
return Cow::Borrowed(text.trim());
}
if last_event_data_lines.len() == 1 {
return Cow::Borrowed(last_event_data_lines[0].trim());
}
let joined = last_event_data_lines.join("\n");
Cow::Owned(joined.trim().to_string())
}
async fn read_first_jsonrpc_from_sse_response(
resp: reqwest::Response,
) -> Result<Option<JsonRpcResponse>> {
let stream = resp
.bytes_stream()
.map(|item| item.map_err(std::io::Error::other));
let reader = tokio_util::io::StreamReader::new(stream);
let mut lines = BufReader::new(reader).lines();
let mut cur_event: Option<String> = None;
let mut cur_data: Vec<String> = Vec::new();
while let Ok(line_opt) = lines.next_line().await {
let Some(mut line) = line_opt else { break };
if line.ends_with('\r') {
line.pop();
}
if line.is_empty() {
if cur_event.is_none() && cur_data.is_empty() {
continue;
}
let event = cur_event.take();
let data = cur_data.join("\n");
cur_data.clear();
let event = event.unwrap_or_else(|| "message".to_string());
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
{
continue;
}
if !event.eq_ignore_ascii_case("message") {
continue;
}
let trimmed = data.trim();
if trimmed.is_empty() {
continue;
}
let json_str = extract_json_from_sse_text(trimmed);
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
return Ok(Some(resp));
}
continue;
}
if line.starts_with(':') {
continue;
}
if let Some(rest) = line.strip_prefix("event:") {
cur_event = Some(rest.trim().to_string());
continue;
}
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
cur_data.push(rest.to_string());
}
}
Ok(None)
}
#[async_trait::async_trait]
impl McpTransportConn for SseTransport {
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
self.ensure_connected().await?;
let id = request.id.as_ref().and_then(|v| v.as_u64());
let body = serde_json::to_string(request)?;
let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
if self.stream_state == SseStreamState::Connected && !from_endpoint {
for _ in 0..3 {
{
let guard = self.shared.lock().await;
if guard.message_url_from_endpoint {
if let Some(url) = &guard.message_url {
message_url = url.clone();
from_endpoint = true;
break;
}
}
}
let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
}
}
let primary_url = if from_endpoint {
message_url.clone()
} else {
self.sse_url.clone()
};
let secondary_url = if message_url == self.sse_url {
None
} else if primary_url == message_url {
Some(self.sse_url.clone())
} else {
Some(message_url.clone())
};
let has_secondary = secondary_url.is_some();
let mut rx = None;
if let Some(id) = id {
if self.stream_state == SseStreamState::Connected {
let (tx, ch) = oneshot::channel();
{
let mut guard = self.shared.lock().await;
guard.pending.insert(id, tx);
}
rx = Some((id, ch));
}
}
let mut got_direct = None;
let mut last_status = None;
for (i, url) in std::iter::once(primary_url)
.chain(secondary_url.into_iter())
.enumerate()
{
let mut req = self
.client
.post(&url)
.timeout(Duration::from_secs(120))
.body(body.clone())
.header("Content-Type", "application/json");
for (key, value) in &self.headers {
req = req.header(key, value);
}
if !self
.headers
.keys()
.any(|k| k.eq_ignore_ascii_case("Accept"))
{
req = req.header("Accept", "application/json, text/event-stream");
}
let resp = req.send().await.context("SSE POST to MCP server failed")?;
let status = resp.status();
last_status = Some(status);
if (status == reqwest::StatusCode::NOT_FOUND
|| status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
&& i == 0
{
continue;
}
if !status.is_success() {
break;
}
if request.id.is_none() {
got_direct = Some(JsonRpcResponse {
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
id: None,
result: None,
error: None,
});
break;
}
let is_sse = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
if is_sse {
if i == 0 && has_secondary {
match timeout(
Duration::from_secs(3),
read_first_jsonrpc_from_sse_response(resp),
)
.await
{
Ok(res) => {
if let Some(resp) = res? {
got_direct = Some(resp);
}
break;
}
Err(_) => continue,
}
}
if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
got_direct = Some(resp);
}
break;
}
let text = if i == 0 && has_secondary {
match timeout(Duration::from_secs(3), resp.text()).await {
Ok(Ok(t)) => t,
Ok(Err(_)) => String::new(),
Err(_) => continue,
}
} else {
resp.text().await.unwrap_or_default()
};
let trimmed = text.trim();
if !trimmed.is_empty() {
let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
extract_json_from_sse_text(trimmed)
} else {
Cow::Borrowed(trimmed)
};
if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
got_direct = Some(mcp_resp);
}
}
break;
}
if let Some((id, _)) = rx.as_ref() {
if got_direct.is_some() {
let mut guard = self.shared.lock().await;
guard.pending.remove(id);
} else if let Some(status) = last_status {
if !status.is_success() {
let mut guard = self.shared.lock().await;
guard.pending.remove(id);
}
}
}
if let Some(resp) = got_direct {
return Ok(resp);
}
if let Some(status) = last_status {
if !status.is_success() {
bail!("MCP server returned HTTP {}", status);
}
} else {
bail!("MCP request not sent");
}
let Some((_id, rx)) = rx else {
bail!("MCP server returned no response");
};
rx.await.map_err(|_| anyhow!("SSE response channel closed"))
}
async fn close(&mut self) -> Result<()> {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(task) = self.reader_task.take() {
task.abort();
}
Ok(())
}
}
// ── Factory ──────────────────────────────────────────────────────────────
/// Create a transport based on config.
pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
match config.transport {
McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
}
}
// ── Tests ─────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_default_is_stdio() {
let config = McpServerConfig::default();
assert_eq!(config.transport, McpTransport::Stdio);
}
#[test]
fn test_http_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Http,
..Default::default()
};
assert!(HttpTransport::new(&config).is_err());
}
#[test]
fn test_sse_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Sse,
..Default::default()
};
assert!(SseTransport::new(&config).is_err());
}
#[test]
fn test_extract_json_from_sse_data_no_space() {
let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_with_event_and_id() {
let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_multiline_data() {
let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
#[test]
fn test_extract_json_from_sse_uses_last_event_with_data() {
let input =
": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
let extracted = extract_json_from_sse_text(input);
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
}
}

View File

@ -40,6 +40,10 @@ pub mod hardware_memory_map;
pub mod hardware_memory_read;
pub mod http_request;
pub mod image_info;
pub mod mcp_client;
pub mod mcp_protocol;
pub mod mcp_tool;
pub mod mcp_transport;
pub mod memory_forget;
pub mod memory_recall;
pub mod memory_store;

27
web/src/lib/uuid.ts Normal file
View File

@ -0,0 +1,27 @@
/**
* Generate a UUID v4 string.
*
* Uses `crypto.randomUUID()` when available (modern browsers, secure contexts)
* and falls back to a manual implementation backed by `crypto.getRandomValues()`
* for older browsers (e.g. Safari < 15.4, some Electron/Raspberry-Pi builds).
*
* Closes #3303, #3261.
*/
export function generateUUID(): string {
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
return crypto.randomUUID();
}
// Fallback: RFC 4122 version 4 UUID via getRandomValues
// crypto must exist if we reached here (only randomUUID is missing)
const c = globalThis.crypto;
const bytes = new Uint8Array(16);
c.getRandomValues(bytes);
// Set version (4) and variant (10xx) bits per RFC 4122
bytes[6] = (bytes[6]! & 0x0f) | 0x40;
bytes[8] = (bytes[8]! & 0x3f) | 0x80;
const hex = Array.from(bytes, (b) => b.toString(16).padStart(2, '0')).join('');
return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20)}`;
}

View File

@ -1,5 +1,6 @@
import type { WsMessage } from '../types/api';
import { getToken } from './auth';
import { generateUUID } from './uuid';
export type WsMessageHandler = (msg: WsMessage) => void;
export type WsOpenHandler = () => void;
@ -26,7 +27,7 @@ const SESSION_STORAGE_KEY = 'zeroclaw_session_id';
function getOrCreateSessionId(): string {
let id = sessionStorage.getItem(SESSION_STORAGE_KEY);
if (!id) {
id = crypto.randomUUID();
id = generateUUID();
sessionStorage.setItem(SESSION_STORAGE_KEY, id);
}
return id;

View File

@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback } from 'react';
import { Send, Bot, User, AlertCircle, Copy, Check } from 'lucide-react';
import type { WsMessage } from '@/types/api';
import { WebSocketClient } from '@/lib/ws';
import { generateUUID } from '@/lib/uuid';
interface ChatMessage {
id: string;
@ -53,7 +54,7 @@ export default function AgentChat() {
setMessages((prev) => [
...prev,
{
id: crypto.randomUUID(),
id: generateUUID(),
role: 'agent',
content,
timestamp: new Date(),
@ -69,7 +70,7 @@ export default function AgentChat() {
setMessages((prev) => [
...prev,
{
id: crypto.randomUUID(),
id: generateUUID(),
role: 'agent',
content: `[Tool Call] ${msg.name ?? 'unknown'}(${JSON.stringify(msg.args ?? {})})`,
timestamp: new Date(),
@ -81,7 +82,7 @@ export default function AgentChat() {
setMessages((prev) => [
...prev,
{
id: crypto.randomUUID(),
id: generateUUID(),
role: 'agent',
content: `[Tool Result] ${msg.output ?? ''}`,
timestamp: new Date(),
@ -93,7 +94,7 @@ export default function AgentChat() {
setMessages((prev) => [
...prev,
{
id: crypto.randomUUID(),
id: generateUUID(),
role: 'agent',
content: `[Error] ${msg.message ?? 'Unknown error'}`,
timestamp: new Date(),
@ -124,7 +125,7 @@ export default function AgentChat() {
setMessages((prev) => [
...prev,
{
id: crypto.randomUUID(),
id: generateUUID(),
role: 'user',
content: trimmed,
timestamp: new Date(),