286 lines
9.6 KiB
Rust
286 lines
9.6 KiB
Rust
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
|
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::process::{Child, Command};
|
|
use tokio::time::{timeout, Duration};
|
|
|
|
use crate::config::schema::{McpServerConfig, McpTransport};
|
|
use crate::tools::mcp_protocol::{JsonRpcRequest, JsonRpcResponse};
|
|
|
|
/// Maximum bytes for a single JSON-RPC response.
|
|
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
|
|
|
/// Timeout for init/list operations.
|
|
const RECV_TIMEOUT_SECS: u64 = 30;
|
|
|
|
// ── Transport Trait ──────────────────────────────────────────────────────
|
|
|
|
/// Abstract transport for MCP communication.
|
|
#[async_trait::async_trait]
|
|
pub trait McpTransportConn: Send + Sync {
|
|
/// Send a JSON-RPC request and receive the response.
|
|
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
|
|
|
|
/// Close the connection.
|
|
async fn close(&mut self) -> Result<()>;
|
|
}
|
|
|
|
// ── Stdio Transport ──────────────────────────────────────────────────────
|
|
|
|
/// Stdio-based transport (spawn local process).
|
|
pub struct StdioTransport {
|
|
_child: Child,
|
|
stdin: tokio::process::ChildStdin,
|
|
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
|
|
}
|
|
|
|
impl StdioTransport {
|
|
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
|
let mut child = Command::new(&config.command)
|
|
.args(&config.args)
|
|
.envs(&config.env)
|
|
.stdin(std::process::Stdio::piped())
|
|
.stdout(std::process::Stdio::piped())
|
|
.stderr(std::process::Stdio::inherit())
|
|
.kill_on_drop(true)
|
|
.spawn()
|
|
.with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
|
|
|
|
let stdin = child
|
|
.stdin
|
|
.take()
|
|
.ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
|
|
let stdout = child
|
|
.stdout
|
|
.take()
|
|
.ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
|
|
let stdout_lines = BufReader::new(stdout).lines();
|
|
|
|
Ok(Self {
|
|
_child: child,
|
|
stdin,
|
|
stdout_lines,
|
|
})
|
|
}
|
|
|
|
async fn send_raw(&mut self, line: &str) -> Result<()> {
|
|
self.stdin
|
|
.write_all(line.as_bytes())
|
|
.await
|
|
.context("failed to write to MCP server stdin")?;
|
|
self.stdin
|
|
.write_all(b"\n")
|
|
.await
|
|
.context("failed to write newline to MCP server stdin")?;
|
|
self.stdin.flush().await.context("failed to flush stdin")?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn recv_raw(&mut self) -> Result<String> {
|
|
let line = self
|
|
.stdout_lines
|
|
.next_line()
|
|
.await?
|
|
.ok_or_else(|| anyhow!("MCP server closed stdout"))?;
|
|
if line.len() > MAX_LINE_BYTES {
|
|
bail!("MCP response too large: {} bytes", line.len());
|
|
}
|
|
Ok(line)
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl McpTransportConn for StdioTransport {
|
|
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
|
let line = serde_json::to_string(request)?;
|
|
self.send_raw(&line).await?;
|
|
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
|
|
.await
|
|
.context("timeout waiting for MCP response")??;
|
|
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
|
|
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
|
|
Ok(resp)
|
|
}
|
|
|
|
async fn close(&mut self) -> Result<()> {
|
|
let _ = self.stdin.shutdown().await;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// ── HTTP Transport ───────────────────────────────────────────────────────
|
|
|
|
/// HTTP-based transport (POST requests).
|
|
pub struct HttpTransport {
|
|
url: String,
|
|
client: reqwest::Client,
|
|
headers: std::collections::HashMap<String, String>,
|
|
}
|
|
|
|
impl HttpTransport {
|
|
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
|
let url = config
|
|
.url
|
|
.as_ref()
|
|
.ok_or_else(|| anyhow!("URL required for HTTP transport"))?
|
|
.clone();
|
|
|
|
let client = reqwest::Client::builder()
|
|
.timeout(Duration::from_secs(120))
|
|
.build()
|
|
.context("failed to build HTTP client")?;
|
|
|
|
Ok(Self {
|
|
url,
|
|
client,
|
|
headers: config.headers.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl McpTransportConn for HttpTransport {
|
|
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
|
let body = serde_json::to_string(request)?;
|
|
|
|
let mut req = self.client.post(&self.url).body(body);
|
|
for (key, value) in &self.headers {
|
|
req = req.header(key, value);
|
|
}
|
|
|
|
let resp = req
|
|
.send()
|
|
.await
|
|
.context("HTTP request to MCP server failed")?;
|
|
|
|
if !resp.status().is_success() {
|
|
bail!("MCP server returned HTTP {}", resp.status());
|
|
}
|
|
|
|
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
|
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
|
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
|
|
|
Ok(mcp_resp)
|
|
}
|
|
|
|
async fn close(&mut self) -> Result<()> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// ── SSE Transport ─────────────────────────────────────────────────────────
|
|
|
|
/// SSE-based transport (HTTP POST for requests, SSE for responses).
|
|
pub struct SseTransport {
|
|
base_url: String,
|
|
client: reqwest::Client,
|
|
headers: std::collections::HashMap<String, String>,
|
|
#[allow(dead_code)]
|
|
event_source: Option<tokio::task::JoinHandle<()>>,
|
|
}
|
|
|
|
impl SseTransport {
|
|
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
|
let base_url = config
|
|
.url
|
|
.as_ref()
|
|
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
|
|
.clone();
|
|
|
|
let client = reqwest::Client::builder()
|
|
.timeout(Duration::from_secs(120))
|
|
.build()
|
|
.context("failed to build HTTP client")?;
|
|
|
|
Ok(Self {
|
|
base_url,
|
|
client,
|
|
headers: config.headers.clone(),
|
|
event_source: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl McpTransportConn for SseTransport {
|
|
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
|
// For SSE, we POST the request and the response comes via SSE stream.
|
|
// Simplified implementation: treat as HTTP for now, proper SSE would
|
|
// maintain a persistent event stream.
|
|
let body = serde_json::to_string(request)?;
|
|
let url = format!("{}/message", self.base_url.trim_end_matches('/'));
|
|
|
|
let mut req = self
|
|
.client
|
|
.post(&url)
|
|
.body(body)
|
|
.header("Content-Type", "application/json");
|
|
for (key, value) in &self.headers {
|
|
req = req.header(key, value);
|
|
}
|
|
|
|
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
|
|
|
if !resp.status().is_success() {
|
|
bail!("MCP server returned HTTP {}", resp.status());
|
|
}
|
|
|
|
// For now, parse response directly. Full SSE would read from event stream.
|
|
let resp_text = resp.text().await.context("failed to read SSE response")?;
|
|
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
|
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
|
|
|
Ok(mcp_resp)
|
|
}
|
|
|
|
async fn close(&mut self) -> Result<()> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
// ── Factory ──────────────────────────────────────────────────────────────
|
|
|
|
/// Create a transport based on config.
|
|
pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
|
|
match config.transport {
|
|
McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
|
|
McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
|
|
McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
|
|
}
|
|
}
|
|
|
|
// ── Tests ─────────────────────────────────────────────────────────────────
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_transport_default_is_stdio() {
|
|
let config = McpServerConfig::default();
|
|
assert_eq!(config.transport, McpTransport::Stdio);
|
|
}
|
|
|
|
#[test]
|
|
fn test_http_transport_requires_url() {
|
|
let config = McpServerConfig {
|
|
name: "test".into(),
|
|
transport: McpTransport::Http,
|
|
..Default::default()
|
|
};
|
|
assert!(HttpTransport::new(&config).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_sse_transport_requires_url() {
|
|
let config = McpServerConfig {
|
|
name: "test".into(),
|
|
transport: McpTransport::Sse,
|
|
..Default::default()
|
|
};
|
|
assert!(SseTransport::new(&config).is_err());
|
|
}
|
|
}
|