Compare commits

...

2 Commits

Author SHA1 Message Date
argenis de la rosa
7bcc8d5d26 fix(agent): address review feedback on session state persistence
- Validate version field on load; bail on unsupported versions
- Refresh stale system prompt from history[0] on resume
- Use write-to-temp-then-rename for atomic saves
- Add tests for nonexistent file, malformed JSON, unsupported version,
  and stale system prompt refresh
2026-03-13 18:42:28 -04:00
TOTHEMOON\youdo
95c9080a37 feat(agent): 支持交互会话状态持久化与恢复 2026-03-13 18:08:17 +08:00
4 changed files with 198 additions and 1 deletions

View File

@ -12,9 +12,11 @@ use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use regex::{Regex, RegexSet};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt::Write;
use std::io::Write as _;
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
@ -236,6 +238,60 @@ async fn auto_compact_history(
Ok(true)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct InteractiveSessionState {
version: u32,
history: Vec<ChatMessage>,
}
impl InteractiveSessionState {
fn from_history(history: &[ChatMessage]) -> Self {
Self {
version: 1,
history: history.to_vec(),
}
}
}
fn load_interactive_session_history(path: &Path, system_prompt: &str) -> Result<Vec<ChatMessage>> {
if !path.exists() {
return Ok(vec![ChatMessage::system(system_prompt)]);
}
let raw = std::fs::read_to_string(path)?;
let state: InteractiveSessionState = serde_json::from_str(&raw)?;
if state.version != 1 {
anyhow::bail!("unsupported session state version: {}", state.version);
}
let mut history = state.history;
if history.is_empty() {
history.push(ChatMessage::system(system_prompt));
} else if history.first().map(|msg| msg.role.as_str()) == Some("system") {
// Always refresh with the current system prompt so config changes take effect.
history[0] = ChatMessage::system(system_prompt);
} else {
history.insert(0, ChatMessage::system(system_prompt));
}
Ok(history)
}
fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let payload = serde_json::to_string_pretty(&InteractiveSessionState::from_history(history))?;
// Write to a temporary file then rename for atomicity.
let tmp_path = path.with_extension("tmp");
std::fs::write(&tmp_path, payload)?;
std::fs::rename(&tmp_path, path)?;
Ok(())
}
/// Build context preamble by searching memory for relevant entries.
/// Entries with a hybrid score below `min_relevance_score` are dropped to
/// prevent unrelated memories from bleeding into the conversation.
@ -2802,6 +2858,7 @@ pub async fn run(
temperature: f64,
peripheral_overrides: Vec<String>,
interactive: bool,
session_state_file: Option<PathBuf>,
) -> Result<String> {
// ── Wire up agnostic subsystems ──────────────────────────────
let base_observer = observability::create_observer(&config.observability);
@ -3119,7 +3176,11 @@ pub async fn run(
let cli = crate::channels::CliChannel::new();
// Persistent conversation history across turns
let mut history = vec![ChatMessage::system(&system_prompt)];
let mut history = if let Some(path) = session_state_file.as_deref() {
load_interactive_session_history(path, &system_prompt)?
} else {
vec![ChatMessage::system(&system_prompt)]
};
loop {
print!("> ");
@ -3182,6 +3243,9 @@ pub async fn run(
} else {
println!("Conversation cleared.\n");
}
if let Some(path) = session_state_file.as_deref() {
save_interactive_session_history(path, &history)?;
}
continue;
}
_ => {}
@ -3266,6 +3330,10 @@ pub async fn run(
// Hard cap as a safety net.
trim_history(&mut history, config.agent.max_history_messages);
if let Some(path) = session_state_file.as_deref() {
save_interactive_session_history(path, &history)?;
}
}
}
@ -3467,6 +3535,110 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
#[cfg(test)]
mod tests {
use super::{
apply_compaction_summary, build_compaction_transcript, load_interactive_session_history,
save_interactive_session_history, InteractiveSessionState,
};
use crate::providers::ChatMessage;
use tempfile::tempdir;
#[test]
fn interactive_session_state_round_trips_history() {
let dir = tempdir().unwrap();
let path = dir.path().join("session.json");
let history = vec![
ChatMessage::system("system"),
ChatMessage::user("hello"),
ChatMessage::assistant("hi"),
];
save_interactive_session_history(&path, &history).unwrap();
let restored = load_interactive_session_history(&path, "fallback").unwrap();
assert_eq!(restored.len(), 3);
assert_eq!(restored[0].role, "system");
assert_eq!(restored[1].content, "hello");
assert_eq!(restored[2].content, "hi");
}
#[test]
fn interactive_session_state_adds_missing_system_prompt() {
let dir = tempdir().unwrap();
let path = dir.path().join("session.json");
let payload = serde_json::to_string_pretty(&InteractiveSessionState {
version: 1,
history: vec![ChatMessage::user("orphan")],
})
.unwrap();
std::fs::write(&path, payload).unwrap();
let restored = load_interactive_session_history(&path, "fallback system").unwrap();
assert_eq!(restored[0].role, "system");
assert_eq!(restored[0].content, "fallback system");
assert_eq!(restored[1].content, "orphan");
}
#[test]
fn interactive_session_state_nonexistent_file_returns_fresh_history() {
let dir = tempdir().unwrap();
let path = dir.path().join("does_not_exist.json");
let history = load_interactive_session_history(&path, "fresh prompt").unwrap();
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, "system");
assert_eq!(history[0].content, "fresh prompt");
}
#[test]
fn interactive_session_state_malformed_json_returns_error() {
let dir = tempdir().unwrap();
let path = dir.path().join("bad.json");
std::fs::write(&path, "NOT VALID JSON {{{").unwrap();
let result = load_interactive_session_history(&path, "prompt");
assert!(result.is_err());
}
#[test]
fn interactive_session_state_unsupported_version_returns_error() {
let dir = tempdir().unwrap();
let path = dir.path().join("v99.json");
let payload = serde_json::to_string_pretty(&InteractiveSessionState {
version: 99,
history: vec![ChatMessage::user("msg")],
})
.unwrap();
std::fs::write(&path, payload).unwrap();
let result = load_interactive_session_history(&path, "prompt");
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("unsupported session state version"),
"unexpected error: {err_msg}"
);
}
#[test]
fn interactive_session_state_refreshes_stale_system_prompt() {
let dir = tempdir().unwrap();
let path = dir.path().join("session.json");
let history = vec![
ChatMessage::system("old system prompt"),
ChatMessage::user("hello"),
];
save_interactive_session_history(&path, &history).unwrap();
let restored = load_interactive_session_history(&path, "new system prompt").unwrap();
assert_eq!(restored.len(), 2);
assert_eq!(restored[0].role, "system");
assert_eq!(restored[0].content, "new system prompt");
assert_eq!(restored[1].content, "hello");
}
use super::*;
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine as _};

