From ff6027fce79a4c438bab070b2ed7d43120d802bc Mon Sep 17 00:00:00 2001 From: xj Date: Thu, 19 Feb 2026 01:37:12 -0800 Subject: [PATCH] feat(hooks): add HookHandler trait, HookResult, and HookRunner dispatcher Co-Authored-By: Claude Opus 4.6 --- src/config/mod.rs | 12 +- src/hooks/mod.rs | 5 + src/hooks/runner.rs | 414 ++++++++++++++++++++++++++++++++++++++++++++ src/hooks/traits.rs | 140 +++++++++++++++ src/lib.rs | 1 + 5 files changed, 566 insertions(+), 6 deletions(-) create mode 100644 src/hooks/mod.rs create mode 100644 src/hooks/runner.rs create mode 100644 src/hooks/traits.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index f329d8a96..a65270db0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -9,12 +9,12 @@ pub use schema::{ DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig, GatewayConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig, - NextcloudTalkConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, - ProxyConfig, ProxyScope, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, - RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, - SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig, - StorageProviderSection, StreamMode, TelegramConfig, TranscriptionConfig, TunnelConfig, - WebSearchConfig, WebhookConfig, + BuiltinHooksConfig, HooksConfig, NextcloudTalkConfig, ObservabilityConfig, + PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QueryClassificationConfig, + ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, + SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, SkillsPromptInjectionMode, + SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, + TelegramConfig, TranscriptionConfig, TunnelConfig, WebSearchConfig, WebhookConfig, }; #[cfg(test)] diff --git a/src/hooks/mod.rs b/src/hooks/mod.rs new file mode 100644 index 000000000..31df8274c --- /dev/null +++ b/src/hooks/mod.rs @@ -0,0 +1,5 @@ +mod runner; +mod traits; + +pub use runner::HookRunner; +pub use traits::{HookHandler, HookResult}; diff --git a/src/hooks/runner.rs b/src/hooks/runner.rs new file mode 100644 index 000000000..94844459b --- /dev/null +++ b/src/hooks/runner.rs @@ -0,0 +1,414 @@ +use std::time::Duration; + +use futures::future::join_all; +use serde_json::Value; +use tracing::info; + +use crate::channels::traits::ChannelMessage; +use crate::providers::traits::{ChatMessage, ChatResponse}; +use crate::tools::traits::ToolResult; + +use super::traits::{HookHandler, HookResult}; + +/// Dispatcher that manages registered hook handlers. +/// +/// Void hooks are dispatched in parallel via `join_all`. +/// Modifying hooks run sequentially by priority (higher first), piping output +/// and short-circuiting on `Cancel`. +pub struct HookRunner { + handlers: Vec>, +} + +impl HookRunner { + /// Create an empty runner with no handlers. + pub fn new() -> Self { + Self { + handlers: Vec::new(), + } + } + + /// Register a handler and re-sort by descending priority. + pub fn register(&mut self, handler: Box) { + self.handlers.push(handler); + self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority())); + } + + // --------------------------------------------------------------- + // Void dispatchers (parallel, fire-and-forget) + // --------------------------------------------------------------- + + pub async fn fire_gateway_start(&self, host: &str, port: u16) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_gateway_start(host, port)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_gateway_stop(&self) { + let futs: Vec<_> = self.handlers.iter().map(|h| h.on_gateway_stop()).collect(); + join_all(futs).await; + } + + pub async fn fire_session_start(&self, session_id: &str, channel: &str) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_session_start(session_id, channel)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_session_end(&self, session_id: &str, channel: &str) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_session_end(session_id, channel)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_llm_input(&self, messages: &[ChatMessage], model: &str) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_llm_input(messages, model)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_llm_output(&self, response: &ChatResponse) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_llm_output(response)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_after_tool_call(&self, tool: &str, result: &ToolResult, duration: Duration) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_after_tool_call(tool, result, duration)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_message_sent(&self, channel: &str, recipient: &str, content: &str) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_message_sent(channel, recipient, content)) + .collect(); + join_all(futs).await; + } + + pub async fn fire_heartbeat_tick(&self) { + let futs: Vec<_> = self + .handlers + .iter() + .map(|h| h.on_heartbeat_tick()) + .collect(); + join_all(futs).await; + } + + // --------------------------------------------------------------- + // Modifying dispatchers (sequential by priority, short-circuit on Cancel) + // --------------------------------------------------------------- + + pub async fn run_before_model_resolve( + &self, + mut provider: String, + mut model: String, + ) -> HookResult<(String, String)> { + for h in &self.handlers { + match h.before_model_resolve(provider, model).await { + HookResult::Continue((p, m)) => { + provider = p; + model = m; + } + HookResult::Cancel(reason) => { + info!( + hook = h.name(), + reason, "before_model_resolve cancelled by hook" + ); + return HookResult::Cancel(reason); + } + } + } + HookResult::Continue((provider, model)) + } + + pub async fn run_before_prompt_build(&self, mut prompt: String) -> HookResult { + for h in &self.handlers { + match h.before_prompt_build(prompt).await { + HookResult::Continue(p) => prompt = p, + HookResult::Cancel(reason) => { + info!( + hook = h.name(), + reason, "before_prompt_build cancelled by hook" + ); + return HookResult::Cancel(reason); + } + } + } + HookResult::Continue(prompt) + } + + pub async fn run_before_llm_call( + &self, + mut messages: Vec, + mut model: String, + ) -> HookResult<(Vec, String)> { + for h in &self.handlers { + match h.before_llm_call(messages, model).await { + HookResult::Continue((m, mdl)) => { + messages = m; + model = mdl; + } + HookResult::Cancel(reason) => { + info!(hook = h.name(), reason, "before_llm_call cancelled by hook"); + return HookResult::Cancel(reason); + } + } + } + HookResult::Continue((messages, model)) + } + + pub async fn run_before_tool_call( + &self, + mut name: String, + mut args: Value, + ) -> HookResult<(String, Value)> { + for h in &self.handlers { + match h.before_tool_call(name, args).await { + HookResult::Continue((n, a)) => { + name = n; + args = a; + } + HookResult::Cancel(reason) => { + info!( + hook = h.name(), + reason, "before_tool_call cancelled by hook" + ); + return HookResult::Cancel(reason); + } + } + } + HookResult::Continue((name, args)) + } + + pub async fn run_on_message_received( + &self, + mut message: ChannelMessage, + ) -> HookResult { + for h in &self.handlers { + match h.on_message_received(message).await { + HookResult::Continue(m) => message = m, + HookResult::Cancel(reason) => { + info!( + hook = h.name(), + reason, "on_message_received cancelled by hook" + ); + return HookResult::Cancel(reason); + } + } + } + HookResult::Continue(message) + } + + pub async fn run_on_message_sending( + &self, + mut channel: String, + mut recipient: String, + mut content: String, + ) -> HookResult<(String, String, String)> { + for h in &self.handlers { + match h.on_message_sending(channel, recipient, content).await { + HookResult::Continue((c, r, ct)) => { + channel = c; + recipient = r; + content = ct; + } + HookResult::Cancel(reason) => { + info!( + hook = h.name(), + reason, "on_message_sending cancelled by hook" + ); + return HookResult::Cancel(reason); + } + } + } + HookResult::Continue((channel, recipient, content)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use std::sync::atomic::{AtomicU32, Ordering}; + use std::sync::Arc; + + /// A hook that records how many times void events fire. + struct CountingHook { + name: String, + priority: i32, + fire_count: Arc, + } + + impl CountingHook { + fn new(name: &str, priority: i32) -> (Self, Arc) { + let count = Arc::new(AtomicU32::new(0)); + ( + Self { + name: name.to_string(), + priority, + fire_count: count.clone(), + }, + count, + ) + } + } + + #[async_trait] + impl HookHandler for CountingHook { + fn name(&self) -> &str { + &self.name + } + fn priority(&self) -> i32 { + self.priority + } + async fn on_heartbeat_tick(&self) { + self.fire_count.fetch_add(1, Ordering::SeqCst); + } + } + + /// A modifying hook that uppercases the prompt. + struct UppercasePromptHook { + name: String, + priority: i32, + } + + #[async_trait] + impl HookHandler for UppercasePromptHook { + fn name(&self) -> &str { + &self.name + } + fn priority(&self) -> i32 { + self.priority + } + async fn before_prompt_build(&self, prompt: String) -> HookResult { + HookResult::Continue(prompt.to_uppercase()) + } + } + + /// A modifying hook that cancels before_prompt_build. + struct CancelPromptHook { + name: String, + priority: i32, + } + + #[async_trait] + impl HookHandler for CancelPromptHook { + fn name(&self) -> &str { + &self.name + } + fn priority(&self) -> i32 { + self.priority + } + async fn before_prompt_build(&self, _prompt: String) -> HookResult { + HookResult::Cancel("blocked by policy".into()) + } + } + + /// A modifying hook that appends a suffix to the prompt. + struct SuffixPromptHook { + name: String, + priority: i32, + suffix: String, + } + + #[async_trait] + impl HookHandler for SuffixPromptHook { + fn name(&self) -> &str { + &self.name + } + fn priority(&self) -> i32 { + self.priority + } + async fn before_prompt_build(&self, prompt: String) -> HookResult { + HookResult::Continue(format!("{}{}", prompt, self.suffix)) + } + } + + #[test] + fn register_and_sort_by_priority() { + let mut runner = HookRunner::new(); + let (low, _) = CountingHook::new("low", 1); + let (high, _) = CountingHook::new("high", 10); + let (mid, _) = CountingHook::new("mid", 5); + + runner.register(Box::new(low)); + runner.register(Box::new(high)); + runner.register(Box::new(mid)); + + let names: Vec<&str> = runner.handlers.iter().map(|h| h.name()).collect(); + assert_eq!(names, vec!["high", "mid", "low"]); + } + + #[tokio::test] + async fn void_hooks_fire_all_handlers() { + let mut runner = HookRunner::new(); + let (h1, c1) = CountingHook::new("hook_a", 0); + let (h2, c2) = CountingHook::new("hook_b", 0); + + runner.register(Box::new(h1)); + runner.register(Box::new(h2)); + + runner.fire_heartbeat_tick().await; + + assert_eq!(c1.load(Ordering::SeqCst), 1); + assert_eq!(c2.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn modifying_hook_can_cancel() { + let mut runner = HookRunner::new(); + runner.register(Box::new(CancelPromptHook { + name: "blocker".into(), + priority: 10, + })); + runner.register(Box::new(UppercasePromptHook { + name: "upper".into(), + priority: 0, + })); + + let result = runner.run_before_prompt_build("hello".into()).await; + assert!(result.is_cancel()); + } + + #[tokio::test] + async fn modifying_hook_pipelines_data() { + let mut runner = HookRunner::new(); + + // Priority 10 runs first: uppercases + runner.register(Box::new(UppercasePromptHook { + name: "upper".into(), + priority: 10, + })); + // Priority 0 runs second: appends suffix + runner.register(Box::new(SuffixPromptHook { + name: "suffix".into(), + priority: 0, + suffix: "_done".into(), + })); + + match runner.run_before_prompt_build("hello".into()).await { + HookResult::Continue(result) => assert_eq!(result, "HELLO_done"), + HookResult::Cancel(_) => panic!("should not cancel"), + } + } +} diff --git a/src/hooks/traits.rs b/src/hooks/traits.rs new file mode 100644 index 000000000..81f8e6efe --- /dev/null +++ b/src/hooks/traits.rs @@ -0,0 +1,140 @@ +use async_trait::async_trait; +use serde_json::Value; +use std::time::Duration; + +use crate::channels::traits::ChannelMessage; +use crate::providers::traits::{ChatMessage, ChatResponse}; +use crate::tools::traits::ToolResult; + +/// Result of a modifying hook — continue with (possibly modified) data, or cancel. +#[derive(Debug, Clone)] +pub enum HookResult { + Continue(T), + Cancel(String), +} + +impl HookResult { + pub fn is_cancel(&self) -> bool { + matches!(self, HookResult::Cancel(_)) + } +} + +/// Trait for hook handlers. All methods have default no-op implementations. +/// Implement only the events you care about. +#[async_trait] +pub trait HookHandler: Send + Sync { + fn name(&self) -> &str; + fn priority(&self) -> i32 { + 0 + } + + // --- Void hooks (parallel, fire-and-forget) --- + async fn on_gateway_start(&self, _host: &str, _port: u16) {} + async fn on_gateway_stop(&self) {} + async fn on_session_start(&self, _session_id: &str, _channel: &str) {} + async fn on_session_end(&self, _session_id: &str, _channel: &str) {} + async fn on_llm_input(&self, _messages: &[ChatMessage], _model: &str) {} + async fn on_llm_output(&self, _response: &ChatResponse) {} + async fn on_after_tool_call(&self, _tool: &str, _result: &ToolResult, _duration: Duration) {} + async fn on_message_sent(&self, _channel: &str, _recipient: &str, _content: &str) {} + async fn on_heartbeat_tick(&self) {} + + // --- Modifying hooks (sequential by priority, can cancel) --- + async fn before_model_resolve( + &self, + provider: String, + model: String, + ) -> HookResult<(String, String)> { + HookResult::Continue((provider, model)) + } + + async fn before_prompt_build(&self, prompt: String) -> HookResult { + HookResult::Continue(prompt) + } + + async fn before_llm_call( + &self, + messages: Vec, + model: String, + ) -> HookResult<(Vec, String)> { + HookResult::Continue((messages, model)) + } + + async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> { + HookResult::Continue((name, args)) + } + + async fn on_message_received(&self, message: ChannelMessage) -> HookResult { + HookResult::Continue(message) + } + + async fn on_message_sending( + &self, + channel: String, + recipient: String, + content: String, + ) -> HookResult<(String, String, String)> { + HookResult::Continue((channel, recipient, content)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestHook { + name: String, + priority: i32, + } + + impl TestHook { + fn new(name: &str, priority: i32) -> Self { + Self { + name: name.to_string(), + priority, + } + } + } + + #[async_trait] + impl HookHandler for TestHook { + fn name(&self) -> &str { + &self.name + } + fn priority(&self) -> i32 { + self.priority + } + } + + #[test] + fn hook_result_is_cancel() { + let ok: HookResult = HookResult::Continue("hi".into()); + assert!(!ok.is_cancel()); + let cancel: HookResult = HookResult::Cancel("blocked".into()); + assert!(cancel.is_cancel()); + } + + #[test] + fn default_priority_is_zero() { + struct MinimalHook; + #[async_trait] + impl HookHandler for MinimalHook { + fn name(&self) -> &str { + "minimal" + } + } + assert_eq!(MinimalHook.priority(), 0); + } + + #[tokio::test] + async fn default_modifying_hooks_pass_through() { + let hook = TestHook::new("test", 0); + match hook + .before_tool_call("shell".into(), serde_json::json!({"cmd": "ls"})) + .await + { + HookResult::Continue((name, _args)) => assert_eq!(name, "shell"), + HookResult::Cancel(_) => panic!("should not cancel"), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6d6e743a1..aa06883bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,7 @@ pub mod gateway; pub(crate) mod hardware; pub(crate) mod health; pub(crate) mod heartbeat; +pub mod hooks; pub(crate) mod identity; pub(crate) mod integrations; pub mod memory;