fix(channel): consume provider streaming in tool loop drafts

This commit is contained in:
argenis de la rosa 2026-03-05 09:57:00 -05:00 committed by Argenis
parent 201de8a300
commit 52b9e6a221

View File

@ -11,6 +11,7 @@ use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use futures_util::StreamExt;
use regex::{Regex, RegexSet};
use rustyline::error::ReadlineError;
use std::collections::{BTreeSet, HashSet};
@ -44,6 +45,8 @@ use parsing::{
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
const STREAM_CHUNK_MIN_CHARS: usize = 80;
/// Rolling window size for detecting streamed tool-call payload markers.
const STREAM_TOOL_MARKER_WINDOW_CHARS: usize = 512;
/// Default maximum agentic tool-use iterations per user message to prevent runaway loops.
/// Used as a safe fallback when `max_tool_iterations` is unset or configured as zero.
@ -494,6 +497,127 @@ pub(crate) fn is_tool_iteration_limit_error(err: &anyhow::Error) -> bool {
})
}
#[derive(Debug, Default)]
struct StreamedChatOutcome {
response_text: String,
forwarded_live_deltas: bool,
}
fn looks_like_streamed_tool_payload(window: &str) -> bool {
let lowered = window.to_ascii_lowercase();
lowered.contains("<tool_call")
|| lowered.contains("<toolcall")
|| lowered.contains("\"tool_calls\"")
}
async fn call_provider_chat(
provider: &dyn Provider,
messages: &[ChatMessage],
request_tools: Option<&[crate::tools::ToolSpec]>,
model: &str,
temperature: f64,
cancellation_token: Option<&CancellationToken>,
) -> Result<crate::providers::ChatResponse> {
let chat_future = provider.chat(
ChatRequest {
messages,
tools: request_tools,
},
model,
temperature,
);
if let Some(token) = cancellation_token {
tokio::select! {
() = token.cancelled() => Err(ToolLoopCancelled.into()),
result = chat_future => result,
}
} else {
chat_future.await
}
}
async fn consume_provider_streaming_response(
provider: &dyn Provider,
messages: &[ChatMessage],
model: &str,
temperature: f64,
cancellation_token: Option<&CancellationToken>,
on_delta: Option<&tokio::sync::mpsc::Sender<String>>,
) -> Result<StreamedChatOutcome> {
let mut provider_stream = provider.stream_chat_with_history(
messages,
model,
temperature,
crate::providers::traits::StreamOptions::new(true),
);
let mut outcome = StreamedChatOutcome::default();
let mut delta_sender = on_delta;
let mut suppress_forwarding = false;
let mut marker_window = String::new();
loop {
let next_chunk = if let Some(token) = cancellation_token {
tokio::select! {
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
chunk = provider_stream.next() => chunk,
}
} else {
provider_stream.next().await
};
let Some(chunk_result) = next_chunk else {
break;
};
let chunk = chunk_result.map_err(|err| anyhow::anyhow!("provider stream error: {err}"))?;
if chunk.is_final {
break;
}
if chunk.delta.is_empty() {
continue;
}
outcome.response_text.push_str(&chunk.delta);
marker_window.push_str(&chunk.delta);
if marker_window.len() > STREAM_TOOL_MARKER_WINDOW_CHARS {
let keep_from = marker_window.len() - STREAM_TOOL_MARKER_WINDOW_CHARS;
let boundary = marker_window
.char_indices()
.find(|(idx, _)| *idx >= keep_from)
.map_or(0, |(idx, _)| idx);
marker_window.drain(..boundary);
}
if !suppress_forwarding && looks_like_streamed_tool_payload(&marker_window) {
suppress_forwarding = true;
if outcome.forwarded_live_deltas {
if let Some(tx) = delta_sender {
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
}
outcome.forwarded_live_deltas = false;
}
}
if suppress_forwarding {
continue;
}
if let Some(tx) = delta_sender {
if !outcome.forwarded_live_deltas {
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
outcome.forwarded_live_deltas = true;
}
if tx.send(chunk.delta).await.is_err() {
delta_sender = None;
}
}
}
Ok(outcome)
}
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
/// When `silent` is true, suppresses stdout (for channel use).
@ -779,23 +903,74 @@ pub(crate) async fn run_tool_call_loop(
} else {
None
};
let should_consume_provider_stream =
on_delta.is_some() && provider.supports_streaming() && request_tools.is_none();
let mut streamed_live_deltas = false;
let chat_future = provider.chat(
ChatRequest {
messages: &prepared_messages.messages,
tools: request_tools,
},
model,
temperature,
);
let chat_result = if let Some(token) = cancellation_token.as_ref() {
tokio::select! {
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
result = chat_future => result,
let chat_result = if should_consume_provider_stream {
match consume_provider_streaming_response(
provider,
&prepared_messages.messages,
model,
temperature,
cancellation_token.as_ref(),
on_delta.as_ref(),
)
.await
{
Ok(streamed) => {
streamed_live_deltas = streamed.forwarded_live_deltas;
Ok(crate::providers::ChatResponse {
text: Some(streamed.response_text),
tool_calls: Vec::new(),
usage: None,
reasoning_content: None,
})
}
Err(stream_err) => {
tracing::warn!(
provider = provider_name,
model = model,
iteration = iteration + 1,
"provider streaming failed, falling back to non-streaming chat: {stream_err}"
);
runtime_trace::record_event(
"llm_stream_fallback",
Some(channel_name),
Some(provider_name),
Some(model),
Some(&turn_id),
Some(false),
Some("provider stream failed; fallback to non-streaming chat"),
serde_json::json!({
"iteration": iteration + 1,
"error": scrub_credentials(&stream_err.to_string()),
}),
);
if let Some(ref tx) = on_delta {
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
}
call_provider_chat(
provider,
&prepared_messages.messages,
request_tools,
model,
temperature,
cancellation_token.as_ref(),
)
.await
}
}
} else {
chat_future.await
call_provider_chat(
provider,
&prepared_messages.messages,
request_tools,
model,
temperature,
cancellation_token.as_ref(),
)
.await
};
let (
@ -805,6 +980,7 @@ pub(crate) async fn run_tool_call_loop(
assistant_history_content,
native_tool_calls,
parse_issue_detected,
response_streamed_live,
) = match chat_result {
Ok(resp) => {
let (resp_input_tokens, resp_output_tokens) = resp
@ -908,6 +1084,7 @@ pub(crate) async fn run_tool_call_loop(
assistant_history_content,
native_calls,
parse_issue.is_some(),
streamed_live_deltas,
)
}
Err(e) => {
@ -1052,6 +1229,12 @@ pub(crate) async fn run_tool_call_loop(
// If a streaming sender is provided, relay the text in small chunks
// so the channel can progressively update the draft message.
if let Some(ref tx) = on_delta {
let should_emit_post_hoc_chunks =
!response_streamed_live || display_text != response_text;
if !should_emit_post_hoc_chunks {
history.push(ChatMessage::assistant(response_text.clone()));
return Ok(display_text);
}
// Clear accumulated progress lines before streaming the final answer.
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
// Split on whitespace boundaries, accumulating chunks of at least
@ -2371,7 +2554,7 @@ mod tests {
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::observability::NoopObserver;
use crate::providers::traits::ProviderCapabilities;
use crate::providers::traits::{ProviderCapabilities, StreamChunk, StreamOptions};
use crate::providers::ChatResponse;
use crate::runtime::NativeRuntime;
use crate::security::{AutonomyLevel, SecurityPolicy, ShellRedirectPolicy};
@ -2504,6 +2687,81 @@ mod tests {
}
}
struct StreamingScriptedProvider {
responses: Arc<Mutex<VecDeque<String>>>,
stream_calls: Arc<AtomicUsize>,
chat_calls: Arc<AtomicUsize>,
}
impl StreamingScriptedProvider {
fn from_text_responses(responses: Vec<&str>) -> Self {
Self {
responses: Arc::new(Mutex::new(
responses.into_iter().map(ToString::to_string).collect(),
)),
stream_calls: Arc::new(AtomicUsize::new(0)),
chat_calls: Arc::new(AtomicUsize::new(0)),
}
}
}
#[async_trait]
impl Provider for StreamingScriptedProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
anyhow::bail!(
"chat_with_system should not be used in streaming scripted provider tests"
);
}
async fn chat(
&self,
_request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
self.chat_calls.fetch_add(1, Ordering::SeqCst);
anyhow::bail!("chat should not be called when streaming succeeds")
}
fn supports_streaming(&self) -> bool {
true
}
fn stream_chat_with_history(
&self,
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
options: StreamOptions,
) -> futures_util::stream::BoxStream<
'static,
crate::providers::traits::StreamResult<StreamChunk>,
> {
self.stream_calls.fetch_add(1, Ordering::SeqCst);
if !options.enabled {
return Box::pin(futures_util::stream::empty());
}
let response = self
.responses
.lock()
.expect("responses lock should be valid")
.pop_front()
.unwrap_or_default();
Box::pin(futures_util::stream::iter(vec![
Ok(StreamChunk::delta(response)),
Ok(StreamChunk::final_chunk()),
]))
}
}
struct CountingTool {
name: String,
invocations: Arc<AtomicUsize>,
@ -3539,6 +3797,116 @@ mod tests {
);
}
#[tokio::test]
async fn run_tool_call_loop_consumes_provider_stream_for_final_response() {
let provider =
StreamingScriptedProvider::from_text_responses(vec!["streamed final answer"]);
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
let mut history = vec![
ChatMessage::system("test-system"),
ChatMessage::user("say hi"),
];
let observer = NoopObserver;
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(32);
let result = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"telegram",
&crate::config::MultimodalConfig::default(),
4,
None,
Some(tx),
None,
&[],
)
.await
.expect("streaming provider should complete");
let mut visible_deltas = String::new();
while let Some(delta) = rx.recv().await {
if delta == DRAFT_CLEAR_SENTINEL || delta.starts_with(DRAFT_PROGRESS_SENTINEL) {
continue;
}
visible_deltas.push_str(&delta);
}
assert_eq!(result, "streamed final answer");
assert_eq!(
visible_deltas, "streamed final answer",
"draft should receive upstream deltas once without post-hoc duplication"
);
assert_eq!(provider.stream_calls.load(Ordering::SeqCst), 1);
assert_eq!(provider.chat_calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn run_tool_call_loop_streaming_path_preserves_tool_loop_semantics() {
let provider = StreamingScriptedProvider::from_text_responses(vec![
r#"<tool_call>
{"name":"count_tool","arguments":{"value":"A"}}
</tool_call>"#,
"done",
]);
let invocations = Arc::new(AtomicUsize::new(0));
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingTool::new(
"count_tool",
Arc::clone(&invocations),
))];
let mut history = vec![
ChatMessage::system("test-system"),
ChatMessage::user("run tool calls"),
];
let observer = NoopObserver;
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(64);
let result = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"telegram",
&crate::config::MultimodalConfig::default(),
5,
None,
Some(tx),
None,
&[],
)
.await
.expect("streaming tool loop should execute tool and finish");
let mut visible_deltas = String::new();
while let Some(delta) = rx.recv().await {
if delta == DRAFT_CLEAR_SENTINEL || delta.starts_with(DRAFT_PROGRESS_SENTINEL) {
continue;
}
visible_deltas.push_str(&delta);
}
assert_eq!(result, "done");
assert_eq!(invocations.load(Ordering::SeqCst), 1);
assert_eq!(provider.stream_calls.load(Ordering::SeqCst), 2);
assert_eq!(provider.chat_calls.load(Ordering::SeqCst), 0);
assert_eq!(visible_deltas, "done");
assert!(
!visible_deltas.contains("<tool_call"),
"draft text should not leak streamed tool payload markers"
);
}
#[test]
fn looks_like_unverified_action_completion_without_tool_call_detects_claimed_side_effects() {
assert!(looks_like_unverified_action_completion_without_tool_call(