View File

@ -175,6 +175,7 @@ async fn run_agent_job(
config.default_temperature,
vec![],
false,
None,
)
.await
}

View File

@ -235,6 +235,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
temp,
vec![],
false,
None,
)
.await
{

View File

@ -37,6 +37,7 @@ use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
use dialoguer::{Input, Password};
use serde::{Deserialize, Serialize};
use std::io::Write;
use std::path::PathBuf;
use tracing::{info, warn};
use tracing_subscriber::{fmt, EnvFilter};
@ -180,6 +181,10 @@ Examples:
#[arg(short, long)]
message: Option<String>,
/// Load and save interactive session state in this JSON file
#[arg(long)]
session_state_file: Option<PathBuf>,
/// Provider to use (openrouter, anthropic, openai, openai-codex)
#[arg(short, long)]
provider: Option<String>,
@ -814,6 +819,7 @@ async fn main() -> Result<()> {
Commands::Agent {
message,
session_state_file,
provider,
model,
temperature,
@ -829,6 +835,7 @@ async fn main() -> Result<()> {
final_temperature,
peripheral,
true,
session_state_file,
)
.await
.map(|_| ())
@ -2218,6 +2225,22 @@ mod tests {
}
}
#[test]
fn agent_command_parses_session_state_file() {
let cli =
Cli::try_parse_from(["zeroclaw", "agent", "--session-state-file", "session.json"])
.expect("agent command with session state file should parse");
match cli.command {
Commands::Agent {
session_state_file, ..
} => {
assert_eq!(session_state_file, Some(PathBuf::from("session.json")));
}
other => panic!("expected agent command, got {other:?}"),
}
}
#[test]
fn agent_fallback_uses_config_default_temperature() {
// Test that when user doesn't provide --temperature,