merge: resolve conflicts with master (image_gen + sessions)
This commit is contained in:
@@ -373,6 +373,7 @@ impl Agent {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
config,
|
||||
None,
|
||||
);
|
||||
|
||||
// ── Wire MCP tools (non-fatal) ─────────────────────────────
|
||||
|
||||
@@ -3539,6 +3539,7 @@ pub async fn run(
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
@@ -3833,6 +3834,8 @@ pub async fn run(
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
|
||||
// Append structured tool-use instructions with schemas (only for non-native providers)
|
||||
@@ -4297,6 +4300,7 @@ pub async fn process_message(
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
let peripheral_tools: Vec<Box<dyn Tool>> =
|
||||
crate::peripherals::create_peripheral_tools(&config.peripherals).await?;
|
||||
@@ -4493,6 +4497,8 @@ pub async fn process_message(
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
if !native_tools {
|
||||
system_prompt.push_str(&build_tool_instructions(&tools_registry, Some(&i18n_descs)));
|
||||
|
||||
+2
-2
@@ -122,7 +122,7 @@ impl ApprovalManager {
|
||||
}
|
||||
|
||||
// always_ask overrides everything.
|
||||
if self.always_ask.contains(tool_name) {
|
||||
if self.always_ask.contains("*") || self.always_ask.contains(tool_name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ impl ApprovalManager {
|
||||
}
|
||||
|
||||
// auto_approve skips the prompt.
|
||||
if self.auto_approve.contains(tool_name) {
|
||||
if self.auto_approve.contains("*") || self.auto_approve.contains(tool_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,549 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::memory::{Memory, MemoryCategory};
|
||||
|
||||
/// Discord History channel — connects via Gateway WebSocket, stores ALL non-bot messages
|
||||
/// to a dedicated discord.db, and forwards @mention messages to the agent.
|
||||
pub struct DiscordHistoryChannel {
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
/// Channel IDs to watch. Empty = watch all channels.
|
||||
channel_ids: Vec<String>,
|
||||
/// Dedicated discord.db memory backend.
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
proxy_url: Option<String>,
|
||||
/// When false, DM messages are not stored in discord.db.
|
||||
store_dms: bool,
|
||||
/// When false, @mentions in DMs are not forwarded to the agent.
|
||||
respond_to_dms: bool,
|
||||
}
|
||||
|
||||
impl DiscordHistoryChannel {
|
||||
pub fn new(
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
channel_ids: Vec<String>,
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
store_dms: bool,
|
||||
respond_to_dms: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
allowed_users,
|
||||
channel_ids,
|
||||
discord_memory,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
store_dms,
|
||||
respond_to_dms,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
|
||||
self.proxy_url = proxy_url;
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client(
|
||||
"channel.discord_history",
|
||||
self.proxy_url.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
if self.allowed_users.is_empty() {
|
||||
return true; // default open for logging channel
|
||||
}
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
fn is_channel_watched(&self, channel_id: &str) -> bool {
|
||||
self.channel_ids.is_empty() || self.channel_ids.iter().any(|c| c == channel_id)
|
||||
}
|
||||
|
||||
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
||||
let part = token.split('.').next()?;
|
||||
base64_decode(part)
|
||||
}
|
||||
|
||||
async fn resolve_channel_name(&self, channel_id: &str) -> String {
|
||||
// 1. Check persistent database (via discord_memory)
|
||||
let cache_key = format!("cache:channel_name:{}", channel_id);
|
||||
|
||||
if let Ok(Some(cached_mem)) = self.discord_memory.get(&cache_key).await {
|
||||
// Check if it's still fresh (e.g., less than 24 hours old)
|
||||
// Note: cached_mem.timestamp is an RFC3339 string
|
||||
let is_fresh =
|
||||
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&cached_mem.timestamp) {
|
||||
chrono::Utc::now().signed_duration_since(ts.with_timezone(&chrono::Utc))
|
||||
< chrono::Duration::hours(24)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if is_fresh {
|
||||
return cached_mem.content.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Fetch from API (either not in DB or stale)
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}");
|
||||
let resp = self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let name = if let Ok(r) = resp {
|
||||
if let Ok(json) = r.json::<serde_json::Value>().await {
|
||||
json.get("name")
|
||||
.and_then(|n| n.as_str())
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
// For DMs, there might not be a 'name', use the recipient's username if available
|
||||
json.get("recipients")
|
||||
.and_then(|r| r.as_array())
|
||||
.and_then(|a| a.first())
|
||||
.and_then(|u| u.get("username"))
|
||||
.and_then(|un| un.as_str())
|
||||
.map(|s| format!("dm-{}", s))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let resolved = name.unwrap_or_else(|| channel_id.to_string());
|
||||
|
||||
// 3. Store in persistent database
|
||||
let _ = self
|
||||
.discord_memory
|
||||
.store(
|
||||
&cache_key,
|
||||
&resolved,
|
||||
crate::memory::MemoryCategory::Custom("channel_cache".to_string()),
|
||||
Some(channel_id),
|
||||
)
|
||||
.await;
|
||||
|
||||
resolved
|
||||
}
|
||||
}
|
||||
|
||||
const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
fn base64_decode(input: &str) -> Option<String> {
|
||||
let padded = match input.len() % 4 {
|
||||
2 => format!("{input}=="),
|
||||
3 => format!("{input}="),
|
||||
_ => input.to_string(),
|
||||
};
|
||||
let mut bytes = Vec::new();
|
||||
let chars: Vec<u8> = padded.bytes().collect();
|
||||
for chunk in chars.chunks(4) {
|
||||
if chunk.len() < 4 {
|
||||
break;
|
||||
}
|
||||
let mut v = [0usize; 4];
|
||||
for (i, &b) in chunk.iter().enumerate() {
|
||||
if b == b'=' {
|
||||
v[i] = 0;
|
||||
} else {
|
||||
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
|
||||
}
|
||||
}
|
||||
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
|
||||
if chunk[2] != b'=' {
|
||||
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
|
||||
}
|
||||
if chunk[3] != b'=' {
|
||||
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
|
||||
}
|
||||
}
|
||||
String::from_utf8(bytes).ok()
|
||||
}
|
||||
|
||||
fn contains_bot_mention(content: &str, bot_user_id: &str) -> bool {
|
||||
if bot_user_id.is_empty() {
|
||||
return false;
|
||||
}
|
||||
content.contains(&format!("<@{bot_user_id}>"))
|
||||
|| content.contains(&format!("<@!{bot_user_id}>"))
|
||||
}
|
||||
|
||||
fn strip_bot_mention(content: &str, bot_user_id: &str) -> String {
|
||||
let mut result = content.to_string();
|
||||
for tag in [format!("<@{bot_user_id}>"), format!("<@!{bot_user_id}>")] {
|
||||
result = result.replace(&tag, " ");
|
||||
}
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for DiscordHistoryChannel {
|
||||
fn name(&self) -> &str {
|
||||
"discord_history"
|
||||
}
|
||||
|
||||
/// Send a reply back to Discord (used when agent responds to @mention).
|
||||
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||
let content = super::strip_tool_call_tags(&message.content);
|
||||
let url = format!(
|
||||
"https://discord.com/api/v10/channels/{}/messages",
|
||||
message.recipient
|
||||
);
|
||||
self.http_client()
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.json(&json!({"content": content}))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
|
||||
|
||||
// Get Gateway URL
|
||||
let gw_resp: serde_json::Value = self
|
||||
.http_client()
|
||||
.get("https://discord.com/api/v10/gateway/bot")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let gw_url = gw_resp
|
||||
.get("url")
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("wss://gateway.discord.gg");
|
||||
|
||||
let ws_url = format!("{gw_url}/?v=10&encoding=json");
|
||||
tracing::info!("DiscordHistory: connecting to gateway...");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Read Hello (opcode 10)
|
||||
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
|
||||
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
|
||||
let heartbeat_interval = hello_data
|
||||
.get("d")
|
||||
.and_then(|d| d.get("heartbeat_interval"))
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.unwrap_or(41250);
|
||||
|
||||
// Identify with intents for guild + DM messages + message content
|
||||
let identify = json!({
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.bot_token,
|
||||
"intents": 37377,
|
||||
"properties": {
|
||||
"os": "linux",
|
||||
"browser": "zeroclaw",
|
||||
"device": "zeroclaw"
|
||||
}
|
||||
}
|
||||
});
|
||||
write
|
||||
.send(Message::Text(identify.to_string().into()))
|
||||
.await?;
|
||||
|
||||
tracing::info!("DiscordHistory: connected and identified");
|
||||
|
||||
let mut sequence: i64 = -1;
|
||||
|
||||
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
|
||||
tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(std::time::Duration::from_millis(heartbeat_interval));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if hb_tx.send(()).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let guild_filter = self.guild_id.clone();
|
||||
let discord_memory = Arc::clone(&self.discord_memory);
|
||||
let store_dms = self.store_dms;
|
||||
let respond_to_dms = self.respond_to_dms;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = hb_rx.recv() => {
|
||||
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
||||
let hb = json!({"op": 1, "d": d});
|
||||
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = read.next() => {
|
||||
let msg = match msg {
|
||||
Some(Ok(Message::Text(t))) => t,
|
||||
Some(Ok(Message::Ping(payload))) => {
|
||||
if write.send(Message::Pong(payload)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Some(Ok(Message::Close(_))) | None => break,
|
||||
Some(Err(e)) => {
|
||||
tracing::warn!("DiscordHistory: websocket error: {e}");
|
||||
break;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) {
|
||||
sequence = s;
|
||||
}
|
||||
|
||||
let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0);
|
||||
match op {
|
||||
1 => {
|
||||
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
||||
let hb = json!({"op": 1, "d": d});
|
||||
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
7 => { tracing::warn!("DiscordHistory: Reconnect (op 7)"); break; }
|
||||
9 => { tracing::warn!("DiscordHistory: Invalid Session (op 9)"); break; }
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
|
||||
if event_type != "MESSAGE_CREATE" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(d) = event.get("d") else { continue };
|
||||
|
||||
// Skip messages from the bot itself
|
||||
let author_id = d
|
||||
.get("author")
|
||||
.and_then(|a| a.get("id"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("");
|
||||
let username = d
|
||||
.get("author")
|
||||
.and_then(|a| a.get("username"))
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or(author_id);
|
||||
|
||||
if author_id == bot_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip other bots
|
||||
if d.get("author")
|
||||
.and_then(|a| a.get("bot"))
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_id = d
|
||||
.get("channel_id")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
// DM detection: DMs have no guild_id
|
||||
let is_dm_event = d.get("guild_id").and_then(serde_json::Value::as_str).is_none();
|
||||
|
||||
// Resolve channel name (with cache)
|
||||
let channel_display = if is_dm_event {
|
||||
"dm".to_string()
|
||||
} else {
|
||||
self.resolve_channel_name(&channel_id).await
|
||||
};
|
||||
|
||||
if is_dm_event && !store_dms && !respond_to_dms {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Guild filter
|
||||
if let Some(ref gid) = guild_filter {
|
||||
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str);
|
||||
if let Some(g) = msg_guild {
|
||||
if g != gid {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Channel filter
|
||||
if !self.is_channel_watched(&channel_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !self.is_user_allowed(author_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let is_mention = contains_bot_mention(content, &bot_user_id);
|
||||
|
||||
// Collect attachment URLs
|
||||
let attachments: Vec<String> = d
|
||||
.get("attachments")
|
||||
.and_then(|a| a.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|a| a.get("url").and_then(|u| u.as_str()))
|
||||
.map(|u| u.to_string())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Store messages to discord.db (skip DMs if store_dms=false)
|
||||
if (!is_dm_event || store_dms) && (!content.is_empty() || !attachments.is_empty()) {
|
||||
let ts = chrono::Utc::now().to_rfc3339();
|
||||
let mut mem_content = format!(
|
||||
"@{username} in #{channel_display} at {ts}: {content}"
|
||||
);
|
||||
if !attachments.is_empty() {
|
||||
mem_content.push_str(" [attachments: ");
|
||||
mem_content.push_str(&attachments.join(", "));
|
||||
mem_content.push(']');
|
||||
}
|
||||
let mem_key = format!(
|
||||
"discord_{}",
|
||||
if message_id.is_empty() {
|
||||
Uuid::new_v4().to_string()
|
||||
} else {
|
||||
message_id.to_string()
|
||||
}
|
||||
);
|
||||
let channel_id_for_session = if channel_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(channel_id.as_str())
|
||||
};
|
||||
if let Err(err) = discord_memory
|
||||
.store(
|
||||
&mem_key,
|
||||
&mem_content,
|
||||
MemoryCategory::Custom("discord".to_string()),
|
||||
channel_id_for_session,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("discord_history: failed to store message: {err}");
|
||||
} else {
|
||||
tracing::debug!(
|
||||
"discord_history: stored message from @{username} in #{channel_display}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Forward @mention to agent (skip DMs if respond_to_dms=false)
|
||||
if is_mention && (!is_dm_event || respond_to_dms) {
|
||||
let clean_content = strip_bot_mention(content, &bot_user_id);
|
||||
if clean_content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let channel_msg = ChannelMessage {
|
||||
id: if message_id.is_empty() {
|
||||
Uuid::new_v4().to_string()
|
||||
} else {
|
||||
format!("discord_{message_id}")
|
||||
},
|
||||
sender: author_id.to_string(),
|
||||
reply_target: if channel_id.is_empty() {
|
||||
author_id.to_string()
|
||||
} else {
|
||||
channel_id.clone()
|
||||
},
|
||||
content: clean_content,
|
||||
channel: "discord_history".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
};
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.http_client()
|
||||
.get("https://discord.com/api/v10/users/@me")
|
||||
.header("Authorization", format!("Bot {}", self.bot_token))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> {
|
||||
let mut guard = self.typing_handles.lock();
|
||||
if let Some(h) = guard.remove(recipient) {
|
||||
h.abort();
|
||||
}
|
||||
let client = self.http_client();
|
||||
let token = self.bot_token.clone();
|
||||
let channel_id = recipient.to_string();
|
||||
let handle = tokio::spawn(async move {
|
||||
let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing");
|
||||
loop {
|
||||
let _ = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {token}"))
|
||||
.send()
|
||||
.await;
|
||||
tokio::time::sleep(std::time::Duration::from_secs(8)).await;
|
||||
}
|
||||
});
|
||||
guard.insert(recipient.to_string(), handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> {
|
||||
let mut guard = self.typing_handles.lock();
|
||||
if let Some(handle) = guard.remove(recipient) {
|
||||
handle.abort();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+101
-13
@@ -19,7 +19,9 @@ pub mod clawdtalk;
|
||||
pub mod cli;
|
||||
pub mod dingtalk;
|
||||
pub mod discord;
|
||||
pub mod discord_history;
|
||||
pub mod email_channel;
|
||||
pub mod gmail_push;
|
||||
pub mod imessage;
|
||||
pub mod irc;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
@@ -45,6 +47,8 @@ pub mod traits;
|
||||
pub mod transcription;
|
||||
pub mod tts;
|
||||
pub mod twitter;
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub mod voice_wake;
|
||||
pub mod wati;
|
||||
pub mod webhook;
|
||||
pub mod wecom;
|
||||
@@ -59,7 +63,9 @@ pub use clawdtalk::{ClawdTalkChannel, ClawdTalkConfig};
|
||||
pub use cli::CliChannel;
|
||||
pub use dingtalk::DingTalkChannel;
|
||||
pub use discord::DiscordChannel;
|
||||
pub use discord_history::DiscordHistoryChannel;
|
||||
pub use email_channel::EmailChannel;
|
||||
pub use gmail_push::GmailPushChannel;
|
||||
pub use imessage::IMessageChannel;
|
||||
pub use irc::IrcChannel;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
@@ -82,6 +88,8 @@ pub use traits::{Channel, SendMessage};
|
||||
#[allow(unused_imports)]
|
||||
pub use tts::{TtsManager, TtsProvider};
|
||||
pub use twitter::TwitterChannel;
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub use voice_wake::VoiceWakeChannel;
|
||||
pub use wati::WatiChannel;
|
||||
pub use webhook::WebhookChannel;
|
||||
pub use wecom::WeComChannel;
|
||||
@@ -3128,9 +3136,12 @@ pub fn build_system_prompt_with_mode(
|
||||
Some(&autonomy_cfg),
|
||||
native_tools,
|
||||
skills_prompt_mode,
|
||||
false,
|
||||
0,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
workspace_dir: &std::path::Path,
|
||||
model_name: &str,
|
||||
@@ -3141,6 +3152,8 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
autonomy_config: Option<&crate::config::AutonomyConfig>,
|
||||
native_tools: bool,
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode,
|
||||
compact_context: bool,
|
||||
max_system_prompt_chars: usize,
|
||||
) -> String {
|
||||
use std::fmt::Write;
|
||||
let mut prompt = String::with_capacity(8192);
|
||||
@@ -3167,11 +3180,19 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
// ── 1. Tooling ──────────────────────────────────────────────
|
||||
if !tools.is_empty() {
|
||||
prompt.push_str("## Tools\n\n");
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
if compact_context {
|
||||
// Compact mode: tool names only, no descriptions/schemas
|
||||
prompt.push_str("Available tools: ");
|
||||
let names: Vec<&str> = tools.iter().map(|(name, _)| *name).collect();
|
||||
prompt.push_str(&names.join(", "));
|
||||
prompt.push_str("\n\n");
|
||||
} else {
|
||||
prompt.push_str("You have access to the following tools:\n\n");
|
||||
for (name, desc) in tools {
|
||||
let _ = writeln!(prompt, "- **{name}**: {desc}");
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
// ── 1b. Hardware (when gpio/arduino tools present) ───────────
|
||||
@@ -3315,11 +3336,13 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
std::env::consts::OS,
|
||||
);
|
||||
|
||||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
|
||||
// ── 8. Channel Capabilities (skipped in compact_context mode) ──
|
||||
if !compact_context {
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
|
||||
prompt
|
||||
.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
|
||||
Some(crate::security::AutonomyLevel::Full) => {
|
||||
"- If the runtime policy already allows a tool, use it directly; do not ask the user for extra approval.\n\
|
||||
- Never pretend you are waiting for a human approval click or confirmation when the runtime policy already permits the action.\n\
|
||||
@@ -3333,10 +3356,23 @@ pub fn build_system_prompt_with_mode_and_autonomy(
|
||||
- If there is no approval path for this channel or the runtime blocks an action, explain that restriction directly instead of simulating an approval flow.\n"
|
||||
}
|
||||
});
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
|
||||
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
|
||||
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
|
||||
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
|
||||
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
|
||||
} // end if !compact_context (Channel Capabilities)
|
||||
|
||||
// ── 9. Truncation (max_system_prompt_chars budget) ──────────
|
||||
if max_system_prompt_chars > 0 && prompt.len() > max_system_prompt_chars {
|
||||
// Truncate on a char boundary, keeping the top portion (identity + safety).
|
||||
let mut end = max_system_prompt_chars;
|
||||
// Ensure we don't split a multi-byte UTF-8 character.
|
||||
while !prompt.is_char_boundary(end) && end > 0 {
|
||||
end -= 1;
|
||||
}
|
||||
prompt.truncate(end);
|
||||
prompt.push_str("\n\n[System prompt truncated to fit context budget]\n");
|
||||
}
|
||||
|
||||
if prompt.is_empty() {
|
||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct."
|
||||
@@ -3747,6 +3783,31 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref dh) = config.channels_config.discord_history {
|
||||
match crate::memory::SqliteMemory::new_named(&config.workspace_dir, "discord") {
|
||||
Ok(discord_mem) => {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Discord History",
|
||||
channel: Arc::new(
|
||||
DiscordHistoryChannel::new(
|
||||
dh.bot_token.clone(),
|
||||
dh.guild_id.clone(),
|
||||
dh.allowed_users.clone(),
|
||||
dh.channel_ids.clone(),
|
||||
Arc::new(discord_mem),
|
||||
dh.store_dms,
|
||||
dh.respond_to_dms,
|
||||
)
|
||||
.with_proxy_url(dh.proxy_url.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("discord_history: failed to open discord.db: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sl) = config.channels_config.slack {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Slack",
|
||||
@@ -3938,6 +3999,15 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref gp_cfg) = config.channels_config.gmail_push {
|
||||
if gp_cfg.enabled {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Gmail Push",
|
||||
channel: Arc::new(GmailPushChannel::new(gp_cfg.clone())),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "IRC",
|
||||
@@ -4114,6 +4184,17 @@ fn collect_configured_channels(
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
if let Some(ref vw) = config.channels_config.voice_wake {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "VoiceWake",
|
||||
channel: Arc::new(VoiceWakeChannel::new(
|
||||
vw.clone(),
|
||||
config.transcription.clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref wh) = config.channels_config.webhook {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Webhook",
|
||||
@@ -4278,6 +4359,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
None,
|
||||
);
|
||||
|
||||
// Wire MCP tools into the registry before freezing — non-fatal.
|
||||
@@ -4451,6 +4533,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
Some(&config.autonomy),
|
||||
native_tools,
|
||||
config.skills.prompt_injection_mode,
|
||||
config.agent.compact_context,
|
||||
config.agent.max_system_prompt_chars,
|
||||
);
|
||||
if !native_tools {
|
||||
system_prompt.push_str(&build_tool_instructions(
|
||||
@@ -7782,6 +7866,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
Some(&config),
|
||||
false,
|
||||
crate::config::SkillsPromptInjectionMode::Full,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
|
||||
assert!(
|
||||
@@ -7811,6 +7897,8 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
Some(&config),
|
||||
false,
|
||||
crate::config::SkillsPromptInjectionMode::Full,
|
||||
false,
|
||||
0,
|
||||
);
|
||||
|
||||
assert!(
|
||||
|
||||
@@ -0,0 +1,586 @@
|
||||
//! Voice Wake Word detection channel.
|
||||
//!
|
||||
//! Listens on the default microphone via `cpal`, detects a configurable wake
|
||||
//! word using energy-based VAD followed by transcription-based keyword matching,
|
||||
//! then captures the subsequent utterance and dispatches it as a channel message.
|
||||
//!
|
||||
//! Gated behind the `voice-wake` Cargo feature.
|
||||
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::channels::transcription::transcribe_audio;
|
||||
use crate::config::schema::VoiceWakeConfig;
|
||||
use crate::config::TranscriptionConfig;
|
||||
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
|
||||
// ── State machine ──────────────────────────────────────────────
|
||||
|
||||
/// Maximum allowed capture duration (seconds) to prevent unbounded memory growth.
|
||||
const MAX_CAPTURE_SECS_LIMIT: u32 = 300;
|
||||
|
||||
/// Minimum silence timeout to prevent API hammering.
|
||||
const MIN_SILENCE_TIMEOUT_MS: u32 = 100;
|
||||
|
||||
/// Internal states for the wake-word detector.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WakeState {
|
||||
/// Passively monitoring microphone energy levels.
|
||||
Listening,
|
||||
/// Energy spike detected — capturing a short window to check for wake word.
|
||||
Triggered,
|
||||
/// Wake word confirmed — capturing the full utterance that follows.
|
||||
Capturing,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WakeState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Listening => write!(f, "Listening"),
|
||||
Self::Triggered => write!(f, "Triggered"),
|
||||
Self::Capturing => write!(f, "Capturing"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Channel implementation ─────────────────────────────────────
|
||||
|
||||
/// Voice wake-word channel that activates on a spoken keyword.
|
||||
pub struct VoiceWakeChannel {
|
||||
config: VoiceWakeConfig,
|
||||
transcription_config: TranscriptionConfig,
|
||||
}
|
||||
|
||||
impl VoiceWakeChannel {
|
||||
/// Create a new `VoiceWakeChannel` from its config sections.
|
||||
pub fn new(config: VoiceWakeConfig, transcription_config: TranscriptionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
transcription_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Channel for VoiceWakeChannel {
|
||||
fn name(&self) -> &str {
|
||||
"voice_wake"
|
||||
}
|
||||
|
||||
async fn send(&self, _message: &SendMessage) -> Result<()> {
|
||||
// Voice wake is input-only; outbound messages are not supported.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
|
||||
let config = self.config.clone();
|
||||
let transcription_config = self.transcription_config.clone();
|
||||
|
||||
// ── Validate config ───────────────────────────────────
|
||||
let energy_threshold = config.energy_threshold;
|
||||
if !energy_threshold.is_finite() || energy_threshold <= 0.0 {
|
||||
bail!("VoiceWake: energy_threshold must be a positive finite number, got {energy_threshold}");
|
||||
}
|
||||
if config.silence_timeout_ms < MIN_SILENCE_TIMEOUT_MS {
|
||||
bail!(
|
||||
"VoiceWake: silence_timeout_ms must be >= {MIN_SILENCE_TIMEOUT_MS}, got {}",
|
||||
config.silence_timeout_ms
|
||||
);
|
||||
}
|
||||
let max_capture_secs = config.max_capture_secs.min(MAX_CAPTURE_SECS_LIMIT);
|
||||
if max_capture_secs != config.max_capture_secs {
|
||||
warn!(
|
||||
"VoiceWake: max_capture_secs clamped from {} to {MAX_CAPTURE_SECS_LIMIT}",
|
||||
config.max_capture_secs
|
||||
);
|
||||
}
|
||||
|
||||
// Run the blocking audio capture loop on a dedicated thread.
|
||||
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(64);
|
||||
|
||||
let silence_timeout = Duration::from_millis(u64::from(config.silence_timeout_ms));
|
||||
let max_capture = Duration::from_secs(u64::from(max_capture_secs));
|
||||
let sample_rate: u32;
|
||||
let channels_count: u16;
|
||||
|
||||
// ── Initialise cpal stream ────────────────────────────
|
||||
// cpal::Stream is !Send, so we build and hold it on a dedicated thread.
|
||||
// When the listen function exits, the shutdown oneshot is dropped,
|
||||
// the thread exits, and the stream + microphone are released.
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
let (init_tx, init_rx) = tokio::sync::oneshot::channel::<Result<(u32, u16)>>();
|
||||
{
|
||||
let audio_tx_clone = audio_tx.clone();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
let result = (|| -> Result<(u32, u16, cpal::Stream)> {
|
||||
let host = cpal::default_host();
|
||||
let device = host.default_input_device().ok_or_else(|| {
|
||||
anyhow::anyhow!("No default audio input device available")
|
||||
})?;
|
||||
|
||||
let supported = device.default_input_config()?;
|
||||
let sr = supported.sample_rate().0;
|
||||
let ch = supported.channels();
|
||||
|
||||
info!(
|
||||
device = ?device.name().unwrap_or_default(),
|
||||
sample_rate = sr,
|
||||
channels = ch,
|
||||
"VoiceWake: opening audio input"
|
||||
);
|
||||
|
||||
let stream_config: cpal::StreamConfig = supported.into();
|
||||
|
||||
let stream = device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let _ = audio_tx_clone.try_send(data.to_vec());
|
||||
},
|
||||
move |err| {
|
||||
warn!("VoiceWake: audio stream error: {err}");
|
||||
},
|
||||
None,
|
||||
)?;
|
||||
|
||||
stream.play()?;
|
||||
Ok((sr, ch, stream))
|
||||
})();
|
||||
|
||||
match result {
|
||||
Ok((sr, ch, _stream)) => {
|
||||
let _ = init_tx.send(Ok((sr, ch)));
|
||||
// Hold the stream alive until shutdown is signalled.
|
||||
let _ = shutdown_rx.blocking_recv();
|
||||
debug!("VoiceWake: stream holder thread exiting");
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = init_tx.send(Err(e));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let (sr, ch) = init_rx
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("VoiceWake: stream init thread panicked"))??;
|
||||
sample_rate = sr;
|
||||
channels_count = ch;
|
||||
}
|
||||
|
||||
// Drop the extra sender so the channel closes when the stream sender drops.
|
||||
drop(audio_tx);
|
||||
|
||||
// ── Main detection loop ───────────────────────────────
|
||||
let wake_word = config.wake_word.to_lowercase();
|
||||
let mut state = WakeState::Listening;
|
||||
let mut capture_buf: Vec<f32> = Vec::new();
|
||||
let mut last_voice_at = Instant::now();
|
||||
let mut capture_start = Instant::now();
|
||||
let mut msg_counter: u64 = 0;
|
||||
|
||||
// Hard cap on capture buffer: max_capture_secs * sample_rate * channels * 2 (safety margin).
|
||||
let max_buf_samples =
|
||||
max_capture_secs as usize * sample_rate as usize * channels_count as usize * 2;
|
||||
|
||||
info!(wake_word = %wake_word, "VoiceWake: entering listen loop");
|
||||
|
||||
while let Some(chunk) = audio_rx.recv().await {
|
||||
let energy = compute_rms_energy(&chunk);
|
||||
|
||||
match state {
|
||||
WakeState::Listening => {
|
||||
if energy >= energy_threshold {
|
||||
debug!(
|
||||
energy,
|
||||
"VoiceWake: energy spike — transitioning to Triggered"
|
||||
);
|
||||
state = WakeState::Triggered;
|
||||
capture_buf.clear();
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
last_voice_at = Instant::now();
|
||||
capture_start = Instant::now();
|
||||
}
|
||||
}
|
||||
WakeState::Triggered => {
|
||||
if capture_buf.len() + chunk.len() <= max_buf_samples {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
}
|
||||
|
||||
let since_voice = last_voice_at.elapsed();
|
||||
let since_start = capture_start.elapsed();
|
||||
|
||||
// After enough silence or max time, transcribe to check for wake word.
|
||||
if since_voice >= silence_timeout || since_start >= max_capture {
|
||||
debug!("VoiceWake: Triggered window closed — transcribing for wake word");
|
||||
|
||||
let wav_bytes =
|
||||
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
|
||||
|
||||
match transcribe_audio(wav_bytes, "wake_check.wav", &transcription_config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let lower = text.to_lowercase();
|
||||
if lower.contains(&wake_word) {
|
||||
info!(text = %text, "VoiceWake: wake word detected — capturing utterance");
|
||||
state = WakeState::Capturing;
|
||||
capture_buf.clear();
|
||||
last_voice_at = Instant::now();
|
||||
capture_start = Instant::now();
|
||||
} else {
|
||||
debug!(text = %text, "VoiceWake: no wake word — back to Listening");
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VoiceWake: transcription error during wake check: {e}");
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WakeState::Capturing => {
|
||||
if capture_buf.len() + chunk.len() <= max_buf_samples {
|
||||
capture_buf.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if energy >= energy_threshold {
|
||||
last_voice_at = Instant::now();
|
||||
}
|
||||
|
||||
let since_voice = last_voice_at.elapsed();
|
||||
let since_start = capture_start.elapsed();
|
||||
|
||||
if since_voice >= silence_timeout || since_start >= max_capture {
|
||||
debug!("VoiceWake: utterance capture complete — transcribing");
|
||||
|
||||
let wav_bytes =
|
||||
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
|
||||
|
||||
match transcribe_audio(wav_bytes, "utterance.wav", &transcription_config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim().to_string();
|
||||
if !trimmed.is_empty() {
|
||||
msg_counter += 1;
|
||||
let ts = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: format!("voice_wake_{msg_counter}"),
|
||||
sender: "voice_user".into(),
|
||||
reply_target: "voice_user".into(),
|
||||
content: trimmed,
|
||||
channel: "voice_wake".into(),
|
||||
timestamp: ts,
|
||||
thread_ts: None,
|
||||
interruption_scope_id: None,
|
||||
};
|
||||
|
||||
if let Err(e) = tx.send(msg).await {
|
||||
warn!("VoiceWake: failed to dispatch message: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VoiceWake: transcription error for utterance: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
state = WakeState::Listening;
|
||||
capture_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Signal the stream holder thread to exit and release the microphone.
|
||||
drop(shutdown_tx);
|
||||
bail!("VoiceWake: audio stream ended unexpectedly");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Audio utilities ────────────────────────────────────────────
|
||||
|
||||
/// Compute RMS (root-mean-square) energy of an audio chunk.
|
||||
pub fn compute_rms_energy(samples: &[f32]) -> f32 {
|
||||
if samples.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
|
||||
(sum_sq / samples.len() as f32).sqrt()
|
||||
}
|
||||
|
||||
/// Encode raw f32 PCM samples as a WAV byte buffer (16-bit PCM).
|
||||
///
|
||||
/// This produces a minimal valid WAV file that Whisper-compatible APIs accept.
|
||||
pub fn encode_wav_from_f32(samples: &[f32], sample_rate: u32, channels: u16) -> Vec<u8> {
|
||||
let bits_per_sample: u16 = 16;
|
||||
let byte_rate = u32::from(channels) * sample_rate * u32::from(bits_per_sample) / 8;
|
||||
let block_align = channels * bits_per_sample / 8;
|
||||
// Guard against u32 overflow — reject buffers that exceed WAV's 4 GB limit.
|
||||
let data_bytes = samples.len() * 2;
|
||||
assert!(
|
||||
u32::try_from(data_bytes).is_ok(),
|
||||
"audio buffer too large for WAV encoding ({data_bytes} bytes)"
|
||||
);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let data_len = data_bytes as u32;
|
||||
let file_len = 36 + data_len;
|
||||
|
||||
let mut buf = Vec::with_capacity(file_len as usize + 8);
|
||||
|
||||
// RIFF header
|
||||
buf.extend_from_slice(b"RIFF");
|
||||
buf.extend_from_slice(&file_len.to_le_bytes());
|
||||
buf.extend_from_slice(b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
buf.extend_from_slice(b"fmt ");
|
||||
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
|
||||
buf.extend_from_slice(&channels.to_le_bytes());
|
||||
buf.extend_from_slice(&sample_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&byte_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&block_align.to_le_bytes());
|
||||
buf.extend_from_slice(&bits_per_sample.to_le_bytes());
|
||||
|
||||
// data chunk
|
||||
buf.extend_from_slice(b"data");
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
|
||||
for &sample in samples {
|
||||
let clamped = sample.clamp(-1.0, 1.0);
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let pcm16 = (clamped * 32767.0) as i16; // clamped to [-1,1] so fits i16
|
||||
buf.extend_from_slice(&pcm16.to_le_bytes());
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::traits::ChannelConfig;
|
||||
|
||||
// ── State machine tests ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wake_state_display() {
|
||||
assert_eq!(WakeState::Listening.to_string(), "Listening");
|
||||
assert_eq!(WakeState::Triggered.to_string(), "Triggered");
|
||||
assert_eq!(WakeState::Capturing.to_string(), "Capturing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wake_state_equality() {
|
||||
assert_eq!(WakeState::Listening, WakeState::Listening);
|
||||
assert_ne!(WakeState::Listening, WakeState::Triggered);
|
||||
}
|
||||
|
||||
// ── Energy computation tests ───────────────────────────
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_silence_is_zero() {
|
||||
let silence = vec![0.0f32; 1024];
|
||||
assert_eq!(compute_rms_energy(&silence), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_empty_is_zero() {
|
||||
assert_eq!(compute_rms_energy(&[]), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_of_constant_signal() {
|
||||
// Constant signal at 0.5 → RMS should be 0.5
|
||||
let signal = vec![0.5f32; 100];
|
||||
let energy = compute_rms_energy(&signal);
|
||||
assert!((energy - 0.5).abs() < 1e-5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_above_threshold() {
|
||||
let loud = vec![0.8f32; 256];
|
||||
let energy = compute_rms_energy(&loud);
|
||||
assert!(energy > 0.01, "Loud signal should exceed default threshold");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rms_energy_below_threshold_for_quiet() {
|
||||
let quiet = vec![0.001f32; 256];
|
||||
let energy = compute_rms_energy(&quiet);
|
||||
assert!(
|
||||
energy < 0.01,
|
||||
"Very quiet signal should be below default threshold"
|
||||
);
|
||||
}
|
||||
|
||||
// ── WAV encoding tests ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn wav_header_is_valid() {
|
||||
let samples = vec![0.0f32; 100];
|
||||
let wav = encode_wav_from_f32(&samples, 16000, 1);
|
||||
|
||||
// RIFF header
|
||||
assert_eq!(&wav[0..4], b"RIFF");
|
||||
assert_eq!(&wav[8..12], b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
assert_eq!(&wav[12..16], b"fmt ");
|
||||
let fmt_size = u32::from_le_bytes(wav[16..20].try_into().unwrap());
|
||||
assert_eq!(fmt_size, 16);
|
||||
|
||||
// PCM format
|
||||
let format = u16::from_le_bytes(wav[20..22].try_into().unwrap());
|
||||
assert_eq!(format, 1);
|
||||
|
||||
// Channels
|
||||
let channels = u16::from_le_bytes(wav[22..24].try_into().unwrap());
|
||||
assert_eq!(channels, 1);
|
||||
|
||||
// Sample rate
|
||||
let sr = u32::from_le_bytes(wav[24..28].try_into().unwrap());
|
||||
assert_eq!(sr, 16000);
|
||||
|
||||
// data chunk
|
||||
assert_eq!(&wav[36..40], b"data");
|
||||
let data_size = u32::from_le_bytes(wav[40..44].try_into().unwrap());
|
||||
assert_eq!(data_size, 200); // 100 samples * 2 bytes each
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wav_total_size_correct() {
|
||||
let samples = vec![0.0f32; 50];
|
||||
let wav = encode_wav_from_f32(&samples, 44100, 2);
|
||||
// header (44 bytes) + data (50 * 2 = 100 bytes)
|
||||
assert_eq!(wav.len(), 144);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wav_encodes_clipped_samples() {
|
||||
// Samples outside [-1, 1] should be clamped
|
||||
let samples = vec![-2.0f32, 2.0, 0.0];
|
||||
let wav = encode_wav_from_f32(&samples, 16000, 1);
|
||||
|
||||
let s0 = i16::from_le_bytes(wav[44..46].try_into().unwrap());
|
||||
let s1 = i16::from_le_bytes(wav[46..48].try_into().unwrap());
|
||||
let s2 = i16::from_le_bytes(wav[48..50].try_into().unwrap());
|
||||
|
||||
assert_eq!(s0, -32767); // clamped to -1.0
|
||||
assert_eq!(s1, 32767); // clamped to 1.0
|
||||
assert_eq!(s2, 0);
|
||||
}
|
||||
|
||||
// ── Config parsing tests ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_defaults() {
|
||||
let config = VoiceWakeConfig::default();
|
||||
assert_eq!(config.wake_word, "hey zeroclaw");
|
||||
assert_eq!(config.silence_timeout_ms, 2000);
|
||||
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
|
||||
assert_eq!(config.max_capture_secs, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_deserialize_partial() {
|
||||
let toml_str = r#"
|
||||
wake_word = "okay agent"
|
||||
max_capture_secs = 60
|
||||
"#;
|
||||
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.wake_word, "okay agent");
|
||||
assert_eq!(config.max_capture_secs, 60);
|
||||
// Defaults preserved for unset fields
|
||||
assert_eq!(config.silence_timeout_ms, 2000);
|
||||
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_deserialize_all_fields() {
|
||||
let toml_str = r#"
|
||||
wake_word = "hello bot"
|
||||
silence_timeout_ms = 3000
|
||||
energy_threshold = 0.05
|
||||
max_capture_secs = 15
|
||||
"#;
|
||||
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.wake_word, "hello bot");
|
||||
assert_eq!(config.silence_timeout_ms, 3000);
|
||||
assert!((config.energy_threshold - 0.05).abs() < f32::EPSILON);
|
||||
assert_eq!(config.max_capture_secs, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_config_channel_config_trait() {
|
||||
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
|
||||
assert_eq!(VoiceWakeConfig::desc(), "voice wake word detection");
|
||||
}
|
||||
|
||||
// ── State transition logic tests ───────────────────────
|
||||
|
||||
#[test]
|
||||
fn energy_threshold_determines_trigger() {
|
||||
let threshold = 0.01f32;
|
||||
let quiet_energy = compute_rms_energy(&vec![0.005f32; 256]);
|
||||
let loud_energy = compute_rms_energy(&vec![0.5f32; 256]);
|
||||
|
||||
assert!(quiet_energy < threshold, "Quiet should not trigger");
|
||||
assert!(loud_energy >= threshold, "Loud should trigger");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn state_transitions_are_deterministic() {
|
||||
// Verify that the state enum values are distinct and copyable
|
||||
let states = [
|
||||
WakeState::Listening,
|
||||
WakeState::Triggered,
|
||||
WakeState::Capturing,
|
||||
];
|
||||
for (i, a) in states.iter().enumerate() {
|
||||
for (j, b) in states.iter().enumerate() {
|
||||
if i == j {
|
||||
assert_eq!(a, b);
|
||||
} else {
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn channel_config_impl() {
|
||||
// VoiceWakeConfig implements ChannelConfig
|
||||
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
|
||||
assert!(!VoiceWakeConfig::desc().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn voice_wake_channel_name() {
|
||||
let config = VoiceWakeConfig::default();
|
||||
let transcription_config = TranscriptionConfig::default();
|
||||
let channel = VoiceWakeChannel::new(config, transcription_config);
|
||||
assert_eq!(channel.name(), "voice_wake");
|
||||
}
|
||||
}
|
||||
+16
-15
@@ -15,21 +15,22 @@ pub use schema::{
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
|
||||
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
|
||||
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
|
||||
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
|
||||
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
|
||||
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
|
||||
OtpConfig, OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
|
||||
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
|
||||
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
|
||||
IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig,
|
||||
LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig,
|
||||
Microsoft365Config, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
|
||||
NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig, OpenAiSttConfig,
|
||||
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PacingConfig,
|
||||
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
|
||||
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
|
||||
DEFAULT_GWS_SERVICES,
|
||||
};
|
||||
|
||||
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
|
||||
|
||||
@@ -357,6 +357,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub linkedin: LinkedInConfig,
|
||||
|
||||
/// Standalone image generation tool configuration (`[image_gen]`).
|
||||
#[serde(default)]
|
||||
pub image_gen: ImageGenConfig,
|
||||
|
||||
/// Plugin system configuration (`[plugins]`).
|
||||
#[serde(default)]
|
||||
pub plugins: PluginsConfig,
|
||||
@@ -1248,6 +1252,12 @@ pub struct AgentConfig {
|
||||
/// Default: `[]` (no filtering — all tools included).
|
||||
#[serde(default)]
|
||||
pub tool_filter_groups: Vec<ToolFilterGroup>,
|
||||
/// Maximum characters for the assembled system prompt. When `> 0`, the prompt
|
||||
/// is truncated to this limit after assembly (keeping the top portion which
|
||||
/// contains identity and safety instructions). `0` means unlimited.
|
||||
/// Useful for small-context models (e.g. glm-4.5-air ~8K tokens → set to 8000).
|
||||
#[serde(default = "default_max_system_prompt_chars")]
|
||||
pub max_system_prompt_chars: usize,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@@ -1266,6 +1276,10 @@ fn default_agent_tool_dispatcher() -> String {
|
||||
"auto".into()
|
||||
}
|
||||
|
||||
fn default_max_system_prompt_chars() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
@@ -1277,6 +1291,7 @@ impl Default for AgentConfig {
|
||||
tool_dispatcher: default_agent_tool_dispatcher(),
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
tool_filter_groups: Vec::new(),
|
||||
max_system_prompt_chars: default_max_system_prompt_chars(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2963,6 +2978,46 @@ impl Default for ImageProviderFluxConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Standalone Image Generation ─────────────────────────────────
|
||||
|
||||
/// Standalone image generation tool configuration (`[image_gen]`).
|
||||
///
|
||||
/// When enabled, registers an `image_gen` tool that generates images via
|
||||
/// fal.ai's synchronous API (Flux / Nano Banana models) and saves them
|
||||
/// to the workspace `images/` directory.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ImageGenConfig {
|
||||
/// Enable the standalone image generation tool. Default: false.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Default fal.ai model identifier.
|
||||
#[serde(default = "default_image_gen_model")]
|
||||
pub default_model: String,
|
||||
|
||||
/// Environment variable name holding the fal.ai API key.
|
||||
#[serde(default = "default_image_gen_api_key_env")]
|
||||
pub api_key_env: String,
|
||||
}
|
||||
|
||||
fn default_image_gen_model() -> String {
|
||||
"fal-ai/flux/schnell".into()
|
||||
}
|
||||
|
||||
fn default_image_gen_api_key_env() -> String {
|
||||
"FAL_API_KEY".into()
|
||||
}
|
||||
|
||||
impl Default for ImageGenConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
default_model: default_image_gen_model(),
|
||||
api_key_env: default_image_gen_api_key_env(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Claude Code ─────────────────────────────────────────────────
|
||||
|
||||
/// Claude Code CLI tool configuration (`[claude_code]` section).
|
||||
@@ -4867,6 +4922,8 @@ pub struct ChannelsConfig {
|
||||
pub telegram: Option<TelegramConfig>,
|
||||
/// Discord bot channel configuration.
|
||||
pub discord: Option<DiscordConfig>,
|
||||
/// Discord history channel — logs ALL messages and forwards @mentions to agent.
|
||||
pub discord_history: Option<DiscordHistoryConfig>,
|
||||
/// Slack bot channel configuration.
|
||||
pub slack: Option<SlackConfig>,
|
||||
/// Mattermost bot channel configuration.
|
||||
@@ -4889,6 +4946,8 @@ pub struct ChannelsConfig {
|
||||
pub nextcloud_talk: Option<NextcloudTalkConfig>,
|
||||
/// Email channel configuration.
|
||||
pub email: Option<crate::channels::email_channel::EmailConfig>,
|
||||
/// Gmail Pub/Sub push notification channel configuration.
|
||||
pub gmail_push: Option<crate::channels::gmail_push::GmailPushConfig>,
|
||||
/// IRC channel configuration.
|
||||
pub irc: Option<IrcConfig>,
|
||||
/// Lark channel configuration.
|
||||
@@ -4913,6 +4972,9 @@ pub struct ChannelsConfig {
|
||||
pub reddit: Option<RedditConfig>,
|
||||
/// Bluesky channel configuration (AT Protocol).
|
||||
pub bluesky: Option<BlueskyConfig>,
|
||||
/// Voice wake word detection channel configuration.
|
||||
#[cfg(feature = "voice-wake")]
|
||||
pub voice_wake: Option<VoiceWakeConfig>,
|
||||
/// Base timeout in seconds for processing a single channel message (LLM + tools).
|
||||
/// Runtime uses this as a per-turn budget that scales with tool-loop depth
|
||||
/// (up to 4x, capped) so one slow/retried model call does not consume the
|
||||
@@ -4995,6 +5057,10 @@ impl ChannelsConfig {
|
||||
Box::new(ConfigWrapper::new(self.email.as_ref())),
|
||||
self.email.is_some(),
|
||||
),
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.gmail_push.as_ref())),
|
||||
self.gmail_push.is_some(),
|
||||
),
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.irc.as_ref())),
|
||||
self.irc.is_some()
|
||||
@@ -5036,6 +5102,11 @@ impl ChannelsConfig {
|
||||
Box::new(ConfigWrapper::new(self.bluesky.as_ref())),
|
||||
self.bluesky.is_some(),
|
||||
),
|
||||
#[cfg(feature = "voice-wake")]
|
||||
(
|
||||
Box::new(ConfigWrapper::new(self.voice_wake.as_ref())),
|
||||
self.voice_wake.is_some(),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@@ -5063,6 +5134,7 @@ impl Default for ChannelsConfig {
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -5074,6 +5146,7 @@ impl Default for ChannelsConfig {
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -5087,6 +5160,8 @@ impl Default for ChannelsConfig {
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: default_channel_message_timeout_secs(),
|
||||
ack_reactions: true,
|
||||
show_tool_calls: false,
|
||||
@@ -5190,6 +5265,39 @@ impl ChannelConfig for DiscordConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Discord history channel — logs ALL messages to discord.db and forwards @mentions to the agent.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DiscordHistoryConfig {
|
||||
/// Discord bot token (from Discord Developer Portal).
|
||||
pub bot_token: String,
|
||||
/// Optional guild (server) ID to restrict logging to a single guild.
|
||||
pub guild_id: Option<String>,
|
||||
/// Allowed Discord user IDs. Empty = allow all (open logging).
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Discord channel IDs to watch. Empty = watch all channels.
|
||||
#[serde(default)]
|
||||
pub channel_ids: Vec<String>,
|
||||
/// When true (default), store Direct Messages in discord.db.
|
||||
#[serde(default = "default_true")]
|
||||
pub store_dms: bool,
|
||||
/// When true (default), respond to @mentions in Direct Messages.
|
||||
#[serde(default = "default_true")]
|
||||
pub respond_to_dms: bool,
|
||||
/// Per-channel proxy URL (http, https, socks5, socks5h).
|
||||
#[serde(default)]
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ChannelConfig for DiscordHistoryConfig {
|
||||
fn name() -> &'static str {
|
||||
"Discord History"
|
||||
}
|
||||
fn desc() -> &'static str {
|
||||
"log all messages and forward @mentions"
|
||||
}
|
||||
}
|
||||
|
||||
/// Slack bot channel configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SlackConfig {
|
||||
@@ -6338,6 +6446,74 @@ impl ChannelConfig for BlueskyConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Voice wake word detection channel configuration.
|
||||
///
|
||||
/// Listens on the default microphone for a configurable wake word,
|
||||
/// then captures the following utterance and transcribes it via the
|
||||
/// existing transcription API.
|
||||
#[cfg(feature = "voice-wake")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct VoiceWakeConfig {
|
||||
/// Wake word phrase to listen for (case-insensitive substring match).
|
||||
/// Default: `"hey zeroclaw"`.
|
||||
#[serde(default = "default_voice_wake_word")]
|
||||
pub wake_word: String,
|
||||
/// Silence timeout in milliseconds — how long to wait after the last
|
||||
/// energy spike before finalizing a capture window. Default: `2000`.
|
||||
#[serde(default = "default_voice_wake_silence_timeout_ms")]
|
||||
pub silence_timeout_ms: u32,
|
||||
/// RMS energy threshold for voice activity detection. Samples below
|
||||
/// this level are treated as silence. Default: `0.01`.
|
||||
#[serde(default = "default_voice_wake_energy_threshold")]
|
||||
pub energy_threshold: f32,
|
||||
/// Maximum capture duration in seconds before forcing transcription.
|
||||
/// Default: `30`.
|
||||
#[serde(default = "default_voice_wake_max_capture_secs")]
|
||||
pub max_capture_secs: u32,
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_word() -> String {
|
||||
"hey zeroclaw".into()
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_silence_timeout_ms() -> u32 {
|
||||
2000
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_energy_threshold() -> f32 {
|
||||
0.01
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
fn default_voice_wake_max_capture_secs() -> u32 {
|
||||
30
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
impl Default for VoiceWakeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
wake_word: default_voice_wake_word(),
|
||||
silence_timeout_ms: default_voice_wake_silence_timeout_ms(),
|
||||
energy_threshold: default_voice_wake_energy_threshold(),
|
||||
max_capture_secs: default_voice_wake_max_capture_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "voice-wake")]
|
||||
impl ChannelConfig for VoiceWakeConfig {
|
||||
fn name() -> &'static str {
|
||||
"VoiceWake"
|
||||
}
|
||||
fn desc() -> &'static str {
|
||||
"voice wake word detection"
|
||||
}
|
||||
}
|
||||
|
||||
/// Nostr channel configuration (NIP-04 + NIP-17 private messages)
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
@@ -6811,6 +6987,7 @@ impl Default for Config {
|
||||
node_transport: NodeTransportConfig::default(),
|
||||
knowledge: KnowledgeConfig::default(),
|
||||
linkedin: LinkedInConfig::default(),
|
||||
image_gen: ImageGenConfig::default(),
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
@@ -7607,6 +7784,13 @@ impl Config {
|
||||
"config.channels_config.email.password",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut gp) = config.channels_config.gmail_push {
|
||||
decrypt_secret(
|
||||
&store,
|
||||
&mut gp.oauth_token,
|
||||
"config.channels_config.gmail_push.oauth_token",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut irc) = config.channels_config.irc {
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
@@ -9038,6 +9222,13 @@ impl Config {
|
||||
"config.channels_config.email.password",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut gp) = config_to_save.channels_config.gmail_push {
|
||||
encrypt_secret(
|
||||
&store,
|
||||
&mut gp.oauth_token,
|
||||
"config.channels_config.gmail_push.oauth_token",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut irc) = config_to_save.channels_config.irc {
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
@@ -9665,6 +9856,7 @@ default_temperature = 0.7
|
||||
proxy_url: None,
|
||||
}),
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -9676,6 +9868,7 @@ default_temperature = 0.7
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -9689,6 +9882,8 @@ default_temperature = 0.7
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
@@ -9733,6 +9928,7 @@ default_temperature = 0.7
|
||||
node_transport: NodeTransportConfig::default(),
|
||||
knowledge: KnowledgeConfig::default(),
|
||||
linkedin: LinkedInConfig::default(),
|
||||
image_gen: ImageGenConfig::default(),
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
@@ -10114,6 +10310,7 @@ default_temperature = 0.7
|
||||
node_transport: NodeTransportConfig::default(),
|
||||
knowledge: KnowledgeConfig::default(),
|
||||
linkedin: LinkedInConfig::default(),
|
||||
image_gen: ImageGenConfig::default(),
|
||||
plugins: PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: VerifiableIntentConfig::default(),
|
||||
@@ -10500,6 +10697,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -10521,6 +10719,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -10533,6 +10732,8 @@ allowed_users = ["@ops:matrix.org"]
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
@@ -10815,6 +11016,7 @@ channel_id = "C123"
|
||||
cli: true,
|
||||
telegram: None,
|
||||
discord: None,
|
||||
discord_history: None,
|
||||
slack: None,
|
||||
mattermost: None,
|
||||
webhook: None,
|
||||
@@ -10840,6 +11042,7 @@ channel_id = "C123"
|
||||
wati: None,
|
||||
nextcloud_talk: None,
|
||||
email: None,
|
||||
gmail_push: None,
|
||||
irc: None,
|
||||
lark: None,
|
||||
feishu: None,
|
||||
@@ -10852,6 +11055,8 @@ channel_id = "C123"
|
||||
clawdtalk: None,
|
||||
reddit: None,
|
||||
bluesky: None,
|
||||
#[cfg(feature = "voice-wake")]
|
||||
voice_wake: None,
|
||||
message_timeout_secs: 300,
|
||||
ack_reactions: true,
|
||||
show_tool_calls: true,
|
||||
|
||||
+3
-1
@@ -23,7 +23,7 @@ fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
|
||||
}
|
||||
|
||||
/// Verify bearer token against PairingGuard. Returns error response if unauthorized.
|
||||
fn require_auth(
|
||||
pub(super) fn require_auth(
|
||||
state: &AppState,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
|
||||
@@ -1429,6 +1429,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -1439,6 +1440,7 @@ mod tests {
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
path_prefix: String::new(),
|
||||
canvas_store: crate::tools::canvas::CanvasStore::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
//! Live Canvas gateway routes — REST + WebSocket for real-time canvas updates.
|
||||
//!
|
||||
//! - `GET /api/canvas/:id` — get current canvas content (JSON)
|
||||
//! - `POST /api/canvas/:id` — push content programmatically
|
||||
//! - `GET /api/canvas` — list all active canvases
|
||||
//! - `WS /ws/canvas/:id` — real-time canvas updates via WebSocket
|
||||
|
||||
use super::api::require_auth;
|
||||
use super::AppState;
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
Path, State, WebSocketUpgrade,
|
||||
},
|
||||
http::{header, HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::Deserialize;
|
||||
|
||||
/// POST /api/canvas/:id request body.
|
||||
#[derive(Deserialize)]
|
||||
pub struct CanvasPostBody {
|
||||
pub content_type: Option<String>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// GET /api/canvas — list all active canvases.
|
||||
pub async fn handle_canvas_list(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let ids = state.canvas_store.list();
|
||||
Json(serde_json::json!({ "canvases": ids })).into_response()
|
||||
}
|
||||
|
||||
/// GET /api/canvas/:id — get current canvas content.
|
||||
pub async fn handle_canvas_get(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
match state.canvas_store.snapshot(&id) {
|
||||
Some(frame) => Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frame": frame,
|
||||
}))
|
||||
.into_response(),
|
||||
None => (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({ "error": format!("Canvas '{}' not found", id) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /api/canvas/:id/history — get canvas frame history.
|
||||
pub async fn handle_canvas_history(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let history = state.canvas_store.history(&id);
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frames": history,
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// POST /api/canvas/:id — push content to a canvas.
|
||||
pub async fn handle_canvas_post(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<CanvasPostBody>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let content_type = body.content_type.as_deref().unwrap_or("html");
|
||||
|
||||
// Validate content_type against allowed set (prevent injecting "eval" frames via REST).
|
||||
if !crate::tools::canvas::ALLOWED_CONTENT_TYPES.contains(&content_type) {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": format!(
|
||||
"Invalid content_type '{}'. Allowed: {:?}",
|
||||
content_type,
|
||||
crate::tools::canvas::ALLOWED_CONTENT_TYPES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Enforce content size limit (same as tool-side validation).
|
||||
if body.content.len() > crate::tools::canvas::MAX_CONTENT_SIZE {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Json(serde_json::json!({
|
||||
"error": format!(
|
||||
"Content exceeds maximum size of {} bytes",
|
||||
crate::tools::canvas::MAX_CONTENT_SIZE
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match state.canvas_store.render(&id, content_type, &body.content) {
|
||||
Some(frame) => (
|
||||
StatusCode::CREATED,
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"frame": frame,
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
None => (
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(serde_json::json!({
|
||||
"error": "Maximum canvas count reached. Clear unused canvases first."
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
/// DELETE /api/canvas/:id — clear a canvas.
|
||||
pub async fn handle_canvas_clear(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
state.canvas_store.clear(&id);
|
||||
Json(serde_json::json!({
|
||||
"canvas_id": id,
|
||||
"status": "cleared",
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// WS /ws/canvas/:id — real-time canvas updates.
|
||||
pub async fn handle_ws_canvas(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> impl IntoResponse {
|
||||
// Auth check (same pattern as ws::handle_ws_chat)
|
||||
if state.pairing.require_pairing() {
|
||||
let token = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.or_else(|| {
|
||||
// Fallback: check query params in the upgrade request URI
|
||||
headers
|
||||
.get("sec-websocket-protocol")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|protos| {
|
||||
protos
|
||||
.split(',')
|
||||
.map(|p| p.trim())
|
||||
.find_map(|p| p.strip_prefix("bearer."))
|
||||
})
|
||||
})
|
||||
.unwrap_or("");
|
||||
|
||||
if !state.pairing.is_authenticated(token) {
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Unauthorized — provide Authorization header or Sec-WebSocket-Protocol bearer",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
ws.on_upgrade(move |socket| handle_canvas_socket(socket, state, id))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn handle_canvas_socket(socket: WebSocket, state: AppState, canvas_id: String) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Subscribe to canvas updates
|
||||
let mut rx = match state.canvas_store.subscribe(&canvas_id) {
|
||||
Some(rx) => rx,
|
||||
None => {
|
||||
let msg = serde_json::json!({
|
||||
"type": "error",
|
||||
"error": "Maximum canvas count reached",
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Send current state immediately if available
|
||||
if let Some(frame) = state.canvas_store.snapshot(&canvas_id) {
|
||||
let msg = serde_json::json!({
|
||||
"type": "frame",
|
||||
"canvas_id": canvas_id,
|
||||
"frame": frame,
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
}
|
||||
|
||||
// Send a connected acknowledgement
|
||||
let ack = serde_json::json!({
|
||||
"type": "connected",
|
||||
"canvas_id": canvas_id,
|
||||
});
|
||||
let _ = sender.send(Message::Text(ack.to_string().into())).await;
|
||||
|
||||
// Spawn a task that forwards broadcast updates to the WebSocket
|
||||
let canvas_id_clone = canvas_id.clone();
|
||||
let send_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(frame) => {
|
||||
let msg = serde_json::json!({
|
||||
"type": "frame",
|
||||
"canvas_id": canvas_id_clone,
|
||||
"frame": frame,
|
||||
});
|
||||
if sender
|
||||
.send(Message::Text(msg.to_string().into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
// Client fell behind — notify and continue rather than disconnecting.
|
||||
let msg = serde_json::json!({
|
||||
"type": "lagged",
|
||||
"canvas_id": canvas_id_clone,
|
||||
"missed_frames": n,
|
||||
});
|
||||
let _ = sender.send(Message::Text(msg.to_string().into())).await;
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Read loop: we mostly ignore incoming messages but handle close/ping
|
||||
while let Some(msg) = receiver.next().await {
|
||||
match msg {
|
||||
Ok(Message::Close(_)) | Err(_) => break,
|
||||
_ => {} // Ignore all other messages (pings are handled by axum)
|
||||
}
|
||||
}
|
||||
|
||||
// Abort the send task when the connection is closed
|
||||
send_task.abort();
|
||||
}
|
||||
+122
-2
@@ -11,14 +11,15 @@ pub mod api;
|
||||
pub mod api_pairing;
|
||||
#[cfg(feature = "plugins-wasm")]
|
||||
pub mod api_plugins;
|
||||
pub mod canvas;
|
||||
pub mod nodes;
|
||||
pub mod sse;
|
||||
pub mod static_files;
|
||||
pub mod ws;
|
||||
|
||||
use crate::channels::{
|
||||
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel, LinqChannel,
|
||||
NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
|
||||
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel,
|
||||
GmailPushChannel, LinqChannel, NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::cost::CostTracker;
|
||||
@@ -28,6 +29,7 @@ use crate::runtime;
|
||||
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools;
|
||||
use crate::tools::canvas::CanvasStore;
|
||||
use crate::tools::traits::ToolSpec;
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::{Context, Result};
|
||||
@@ -336,6 +338,8 @@ pub struct AppState {
|
||||
/// Nextcloud Talk webhook secret for signature verification
|
||||
pub nextcloud_talk_webhook_secret: Option<Arc<str>>,
|
||||
pub wati: Option<Arc<WatiChannel>>,
|
||||
/// Gmail Pub/Sub push notification channel
|
||||
pub gmail_push: Option<Arc<GmailPushChannel>>,
|
||||
/// Observability backend for metrics scraping
|
||||
pub observer: Arc<dyn crate::observability::Observer>,
|
||||
/// Registered tool specs (for web dashboard tools page)
|
||||
@@ -356,6 +360,8 @@ pub struct AppState {
|
||||
pub device_registry: Option<Arc<api_pairing::DeviceRegistry>>,
|
||||
/// Pending pairing request store
|
||||
pub pending_pairings: Option<Arc<api_pairing::PairingStore>>,
|
||||
/// Shared canvas store for Live Canvas (A2UI) system
|
||||
pub canvas_store: CanvasStore,
|
||||
}
|
||||
|
||||
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
|
||||
@@ -432,6 +438,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let canvas_store = tools::CanvasStore::new();
|
||||
|
||||
let (mut tools_registry_raw, delegate_handle_gw, _reaction_handle_gw) =
|
||||
tools::all_tools_with_runtime(
|
||||
Arc::new(config.clone()),
|
||||
@@ -447,6 +455,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
&config.agents,
|
||||
config.api_key.as_deref(),
|
||||
&config,
|
||||
Some(canvas_store.clone()),
|
||||
);
|
||||
|
||||
// ── Wire MCP tools into the gateway tool registry (non-fatal) ───
|
||||
@@ -630,6 +639,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
})
|
||||
.map(Arc::from);
|
||||
|
||||
// Gmail Push channel (if configured and enabled)
|
||||
let gmail_push_channel: Option<Arc<GmailPushChannel>> = config
|
||||
.channels_config
|
||||
.gmail_push
|
||||
.as_ref()
|
||||
.filter(|gp| gp.enabled)
|
||||
.map(|gp| Arc::new(GmailPushChannel::new(gp.clone())));
|
||||
|
||||
// ── Session persistence for WS chat ─────────────────────
|
||||
let session_backend: Option<Arc<dyn SessionBackend>> = if config.gateway.session_persistence {
|
||||
match SqliteSessionBackend::new(&config.workspace_dir) {
|
||||
@@ -801,6 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
nextcloud_talk: nextcloud_talk_channel,
|
||||
nextcloud_talk_webhook_secret,
|
||||
wati: wati_channel,
|
||||
gmail_push: gmail_push_channel,
|
||||
observer: broadcast_observer,
|
||||
tools_registry,
|
||||
cost_tracker,
|
||||
@@ -811,6 +829,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
device_registry,
|
||||
pending_pairings,
|
||||
path_prefix: path_prefix.unwrap_or("").to_string(),
|
||||
canvas_store,
|
||||
};
|
||||
|
||||
// Config PUT needs larger body limit (1MB)
|
||||
@@ -835,6 +854,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/wati", get(handle_wati_verify))
|
||||
.route("/wati", post(handle_wati_webhook))
|
||||
.route("/nextcloud-talk", post(handle_nextcloud_talk_webhook))
|
||||
.route("/webhook/gmail", post(handle_gmail_push_webhook))
|
||||
// ── Web Dashboard API routes ──
|
||||
.route("/api/status", get(api::handle_api_status))
|
||||
.route("/api/config", get(api::handle_api_config_get))
|
||||
@@ -875,6 +895,18 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route(
|
||||
"/api/devices/{id}/token/rotate",
|
||||
post(api_pairing::rotate_token),
|
||||
)
|
||||
// ── Live Canvas (A2UI) routes ──
|
||||
.route("/api/canvas", get(canvas::handle_canvas_list))
|
||||
.route(
|
||||
"/api/canvas/{id}",
|
||||
get(canvas::handle_canvas_get)
|
||||
.post(canvas::handle_canvas_post)
|
||||
.delete(canvas::handle_canvas_clear),
|
||||
)
|
||||
.route(
|
||||
"/api/canvas/{id}/history",
|
||||
get(canvas::handle_canvas_history),
|
||||
);
|
||||
|
||||
// ── Plugin management API (requires plugins-wasm feature) ──
|
||||
@@ -889,6 +921,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/api/events", get(sse::handle_sse_events))
|
||||
// ── WebSocket agent chat ──
|
||||
.route("/ws/chat", get(ws::handle_ws_chat))
|
||||
// ── WebSocket canvas updates ──
|
||||
.route("/ws/canvas/{id}", get(canvas::handle_ws_canvas))
|
||||
// ── WebSocket node discovery ──
|
||||
.route("/ws/nodes", get(nodes::handle_ws_nodes))
|
||||
// ── Static assets (web dashboard) ──
|
||||
@@ -1813,6 +1847,74 @@ async fn handle_nextcloud_talk_webhook(
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
/// Maximum request body size for the Gmail webhook endpoint (1 MB).
|
||||
/// Google Pub/Sub messages are typically under 10 KB.
|
||||
const GMAIL_WEBHOOK_MAX_BODY: usize = 1024 * 1024;
|
||||
|
||||
/// POST /webhook/gmail — incoming Gmail Pub/Sub push notification
|
||||
async fn handle_gmail_push_webhook(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> impl IntoResponse {
|
||||
let Some(ref gmail_push) = state.gmail_push else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Gmail push not configured"})),
|
||||
);
|
||||
};
|
||||
|
||||
// Enforce body size limit.
|
||||
if body.len() > GMAIL_WEBHOOK_MAX_BODY {
|
||||
return (
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
Json(serde_json::json!({"error": "Request body too large"})),
|
||||
);
|
||||
}
|
||||
|
||||
// Authenticate the webhook request using a shared secret.
|
||||
let secret = gmail_push.resolve_webhook_secret();
|
||||
if !secret.is_empty() {
|
||||
let provided = headers
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|auth| auth.strip_prefix("Bearer "))
|
||||
.unwrap_or("");
|
||||
|
||||
if provided != secret {
|
||||
tracing::warn!("Gmail push webhook: unauthorized request");
|
||||
return (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({"error": "Unauthorized"})),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let body_str = String::from_utf8_lossy(&body);
|
||||
let envelope: crate::channels::gmail_push::PubSubEnvelope =
|
||||
match serde_json::from_str(&body_str) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!("Gmail push webhook: invalid payload: {e}");
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "Invalid Pub/Sub envelope"})),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Process the notification asynchronously (non-blocking for the webhook response)
|
||||
let channel = Arc::clone(gmail_push);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = channel.handle_notification(&envelope).await {
|
||||
tracing::error!("Gmail push notification processing failed: {e:#}");
|
||||
}
|
||||
});
|
||||
|
||||
// Acknowledge immediately — Google Pub/Sub requires a 2xx within ~10s
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// ADMIN HANDLERS (for CLI management)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -2001,6 +2103,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2011,6 +2114,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_metrics(State(state)).await.into_response();
|
||||
@@ -2057,6 +2161,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer,
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2067,6 +2172,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_metrics(State(state)).await.into_response();
|
||||
@@ -2442,6 +2548,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2452,6 +2559,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2512,6 +2620,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2522,6 +2631,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
@@ -2594,6 +2704,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2604,6 +2715,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = handle_webhook(
|
||||
@@ -2648,6 +2760,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2658,6 +2771,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2707,6 +2821,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2717,6 +2832,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
@@ -2771,6 +2887,7 @@ mod tests {
|
||||
nextcloud_talk: None,
|
||||
nextcloud_talk_webhook_secret: None,
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2781,6 +2898,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let response = Box::pin(handle_nextcloud_talk_webhook(
|
||||
@@ -2831,6 +2949,7 @@ mod tests {
|
||||
nextcloud_talk: Some(channel),
|
||||
nextcloud_talk_webhook_secret: Some(Arc::from(secret)),
|
||||
wati: None,
|
||||
gmail_push: None,
|
||||
observer: Arc::new(crate::observability::NoopObserver),
|
||||
tools_registry: Arc::new(Vec::new()),
|
||||
cost_tracker: None,
|
||||
@@ -2841,6 +2960,7 @@ mod tests {
|
||||
session_backend: None,
|
||||
device_registry: None,
|
||||
pending_pairings: None,
|
||||
canvas_store: CanvasStore::new(),
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
@@ -407,7 +407,11 @@ mod tests {
|
||||
// Simpler: write a temp script.
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let script_path = dir.path().join("tool.sh");
|
||||
std::fs::write(&script_path, format!("#!/bin/sh\necho '{}'\n", result_json)).unwrap();
|
||||
std::fs::write(
|
||||
&script_path,
|
||||
format!("#!/bin/sh\ncat > /dev/null\necho '{}'\n", result_json),
|
||||
)
|
||||
.unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
@@ -46,6 +46,31 @@ impl SqliteMemory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Like `new`, but stores data in `{db_name}.db` instead of `brain.db`.
|
||||
pub fn new_named(workspace_dir: &Path, db_name: &str) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join(format!("{db_name}.db"));
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let conn = Self::open_connection(&db_path, None)?;
|
||||
conn.execute_batch(
|
||||
"PRAGMA journal_mode = WAL;
|
||||
PRAGMA synchronous = NORMAL;
|
||||
PRAGMA mmap_size = 8388608;
|
||||
PRAGMA cache_size = -2000;
|
||||
PRAGMA temp_store = MEMORY;",
|
||||
)?;
|
||||
Self::init_schema(&conn)?;
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
db_path,
|
||||
embedder: Arc::new(super::embeddings::NoopEmbedding),
|
||||
vector_weight: 0.7,
|
||||
keyword_weight: 0.3,
|
||||
cache_max: 10_000,
|
||||
})
|
||||
}
|
||||
|
||||
/// Build SQLite memory with optional open timeout.
|
||||
///
|
||||
/// If `open_timeout_secs` is `Some(n)`, opening the database is limited to `n` seconds
|
||||
|
||||
@@ -197,6 +197,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
knowledge: crate::config::KnowledgeConfig::default(),
|
||||
linkedin: crate::config::LinkedInConfig::default(),
|
||||
image_gen: crate::config::ImageGenConfig::default(),
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
@@ -620,6 +621,7 @@ async fn run_quick_setup_with_home(
|
||||
node_transport: crate::config::NodeTransportConfig::default(),
|
||||
knowledge: crate::config::KnowledgeConfig::default(),
|
||||
linkedin: crate::config::LinkedInConfig::default(),
|
||||
image_gen: crate::config::ImageGenConfig::default(),
|
||||
plugins: crate::config::PluginsConfig::default(),
|
||||
locale: None,
|
||||
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
|
||||
|
||||
@@ -412,6 +412,7 @@ mod tests {
|
||||
));
|
||||
let mut f = std::fs::File::create(&path).unwrap();
|
||||
writeln!(f, "#!/bin/sh\ncat /dev/stdin").unwrap();
|
||||
f.sync_all().unwrap();
|
||||
drop(f);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
|
||||
@@ -108,6 +108,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
|
||||
"token limit exceeded",
|
||||
"prompt is too long",
|
||||
"input is too long",
|
||||
"prompt exceeds max length",
|
||||
];
|
||||
|
||||
hints.iter().any(|hint| lower.contains(hint))
|
||||
|
||||
+17
-2
@@ -97,7 +97,8 @@ pub struct SecurityPolicy {
|
||||
/// Default allowed commands for Unix platforms.
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn default_allowed_commands() -> Vec<String> {
|
||||
vec![
|
||||
#[allow(unused_mut)]
|
||||
let mut cmds = vec![
|
||||
"git".into(),
|
||||
"npm".into(),
|
||||
"cargo".into(),
|
||||
@@ -111,7 +112,16 @@ fn default_allowed_commands() -> Vec<String> {
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
"date".into(),
|
||||
]
|
||||
"df".into(),
|
||||
"du".into(),
|
||||
"uname".into(),
|
||||
"uptime".into(),
|
||||
"hostname".into(),
|
||||
];
|
||||
// `free` is Linux-only; it does not exist on macOS or other BSDs.
|
||||
#[cfg(target_os = "linux")]
|
||||
cmds.push("free".into());
|
||||
cmds
|
||||
}
|
||||
|
||||
/// Default allowed commands for Windows platforms.
|
||||
@@ -142,6 +152,11 @@ fn default_allowed_commands() -> Vec<String> {
|
||||
"wc".into(),
|
||||
"head".into(),
|
||||
"tail".into(),
|
||||
"df".into(),
|
||||
"du".into(),
|
||||
"uname".into(),
|
||||
"uptime".into(),
|
||||
"hostname".into(),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,636 @@
|
||||
//! Live Canvas (A2UI) tool — push rendered content to a web canvas in real time.
|
||||
//!
|
||||
//! The agent can render HTML/SVG/Markdown to a named canvas, snapshot its
|
||||
//! current state, clear it, or evaluate a JavaScript expression in the canvas
|
||||
//! context. Content is stored in a shared [`CanvasStore`] and broadcast to
|
||||
//! connected WebSocket clients via per-canvas channels.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
/// Maximum content size per canvas frame (256 KB).
|
||||
pub const MAX_CONTENT_SIZE: usize = 256 * 1024;
|
||||
|
||||
/// Maximum number of history frames kept per canvas.
|
||||
const MAX_HISTORY_FRAMES: usize = 50;
|
||||
|
||||
/// Broadcast channel capacity per canvas.
|
||||
const BROADCAST_CAPACITY: usize = 64;
|
||||
|
||||
/// Maximum number of concurrent canvases to prevent memory exhaustion.
|
||||
const MAX_CANVAS_COUNT: usize = 100;
|
||||
|
||||
/// Allowed content types for canvas frames via the REST API.
|
||||
pub const ALLOWED_CONTENT_TYPES: &[&str] = &["html", "svg", "markdown", "text"];
|
||||
|
||||
/// A single canvas frame (one render).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CanvasFrame {
|
||||
/// Unique frame identifier.
|
||||
pub frame_id: String,
|
||||
/// Content type: `html`, `svg`, `markdown`, or `text`.
|
||||
pub content_type: String,
|
||||
/// The rendered content.
|
||||
pub content: String,
|
||||
/// ISO-8601 timestamp of when the frame was created.
|
||||
pub timestamp: String,
|
||||
}
|
||||
|
||||
/// Per-canvas state: current content + history + broadcast sender.
|
||||
struct CanvasEntry {
|
||||
current: Option<CanvasFrame>,
|
||||
history: Vec<CanvasFrame>,
|
||||
tx: broadcast::Sender<CanvasFrame>,
|
||||
}
|
||||
|
||||
/// Shared canvas store — holds all active canvases.
|
||||
///
|
||||
/// Thread-safe and cheaply cloneable (wraps `Arc`).
|
||||
#[derive(Clone)]
|
||||
pub struct CanvasStore {
|
||||
inner: Arc<RwLock<HashMap<String, CanvasEntry>>>,
|
||||
}
|
||||
|
||||
impl Default for CanvasStore {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CanvasStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new frame to a canvas. Creates the canvas if it does not exist.
|
||||
/// Returns `None` if the maximum canvas count has been reached and this is a new canvas.
|
||||
pub fn render(
|
||||
&self,
|
||||
canvas_id: &str,
|
||||
content_type: &str,
|
||||
content: &str,
|
||||
) -> Option<CanvasFrame> {
|
||||
let frame = CanvasFrame {
|
||||
frame_id: uuid::Uuid::new_v4().to_string(),
|
||||
content_type: content_type.to_string(),
|
||||
content: content.to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
|
||||
let mut store = self.inner.write();
|
||||
|
||||
// Enforce canvas count limit for new canvases.
|
||||
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = store
|
||||
.entry(canvas_id.to_string())
|
||||
.or_insert_with(|| CanvasEntry {
|
||||
current: None,
|
||||
history: Vec::new(),
|
||||
tx: broadcast::channel(BROADCAST_CAPACITY).0,
|
||||
});
|
||||
|
||||
entry.current = Some(frame.clone());
|
||||
entry.history.push(frame.clone());
|
||||
if entry.history.len() > MAX_HISTORY_FRAMES {
|
||||
let excess = entry.history.len() - MAX_HISTORY_FRAMES;
|
||||
entry.history.drain(..excess);
|
||||
}
|
||||
|
||||
// Best-effort broadcast — ignore errors (no receivers is fine).
|
||||
let _ = entry.tx.send(frame.clone());
|
||||
|
||||
Some(frame)
|
||||
}
|
||||
|
||||
/// Get the current (most recent) frame for a canvas.
|
||||
pub fn snapshot(&self, canvas_id: &str) -> Option<CanvasFrame> {
|
||||
let store = self.inner.read();
|
||||
store.get(canvas_id).and_then(|entry| entry.current.clone())
|
||||
}
|
||||
|
||||
/// Get the frame history for a canvas.
|
||||
pub fn history(&self, canvas_id: &str) -> Vec<CanvasFrame> {
|
||||
let store = self.inner.read();
|
||||
store
|
||||
.get(canvas_id)
|
||||
.map(|entry| entry.history.clone())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Clear a canvas (removes current content and history).
|
||||
pub fn clear(&self, canvas_id: &str) -> bool {
|
||||
let mut store = self.inner.write();
|
||||
if let Some(entry) = store.get_mut(canvas_id) {
|
||||
entry.current = None;
|
||||
entry.history.clear();
|
||||
// Send an empty frame to signal clear to subscribers.
|
||||
let clear_frame = CanvasFrame {
|
||||
frame_id: uuid::Uuid::new_v4().to_string(),
|
||||
content_type: "clear".to_string(),
|
||||
content: String::new(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
let _ = entry.tx.send(clear_frame);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to real-time updates for a canvas.
|
||||
/// Creates the canvas entry if it does not exist (subject to canvas count limit).
|
||||
/// Returns `None` if the canvas does not exist and the limit has been reached.
|
||||
pub fn subscribe(&self, canvas_id: &str) -> Option<broadcast::Receiver<CanvasFrame>> {
|
||||
let mut store = self.inner.write();
|
||||
|
||||
// Enforce canvas count limit for new entries.
|
||||
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = store
|
||||
.entry(canvas_id.to_string())
|
||||
.or_insert_with(|| CanvasEntry {
|
||||
current: None,
|
||||
history: Vec::new(),
|
||||
tx: broadcast::channel(BROADCAST_CAPACITY).0,
|
||||
});
|
||||
Some(entry.tx.subscribe())
|
||||
}
|
||||
|
||||
/// List all canvas IDs that currently have content.
|
||||
pub fn list(&self) -> Vec<String> {
|
||||
let store = self.inner.read();
|
||||
store.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// `CanvasTool` — agent-callable tool for the Live Canvas (A2UI) system.
|
||||
pub struct CanvasTool {
|
||||
store: CanvasStore,
|
||||
}
|
||||
|
||||
impl CanvasTool {
|
||||
pub fn new(store: CanvasStore) -> Self {
|
||||
Self { store }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CanvasTool {
|
||||
fn name(&self) -> &str {
|
||||
"canvas"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Push rendered content (HTML, SVG, Markdown) to a live web canvas that users can see \
|
||||
in real-time. Actions: render (push content), snapshot (get current content), \
|
||||
clear (reset canvas), eval (evaluate JS expression in canvas context). \
|
||||
Each canvas is identified by a canvas_id string."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform on the canvas.",
|
||||
"enum": ["render", "snapshot", "clear", "eval"]
|
||||
},
|
||||
"canvas_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the canvas. Defaults to 'default'."
|
||||
},
|
||||
"content_type": {
|
||||
"type": "string",
|
||||
"description": "Content type for render action: html, svg, markdown, or text.",
|
||||
"enum": ["html", "svg", "markdown", "text"]
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to render (for render action)."
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "JavaScript expression to evaluate (for eval action). \
|
||||
The result is returned as text. Evaluated client-side in the canvas iframe."
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let action = match args.get("action").and_then(|v| v.as_str()) {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: action".to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let canvas_id = args
|
||||
.get("canvas_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
|
||||
match action {
|
||||
"render" => {
|
||||
let content_type = args
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("html");
|
||||
|
||||
let content = match args.get("content").and_then(|v| v.as_str()) {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing required parameter: content (for render action)"
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if content.len() > MAX_CONTENT_SIZE {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Content exceeds maximum size of {} bytes",
|
||||
MAX_CONTENT_SIZE
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
match self.store.render(canvas_id, content_type, content) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Rendered {} content to canvas '{}' (frame: {})",
|
||||
content_type, canvas_id, frame.frame_id
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Maximum canvas count ({}) reached. Clear unused canvases first.",
|
||||
MAX_CANVAS_COUNT
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
"snapshot" => match self.store.snapshot(canvas_id) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string_pretty(&frame)
|
||||
.unwrap_or_else(|_| frame.content.clone()),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Canvas '{}' is empty", canvas_id),
|
||||
error: None,
|
||||
}),
|
||||
},
|
||||
|
||||
"clear" => {
|
||||
let existed = self.store.clear(canvas_id);
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: if existed {
|
||||
format!("Canvas '{}' cleared", canvas_id)
|
||||
} else {
|
||||
format!("Canvas '{}' was already empty", canvas_id)
|
||||
},
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
|
||||
"eval" => {
|
||||
// Eval is handled client-side. We store an eval request as a special frame
|
||||
// that the web viewer interprets.
|
||||
let expression = match args.get("expression").and_then(|v| v.as_str()) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Missing required parameter: expression (for eval action)"
|
||||
.to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Push a special eval frame so connected clients know to evaluate it.
|
||||
match self.store.render(canvas_id, "eval", expression) {
|
||||
Some(frame) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Eval request sent to canvas '{}' (frame: {}). \
|
||||
Result will be available to connected viewers.",
|
||||
canvas_id, frame.frame_id
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
None => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Maximum canvas count ({}) reached. Clear unused canvases first.",
|
||||
MAX_CANVAS_COUNT
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
other => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Unknown action: '{}'. Valid actions: render, snapshot, clear, eval",
|
||||
other
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn canvas_store_render_and_snapshot() {
|
||||
let store = CanvasStore::new();
|
||||
let frame = store.render("test", "html", "<h1>Hello</h1>").unwrap();
|
||||
assert_eq!(frame.content_type, "html");
|
||||
assert_eq!(frame.content, "<h1>Hello</h1>");
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.frame_id, frame.frame_id);
|
||||
assert_eq!(snapshot.content, "<h1>Hello</h1>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_snapshot_empty_returns_none() {
|
||||
let store = CanvasStore::new();
|
||||
assert!(store.snapshot("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_clear_removes_content() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>content</p>");
|
||||
assert!(store.snapshot("test").is_some());
|
||||
|
||||
let cleared = store.clear("test");
|
||||
assert!(cleared);
|
||||
assert!(store.snapshot("test").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_clear_nonexistent_returns_false() {
|
||||
let store = CanvasStore::new();
|
||||
assert!(!store.clear("nonexistent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_history_tracks_frames() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "frame1");
|
||||
store.render("test", "html", "frame2");
|
||||
store.render("test", "html", "frame3");
|
||||
|
||||
let history = store.history("test");
|
||||
assert_eq!(history.len(), 3);
|
||||
assert_eq!(history[0].content, "frame1");
|
||||
assert_eq!(history[2].content, "frame3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_history_limit_enforced() {
|
||||
let store = CanvasStore::new();
|
||||
for i in 0..60 {
|
||||
store.render("test", "html", &format!("frame{i}"));
|
||||
}
|
||||
|
||||
let history = store.history("test");
|
||||
assert_eq!(history.len(), MAX_HISTORY_FRAMES);
|
||||
// Oldest frames should have been dropped
|
||||
assert_eq!(history[0].content, "frame10");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_list_returns_canvas_ids() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("alpha", "html", "a");
|
||||
store.render("beta", "svg", "b");
|
||||
|
||||
let mut ids = store.list();
|
||||
ids.sort();
|
||||
assert_eq!(ids, vec!["alpha", "beta"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_subscribe_receives_updates() {
|
||||
let store = CanvasStore::new();
|
||||
let mut rx = store.subscribe("test").unwrap();
|
||||
store.render("test", "html", "<p>live</p>");
|
||||
|
||||
let frame = rx.try_recv().unwrap();
|
||||
assert_eq!(frame.content, "<p>live</p>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"canvas_id": "test",
|
||||
"content_type": "html",
|
||||
"content": "<h1>Hello World</h1>"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Rendered html content"));
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.content, "<h1>Hello World</h1>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_snapshot_action() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>snap</p>");
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "snapshot", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("<p>snap</p>"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_snapshot_empty() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "snapshot", "canvas_id": "empty"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_clear_action() {
|
||||
let store = CanvasStore::new();
|
||||
store.render("test", "html", "<p>clear me</p>");
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({"action": "clear", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("cleared"));
|
||||
assert!(store.snapshot("test").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_eval_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "eval",
|
||||
"canvas_id": "test",
|
||||
"expression": "document.title"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Eval request sent"));
|
||||
|
||||
let snapshot = store.snapshot("test").unwrap();
|
||||
assert_eq!(snapshot.content_type, "eval");
|
||||
assert_eq!(snapshot.content, "document.title");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_unknown_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("Unknown action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_missing_action() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("action"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_missing_content() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "render", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("content"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_render_content_too_large() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let big_content = "x".repeat(MAX_CONTENT_SIZE + 1);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"canvas_id": "test",
|
||||
"content": big_content
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("maximum size"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_default_canvas_id() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store.clone());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "render",
|
||||
"content_type": "html",
|
||||
"content": "<p>default</p>"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(store.snapshot("default").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canvas_store_enforces_max_canvas_count() {
|
||||
let store = CanvasStore::new();
|
||||
// Create MAX_CANVAS_COUNT canvases
|
||||
for i in 0..MAX_CANVAS_COUNT {
|
||||
assert!(store
|
||||
.render(&format!("canvas_{i}"), "html", "content")
|
||||
.is_some());
|
||||
}
|
||||
// The next new canvas should be rejected
|
||||
assert!(store.render("one_too_many", "html", "content").is_none());
|
||||
// But rendering to an existing canvas should still work
|
||||
assert!(store.render("canvas_0", "html", "updated").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn canvas_tool_eval_missing_expression() {
|
||||
let store = CanvasStore::new();
|
||||
let tool = CanvasTool::new(store);
|
||||
let result = tool
|
||||
.execute(json!({"action": "eval", "canvas_id": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("expression"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Search Discord message history stored in discord.db.
|
||||
pub struct DiscordSearchTool {
|
||||
discord_memory: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl DiscordSearchTool {
|
||||
pub fn new(discord_memory: Arc<dyn Memory>) -> Self {
|
||||
Self { discord_memory }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DiscordSearchTool {
|
||||
fn name(&self) -> &str {
|
||||
"discord_search"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Search Discord message history. Returns messages matching a keyword query, optionally filtered by channel_id, author_id, or time range."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Keywords or phrase to search for in Discord messages (optional if since/until provided)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default: 10)"
|
||||
},
|
||||
"channel_id": {
|
||||
"type": "string",
|
||||
"description": "Filter results to a specific Discord channel ID"
|
||||
},
|
||||
"since": {
|
||||
"type": "string",
|
||||
"description": "Filter messages at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
|
||||
},
|
||||
"until": {
|
||||
"type": "string",
|
||||
"description": "Filter messages at or before this time (RFC 3339)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let channel_id = args.get("channel_id").and_then(|v| v.as_str());
|
||||
let since = args.get("since").and_then(|v| v.as_str());
|
||||
let until = args.get("until").and_then(|v| v.as_str());
|
||||
|
||||
if query.trim().is_empty() && since.is_none() && until.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(s) = since {
|
||||
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'since' date: {s}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let Some(u) = until {
|
||||
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'until' date: {u}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
if let (Some(s), Some(u)) = (since, until) {
|
||||
if let (Ok(s_dt), Ok(u_dt)) = (
|
||||
chrono::DateTime::parse_from_rfc3339(s),
|
||||
chrono::DateTime::parse_from_rfc3339(u),
|
||||
) {
|
||||
if s_dt >= u_dt {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'since' must be before 'until'".into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(10, |v| v as usize);
|
||||
|
||||
match self
|
||||
.discord_memory
|
||||
.recall(query, limit, channel_id, since, until)
|
||||
.await
|
||||
{
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No Discord messages found.".into(),
|
||||
error: None,
|
||||
}),
|
||||
Ok(entries) => {
|
||||
let mut output = format!("Found {} Discord messages:\n", entries.len());
|
||||
for entry in &entries {
|
||||
let score = entry
|
||||
.score
|
||||
.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
|
||||
let _ = writeln!(output, "- {}{score}", entry.content);
|
||||
}
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Discord search failed: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn seeded_discord_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new_named(tmp.path(), "discord").unwrap();
|
||||
(tmp, Arc::new(mem))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_empty() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No Discord messages found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_finds_match() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
mem.store(
|
||||
"discord_001",
|
||||
"@user1 in #general at 2025-01-01T00:00:00Z: hello world",
|
||||
MemoryCategory::Custom("discord".to_string()),
|
||||
Some("general"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("hello"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn search_requires_query_or_time() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("at least"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = seeded_discord_mem();
|
||||
let tool = DiscordSearchTool::new(mem);
|
||||
assert_eq!(tool.name(), "discord_search");
|
||||
assert!(tool.parameters_schema()["properties"]["query"].is_object());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,494 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
|
||||
///
|
||||
/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
|
||||
/// calls the fal.ai synchronous endpoint, downloads the resulting image,
|
||||
/// and saves it to `{workspace}/images/{filename}.png`.
|
||||
pub struct ImageGenTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
default_model: String,
|
||||
api_key_env: String,
|
||||
}
|
||||
|
||||
impl ImageGenTool {
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
default_model: String,
|
||||
api_key_env: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
security,
|
||||
workspace_dir,
|
||||
default_model,
|
||||
api_key_env,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a reusable HTTP client with reasonable timeouts.
|
||||
fn http_client() -> reqwest::Client {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.build()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Read an API key from the environment.
|
||||
fn read_api_key(env_var: &str) -> Result<String, String> {
|
||||
std::env::var(env_var)
|
||||
.map(|v| v.trim().to_string())
|
||||
.ok()
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
|
||||
}
|
||||
|
||||
/// Core generation logic: call fal.ai, download image, save to disk.
|
||||
async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// ── Parse parameters ───────────────────────────────────────
|
||||
let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
|
||||
Some(p) if !p.trim().is_empty() => p.trim().to_string(),
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Missing required parameter: 'prompt'".into()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let filename = args
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or("generated_image");
|
||||
|
||||
// Sanitize filename — strip path components to prevent traversal.
|
||||
let safe_name = PathBuf::from(filename).file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
|
||||
let size = args
|
||||
.get("size")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("square_hd");
|
||||
|
||||
// Validate size enum.
|
||||
const VALID_SIZES: &[&str] = &[
|
||||
"square_hd",
|
||||
"landscape_4_3",
|
||||
"portrait_4_3",
|
||||
"landscape_16_9",
|
||||
"portrait_16_9",
|
||||
];
|
||||
if !VALID_SIZES.contains(&size) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid size '{size}'. Valid values: {}",
|
||||
VALID_SIZES.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let model = args
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.trim().is_empty())
|
||||
.unwrap_or(&self.default_model);
|
||||
|
||||
// Validate model identifier: must look like a fal.ai model path
|
||||
// (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
|
||||
// strings, or fragments that could redirect the HTTP request.
|
||||
if model.contains("..")
|
||||
|| model.contains('?')
|
||||
|| model.contains('#')
|
||||
|| model.contains('\\')
|
||||
|| model.starts_with('/')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid model identifier '{model}'. \
|
||||
Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// ── Read API key ───────────────────────────────────────────
|
||||
let api_key = match Self::read_api_key(&self.api_key_env) {
|
||||
Ok(k) => k,
|
||||
Err(msg) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(msg),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// ── Call fal.ai ────────────────────────────────────────────
|
||||
let client = Self::http_client();
|
||||
let url = format!("https://fal.run/{model}");
|
||||
|
||||
let body = json!({
|
||||
"prompt": prompt,
|
||||
"image_size": size,
|
||||
"num_images": 1
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Key {api_key}"))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("fal.ai request failed")?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("fal.ai API error ({status}): {body_text}")),
|
||||
});
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse fal.ai response as JSON")?;
|
||||
|
||||
let image_url = resp_json
|
||||
.pointer("/images/0/url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("No image URL in fal.ai response"))?;
|
||||
|
||||
// ── Download image ─────────────────────────────────────────
|
||||
let img_resp = client
|
||||
.get(image_url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to download generated image")?;
|
||||
|
||||
if !img_resp.status().is_success() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to download image from {image_url} ({})",
|
||||
img_resp.status()
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let bytes = img_resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("Failed to read image bytes")?;
|
||||
|
||||
// ── Save to disk ───────────────────────────────────────────
|
||||
let images_dir = self.workspace_dir.join("images");
|
||||
tokio::fs::create_dir_all(&images_dir)
|
||||
.await
|
||||
.context("Failed to create images directory")?;
|
||||
|
||||
let output_path = images_dir.join(format!("{safe_name}.png"));
|
||||
tokio::fs::write(&output_path, &bytes)
|
||||
.await
|
||||
.context("Failed to write image file")?;
|
||||
|
||||
let size_kb = bytes.len() / 1024;
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Image generated successfully.\n\
|
||||
File: {}\n\
|
||||
Size: {} KB\n\
|
||||
Model: {}\n\
|
||||
Prompt: {}",
|
||||
output_path.display(),
|
||||
size_kb,
|
||||
model,
|
||||
prompt,
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ImageGenTool {
|
||||
fn name(&self) -> &str {
|
||||
"image_gen"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Generate an image from a text prompt using fal.ai (Flux models). \
|
||||
Saves the result to the workspace images directory and returns the file path."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["prompt"],
|
||||
"properties": {
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Text prompt describing the image to generate."
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
|
||||
},
|
||||
"size": {
|
||||
"type": "string",
|
||||
"enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
|
||||
"description": "Image aspect ratio / size preset (default: 'square_hd')."
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Security: image generation is a side-effecting action (HTTP + file write).
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "image_gen")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
self.generate(args).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn test_tool() -> ImageGenTool {
|
||||
ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY".into(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name() {
|
||||
let tool = test_tool();
|
||||
assert_eq!(tool.name(), "image_gen");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_description_is_nonempty() {
|
||||
let tool = test_tool();
|
||||
assert!(!tool.description().is_empty());
|
||||
assert!(tool.description().contains("image"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_required_prompt() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["required"], json!(["prompt"]));
|
||||
assert!(schema["properties"]["prompt"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_optional_params() {
|
||||
let tool = test_tool();
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["filename"].is_object());
|
||||
assert!(schema["properties"]["size"].is_object());
|
||||
assert!(schema["properties"]["model"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_spec_roundtrip() {
|
||||
let tool = test_tool();
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "image_gen");
|
||||
assert!(spec.parameters.is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_prompt_returns_error() {
|
||||
let tool = test_tool();
|
||||
let result = tool.execute(json!({"prompt": " "})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("prompt"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_api_key_returns_error() {
|
||||
// Temporarily ensure the env var is unset.
|
||||
let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
|
||||
std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_IMAGE_GEN".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "a sunset over the ocean"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("FAL_API_KEY_TEST_IMAGE_GEN"));
|
||||
|
||||
// Restore if it was set.
|
||||
if let Some(val) = original {
|
||||
std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_size_returns_error() {
|
||||
// Set a dummy key so we get past the key check.
|
||||
std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_SIZE".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "test", "size": "invalid_size"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap().contains("Invalid size"));
|
||||
|
||||
std::env::remove_var("FAL_API_KEY_TEST_SIZE");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_only_autonomy_blocks_execution() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ImageGenTool::new(
|
||||
security,
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY".into(),
|
||||
);
|
||||
let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
let err = result.error.as_deref().unwrap();
|
||||
assert!(
|
||||
err.contains("read-only") || err.contains("image_gen"),
|
||||
"expected read-only or image_gen in error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_model_with_traversal_returns_error() {
|
||||
std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key");
|
||||
|
||||
let tool = ImageGenTool::new(
|
||||
test_security(),
|
||||
std::env::temp_dir(),
|
||||
"fal-ai/flux/schnell".into(),
|
||||
"FAL_API_KEY_TEST_MODEL".into(),
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.contains("Invalid model identifier"));
|
||||
|
||||
std::env::remove_var("FAL_API_KEY_TEST_MODEL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_missing() {
|
||||
let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.contains("DEFINITELY_NOT_SET_ZC_TEST_12345"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filename_traversal_is_sanitized() {
|
||||
// Verify that path traversal in filenames is stripped to just the final component.
|
||||
let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
assert_eq!(sanitized, "passwd");
|
||||
|
||||
// ".." alone has no file_name, falls back to default.
|
||||
let sanitized = PathBuf::from("..").file_name().map_or_else(
|
||||
|| "generated_image".to_string(),
|
||||
|n| n.to_string_lossy().to_string(),
|
||||
);
|
||||
assert_eq!(sanitized, "generated_image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_api_key_present() {
|
||||
std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123");
|
||||
let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "test_value_123");
|
||||
std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY");
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ pub mod browser;
|
||||
pub mod browser_delegate;
|
||||
pub mod browser_open;
|
||||
pub mod calculator;
|
||||
pub mod canvas;
|
||||
pub mod claude_code;
|
||||
pub mod cli_discovery;
|
||||
pub mod cloud_ops;
|
||||
@@ -34,6 +35,7 @@ pub mod cron_runs;
|
||||
pub mod cron_update;
|
||||
pub mod data_management;
|
||||
pub mod delegate;
|
||||
pub mod discord_search;
|
||||
pub mod file_edit;
|
||||
pub mod file_read;
|
||||
pub mod file_write;
|
||||
@@ -47,6 +49,7 @@ pub mod hardware_memory_map;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub mod hardware_memory_read;
|
||||
pub mod http_request;
|
||||
pub mod image_gen;
|
||||
pub mod image_info;
|
||||
pub mod jira_tool;
|
||||
pub mod knowledge_tool;
|
||||
@@ -76,6 +79,7 @@ pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod sessions;
|
||||
pub mod shell;
|
||||
pub mod swarm;
|
||||
pub mod text_browser;
|
||||
@@ -94,6 +98,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
|
||||
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
|
||||
pub use browser_open::BrowserOpenTool;
|
||||
pub use calculator::CalculatorTool;
|
||||
pub use canvas::{CanvasStore, CanvasTool};
|
||||
pub use claude_code::ClaudeCodeTool;
|
||||
pub use cloud_ops::CloudOpsTool;
|
||||
pub use cloud_patterns::CloudPatternsTool;
|
||||
@@ -107,6 +112,7 @@ pub use cron_runs::CronRunsTool;
|
||||
pub use cron_update::CronUpdateTool;
|
||||
pub use data_management::DataManagementTool;
|
||||
pub use delegate::DelegateTool;
|
||||
pub use discord_search::DiscordSearchTool;
|
||||
pub use file_edit::FileEditTool;
|
||||
pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
@@ -120,6 +126,7 @@ pub use hardware_memory_map::HardwareMemoryMapTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub use hardware_memory_read::HardwareMemoryReadTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use image_gen::ImageGenTool;
|
||||
pub use image_info::ImageInfoTool;
|
||||
pub use jira_tool::JiraTool;
|
||||
pub use knowledge_tool::KnowledgeTool;
|
||||
@@ -147,6 +154,7 @@ pub use schedule::ScheduleTool;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use sessions::{SessionsHistoryTool, SessionsListTool, SessionsSendTool};
|
||||
pub use shell::ShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use text_browser::TextBrowserTool;
|
||||
@@ -264,6 +272,7 @@ pub fn all_tools(
|
||||
agents: &HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
canvas_store: Option<CanvasStore>,
|
||||
) -> (
|
||||
Vec<Box<dyn Tool>>,
|
||||
Option<DelegateParentToolsHandle>,
|
||||
@@ -283,6 +292,7 @@ pub fn all_tools(
|
||||
agents,
|
||||
fallback_api_key,
|
||||
root_config,
|
||||
canvas_store,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -302,6 +312,7 @@ pub fn all_tools_with_runtime(
|
||||
agents: &HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<&str>,
|
||||
root_config: &crate::config::Config,
|
||||
canvas_store: Option<CanvasStore>,
|
||||
) -> (
|
||||
Vec<Box<dyn Tool>>,
|
||||
Option<DelegateParentToolsHandle>,
|
||||
@@ -346,8 +357,21 @@ pub fn all_tools_with_runtime(
|
||||
)),
|
||||
Arc::new(CalculatorTool::new()),
|
||||
Arc::new(WeatherTool::new()),
|
||||
Arc::new(CanvasTool::new(canvas_store.unwrap_or_default())),
|
||||
];
|
||||
|
||||
// Register discord_search if discord_history channel is configured
|
||||
if root_config.channels_config.discord_history.is_some() {
|
||||
match crate::memory::SqliteMemory::new_named(workspace_dir, "discord") {
|
||||
Ok(discord_mem) => {
|
||||
tool_arcs.push(Arc::new(DiscordSearchTool::new(Arc::new(discord_mem))));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("discord_search: failed to open discord.db: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(
|
||||
root_config.skills.prompt_injection_mode,
|
||||
crate::config::SkillsPromptInjectionMode::Compact
|
||||
@@ -555,6 +579,18 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
|
||||
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));
|
||||
|
||||
// Session-to-session messaging tools (always available when sessions dir exists)
|
||||
if let Ok(session_store) = crate::channels::session_store::SessionStore::new(workspace_dir) {
|
||||
let backend: Arc<dyn crate::channels::session_backend::SessionBackend> =
|
||||
Arc::new(session_store);
|
||||
tool_arcs.push(Arc::new(SessionsListTool::new(backend.clone())));
|
||||
tool_arcs.push(Arc::new(SessionsHistoryTool::new(
|
||||
backend.clone(),
|
||||
security.clone(),
|
||||
)));
|
||||
tool_arcs.push(Arc::new(SessionsSendTool::new(backend, security.clone())));
|
||||
}
|
||||
|
||||
// LinkedIn integration (config-gated)
|
||||
if root_config.linkedin.enabled {
|
||||
tool_arcs.push(Arc::new(LinkedInTool::new(
|
||||
@@ -566,6 +602,16 @@ pub fn all_tools_with_runtime(
|
||||
)));
|
||||
}
|
||||
|
||||
// Standalone image generation tool (config-gated)
|
||||
if root_config.image_gen.enabled {
|
||||
tool_arcs.push(Arc::new(ImageGenTool::new(
|
||||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
root_config.image_gen.default_model.clone(),
|
||||
root_config.image_gen.api_key_env.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(key) = composio_key {
|
||||
if !key.is_empty() {
|
||||
tool_arcs.push(Arc::new(ComposioTool::new(
|
||||
@@ -856,6 +902,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"browser_open"));
|
||||
@@ -898,6 +945,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"browser_open"));
|
||||
@@ -1051,6 +1099,7 @@ mod tests {
|
||||
&agents,
|
||||
Some("delegate-test-credential"),
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"delegate"));
|
||||
@@ -1084,6 +1133,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"delegate"));
|
||||
@@ -1118,6 +1168,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"read_skill"));
|
||||
@@ -1152,6 +1203,7 @@ mod tests {
|
||||
&HashMap::new(),
|
||||
None,
|
||||
&cfg,
|
||||
None,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"read_skill"));
|
||||
|
||||
@@ -0,0 +1,573 @@
|
||||
//! Session-to-session messaging tools for inter-agent communication.
|
||||
//!
|
||||
//! Provides three tools:
|
||||
//! - `sessions_list` — list active sessions with metadata
|
||||
//! - `sessions_history` — read message history from a specific session
|
||||
//! - `sessions_send` — send a message to a specific session
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::channels::session_backend::SessionBackend;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Validate that a session ID is non-empty and contains at least one
|
||||
/// alphanumeric character (prevents blank keys after sanitization).
|
||||
fn validate_session_id(session_id: &str) -> Result<(), ToolResult> {
|
||||
let trimmed = session_id.trim();
|
||||
if trimmed.is_empty() || !trimmed.chars().any(|c| c.is_alphanumeric()) {
|
||||
return Err(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Invalid 'session_id': must be non-empty and contain at least one alphanumeric character.".into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── SessionsListTool ────────────────────────────────────────────────
|
||||
|
||||
/// Lists active sessions with their channel, last activity time, and message count.
|
||||
pub struct SessionsListTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
}
|
||||
|
||||
impl SessionsListTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>) -> Self {
|
||||
Self { backend }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsListTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_list"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"List all active conversation sessions with their channel, last activity time, and message count."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max sessions to return (default: 50)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(50, |v| v as usize);
|
||||
|
||||
let metadata = self.backend.list_sessions_with_metadata();
|
||||
|
||||
if metadata.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No active sessions found.".into(),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
let capped: Vec<_> = metadata.into_iter().take(limit).collect();
|
||||
let mut output = format!("Found {} session(s):\n", capped.len());
|
||||
for meta in &capped {
|
||||
// Extract channel from key (convention: channel__identifier)
|
||||
let channel = meta.key.split("__").next().unwrap_or(&meta.key);
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"- {}: channel={}, messages={}, last_activity={}",
|
||||
meta.key, channel, meta.message_count, meta.last_activity
|
||||
);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionsHistoryTool ─────────────────────────────────────────────
|
||||
|
||||
/// Reads the message history of a specific session by ID.
|
||||
pub struct SessionsHistoryTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SessionsHistoryTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { backend, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsHistoryTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_history"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Read the message history of a specific session by its session ID. Returns the last N messages."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The session ID to read history from (e.g. telegram__user123)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max messages to return, from most recent (default: 20)"
|
||||
}
|
||||
},
|
||||
"required": ["session_id"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Read, "sessions_history")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let session_id = args
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
|
||||
|
||||
if let Err(result) = validate_session_id(session_id) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let limit = args
|
||||
.get("limit")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(20, |v| v as usize);
|
||||
|
||||
let messages = self.backend.load(session_id);
|
||||
|
||||
if messages.is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("No messages found for session '{session_id}'."),
|
||||
error: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Take the last `limit` messages
|
||||
let start = messages.len().saturating_sub(limit);
|
||||
let tail = &messages[start..];
|
||||
|
||||
let mut output = format!(
|
||||
"Session '{}': showing {}/{} messages\n",
|
||||
session_id,
|
||||
tail.len(),
|
||||
messages.len()
|
||||
);
|
||||
for msg in tail {
|
||||
let _ = writeln!(output, "[{}] {}", msg.role, msg.content);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ── SessionsSendTool ────────────────────────────────────────────────
|
||||
|
||||
/// Sends a message to a specific session, enabling inter-agent communication.
|
||||
pub struct SessionsSendTool {
|
||||
backend: Arc<dyn SessionBackend>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SessionsSendTool {
|
||||
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { backend, security }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SessionsSendTool {
|
||||
fn name(&self) -> &str {
|
||||
"sessions_send"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a message to a specific session by its session ID. The message is appended to the session's conversation history as a 'user' message, enabling inter-agent communication."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The target session ID (e.g. telegram__user123)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
}
|
||||
},
|
||||
"required": ["session_id", "message"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "sessions_send")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let session_id = args
|
||||
.get("session_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
|
||||
|
||||
if let Err(result) = validate_session_id(session_id) {
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let message = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?;
|
||||
|
||||
if message.trim().is_empty() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Message content must not be empty.".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let chat_msg = crate::providers::traits::ChatMessage::user(message);
|
||||
|
||||
match self.backend.append(session_id, &chat_msg) {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Message sent to session '{session_id}'."),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to send message: {e}")),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::channels::session_store::SessionStore;
|
||||
use crate::providers::traits::ChatMessage;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_backend() -> (TempDir, Arc<dyn SessionBackend>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(store))
|
||||
}
|
||||
|
||||
fn seeded_backend() -> (TempDir, Arc<dyn SessionBackend>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let store = SessionStore::new(tmp.path()).unwrap();
|
||||
store
|
||||
.append("telegram__alice", &ChatMessage::user("Hello from Alice"))
|
||||
.unwrap();
|
||||
store
|
||||
.append(
|
||||
"telegram__alice",
|
||||
&ChatMessage::assistant("Hi Alice, how can I help?"),
|
||||
)
|
||||
.unwrap();
|
||||
store
|
||||
.append("discord__bob", &ChatMessage::user("Hey from Bob"))
|
||||
.unwrap();
|
||||
(tmp, Arc::new(store))
|
||||
}
|
||||
|
||||
// ── SessionsListTool tests ──────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_empty_sessions() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No active sessions"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_shows_all() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("2 session(s)"));
|
||||
assert!(result.output.contains("telegram__alice"));
|
||||
assert!(result.output.contains("discord__bob"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_respects_limit() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({"limit": 1})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("1 session(s)"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_sessions_extracts_channel() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
let result = tool.execute(json!({})).await.unwrap();
|
||||
assert!(result.output.contains("channel=telegram"));
|
||||
assert!(result.output.contains("channel=discord"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsListTool::new(backend);
|
||||
assert_eq!(tool.name(), "sessions_list");
|
||||
assert!(tool.parameters_schema()["properties"]["limit"].is_object());
|
||||
}
|
||||
|
||||
// ── SessionsHistoryTool tests ───────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_empty_session() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "nonexistent"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No messages found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_returns_messages() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "telegram__alice"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("showing 2/2 messages"));
|
||||
assert!(result.output.contains("[user] Hello from Alice"));
|
||||
assert!(result.output.contains("[assistant] Hi Alice"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_respects_limit() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"session_id": "telegram__alice", "limit": 1}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("showing 1/2 messages"));
|
||||
// Should show only the last message
|
||||
assert!(result.output.contains("[assistant]"));
|
||||
assert!(!result.output.contains("[user] Hello from Alice"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_missing_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("session_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn history_rejects_empty_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"session_id": " "})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsHistoryTool::new(backend, test_security());
|
||||
assert_eq!(tool.name(), "sessions_history");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["session_id"].is_object());
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("session_id")));
|
||||
}
|
||||
|
||||
// ── SessionsSendTool tests ──────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_appends_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": "Hello from another agent"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Message sent"));
|
||||
|
||||
// Verify message was appended
|
||||
let messages = backend.load("telegram__alice");
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].role, "user");
|
||||
assert_eq!(messages[0].content, "Hello from another agent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_to_existing_session() {
|
||||
let (_tmp, backend) = seeded_backend();
|
||||
let tool = SessionsSendTool::new(backend.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": "Inter-agent message"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.success);
|
||||
|
||||
let messages = backend.load("telegram__alice");
|
||||
assert_eq!(messages.len(), 3);
|
||||
assert_eq!(messages[2].content, "Inter-agent message");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_empty_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "telegram__alice",
|
||||
"message": " "
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("empty"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_empty_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "",
|
||||
"message": "hello"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_rejects_non_alphanumeric_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"session_id": "///",
|
||||
"message": "hello"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_missing_session_id() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"message": "hi"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("session_id"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_missing_message() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
let result = tool.execute(json!({"session_id": "telegram__alice"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("message"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_tool_name_and_schema() {
|
||||
let (_tmp, backend) = test_backend();
|
||||
let tool = SessionsSendTool::new(backend, test_security());
|
||||
assert_eq!(tool.name(), "sessions_send");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("session_id")));
|
||||
assert!(schema["required"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.contains(&json!("message")));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user