Merge pull request #2013 from zeroclaw-labs/issue-1380-mcp-main
feat(mcp): add external MCP server support on main
This commit is contained in:
commit
4fa8206332
22
Cargo.lock
generated
22
Cargo.lock
generated
@ -851,9 +851,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.44"
|
||||
version = "0.4.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0"
|
||||
checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118"
|
||||
dependencies = [
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
@ -3212,6 +3212,12 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.12.1"
|
||||
@ -4103,7 +4109,7 @@ dependencies = [
|
||||
"core-foundation-sys",
|
||||
"futures-core",
|
||||
"io-kit-sys 0.5.0",
|
||||
"linux-raw-sys",
|
||||
"linux-raw-sys 0.11.0",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
@ -5601,15 +5607,15 @@ dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"linux-raw-sys 0.12.1",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.37"
|
||||
version = "0.23.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4"
|
||||
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"log",
|
||||
@ -6101,9 +6107,9 @@ checksum = "dc6fe69c597f9c37bfeeeeeb33da3530379845f10be461a66d16d03eca2ded77"
|
||||
|
||||
[[package]]
|
||||
name = "shellexpand"
|
||||
version = "3.1.2"
|
||||
version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32824fab5e16e6c4d86dc1ba84489390419a39f97699852b66480bb87d297ed8"
|
||||
checksum = "8b1fdf65dd6331831494dd616b30351c38e96e45921a27745cf98490458b90bb"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
]
|
||||
|
||||
@ -4642,7 +4642,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||
let mut built_tools = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
@ -4656,7 +4656,44 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
));
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
config.mcp.servers.len()
|
||||
);
|
||||
match crate::tools::McpRegistry::connect_all(&config.mcp.servers).await {
|
||||
Ok(registry) => {
|
||||
let registry = std::sync::Arc::new(registry);
|
||||
let names = registry.tool_names();
|
||||
let mut registered = 0usize;
|
||||
for name in names {
|
||||
if let Some(def) = registry.get_tool_def(&name).await {
|
||||
let wrapper = crate::tools::McpToolWrapper::new(
|
||||
name,
|
||||
def,
|
||||
std::sync::Arc::clone(®istry),
|
||||
);
|
||||
built_tools.push(Box::new(wrapper));
|
||||
registered += 1;
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
"MCP: {} tool(s) registered from {} server(s)",
|
||||
registered,
|
||||
registry.server_count()
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal — daemon continues with the tools registered above.
|
||||
tracing::error!("MCP registry failed to initialize: {e:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_registry = Arc::new(built_tools);
|
||||
|
||||
let skills = crate::skills::load_skills_with_config(&workspace, &config);
|
||||
|
||||
|
||||
@ -269,6 +269,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub agents_ipc: AgentsIpcConfig,
|
||||
|
||||
/// External MCP server connections (`[mcp]`).
|
||||
#[serde(default, alias = "mcpServers")]
|
||||
pub mcp: McpConfig,
|
||||
|
||||
/// Vision support override for the active provider/model.
|
||||
/// - `None` (default): use provider's built-in default
|
||||
/// - `Some(true)`: force vision support on (e.g. Ollama running llava)
|
||||
@ -530,6 +534,60 @@ impl Default for TranscriptionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── MCP ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Transport type for MCP server connections.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum McpTransport {
|
||||
/// Spawn a local process and communicate over stdin/stdout.
|
||||
#[default]
|
||||
Stdio,
|
||||
/// Connect via HTTP POST.
|
||||
Http,
|
||||
/// Connect via HTTP + Server-Sent Events.
|
||||
Sse,
|
||||
}
|
||||
|
||||
/// Configuration for a single external MCP server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct McpServerConfig {
|
||||
/// Display name used as a tool prefix (`<server>__<tool>`).
|
||||
pub name: String,
|
||||
/// Transport type (default: stdio).
|
||||
#[serde(default)]
|
||||
pub transport: McpTransport,
|
||||
/// URL for HTTP/SSE transports.
|
||||
#[serde(default)]
|
||||
pub url: Option<String>,
|
||||
/// Executable to spawn for stdio transport.
|
||||
#[serde(default)]
|
||||
pub command: String,
|
||||
/// Command arguments for stdio transport.
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
/// Optional environment variables for stdio transport.
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
/// Optional HTTP headers for HTTP/SSE transports.
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
/// Optional per-call timeout in seconds (hard capped in validation).
|
||||
#[serde(default)]
|
||||
pub tool_timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
/// External MCP client configuration (`[mcp]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct McpConfig {
|
||||
/// Enable MCP tool loading.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Configured MCP servers.
|
||||
#[serde(default, alias = "mcpServers")]
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
// ── Agents IPC ──────────────────────────────────────────────────
|
||||
|
||||
fn default_agents_ipc_db_path() -> String {
|
||||
@ -4929,6 +4987,7 @@ impl Default for Config {
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
agents_ipc: AgentsIpcConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
model_support_vision: None,
|
||||
wasm: WasmConfig::default(),
|
||||
}
|
||||
@ -5661,6 +5720,65 @@ fn read_codex_openai_api_key() -> Option<String> {
|
||||
.map(ToString::to_string)
|
||||
}
|
||||
|
||||
const MCP_MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
fn validate_mcp_config(config: &McpConfig) -> Result<()> {
|
||||
let mut seen_names = std::collections::HashSet::new();
|
||||
for (i, server) in config.servers.iter().enumerate() {
|
||||
let name = server.name.trim();
|
||||
if name.is_empty() {
|
||||
anyhow::bail!("mcp.servers[{i}].name must not be empty");
|
||||
}
|
||||
if !seen_names.insert(name.to_ascii_lowercase()) {
|
||||
anyhow::bail!("mcp.servers contains duplicate name: {name}");
|
||||
}
|
||||
|
||||
if let Some(timeout) = server.tool_timeout_secs {
|
||||
if timeout == 0 {
|
||||
anyhow::bail!("mcp.servers[{i}].tool_timeout_secs must be greater than 0");
|
||||
}
|
||||
if timeout > MCP_MAX_TOOL_TIMEOUT_SECS {
|
||||
anyhow::bail!(
|
||||
"mcp.servers[{i}].tool_timeout_secs exceeds max {MCP_MAX_TOOL_TIMEOUT_SECS}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match server.transport {
|
||||
McpTransport::Stdio => {
|
||||
if server.command.trim().is_empty() {
|
||||
anyhow::bail!(
|
||||
"mcp.servers[{i}] with transport=stdio requires non-empty command"
|
||||
);
|
||||
}
|
||||
}
|
||||
McpTransport::Http | McpTransport::Sse => {
|
||||
let url = server
|
||||
.url
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"mcp.servers[{i}] with transport={} requires url",
|
||||
match server.transport {
|
||||
McpTransport::Http => "http",
|
||||
McpTransport::Sse => "sse",
|
||||
McpTransport::Stdio => "stdio",
|
||||
}
|
||||
)
|
||||
})?;
|
||||
let parsed = reqwest::Url::parse(url)
|
||||
.with_context(|| format!("mcp.servers[{i}].url is not a valid URL"))?;
|
||||
if !matches!(parsed.scheme(), "http" | "https") {
|
||||
anyhow::bail!("mcp.servers[{i}].url must use http/https");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub async fn load_or_init() -> Result<Self> {
|
||||
let (default_zeroclaw_dir, default_workspace_dir) = default_config_and_workspace_dirs()?;
|
||||
@ -6316,6 +6434,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// MCP
|
||||
if self.mcp.enabled {
|
||||
validate_mcp_config(&self.mcp)?;
|
||||
}
|
||||
|
||||
// Proxy (delegate to existing validation)
|
||||
self.proxy.validate()?;
|
||||
|
||||
@ -7397,6 +7520,7 @@ default_temperature = 0.7
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
agents_ipc: AgentsIpcConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
model_support_vision: None,
|
||||
wasm: WasmConfig::default(),
|
||||
};
|
||||
@ -7767,6 +7891,7 @@ tool_dispatcher = "xml"
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
agents_ipc: AgentsIpcConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
model_support_vision: None,
|
||||
wasm: WasmConfig::default(),
|
||||
};
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
use crate::config::Config;
|
||||
use anyhow::Result;
|
||||
use anyhow::{bail, Result};
|
||||
use chrono::Utc;
|
||||
use std::future::Future;
|
||||
use std::path::PathBuf;
|
||||
@ -9,6 +9,22 @@ use tokio::time::Duration;
|
||||
const STATUS_FLUSH_SECONDS: u64 = 5;
|
||||
|
||||
pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
|
||||
// Pre-flight: check if port is already in use by another zeroclaw daemon
|
||||
if let Err(_e) = check_port_available(&host, port).await {
|
||||
// Port is in use - check if it's our daemon
|
||||
if is_zeroclaw_daemon_running(&host, port).await {
|
||||
tracing::info!("ZeroClaw daemon already running on {host}:{port}");
|
||||
println!("✓ ZeroClaw daemon already running on http://{host}:{port}");
|
||||
println!(" Use 'zeroclaw restart' to restart, or 'zeroclaw status' to check health.");
|
||||
return Ok(());
|
||||
}
|
||||
// Something else is using the port
|
||||
bail!(
|
||||
"Port {port} is already in use by another process. \
|
||||
Run 'lsof -i :{port}' to identify it, or use a different port."
|
||||
);
|
||||
}
|
||||
|
||||
let initial_backoff = config.reliability.channel_initial_backoff_secs.max(1);
|
||||
let max_backoff = config
|
||||
.reliability
|
||||
@ -326,6 +342,49 @@ fn has_supervised_channels(config: &Config) -> bool {
|
||||
.any(|(_, ok)| *ok)
|
||||
}
|
||||
|
||||
/// Check if a port is available for binding
|
||||
async fn check_port_available(host: &str, port: u16) -> Result<()> {
|
||||
let addr: std::net::SocketAddr = format!("{host}:{port}").parse()?;
|
||||
match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(listener) => {
|
||||
// Successfully bound - close it and return Ok
|
||||
drop(listener);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
|
||||
bail!("Port {} is already in use", port)
|
||||
}
|
||||
Err(e) => bail!("Failed to check port {}: {}", port, e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a running daemon on this port is our zeroclaw daemon
|
||||
async fn is_zeroclaw_daemon_running(host: &str, port: u16) -> bool {
|
||||
let url = format!("http://{}:{}/health", host, port);
|
||||
match reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
{
|
||||
Ok(client) => match client.get(&url).send().await {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
// Check if response looks like our health endpoint
|
||||
if let Ok(json) = resp.json::<serde_json::Value>().await {
|
||||
// Our health endpoint has "status" and "runtime.components"
|
||||
json.get("status").is_some() && json.get("runtime").is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
Err(_) => false,
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -182,6 +182,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
transcription: crate::config::TranscriptionConfig::default(),
|
||||
agents_ipc: crate::config::AgentsIpcConfig::default(),
|
||||
mcp: crate::config::schema::McpConfig::default(),
|
||||
model_support_vision: None,
|
||||
wasm: crate::config::WasmConfig::default(),
|
||||
};
|
||||
@ -542,6 +543,7 @@ async fn run_quick_setup_with_home(
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
transcription: crate::config::TranscriptionConfig::default(),
|
||||
agents_ipc: crate::config::AgentsIpcConfig::default(),
|
||||
mcp: crate::config::schema::McpConfig::default(),
|
||||
model_support_vision: None,
|
||||
wasm: crate::config::WasmConfig::default(),
|
||||
};
|
||||
|
||||
357
src/tools/mcp_client.rs
Normal file
357
src/tools/mcp_client.rs
Normal file
@ -0,0 +1,357 @@
|
||||
//! MCP (Model Context Protocol) client — connects to external tool servers.
|
||||
//!
|
||||
//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
use crate::config::schema::McpServerConfig;
|
||||
use crate::tools::mcp_protocol::{
|
||||
JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION,
|
||||
};
|
||||
use crate::tools::mcp_transport::{create_transport, McpTransportConn};
|
||||
|
||||
/// Timeout for receiving a response from an MCP server during init/list.
|
||||
/// Prevents a hung server from blocking the daemon indefinitely.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Default timeout for tool calls (seconds) when not configured per-server.
|
||||
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
|
||||
|
||||
/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
|
||||
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
// ── Internal server state ──────────────────────────────────────────────────
|
||||
|
||||
struct McpServerInner {
|
||||
config: McpServerConfig,
|
||||
transport: Box<dyn McpTransportConn>,
|
||||
next_id: AtomicU64,
|
||||
tools: Vec<McpToolDef>,
|
||||
}
|
||||
|
||||
// ── McpServer ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// A live connection to one MCP server (any transport).
|
||||
#[derive(Clone)]
|
||||
pub struct McpServer {
|
||||
inner: Arc<Mutex<McpServerInner>>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
/// Connect to the server, perform the initialize handshake, and fetch the tool list.
|
||||
pub async fn connect(config: McpServerConfig) -> Result<Self> {
|
||||
// Create transport based on config
|
||||
let mut transport = create_transport(&config).with_context(|| {
|
||||
format!(
|
||||
"failed to create transport for MCP server `{}`",
|
||||
config.name
|
||||
)
|
||||
})?;
|
||||
|
||||
// Initialize handshake
|
||||
let id = 1u64;
|
||||
let init_req = JsonRpcRequest::new(
|
||||
id,
|
||||
"initialize",
|
||||
json!({
|
||||
"protocolVersion": MCP_PROTOCOL_VERSION,
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "zeroclaw",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let init_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
transport.send_and_recv(&init_req),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` timed out after {}s waiting for initialize response",
|
||||
config.name, RECV_TIMEOUT_SECS
|
||||
)
|
||||
})??;
|
||||
|
||||
if init_resp.error.is_some() {
|
||||
bail!(
|
||||
"MCP server `{}` rejected initialize: {:?}",
|
||||
config.name,
|
||||
init_resp.error
|
||||
);
|
||||
}
|
||||
|
||||
// Notify server that client is initialized (no response expected for notifications)
|
||||
// For notifications, we send but don't wait for response
|
||||
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
|
||||
// Best effort - ignore errors for notifications
|
||||
let _ = transport.send_and_recv(¬if).await;
|
||||
|
||||
// Fetch available tools
|
||||
let id = 2u64;
|
||||
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
|
||||
|
||||
let list_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
transport.send_and_recv(&list_req),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` timed out after {}s waiting for tools/list response",
|
||||
config.name, RECV_TIMEOUT_SECS
|
||||
)
|
||||
})??;
|
||||
|
||||
let result = list_resp
|
||||
.result
|
||||
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
|
||||
let tool_list: McpToolsListResult = serde_json::from_value(result)
|
||||
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
|
||||
|
||||
let tool_count = tool_list.tools.len();
|
||||
|
||||
let inner = McpServerInner {
|
||||
config,
|
||||
transport,
|
||||
next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
|
||||
tools: tool_list.tools,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"MCP server `{}` connected — {} tool(s) available",
|
||||
inner.config.name,
|
||||
tool_count
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(inner)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Tools advertised by this server.
|
||||
pub async fn tools(&self) -> Vec<McpToolDef> {
|
||||
self.inner.lock().await.tools.clone()
|
||||
}
|
||||
|
||||
/// Server display name.
|
||||
pub async fn name(&self) -> String {
|
||||
self.inner.lock().await.config.name.clone()
|
||||
}
|
||||
|
||||
/// Call a tool on this server. Returns the raw JSON result.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let req = JsonRpcRequest::new(
|
||||
id,
|
||||
"tools/call",
|
||||
json!({ "name": tool_name, "arguments": arguments }),
|
||||
);
|
||||
|
||||
// Use per-server tool timeout if configured, otherwise default.
|
||||
// Cap at MAX_TOOL_TIMEOUT_SECS for safety.
|
||||
let tool_timeout = inner
|
||||
.config
|
||||
.tool_timeout_secs
|
||||
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
|
||||
.min(MAX_TOOL_TIMEOUT_SECS);
|
||||
|
||||
let resp = timeout(
|
||||
Duration::from_secs(tool_timeout),
|
||||
inner.transport.send_and_recv(&req),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow!(
|
||||
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
|
||||
inner.config.name,
|
||||
tool_timeout
|
||||
)
|
||||
})?
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` error during tool call `{tool_name}`",
|
||||
inner.config.name
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(err) = resp.error {
|
||||
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
|
||||
}
|
||||
Ok(resp.result.unwrap_or(serde_json::Value::Null))
|
||||
}
|
||||
}
|
||||
|
||||
// ── McpRegistry ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Registry of all connected MCP servers, with a flat tool index.
|
||||
pub struct McpRegistry {
|
||||
servers: Vec<McpServer>,
|
||||
/// prefixed_name → (server_index, original_tool_name)
|
||||
tool_index: HashMap<String, (usize, String)>,
|
||||
}
|
||||
|
||||
impl McpRegistry {
|
||||
/// Connect to all configured servers. Non-fatal: failures are logged and skipped.
|
||||
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
|
||||
let mut servers = Vec::new();
|
||||
let mut tool_index = HashMap::new();
|
||||
|
||||
for config in configs {
|
||||
match McpServer::connect(config.clone()).await {
|
||||
Ok(server) => {
|
||||
let server_idx = servers.len();
|
||||
// Collect tools while holding the lock once, then release
|
||||
let tools = server.tools().await;
|
||||
for tool in &tools {
|
||||
// Prefix prevents name collisions across servers
|
||||
let prefixed = format!("{}__{}", config.name, tool.name);
|
||||
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
|
||||
}
|
||||
servers.push(server);
|
||||
}
|
||||
// Non-fatal — log and continue with remaining servers
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
servers,
|
||||
tool_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// All prefixed tool names across all connected servers.
|
||||
pub fn tool_names(&self) -> Vec<String> {
|
||||
self.tool_index.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Tool definition for a given prefixed name (cloned).
|
||||
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
|
||||
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
|
||||
let inner = self.servers[*server_idx].inner.lock().await;
|
||||
inner
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| &t.name == original_name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Execute a tool by prefixed name.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
prefixed_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<String> {
|
||||
let (server_idx, original_name) = self
|
||||
.tool_index
|
||||
.get(prefixed_name)
|
||||
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
|
||||
let result = self.servers[*server_idx]
|
||||
.call_tool(original_name, arguments)
|
||||
.await?;
|
||||
serde_json::to_string_pretty(&result)
|
||||
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
|
||||
pub fn server_count(&self) -> usize {
|
||||
self.servers.len()
|
||||
}
|
||||
|
||||
pub fn tool_count(&self) -> usize {
|
||||
self.tool_index.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::McpTransport;
|
||||
|
||||
#[test]
|
||||
fn tool_name_prefix_format() {
|
||||
let prefixed = format!("{}__{}", "filesystem", "read_file");
|
||||
assert_eq!(prefixed, "filesystem__read_file");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_nonexistent_command_fails_cleanly() {
|
||||
// A command that doesn't exist should fail at spawn, not panic.
|
||||
let config = McpServerConfig {
|
||||
name: "nonexistent".to_string(),
|
||||
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
|
||||
args: vec![],
|
||||
env: Default::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: Default::default(),
|
||||
};
|
||||
let result = McpServer::connect(config).await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.err().unwrap().to_string();
|
||||
assert!(msg.contains("failed to create transport"), "got: {msg}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_all_nonfatal_on_single_failure() {
|
||||
// If one server config is bad, connect_all should succeed (with 0 servers).
|
||||
let configs = vec![McpServerConfig {
|
||||
name: "bad".to_string(),
|
||||
command: "/usr/bin/does_not_exist_zc_test".to_string(),
|
||||
args: vec![],
|
||||
env: Default::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: Default::default(),
|
||||
}];
|
||||
let registry = McpRegistry::connect_all(&configs)
|
||||
.await
|
||||
.expect("connect_all should not fail");
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.tool_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Http,
|
||||
..Default::default()
|
||||
};
|
||||
let result = create_transport(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sse_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Sse,
|
||||
..Default::default()
|
||||
};
|
||||
let result = create_transport(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
126
src/tools/mcp_protocol.rs
Normal file
126
src/tools/mcp_protocol.rs
Normal file
@ -0,0 +1,126 @@
|
||||
//! MCP (Model Context Protocol) JSON-RPC 2.0 protocol types.
|
||||
//! Protocol version: 2024-11-05
|
||||
//! Adapted from ops-mcp-server/src/protocol.rs for client use.
|
||||
//! Both Serialize and Deserialize are derived — the client both sends (Serialize)
|
||||
//! and receives (Deserialize) JSON-RPC messages.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||
|
||||
// Standard JSON-RPC 2.0 error codes
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
/// Outbound JSON-RPC request (client → MCP server).
|
||||
/// Used for both method calls (with id) and notifications (id = None).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
/// Create a method call request with a numeric id.
|
||||
pub fn new(id: u64, method: impl Into<String>, params: serde_json::Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id: Some(serde_json::Value::Number(id.into())),
|
||||
method: method.into(),
|
||||
params: Some(params),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a notification — no id, no response expected from server.
|
||||
pub fn notification(method: impl Into<String>, params: serde_json::Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
method: method.into(),
|
||||
params: Some(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inbound JSON-RPC response (MCP server → client).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
/// JSON-RPC error object embedded in a response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A tool advertised by an MCP server (from `tools/list` response).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolDef {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Expected shape of the `tools/list` result payload.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct McpToolsListResult {
|
||||
pub tools: Vec<McpToolDef>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_id() {
|
||||
let req = JsonRpcRequest::new(1, "tools/list", serde_json::json!({}));
|
||||
let s = serde_json::to_string(&req).unwrap();
|
||||
assert!(s.contains("\"id\":1"));
|
||||
assert!(s.contains("\"method\":\"tools/list\""));
|
||||
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notification_omits_id() {
|
||||
let notif =
|
||||
JsonRpcRequest::notification("notifications/initialized", serde_json::json!({}));
|
||||
let s = serde_json::to_string(¬if).unwrap();
|
||||
assert!(!s.contains("\"id\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.result.is_some());
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_def_deserializes_input_schema() {
|
||||
let json = r#"{"name":"read_file","description":"Read a file","inputSchema":{"type":"object","properties":{"path":{"type":"string"}}}}"#;
|
||||
let def: McpToolDef = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(def.name, "read_file");
|
||||
assert!(def.input_schema.is_object());
|
||||
}
|
||||
}
|
||||
68
src/tools/mcp_tool.rs
Normal file
68
src/tools/mcp_tool.rs
Normal file
@ -0,0 +1,68 @@
|
||||
//! Wraps a discovered MCP tool as a zeroclaw [`Tool`] so it is dispatched
|
||||
//! through the existing tool registry and agent loop without modification.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::tools::mcp_client::McpRegistry;
|
||||
use crate::tools::mcp_protocol::McpToolDef;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
/// A zeroclaw [`Tool`] backed by an MCP server tool.
|
||||
///
|
||||
/// The `prefixed_name` (e.g. `filesystem__read_file`) is what the agent loop
|
||||
/// sees. The registry knows how to route it to the correct server.
|
||||
pub struct McpToolWrapper {
|
||||
/// Prefixed name: `<server_name>__<tool_name>`.
|
||||
prefixed_name: String,
|
||||
/// Description extracted from the MCP tool definition. Stored as an owned
|
||||
/// String so that `description()` can return `&str` with self's lifetime.
|
||||
description: String,
|
||||
/// JSON schema for the tool's input parameters.
|
||||
input_schema: serde_json::Value,
|
||||
/// Shared registry — used to dispatch actual tool calls.
|
||||
registry: Arc<McpRegistry>,
|
||||
}
|
||||
|
||||
impl McpToolWrapper {
|
||||
pub fn new(prefixed_name: String, def: McpToolDef, registry: Arc<McpRegistry>) -> Self {
|
||||
let description = def.description.unwrap_or_else(|| "MCP tool".to_string());
|
||||
Self {
|
||||
prefixed_name,
|
||||
description,
|
||||
input_schema: def.input_schema,
|
||||
registry,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for McpToolWrapper {
|
||||
fn name(&self) -> &str {
|
||||
&self.prefixed_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.input_schema.clone()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
match self.registry.call_tool(&self.prefixed_name, args).await {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
285
src/tools/mcp_transport.rs
Normal file
285
src/tools/mcp_transport.rs
Normal file
@ -0,0 +1,285 @@
|
||||
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
use crate::config::schema::{McpServerConfig, McpTransport};
|
||||
use crate::tools::mcp_protocol::{JsonRpcRequest, JsonRpcResponse};
|
||||
|
||||
/// Maximum bytes for a single JSON-RPC response.
|
||||
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
||||
|
||||
/// Timeout for init/list operations.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
// ── Transport Trait ──────────────────────────────────────────────────────
|
||||
|
||||
/// Abstract transport for MCP communication.
|
||||
#[async_trait::async_trait]
|
||||
pub trait McpTransportConn: Send + Sync {
|
||||
/// Send a JSON-RPC request and receive the response.
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
|
||||
|
||||
/// Close the connection.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
// ── Stdio Transport ──────────────────────────────────────────────────────
|
||||
|
||||
/// Stdio-based transport (spawn local process).
|
||||
pub struct StdioTransport {
|
||||
_child: Child,
|
||||
stdin: tokio::process::ChildStdin,
|
||||
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let mut child = Command::new(&config.command)
|
||||
.args(&config.args)
|
||||
.envs(&config.env)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
|
||||
let stdout_lines = BufReader::new(stdout).lines();
|
||||
|
||||
Ok(Self {
|
||||
_child: child,
|
||||
stdin,
|
||||
stdout_lines,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_raw(&mut self, line: &str) -> Result<()> {
|
||||
self.stdin
|
||||
.write_all(line.as_bytes())
|
||||
.await
|
||||
.context("failed to write to MCP server stdin")?;
|
||||
self.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.context("failed to write newline to MCP server stdin")?;
|
||||
self.stdin.flush().await.context("failed to flush stdin")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_raw(&mut self) -> Result<String> {
|
||||
let line = self
|
||||
.stdout_lines
|
||||
.next_line()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("MCP server closed stdout"))?;
|
||||
if line.len() > MAX_LINE_BYTES {
|
||||
bail!("MCP response too large: {} bytes", line.len());
|
||||
}
|
||||
Ok(line)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for StdioTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let line = serde_json::to_string(request)?;
|
||||
self.send_raw(&line).await?;
|
||||
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
|
||||
.await
|
||||
.context("timeout waiting for MCP response")??;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
let _ = self.stdin.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── HTTP Transport ───────────────────────────────────────────────────────
|
||||
|
||||
/// HTTP-based transport (POST requests).
|
||||
pub struct HttpTransport {
|
||||
url: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl HttpTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for HTTP transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
url,
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for HttpTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let mut req = self.client.post(&self.url).body(body);
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.context("HTTP request to MCP server failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("MCP server returned HTTP {}", resp.status());
|
||||
}
|
||||
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
|
||||
Ok(mcp_resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── SSE Transport ─────────────────────────────────────────────────────────
|
||||
|
||||
/// SSE-based transport (HTTP POST for requests, SSE for responses).
|
||||
pub struct SseTransport {
|
||||
base_url: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
#[allow(dead_code)]
|
||||
event_source: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let base_url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
base_url,
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
event_source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for SseTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
// For SSE, we POST the request and the response comes via SSE stream.
|
||||
// Simplified implementation: treat as HTTP for now, proper SSE would
|
||||
// maintain a persistent event stream.
|
||||
let body = serde_json::to_string(request)?;
|
||||
let url = format!("{}/message", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.body(body)
|
||||
.header("Content-Type", "application/json");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("MCP server returned HTTP {}", resp.status());
|
||||
}
|
||||
|
||||
// For now, parse response directly. Full SSE would read from event stream.
|
||||
let resp_text = resp.text().await.context("failed to read SSE response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
|
||||
Ok(mcp_resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Factory ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Create a transport based on config.
|
||||
pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
|
||||
match config.transport {
|
||||
McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
|
||||
McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
|
||||
McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transport_default_is_stdio() {
|
||||
let config = McpServerConfig::default();
|
||||
assert_eq!(config.transport, McpTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Http,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(HttpTransport::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Sse,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(SseTransport::new(&config).is_err());
|
||||
}
|
||||
}
|
||||
@ -43,6 +43,10 @@ pub mod hardware_memory_map;
|
||||
pub mod hardware_memory_read;
|
||||
pub mod http_request;
|
||||
pub mod image_info;
|
||||
pub mod mcp_client;
|
||||
pub mod mcp_protocol;
|
||||
pub mod mcp_tool;
|
||||
pub mod mcp_transport;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
@ -93,6 +97,8 @@ pub use hardware_memory_map::HardwareMemoryMapTool;
|
||||
pub use hardware_memory_read::HardwareMemoryReadTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use image_info::ImageInfoTool;
|
||||
pub use mcp_client::McpRegistry;
|
||||
pub use mcp_tool::McpToolWrapper;
|
||||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user