From 52b9e6a2210df08dbccf08ff81907d6d93cf9254 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Thu, 5 Mar 2026 09:57:00 -0500 Subject: [PATCH] fix(channel): consume provider streaming in tool loop drafts --- src/agent/loop_.rs | 398 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 383 insertions(+), 15 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 0bb627fda..a0ca7fb49 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -11,6 +11,7 @@ use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; +use futures_util::StreamExt; use regex::{Regex, RegexSet}; use rustyline::error::ReadlineError; use std::collections::{BTreeSet, HashSet}; @@ -44,6 +45,8 @@ use parsing::{ /// Minimum characters per chunk when relaying LLM text to a streaming draft. const STREAM_CHUNK_MIN_CHARS: usize = 80; +/// Rolling window size for detecting streamed tool-call payload markers. +const STREAM_TOOL_MARKER_WINDOW_CHARS: usize = 512; /// Default maximum agentic tool-use iterations per user message to prevent runaway loops. /// Used as a safe fallback when `max_tool_iterations` is unset or configured as zero. @@ -494,6 +497,127 @@ pub(crate) fn is_tool_iteration_limit_error(err: &anyhow::Error) -> bool { }) } +#[derive(Debug, Default)] +struct StreamedChatOutcome { + response_text: String, + forwarded_live_deltas: bool, +} + +fn looks_like_streamed_tool_payload(window: &str) -> bool { + let lowered = window.to_ascii_lowercase(); + lowered.contains(", + model: &str, + temperature: f64, + cancellation_token: Option<&CancellationToken>, +) -> Result { + let chat_future = provider.chat( + ChatRequest { + messages, + tools: request_tools, + }, + model, + temperature, + ); + + if let Some(token) = cancellation_token { + tokio::select! { + () = token.cancelled() => Err(ToolLoopCancelled.into()), + result = chat_future => result, + } + } else { + chat_future.await + } +} + +async fn consume_provider_streaming_response( + provider: &dyn Provider, + messages: &[ChatMessage], + model: &str, + temperature: f64, + cancellation_token: Option<&CancellationToken>, + on_delta: Option<&tokio::sync::mpsc::Sender>, +) -> Result { + let mut provider_stream = provider.stream_chat_with_history( + messages, + model, + temperature, + crate::providers::traits::StreamOptions::new(true), + ); + let mut outcome = StreamedChatOutcome::default(); + let mut delta_sender = on_delta; + let mut suppress_forwarding = false; + let mut marker_window = String::new(); + + loop { + let next_chunk = if let Some(token) = cancellation_token { + tokio::select! { + () = token.cancelled() => return Err(ToolLoopCancelled.into()), + chunk = provider_stream.next() => chunk, + } + } else { + provider_stream.next().await + }; + + let Some(chunk_result) = next_chunk else { + break; + }; + + let chunk = chunk_result.map_err(|err| anyhow::anyhow!("provider stream error: {err}"))?; + if chunk.is_final { + break; + } + if chunk.delta.is_empty() { + continue; + } + + outcome.response_text.push_str(&chunk.delta); + marker_window.push_str(&chunk.delta); + + if marker_window.len() > STREAM_TOOL_MARKER_WINDOW_CHARS { + let keep_from = marker_window.len() - STREAM_TOOL_MARKER_WINDOW_CHARS; + let boundary = marker_window + .char_indices() + .find(|(idx, _)| *idx >= keep_from) + .map_or(0, |(idx, _)| idx); + marker_window.drain(..boundary); + } + + if !suppress_forwarding && looks_like_streamed_tool_payload(&marker_window) { + suppress_forwarding = true; + if outcome.forwarded_live_deltas { + if let Some(tx) = delta_sender { + let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await; + } + outcome.forwarded_live_deltas = false; + } + } + + if suppress_forwarding { + continue; + } + + if let Some(tx) = delta_sender { + if !outcome.forwarded_live_deltas { + let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await; + outcome.forwarded_live_deltas = true; + } + if tx.send(chunk.delta).await.is_err() { + delta_sender = None; + } + } + } + + Ok(outcome) +} + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. /// When `silent` is true, suppresses stdout (for channel use). @@ -779,23 +903,74 @@ pub(crate) async fn run_tool_call_loop( } else { None }; + let should_consume_provider_stream = + on_delta.is_some() && provider.supports_streaming() && request_tools.is_none(); + let mut streamed_live_deltas = false; - let chat_future = provider.chat( - ChatRequest { - messages: &prepared_messages.messages, - tools: request_tools, - }, - model, - temperature, - ); - - let chat_result = if let Some(token) = cancellation_token.as_ref() { - tokio::select! { - () = token.cancelled() => return Err(ToolLoopCancelled.into()), - result = chat_future => result, + let chat_result = if should_consume_provider_stream { + match consume_provider_streaming_response( + provider, + &prepared_messages.messages, + model, + temperature, + cancellation_token.as_ref(), + on_delta.as_ref(), + ) + .await + { + Ok(streamed) => { + streamed_live_deltas = streamed.forwarded_live_deltas; + Ok(crate::providers::ChatResponse { + text: Some(streamed.response_text), + tool_calls: Vec::new(), + usage: None, + reasoning_content: None, + }) + } + Err(stream_err) => { + tracing::warn!( + provider = provider_name, + model = model, + iteration = iteration + 1, + "provider streaming failed, falling back to non-streaming chat: {stream_err}" + ); + runtime_trace::record_event( + "llm_stream_fallback", + Some(channel_name), + Some(provider_name), + Some(model), + Some(&turn_id), + Some(false), + Some("provider stream failed; fallback to non-streaming chat"), + serde_json::json!({ + "iteration": iteration + 1, + "error": scrub_credentials(&stream_err.to_string()), + }), + ); + if let Some(ref tx) = on_delta { + let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await; + } + call_provider_chat( + provider, + &prepared_messages.messages, + request_tools, + model, + temperature, + cancellation_token.as_ref(), + ) + .await + } } } else { - chat_future.await + call_provider_chat( + provider, + &prepared_messages.messages, + request_tools, + model, + temperature, + cancellation_token.as_ref(), + ) + .await }; let ( @@ -805,6 +980,7 @@ pub(crate) async fn run_tool_call_loop( assistant_history_content, native_tool_calls, parse_issue_detected, + response_streamed_live, ) = match chat_result { Ok(resp) => { let (resp_input_tokens, resp_output_tokens) = resp @@ -908,6 +1084,7 @@ pub(crate) async fn run_tool_call_loop( assistant_history_content, native_calls, parse_issue.is_some(), + streamed_live_deltas, ) } Err(e) => { @@ -1052,6 +1229,12 @@ pub(crate) async fn run_tool_call_loop( // If a streaming sender is provided, relay the text in small chunks // so the channel can progressively update the draft message. if let Some(ref tx) = on_delta { + let should_emit_post_hoc_chunks = + !response_streamed_live || display_text != response_text; + if !should_emit_post_hoc_chunks { + history.push(ChatMessage::assistant(response_text.clone())); + return Ok(display_text); + } // Clear accumulated progress lines before streaming the final answer. let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await; // Split on whitespace boundaries, accumulating chunks of at least @@ -2371,7 +2554,7 @@ mod tests { use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use crate::observability::NoopObserver; - use crate::providers::traits::ProviderCapabilities; + use crate::providers::traits::{ProviderCapabilities, StreamChunk, StreamOptions}; use crate::providers::ChatResponse; use crate::runtime::NativeRuntime; use crate::security::{AutonomyLevel, SecurityPolicy, ShellRedirectPolicy}; @@ -2504,6 +2687,81 @@ mod tests { } } + struct StreamingScriptedProvider { + responses: Arc>>, + stream_calls: Arc, + chat_calls: Arc, + } + + impl StreamingScriptedProvider { + fn from_text_responses(responses: Vec<&str>) -> Self { + Self { + responses: Arc::new(Mutex::new( + responses.into_iter().map(ToString::to_string).collect(), + )), + stream_calls: Arc::new(AtomicUsize::new(0)), + chat_calls: Arc::new(AtomicUsize::new(0)), + } + } + } + + #[async_trait] + impl Provider for StreamingScriptedProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + anyhow::bail!( + "chat_with_system should not be used in streaming scripted provider tests" + ); + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.chat_calls.fetch_add(1, Ordering::SeqCst); + anyhow::bail!("chat should not be called when streaming succeeds") + } + + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_history( + &self, + _messages: &[ChatMessage], + _model: &str, + _temperature: f64, + options: StreamOptions, + ) -> futures_util::stream::BoxStream< + 'static, + crate::providers::traits::StreamResult, + > { + self.stream_calls.fetch_add(1, Ordering::SeqCst); + if !options.enabled { + return Box::pin(futures_util::stream::empty()); + } + + let response = self + .responses + .lock() + .expect("responses lock should be valid") + .pop_front() + .unwrap_or_default(); + + Box::pin(futures_util::stream::iter(vec![ + Ok(StreamChunk::delta(response)), + Ok(StreamChunk::final_chunk()), + ])) + } + } + struct CountingTool { name: String, invocations: Arc, @@ -3539,6 +3797,116 @@ mod tests { ); } + #[tokio::test] + async fn run_tool_call_loop_consumes_provider_stream_for_final_response() { + let provider = + StreamingScriptedProvider::from_text_responses(vec!["streamed final answer"]); + let tools_registry: Vec> = Vec::new(); + let mut history = vec![ + ChatMessage::system("test-system"), + ChatMessage::user("say hi"), + ]; + let observer = NoopObserver; + let (tx, mut rx) = tokio::sync::mpsc::channel::(32); + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "telegram", + &crate::config::MultimodalConfig::default(), + 4, + None, + Some(tx), + None, + &[], + ) + .await + .expect("streaming provider should complete"); + + let mut visible_deltas = String::new(); + while let Some(delta) = rx.recv().await { + if delta == DRAFT_CLEAR_SENTINEL || delta.starts_with(DRAFT_PROGRESS_SENTINEL) { + continue; + } + visible_deltas.push_str(&delta); + } + + assert_eq!(result, "streamed final answer"); + assert_eq!( + visible_deltas, "streamed final answer", + "draft should receive upstream deltas once without post-hoc duplication" + ); + assert_eq!(provider.stream_calls.load(Ordering::SeqCst), 1); + assert_eq!(provider.chat_calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn run_tool_call_loop_streaming_path_preserves_tool_loop_semantics() { + let provider = StreamingScriptedProvider::from_text_responses(vec![ + r#" +{"name":"count_tool","arguments":{"value":"A"}} +"#, + "done", + ]); + let invocations = Arc::new(AtomicUsize::new(0)); + let tools_registry: Vec> = vec![Box::new(CountingTool::new( + "count_tool", + Arc::clone(&invocations), + ))]; + let mut history = vec![ + ChatMessage::system("test-system"), + ChatMessage::user("run tool calls"), + ]; + let observer = NoopObserver; + let (tx, mut rx) = tokio::sync::mpsc::channel::(64); + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "telegram", + &crate::config::MultimodalConfig::default(), + 5, + None, + Some(tx), + None, + &[], + ) + .await + .expect("streaming tool loop should execute tool and finish"); + + let mut visible_deltas = String::new(); + while let Some(delta) = rx.recv().await { + if delta == DRAFT_CLEAR_SENTINEL || delta.starts_with(DRAFT_PROGRESS_SENTINEL) { + continue; + } + visible_deltas.push_str(&delta); + } + + assert_eq!(result, "done"); + assert_eq!(invocations.load(Ordering::SeqCst), 1); + assert_eq!(provider.stream_calls.load(Ordering::SeqCst), 2); + assert_eq!(provider.chat_calls.load(Ordering::SeqCst), 0); + assert_eq!(visible_deltas, "done"); + assert!( + !visible_deltas.contains("