fix(channel): consume provider streaming in tool loop drafts
This commit is contained in:
parent
201de8a300
commit
52b9e6a221
@ -11,6 +11,7 @@ use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use futures_util::StreamExt;
|
||||
use regex::{Regex, RegexSet};
|
||||
use rustyline::error::ReadlineError;
|
||||
use std::collections::{BTreeSet, HashSet};
|
||||
@ -44,6 +45,8 @@ use parsing::{
|
||||
|
||||
/// Minimum characters per chunk when relaying LLM text to a streaming draft.
|
||||
const STREAM_CHUNK_MIN_CHARS: usize = 80;
|
||||
/// Rolling window size for detecting streamed tool-call payload markers.
|
||||
const STREAM_TOOL_MARKER_WINDOW_CHARS: usize = 512;
|
||||
|
||||
/// Default maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||
/// Used as a safe fallback when `max_tool_iterations` is unset or configured as zero.
|
||||
@ -494,6 +497,127 @@ pub(crate) fn is_tool_iteration_limit_error(err: &anyhow::Error) -> bool {
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct StreamedChatOutcome {
|
||||
response_text: String,
|
||||
forwarded_live_deltas: bool,
|
||||
}
|
||||
|
||||
fn looks_like_streamed_tool_payload(window: &str) -> bool {
|
||||
let lowered = window.to_ascii_lowercase();
|
||||
lowered.contains("<tool_call")
|
||||
|| lowered.contains("<toolcall")
|
||||
|| lowered.contains("\"tool_calls\"")
|
||||
}
|
||||
|
||||
async fn call_provider_chat(
|
||||
provider: &dyn Provider,
|
||||
messages: &[ChatMessage],
|
||||
request_tools: Option<&[crate::tools::ToolSpec]>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<crate::providers::ChatResponse> {
|
||||
let chat_future = provider.chat(
|
||||
ChatRequest {
|
||||
messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
);
|
||||
|
||||
if let Some(token) = cancellation_token {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
}
|
||||
}
|
||||
|
||||
async fn consume_provider_streaming_response(
|
||||
provider: &dyn Provider,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
on_delta: Option<&tokio::sync::mpsc::Sender<String>>,
|
||||
) -> Result<StreamedChatOutcome> {
|
||||
let mut provider_stream = provider.stream_chat_with_history(
|
||||
messages,
|
||||
model,
|
||||
temperature,
|
||||
crate::providers::traits::StreamOptions::new(true),
|
||||
);
|
||||
let mut outcome = StreamedChatOutcome::default();
|
||||
let mut delta_sender = on_delta;
|
||||
let mut suppress_forwarding = false;
|
||||
let mut marker_window = String::new();
|
||||
|
||||
loop {
|
||||
let next_chunk = if let Some(token) = cancellation_token {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
chunk = provider_stream.next() => chunk,
|
||||
}
|
||||
} else {
|
||||
provider_stream.next().await
|
||||
};
|
||||
|
||||
let Some(chunk_result) = next_chunk else {
|
||||
break;
|
||||
};
|
||||
|
||||
let chunk = chunk_result.map_err(|err| anyhow::anyhow!("provider stream error: {err}"))?;
|
||||
if chunk.is_final {
|
||||
break;
|
||||
}
|
||||
if chunk.delta.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
outcome.response_text.push_str(&chunk.delta);
|
||||
marker_window.push_str(&chunk.delta);
|
||||
|
||||
if marker_window.len() > STREAM_TOOL_MARKER_WINDOW_CHARS {
|
||||
let keep_from = marker_window.len() - STREAM_TOOL_MARKER_WINDOW_CHARS;
|
||||
let boundary = marker_window
|
||||
.char_indices()
|
||||
.find(|(idx, _)| *idx >= keep_from)
|
||||
.map_or(0, |(idx, _)| idx);
|
||||
marker_window.drain(..boundary);
|
||||
}
|
||||
|
||||
if !suppress_forwarding && looks_like_streamed_tool_payload(&marker_window) {
|
||||
suppress_forwarding = true;
|
||||
if outcome.forwarded_live_deltas {
|
||||
if let Some(tx) = delta_sender {
|
||||
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
|
||||
}
|
||||
outcome.forwarded_live_deltas = false;
|
||||
}
|
||||
}
|
||||
|
||||
if suppress_forwarding {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(tx) = delta_sender {
|
||||
if !outcome.forwarded_live_deltas {
|
||||
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
|
||||
outcome.forwarded_live_deltas = true;
|
||||
}
|
||||
if tx.send(chunk.delta).await.is_err() {
|
||||
delta_sender = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
}
|
||||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
/// When `silent` is true, suppresses stdout (for channel use).
|
||||
@ -779,23 +903,74 @@ pub(crate) async fn run_tool_call_loop(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let should_consume_provider_stream =
|
||||
on_delta.is_some() && provider.supports_streaming() && request_tools.is_none();
|
||||
let mut streamed_live_deltas = false;
|
||||
|
||||
let chat_future = provider.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
);
|
||||
|
||||
let chat_result = if let Some(token) = cancellation_token.as_ref() {
|
||||
tokio::select! {
|
||||
() = token.cancelled() => return Err(ToolLoopCancelled.into()),
|
||||
result = chat_future => result,
|
||||
let chat_result = if should_consume_provider_stream {
|
||||
match consume_provider_streaming_response(
|
||||
provider,
|
||||
&prepared_messages.messages,
|
||||
model,
|
||||
temperature,
|
||||
cancellation_token.as_ref(),
|
||||
on_delta.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(streamed) => {
|
||||
streamed_live_deltas = streamed.forwarded_live_deltas;
|
||||
Ok(crate::providers::ChatResponse {
|
||||
text: Some(streamed.response_text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
Err(stream_err) => {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = model,
|
||||
iteration = iteration + 1,
|
||||
"provider streaming failed, falling back to non-streaming chat: {stream_err}"
|
||||
);
|
||||
runtime_trace::record_event(
|
||||
"llm_stream_fallback",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some("provider stream failed; fallback to non-streaming chat"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"error": scrub_credentials(&stream_err.to_string()),
|
||||
}),
|
||||
);
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
|
||||
}
|
||||
call_provider_chat(
|
||||
provider,
|
||||
&prepared_messages.messages,
|
||||
request_tools,
|
||||
model,
|
||||
temperature,
|
||||
cancellation_token.as_ref(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
chat_future.await
|
||||
call_provider_chat(
|
||||
provider,
|
||||
&prepared_messages.messages,
|
||||
request_tools,
|
||||
model,
|
||||
temperature,
|
||||
cancellation_token.as_ref(),
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
let (
|
||||
@ -805,6 +980,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
assistant_history_content,
|
||||
native_tool_calls,
|
||||
parse_issue_detected,
|
||||
response_streamed_live,
|
||||
) = match chat_result {
|
||||
Ok(resp) => {
|
||||
let (resp_input_tokens, resp_output_tokens) = resp
|
||||
@ -908,6 +1084,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
assistant_history_content,
|
||||
native_calls,
|
||||
parse_issue.is_some(),
|
||||
streamed_live_deltas,
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
@ -1052,6 +1229,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
// If a streaming sender is provided, relay the text in small chunks
|
||||
// so the channel can progressively update the draft message.
|
||||
if let Some(ref tx) = on_delta {
|
||||
let should_emit_post_hoc_chunks =
|
||||
!response_streamed_live || display_text != response_text;
|
||||
if !should_emit_post_hoc_chunks {
|
||||
history.push(ChatMessage::assistant(response_text.clone()));
|
||||
return Ok(display_text);
|
||||
}
|
||||
// Clear accumulated progress lines before streaming the final answer.
|
||||
let _ = tx.send(DRAFT_CLEAR_SENTINEL.to_string()).await;
|
||||
// Split on whitespace boundaries, accumulating chunks of at least
|
||||
@ -2371,7 +2554,7 @@ mod tests {
|
||||
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::traits::ProviderCapabilities;
|
||||
use crate::providers::traits::{ProviderCapabilities, StreamChunk, StreamOptions};
|
||||
use crate::providers::ChatResponse;
|
||||
use crate::runtime::NativeRuntime;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy, ShellRedirectPolicy};
|
||||
@ -2504,6 +2687,81 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingScriptedProvider {
|
||||
responses: Arc<Mutex<VecDeque<String>>>,
|
||||
stream_calls: Arc<AtomicUsize>,
|
||||
chat_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl StreamingScriptedProvider {
|
||||
fn from_text_responses(responses: Vec<&str>) -> Self {
|
||||
Self {
|
||||
responses: Arc::new(Mutex::new(
|
||||
responses.into_iter().map(ToString::to_string).collect(),
|
||||
)),
|
||||
stream_calls: Arc::new(AtomicUsize::new(0)),
|
||||
chat_calls: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for StreamingScriptedProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
anyhow::bail!(
|
||||
"chat_with_system should not be used in streaming scripted provider tests"
|
||||
);
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
self.chat_calls.fetch_add(1, Ordering::SeqCst);
|
||||
anyhow::bail!("chat should not be called when streaming succeeds")
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn stream_chat_with_history(
|
||||
&self,
|
||||
_messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
options: StreamOptions,
|
||||
) -> futures_util::stream::BoxStream<
|
||||
'static,
|
||||
crate::providers::traits::StreamResult<StreamChunk>,
|
||||
> {
|
||||
self.stream_calls.fetch_add(1, Ordering::SeqCst);
|
||||
if !options.enabled {
|
||||
return Box::pin(futures_util::stream::empty());
|
||||
}
|
||||
|
||||
let response = self
|
||||
.responses
|
||||
.lock()
|
||||
.expect("responses lock should be valid")
|
||||
.pop_front()
|
||||
.unwrap_or_default();
|
||||
|
||||
Box::pin(futures_util::stream::iter(vec![
|
||||
Ok(StreamChunk::delta(response)),
|
||||
Ok(StreamChunk::final_chunk()),
|
||||
]))
|
||||
}
|
||||
}
|
||||
|
||||
struct CountingTool {
|
||||
name: String,
|
||||
invocations: Arc<AtomicUsize>,
|
||||
@ -3539,6 +3797,116 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_consumes_provider_stream_for_final_response() {
|
||||
let provider =
|
||||
StreamingScriptedProvider::from_text_responses(vec!["streamed final answer"]);
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("say hi"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(32);
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"telegram",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
4,
|
||||
None,
|
||||
Some(tx),
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("streaming provider should complete");
|
||||
|
||||
let mut visible_deltas = String::new();
|
||||
while let Some(delta) = rx.recv().await {
|
||||
if delta == DRAFT_CLEAR_SENTINEL || delta.starts_with(DRAFT_PROGRESS_SENTINEL) {
|
||||
continue;
|
||||
}
|
||||
visible_deltas.push_str(&delta);
|
||||
}
|
||||
|
||||
assert_eq!(result, "streamed final answer");
|
||||
assert_eq!(
|
||||
visible_deltas, "streamed final answer",
|
||||
"draft should receive upstream deltas once without post-hoc duplication"
|
||||
);
|
||||
assert_eq!(provider.stream_calls.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(provider.chat_calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_streaming_path_preserves_tool_loop_semantics() {
|
||||
let provider = StreamingScriptedProvider::from_text_responses(vec![
|
||||
r#"<tool_call>
|
||||
{"name":"count_tool","arguments":{"value":"A"}}
|
||||
</tool_call>"#,
|
||||
"done",
|
||||
]);
|
||||
let invocations = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingTool::new(
|
||||
"count_tool",
|
||||
Arc::clone(&invocations),
|
||||
))];
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("run tool calls"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(64);
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"telegram",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
5,
|
||||
None,
|
||||
Some(tx),
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("streaming tool loop should execute tool and finish");
|
||||
|
||||
let mut visible_deltas = String::new();
|
||||
while let Some(delta) = rx.recv().await {
|
||||
if delta == DRAFT_CLEAR_SENTINEL || delta.starts_with(DRAFT_PROGRESS_SENTINEL) {
|
||||
continue;
|
||||
}
|
||||
visible_deltas.push_str(&delta);
|
||||
}
|
||||
|
||||
assert_eq!(result, "done");
|
||||
assert_eq!(invocations.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(provider.stream_calls.load(Ordering::SeqCst), 2);
|
||||
assert_eq!(provider.chat_calls.load(Ordering::SeqCst), 0);
|
||||
assert_eq!(visible_deltas, "done");
|
||||
assert!(
|
||||
!visible_deltas.contains("<tool_call"),
|
||||
"draft text should not leak streamed tool payload markers"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn looks_like_unverified_action_completion_without_tool_call_detects_claimed_side_effects() {
|
||||
assert!(looks_like_unverified_action_completion_without_tool_call(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user