- Demote session normal flow logs to debug\n- Skip session operations when CHANNEL_SESSION_CONFIG is uninitialized\n- Add spawn_blocking panic context for SQLite session manager\n- Fix fmt/clippy regressions (Box::pin large futures, cfg features, misc lints)
358 lines
12 KiB
Rust
358 lines
12 KiB
Rust
//! MCP (Model Context Protocol) client — connects to external tool servers.
|
|
//!
|
|
//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::atomic::{AtomicU64, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
use serde_json::json;
|
|
use tokio::sync::Mutex;
|
|
use tokio::time::{timeout, Duration};
|
|
|
|
use crate::config::schema::McpServerConfig;
|
|
use crate::tools::mcp_protocol::{
|
|
JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION,
|
|
};
|
|
use crate::tools::mcp_transport::{create_transport, McpTransportConn};
|
|
|
|
/// Timeout for receiving a response from an MCP server during init/list.
|
|
/// Prevents a hung server from blocking the daemon indefinitely.
|
|
const RECV_TIMEOUT_SECS: u64 = 30;
|
|
|
|
/// Default timeout for tool calls (seconds) when not configured per-server.
|
|
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
|
|
|
|
/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
|
|
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
|
|
|
// ── Internal server state ──────────────────────────────────────────────────
|
|
|
|
struct McpServerInner {
|
|
config: McpServerConfig,
|
|
transport: Box<dyn McpTransportConn>,
|
|
next_id: AtomicU64,
|
|
tools: Vec<McpToolDef>,
|
|
}
|
|
|
|
// ── McpServer ──────────────────────────────────────────────────────────────
|
|
|
|
/// A live connection to one MCP server (any transport).
|
|
#[derive(Clone)]
|
|
pub struct McpServer {
|
|
inner: Arc<Mutex<McpServerInner>>,
|
|
}
|
|
|
|
impl McpServer {
|
|
/// Connect to the server, perform the initialize handshake, and fetch the tool list.
|
|
pub async fn connect(config: McpServerConfig) -> Result<Self> {
|
|
// Create transport based on config
|
|
let mut transport = create_transport(&config).with_context(|| {
|
|
format!(
|
|
"failed to create transport for MCP server `{}`",
|
|
config.name
|
|
)
|
|
})?;
|
|
|
|
// Initialize handshake
|
|
let id = 1u64;
|
|
let init_req = JsonRpcRequest::new(
|
|
id,
|
|
"initialize",
|
|
json!({
|
|
"protocolVersion": MCP_PROTOCOL_VERSION,
|
|
"capabilities": {},
|
|
"clientInfo": {
|
|
"name": "zeroclaw",
|
|
"version": env!("CARGO_PKG_VERSION")
|
|
}
|
|
}),
|
|
);
|
|
|
|
let init_resp = timeout(
|
|
Duration::from_secs(RECV_TIMEOUT_SECS),
|
|
transport.send_and_recv(&init_req),
|
|
)
|
|
.await
|
|
.with_context(|| {
|
|
format!(
|
|
"MCP server `{}` timed out after {}s waiting for initialize response",
|
|
config.name, RECV_TIMEOUT_SECS
|
|
)
|
|
})??;
|
|
|
|
if init_resp.error.is_some() {
|
|
bail!(
|
|
"MCP server `{}` rejected initialize: {:?}",
|
|
config.name,
|
|
init_resp.error
|
|
);
|
|
}
|
|
|
|
// Notify server that client is initialized (no response expected for notifications)
|
|
// For notifications, we send but don't wait for response
|
|
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
|
|
// Best effort - ignore errors for notifications
|
|
let _ = transport.send_and_recv(¬if).await;
|
|
|
|
// Fetch available tools
|
|
let id = 2u64;
|
|
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
|
|
|
|
let list_resp = timeout(
|
|
Duration::from_secs(RECV_TIMEOUT_SECS),
|
|
transport.send_and_recv(&list_req),
|
|
)
|
|
.await
|
|
.with_context(|| {
|
|
format!(
|
|
"MCP server `{}` timed out after {}s waiting for tools/list response",
|
|
config.name, RECV_TIMEOUT_SECS
|
|
)
|
|
})??;
|
|
|
|
let result = list_resp
|
|
.result
|
|
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
|
|
let tool_list: McpToolsListResult = serde_json::from_value(result)
|
|
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
|
|
|
|
let tool_count = tool_list.tools.len();
|
|
|
|
let inner = McpServerInner {
|
|
config,
|
|
transport,
|
|
next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
|
|
tools: tool_list.tools,
|
|
};
|
|
|
|
tracing::info!(
|
|
"MCP server `{}` connected — {} tool(s) available",
|
|
inner.config.name,
|
|
tool_count
|
|
);
|
|
|
|
Ok(Self {
|
|
inner: Arc::new(Mutex::new(inner)),
|
|
})
|
|
}
|
|
|
|
/// Tools advertised by this server.
|
|
pub async fn tools(&self) -> Vec<McpToolDef> {
|
|
self.inner.lock().await.tools.clone()
|
|
}
|
|
|
|
/// Server display name.
|
|
pub async fn name(&self) -> String {
|
|
self.inner.lock().await.config.name.clone()
|
|
}
|
|
|
|
/// Call a tool on this server. Returns the raw JSON result.
|
|
pub async fn call_tool(
|
|
&self,
|
|
tool_name: &str,
|
|
arguments: serde_json::Value,
|
|
) -> Result<serde_json::Value> {
|
|
let mut inner = self.inner.lock().await;
|
|
let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
|
|
let req = JsonRpcRequest::new(
|
|
id,
|
|
"tools/call",
|
|
json!({ "name": tool_name, "arguments": arguments }),
|
|
);
|
|
|
|
// Use per-server tool timeout if configured, otherwise default.
|
|
// Cap at MAX_TOOL_TIMEOUT_SECS for safety.
|
|
let tool_timeout = inner
|
|
.config
|
|
.tool_timeout_secs
|
|
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
|
|
.min(MAX_TOOL_TIMEOUT_SECS);
|
|
|
|
let resp = timeout(
|
|
Duration::from_secs(tool_timeout),
|
|
inner.transport.send_and_recv(&req),
|
|
)
|
|
.await
|
|
.map_err(|_| {
|
|
anyhow!(
|
|
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
|
|
inner.config.name,
|
|
tool_timeout
|
|
)
|
|
})?
|
|
.with_context(|| {
|
|
format!(
|
|
"MCP server `{}` error during tool call `{tool_name}`",
|
|
inner.config.name
|
|
)
|
|
})?;
|
|
|
|
if let Some(err) = resp.error {
|
|
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
|
|
}
|
|
Ok(resp.result.unwrap_or(serde_json::Value::Null))
|
|
}
|
|
}
|
|
|
|
// ── McpRegistry ───────────────────────────────────────────────────────────
|
|
|
|
/// Registry of all connected MCP servers, with a flat tool index.
|
|
pub struct McpRegistry {
|
|
servers: Vec<McpServer>,
|
|
/// prefixed_name → (server_index, original_tool_name)
|
|
tool_index: HashMap<String, (usize, String)>,
|
|
}
|
|
|
|
impl McpRegistry {
|
|
/// Connect to all configured servers. Non-fatal: failures are logged and skipped.
|
|
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
|
|
let mut servers = Vec::new();
|
|
let mut tool_index = HashMap::new();
|
|
|
|
for config in configs {
|
|
match McpServer::connect(config.clone()).await {
|
|
Ok(server) => {
|
|
let server_idx = servers.len();
|
|
// Collect tools while holding the lock once, then release
|
|
let tools = server.tools().await;
|
|
for tool in &tools {
|
|
// Prefix prevents name collisions across servers
|
|
let prefixed = format!("{}__{}", config.name, tool.name);
|
|
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
|
|
}
|
|
servers.push(server);
|
|
}
|
|
// Non-fatal — log and continue with remaining servers
|
|
Err(e) => {
|
|
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(Self {
|
|
servers,
|
|
tool_index,
|
|
})
|
|
}
|
|
|
|
/// All prefixed tool names across all connected servers.
|
|
pub fn tool_names(&self) -> Vec<String> {
|
|
self.tool_index.keys().cloned().collect()
|
|
}
|
|
|
|
/// Tool definition for a given prefixed name (cloned).
|
|
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
|
|
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
|
|
let inner = self.servers[*server_idx].inner.lock().await;
|
|
inner
|
|
.tools
|
|
.iter()
|
|
.find(|t| &t.name == original_name)
|
|
.cloned()
|
|
}
|
|
|
|
/// Execute a tool by prefixed name.
|
|
pub async fn call_tool(
|
|
&self,
|
|
prefixed_name: &str,
|
|
arguments: serde_json::Value,
|
|
) -> Result<String> {
|
|
let (server_idx, original_name) = self
|
|
.tool_index
|
|
.get(prefixed_name)
|
|
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
|
|
let result = self.servers[*server_idx]
|
|
.call_tool(original_name, arguments)
|
|
.await?;
|
|
serde_json::to_string_pretty(&result)
|
|
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.servers.is_empty()
|
|
}
|
|
|
|
pub fn server_count(&self) -> usize {
|
|
self.servers.len()
|
|
}
|
|
|
|
pub fn tool_count(&self) -> usize {
|
|
self.tool_index.len()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::config::schema::McpTransport;
|
|
|
|
#[test]
|
|
fn tool_name_prefix_format() {
|
|
let prefixed = format!("{}__{}", "filesystem", "read_file");
|
|
assert_eq!(prefixed, "filesystem__read_file");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn connect_nonexistent_command_fails_cleanly() {
|
|
// A command that doesn't exist should fail at spawn, not panic.
|
|
let config = McpServerConfig {
|
|
name: "nonexistent".to_string(),
|
|
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
|
|
args: vec![],
|
|
env: std::collections::HashMap::default(),
|
|
tool_timeout_secs: None,
|
|
transport: McpTransport::Stdio,
|
|
url: None,
|
|
headers: std::collections::HashMap::default(),
|
|
};
|
|
let result = McpServer::connect(config).await;
|
|
assert!(result.is_err());
|
|
let msg = result.err().unwrap().to_string();
|
|
assert!(msg.contains("failed to create transport"), "got: {msg}");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn connect_all_nonfatal_on_single_failure() {
|
|
// If one server config is bad, connect_all should succeed (with 0 servers).
|
|
let configs = vec![McpServerConfig {
|
|
name: "bad".to_string(),
|
|
command: "/usr/bin/does_not_exist_zc_test".to_string(),
|
|
args: vec![],
|
|
env: std::collections::HashMap::default(),
|
|
tool_timeout_secs: None,
|
|
transport: McpTransport::Stdio,
|
|
url: None,
|
|
headers: std::collections::HashMap::default(),
|
|
}];
|
|
let registry = McpRegistry::connect_all(&configs)
|
|
.await
|
|
.expect("connect_all should not fail");
|
|
assert!(registry.is_empty());
|
|
assert_eq!(registry.tool_count(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn http_transport_requires_url() {
|
|
let config = McpServerConfig {
|
|
name: "test".into(),
|
|
transport: McpTransport::Http,
|
|
..Default::default()
|
|
};
|
|
let result = create_transport(&config);
|
|
assert!(result.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn sse_transport_requires_url() {
|
|
let config = McpServerConfig {
|
|
name: "test".into(),
|
|
transport: McpTransport::Sse,
|
|
..Default::default()
|
|
};
|
|
let result = create_transport(&config);
|
|
assert!(result.is_err());
|
|
}
|
|
}
|