feat(hooks): add HookHandler trait, HookResult, and HookRunner dispatcher
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
9ff86c372f
commit
ff6027fce7
@ -9,12 +9,12 @@ pub use schema::{
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig, GatewayConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig,
|
||||
IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig,
|
||||
ProxyConfig, ProxyScope, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
|
||||
StorageProviderSection, StreamMode, TelegramConfig, TranscriptionConfig, TunnelConfig,
|
||||
WebSearchConfig, WebhookConfig,
|
||||
BuiltinHooksConfig, HooksConfig, NextcloudTalkConfig, ObservabilityConfig,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, SkillsPromptInjectionMode,
|
||||
SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode,
|
||||
TelegramConfig, TranscriptionConfig, TunnelConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
5
src/hooks/mod.rs
Normal file
5
src/hooks/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
mod runner;
|
||||
mod traits;
|
||||
|
||||
pub use runner::HookRunner;
|
||||
pub use traits::{HookHandler, HookResult};
|
||||
414
src/hooks/runner.rs
Normal file
414
src/hooks/runner.rs
Normal file
@ -0,0 +1,414 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::future::join_all;
|
||||
use serde_json::Value;
|
||||
use tracing::info;
|
||||
|
||||
use crate::channels::traits::ChannelMessage;
|
||||
use crate::providers::traits::{ChatMessage, ChatResponse};
|
||||
use crate::tools::traits::ToolResult;
|
||||
|
||||
use super::traits::{HookHandler, HookResult};
|
||||
|
||||
/// Dispatcher that manages registered hook handlers.
|
||||
///
|
||||
/// Void hooks are dispatched in parallel via `join_all`.
|
||||
/// Modifying hooks run sequentially by priority (higher first), piping output
|
||||
/// and short-circuiting on `Cancel`.
|
||||
pub struct HookRunner {
|
||||
handlers: Vec<Box<dyn HookHandler>>,
|
||||
}
|
||||
|
||||
impl HookRunner {
|
||||
/// Create an empty runner with no handlers.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a handler and re-sort by descending priority.
|
||||
pub fn register(&mut self, handler: Box<dyn HookHandler>) {
|
||||
self.handlers.push(handler);
|
||||
self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority()));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Void dispatchers (parallel, fire-and-forget)
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
pub async fn fire_gateway_start(&self, host: &str, port: u16) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_gateway_start(host, port))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_gateway_stop(&self) {
|
||||
let futs: Vec<_> = self.handlers.iter().map(|h| h.on_gateway_stop()).collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_session_start(&self, session_id: &str, channel: &str) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_session_start(session_id, channel))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_session_end(&self, session_id: &str, channel: &str) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_session_end(session_id, channel))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_llm_input(&self, messages: &[ChatMessage], model: &str) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_llm_input(messages, model))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_llm_output(&self, response: &ChatResponse) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_llm_output(response))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_after_tool_call(&self, tool: &str, result: &ToolResult, duration: Duration) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_after_tool_call(tool, result, duration))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_message_sent(&self, channel: &str, recipient: &str, content: &str) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_message_sent(channel, recipient, content))
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
pub async fn fire_heartbeat_tick(&self) {
|
||||
let futs: Vec<_> = self
|
||||
.handlers
|
||||
.iter()
|
||||
.map(|h| h.on_heartbeat_tick())
|
||||
.collect();
|
||||
join_all(futs).await;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Modifying dispatchers (sequential by priority, short-circuit on Cancel)
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
pub async fn run_before_model_resolve(
|
||||
&self,
|
||||
mut provider: String,
|
||||
mut model: String,
|
||||
) -> HookResult<(String, String)> {
|
||||
for h in &self.handlers {
|
||||
match h.before_model_resolve(provider, model).await {
|
||||
HookResult::Continue((p, m)) => {
|
||||
provider = p;
|
||||
model = m;
|
||||
}
|
||||
HookResult::Cancel(reason) => {
|
||||
info!(
|
||||
hook = h.name(),
|
||||
reason, "before_model_resolve cancelled by hook"
|
||||
);
|
||||
return HookResult::Cancel(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
HookResult::Continue((provider, model))
|
||||
}
|
||||
|
||||
pub async fn run_before_prompt_build(&self, mut prompt: String) -> HookResult<String> {
|
||||
for h in &self.handlers {
|
||||
match h.before_prompt_build(prompt).await {
|
||||
HookResult::Continue(p) => prompt = p,
|
||||
HookResult::Cancel(reason) => {
|
||||
info!(
|
||||
hook = h.name(),
|
||||
reason, "before_prompt_build cancelled by hook"
|
||||
);
|
||||
return HookResult::Cancel(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
HookResult::Continue(prompt)
|
||||
}
|
||||
|
||||
pub async fn run_before_llm_call(
|
||||
&self,
|
||||
mut messages: Vec<ChatMessage>,
|
||||
mut model: String,
|
||||
) -> HookResult<(Vec<ChatMessage>, String)> {
|
||||
for h in &self.handlers {
|
||||
match h.before_llm_call(messages, model).await {
|
||||
HookResult::Continue((m, mdl)) => {
|
||||
messages = m;
|
||||
model = mdl;
|
||||
}
|
||||
HookResult::Cancel(reason) => {
|
||||
info!(hook = h.name(), reason, "before_llm_call cancelled by hook");
|
||||
return HookResult::Cancel(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
HookResult::Continue((messages, model))
|
||||
}
|
||||
|
||||
pub async fn run_before_tool_call(
|
||||
&self,
|
||||
mut name: String,
|
||||
mut args: Value,
|
||||
) -> HookResult<(String, Value)> {
|
||||
for h in &self.handlers {
|
||||
match h.before_tool_call(name, args).await {
|
||||
HookResult::Continue((n, a)) => {
|
||||
name = n;
|
||||
args = a;
|
||||
}
|
||||
HookResult::Cancel(reason) => {
|
||||
info!(
|
||||
hook = h.name(),
|
||||
reason, "before_tool_call cancelled by hook"
|
||||
);
|
||||
return HookResult::Cancel(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
HookResult::Continue((name, args))
|
||||
}
|
||||
|
||||
pub async fn run_on_message_received(
|
||||
&self,
|
||||
mut message: ChannelMessage,
|
||||
) -> HookResult<ChannelMessage> {
|
||||
for h in &self.handlers {
|
||||
match h.on_message_received(message).await {
|
||||
HookResult::Continue(m) => message = m,
|
||||
HookResult::Cancel(reason) => {
|
||||
info!(
|
||||
hook = h.name(),
|
||||
reason, "on_message_received cancelled by hook"
|
||||
);
|
||||
return HookResult::Cancel(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
HookResult::Continue(message)
|
||||
}
|
||||
|
||||
pub async fn run_on_message_sending(
|
||||
&self,
|
||||
mut channel: String,
|
||||
mut recipient: String,
|
||||
mut content: String,
|
||||
) -> HookResult<(String, String, String)> {
|
||||
for h in &self.handlers {
|
||||
match h.on_message_sending(channel, recipient, content).await {
|
||||
HookResult::Continue((c, r, ct)) => {
|
||||
channel = c;
|
||||
recipient = r;
|
||||
content = ct;
|
||||
}
|
||||
HookResult::Cancel(reason) => {
|
||||
info!(
|
||||
hook = h.name(),
|
||||
reason, "on_message_sending cancelled by hook"
|
||||
);
|
||||
return HookResult::Cancel(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
HookResult::Continue((channel, recipient, content))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A hook that records how many times void events fire.
|
||||
struct CountingHook {
|
||||
name: String,
|
||||
priority: i32,
|
||||
fire_count: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl CountingHook {
|
||||
fn new(name: &str, priority: i32) -> (Self, Arc<AtomicU32>) {
|
||||
let count = Arc::new(AtomicU32::new(0));
|
||||
(
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
priority,
|
||||
fire_count: count.clone(),
|
||||
},
|
||||
count,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HookHandler for CountingHook {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn priority(&self) -> i32 {
|
||||
self.priority
|
||||
}
|
||||
async fn on_heartbeat_tick(&self) {
|
||||
self.fire_count.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
/// A modifying hook that uppercases the prompt.
|
||||
struct UppercasePromptHook {
|
||||
name: String,
|
||||
priority: i32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HookHandler for UppercasePromptHook {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn priority(&self) -> i32 {
|
||||
self.priority
|
||||
}
|
||||
async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
|
||||
HookResult::Continue(prompt.to_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
/// A modifying hook that cancels before_prompt_build.
|
||||
struct CancelPromptHook {
|
||||
name: String,
|
||||
priority: i32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HookHandler for CancelPromptHook {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn priority(&self) -> i32 {
|
||||
self.priority
|
||||
}
|
||||
async fn before_prompt_build(&self, _prompt: String) -> HookResult<String> {
|
||||
HookResult::Cancel("blocked by policy".into())
|
||||
}
|
||||
}
|
||||
|
||||
/// A modifying hook that appends a suffix to the prompt.
|
||||
struct SuffixPromptHook {
|
||||
name: String,
|
||||
priority: i32,
|
||||
suffix: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HookHandler for SuffixPromptHook {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn priority(&self) -> i32 {
|
||||
self.priority
|
||||
}
|
||||
async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
|
||||
HookResult::Continue(format!("{}{}", prompt, self.suffix))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_and_sort_by_priority() {
|
||||
let mut runner = HookRunner::new();
|
||||
let (low, _) = CountingHook::new("low", 1);
|
||||
let (high, _) = CountingHook::new("high", 10);
|
||||
let (mid, _) = CountingHook::new("mid", 5);
|
||||
|
||||
runner.register(Box::new(low));
|
||||
runner.register(Box::new(high));
|
||||
runner.register(Box::new(mid));
|
||||
|
||||
let names: Vec<&str> = runner.handlers.iter().map(|h| h.name()).collect();
|
||||
assert_eq!(names, vec!["high", "mid", "low"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn void_hooks_fire_all_handlers() {
|
||||
let mut runner = HookRunner::new();
|
||||
let (h1, c1) = CountingHook::new("hook_a", 0);
|
||||
let (h2, c2) = CountingHook::new("hook_b", 0);
|
||||
|
||||
runner.register(Box::new(h1));
|
||||
runner.register(Box::new(h2));
|
||||
|
||||
runner.fire_heartbeat_tick().await;
|
||||
|
||||
assert_eq!(c1.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(c2.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn modifying_hook_can_cancel() {
|
||||
let mut runner = HookRunner::new();
|
||||
runner.register(Box::new(CancelPromptHook {
|
||||
name: "blocker".into(),
|
||||
priority: 10,
|
||||
}));
|
||||
runner.register(Box::new(UppercasePromptHook {
|
||||
name: "upper".into(),
|
||||
priority: 0,
|
||||
}));
|
||||
|
||||
let result = runner.run_before_prompt_build("hello".into()).await;
|
||||
assert!(result.is_cancel());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn modifying_hook_pipelines_data() {
|
||||
let mut runner = HookRunner::new();
|
||||
|
||||
// Priority 10 runs first: uppercases
|
||||
runner.register(Box::new(UppercasePromptHook {
|
||||
name: "upper".into(),
|
||||
priority: 10,
|
||||
}));
|
||||
// Priority 0 runs second: appends suffix
|
||||
runner.register(Box::new(SuffixPromptHook {
|
||||
name: "suffix".into(),
|
||||
priority: 0,
|
||||
suffix: "_done".into(),
|
||||
}));
|
||||
|
||||
match runner.run_before_prompt_build("hello".into()).await {
|
||||
HookResult::Continue(result) => assert_eq!(result, "HELLO_done"),
|
||||
HookResult::Cancel(_) => panic!("should not cancel"),
|
||||
}
|
||||
}
|
||||
}
|
||||
140
src/hooks/traits.rs
Normal file
140
src/hooks/traits.rs
Normal file
@ -0,0 +1,140 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::channels::traits::ChannelMessage;
|
||||
use crate::providers::traits::{ChatMessage, ChatResponse};
|
||||
use crate::tools::traits::ToolResult;
|
||||
|
||||
/// Result of a modifying hook — continue with (possibly modified) data, or cancel.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum HookResult<T> {
|
||||
Continue(T),
|
||||
Cancel(String),
|
||||
}
|
||||
|
||||
impl<T> HookResult<T> {
|
||||
pub fn is_cancel(&self) -> bool {
|
||||
matches!(self, HookResult::Cancel(_))
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for hook handlers. All methods have default no-op implementations.
|
||||
/// Implement only the events you care about.
|
||||
#[async_trait]
|
||||
pub trait HookHandler: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
fn priority(&self) -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
// --- Void hooks (parallel, fire-and-forget) ---
|
||||
async fn on_gateway_start(&self, _host: &str, _port: u16) {}
|
||||
async fn on_gateway_stop(&self) {}
|
||||
async fn on_session_start(&self, _session_id: &str, _channel: &str) {}
|
||||
async fn on_session_end(&self, _session_id: &str, _channel: &str) {}
|
||||
async fn on_llm_input(&self, _messages: &[ChatMessage], _model: &str) {}
|
||||
async fn on_llm_output(&self, _response: &ChatResponse) {}
|
||||
async fn on_after_tool_call(&self, _tool: &str, _result: &ToolResult, _duration: Duration) {}
|
||||
async fn on_message_sent(&self, _channel: &str, _recipient: &str, _content: &str) {}
|
||||
async fn on_heartbeat_tick(&self) {}
|
||||
|
||||
// --- Modifying hooks (sequential by priority, can cancel) ---
|
||||
async fn before_model_resolve(
|
||||
&self,
|
||||
provider: String,
|
||||
model: String,
|
||||
) -> HookResult<(String, String)> {
|
||||
HookResult::Continue((provider, model))
|
||||
}
|
||||
|
||||
async fn before_prompt_build(&self, prompt: String) -> HookResult<String> {
|
||||
HookResult::Continue(prompt)
|
||||
}
|
||||
|
||||
async fn before_llm_call(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
model: String,
|
||||
) -> HookResult<(Vec<ChatMessage>, String)> {
|
||||
HookResult::Continue((messages, model))
|
||||
}
|
||||
|
||||
async fn before_tool_call(&self, name: String, args: Value) -> HookResult<(String, Value)> {
|
||||
HookResult::Continue((name, args))
|
||||
}
|
||||
|
||||
async fn on_message_received(&self, message: ChannelMessage) -> HookResult<ChannelMessage> {
|
||||
HookResult::Continue(message)
|
||||
}
|
||||
|
||||
async fn on_message_sending(
|
||||
&self,
|
||||
channel: String,
|
||||
recipient: String,
|
||||
content: String,
|
||||
) -> HookResult<(String, String, String)> {
|
||||
HookResult::Continue((channel, recipient, content))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct TestHook {
|
||||
name: String,
|
||||
priority: i32,
|
||||
}
|
||||
|
||||
impl TestHook {
|
||||
fn new(name: &str, priority: i32) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
priority,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HookHandler for TestHook {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
fn priority(&self) -> i32 {
|
||||
self.priority
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_result_is_cancel() {
|
||||
let ok: HookResult<String> = HookResult::Continue("hi".into());
|
||||
assert!(!ok.is_cancel());
|
||||
let cancel: HookResult<String> = HookResult::Cancel("blocked".into());
|
||||
assert!(cancel.is_cancel());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_priority_is_zero() {
|
||||
struct MinimalHook;
|
||||
#[async_trait]
|
||||
impl HookHandler for MinimalHook {
|
||||
fn name(&self) -> &str {
|
||||
"minimal"
|
||||
}
|
||||
}
|
||||
assert_eq!(MinimalHook.priority(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn default_modifying_hooks_pass_through() {
|
||||
let hook = TestHook::new("test", 0);
|
||||
match hook
|
||||
.before_tool_call("shell".into(), serde_json::json!({"cmd": "ls"}))
|
||||
.await
|
||||
{
|
||||
HookResult::Continue((name, _args)) => assert_eq!(name, "shell"),
|
||||
HookResult::Cancel(_) => panic!("should not cancel"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -51,6 +51,7 @@ pub mod gateway;
|
||||
pub(crate) mod hardware;
|
||||
pub(crate) mod health;
|
||||
pub(crate) mod heartbeat;
|
||||
pub mod hooks;
|
||||
pub(crate) mod identity;
|
||||
pub(crate) mod integrations;
|
||||
pub mod memory;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user