diff --git a/src/tools/mcp_transport.rs b/src/tools/mcp_transport.rs index 27398451c..cc98c3c78 100644 --- a/src/tools/mcp_transport.rs +++ b/src/tools/mcp_transport.rs @@ -18,6 +18,12 @@ const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB /// Timeout for init/list operations. const RECV_TIMEOUT_SECS: u64 = 30; +/// Streamable HTTP Accept header required by MCP HTTP transport. +const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream"; + +/// Default media type for MCP JSON-RPC request bodies. +const MCP_JSON_CONTENT_TYPE: &str = "application/json"; + // ── Transport Trait ────────────────────────────────────────────────────── /// Abstract transport for MCP communication. @@ -171,10 +177,25 @@ impl McpTransportConn for HttpTransport { async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result { let body = serde_json::to_string(request)?; + let has_accept = self + .headers + .keys() + .any(|k| k.eq_ignore_ascii_case("Accept")); + let has_content_type = self + .headers + .keys() + .any(|k| k.eq_ignore_ascii_case("Content-Type")); + let mut req = self.client.post(&self.url).body(body); + if !has_content_type { + req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE); + } for (key, value) in &self.headers { req = req.header(key, value); } + if !has_accept { + req = req.header("Accept", MCP_STREAMABLE_ACCEPT); + } let resp = req .send() @@ -194,11 +215,24 @@ impl McpTransportConn for HttpTransport { }); } - let resp_text = resp.text().await.context("failed to read HTTP response")?; - let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text) - .with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?; + let is_sse = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream")); + if is_sse { + let maybe_resp = timeout( + Duration::from_secs(RECV_TIMEOUT_SECS), + read_first_jsonrpc_from_sse_response(resp), + ) + .await + .context("timeout waiting for MCP response from streamable HTTP SSE stream")??; + return maybe_resp + .ok_or_else(|| anyhow!("MCP server returned no response in SSE stream")); + } - Ok(mcp_resp) + let resp_text = resp.text().await.context("failed to read HTTP response")?; + parse_jsonrpc_response_text(&resp_text) } async fn close(&mut self) -> Result<()> { @@ -264,14 +298,21 @@ impl SseTransport { } } + let has_accept = self + .headers + .keys() + .any(|k| k.eq_ignore_ascii_case("Accept")); + let mut req = self .client .get(&self.sse_url) - .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache"); for (key, value) in &self.headers { req = req.header(key, value); } + if !has_accept { + req = req.header("Accept", MCP_STREAMABLE_ACCEPT); + } let resp = req.send().await.context("SSE GET to MCP server failed")?; if resp.status() == reqwest::StatusCode::NOT_FOUND @@ -556,6 +597,30 @@ fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> { Cow::Owned(joined.trim().to_string()) } +fn parse_jsonrpc_response_text(resp_text: &str) -> Result { + let trimmed = resp_text.trim(); + if trimmed.is_empty() { + bail!("MCP server returned no response"); + } + + let json_text = if looks_like_sse_text(trimmed) { + extract_json_from_sse_text(trimmed) + } else { + Cow::Borrowed(trimmed) + }; + + let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref()) + .with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?; + Ok(mcp_resp) +} + +fn looks_like_sse_text(text: &str) -> bool { + text.starts_with("data:") + || text.starts_with("event:") + || text.contains("\ndata:") + || text.contains("\nevent:") +} + async fn read_first_jsonrpc_from_sse_response( resp: reqwest::Response, ) -> Result> { @@ -673,21 +738,27 @@ impl McpTransportConn for SseTransport { .chain(secondary_url.into_iter()) .enumerate() { + let has_accept = self + .headers + .keys() + .any(|k| k.eq_ignore_ascii_case("Accept")); + let has_content_type = self + .headers + .keys() + .any(|k| k.eq_ignore_ascii_case("Content-Type")); let mut req = self .client .post(&url) .timeout(Duration::from_secs(120)) - .body(body.clone()) - .header("Content-Type", "application/json"); + .body(body.clone()); + if !has_content_type { + req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE); + } for (key, value) in &self.headers { req = req.header(key, value); } - if !self - .headers - .keys() - .any(|k| k.eq_ignore_ascii_case("Accept")) - { - req = req.header("Accept", "application/json, text/event-stream"); + if !has_accept { + req = req.header("Accept", MCP_STREAMABLE_ACCEPT); } let resp = req.send().await.context("SSE POST to MCP server failed")?; @@ -887,4 +958,34 @@ mod tests { let extracted = extract_json_from_sse_text(input); let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap(); } + + #[test] + fn test_parse_jsonrpc_response_text_handles_plain_json() { + let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}") + .expect("plain JSON response should parse"); + assert_eq!(parsed.id, Some(serde_json::json!(1))); + assert!(parsed.error.is_none()); + } + + #[test] + fn test_parse_jsonrpc_response_text_handles_sse_framed_json() { + let sse = + "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n"; + let parsed = + parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse"); + assert_eq!(parsed.id, Some(serde_json::json!(2))); + assert_eq!( + parsed + .result + .as_ref() + .and_then(|v| v.get("ok")) + .and_then(|v| v.as_bool()), + Some(true) + ); + } + + #[test] + fn test_parse_jsonrpc_response_text_rejects_empty_payload() { + assert!(parse_jsonrpc_response_text(" \n\t ").is_err()); + } }