diff --git a/Cargo.toml b/Cargo.toml index 137937b8f..664be66a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,8 +48,8 @@ schemars = "1.2" tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] } -# Observability - Prometheus metrics -prometheus = { version = "0.14", default-features = false } +# Observability - Prometheus metrics (optional; requires AtomicU64, unavailable on 32-bit) +prometheus = { version = "0.14", default-features = false, optional = true } # Base64 encoding (screenshots, image data) base64 = "0.22" @@ -205,6 +205,8 @@ sandbox-landlock = ["dep:landlock"] sandbox-bubblewrap = [] # Backward-compatible alias for older invocations landlock = ["sandbox-landlock"] +# Prometheus metrics observer (requires 64-bit atomics; disable on 32-bit targets) +metrics = ["dep:prometheus"] # probe = probe-rs for Nucleo memory read (adds ~50 deps; optional) probe = ["dep:probe-rs"] # rag-pdf = PDF ingestion for datasheet RAG diff --git a/build.rs b/build.rs index 0c7da4abb..01afed645 100644 --- a/build.rs +++ b/build.rs @@ -1,6 +1,110 @@ +use std::path::Path; +use std::process::Command; + fn main() { - let dir = std::path::Path::new("web/dist"); - if !dir.exists() { - std::fs::create_dir_all(dir).expect("failed to create web/dist/"); + let dist_dir = Path::new("web/dist"); + let web_dir = Path::new("web"); + + // Tell Cargo to re-run this script when web source files change. + println!("cargo:rerun-if-changed=web/src"); + println!("cargo:rerun-if-changed=web/index.html"); + println!("cargo:rerun-if-changed=web/package.json"); + println!("cargo:rerun-if-changed=web/vite.config.ts"); + + // Attempt to build the web frontend if npm is available and web/dist is + // missing or stale. The build is best-effort: when Node.js is not + // installed (e.g. CI containers, cross-compilation, minimal dev setups) + // we fall back to the existing stub/empty dist directory so the Rust + // build still succeeds. + let needs_build = !dist_dir.join("index.html").exists(); + + if needs_build && web_dir.join("package.json").exists() { + if let Ok(npm) = which_npm() { + eprintln!("cargo:warning=Building web frontend (web/dist is missing or stale)..."); + + // npm ci / npm install + let install_status = Command::new(&npm) + .args(["ci", "--ignore-scripts"]) + .current_dir(web_dir) + .status(); + + match install_status { + Ok(s) if s.success() => {} + Ok(s) => { + // Fall back to `npm install` if `npm ci` fails (no lockfile, etc.) + eprintln!("cargo:warning=npm ci exited with {s}, trying npm install..."); + let fallback = Command::new(&npm) + .args(["install"]) + .current_dir(web_dir) + .status(); + if !matches!(fallback, Ok(s) if s.success()) { + eprintln!("cargo:warning=npm install failed — skipping web build"); + ensure_dist_dir(dist_dir); + return; + } + } + Err(e) => { + eprintln!("cargo:warning=Could not run npm: {e} — skipping web build"); + ensure_dist_dir(dist_dir); + return; + } + } + + // npm run build + let build_status = Command::new(&npm) + .args(["run", "build"]) + .current_dir(web_dir) + .status(); + + match build_status { + Ok(s) if s.success() => { + eprintln!("cargo:warning=Web frontend built successfully."); + } + Ok(s) => { + eprintln!( + "cargo:warning=npm run build exited with {s} — web dashboard may be unavailable" + ); + } + Err(e) => { + eprintln!( + "cargo:warning=Could not run npm build: {e} — web dashboard may be unavailable" + ); + } + } + } + } + + ensure_dist_dir(dist_dir); +} + +/// Ensure the dist directory exists so `rust-embed` does not fail at compile +/// time even when the web frontend is not built. +fn ensure_dist_dir(dist_dir: &Path) { + if !dist_dir.exists() { + std::fs::create_dir_all(dist_dir).expect("failed to create web/dist/"); } } + +/// Locate the `npm` binary on the system PATH. +fn which_npm() -> Result { + let cmd = if cfg!(target_os = "windows") { + "where" + } else { + "which" + }; + + Command::new(cmd) + .arg("npm") + .output() + .ok() + .and_then(|output| { + if output.status.success() { + String::from_utf8(output.stdout) + .ok() + .map(|s| s.lines().next().unwrap_or("npm").trim().to_string()) + } else { + None + } + }) + .ok_or(()) +} diff --git a/install.sh b/install.sh index b4b53dfe3..d8879f3ea 100755 --- a/install.sh +++ b/install.sh @@ -211,8 +211,35 @@ should_attempt_prebuilt_for_resources() { return 1 } +resolve_asset_url() { + local asset_name="$1" + local api_url="https://api.github.com/repos/zeroclaw-labs/zeroclaw/releases" + local releases_json download_url + + # Fetch up to 10 recent releases (includes prereleases) and find the first + # one that contains the requested asset. + releases_json="$(curl -fsSL "${api_url}?per_page=10" 2>/dev/null || true)" + if [[ -z "$releases_json" ]]; then + return 1 + fi + + # Parse with simple grep/sed — avoids jq dependency. + download_url="$(printf '%s\n' "$releases_json" \ + | tr ',' '\n' \ + | grep '"browser_download_url"' \ + | sed 's/.*"browser_download_url"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/' \ + | grep "/${asset_name}\$" \ + | head -n 1)" + + if [[ -z "$download_url" ]]; then + return 1 + fi + + echo "$download_url" +} + install_prebuilt_binary() { - local target archive_url temp_dir archive_path extracted_bin install_dir + local target archive_url temp_dir archive_path extracted_bin install_dir asset_name if ! have_cmd curl; then warn "curl is required for pre-built binary installation." @@ -229,9 +256,17 @@ install_prebuilt_binary() { return 1 fi - archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-${target}.tar.gz" + asset_name="zeroclaw-${target}.tar.gz" + + # Try the GitHub API first to find the newest release (including prereleases) + # that actually contains the asset, then fall back to /releases/latest/. + archive_url="$(resolve_asset_url "$asset_name" || true)" + if [[ -z "$archive_url" ]]; then + archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/${asset_name}" + fi + temp_dir="$(mktemp -d -t zeroclaw-prebuilt-XXXXXX)" - archive_path="$temp_dir/zeroclaw-${target}.tar.gz" + archive_path="$temp_dir/${asset_name}" info "Attempting pre-built binary install for target: $target" if ! curl -fsSL "$archive_url" -o "$archive_path"; then diff --git a/src/channels/irc.rs b/src/channels/irc.rs index f942692d2..15946572e 100644 --- a/src/channels/irc.rs +++ b/src/channels/irc.rs @@ -1,6 +1,10 @@ use crate::channels::traits::{Channel, ChannelMessage, SendMessage}; use async_trait::async_trait; -use std::sync::atomic::{AtomicU64, Ordering}; +#[cfg(not(target_has_atomic = "64"))] +use std::sync::atomic::AtomicU32; +#[cfg(target_has_atomic = "64")] +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::sync::{mpsc, Mutex}; @@ -13,7 +17,10 @@ use tokio_rustls::rustls; const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); /// Monotonic counter to ensure unique message IDs under burst traffic. +#[cfg(target_has_atomic = "64")] static MSG_SEQ: AtomicU64 = AtomicU64::new(0); +#[cfg(not(target_has_atomic = "64"))] +static MSG_SEQ: AtomicU32 = AtomicU32::new(0); /// IRC over TLS channel. /// diff --git a/src/channels/linq.rs b/src/channels/linq.rs index 123322fdd..214f68427 100644 --- a/src/channels/linq.rs +++ b/src/channels/linq.rs @@ -61,20 +61,33 @@ impl LinqChannel { /// Parse an incoming webhook payload from Linq and extract messages. /// - /// Linq webhook envelope: + /// Supports two webhook formats: + /// + /// **New format (webhook_version 2026-02-03):** + /// ```json + /// { + /// "api_version": "v3", + /// "webhook_version": "2026-02-03", + /// "event_type": "message.received", + /// "data": { + /// "id": "msg-...", + /// "direction": "inbound", + /// "sender_handle": { "handle": "+1...", "is_me": false }, + /// "chat": { "id": "chat-..." }, + /// "parts": [{ "type": "text", "value": "..." }] + /// } + /// } + /// ``` + /// + /// **Legacy format (webhook_version 2025-01-01):** /// ```json /// { /// "api_version": "v3", /// "event_type": "message.received", - /// "event_id": "...", - /// "created_at": "...", - /// "trace_id": "...", /// "data": { /// "chat_id": "...", /// "from": "+1...", - /// "recipient_phone": "+1...", /// "is_from_me": false, - /// "service": "iMessage", /// "message": { /// "id": "...", /// "parts": [{ "type": "text", "value": "..." }] @@ -99,18 +112,44 @@ impl LinqChannel { return messages; }; + // Detect format: new format has `sender_handle`, legacy has `from`. + let is_new_format = data.get("sender_handle").is_some(); + // Skip messages sent by the bot itself - if data - .get("is_from_me") - .and_then(|v| v.as_bool()) - .unwrap_or(false) - { + let is_from_me = if is_new_format { + // New format: data.sender_handle.is_me or data.direction == "outbound" + data.get("sender_handle") + .and_then(|sh| sh.get("is_me")) + .and_then(|v| v.as_bool()) + .unwrap_or(false) + || data + .get("direction") + .and_then(|d| d.as_str()) + .is_some_and(|d| d == "outbound") + } else { + // Legacy format: data.is_from_me + data.get("is_from_me") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + }; + + if is_from_me { tracing::debug!("Linq: skipping is_from_me message"); return messages; } // Get sender phone number - let Some(from) = data.get("from").and_then(|f| f.as_str()) else { + let from = if is_new_format { + // New format: data.sender_handle.handle + data.get("sender_handle") + .and_then(|sh| sh.get("handle")) + .and_then(|h| h.as_str()) + } else { + // Legacy format: data.from + data.get("from").and_then(|f| f.as_str()) + }; + + let Some(from) = from else { return messages; }; @@ -132,18 +171,33 @@ impl LinqChannel { } // Get chat_id for reply routing - let chat_id = data - .get("chat_id") - .and_then(|c| c.as_str()) - .unwrap_or("") - .to_string(); - - // Extract text from message parts - let Some(message) = data.get("message") else { - return messages; + let chat_id = if is_new_format { + // New format: data.chat.id + data.get("chat") + .and_then(|c| c.get("id")) + .and_then(|id| id.as_str()) + .unwrap_or("") + .to_string() + } else { + // Legacy format: data.chat_id + data.get("chat_id") + .and_then(|c| c.as_str()) + .unwrap_or("") + .to_string() }; - let Some(parts) = message.get("parts").and_then(|p| p.as_array()) else { + // Extract message parts + let parts = if is_new_format { + // New format: data.parts (directly on data) + data.get("parts").and_then(|p| p.as_array()) + } else { + // Legacy format: data.message.parts + data.get("message") + .and_then(|m| m.get("parts")) + .and_then(|p| p.as_array()) + }; + + let Some(parts) = parts else { return messages; }; @@ -790,4 +844,217 @@ mod tests { let ch = make_channel(); assert_eq!(ch.phone_number(), "+15551234567"); } + + // ---- New format (2026-02-03) tests ---- + + #[test] + fn linq_parse_new_format_text_message() { + let ch = make_channel(); + let payload = serde_json::json!({ + "api_version": "v3", + "webhook_version": "2026-02-03", + "event_type": "message.received", + "event_id": "evt-123", + "created_at": "2026-03-01T12:00:00Z", + "trace_id": "trace-456", + "data": { + "id": "msg-abc", + "direction": "inbound", + "sender_handle": { + "handle": "+1234567890", + "is_me": false + }, + "chat": { "id": "chat-789" }, + "service": "iMessage", + "parts": [{ + "type": "text", + "value": "Hello from new format!" + }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].sender, "+1234567890"); + assert_eq!(msgs[0].content, "Hello from new format!"); + assert_eq!(msgs[0].channel, "linq"); + assert_eq!(msgs[0].reply_target, "chat-789"); + } + + #[test] + fn linq_parse_new_format_skip_is_me() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "outbound", + "sender_handle": { + "handle": "+15551234567", + "is_me": true + }, + "chat": { "id": "chat-789" }, + "parts": [{ "type": "text", "value": "My own message" }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!( + msgs.is_empty(), + "is_me messages should be skipped in new format" + ); + } + + #[test] + fn linq_parse_new_format_skip_outbound_direction() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "outbound", + "sender_handle": { + "handle": "+15551234567", + "is_me": false + }, + "chat": { "id": "chat-789" }, + "parts": [{ "type": "text", "value": "Outbound" }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "outbound direction should be skipped"); + } + + #[test] + fn linq_parse_new_format_unauthorized_sender() { + let ch = make_channel(); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "inbound", + "sender_handle": { + "handle": "+9999999999", + "is_me": false + }, + "chat": { "id": "chat-789" }, + "parts": [{ "type": "text", "value": "Spam" }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!( + msgs.is_empty(), + "Unauthorized senders should be filtered in new format" + ); + } + + #[test] + fn linq_parse_new_format_media_image() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "inbound", + "sender_handle": { + "handle": "+1234567890", + "is_me": false + }, + "chat": { "id": "chat-789" }, + "parts": [{ + "type": "media", + "url": "https://example.com/photo.png", + "mime_type": "image/png" + }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "[IMAGE:https://example.com/photo.png]"); + } + + #[test] + fn linq_parse_new_format_multiple_parts() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "inbound", + "sender_handle": { + "handle": "+1234567890", + "is_me": false + }, + "chat": { "id": "chat-789" }, + "parts": [ + { "type": "text", "value": "Check this out" }, + { "type": "media", "url": "https://example.com/img.jpg", "mime_type": "image/jpeg" } + ] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!( + msgs[0].content, + "Check this out\n[IMAGE:https://example.com/img.jpg]" + ); + } + + #[test] + fn linq_parse_new_format_fallback_reply_target_when_no_chat() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "inbound", + "sender_handle": { + "handle": "+1234567890", + "is_me": false + }, + "parts": [{ "type": "text", "value": "Hi" }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].reply_target, "+1234567890"); + } + + #[test] + fn linq_parse_new_format_normalizes_phone() { + let ch = LinqChannel::new( + "tok".into(), + "+15551234567".into(), + vec!["+1234567890".into()], + ); + let payload = serde_json::json!({ + "event_type": "message.received", + "webhook_version": "2026-02-03", + "data": { + "id": "msg-abc", + "direction": "inbound", + "sender_handle": { + "handle": "1234567890", + "is_me": false + }, + "chat": { "id": "chat-789" }, + "parts": [{ "type": "text", "value": "Hi" }] + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].sender, "+1234567890"); + } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 76b6f7e12..a2501e9c3 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -89,7 +89,11 @@ use std::collections::{HashMap, HashSet}; use std::fmt::Write; use std::path::{Path, PathBuf}; use std::process::Command; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +#[cfg(not(target_has_atomic = "64"))] +use std::sync::atomic::AtomicU32; +#[cfg(target_has_atomic = "64")] +use std::sync::atomic::AtomicU64; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant, SystemTime}; use tokio_util::sync::CancellationToken; @@ -2305,7 +2309,10 @@ async fn run_message_dispatch_loop( String, InFlightSenderTaskState, >::new())); + #[cfg(target_has_atomic = "64")] let task_sequence = Arc::new(AtomicU64::new(1)); + #[cfg(not(target_has_atomic = "64"))] + let task_sequence = Arc::new(AtomicU32::new(1)); while let Some(msg) = rx.recv().await { let permit = match Arc::clone(&semaphore).acquire_owned().await { @@ -2323,7 +2330,7 @@ async fn run_message_dispatch_loop( let sender_scope_key = interruption_scope_key(&msg); let cancellation_token = CancellationToken::new(); let completion = Arc::new(InFlightTaskCompletion::new()); - let task_id = task_sequence.fetch_add(1, Ordering::Relaxed); + let task_id = task_sequence.fetch_add(1, Ordering::Relaxed) as u64; if interrupt_enabled { let previous = { @@ -3364,7 +3371,7 @@ pub async fn start_channels(config: Config) -> Result<()> { }; // Build system prompt from workspace identity files + skills let workspace = config.workspace_dir.clone(); - let tools_registry = Arc::new(tools::all_tools_with_runtime( + let mut built_tools = tools::all_tools_with_runtime( Arc::new(config.clone()), &security, runtime, @@ -3378,7 +3385,44 @@ pub async fn start_channels(config: Config) -> Result<()> { &config.agents, config.api_key.as_deref(), &config, - )); + ); + + // Wire MCP tools into the registry before freezing — non-fatal. + if config.mcp.enabled && !config.mcp.servers.is_empty() { + tracing::info!( + "Initializing MCP client — {} server(s) configured", + config.mcp.servers.len() + ); + match crate::tools::mcp_client::McpRegistry::connect_all(&config.mcp.servers).await { + Ok(registry) => { + let registry = std::sync::Arc::new(registry); + let names = registry.tool_names(); + let mut registered = 0usize; + for name in names { + if let Some(def) = registry.get_tool_def(&name).await { + let wrapper = crate::tools::mcp_tool::McpToolWrapper::new( + name, + def, + std::sync::Arc::clone(®istry), + ); + built_tools.push(Box::new(wrapper)); + registered += 1; + } + } + tracing::info!( + "MCP: {} tool(s) registered from {} server(s)", + registered, + registry.server_count() + ); + } + Err(e) => { + // Non-fatal — daemon continues with the tools registered above. + tracing::error!("MCP registry failed to initialize: {e:#}"); + } + } + } + + let tools_registry = Arc::new(built_tools); let skills = crate::skills::load_skills_with_config(&workspace, &config); diff --git a/src/config/mod.rs b/src/config/mod.rs index afb4b15ac..6327581d3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -10,12 +10,13 @@ pub use schema::{ CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, - HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, - ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig, - OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, - QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, - RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, + HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, + McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig, + NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod, + PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig, + QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, + SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, + SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, TelegramConfig, TranscriptionConfig, TtsConfig, TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index 23729b2a4..4ef7c9773 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -232,6 +232,10 @@ pub struct Config { /// Text-to-Speech configuration (`[tts]`). #[serde(default)] pub tts: TtsConfig, + + /// External MCP server connections (`[mcp]`). + #[serde(default, alias = "mcpServers")] + pub mcp: McpConfig, } /// Named provider profile definition compatible with Codex app-server style config. @@ -455,6 +459,60 @@ impl Default for TranscriptionConfig { } } +// ── MCP ───────────────────────────────────────────────────────── + +/// Transport type for MCP server connections. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +pub enum McpTransport { + /// Spawn a local process and communicate over stdin/stdout. + #[default] + Stdio, + /// Connect via HTTP POST. + Http, + /// Connect via HTTP + Server-Sent Events. + Sse, +} + +/// Configuration for a single external MCP server. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)] +pub struct McpServerConfig { + /// Display name used as a tool prefix (`__`). + pub name: String, + /// Transport type (default: stdio). + #[serde(default)] + pub transport: McpTransport, + /// URL for HTTP/SSE transports. + #[serde(default)] + pub url: Option, + /// Executable to spawn for stdio transport. + #[serde(default)] + pub command: String, + /// Command arguments for stdio transport. + #[serde(default)] + pub args: Vec, + /// Optional environment variables for stdio transport. + #[serde(default)] + pub env: HashMap, + /// Optional HTTP headers for HTTP/SSE transports. + #[serde(default)] + pub headers: HashMap, + /// Optional per-call timeout in seconds (hard capped in validation). + #[serde(default)] + pub tool_timeout_secs: Option, +} + +/// External MCP client configuration (`[mcp]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)] +pub struct McpConfig { + /// Enable MCP tool loading. + #[serde(default)] + pub enabled: bool, + /// Configured MCP servers. + #[serde(default, alias = "mcpServers")] + pub servers: Vec, +} + // ── TTS (Text-to-Speech) ───────────────────────────────────────── fn default_tts_provider() -> String { @@ -1634,6 +1692,65 @@ fn service_selector_matches(selector: &str, service_key: &str) -> bool { false } +const MCP_MAX_TOOL_TIMEOUT_SECS: u64 = 600; + +fn validate_mcp_config(config: &McpConfig) -> Result<()> { + let mut seen_names = std::collections::HashSet::new(); + for (i, server) in config.servers.iter().enumerate() { + let name = server.name.trim(); + if name.is_empty() { + anyhow::bail!("mcp.servers[{i}].name must not be empty"); + } + if !seen_names.insert(name.to_ascii_lowercase()) { + anyhow::bail!("mcp.servers contains duplicate name: {name}"); + } + + if let Some(timeout) = server.tool_timeout_secs { + if timeout == 0 { + anyhow::bail!("mcp.servers[{i}].tool_timeout_secs must be greater than 0"); + } + if timeout > MCP_MAX_TOOL_TIMEOUT_SECS { + anyhow::bail!( + "mcp.servers[{i}].tool_timeout_secs exceeds max {MCP_MAX_TOOL_TIMEOUT_SECS}" + ); + } + } + + match server.transport { + McpTransport::Stdio => { + if server.command.trim().is_empty() { + anyhow::bail!( + "mcp.servers[{i}] with transport=stdio requires non-empty command" + ); + } + } + McpTransport::Http | McpTransport::Sse => { + let url = server + .url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .ok_or_else(|| { + anyhow::anyhow!( + "mcp.servers[{i}] with transport={} requires url", + match server.transport { + McpTransport::Http => "http", + McpTransport::Sse => "sse", + McpTransport::Stdio => "stdio", + } + ) + })?; + let parsed = reqwest::Url::parse(url) + .with_context(|| format!("mcp.servers[{i}].url is not a valid URL"))?; + if !matches!(parsed.scheme(), "http" | "https") { + anyhow::bail!("mcp.servers[{i}].url must use http/https"); + } + } + } + } + Ok(()) +} + fn validate_proxy_url(field: &str, url: &str) -> Result<()> { let parsed = reqwest::Url::parse(url) .with_context(|| format!("Invalid {field} URL: '{url}' is not a valid URL"))?; @@ -3885,6 +4002,7 @@ impl Default for Config { query_classification: QueryClassificationConfig::default(), transcription: TranscriptionConfig::default(), tts: TtsConfig::default(), + mcp: McpConfig::default(), } } } @@ -4844,6 +4962,11 @@ impl Config { } } + // MCP + if self.mcp.enabled { + validate_mcp_config(&self.mcp)?; + } + // Proxy (delegate to existing validation) self.proxy.validate()?; @@ -5850,6 +5973,7 @@ default_temperature = 0.7 hardware: HardwareConfig::default(), transcription: TranscriptionConfig::default(), tts: TtsConfig::default(), + mcp: McpConfig::default(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -6046,6 +6170,7 @@ tool_dispatcher = "xml" hardware: HardwareConfig::default(), transcription: TranscriptionConfig::default(), tts: TtsConfig::default(), + mcp: McpConfig::default(), }; config.save().await.unwrap(); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 176bfa06a..5a5bc1021 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -748,15 +748,25 @@ const PROMETHEUS_CONTENT_TYPE: &str = "text/plain; version=0.0.4; charset=utf-8" /// GET /metrics — Prometheus text exposition format async fn handle_metrics(State(state): State) -> impl IntoResponse { - let body = if let Some(prom) = state - .observer - .as_ref() - .as_any() - .downcast_ref::() - { - prom.encode() - } else { - String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n") + let body = { + #[cfg(feature = "metrics")] + { + if let Some(prom) = state + .observer + .as_ref() + .as_any() + .downcast_ref::() + { + prom.encode() + } else { + String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n") + } + } + #[cfg(not(feature = "metrics"))] + { + let _ = &state; + String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n") + } }; ( @@ -1739,6 +1749,7 @@ mod tests { assert!(text.contains("Prometheus backend not enabled")); } + #[cfg(feature = "metrics")] #[tokio::test] async fn metrics_endpoint_renders_prometheus_output() { let prom = Arc::new(crate::observability::PrometheusObserver::new()); diff --git a/src/observability/mod.rs b/src/observability/mod.rs index 0f4bddcef..b43692852 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -3,6 +3,7 @@ pub mod multi; pub mod noop; #[cfg(feature = "observability-otel")] pub mod otel; +#[cfg(feature = "metrics")] pub mod prometheus; pub mod runtime_trace; pub mod traits; @@ -15,6 +16,7 @@ pub use self::multi::MultiObserver; pub use noop::NoopObserver; #[cfg(feature = "observability-otel")] pub use otel::OtelObserver; +#[cfg(feature = "metrics")] pub use prometheus::PrometheusObserver; pub use traits::{Observer, ObserverEvent}; #[allow(unused_imports)] @@ -26,7 +28,19 @@ use crate::config::ObservabilityConfig; pub fn create_observer(config: &ObservabilityConfig) -> Box { match config.backend.as_str() { "log" => Box::new(LogObserver::new()), - "prometheus" => Box::new(PrometheusObserver::new()), + "prometheus" => { + #[cfg(feature = "metrics")] + { + Box::new(PrometheusObserver::new()) + } + #[cfg(not(feature = "metrics"))] + { + tracing::warn!( + "Prometheus backend requested but this build was compiled without `metrics`; falling back to noop." + ); + Box::new(NoopObserver) + } + } "otel" | "opentelemetry" | "otlp" => { #[cfg(feature = "observability-otel")] match OtelObserver::new( @@ -104,7 +118,12 @@ mod tests { backend: "prometheus".into(), ..ObservabilityConfig::default() }; - assert_eq!(create_observer(&cfg).name(), "prometheus"); + let expected = if cfg!(feature = "metrics") { + "prometheus" + } else { + "noop" + }; + assert_eq!(create_observer(&cfg).name(), expected); } #[test] diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 1006a3e5f..ae0feb5f5 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result { query_classification: crate::config::QueryClassificationConfig::default(), transcription: crate::config::TranscriptionConfig::default(), tts: crate::config::TtsConfig::default(), + mcp: crate::config::McpConfig::default(), }; println!( @@ -526,6 +527,7 @@ async fn run_quick_setup_with_home( query_classification: crate::config::QueryClassificationConfig::default(), transcription: crate::config::TranscriptionConfig::default(), tts: crate::config::TtsConfig::default(), + mcp: crate::config::McpConfig::default(), }; config.save().await?; diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 1e69c8e83..e91625d96 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -27,7 +27,7 @@ struct ChatRequest { tools: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] struct Message { role: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -40,14 +40,14 @@ struct Message { tool_name: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] struct OutgoingToolCall { #[serde(rename = "type")] kind: String, function: OutgoingFunction, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] struct OutgoingFunction { name: String, arguments: serde_json::Value, @@ -258,13 +258,31 @@ impl OllamaProvider { model: &str, temperature: f64, tools: Option<&[serde_json::Value]>, + ) -> ChatRequest { + self.build_chat_request_with_think( + messages, + model, + temperature, + tools, + self.reasoning_enabled, + ) + } + + /// Build a chat request with an explicit `think` value. + fn build_chat_request_with_think( + &self, + messages: Vec, + model: &str, + temperature: f64, + tools: Option<&[serde_json::Value]>, + think: Option, ) -> ChatRequest { ChatRequest { model: model.to_string(), messages, stream: false, options: Options { temperature }, - think: self.reasoning_enabled, + think, tools: tools.map(|t| t.to_vec()), } } @@ -396,17 +414,18 @@ impl OllamaProvider { .collect() } - /// Send a request to Ollama and get the parsed response. - /// Pass `tools` to enable native function-calling for models that support it. - async fn send_request( + /// Send a single HTTP request to Ollama and parse the response. + async fn send_request_inner( &self, - messages: Vec, + messages: &[Message], model: &str, temperature: f64, should_auth: bool, tools: Option<&[serde_json::Value]>, + think: Option, ) -> anyhow::Result { - let request = self.build_chat_request(messages, model, temperature, tools); + let request = + self.build_chat_request_with_think(messages.to_vec(), model, temperature, tools, think); let url = format!("{}/api/chat", self.base_url); @@ -466,6 +485,59 @@ impl OllamaProvider { Ok(chat_response) } + /// Send a request to Ollama and get the parsed response. + /// Pass `tools` to enable native function-calling for models that support it. + /// + /// When `reasoning_enabled` (`think`) is set to `true`, the first request + /// includes `think: true`. If that request fails (the model may not support + /// the `think` parameter), we automatically retry once with `think` omitted + /// so the call succeeds instead of entering an infinite retry loop. + async fn send_request( + &self, + messages: Vec, + model: &str, + temperature: f64, + should_auth: bool, + tools: Option<&[serde_json::Value]>, + ) -> anyhow::Result { + let result = self + .send_request_inner( + &messages, + model, + temperature, + should_auth, + tools, + self.reasoning_enabled, + ) + .await; + + match result { + Ok(resp) => Ok(resp), + Err(first_err) if self.reasoning_enabled == Some(true) => { + tracing::warn!( + model = model, + error = %first_err, + "Ollama request failed with think=true; retrying without reasoning \ + (model may not support it)" + ); + // Retry with think omitted from the request entirely. + self.send_request_inner(&messages, model, temperature, should_auth, tools, None) + .await + .map_err(|retry_err| { + // Both attempts failed — return the original error for clarity. + tracing::error!( + model = model, + original_error = %first_err, + retry_error = %retry_err, + "Ollama request also failed without think; returning original error" + ); + first_err + }) + } + Err(e) => Err(e), + } + } + /// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs /// /// Handles quirky model behavior where tool calls are wrapped: diff --git a/src/tools/mcp_client.rs b/src/tools/mcp_client.rs new file mode 100644 index 000000000..bc53ad1af --- /dev/null +++ b/src/tools/mcp_client.rs @@ -0,0 +1,357 @@ +//! MCP (Model Context Protocol) client — connects to external tool servers. +//! +//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE. + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +use anyhow::{anyhow, bail, Context, Result}; +use serde_json::json; +use tokio::sync::Mutex; +use tokio::time::{timeout, Duration}; + +use crate::config::schema::McpServerConfig; +use crate::tools::mcp_protocol::{ + JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION, +}; +use crate::tools::mcp_transport::{create_transport, McpTransportConn}; + +/// Timeout for receiving a response from an MCP server during init/list. +/// Prevents a hung server from blocking the daemon indefinitely. +const RECV_TIMEOUT_SECS: u64 = 30; + +/// Default timeout for tool calls (seconds) when not configured per-server. +const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180; + +/// Maximum allowed tool call timeout (seconds) — hard safety ceiling. +const MAX_TOOL_TIMEOUT_SECS: u64 = 600; + +// ── Internal server state ────────────────────────────────────────────────── + +struct McpServerInner { + config: McpServerConfig, + transport: Box, + next_id: AtomicU64, + tools: Vec, +} + +// ── McpServer ────────────────────────────────────────────────────────────── + +/// A live connection to one MCP server (any transport). +#[derive(Clone)] +pub struct McpServer { + inner: Arc>, +} + +impl McpServer { + /// Connect to the server, perform the initialize handshake, and fetch the tool list. + pub async fn connect(config: McpServerConfig) -> Result { + // Create transport based on config + let mut transport = create_transport(&config).with_context(|| { + format!( + "failed to create transport for MCP server `{}`", + config.name + ) + })?; + + // Initialize handshake + let id = 1u64; + let init_req = JsonRpcRequest::new( + id, + "initialize", + json!({ + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": { + "name": "zeroclaw", + "version": env!("CARGO_PKG_VERSION") + } + }), + ); + + let init_resp = timeout( + Duration::from_secs(RECV_TIMEOUT_SECS), + transport.send_and_recv(&init_req), + ) + .await + .with_context(|| { + format!( + "MCP server `{}` timed out after {}s waiting for initialize response", + config.name, RECV_TIMEOUT_SECS + ) + })??; + + if init_resp.error.is_some() { + bail!( + "MCP server `{}` rejected initialize: {:?}", + config.name, + init_resp.error + ); + } + + // Notify server that client is initialized (no response expected for notifications) + let notif = JsonRpcRequest::notification("notifications/initialized", json!({})); + // Best effort - ignore errors for notifications + let _ = transport.send_and_recv(¬if).await; + + // Fetch available tools + let id = 2u64; + let list_req = JsonRpcRequest::new(id, "tools/list", json!({})); + + let list_resp = timeout( + Duration::from_secs(RECV_TIMEOUT_SECS), + transport.send_and_recv(&list_req), + ) + .await + .with_context(|| { + format!( + "MCP server `{}` timed out after {}s waiting for tools/list response", + config.name, RECV_TIMEOUT_SECS + ) + })??; + + let result = list_resp + .result + .ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?; + let tool_list: McpToolsListResult = serde_json::from_value(result) + .with_context(|| format!("failed to parse tools/list from `{}`", config.name))?; + + let tool_count = tool_list.tools.len(); + + let inner = McpServerInner { + config, + transport, + next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2 + tools: tool_list.tools, + }; + + tracing::info!( + "MCP server `{}` connected — {} tool(s) available", + inner.config.name, + tool_count + ); + + Ok(Self { + inner: Arc::new(Mutex::new(inner)), + }) + } + + /// Tools advertised by this server. + pub async fn tools(&self) -> Vec { + self.inner.lock().await.tools.clone() + } + + /// Server display name. + #[allow(dead_code)] + pub async fn name(&self) -> String { + self.inner.lock().await.config.name.clone() + } + + /// Call a tool on this server. Returns the raw JSON result. + pub async fn call_tool( + &self, + tool_name: &str, + arguments: serde_json::Value, + ) -> Result { + let mut inner = self.inner.lock().await; + let id = inner.next_id.fetch_add(1, Ordering::Relaxed); + let req = JsonRpcRequest::new( + id, + "tools/call", + json!({ "name": tool_name, "arguments": arguments }), + ); + + // Use per-server tool timeout if configured, otherwise default. + // Cap at MAX_TOOL_TIMEOUT_SECS for safety. + let tool_timeout = inner + .config + .tool_timeout_secs + .unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS) + .min(MAX_TOOL_TIMEOUT_SECS); + + let resp = timeout( + Duration::from_secs(tool_timeout), + inner.transport.send_and_recv(&req), + ) + .await + .map_err(|_| { + anyhow!( + "MCP server `{}` timed out after {}s during tool call `{tool_name}`", + inner.config.name, + tool_timeout + ) + })? + .with_context(|| { + format!( + "MCP server `{}` error during tool call `{tool_name}`", + inner.config.name + ) + })?; + + if let Some(err) = resp.error { + bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message); + } + Ok(resp.result.unwrap_or(serde_json::Value::Null)) + } +} + +// ── McpRegistry ─────────────────────────────────────────────────────────── + +/// Registry of all connected MCP servers, with a flat tool index. +pub struct McpRegistry { + servers: Vec, + /// prefixed_name -> (server_index, original_tool_name) + tool_index: HashMap, +} + +impl McpRegistry { + /// Connect to all configured servers. Non-fatal: failures are logged and skipped. + pub async fn connect_all(configs: &[McpServerConfig]) -> Result { + let mut servers = Vec::new(); + let mut tool_index = HashMap::new(); + + for config in configs { + match McpServer::connect(config.clone()).await { + Ok(server) => { + let server_idx = servers.len(); + // Collect tools while holding the lock once, then release + let tools = server.tools().await; + for tool in &tools { + // Prefix prevents name collisions across servers + let prefixed = format!("{}__{}", config.name, tool.name); + tool_index.insert(prefixed, (server_idx, tool.name.clone())); + } + servers.push(server); + } + // Non-fatal — log and continue with remaining servers + Err(e) => { + tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e); + } + } + } + + Ok(Self { + servers, + tool_index, + }) + } + + /// All prefixed tool names across all connected servers. + pub fn tool_names(&self) -> Vec { + self.tool_index.keys().cloned().collect() + } + + /// Tool definition for a given prefixed name (cloned). + pub async fn get_tool_def(&self, prefixed_name: &str) -> Option { + let (server_idx, original_name) = self.tool_index.get(prefixed_name)?; + let inner = self.servers[*server_idx].inner.lock().await; + inner + .tools + .iter() + .find(|t| &t.name == original_name) + .cloned() + } + + /// Execute a tool by prefixed name. + pub async fn call_tool( + &self, + prefixed_name: &str, + arguments: serde_json::Value, + ) -> Result { + let (server_idx, original_name) = self + .tool_index + .get(prefixed_name) + .ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?; + let result = self.servers[*server_idx] + .call_tool(original_name, arguments) + .await?; + serde_json::to_string_pretty(&result) + .with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`")) + } + + pub fn is_empty(&self) -> bool { + self.servers.is_empty() + } + + pub fn server_count(&self) -> usize { + self.servers.len() + } + + pub fn tool_count(&self) -> usize { + self.tool_index.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::schema::McpTransport; + + #[test] + fn tool_name_prefix_format() { + let prefixed = format!("{}__{}", "filesystem", "read_file"); + assert_eq!(prefixed, "filesystem__read_file"); + } + + #[tokio::test] + async fn connect_nonexistent_command_fails_cleanly() { + // A command that doesn't exist should fail at spawn, not panic. + let config = McpServerConfig { + name: "nonexistent".to_string(), + command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(), + args: vec![], + env: HashMap::default(), + tool_timeout_secs: None, + transport: McpTransport::Stdio, + url: None, + headers: HashMap::default(), + }; + let result = McpServer::connect(config).await; + assert!(result.is_err()); + let msg = result.err().unwrap().to_string(); + assert!(msg.contains("failed to create transport"), "got: {msg}"); + } + + #[tokio::test] + async fn connect_all_nonfatal_on_single_failure() { + // If one server config is bad, connect_all should succeed (with 0 servers). + let configs = vec![McpServerConfig { + name: "bad".to_string(), + command: "/usr/bin/does_not_exist_zc_test".to_string(), + args: vec![], + env: HashMap::default(), + tool_timeout_secs: None, + transport: McpTransport::Stdio, + url: None, + headers: HashMap::default(), + }]; + let registry = McpRegistry::connect_all(&configs) + .await + .expect("connect_all should not fail"); + assert!(registry.is_empty()); + assert_eq!(registry.tool_count(), 0); + } + + #[test] + fn http_transport_requires_url() { + let config = McpServerConfig { + name: "test".into(), + transport: McpTransport::Http, + ..Default::default() + }; + let result = create_transport(&config); + assert!(result.is_err()); + } + + #[test] + fn sse_transport_requires_url() { + let config = McpServerConfig { + name: "test".into(), + transport: McpTransport::Sse, + ..Default::default() + }; + let result = create_transport(&config); + assert!(result.is_err()); + } +} diff --git a/src/tools/mcp_protocol.rs b/src/tools/mcp_protocol.rs new file mode 100644 index 000000000..c3d77e9dc --- /dev/null +++ b/src/tools/mcp_protocol.rs @@ -0,0 +1,130 @@ +//! MCP (Model Context Protocol) JSON-RPC 2.0 protocol types. +//! Protocol version: 2024-11-05 +//! Adapted from ops-mcp-server/src/protocol.rs for client use. +//! Both Serialize and Deserialize are derived — the client both sends (Serialize) +//! and receives (Deserialize) JSON-RPC messages. + +use serde::{Deserialize, Serialize}; + +pub const JSONRPC_VERSION: &str = "2.0"; +pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05"; + +// Standard JSON-RPC 2.0 error codes +#[allow(dead_code)] +pub const PARSE_ERROR: i32 = -32700; +#[allow(dead_code)] +pub const INVALID_REQUEST: i32 = -32600; +#[allow(dead_code)] +pub const METHOD_NOT_FOUND: i32 = -32601; +#[allow(dead_code)] +pub const INVALID_PARAMS: i32 = -32602; +pub const INTERNAL_ERROR: i32 = -32603; + +/// Outbound JSON-RPC request (client -> MCP server). +/// Used for both method calls (with id) and notifications (id = None). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +impl JsonRpcRequest { + /// Create a method call request with a numeric id. + pub fn new(id: u64, method: impl Into, params: serde_json::Value) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + id: Some(serde_json::Value::Number(id.into())), + method: method.into(), + params: Some(params), + } + } + + /// Create a notification — no id, no response expected from server. + pub fn notification(method: impl Into, params: serde_json::Value) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + id: None, + method: method.into(), + params: Some(params), + } + } +} + +/// Inbound JSON-RPC response (MCP server -> client). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcResponse { + pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// JSON-RPC error object embedded in a response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// A tool advertised by an MCP server (from `tools/list` response). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpToolDef { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(rename = "inputSchema")] + pub input_schema: serde_json::Value, +} + +/// Expected shape of the `tools/list` result payload. +#[derive(Debug, Deserialize)] +pub struct McpToolsListResult { + pub tools: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_serializes_with_id() { + let req = JsonRpcRequest::new(1, "tools/list", serde_json::json!({})); + let s = serde_json::to_string(&req).unwrap(); + assert!(s.contains("\"id\":1")); + assert!(s.contains("\"method\":\"tools/list\"")); + assert!(s.contains("\"jsonrpc\":\"2.0\"")); + } + + #[test] + fn notification_omits_id() { + let notif = + JsonRpcRequest::notification("notifications/initialized", serde_json::json!({})); + let s = serde_json::to_string(¬if).unwrap(); + assert!(!s.contains("\"id\"")); + } + + #[test] + fn response_deserializes() { + let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#; + let resp: JsonRpcResponse = serde_json::from_str(json).unwrap(); + assert!(resp.result.is_some()); + assert!(resp.error.is_none()); + } + + #[test] + fn tool_def_deserializes_input_schema() { + let json = r#"{"name":"read_file","description":"Read a file","inputSchema":{"type":"object","properties":{"path":{"type":"string"}}}}"#; + let def: McpToolDef = serde_json::from_str(json).unwrap(); + assert_eq!(def.name, "read_file"); + assert!(def.input_schema.is_object()); + } +} diff --git a/src/tools/mcp_tool.rs b/src/tools/mcp_tool.rs new file mode 100644 index 000000000..d2c08a84e --- /dev/null +++ b/src/tools/mcp_tool.rs @@ -0,0 +1,68 @@ +//! Wraps a discovered MCP tool as a zeroclaw [`Tool`] so it is dispatched +//! through the existing tool registry and agent loop without modification. + +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::tools::mcp_client::McpRegistry; +use crate::tools::mcp_protocol::McpToolDef; +use crate::tools::traits::{Tool, ToolResult}; + +/// A zeroclaw [`Tool`] backed by an MCP server tool. +/// +/// The `prefixed_name` (e.g. `filesystem__read_file`) is what the agent loop +/// sees. The registry knows how to route it to the correct server. +pub struct McpToolWrapper { + /// Prefixed name: `__`. + prefixed_name: String, + /// Description extracted from the MCP tool definition. Stored as an owned + /// String so that `description()` can return `&str` with self's lifetime. + description: String, + /// JSON schema for the tool's input parameters. + input_schema: serde_json::Value, + /// Shared registry — used to dispatch actual tool calls. + registry: Arc, +} + +impl McpToolWrapper { + pub fn new(prefixed_name: String, def: McpToolDef, registry: Arc) -> Self { + let description = def.description.unwrap_or_else(|| "MCP tool".to_string()); + Self { + prefixed_name, + description, + input_schema: def.input_schema, + registry, + } + } +} + +#[async_trait] +impl Tool for McpToolWrapper { + fn name(&self) -> &str { + &self.prefixed_name + } + + fn description(&self) -> &str { + &self.description + } + + fn parameters_schema(&self) -> serde_json::Value { + self.input_schema.clone() + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + match self.registry.call_tool(&self.prefixed_name, args).await { + Ok(output) => Ok(ToolResult { + success: true, + output, + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e.to_string()), + }), + } + } +} diff --git a/src/tools/mcp_transport.rs b/src/tools/mcp_transport.rs new file mode 100644 index 000000000..85415ef96 --- /dev/null +++ b/src/tools/mcp_transport.rs @@ -0,0 +1,868 @@ +//! MCP transport abstraction — supports stdio, SSE, and HTTP transports. + +use std::borrow::Cow; + +use anyhow::{anyhow, bail, Context, Result}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, Command}; +use tokio::sync::{oneshot, Mutex, Notify}; +use tokio::time::{timeout, Duration}; +use tokio_stream::StreamExt; + +use crate::config::schema::{McpServerConfig, McpTransport}; +use crate::tools::mcp_protocol::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR}; + +/// Maximum bytes for a single JSON-RPC response. +const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB + +/// Timeout for init/list operations. +const RECV_TIMEOUT_SECS: u64 = 30; + +// ── Transport Trait ────────────────────────────────────────────────────── + +/// Abstract transport for MCP communication. +#[async_trait::async_trait] +pub trait McpTransportConn: Send + Sync { + /// Send a JSON-RPC request and receive the response. + async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result; + + /// Close the connection. + async fn close(&mut self) -> Result<()>; +} + +// ── Stdio Transport ────────────────────────────────────────────────────── + +/// Stdio-based transport (spawn local process). +pub struct StdioTransport { + _child: Child, + stdin: tokio::process::ChildStdin, + stdout_lines: tokio::io::Lines>, +} + +impl StdioTransport { + pub fn new(config: &McpServerConfig) -> Result { + let mut child = Command::new(&config.command) + .args(&config.args) + .envs(&config.env) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn() + .with_context(|| format!("failed to spawn MCP server `{}`", config.name))?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?; + let stdout_lines = BufReader::new(stdout).lines(); + + Ok(Self { + _child: child, + stdin, + stdout_lines, + }) + } + + async fn send_raw(&mut self, line: &str) -> Result<()> { + self.stdin + .write_all(line.as_bytes()) + .await + .context("failed to write to MCP server stdin")?; + self.stdin + .write_all(b"\n") + .await + .context("failed to write newline to MCP server stdin")?; + self.stdin.flush().await.context("failed to flush stdin")?; + Ok(()) + } + + async fn recv_raw(&mut self) -> Result { + let line = self + .stdout_lines + .next_line() + .await? + .ok_or_else(|| anyhow!("MCP server closed stdout"))?; + if line.len() > MAX_LINE_BYTES { + bail!("MCP response too large: {} bytes", line.len()); + } + Ok(line) + } +} + +#[async_trait::async_trait] +impl McpTransportConn for StdioTransport { + async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result { + let line = serde_json::to_string(request)?; + self.send_raw(&line).await?; + if request.id.is_none() { + return Ok(JsonRpcResponse { + jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(), + id: None, + result: None, + error: None, + }); + } + let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw()) + .await + .context("timeout waiting for MCP response")??; + let resp: JsonRpcResponse = serde_json::from_str(&resp_line) + .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?; + Ok(resp) + } + + async fn close(&mut self) -> Result<()> { + let _ = self.stdin.shutdown().await; + Ok(()) + } +} + +// ── HTTP Transport ─────────────────────────────────────────────────────── + +/// HTTP-based transport (POST requests). +pub struct HttpTransport { + url: String, + client: reqwest::Client, + headers: std::collections::HashMap, +} + +impl HttpTransport { + pub fn new(config: &McpServerConfig) -> Result { + let url = config + .url + .as_ref() + .ok_or_else(|| anyhow!("URL required for HTTP transport"))? + .clone(); + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(120)) + .build() + .context("failed to build HTTP client")?; + + Ok(Self { + url, + client, + headers: config.headers.clone(), + }) + } +} + +#[async_trait::async_trait] +impl McpTransportConn for HttpTransport { + async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result { + let body = serde_json::to_string(request)?; + + let mut req = self.client.post(&self.url).body(body); + for (key, value) in &self.headers { + req = req.header(key, value); + } + + let resp = req + .send() + .await + .context("HTTP request to MCP server failed")?; + + if !resp.status().is_success() { + bail!("MCP server returned HTTP {}", resp.status()); + } + + if request.id.is_none() { + return Ok(JsonRpcResponse { + jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(), + id: None, + result: None, + error: None, + }); + } + + let resp_text = resp.text().await.context("failed to read HTTP response")?; + let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text) + .with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?; + + Ok(mcp_resp) + } + + async fn close(&mut self) -> Result<()> { + Ok(()) + } +} + +// ── SSE Transport ───────────────────────────────────────────────────────── + +/// SSE-based transport (HTTP POST for requests, SSE for responses). +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum SseStreamState { + Unknown, + Connected, + Unsupported, +} + +pub struct SseTransport { + sse_url: String, + server_name: String, + client: reqwest::Client, + headers: std::collections::HashMap, + stream_state: SseStreamState, + shared: std::sync::Arc>, + notify: std::sync::Arc, + shutdown_tx: Option>, + reader_task: Option>, +} + +impl SseTransport { + pub fn new(config: &McpServerConfig) -> Result { + let sse_url = config + .url + .as_ref() + .ok_or_else(|| anyhow!("URL required for SSE transport"))? + .clone(); + + let client = reqwest::Client::builder() + .build() + .context("failed to build HTTP client")?; + + Ok(Self { + sse_url, + server_name: config.name.clone(), + client, + headers: config.headers.clone(), + stream_state: SseStreamState::Unknown, + shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())), + notify: std::sync::Arc::new(Notify::new()), + shutdown_tx: None, + reader_task: None, + }) + } + + async fn ensure_connected(&mut self) -> Result<()> { + if self.stream_state == SseStreamState::Unsupported { + return Ok(()); + } + if let Some(task) = &self.reader_task { + if !task.is_finished() { + self.stream_state = SseStreamState::Connected; + return Ok(()); + } + } + + let mut req = self + .client + .get(&self.sse_url) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache"); + for (key, value) in &self.headers { + req = req.header(key, value); + } + + let resp = req.send().await.context("SSE GET to MCP server failed")?; + if resp.status() == reqwest::StatusCode::NOT_FOUND + || resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED + { + self.stream_state = SseStreamState::Unsupported; + return Ok(()); + } + if !resp.status().is_success() { + return Err(anyhow!("MCP server returned HTTP {}", resp.status())); + } + let is_event_stream = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream")); + if !is_event_stream { + self.stream_state = SseStreamState::Unsupported; + return Ok(()); + } + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + self.shutdown_tx = Some(shutdown_tx); + + let shared = self.shared.clone(); + let notify = self.notify.clone(); + let sse_url = self.sse_url.clone(); + let server_name = self.server_name.clone(); + + self.reader_task = Some(tokio::spawn(async move { + let stream = resp + .bytes_stream() + .map(|item| item.map_err(std::io::Error::other)); + let reader = tokio_util::io::StreamReader::new(stream); + let mut lines = BufReader::new(reader).lines(); + + let mut cur_event: Option = None; + let mut cur_id: Option = None; + let mut cur_data: Vec = Vec::new(); + + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + } + line = lines.next_line() => { + let Ok(line_opt) = line else { break; }; + let Some(mut line) = line_opt else { break; }; + if line.ends_with('\r') { + line.pop(); + } + if line.is_empty() { + if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() { + continue; + } + let event = cur_event.take(); + let data = cur_data.join("\n"); + cur_data.clear(); + let id = cur_id.take(); + handle_sse_event( + &server_name, + &sse_url, + &shared, + ¬ify, + event.as_deref(), + id.as_deref(), + data, + ) + .await; + continue; + } + + if line.starts_with(':') { + continue; + } + + if let Some(rest) = line.strip_prefix("event:") { + cur_event = Some(rest.trim().to_string()); + continue; + } + if let Some(rest) = line.strip_prefix("data:") { + let rest = rest.strip_prefix(' ').unwrap_or(rest); + cur_data.push(rest.to_string()); + continue; + } + if let Some(rest) = line.strip_prefix("id:") { + cur_id = Some(rest.trim().to_string()); + } + } + } + } + + let pending = { + let mut guard = shared.lock().await; + std::mem::take(&mut guard.pending) + }; + for (_, tx) in pending { + let _ = tx.send(JsonRpcResponse { + jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(), + id: None, + result: None, + error: Some(JsonRpcError { + code: INTERNAL_ERROR, + message: "SSE connection closed".to_string(), + data: None, + }), + }); + } + })); + self.stream_state = SseStreamState::Connected; + + Ok(()) + } + + async fn get_message_url(&self) -> Result<(String, bool)> { + let guard = self.shared.lock().await; + if let Some(url) = &guard.message_url { + return Ok((url.clone(), guard.message_url_from_endpoint)); + } + drop(guard); + + let derived = derive_message_url(&self.sse_url, "messages") + .or_else(|| derive_message_url(&self.sse_url, "message")) + .ok_or_else(|| anyhow!("invalid SSE URL"))?; + let mut guard = self.shared.lock().await; + if guard.message_url.is_none() { + guard.message_url = Some(derived.clone()); + guard.message_url_from_endpoint = false; + } + Ok((derived, false)) + } +} + +#[derive(Default)] +struct SseSharedState { + message_url: Option, + message_url_from_endpoint: bool, + pending: std::collections::HashMap>, +} + +fn derive_message_url(sse_url: &str, message_path: &str) -> Option { + let url = reqwest::Url::parse(sse_url).ok()?; + let mut segments: Vec<&str> = url.path_segments()?.collect(); + if segments.is_empty() { + return None; + } + if segments.last().copied() == Some("sse") { + segments.pop(); + segments.push(message_path); + let mut new_url = url.clone(); + new_url.set_path(&format!("/{}", segments.join("/"))); + return Some(new_url.to_string()); + } + let mut new_url = url.clone(); + let mut path = url.path().trim_end_matches('/').to_string(); + path.push('/'); + path.push_str(message_path); + new_url.set_path(&path); + Some(new_url.to_string()) +} + +async fn handle_sse_event( + server_name: &str, + sse_url: &str, + shared: &std::sync::Arc>, + notify: &std::sync::Arc, + event: Option<&str>, + _id: Option<&str>, + data: String, +) { + let event = event.unwrap_or("message"); + let trimmed = data.trim(); + if trimmed.is_empty() { + return; + } + + if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") { + if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) { + let mut guard = shared.lock().await; + guard.message_url = Some(url); + guard.message_url_from_endpoint = true; + drop(guard); + notify.notify_waiters(); + } + return; + } + + if !event.eq_ignore_ascii_case("message") { + return; + } + + let Ok(value) = serde_json::from_str::(trimmed) else { + return; + }; + + let Ok(resp) = serde_json::from_value::(value.clone()) else { + let _ = serde_json::from_value::(value); + return; + }; + + let Some(id_val) = resp.id.clone() else { + return; + }; + let id = match id_val.as_u64() { + Some(v) => v, + None => return, + }; + + let tx = { + let mut guard = shared.lock().await; + guard.pending.remove(&id) + }; + if let Some(tx) = tx { + let _ = tx.send(resp); + } else { + tracing::debug!( + "MCP SSE `{}` received response for unknown id {}", + server_name, + id + ); + } +} + +fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option { + if data.starts_with('{') { + let v: serde_json::Value = serde_json::from_str(data).ok()?; + let endpoint = v.get("endpoint")?.as_str()?; + return parse_endpoint_from_data(sse_url, endpoint); + } + if data.starts_with("http://") || data.starts_with("https://") { + return Some(data.to_string()); + } + let base = reqwest::Url::parse(sse_url).ok()?; + base.join(data).ok().map(|u| u.to_string()) +} + +fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> { + let text = resp_text.trim_start_matches('\u{feff}'); + let mut current_data_lines: Vec<&str> = Vec::new(); + let mut last_event_data_lines: Vec<&str> = Vec::new(); + + for raw_line in text.lines() { + let line = raw_line.trim_end_matches('\r').trim_start(); + if line.is_empty() { + if !current_data_lines.is_empty() { + last_event_data_lines = std::mem::take(&mut current_data_lines); + } + continue; + } + + if line.starts_with(':') { + continue; + } + + if let Some(rest) = line.strip_prefix("data:") { + let rest = rest.strip_prefix(' ').unwrap_or(rest); + current_data_lines.push(rest); + } + } + + if !current_data_lines.is_empty() { + last_event_data_lines = current_data_lines; + } + + if last_event_data_lines.is_empty() { + return Cow::Borrowed(text.trim()); + } + + if last_event_data_lines.len() == 1 { + return Cow::Borrowed(last_event_data_lines[0].trim()); + } + + let joined = last_event_data_lines.join("\n"); + Cow::Owned(joined.trim().to_string()) +} + +async fn read_first_jsonrpc_from_sse_response( + resp: reqwest::Response, +) -> Result> { + let stream = resp + .bytes_stream() + .map(|item| item.map_err(std::io::Error::other)); + let reader = tokio_util::io::StreamReader::new(stream); + let mut lines = BufReader::new(reader).lines(); + + let mut cur_event: Option = None; + let mut cur_data: Vec = Vec::new(); + + while let Ok(line_opt) = lines.next_line().await { + let Some(mut line) = line_opt else { break }; + if line.ends_with('\r') { + line.pop(); + } + if line.is_empty() { + if cur_event.is_none() && cur_data.is_empty() { + continue; + } + let event = cur_event.take(); + let data = cur_data.join("\n"); + cur_data.clear(); + + let event = event.unwrap_or_else(|| "message".to_string()); + if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") + { + continue; + } + if !event.eq_ignore_ascii_case("message") { + continue; + } + + let trimmed = data.trim(); + if trimmed.is_empty() { + continue; + } + let json_str = extract_json_from_sse_text(trimmed); + if let Ok(resp) = serde_json::from_str::(json_str.as_ref()) { + return Ok(Some(resp)); + } + continue; + } + + if line.starts_with(':') { + continue; + } + if let Some(rest) = line.strip_prefix("event:") { + cur_event = Some(rest.trim().to_string()); + continue; + } + if let Some(rest) = line.strip_prefix("data:") { + let rest = rest.strip_prefix(' ').unwrap_or(rest); + cur_data.push(rest.to_string()); + } + } + + Ok(None) +} + +#[async_trait::async_trait] +impl McpTransportConn for SseTransport { + async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result { + self.ensure_connected().await?; + + let id = request.id.as_ref().and_then(|v| v.as_u64()); + let body = serde_json::to_string(request)?; + + let (mut message_url, mut from_endpoint) = self.get_message_url().await?; + if self.stream_state == SseStreamState::Connected && !from_endpoint { + for _ in 0..3 { + { + let guard = self.shared.lock().await; + if guard.message_url_from_endpoint { + if let Some(url) = &guard.message_url { + message_url = url.clone(); + from_endpoint = true; + break; + } + } + } + let _ = timeout(Duration::from_millis(300), self.notify.notified()).await; + } + } + let primary_url = if from_endpoint { + message_url.clone() + } else { + self.sse_url.clone() + }; + let secondary_url = if message_url == self.sse_url { + None + } else if primary_url == message_url { + Some(self.sse_url.clone()) + } else { + Some(message_url.clone()) + }; + let has_secondary = secondary_url.is_some(); + + let mut rx = None; + if let Some(id) = id { + if self.stream_state == SseStreamState::Connected { + let (tx, ch) = oneshot::channel(); + { + let mut guard = self.shared.lock().await; + guard.pending.insert(id, tx); + } + rx = Some((id, ch)); + } + } + + let mut got_direct = None; + let mut last_status = None; + + for (i, url) in std::iter::once(primary_url) + .chain(secondary_url.into_iter()) + .enumerate() + { + let mut req = self + .client + .post(&url) + .timeout(Duration::from_secs(120)) + .body(body.clone()) + .header("Content-Type", "application/json"); + for (key, value) in &self.headers { + req = req.header(key, value); + } + if !self + .headers + .keys() + .any(|k| k.eq_ignore_ascii_case("Accept")) + { + req = req.header("Accept", "application/json, text/event-stream"); + } + + let resp = req.send().await.context("SSE POST to MCP server failed")?; + let status = resp.status(); + last_status = Some(status); + + if (status == reqwest::StatusCode::NOT_FOUND + || status == reqwest::StatusCode::METHOD_NOT_ALLOWED) + && i == 0 + { + continue; + } + + if !status.is_success() { + break; + } + + if request.id.is_none() { + got_direct = Some(JsonRpcResponse { + jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(), + id: None, + result: None, + error: None, + }); + break; + } + + let is_sse = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream")); + + if is_sse { + if i == 0 && has_secondary { + match timeout( + Duration::from_secs(3), + read_first_jsonrpc_from_sse_response(resp), + ) + .await + { + Ok(res) => { + if let Some(resp) = res? { + got_direct = Some(resp); + } + break; + } + Err(_) => continue, + } + } + if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? { + got_direct = Some(resp); + } + break; + } + + let text = if i == 0 && has_secondary { + match timeout(Duration::from_secs(3), resp.text()).await { + Ok(Ok(t)) => t, + Ok(Err(_)) => String::new(), + Err(_) => continue, + } + } else { + resp.text().await.unwrap_or_default() + }; + let trimmed = text.trim(); + if !trimmed.is_empty() { + let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") { + extract_json_from_sse_text(trimmed) + } else { + Cow::Borrowed(trimmed) + }; + if let Ok(mcp_resp) = serde_json::from_str::(json_str.as_ref()) { + got_direct = Some(mcp_resp); + } + } + break; + } + + if let Some((id, _)) = rx.as_ref() { + if got_direct.is_some() { + let mut guard = self.shared.lock().await; + guard.pending.remove(id); + } else if let Some(status) = last_status { + if !status.is_success() { + let mut guard = self.shared.lock().await; + guard.pending.remove(id); + } + } + } + + if let Some(resp) = got_direct { + return Ok(resp); + } + + if let Some(status) = last_status { + if !status.is_success() { + bail!("MCP server returned HTTP {}", status); + } + } else { + bail!("MCP request not sent"); + } + + let Some((_id, rx)) = rx else { + bail!("MCP server returned no response"); + }; + + rx.await.map_err(|_| anyhow!("SSE response channel closed")) + } + + async fn close(&mut self) -> Result<()> { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + if let Some(task) = self.reader_task.take() { + task.abort(); + } + Ok(()) + } +} + +// ── Factory ────────────────────────────────────────────────────────────── + +/// Create a transport based on config. +pub fn create_transport(config: &McpServerConfig) -> Result> { + match config.transport { + McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)), + McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)), + McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)), + } +} + +// ── Tests ───────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transport_default_is_stdio() { + let config = McpServerConfig::default(); + assert_eq!(config.transport, McpTransport::Stdio); + } + + #[test] + fn test_http_transport_requires_url() { + let config = McpServerConfig { + name: "test".into(), + transport: McpTransport::Http, + ..Default::default() + }; + assert!(HttpTransport::new(&config).is_err()); + } + + #[test] + fn test_sse_transport_requires_url() { + let config = McpServerConfig { + name: "test".into(), + transport: McpTransport::Sse, + ..Default::default() + }; + assert!(SseTransport::new(&config).is_err()); + } + + #[test] + fn test_extract_json_from_sse_data_no_space() { + let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n"; + let extracted = extract_json_from_sse_text(input); + let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap(); + } + + #[test] + fn test_extract_json_from_sse_with_event_and_id() { + let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n"; + let extracted = extract_json_from_sse_text(input); + let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap(); + } + + #[test] + fn test_extract_json_from_sse_multiline_data() { + let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n"; + let extracted = extract_json_from_sse_text(input); + let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap(); + } + + #[test] + fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() { + let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n"; + let extracted = extract_json_from_sse_text(input); + let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap(); + } + + #[test] + fn test_extract_json_from_sse_uses_last_event_with_data() { + let input = + ": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n"; + let extracted = extract_json_from_sse_text(input); + let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap(); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 18e22aa53..ae6b623fa 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -40,6 +40,10 @@ pub mod hardware_memory_map; pub mod hardware_memory_read; pub mod http_request; pub mod image_info; +pub mod mcp_client; +pub mod mcp_protocol; +pub mod mcp_tool; +pub mod mcp_transport; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; diff --git a/web/src/lib/uuid.ts b/web/src/lib/uuid.ts new file mode 100644 index 000000000..e45bc5da9 --- /dev/null +++ b/web/src/lib/uuid.ts @@ -0,0 +1,27 @@ +/** + * Generate a UUID v4 string. + * + * Uses `crypto.randomUUID()` when available (modern browsers, secure contexts) + * and falls back to a manual implementation backed by `crypto.getRandomValues()` + * for older browsers (e.g. Safari < 15.4, some Electron/Raspberry-Pi builds). + * + * Closes #3303, #3261. + */ +export function generateUUID(): string { + if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') { + return crypto.randomUUID(); + } + + // Fallback: RFC 4122 version 4 UUID via getRandomValues + // crypto must exist if we reached here (only randomUUID is missing) + const c = globalThis.crypto; + const bytes = new Uint8Array(16); + c.getRandomValues(bytes); + + // Set version (4) and variant (10xx) bits per RFC 4122 + bytes[6] = (bytes[6]! & 0x0f) | 0x40; + bytes[8] = (bytes[8]! & 0x3f) | 0x80; + + const hex = Array.from(bytes, (b) => b.toString(16).padStart(2, '0')).join(''); + return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20)}`; +} diff --git a/web/src/lib/ws.ts b/web/src/lib/ws.ts index 4772a7e74..fdb7315d4 100644 --- a/web/src/lib/ws.ts +++ b/web/src/lib/ws.ts @@ -1,5 +1,6 @@ import type { WsMessage } from '../types/api'; import { getToken } from './auth'; +import { generateUUID } from './uuid'; export type WsMessageHandler = (msg: WsMessage) => void; export type WsOpenHandler = () => void; @@ -26,7 +27,7 @@ const SESSION_STORAGE_KEY = 'zeroclaw_session_id'; function getOrCreateSessionId(): string { let id = sessionStorage.getItem(SESSION_STORAGE_KEY); if (!id) { - id = crypto.randomUUID(); + id = generateUUID(); sessionStorage.setItem(SESSION_STORAGE_KEY, id); } return id; diff --git a/web/src/pages/AgentChat.tsx b/web/src/pages/AgentChat.tsx index 8311707e2..31e46fd0b 100644 --- a/web/src/pages/AgentChat.tsx +++ b/web/src/pages/AgentChat.tsx @@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback } from 'react'; import { Send, Bot, User, AlertCircle, Copy, Check } from 'lucide-react'; import type { WsMessage } from '@/types/api'; import { WebSocketClient } from '@/lib/ws'; +import { generateUUID } from '@/lib/uuid'; interface ChatMessage { id: string; @@ -53,7 +54,7 @@ export default function AgentChat() { setMessages((prev) => [ ...prev, { - id: crypto.randomUUID(), + id: generateUUID(), role: 'agent', content, timestamp: new Date(), @@ -69,7 +70,7 @@ export default function AgentChat() { setMessages((prev) => [ ...prev, { - id: crypto.randomUUID(), + id: generateUUID(), role: 'agent', content: `[Tool Call] ${msg.name ?? 'unknown'}(${JSON.stringify(msg.args ?? {})})`, timestamp: new Date(), @@ -81,7 +82,7 @@ export default function AgentChat() { setMessages((prev) => [ ...prev, { - id: crypto.randomUUID(), + id: generateUUID(), role: 'agent', content: `[Tool Result] ${msg.output ?? ''}`, timestamp: new Date(), @@ -93,7 +94,7 @@ export default function AgentChat() { setMessages((prev) => [ ...prev, { - id: crypto.randomUUID(), + id: generateUUID(), role: 'agent', content: `[Error] ${msg.message ?? 'Unknown error'}`, timestamp: new Date(), @@ -124,7 +125,7 @@ export default function AgentChat() { setMessages((prev) => [ ...prev, { - id: crypto.randomUUID(), + id: generateUUID(), role: 'user', content: trimmed, timestamp: new Date(),