zeroclaw/src/tools/mcp_client.rs
VirtualHotBar 0a42329ca5 fix: session leftovers after db47f56
- Demote session normal flow logs to debug\n- Skip session operations when CHANNEL_SESSION_CONFIG is uninitialized\n- Add spawn_blocking panic context for SQLite session manager\n- Fix fmt/clippy regressions (Box::pin large futures, cfg features, misc lints)
2026-02-28 15:23:42 +08:00

358 lines
12 KiB
Rust

//! MCP (Model Context Protocol) client — connects to external tool servers.
//!
//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use serde_json::json;
use tokio::sync::Mutex;
use tokio::time::{timeout, Duration};
use crate::config::schema::McpServerConfig;
use crate::tools::mcp_protocol::{
JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION,
};
use crate::tools::mcp_transport::{create_transport, McpTransportConn};
/// Timeout for receiving a response from an MCP server during init/list.
/// Prevents a hung server from blocking the daemon indefinitely.
const RECV_TIMEOUT_SECS: u64 = 30;
/// Default timeout for tool calls (seconds) when not configured per-server.
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
// ── Internal server state ──────────────────────────────────────────────────
struct McpServerInner {
config: McpServerConfig,
transport: Box<dyn McpTransportConn>,
next_id: AtomicU64,
tools: Vec<McpToolDef>,
}
// ── McpServer ──────────────────────────────────────────────────────────────
/// A live connection to one MCP server (any transport).
#[derive(Clone)]
pub struct McpServer {
inner: Arc<Mutex<McpServerInner>>,
}
impl McpServer {
/// Connect to the server, perform the initialize handshake, and fetch the tool list.
pub async fn connect(config: McpServerConfig) -> Result<Self> {
// Create transport based on config
let mut transport = create_transport(&config).with_context(|| {
format!(
"failed to create transport for MCP server `{}`",
config.name
)
})?;
// Initialize handshake
let id = 1u64;
let init_req = JsonRpcRequest::new(
id,
"initialize",
json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {
"name": "zeroclaw",
"version": env!("CARGO_PKG_VERSION")
}
}),
);
let init_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
transport.send_and_recv(&init_req),
)
.await
.with_context(|| {
format!(
"MCP server `{}` timed out after {}s waiting for initialize response",
config.name, RECV_TIMEOUT_SECS
)
})??;
if init_resp.error.is_some() {
bail!(
"MCP server `{}` rejected initialize: {:?}",
config.name,
init_resp.error
);
}
// Notify server that client is initialized (no response expected for notifications)
// For notifications, we send but don't wait for response
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
// Best effort - ignore errors for notifications
let _ = transport.send_and_recv(&notif).await;
// Fetch available tools
let id = 2u64;
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
let list_resp = timeout(
Duration::from_secs(RECV_TIMEOUT_SECS),
transport.send_and_recv(&list_req),
)
.await
.with_context(|| {
format!(
"MCP server `{}` timed out after {}s waiting for tools/list response",
config.name, RECV_TIMEOUT_SECS
)
})??;
let result = list_resp
.result
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
let tool_list: McpToolsListResult = serde_json::from_value(result)
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
let tool_count = tool_list.tools.len();
let inner = McpServerInner {
config,
transport,
next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
tools: tool_list.tools,
};
tracing::info!(
"MCP server `{}` connected — {} tool(s) available",
inner.config.name,
tool_count
);
Ok(Self {
inner: Arc::new(Mutex::new(inner)),
})
}
/// Tools advertised by this server.
pub async fn tools(&self) -> Vec<McpToolDef> {
self.inner.lock().await.tools.clone()
}
/// Server display name.
pub async fn name(&self) -> String {
self.inner.lock().await.config.name.clone()
}
/// Call a tool on this server. Returns the raw JSON result.
pub async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<serde_json::Value> {
let mut inner = self.inner.lock().await;
let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
let req = JsonRpcRequest::new(
id,
"tools/call",
json!({ "name": tool_name, "arguments": arguments }),
);
// Use per-server tool timeout if configured, otherwise default.
// Cap at MAX_TOOL_TIMEOUT_SECS for safety.
let tool_timeout = inner
.config
.tool_timeout_secs
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
.min(MAX_TOOL_TIMEOUT_SECS);
let resp = timeout(
Duration::from_secs(tool_timeout),
inner.transport.send_and_recv(&req),
)
.await
.map_err(|_| {
anyhow!(
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
inner.config.name,
tool_timeout
)
})?
.with_context(|| {
format!(
"MCP server `{}` error during tool call `{tool_name}`",
inner.config.name
)
})?;
if let Some(err) = resp.error {
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
}
Ok(resp.result.unwrap_or(serde_json::Value::Null))
}
}
// ── McpRegistry ───────────────────────────────────────────────────────────
/// Registry of all connected MCP servers, with a flat tool index.
pub struct McpRegistry {
servers: Vec<McpServer>,
/// prefixed_name → (server_index, original_tool_name)
tool_index: HashMap<String, (usize, String)>,
}
impl McpRegistry {
/// Connect to all configured servers. Non-fatal: failures are logged and skipped.
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
let mut servers = Vec::new();
let mut tool_index = HashMap::new();
for config in configs {
match McpServer::connect(config.clone()).await {
Ok(server) => {
let server_idx = servers.len();
// Collect tools while holding the lock once, then release
let tools = server.tools().await;
for tool in &tools {
// Prefix prevents name collisions across servers
let prefixed = format!("{}__{}", config.name, tool.name);
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
}
servers.push(server);
}
// Non-fatal — log and continue with remaining servers
Err(e) => {
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
}
}
}
Ok(Self {
servers,
tool_index,
})
}
/// All prefixed tool names across all connected servers.
pub fn tool_names(&self) -> Vec<String> {
self.tool_index.keys().cloned().collect()
}
/// Tool definition for a given prefixed name (cloned).
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
let inner = self.servers[*server_idx].inner.lock().await;
inner
.tools
.iter()
.find(|t| &t.name == original_name)
.cloned()
}
/// Execute a tool by prefixed name.
pub async fn call_tool(
&self,
prefixed_name: &str,
arguments: serde_json::Value,
) -> Result<String> {
let (server_idx, original_name) = self
.tool_index
.get(prefixed_name)
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
let result = self.servers[*server_idx]
.call_tool(original_name, arguments)
.await?;
serde_json::to_string_pretty(&result)
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
}
pub fn is_empty(&self) -> bool {
self.servers.is_empty()
}
pub fn server_count(&self) -> usize {
self.servers.len()
}
pub fn tool_count(&self) -> usize {
self.tool_index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::schema::McpTransport;
#[test]
fn tool_name_prefix_format() {
let prefixed = format!("{}__{}", "filesystem", "read_file");
assert_eq!(prefixed, "filesystem__read_file");
}
#[tokio::test]
async fn connect_nonexistent_command_fails_cleanly() {
// A command that doesn't exist should fail at spawn, not panic.
let config = McpServerConfig {
name: "nonexistent".to_string(),
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
args: vec![],
env: std::collections::HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: std::collections::HashMap::default(),
};
let result = McpServer::connect(config).await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("failed to create transport"), "got: {msg}");
}
#[tokio::test]
async fn connect_all_nonfatal_on_single_failure() {
// If one server config is bad, connect_all should succeed (with 0 servers).
let configs = vec![McpServerConfig {
name: "bad".to_string(),
command: "/usr/bin/does_not_exist_zc_test".to_string(),
args: vec![],
env: std::collections::HashMap::default(),
tool_timeout_secs: None,
transport: McpTransport::Stdio,
url: None,
headers: std::collections::HashMap::default(),
}];
let registry = McpRegistry::connect_all(&configs)
.await
.expect("connect_all should not fail");
assert!(registry.is_empty());
assert_eq!(registry.tool_count(), 0);
}
#[test]
fn http_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Http,
..Default::default()
};
let result = create_transport(&config);
assert!(result.is_err());
}
#[test]
fn sse_transport_requires_url() {
let config = McpServerConfig {
name: "test".into(),
transport: McpTransport::Sse,
..Default::default()
};
let result = create_transport(&config);
assert!(result.is_err());
}
}