Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
7e0570abd6
@ -48,8 +48,8 @@ schemars = "1.2"
|
||||
tracing = { version = "0.1", default-features = false }
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
|
||||
|
||||
# Observability - Prometheus metrics
|
||||
prometheus = { version = "0.14", default-features = false }
|
||||
# Observability - Prometheus metrics (optional; requires AtomicU64, unavailable on 32-bit)
|
||||
prometheus = { version = "0.14", default-features = false, optional = true }
|
||||
|
||||
# Base64 encoding (screenshots, image data)
|
||||
base64 = "0.22"
|
||||
@ -205,6 +205,8 @@ sandbox-landlock = ["dep:landlock"]
|
||||
sandbox-bubblewrap = []
|
||||
# Backward-compatible alias for older invocations
|
||||
landlock = ["sandbox-landlock"]
|
||||
# Prometheus metrics observer (requires 64-bit atomics; disable on 32-bit targets)
|
||||
metrics = ["dep:prometheus"]
|
||||
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
||||
probe = ["dep:probe-rs"]
|
||||
# rag-pdf = PDF ingestion for datasheet RAG
|
||||
|
||||
110
build.rs
110
build.rs
@ -1,6 +1,110 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
let dir = std::path::Path::new("web/dist");
|
||||
if !dir.exists() {
|
||||
std::fs::create_dir_all(dir).expect("failed to create web/dist/");
|
||||
let dist_dir = Path::new("web/dist");
|
||||
let web_dir = Path::new("web");
|
||||
|
||||
// Tell Cargo to re-run this script when web source files change.
|
||||
println!("cargo:rerun-if-changed=web/src");
|
||||
println!("cargo:rerun-if-changed=web/index.html");
|
||||
println!("cargo:rerun-if-changed=web/package.json");
|
||||
println!("cargo:rerun-if-changed=web/vite.config.ts");
|
||||
|
||||
// Attempt to build the web frontend if npm is available and web/dist is
|
||||
// missing or stale. The build is best-effort: when Node.js is not
|
||||
// installed (e.g. CI containers, cross-compilation, minimal dev setups)
|
||||
// we fall back to the existing stub/empty dist directory so the Rust
|
||||
// build still succeeds.
|
||||
let needs_build = !dist_dir.join("index.html").exists();
|
||||
|
||||
if needs_build && web_dir.join("package.json").exists() {
|
||||
if let Ok(npm) = which_npm() {
|
||||
eprintln!("cargo:warning=Building web frontend (web/dist is missing or stale)...");
|
||||
|
||||
// npm ci / npm install
|
||||
let install_status = Command::new(&npm)
|
||||
.args(["ci", "--ignore-scripts"])
|
||||
.current_dir(web_dir)
|
||||
.status();
|
||||
|
||||
match install_status {
|
||||
Ok(s) if s.success() => {}
|
||||
Ok(s) => {
|
||||
// Fall back to `npm install` if `npm ci` fails (no lockfile, etc.)
|
||||
eprintln!("cargo:warning=npm ci exited with {s}, trying npm install...");
|
||||
let fallback = Command::new(&npm)
|
||||
.args(["install"])
|
||||
.current_dir(web_dir)
|
||||
.status();
|
||||
if !matches!(fallback, Ok(s) if s.success()) {
|
||||
eprintln!("cargo:warning=npm install failed — skipping web build");
|
||||
ensure_dist_dir(dist_dir);
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("cargo:warning=Could not run npm: {e} — skipping web build");
|
||||
ensure_dist_dir(dist_dir);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// npm run build
|
||||
let build_status = Command::new(&npm)
|
||||
.args(["run", "build"])
|
||||
.current_dir(web_dir)
|
||||
.status();
|
||||
|
||||
match build_status {
|
||||
Ok(s) if s.success() => {
|
||||
eprintln!("cargo:warning=Web frontend built successfully.");
|
||||
}
|
||||
Ok(s) => {
|
||||
eprintln!(
|
||||
"cargo:warning=npm run build exited with {s} — web dashboard may be unavailable"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"cargo:warning=Could not run npm build: {e} — web dashboard may be unavailable"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ensure_dist_dir(dist_dir);
|
||||
}
|
||||
|
||||
/// Ensure the dist directory exists so `rust-embed` does not fail at compile
|
||||
/// time even when the web frontend is not built.
|
||||
fn ensure_dist_dir(dist_dir: &Path) {
|
||||
if !dist_dir.exists() {
|
||||
std::fs::create_dir_all(dist_dir).expect("failed to create web/dist/");
|
||||
}
|
||||
}
|
||||
|
||||
/// Locate the `npm` binary on the system PATH.
|
||||
fn which_npm() -> Result<String, ()> {
|
||||
let cmd = if cfg!(target_os = "windows") {
|
||||
"where"
|
||||
} else {
|
||||
"which"
|
||||
};
|
||||
|
||||
Command::new(cmd)
|
||||
.arg("npm")
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|output| {
|
||||
if output.status.success() {
|
||||
String::from_utf8(output.stdout)
|
||||
.ok()
|
||||
.map(|s| s.lines().next().unwrap_or("npm").trim().to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or(())
|
||||
}
|
||||
|
||||
41
install.sh
41
install.sh
@ -211,8 +211,35 @@ should_attempt_prebuilt_for_resources() {
|
||||
return 1
|
||||
}
|
||||
|
||||
resolve_asset_url() {
|
||||
local asset_name="$1"
|
||||
local api_url="https://api.github.com/repos/zeroclaw-labs/zeroclaw/releases"
|
||||
local releases_json download_url
|
||||
|
||||
# Fetch up to 10 recent releases (includes prereleases) and find the first
|
||||
# one that contains the requested asset.
|
||||
releases_json="$(curl -fsSL "${api_url}?per_page=10" 2>/dev/null || true)"
|
||||
if [[ -z "$releases_json" ]]; then
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Parse with simple grep/sed — avoids jq dependency.
|
||||
download_url="$(printf '%s\n' "$releases_json" \
|
||||
| tr ',' '\n' \
|
||||
| grep '"browser_download_url"' \
|
||||
| sed 's/.*"browser_download_url"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/' \
|
||||
| grep "/${asset_name}\$" \
|
||||
| head -n 1)"
|
||||
|
||||
if [[ -z "$download_url" ]]; then
|
||||
return 1
|
||||
fi
|
||||
|
||||
echo "$download_url"
|
||||
}
|
||||
|
||||
install_prebuilt_binary() {
|
||||
local target archive_url temp_dir archive_path extracted_bin install_dir
|
||||
local target archive_url temp_dir archive_path extracted_bin install_dir asset_name
|
||||
|
||||
if ! have_cmd curl; then
|
||||
warn "curl is required for pre-built binary installation."
|
||||
@ -229,9 +256,17 @@ install_prebuilt_binary() {
|
||||
return 1
|
||||
fi
|
||||
|
||||
archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-${target}.tar.gz"
|
||||
asset_name="zeroclaw-${target}.tar.gz"
|
||||
|
||||
# Try the GitHub API first to find the newest release (including prereleases)
|
||||
# that actually contains the asset, then fall back to /releases/latest/.
|
||||
archive_url="$(resolve_asset_url "$asset_name" || true)"
|
||||
if [[ -z "$archive_url" ]]; then
|
||||
archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/${asset_name}"
|
||||
fi
|
||||
|
||||
temp_dir="$(mktemp -d -t zeroclaw-prebuilt-XXXXXX)"
|
||||
archive_path="$temp_dir/zeroclaw-${target}.tar.gz"
|
||||
archive_path="$temp_dir/${asset_name}"
|
||||
|
||||
info "Attempting pre-built binary install for target: $target"
|
||||
if ! curl -fsSL "$archive_url" -o "$archive_path"; then
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
use crate::channels::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
use std::sync::atomic::AtomicU32;
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
@ -13,7 +17,10 @@ use tokio_rustls::rustls;
|
||||
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||||
|
||||
/// Monotonic counter to ensure unique message IDs under burst traffic.
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
static MSG_SEQ: AtomicU64 = AtomicU64::new(0);
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
static MSG_SEQ: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
/// IRC over TLS channel.
|
||||
///
|
||||
|
||||
@ -61,20 +61,33 @@ impl LinqChannel {
|
||||
|
||||
/// Parse an incoming webhook payload from Linq and extract messages.
|
||||
///
|
||||
/// Linq webhook envelope:
|
||||
/// Supports two webhook formats:
|
||||
///
|
||||
/// **New format (webhook_version 2026-02-03):**
|
||||
/// ```json
|
||||
/// {
|
||||
/// "api_version": "v3",
|
||||
/// "webhook_version": "2026-02-03",
|
||||
/// "event_type": "message.received",
|
||||
/// "data": {
|
||||
/// "id": "msg-...",
|
||||
/// "direction": "inbound",
|
||||
/// "sender_handle": { "handle": "+1...", "is_me": false },
|
||||
/// "chat": { "id": "chat-..." },
|
||||
/// "parts": [{ "type": "text", "value": "..." }]
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// **Legacy format (webhook_version 2025-01-01):**
|
||||
/// ```json
|
||||
/// {
|
||||
/// "api_version": "v3",
|
||||
/// "event_type": "message.received",
|
||||
/// "event_id": "...",
|
||||
/// "created_at": "...",
|
||||
/// "trace_id": "...",
|
||||
/// "data": {
|
||||
/// "chat_id": "...",
|
||||
/// "from": "+1...",
|
||||
/// "recipient_phone": "+1...",
|
||||
/// "is_from_me": false,
|
||||
/// "service": "iMessage",
|
||||
/// "message": {
|
||||
/// "id": "...",
|
||||
/// "parts": [{ "type": "text", "value": "..." }]
|
||||
@ -99,18 +112,44 @@ impl LinqChannel {
|
||||
return messages;
|
||||
};
|
||||
|
||||
// Detect format: new format has `sender_handle`, legacy has `from`.
|
||||
let is_new_format = data.get("sender_handle").is_some();
|
||||
|
||||
// Skip messages sent by the bot itself
|
||||
if data
|
||||
.get("is_from_me")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let is_from_me = if is_new_format {
|
||||
// New format: data.sender_handle.is_me or data.direction == "outbound"
|
||||
data.get("sender_handle")
|
||||
.and_then(|sh| sh.get("is_me"))
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
|| data
|
||||
.get("direction")
|
||||
.and_then(|d| d.as_str())
|
||||
.is_some_and(|d| d == "outbound")
|
||||
} else {
|
||||
// Legacy format: data.is_from_me
|
||||
data.get("is_from_me")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
if is_from_me {
|
||||
tracing::debug!("Linq: skipping is_from_me message");
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Get sender phone number
|
||||
let Some(from) = data.get("from").and_then(|f| f.as_str()) else {
|
||||
let from = if is_new_format {
|
||||
// New format: data.sender_handle.handle
|
||||
data.get("sender_handle")
|
||||
.and_then(|sh| sh.get("handle"))
|
||||
.and_then(|h| h.as_str())
|
||||
} else {
|
||||
// Legacy format: data.from
|
||||
data.get("from").and_then(|f| f.as_str())
|
||||
};
|
||||
|
||||
let Some(from) = from else {
|
||||
return messages;
|
||||
};
|
||||
|
||||
@ -132,18 +171,33 @@ impl LinqChannel {
|
||||
}
|
||||
|
||||
// Get chat_id for reply routing
|
||||
let chat_id = data
|
||||
.get("chat_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// Extract text from message parts
|
||||
let Some(message) = data.get("message") else {
|
||||
return messages;
|
||||
let chat_id = if is_new_format {
|
||||
// New format: data.chat.id
|
||||
data.get("chat")
|
||||
.and_then(|c| c.get("id"))
|
||||
.and_then(|id| id.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else {
|
||||
// Legacy format: data.chat_id
|
||||
data.get("chat_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
};
|
||||
|
||||
let Some(parts) = message.get("parts").and_then(|p| p.as_array()) else {
|
||||
// Extract message parts
|
||||
let parts = if is_new_format {
|
||||
// New format: data.parts (directly on data)
|
||||
data.get("parts").and_then(|p| p.as_array())
|
||||
} else {
|
||||
// Legacy format: data.message.parts
|
||||
data.get("message")
|
||||
.and_then(|m| m.get("parts"))
|
||||
.and_then(|p| p.as_array())
|
||||
};
|
||||
|
||||
let Some(parts) = parts else {
|
||||
return messages;
|
||||
};
|
||||
|
||||
@ -790,4 +844,217 @@ mod tests {
|
||||
let ch = make_channel();
|
||||
assert_eq!(ch.phone_number(), "+15551234567");
|
||||
}
|
||||
|
||||
// ---- New format (2026-02-03) tests ----
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_text_message() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"api_version": "v3",
|
||||
"webhook_version": "2026-02-03",
|
||||
"event_type": "message.received",
|
||||
"event_id": "evt-123",
|
||||
"created_at": "2026-03-01T12:00:00Z",
|
||||
"trace_id": "trace-456",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"service": "iMessage",
|
||||
"parts": [{
|
||||
"type": "text",
|
||||
"value": "Hello from new format!"
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "+1234567890");
|
||||
assert_eq!(msgs[0].content, "Hello from new format!");
|
||||
assert_eq!(msgs[0].channel, "linq");
|
||||
assert_eq!(msgs[0].reply_target, "chat-789");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_skip_is_me() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "outbound",
|
||||
"sender_handle": {
|
||||
"handle": "+15551234567",
|
||||
"is_me": true
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "My own message" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert!(
|
||||
msgs.is_empty(),
|
||||
"is_me messages should be skipped in new format"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_skip_outbound_direction() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "outbound",
|
||||
"sender_handle": {
|
||||
"handle": "+15551234567",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "Outbound" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert!(msgs.is_empty(), "outbound direction should be skipped");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_unauthorized_sender() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+9999999999",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "Spam" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert!(
|
||||
msgs.is_empty(),
|
||||
"Unauthorized senders should be filtered in new format"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_media_image() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{
|
||||
"type": "media",
|
||||
"url": "https://example.com/photo.png",
|
||||
"mime_type": "image/png"
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "[IMAGE:https://example.com/photo.png]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_multiple_parts() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [
|
||||
{ "type": "text", "value": "Check this out" },
|
||||
{ "type": "media", "url": "https://example.com/img.jpg", "mime_type": "image/jpeg" }
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(
|
||||
msgs[0].content,
|
||||
"Check this out\n[IMAGE:https://example.com/img.jpg]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_fallback_reply_target_when_no_chat() {
|
||||
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "+1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"parts": [{ "type": "text", "value": "Hi" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].reply_target, "+1234567890");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn linq_parse_new_format_normalizes_phone() {
|
||||
let ch = LinqChannel::new(
|
||||
"tok".into(),
|
||||
"+15551234567".into(),
|
||||
vec!["+1234567890".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"event_type": "message.received",
|
||||
"webhook_version": "2026-02-03",
|
||||
"data": {
|
||||
"id": "msg-abc",
|
||||
"direction": "inbound",
|
||||
"sender_handle": {
|
||||
"handle": "1234567890",
|
||||
"is_me": false
|
||||
},
|
||||
"chat": { "id": "chat-789" },
|
||||
"parts": [{ "type": "text", "value": "Hi" }]
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_webhook_payload(&payload);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "+1234567890");
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,7 +89,11 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
use std::sync::atomic::AtomicU32;
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@ -2305,7 +2309,10 @@ async fn run_message_dispatch_loop(
|
||||
String,
|
||||
InFlightSenderTaskState,
|
||||
>::new()));
|
||||
#[cfg(target_has_atomic = "64")]
|
||||
let task_sequence = Arc::new(AtomicU64::new(1));
|
||||
#[cfg(not(target_has_atomic = "64"))]
|
||||
let task_sequence = Arc::new(AtomicU32::new(1));
|
||||
|
||||
while let Some(msg) = rx.recv().await {
|
||||
let permit = match Arc::clone(&semaphore).acquire_owned().await {
|
||||
@ -2323,7 +2330,7 @@ async fn run_message_dispatch_loop(
|
||||
let sender_scope_key = interruption_scope_key(&msg);
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let completion = Arc::new(InFlightTaskCompletion::new());
|
||||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed);
|
||||
let task_id = task_sequence.fetch_add(1, Ordering::Relaxed) as u64;
|
||||
|
||||
if interrupt_enabled {
|
||||
let previous = {
|
||||
@ -3364,7 +3371,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
};
|
||||
// Build system prompt from workspace identity files + skills
|
||||
let workspace = config.workspace_dir.clone();
|
||||
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||
let mut built_tools = tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
&security,
|
||||
runtime,
|
||||
@ -3378,7 +3385,44 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
));
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
config.mcp.servers.len()
|
||||
);
|
||||
match crate::tools::mcp_client::McpRegistry::connect_all(&config.mcp.servers).await {
|
||||
Ok(registry) => {
|
||||
let registry = std::sync::Arc::new(registry);
|
||||
let names = registry.tool_names();
|
||||
let mut registered = 0usize;
|
||||
for name in names {
|
||||
if let Some(def) = registry.get_tool_def(&name).await {
|
||||
let wrapper = crate::tools::mcp_tool::McpToolWrapper::new(
|
||||
name,
|
||||
def,
|
||||
std::sync::Arc::clone(®istry),
|
||||
);
|
||||
built_tools.push(Box::new(wrapper));
|
||||
registered += 1;
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
"MCP: {} tool(s) registered from {} server(s)",
|
||||
registered,
|
||||
registry.server_count()
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal — daemon continues with the tools registered above.
|
||||
tracing::error!("MCP registry failed to initialize: {e:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_registry = Arc::new(built_tools);
|
||||
|
||||
let skills = crate::skills::load_skills_with_config(&workspace, &config);
|
||||
|
||||
|
||||
@ -10,12 +10,13 @@ pub use schema::{
|
||||
CronConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleTtsConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig,
|
||||
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig,
|
||||
McpServerConfig, McpTransport, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, ObservabilityConfig, OpenAiTtsConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig,
|
||||
SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
@ -232,6 +232,10 @@ pub struct Config {
|
||||
/// Text-to-Speech configuration (`[tts]`).
|
||||
#[serde(default)]
|
||||
pub tts: TtsConfig,
|
||||
|
||||
/// External MCP server connections (`[mcp]`).
|
||||
#[serde(default, alias = "mcpServers")]
|
||||
pub mcp: McpConfig,
|
||||
}
|
||||
|
||||
/// Named provider profile definition compatible with Codex app-server style config.
|
||||
@ -455,6 +459,60 @@ impl Default for TranscriptionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── MCP ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Transport type for MCP server connections.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum McpTransport {
|
||||
/// Spawn a local process and communicate over stdin/stdout.
|
||||
#[default]
|
||||
Stdio,
|
||||
/// Connect via HTTP POST.
|
||||
Http,
|
||||
/// Connect via HTTP + Server-Sent Events.
|
||||
Sse,
|
||||
}
|
||||
|
||||
/// Configuration for a single external MCP server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct McpServerConfig {
|
||||
/// Display name used as a tool prefix (`<server>__<tool>`).
|
||||
pub name: String,
|
||||
/// Transport type (default: stdio).
|
||||
#[serde(default)]
|
||||
pub transport: McpTransport,
|
||||
/// URL for HTTP/SSE transports.
|
||||
#[serde(default)]
|
||||
pub url: Option<String>,
|
||||
/// Executable to spawn for stdio transport.
|
||||
#[serde(default)]
|
||||
pub command: String,
|
||||
/// Command arguments for stdio transport.
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
/// Optional environment variables for stdio transport.
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
/// Optional HTTP headers for HTTP/SSE transports.
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
/// Optional per-call timeout in seconds (hard capped in validation).
|
||||
#[serde(default)]
|
||||
pub tool_timeout_secs: Option<u64>,
|
||||
}
|
||||
|
||||
/// External MCP client configuration (`[mcp]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct McpConfig {
|
||||
/// Enable MCP tool loading.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Configured MCP servers.
|
||||
#[serde(default, alias = "mcpServers")]
|
||||
pub servers: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
// ── TTS (Text-to-Speech) ─────────────────────────────────────────
|
||||
|
||||
fn default_tts_provider() -> String {
|
||||
@ -1634,6 +1692,65 @@ fn service_selector_matches(selector: &str, service_key: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
const MCP_MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
fn validate_mcp_config(config: &McpConfig) -> Result<()> {
|
||||
let mut seen_names = std::collections::HashSet::new();
|
||||
for (i, server) in config.servers.iter().enumerate() {
|
||||
let name = server.name.trim();
|
||||
if name.is_empty() {
|
||||
anyhow::bail!("mcp.servers[{i}].name must not be empty");
|
||||
}
|
||||
if !seen_names.insert(name.to_ascii_lowercase()) {
|
||||
anyhow::bail!("mcp.servers contains duplicate name: {name}");
|
||||
}
|
||||
|
||||
if let Some(timeout) = server.tool_timeout_secs {
|
||||
if timeout == 0 {
|
||||
anyhow::bail!("mcp.servers[{i}].tool_timeout_secs must be greater than 0");
|
||||
}
|
||||
if timeout > MCP_MAX_TOOL_TIMEOUT_SECS {
|
||||
anyhow::bail!(
|
||||
"mcp.servers[{i}].tool_timeout_secs exceeds max {MCP_MAX_TOOL_TIMEOUT_SECS}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match server.transport {
|
||||
McpTransport::Stdio => {
|
||||
if server.command.trim().is_empty() {
|
||||
anyhow::bail!(
|
||||
"mcp.servers[{i}] with transport=stdio requires non-empty command"
|
||||
);
|
||||
}
|
||||
}
|
||||
McpTransport::Http | McpTransport::Sse => {
|
||||
let url = server
|
||||
.url
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"mcp.servers[{i}] with transport={} requires url",
|
||||
match server.transport {
|
||||
McpTransport::Http => "http",
|
||||
McpTransport::Sse => "sse",
|
||||
McpTransport::Stdio => "stdio",
|
||||
}
|
||||
)
|
||||
})?;
|
||||
let parsed = reqwest::Url::parse(url)
|
||||
.with_context(|| format!("mcp.servers[{i}].url is not a valid URL"))?;
|
||||
if !matches!(parsed.scheme(), "http" | "https") {
|
||||
anyhow::bail!("mcp.servers[{i}].url must use http/https");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_proxy_url(field: &str, url: &str) -> Result<()> {
|
||||
let parsed = reqwest::Url::parse(url)
|
||||
.with_context(|| format!("Invalid {field} URL: '{url}' is not a valid URL"))?;
|
||||
@ -3885,6 +4002,7 @@ impl Default for Config {
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
tts: TtsConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4844,6 +4962,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// MCP
|
||||
if self.mcp.enabled {
|
||||
validate_mcp_config(&self.mcp)?;
|
||||
}
|
||||
|
||||
// Proxy (delegate to existing validation)
|
||||
self.proxy.validate()?;
|
||||
|
||||
@ -5850,6 +5973,7 @@ default_temperature = 0.7
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
tts: TtsConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
};
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
@ -6046,6 +6170,7 @@ tool_dispatcher = "xml"
|
||||
hardware: HardwareConfig::default(),
|
||||
transcription: TranscriptionConfig::default(),
|
||||
tts: TtsConfig::default(),
|
||||
mcp: McpConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await.unwrap();
|
||||
|
||||
@ -748,15 +748,25 @@ const PROMETHEUS_CONTENT_TYPE: &str = "text/plain; version=0.0.4; charset=utf-8"
|
||||
|
||||
/// GET /metrics — Prometheus text exposition format
|
||||
async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let body = if let Some(prom) = state
|
||||
.observer
|
||||
.as_ref()
|
||||
.as_any()
|
||||
.downcast_ref::<crate::observability::PrometheusObserver>()
|
||||
{
|
||||
prom.encode()
|
||||
} else {
|
||||
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
|
||||
let body = {
|
||||
#[cfg(feature = "metrics")]
|
||||
{
|
||||
if let Some(prom) = state
|
||||
.observer
|
||||
.as_ref()
|
||||
.as_any()
|
||||
.downcast_ref::<crate::observability::PrometheusObserver>()
|
||||
{
|
||||
prom.encode()
|
||||
} else {
|
||||
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "metrics"))]
|
||||
{
|
||||
let _ = &state;
|
||||
String::from("# Prometheus backend not enabled. Set [observability] backend = \"prometheus\" in config.\n")
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
@ -1739,6 +1749,7 @@ mod tests {
|
||||
assert!(text.contains("Prometheus backend not enabled"));
|
||||
}
|
||||
|
||||
#[cfg(feature = "metrics")]
|
||||
#[tokio::test]
|
||||
async fn metrics_endpoint_renders_prometheus_output() {
|
||||
let prom = Arc::new(crate::observability::PrometheusObserver::new());
|
||||
|
||||
@ -3,6 +3,7 @@ pub mod multi;
|
||||
pub mod noop;
|
||||
#[cfg(feature = "observability-otel")]
|
||||
pub mod otel;
|
||||
#[cfg(feature = "metrics")]
|
||||
pub mod prometheus;
|
||||
pub mod runtime_trace;
|
||||
pub mod traits;
|
||||
@ -15,6 +16,7 @@ pub use self::multi::MultiObserver;
|
||||
pub use noop::NoopObserver;
|
||||
#[cfg(feature = "observability-otel")]
|
||||
pub use otel::OtelObserver;
|
||||
#[cfg(feature = "metrics")]
|
||||
pub use prometheus::PrometheusObserver;
|
||||
pub use traits::{Observer, ObserverEvent};
|
||||
#[allow(unused_imports)]
|
||||
@ -26,7 +28,19 @@ use crate::config::ObservabilityConfig;
|
||||
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
|
||||
match config.backend.as_str() {
|
||||
"log" => Box::new(LogObserver::new()),
|
||||
"prometheus" => Box::new(PrometheusObserver::new()),
|
||||
"prometheus" => {
|
||||
#[cfg(feature = "metrics")]
|
||||
{
|
||||
Box::new(PrometheusObserver::new())
|
||||
}
|
||||
#[cfg(not(feature = "metrics"))]
|
||||
{
|
||||
tracing::warn!(
|
||||
"Prometheus backend requested but this build was compiled without `metrics`; falling back to noop."
|
||||
);
|
||||
Box::new(NoopObserver)
|
||||
}
|
||||
}
|
||||
"otel" | "opentelemetry" | "otlp" => {
|
||||
#[cfg(feature = "observability-otel")]
|
||||
match OtelObserver::new(
|
||||
@ -104,7 +118,12 @@ mod tests {
|
||||
backend: "prometheus".into(),
|
||||
..ObservabilityConfig::default()
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), "prometheus");
|
||||
let expected = if cfg!(feature = "metrics") {
|
||||
"prometheus"
|
||||
} else {
|
||||
"noop"
|
||||
};
|
||||
assert_eq!(create_observer(&cfg).name(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
transcription: crate::config::TranscriptionConfig::default(),
|
||||
tts: crate::config::TtsConfig::default(),
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
@ -526,6 +527,7 @@ async fn run_quick_setup_with_home(
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
transcription: crate::config::TranscriptionConfig::default(),
|
||||
tts: crate::config::TtsConfig::default(),
|
||||
mcp: crate::config::McpConfig::default(),
|
||||
};
|
||||
|
||||
config.save().await?;
|
||||
|
||||
@ -27,7 +27,7 @@ struct ChatRequest {
|
||||
tools: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@ -40,14 +40,14 @@ struct Message {
|
||||
tool_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct OutgoingToolCall {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: OutgoingFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct OutgoingFunction {
|
||||
name: String,
|
||||
arguments: serde_json::Value,
|
||||
@ -258,13 +258,31 @@ impl OllamaProvider {
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> ChatRequest {
|
||||
self.build_chat_request_with_think(
|
||||
messages,
|
||||
model,
|
||||
temperature,
|
||||
tools,
|
||||
self.reasoning_enabled,
|
||||
)
|
||||
}
|
||||
|
||||
/// Build a chat request with an explicit `think` value.
|
||||
fn build_chat_request_with_think(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
think: Option<bool>,
|
||||
) -> ChatRequest {
|
||||
ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
think: self.reasoning_enabled,
|
||||
think,
|
||||
tools: tools.map(|t| t.to_vec()),
|
||||
}
|
||||
}
|
||||
@ -396,17 +414,18 @@ impl OllamaProvider {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response.
|
||||
/// Pass `tools` to enable native function-calling for models that support it.
|
||||
async fn send_request(
|
||||
/// Send a single HTTP request to Ollama and parse the response.
|
||||
async fn send_request_inner(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
messages: &[Message],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
think: Option<bool>,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = self.build_chat_request(messages, model, temperature, tools);
|
||||
let request =
|
||||
self.build_chat_request_with_think(messages.to_vec(), model, temperature, tools, think);
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
@ -466,6 +485,59 @@ impl OllamaProvider {
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response.
|
||||
/// Pass `tools` to enable native function-calling for models that support it.
|
||||
///
|
||||
/// When `reasoning_enabled` (`think`) is set to `true`, the first request
|
||||
/// includes `think: true`. If that request fails (the model may not support
|
||||
/// the `think` parameter), we automatically retry once with `think` omitted
|
||||
/// so the call succeeds instead of entering an infinite retry loop.
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let result = self
|
||||
.send_request_inner(
|
||||
&messages,
|
||||
model,
|
||||
temperature,
|
||||
should_auth,
|
||||
tools,
|
||||
self.reasoning_enabled,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(resp) => Ok(resp),
|
||||
Err(first_err) if self.reasoning_enabled == Some(true) => {
|
||||
tracing::warn!(
|
||||
model = model,
|
||||
error = %first_err,
|
||||
"Ollama request failed with think=true; retrying without reasoning \
|
||||
(model may not support it)"
|
||||
);
|
||||
// Retry with think omitted from the request entirely.
|
||||
self.send_request_inner(&messages, model, temperature, should_auth, tools, None)
|
||||
.await
|
||||
.map_err(|retry_err| {
|
||||
// Both attempts failed — return the original error for clarity.
|
||||
tracing::error!(
|
||||
model = model,
|
||||
original_error = %first_err,
|
||||
retry_error = %retry_err,
|
||||
"Ollama request also failed without think; returning original error"
|
||||
);
|
||||
first_err
|
||||
})
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
|
||||
357
src/tools/mcp_client.rs
Normal file
357
src/tools/mcp_client.rs
Normal file
@ -0,0 +1,357 @@
|
||||
//! MCP (Model Context Protocol) client — connects to external tool servers.
|
||||
//!
|
||||
//! Supports multiple transports: stdio (spawn local process), HTTP, and SSE.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
use crate::config::schema::McpServerConfig;
|
||||
use crate::tools::mcp_protocol::{
|
||||
JsonRpcRequest, McpToolDef, McpToolsListResult, MCP_PROTOCOL_VERSION,
|
||||
};
|
||||
use crate::tools::mcp_transport::{create_transport, McpTransportConn};
|
||||
|
||||
/// Timeout for receiving a response from an MCP server during init/list.
|
||||
/// Prevents a hung server from blocking the daemon indefinitely.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Default timeout for tool calls (seconds) when not configured per-server.
|
||||
const DEFAULT_TOOL_TIMEOUT_SECS: u64 = 180;
|
||||
|
||||
/// Maximum allowed tool call timeout (seconds) — hard safety ceiling.
|
||||
const MAX_TOOL_TIMEOUT_SECS: u64 = 600;
|
||||
|
||||
// ── Internal server state ──────────────────────────────────────────────────
|
||||
|
||||
struct McpServerInner {
|
||||
config: McpServerConfig,
|
||||
transport: Box<dyn McpTransportConn>,
|
||||
next_id: AtomicU64,
|
||||
tools: Vec<McpToolDef>,
|
||||
}
|
||||
|
||||
// ── McpServer ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// A live connection to one MCP server (any transport).
|
||||
#[derive(Clone)]
|
||||
pub struct McpServer {
|
||||
inner: Arc<Mutex<McpServerInner>>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
/// Connect to the server, perform the initialize handshake, and fetch the tool list.
|
||||
pub async fn connect(config: McpServerConfig) -> Result<Self> {
|
||||
// Create transport based on config
|
||||
let mut transport = create_transport(&config).with_context(|| {
|
||||
format!(
|
||||
"failed to create transport for MCP server `{}`",
|
||||
config.name
|
||||
)
|
||||
})?;
|
||||
|
||||
// Initialize handshake
|
||||
let id = 1u64;
|
||||
let init_req = JsonRpcRequest::new(
|
||||
id,
|
||||
"initialize",
|
||||
json!({
|
||||
"protocolVersion": MCP_PROTOCOL_VERSION,
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "zeroclaw",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let init_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
transport.send_and_recv(&init_req),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` timed out after {}s waiting for initialize response",
|
||||
config.name, RECV_TIMEOUT_SECS
|
||||
)
|
||||
})??;
|
||||
|
||||
if init_resp.error.is_some() {
|
||||
bail!(
|
||||
"MCP server `{}` rejected initialize: {:?}",
|
||||
config.name,
|
||||
init_resp.error
|
||||
);
|
||||
}
|
||||
|
||||
// Notify server that client is initialized (no response expected for notifications)
|
||||
let notif = JsonRpcRequest::notification("notifications/initialized", json!({}));
|
||||
// Best effort - ignore errors for notifications
|
||||
let _ = transport.send_and_recv(¬if).await;
|
||||
|
||||
// Fetch available tools
|
||||
let id = 2u64;
|
||||
let list_req = JsonRpcRequest::new(id, "tools/list", json!({}));
|
||||
|
||||
let list_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
transport.send_and_recv(&list_req),
|
||||
)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` timed out after {}s waiting for tools/list response",
|
||||
config.name, RECV_TIMEOUT_SECS
|
||||
)
|
||||
})??;
|
||||
|
||||
let result = list_resp
|
||||
.result
|
||||
.ok_or_else(|| anyhow!("tools/list returned no result from `{}`", config.name))?;
|
||||
let tool_list: McpToolsListResult = serde_json::from_value(result)
|
||||
.with_context(|| format!("failed to parse tools/list from `{}`", config.name))?;
|
||||
|
||||
let tool_count = tool_list.tools.len();
|
||||
|
||||
let inner = McpServerInner {
|
||||
config,
|
||||
transport,
|
||||
next_id: AtomicU64::new(3), // Start at 3 since we used 1 and 2
|
||||
tools: tool_list.tools,
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"MCP server `{}` connected — {} tool(s) available",
|
||||
inner.config.name,
|
||||
tool_count
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
inner: Arc::new(Mutex::new(inner)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Tools advertised by this server.
|
||||
pub async fn tools(&self) -> Vec<McpToolDef> {
|
||||
self.inner.lock().await.tools.clone()
|
||||
}
|
||||
|
||||
/// Server display name.
|
||||
#[allow(dead_code)]
|
||||
pub async fn name(&self) -> String {
|
||||
self.inner.lock().await.config.name.clone()
|
||||
}
|
||||
|
||||
/// Call a tool on this server. Returns the raw JSON result.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
let mut inner = self.inner.lock().await;
|
||||
let id = inner.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let req = JsonRpcRequest::new(
|
||||
id,
|
||||
"tools/call",
|
||||
json!({ "name": tool_name, "arguments": arguments }),
|
||||
);
|
||||
|
||||
// Use per-server tool timeout if configured, otherwise default.
|
||||
// Cap at MAX_TOOL_TIMEOUT_SECS for safety.
|
||||
let tool_timeout = inner
|
||||
.config
|
||||
.tool_timeout_secs
|
||||
.unwrap_or(DEFAULT_TOOL_TIMEOUT_SECS)
|
||||
.min(MAX_TOOL_TIMEOUT_SECS);
|
||||
|
||||
let resp = timeout(
|
||||
Duration::from_secs(tool_timeout),
|
||||
inner.transport.send_and_recv(&req),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow!(
|
||||
"MCP server `{}` timed out after {}s during tool call `{tool_name}`",
|
||||
inner.config.name,
|
||||
tool_timeout
|
||||
)
|
||||
})?
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"MCP server `{}` error during tool call `{tool_name}`",
|
||||
inner.config.name
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(err) = resp.error {
|
||||
bail!("MCP tool `{tool_name}` error {}: {}", err.code, err.message);
|
||||
}
|
||||
Ok(resp.result.unwrap_or(serde_json::Value::Null))
|
||||
}
|
||||
}
|
||||
|
||||
// ── McpRegistry ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Registry of all connected MCP servers, with a flat tool index.
|
||||
pub struct McpRegistry {
|
||||
servers: Vec<McpServer>,
|
||||
/// prefixed_name -> (server_index, original_tool_name)
|
||||
tool_index: HashMap<String, (usize, String)>,
|
||||
}
|
||||
|
||||
impl McpRegistry {
|
||||
/// Connect to all configured servers. Non-fatal: failures are logged and skipped.
|
||||
pub async fn connect_all(configs: &[McpServerConfig]) -> Result<Self> {
|
||||
let mut servers = Vec::new();
|
||||
let mut tool_index = HashMap::new();
|
||||
|
||||
for config in configs {
|
||||
match McpServer::connect(config.clone()).await {
|
||||
Ok(server) => {
|
||||
let server_idx = servers.len();
|
||||
// Collect tools while holding the lock once, then release
|
||||
let tools = server.tools().await;
|
||||
for tool in &tools {
|
||||
// Prefix prevents name collisions across servers
|
||||
let prefixed = format!("{}__{}", config.name, tool.name);
|
||||
tool_index.insert(prefixed, (server_idx, tool.name.clone()));
|
||||
}
|
||||
servers.push(server);
|
||||
}
|
||||
// Non-fatal — log and continue with remaining servers
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to connect to MCP server `{}`: {:#}", config.name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
servers,
|
||||
tool_index,
|
||||
})
|
||||
}
|
||||
|
||||
/// All prefixed tool names across all connected servers.
|
||||
pub fn tool_names(&self) -> Vec<String> {
|
||||
self.tool_index.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Tool definition for a given prefixed name (cloned).
|
||||
pub async fn get_tool_def(&self, prefixed_name: &str) -> Option<McpToolDef> {
|
||||
let (server_idx, original_name) = self.tool_index.get(prefixed_name)?;
|
||||
let inner = self.servers[*server_idx].inner.lock().await;
|
||||
inner
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| &t.name == original_name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Execute a tool by prefixed name.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
prefixed_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<String> {
|
||||
let (server_idx, original_name) = self
|
||||
.tool_index
|
||||
.get(prefixed_name)
|
||||
.ok_or_else(|| anyhow!("unknown MCP tool `{prefixed_name}`"))?;
|
||||
let result = self.servers[*server_idx]
|
||||
.call_tool(original_name, arguments)
|
||||
.await?;
|
||||
serde_json::to_string_pretty(&result)
|
||||
.with_context(|| format!("failed to serialize result of MCP tool `{prefixed_name}`"))
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
|
||||
pub fn server_count(&self) -> usize {
|
||||
self.servers.len()
|
||||
}
|
||||
|
||||
pub fn tool_count(&self) -> usize {
|
||||
self.tool_index.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::McpTransport;
|
||||
|
||||
#[test]
|
||||
fn tool_name_prefix_format() {
|
||||
let prefixed = format!("{}__{}", "filesystem", "read_file");
|
||||
assert_eq!(prefixed, "filesystem__read_file");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_nonexistent_command_fails_cleanly() {
|
||||
// A command that doesn't exist should fail at spawn, not panic.
|
||||
let config = McpServerConfig {
|
||||
name: "nonexistent".to_string(),
|
||||
command: "/usr/bin/this_binary_does_not_exist_zeroclaw_test".to_string(),
|
||||
args: vec![],
|
||||
env: HashMap::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: HashMap::default(),
|
||||
};
|
||||
let result = McpServer::connect(config).await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.err().unwrap().to_string();
|
||||
assert!(msg.contains("failed to create transport"), "got: {msg}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_all_nonfatal_on_single_failure() {
|
||||
// If one server config is bad, connect_all should succeed (with 0 servers).
|
||||
let configs = vec![McpServerConfig {
|
||||
name: "bad".to_string(),
|
||||
command: "/usr/bin/does_not_exist_zc_test".to_string(),
|
||||
args: vec![],
|
||||
env: HashMap::default(),
|
||||
tool_timeout_secs: None,
|
||||
transport: McpTransport::Stdio,
|
||||
url: None,
|
||||
headers: HashMap::default(),
|
||||
}];
|
||||
let registry = McpRegistry::connect_all(&configs)
|
||||
.await
|
||||
.expect("connect_all should not fail");
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.tool_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn http_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Http,
|
||||
..Default::default()
|
||||
};
|
||||
let result = create_transport(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sse_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Sse,
|
||||
..Default::default()
|
||||
};
|
||||
let result = create_transport(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
130
src/tools/mcp_protocol.rs
Normal file
130
src/tools/mcp_protocol.rs
Normal file
@ -0,0 +1,130 @@
|
||||
//! MCP (Model Context Protocol) JSON-RPC 2.0 protocol types.
|
||||
//! Protocol version: 2024-11-05
|
||||
//! Adapted from ops-mcp-server/src/protocol.rs for client use.
|
||||
//! Both Serialize and Deserialize are derived — the client both sends (Serialize)
|
||||
//! and receives (Deserialize) JSON-RPC messages.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
|
||||
|
||||
// Standard JSON-RPC 2.0 error codes
|
||||
#[allow(dead_code)]
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
#[allow(dead_code)]
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
#[allow(dead_code)]
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
#[allow(dead_code)]
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
/// Outbound JSON-RPC request (client -> MCP server).
|
||||
/// Used for both method calls (with id) and notifications (id = None).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
/// Create a method call request with a numeric id.
|
||||
pub fn new(id: u64, method: impl Into<String>, params: serde_json::Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id: Some(serde_json::Value::Number(id.into())),
|
||||
method: method.into(),
|
||||
params: Some(params),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a notification — no id, no response expected from server.
|
||||
pub fn notification(method: impl Into<String>, params: serde_json::Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
method: method.into(),
|
||||
params: Some(params),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Inbound JSON-RPC response (MCP server -> client).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
/// JSON-RPC error object embedded in a response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// A tool advertised by an MCP server (from `tools/list` response).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolDef {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "inputSchema")]
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Expected shape of the `tools/list` result payload.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct McpToolsListResult {
|
||||
pub tools: Vec<McpToolDef>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_id() {
|
||||
let req = JsonRpcRequest::new(1, "tools/list", serde_json::json!({}));
|
||||
let s = serde_json::to_string(&req).unwrap();
|
||||
assert!(s.contains("\"id\":1"));
|
||||
assert!(s.contains("\"method\":\"tools/list\""));
|
||||
assert!(s.contains("\"jsonrpc\":\"2.0\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notification_omits_id() {
|
||||
let notif =
|
||||
JsonRpcRequest::notification("notifications/initialized", serde_json::json!({}));
|
||||
let s = serde_json::to_string(¬if).unwrap();
|
||||
assert!(!s.contains("\"id\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.result.is_some());
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_def_deserializes_input_schema() {
|
||||
let json = r#"{"name":"read_file","description":"Read a file","inputSchema":{"type":"object","properties":{"path":{"type":"string"}}}}"#;
|
||||
let def: McpToolDef = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(def.name, "read_file");
|
||||
assert!(def.input_schema.is_object());
|
||||
}
|
||||
}
|
||||
68
src/tools/mcp_tool.rs
Normal file
68
src/tools/mcp_tool.rs
Normal file
@ -0,0 +1,68 @@
|
||||
//! Wraps a discovered MCP tool as a zeroclaw [`Tool`] so it is dispatched
|
||||
//! through the existing tool registry and agent loop without modification.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::tools::mcp_client::McpRegistry;
|
||||
use crate::tools::mcp_protocol::McpToolDef;
|
||||
use crate::tools::traits::{Tool, ToolResult};
|
||||
|
||||
/// A zeroclaw [`Tool`] backed by an MCP server tool.
|
||||
///
|
||||
/// The `prefixed_name` (e.g. `filesystem__read_file`) is what the agent loop
|
||||
/// sees. The registry knows how to route it to the correct server.
|
||||
pub struct McpToolWrapper {
|
||||
/// Prefixed name: `<server_name>__<tool_name>`.
|
||||
prefixed_name: String,
|
||||
/// Description extracted from the MCP tool definition. Stored as an owned
|
||||
/// String so that `description()` can return `&str` with self's lifetime.
|
||||
description: String,
|
||||
/// JSON schema for the tool's input parameters.
|
||||
input_schema: serde_json::Value,
|
||||
/// Shared registry — used to dispatch actual tool calls.
|
||||
registry: Arc<McpRegistry>,
|
||||
}
|
||||
|
||||
impl McpToolWrapper {
|
||||
pub fn new(prefixed_name: String, def: McpToolDef, registry: Arc<McpRegistry>) -> Self {
|
||||
let description = def.description.unwrap_or_else(|| "MCP tool".to_string());
|
||||
Self {
|
||||
prefixed_name,
|
||||
description,
|
||||
input_schema: def.input_schema,
|
||||
registry,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for McpToolWrapper {
|
||||
fn name(&self) -> &str {
|
||||
&self.prefixed_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.input_schema.clone()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
match self.registry.call_tool(&self.prefixed_name, args).await {
|
||||
Ok(output) => Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
868
src/tools/mcp_transport.rs
Normal file
868
src/tools/mcp_transport.rs
Normal file
@ -0,0 +1,868 @@
|
||||
//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::{oneshot, Mutex, Notify};
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use crate::config::schema::{McpServerConfig, McpTransport};
|
||||
use crate::tools::mcp_protocol::{JsonRpcError, JsonRpcRequest, JsonRpcResponse, INTERNAL_ERROR};
|
||||
|
||||
/// Maximum bytes for a single JSON-RPC response.
|
||||
const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
||||
|
||||
/// Timeout for init/list operations.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
// ── Transport Trait ──────────────────────────────────────────────────────
|
||||
|
||||
/// Abstract transport for MCP communication.
|
||||
#[async_trait::async_trait]
|
||||
pub trait McpTransportConn: Send + Sync {
|
||||
/// Send a JSON-RPC request and receive the response.
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
|
||||
|
||||
/// Close the connection.
|
||||
async fn close(&mut self) -> Result<()>;
|
||||
}
|
||||
|
||||
// ── Stdio Transport ──────────────────────────────────────────────────────
|
||||
|
||||
/// Stdio-based transport (spawn local process).
|
||||
pub struct StdioTransport {
|
||||
_child: Child,
|
||||
stdin: tokio::process::ChildStdin,
|
||||
stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
|
||||
}
|
||||
|
||||
impl StdioTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let mut child = Command::new(&config.command)
|
||||
.args(&config.args)
|
||||
.envs(&config.env)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.spawn()
|
||||
.with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
|
||||
|
||||
let stdin = child
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
|
||||
let stdout_lines = BufReader::new(stdout).lines();
|
||||
|
||||
Ok(Self {
|
||||
_child: child,
|
||||
stdin,
|
||||
stdout_lines,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_raw(&mut self, line: &str) -> Result<()> {
|
||||
self.stdin
|
||||
.write_all(line.as_bytes())
|
||||
.await
|
||||
.context("failed to write to MCP server stdin")?;
|
||||
self.stdin
|
||||
.write_all(b"\n")
|
||||
.await
|
||||
.context("failed to write newline to MCP server stdin")?;
|
||||
self.stdin.flush().await.context("failed to flush stdin")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_raw(&mut self) -> Result<String> {
|
||||
let line = self
|
||||
.stdout_lines
|
||||
.next_line()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow!("MCP server closed stdout"))?;
|
||||
if line.len() > MAX_LINE_BYTES {
|
||||
bail!("MCP response too large: {} bytes", line.len());
|
||||
}
|
||||
Ok(line)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for StdioTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let line = serde_json::to_string(request)?;
|
||||
self.send_raw(&line).await?;
|
||||
if request.id.is_none() {
|
||||
return Ok(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
let resp_line = timeout(Duration::from_secs(RECV_TIMEOUT_SECS), self.recv_raw())
|
||||
.await
|
||||
.context("timeout waiting for MCP response")??;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
let _ = self.stdin.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── HTTP Transport ───────────────────────────────────────────────────────
|
||||
|
||||
/// HTTP-based transport (POST requests).
|
||||
pub struct HttpTransport {
|
||||
url: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl HttpTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for HTTP transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
url,
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for HttpTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let mut req = self.client.post(&self.url).body(body);
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.context("HTTP request to MCP server failed")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("MCP server returned HTTP {}", resp.status());
|
||||
}
|
||||
|
||||
if request.id.is_none() {
|
||||
return Ok(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
|
||||
Ok(mcp_resp)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── SSE Transport ─────────────────────────────────────────────────────────
|
||||
|
||||
/// SSE-based transport (HTTP POST for requests, SSE for responses).
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
enum SseStreamState {
|
||||
Unknown,
|
||||
Connected,
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
pub struct SseTransport {
|
||||
sse_url: String,
|
||||
server_name: String,
|
||||
client: reqwest::Client,
|
||||
headers: std::collections::HashMap<String, String>,
|
||||
stream_state: SseStreamState,
|
||||
shared: std::sync::Arc<Mutex<SseSharedState>>,
|
||||
notify: std::sync::Arc<Notify>,
|
||||
shutdown_tx: Option<oneshot::Sender<()>>,
|
||||
reader_task: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl SseTransport {
|
||||
pub fn new(config: &McpServerConfig) -> Result<Self> {
|
||||
let sse_url = config
|
||||
.url
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("URL required for SSE transport"))?
|
||||
.clone();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.build()
|
||||
.context("failed to build HTTP client")?;
|
||||
|
||||
Ok(Self {
|
||||
sse_url,
|
||||
server_name: config.name.clone(),
|
||||
client,
|
||||
headers: config.headers.clone(),
|
||||
stream_state: SseStreamState::Unknown,
|
||||
shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
|
||||
notify: std::sync::Arc::new(Notify::new()),
|
||||
shutdown_tx: None,
|
||||
reader_task: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn ensure_connected(&mut self) -> Result<()> {
|
||||
if self.stream_state == SseStreamState::Unsupported {
|
||||
return Ok(());
|
||||
}
|
||||
if let Some(task) = &self.reader_task {
|
||||
if !task.is_finished() {
|
||||
self.stream_state = SseStreamState::Connected;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&self.sse_url)
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE GET to MCP server failed")?;
|
||||
if resp.status() == reqwest::StatusCode::NOT_FOUND
|
||||
|| resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
|
||||
{
|
||||
self.stream_state = SseStreamState::Unsupported;
|
||||
return Ok(());
|
||||
}
|
||||
if !resp.status().is_success() {
|
||||
return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
|
||||
}
|
||||
let is_event_stream = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
if !is_event_stream {
|
||||
self.stream_state = SseStreamState::Unsupported;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
let shared = self.shared.clone();
|
||||
let notify = self.notify.clone();
|
||||
let sse_url = self.sse_url.clone();
|
||||
let server_name = self.server_name.clone();
|
||||
|
||||
self.reader_task = Some(tokio::spawn(async move {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|item| item.map_err(std::io::Error::other));
|
||||
let reader = tokio_util::io::StreamReader::new(stream);
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
let mut cur_event: Option<String> = None;
|
||||
let mut cur_id: Option<String> = None;
|
||||
let mut cur_data: Vec<String> = Vec::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut shutdown_rx => {
|
||||
break;
|
||||
}
|
||||
line = lines.next_line() => {
|
||||
let Ok(line_opt) = line else { break; };
|
||||
let Some(mut line) = line_opt else { break; };
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
if line.is_empty() {
|
||||
if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let event = cur_event.take();
|
||||
let data = cur_data.join("\n");
|
||||
cur_data.clear();
|
||||
let id = cur_id.take();
|
||||
handle_sse_event(
|
||||
&server_name,
|
||||
&sse_url,
|
||||
&shared,
|
||||
¬ify,
|
||||
event.as_deref(),
|
||||
id.as_deref(),
|
||||
data,
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
cur_event = Some(rest.trim().to_string());
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
cur_data.push(rest.to_string());
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("id:") {
|
||||
cur_id = Some(rest.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let pending = {
|
||||
let mut guard = shared.lock().await;
|
||||
std::mem::take(&mut guard.pending)
|
||||
};
|
||||
for (_, tx) in pending {
|
||||
let _ = tx.send(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: INTERNAL_ERROR,
|
||||
message: "SSE connection closed".to_string(),
|
||||
data: None,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}));
|
||||
self.stream_state = SseStreamState::Connected;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_message_url(&self) -> Result<(String, bool)> {
|
||||
let guard = self.shared.lock().await;
|
||||
if let Some(url) = &guard.message_url {
|
||||
return Ok((url.clone(), guard.message_url_from_endpoint));
|
||||
}
|
||||
drop(guard);
|
||||
|
||||
let derived = derive_message_url(&self.sse_url, "messages")
|
||||
.or_else(|| derive_message_url(&self.sse_url, "message"))
|
||||
.ok_or_else(|| anyhow!("invalid SSE URL"))?;
|
||||
let mut guard = self.shared.lock().await;
|
||||
if guard.message_url.is_none() {
|
||||
guard.message_url = Some(derived.clone());
|
||||
guard.message_url_from_endpoint = false;
|
||||
}
|
||||
Ok((derived, false))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SseSharedState {
|
||||
message_url: Option<String>,
|
||||
message_url_from_endpoint: bool,
|
||||
pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
|
||||
}
|
||||
|
||||
fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
|
||||
let url = reqwest::Url::parse(sse_url).ok()?;
|
||||
let mut segments: Vec<&str> = url.path_segments()?.collect();
|
||||
if segments.is_empty() {
|
||||
return None;
|
||||
}
|
||||
if segments.last().copied() == Some("sse") {
|
||||
segments.pop();
|
||||
segments.push(message_path);
|
||||
let mut new_url = url.clone();
|
||||
new_url.set_path(&format!("/{}", segments.join("/")));
|
||||
return Some(new_url.to_string());
|
||||
}
|
||||
let mut new_url = url.clone();
|
||||
let mut path = url.path().trim_end_matches('/').to_string();
|
||||
path.push('/');
|
||||
path.push_str(message_path);
|
||||
new_url.set_path(&path);
|
||||
Some(new_url.to_string())
|
||||
}
|
||||
|
||||
async fn handle_sse_event(
|
||||
server_name: &str,
|
||||
sse_url: &str,
|
||||
shared: &std::sync::Arc<Mutex<SseSharedState>>,
|
||||
notify: &std::sync::Arc<Notify>,
|
||||
event: Option<&str>,
|
||||
_id: Option<&str>,
|
||||
data: String,
|
||||
) {
|
||||
let event = event.unwrap_or("message");
|
||||
let trimmed = data.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
|
||||
if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
|
||||
let mut guard = shared.lock().await;
|
||||
guard.message_url = Some(url);
|
||||
guard.message_url_from_endpoint = true;
|
||||
drop(guard);
|
||||
notify.notify_waiters();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if !event.eq_ignore_ascii_case("message") {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
|
||||
let _ = serde_json::from_value::<JsonRpcRequest>(value);
|
||||
return;
|
||||
};
|
||||
|
||||
let Some(id_val) = resp.id.clone() else {
|
||||
return;
|
||||
};
|
||||
let id = match id_val.as_u64() {
|
||||
Some(v) => v,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let tx = {
|
||||
let mut guard = shared.lock().await;
|
||||
guard.pending.remove(&id)
|
||||
};
|
||||
if let Some(tx) = tx {
|
||||
let _ = tx.send(resp);
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"MCP SSE `{}` received response for unknown id {}",
|
||||
server_name,
|
||||
id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
|
||||
if data.starts_with('{') {
|
||||
let v: serde_json::Value = serde_json::from_str(data).ok()?;
|
||||
let endpoint = v.get("endpoint")?.as_str()?;
|
||||
return parse_endpoint_from_data(sse_url, endpoint);
|
||||
}
|
||||
if data.starts_with("http://") || data.starts_with("https://") {
|
||||
return Some(data.to_string());
|
||||
}
|
||||
let base = reqwest::Url::parse(sse_url).ok()?;
|
||||
base.join(data).ok().map(|u| u.to_string())
|
||||
}
|
||||
|
||||
fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
|
||||
let text = resp_text.trim_start_matches('\u{feff}');
|
||||
let mut current_data_lines: Vec<&str> = Vec::new();
|
||||
let mut last_event_data_lines: Vec<&str> = Vec::new();
|
||||
|
||||
for raw_line in text.lines() {
|
||||
let line = raw_line.trim_end_matches('\r').trim_start();
|
||||
if line.is_empty() {
|
||||
if !current_data_lines.is_empty() {
|
||||
last_event_data_lines = std::mem::take(&mut current_data_lines);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
current_data_lines.push(rest);
|
||||
}
|
||||
}
|
||||
|
||||
if !current_data_lines.is_empty() {
|
||||
last_event_data_lines = current_data_lines;
|
||||
}
|
||||
|
||||
if last_event_data_lines.is_empty() {
|
||||
return Cow::Borrowed(text.trim());
|
||||
}
|
||||
|
||||
if last_event_data_lines.len() == 1 {
|
||||
return Cow::Borrowed(last_event_data_lines[0].trim());
|
||||
}
|
||||
|
||||
let joined = last_event_data_lines.join("\n");
|
||||
Cow::Owned(joined.trim().to_string())
|
||||
}
|
||||
|
||||
async fn read_first_jsonrpc_from_sse_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<Option<JsonRpcResponse>> {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map(|item| item.map_err(std::io::Error::other));
|
||||
let reader = tokio_util::io::StreamReader::new(stream);
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
let mut cur_event: Option<String> = None;
|
||||
let mut cur_data: Vec<String> = Vec::new();
|
||||
|
||||
while let Ok(line_opt) = lines.next_line().await {
|
||||
let Some(mut line) = line_opt else { break };
|
||||
if line.ends_with('\r') {
|
||||
line.pop();
|
||||
}
|
||||
if line.is_empty() {
|
||||
if cur_event.is_none() && cur_data.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let event = cur_event.take();
|
||||
let data = cur_data.join("\n");
|
||||
cur_data.clear();
|
||||
|
||||
let event = event.unwrap_or_else(|| "message".to_string());
|
||||
if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if !event.eq_ignore_ascii_case("message") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let trimmed = data.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let json_str = extract_json_from_sse_text(trimmed);
|
||||
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
|
||||
return Ok(Some(resp));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
cur_event = Some(rest.trim().to_string());
|
||||
continue;
|
||||
}
|
||||
if let Some(rest) = line.strip_prefix("data:") {
|
||||
let rest = rest.strip_prefix(' ').unwrap_or(rest);
|
||||
cur_data.push(rest.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl McpTransportConn for SseTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
self.ensure_connected().await?;
|
||||
|
||||
let id = request.id.as_ref().and_then(|v| v.as_u64());
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
|
||||
if self.stream_state == SseStreamState::Connected && !from_endpoint {
|
||||
for _ in 0..3 {
|
||||
{
|
||||
let guard = self.shared.lock().await;
|
||||
if guard.message_url_from_endpoint {
|
||||
if let Some(url) = &guard.message_url {
|
||||
message_url = url.clone();
|
||||
from_endpoint = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
|
||||
}
|
||||
}
|
||||
let primary_url = if from_endpoint {
|
||||
message_url.clone()
|
||||
} else {
|
||||
self.sse_url.clone()
|
||||
};
|
||||
let secondary_url = if message_url == self.sse_url {
|
||||
None
|
||||
} else if primary_url == message_url {
|
||||
Some(self.sse_url.clone())
|
||||
} else {
|
||||
Some(message_url.clone())
|
||||
};
|
||||
let has_secondary = secondary_url.is_some();
|
||||
|
||||
let mut rx = None;
|
||||
if let Some(id) = id {
|
||||
if self.stream_state == SseStreamState::Connected {
|
||||
let (tx, ch) = oneshot::channel();
|
||||
{
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.insert(id, tx);
|
||||
}
|
||||
rx = Some((id, ch));
|
||||
}
|
||||
}
|
||||
|
||||
let mut got_direct = None;
|
||||
let mut last_status = None;
|
||||
|
||||
for (i, url) in std::iter::once(primary_url)
|
||||
.chain(secondary_url.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(120))
|
||||
.body(body.clone())
|
||||
.header("Content-Type", "application/json");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"))
|
||||
{
|
||||
req = req.header("Accept", "application/json, text/event-stream");
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
let status = resp.status();
|
||||
last_status = Some(status);
|
||||
|
||||
if (status == reqwest::StatusCode::NOT_FOUND
|
||||
|| status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
|
||||
&& i == 0
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
break;
|
||||
}
|
||||
|
||||
if request.id.is_none() {
|
||||
got_direct = Some(JsonRpcResponse {
|
||||
jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
|
||||
id: None,
|
||||
result: None,
|
||||
error: None,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
let is_sse = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
|
||||
if is_sse {
|
||||
if i == 0 && has_secondary {
|
||||
match timeout(
|
||||
Duration::from_secs(3),
|
||||
read_first_jsonrpc_from_sse_response(resp),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if let Some(resp) = res? {
|
||||
got_direct = Some(resp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
|
||||
got_direct = Some(resp);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
let text = if i == 0 && has_secondary {
|
||||
match timeout(Duration::from_secs(3), resp.text()).await {
|
||||
Ok(Ok(t)) => t,
|
||||
Ok(Err(_)) => String::new(),
|
||||
Err(_) => continue,
|
||||
}
|
||||
} else {
|
||||
resp.text().await.unwrap_or_default()
|
||||
};
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
|
||||
extract_json_from_sse_text(trimmed)
|
||||
} else {
|
||||
Cow::Borrowed(trimmed)
|
||||
};
|
||||
if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
|
||||
got_direct = Some(mcp_resp);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some((id, _)) = rx.as_ref() {
|
||||
if got_direct.is_some() {
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.remove(id);
|
||||
} else if let Some(status) = last_status {
|
||||
if !status.is_success() {
|
||||
let mut guard = self.shared.lock().await;
|
||||
guard.pending.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(resp) = got_direct {
|
||||
return Ok(resp);
|
||||
}
|
||||
|
||||
if let Some(status) = last_status {
|
||||
if !status.is_success() {
|
||||
bail!("MCP server returned HTTP {}", status);
|
||||
}
|
||||
} else {
|
||||
bail!("MCP request not sent");
|
||||
}
|
||||
|
||||
let Some((_id, rx)) = rx else {
|
||||
bail!("MCP server returned no response");
|
||||
};
|
||||
|
||||
rx.await.map_err(|_| anyhow!("SSE response channel closed"))
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
if let Some(tx) = self.shutdown_tx.take() {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
if let Some(task) = self.reader_task.take() {
|
||||
task.abort();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Factory ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Create a transport based on config.
|
||||
pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
|
||||
match config.transport {
|
||||
McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
|
||||
McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
|
||||
McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_transport_default_is_stdio() {
|
||||
let config = McpServerConfig::default();
|
||||
assert_eq!(config.transport, McpTransport::Stdio);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_http_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Http,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(HttpTransport::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_transport_requires_url() {
|
||||
let config = McpServerConfig {
|
||||
name: "test".into(),
|
||||
transport: McpTransport::Sse,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(SseTransport::new(&config).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_data_no_space() {
|
||||
let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_with_event_and_id() {
|
||||
let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_multiline_data() {
|
||||
let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
|
||||
let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_json_from_sse_uses_last_event_with_data() {
|
||||
let input =
|
||||
": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
}
|
||||
@ -40,6 +40,10 @@ pub mod hardware_memory_map;
|
||||
pub mod hardware_memory_read;
|
||||
pub mod http_request;
|
||||
pub mod image_info;
|
||||
pub mod mcp_client;
|
||||
pub mod mcp_protocol;
|
||||
pub mod mcp_tool;
|
||||
pub mod mcp_transport;
|
||||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
|
||||
27
web/src/lib/uuid.ts
Normal file
27
web/src/lib/uuid.ts
Normal file
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Generate a UUID v4 string.
|
||||
*
|
||||
* Uses `crypto.randomUUID()` when available (modern browsers, secure contexts)
|
||||
* and falls back to a manual implementation backed by `crypto.getRandomValues()`
|
||||
* for older browsers (e.g. Safari < 15.4, some Electron/Raspberry-Pi builds).
|
||||
*
|
||||
* Closes #3303, #3261.
|
||||
*/
|
||||
export function generateUUID(): string {
|
||||
if (typeof crypto !== 'undefined' && typeof crypto.randomUUID === 'function') {
|
||||
return crypto.randomUUID();
|
||||
}
|
||||
|
||||
// Fallback: RFC 4122 version 4 UUID via getRandomValues
|
||||
// crypto must exist if we reached here (only randomUUID is missing)
|
||||
const c = globalThis.crypto;
|
||||
const bytes = new Uint8Array(16);
|
||||
c.getRandomValues(bytes);
|
||||
|
||||
// Set version (4) and variant (10xx) bits per RFC 4122
|
||||
bytes[6] = (bytes[6]! & 0x0f) | 0x40;
|
||||
bytes[8] = (bytes[8]! & 0x3f) | 0x80;
|
||||
|
||||
const hex = Array.from(bytes, (b) => b.toString(16).padStart(2, '0')).join('');
|
||||
return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20)}`;
|
||||
}
|
||||
@ -1,5 +1,6 @@
|
||||
import type { WsMessage } from '../types/api';
|
||||
import { getToken } from './auth';
|
||||
import { generateUUID } from './uuid';
|
||||
|
||||
export type WsMessageHandler = (msg: WsMessage) => void;
|
||||
export type WsOpenHandler = () => void;
|
||||
@ -26,7 +27,7 @@ const SESSION_STORAGE_KEY = 'zeroclaw_session_id';
|
||||
function getOrCreateSessionId(): string {
|
||||
let id = sessionStorage.getItem(SESSION_STORAGE_KEY);
|
||||
if (!id) {
|
||||
id = crypto.randomUUID();
|
||||
id = generateUUID();
|
||||
sessionStorage.setItem(SESSION_STORAGE_KEY, id);
|
||||
}
|
||||
return id;
|
||||
|
||||
@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { Send, Bot, User, AlertCircle, Copy, Check } from 'lucide-react';
|
||||
import type { WsMessage } from '@/types/api';
|
||||
import { WebSocketClient } from '@/lib/ws';
|
||||
import { generateUUID } from '@/lib/uuid';
|
||||
|
||||
interface ChatMessage {
|
||||
id: string;
|
||||
@ -53,7 +54,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content,
|
||||
timestamp: new Date(),
|
||||
@ -69,7 +70,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content: `[Tool Call] ${msg.name ?? 'unknown'}(${JSON.stringify(msg.args ?? {})})`,
|
||||
timestamp: new Date(),
|
||||
@ -81,7 +82,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content: `[Tool Result] ${msg.output ?? ''}`,
|
||||
timestamp: new Date(),
|
||||
@ -93,7 +94,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'agent',
|
||||
content: `[Error] ${msg.message ?? 'Unknown error'}`,
|
||||
timestamp: new Date(),
|
||||
@ -124,7 +125,7 @@ export default function AgentChat() {
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
id: generateUUID(),
|
||||
role: 'user',
|
||||
content: trimmed,
|
||||
timestamp: new Date(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user