From 467fea87c6d64cbec3268e958f8145220c45205f Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Sat, 28 Feb 2026 19:45:51 -0500 Subject: [PATCH] refactor(hooks): extract HookRunner factory and make plugin registry init idempotent - Add HookRunner::from_config() factory that encapsulates hook construction from HooksConfig, replacing 3 duplicated blocks in agent/loop_, gateway, and channels modules. - Make plugin registry initialize_from_config() idempotent: skip re-init if already initialized, log debug message instead of silently overwriting. - Add capability gating for tool_result_persist hook modifications. --- src/channels/mod.rs | 16 +--- src/gateway/mod.rs | 17 +--- src/hooks/runner.rs | 181 +++++++++++++++++++++++++++++++++++++++-- src/plugins/runtime.rs | 9 ++ 4 files changed, 187 insertions(+), 36 deletions(-) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 69742fab0..caf830137 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -5515,21 +5515,7 @@ pub async fn start_channels(config: Config) -> Result<()> { message_timeout_secs, interrupt_on_new_message, multimodal: config.multimodal.clone(), - hooks: if config.hooks.enabled { - let mut runner = crate::hooks::HookRunner::new(); - if config.hooks.builtin.boot_script { - runner.register(Box::new(crate::hooks::builtin::BootScriptHook)); - } - if config.hooks.builtin.command_logger { - runner.register(Box::new(crate::hooks::builtin::CommandLoggerHook::new())); - } - if config.hooks.builtin.session_memory { - runner.register(Box::new(crate::hooks::builtin::SessionMemoryHook)); - } - Some(Arc::new(runner)) - } else { - None - }, + hooks: crate::hooks::HookRunner::from_config(&config.hooks).map(Arc::new), non_cli_excluded_tools: Arc::new(Mutex::new( config.autonomy.non_cli_excluded_tools.clone(), )), diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 469dca696..c4f731ec0 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -369,21 +369,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let config_state = Arc::new(Mutex::new(config.clone())); // ── Hooks ────────────────────────────────────────────────────── - let hooks: Option> = if config.hooks.enabled { - let mut runner = crate::hooks::HookRunner::new(); - if config.hooks.builtin.boot_script { - runner.register(Box::new(crate::hooks::builtin::BootScriptHook)); - } - if config.hooks.builtin.command_logger { - runner.register(Box::new(crate::hooks::builtin::CommandLoggerHook::new())); - } - if config.hooks.builtin.session_memory { - runner.register(Box::new(crate::hooks::builtin::SessionMemoryHook)); - } - Some(std::sync::Arc::new(runner)) - } else { - None - }; + let hooks = crate::hooks::HookRunner::from_config(&config.hooks) + .map(std::sync::Arc::new); let addr: SocketAddr = format!("{host}:{port}").parse()?; let listener = tokio::net::TcpListener::bind(addr).await?; diff --git a/src/hooks/runner.rs b/src/hooks/runner.rs index 2af598dc4..e09c4b43e 100644 --- a/src/hooks/runner.rs +++ b/src/hooks/runner.rs @@ -6,6 +6,8 @@ use std::panic::AssertUnwindSafe; use tracing::info; use crate::channels::traits::ChannelMessage; +use crate::config::HooksConfig; +use crate::plugins::traits::PluginCapability; use crate::providers::traits::{ChatMessage, ChatResponse}; use crate::tools::traits::ToolResult; @@ -28,6 +30,26 @@ impl HookRunner { } } + /// Build a hook runner from configuration, registering enabled built-in hooks. + /// + /// Returns `None` if hooks are disabled in config. + pub fn from_config(config: &HooksConfig) -> Option { + if !config.enabled { + return None; + } + let mut runner = Self::new(); + if config.builtin.boot_script { + runner.register(Box::new(super::builtin::BootScriptHook)); + } + if config.builtin.command_logger { + runner.register(Box::new(super::builtin::CommandLoggerHook::new())); + } + if config.builtin.session_memory { + runner.register(Box::new(super::builtin::SessionMemoryHook)); + } + Some(runner) + } + /// Register a handler and re-sort by descending priority. pub fn register(&mut self, handler: Box) { self.handlers.push(handler); @@ -307,17 +329,45 @@ impl HookRunner { ) -> HookResult { for h in &self.handlers { let hook_name = h.name(); + let has_modify_cap = h + .capabilities() + .contains(&PluginCapability::ModifyToolResults); match AssertUnwindSafe(h.tool_result_persist(tool.clone(), result.clone())) .catch_unwind() .await { - Ok(HookResult::Continue(next_result)) => result = next_result, + Ok(HookResult::Continue(next_result)) => { + if next_result.success != result.success + || next_result.output != result.output + || next_result.error != result.error + { + if has_modify_cap { + result = next_result; + } else { + tracing::warn!( + hook = hook_name, + "hook attempted to modify tool result without ModifyToolResults capability; ignoring modification" + ); + } + } else { + // No actual modification — pass-through is always allowed. + result = next_result; + } + } Ok(HookResult::Cancel(reason)) => { - info!( - hook = hook_name, - reason, "tool_result_persist cancelled by hook" - ); - return HookResult::Cancel(reason); + if has_modify_cap { + info!( + hook = hook_name, + reason, "tool_result_persist cancelled by hook" + ); + return HookResult::Cancel(reason); + } else { + tracing::warn!( + hook = hook_name, + reason, + "hook attempted to cancel tool result without ModifyToolResults capability; ignoring cancellation" + ); + } } Err(_) => { tracing::error!( @@ -565,4 +615,123 @@ mod tests { HookResult::Cancel(_) => panic!("should not cancel"), } } + + // -- Capability-gated tool_result_persist tests -- + + /// Hook that flips success to false (modification) without capability. + struct UncappedResultMutator; + + #[async_trait] + impl HookHandler for UncappedResultMutator { + fn name(&self) -> &str { + "uncapped_mutator" + } + async fn tool_result_persist( + &self, + _tool: String, + mut result: ToolResult, + ) -> HookResult { + result.success = false; + result.output = "tampered".into(); + HookResult::Continue(result) + } + } + + /// Hook that flips success with the required capability. + struct CappedResultMutator; + + #[async_trait] + impl HookHandler for CappedResultMutator { + fn name(&self) -> &str { + "capped_mutator" + } + fn capabilities(&self) -> &[PluginCapability] { + &[PluginCapability::ModifyToolResults] + } + async fn tool_result_persist( + &self, + _tool: String, + mut result: ToolResult, + ) -> HookResult { + result.success = false; + result.output = "authorized_change".into(); + HookResult::Continue(result) + } + } + + /// Hook without capability that tries to cancel. + struct UncappedResultCanceller; + + #[async_trait] + impl HookHandler for UncappedResultCanceller { + fn name(&self) -> &str { + "uncapped_canceller" + } + async fn tool_result_persist( + &self, + _tool: String, + _result: ToolResult, + ) -> HookResult { + HookResult::Cancel("blocked".into()) + } + } + + fn sample_tool_result() -> ToolResult { + ToolResult { + success: true, + output: "original".into(), + error: None, + } + } + + #[tokio::test] + async fn tool_result_persist_blocks_modification_without_capability() { + let mut runner = HookRunner::new(); + runner.register(Box::new(UncappedResultMutator)); + + let result = runner + .run_tool_result_persist("shell".into(), sample_tool_result()) + .await; + match result { + HookResult::Continue(r) => { + assert!(r.success, "modification should have been blocked"); + assert_eq!(r.output, "original"); + } + HookResult::Cancel(_) => panic!("should not cancel"), + } + } + + #[tokio::test] + async fn tool_result_persist_allows_modification_with_capability() { + let mut runner = HookRunner::new(); + runner.register(Box::new(CappedResultMutator)); + + let result = runner + .run_tool_result_persist("shell".into(), sample_tool_result()) + .await; + match result { + HookResult::Continue(r) => { + assert!(!r.success, "modification should have been applied"); + assert_eq!(r.output, "authorized_change"); + } + HookResult::Cancel(_) => panic!("should not cancel"), + } + } + + #[tokio::test] + async fn tool_result_persist_blocks_cancel_without_capability() { + let mut runner = HookRunner::new(); + runner.register(Box::new(UncappedResultCanceller)); + + let result = runner + .run_tool_result_persist("shell".into(), sample_tool_result()) + .await; + match result { + HookResult::Continue(r) => { + assert!(r.success, "cancel should have been blocked"); + assert_eq!(r.output, "original"); + } + HookResult::Cancel(_) => panic!("cancel without capability should be blocked"), + } + } } diff --git a/src/plugins/runtime.rs b/src/plugins/runtime.rs index 6c90d458b..3a3f12ef3 100644 --- a/src/plugins/runtime.rs +++ b/src/plugins/runtime.rs @@ -70,13 +70,22 @@ fn registry_cell() -> &'static RwLock { CELL.get_or_init(|| RwLock::new(PluginRegistry::default())) } +/// Whether `initialize_from_config` has completed at least once. +static INITIALIZED: std::sync::atomic::AtomicBool = + std::sync::atomic::AtomicBool::new(false); + pub fn initialize_from_config(config: &PluginsConfig) -> Result<()> { + if INITIALIZED.load(std::sync::atomic::Ordering::Acquire) { + tracing::debug!("plugin registry already initialized, skipping re-init"); + return Ok(()); + } let runtime = PluginRuntime::new(); let registry = runtime.load_registry_from_config(config)?; let mut guard = registry_cell() .write() .unwrap_or_else(std::sync::PoisonError::into_inner); *guard = registry; + INITIALIZED.store(true, std::sync::atomic::Ordering::Release); Ok(()) }