diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 94c79ae0c..a2350b5c2 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1693,6 +1693,31 @@ async fn create_resilient_provider_nonblocking( .context("failed to join provider initialization task")? } +async fn create_routed_provider_nonblocking( + provider_name: &str, + api_key: Option, + api_url: Option, + reliability: crate::config::ReliabilityConfig, + model_routes: Vec, + default_model: String, + provider_runtime_options: providers::ProviderRuntimeOptions, +) -> anyhow::Result> { + let provider_name = provider_name.to_string(); + tokio::task::spawn_blocking(move || { + providers::create_routed_provider_with_options( + &provider_name, + api_key.as_deref(), + api_url.as_deref(), + &reliability, + &model_routes, + &default_model, + &provider_runtime_options, + ) + }) + .await + .context("failed to join routed provider initialization task")? +} + fn build_models_help_response(current: &ChannelRouteSelection, workspace_dir: &Path) -> String { let mut response = String::new(); let _ = writeln!( @@ -4669,6 +4694,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> { #[allow(clippy::too_many_lines)] pub async fn start_channels(config: Config) -> Result<()> { let provider_name = resolved_default_provider(&config); + let model = resolved_default_model(&config); let provider_runtime_options = providers::ProviderRuntimeOptions { auth_profile_override: None, provider_api_url: config.api_url.clone(), @@ -4681,11 +4707,13 @@ pub async fn start_channels(config: Config) -> Result<()> { model_support_vision: config.model_support_vision, }; let provider: Arc = Arc::from( - create_resilient_provider_nonblocking( + create_routed_provider_nonblocking( &provider_name, config.api_key.clone(), config.api_url.clone(), config.reliability.clone(), + config.model_routes.clone(), + model.clone(), provider_runtime_options.clone(), ) .await?, @@ -4719,7 +4747,6 @@ pub async fn start_channels(config: Config) -> Result<()> { &config.autonomy, &config.workspace_dir, )); - let model = resolved_default_model(&config); let temperature = config.default_temperature; let mem: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, @@ -8175,6 +8202,39 @@ BTC is currently around $65,000 based on latest tool output."# store.remove(&config_path); } + #[tokio::test] + async fn start_channels_uses_model_routes_when_global_provider_key_is_missing() { + let temp = tempfile::TempDir::new().expect("temp dir"); + let workspace_dir = temp.path().join("workspace"); + std::fs::create_dir_all(&workspace_dir).expect("workspace dir"); + + let mut cfg = Config::default(); + cfg.workspace_dir = workspace_dir; + cfg.config_path = temp.path().join("config.toml"); + cfg.default_provider = None; + cfg.api_key = None; + cfg.default_model = Some("hint:fast".to_string()); + cfg.model_routes = vec![crate::config::ModelRouteConfig { + hint: "fast".to_string(), + provider: "openai-codex".to_string(), + model: "gpt-5.3-codex".to_string(), + max_tokens: Some(512), + api_key: Some("route-specific-key".to_string()), + }]; + + let config_path = cfg.config_path.clone(); + let result = start_channels(cfg).await; + let mut store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + store.remove(&config_path); + + assert!( + result.is_ok(), + "start_channels should support routed providers without global credentials: {result:?}" + ); + } + #[tokio::test] async fn process_channel_message_respects_configured_max_tool_iterations_above_default() { let channel_impl = Arc::new(RecordingChannel::default()); diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 76f02e762..48793221a 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -7,6 +7,54 @@ use tokio::task::JoinHandle; use tokio::time::Duration; const STATUS_FLUSH_SECONDS: u64 = 5; +const SHUTDOWN_GRACE_SECONDS: u64 = 5; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ShutdownSignal { + CtrlC, + SigTerm, +} + +fn shutdown_reason(signal: ShutdownSignal) -> &'static str { + match signal { + ShutdownSignal::CtrlC => "shutdown requested (SIGINT)", + ShutdownSignal::SigTerm => "shutdown requested (SIGTERM)", + } +} + +#[cfg(unix)] +fn shutdown_hint() -> &'static str { + "Ctrl+C or SIGTERM to stop" +} + +#[cfg(not(unix))] +fn shutdown_hint() -> &'static str { + "Ctrl+C to stop" +} + +async fn wait_for_shutdown_signal() -> Result { + #[cfg(unix)] + { + use tokio::signal::unix::{signal, SignalKind}; + + let mut sigterm = signal(SignalKind::terminate())?; + tokio::select! { + ctrl_c = tokio::signal::ctrl_c() => { + ctrl_c?; + Ok(ShutdownSignal::CtrlC) + } + sigterm_result = sigterm.recv() => match sigterm_result { + Some(()) => Ok(ShutdownSignal::SigTerm), + None => anyhow::bail!("SIGTERM signal stream unexpectedly closed"), + }, + } + } + #[cfg(not(unix))] + { + tokio::signal::ctrl_c().await?; + Ok(ShutdownSignal::CtrlC) + } +} pub async fn run(config: Config, host: String, port: u16) -> Result<()> { let initial_backoff = config.reliability.channel_initial_backoff_secs.max(1); @@ -90,19 +138,40 @@ pub async fn run(config: Config, host: String, port: u16) -> Result<()> { println!("🧠 ZeroClaw daemon started"); println!(" Gateway: http://{host}:{port}"); println!(" Components: gateway, channels, heartbeat, scheduler"); - println!(" Ctrl+C to stop"); + println!(" {}", shutdown_hint()); - tokio::signal::ctrl_c().await?; - crate::health::mark_component_error("daemon", "shutdown requested"); + let signal = wait_for_shutdown_signal().await?; + crate::health::mark_component_error("daemon", shutdown_reason(signal)); + let aborted = + shutdown_handles_with_grace(handles, Duration::from_secs(SHUTDOWN_GRACE_SECONDS)).await; + if aborted > 0 { + tracing::warn!( + aborted, + grace_seconds = SHUTDOWN_GRACE_SECONDS, + "Forced shutdown for daemon tasks that exceeded graceful drain window" + ); + } + Ok(()) +} + +async fn shutdown_handles_with_grace(handles: Vec>, grace: Duration) -> usize { + let deadline = tokio::time::Instant::now() + grace; + while !handles.iter().all(JoinHandle::is_finished) && tokio::time::Instant::now() < deadline { + tokio::time::sleep(Duration::from_millis(50)).await; + } + + let mut aborted = 0usize; for handle in &handles { - handle.abort(); + if !handle.is_finished() { + handle.abort(); + aborted += 1; + } } for handle in handles { let _ = handle.await; } - - Ok(()) + aborted } pub fn state_file_path(config: &Config) -> PathBuf { @@ -350,6 +419,54 @@ mod tests { assert_eq!(path, tmp.path().join("daemon_state.json")); } + #[test] + fn shutdown_reason_for_ctrl_c_mentions_sigint() { + assert_eq!( + shutdown_reason(ShutdownSignal::CtrlC), + "shutdown requested (SIGINT)" + ); + } + + #[test] + fn shutdown_reason_for_sigterm_mentions_sigterm() { + assert_eq!( + shutdown_reason(ShutdownSignal::SigTerm), + "shutdown requested (SIGTERM)" + ); + } + + #[test] + fn shutdown_hint_matches_platform_signal_support() { + #[cfg(unix)] + assert_eq!(shutdown_hint(), "Ctrl+C or SIGTERM to stop"); + + #[cfg(not(unix))] + assert_eq!(shutdown_hint(), "Ctrl+C to stop"); + } + + #[tokio::test] + async fn graceful_shutdown_waits_for_completed_handles_without_abort() { + let finished = tokio::spawn(async {}); + let aborted = shutdown_handles_with_grace(vec![finished], Duration::from_millis(20)).await; + assert_eq!(aborted, 0); + } + + #[tokio::test] + async fn graceful_shutdown_aborts_stuck_handles_after_timeout() { + let never_finishes = tokio::spawn(async { + tokio::time::sleep(Duration::from_secs(30)).await; + }); + let started = tokio::time::Instant::now(); + let aborted = + shutdown_handles_with_grace(vec![never_finishes], Duration::from_millis(20)).await; + + assert_eq!(aborted, 1); + assert!( + started.elapsed() < Duration::from_secs(2), + "shutdown should not block indefinitely" + ); + } + #[tokio::test] async fn supervisor_marks_error_and_restart_on_failure() { let handle = spawn_component_supervisor("daemon-test-fail", 1, 1, || async {