From 7bcc8d5d26992298c7e475c10d144fb08de15542 Mon Sep 17 00:00:00 2001 From: argenis de la rosa Date: Fri, 13 Mar 2026 18:42:28 -0400 Subject: [PATCH] fix(agent): address review feedback on session state persistence - Validate version field on load; bail on unsupported versions - Refresh stale system prompt from history[0] on resume - Use write-to-temp-then-rename for atomic saves - Add tests for nonexistent file, malformed JSON, unsupported version, and stale system prompt refresh --- src/agent/loop_.rs | 87 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 80 insertions(+), 7 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index a6522392a..f8d277a69 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -259,14 +259,23 @@ fn load_interactive_session_history(path: &Path, system_prompt: &str) -> Result< } let raw = std::fs::read_to_string(path)?; - let mut state: InteractiveSessionState = serde_json::from_str(&raw)?; - if state.history.is_empty() { - state.history.push(ChatMessage::system(system_prompt)); - } else if state.history.first().map(|msg| msg.role.as_str()) != Some("system") { - state.history.insert(0, ChatMessage::system(system_prompt)); + let state: InteractiveSessionState = serde_json::from_str(&raw)?; + + if state.version != 1 { + anyhow::bail!("unsupported session state version: {}", state.version); } - Ok(state.history) + let mut history = state.history; + if history.is_empty() { + history.push(ChatMessage::system(system_prompt)); + } else if history.first().map(|msg| msg.role.as_str()) == Some("system") { + // Always refresh with the current system prompt so config changes take effect. + history[0] = ChatMessage::system(system_prompt); + } else { + history.insert(0, ChatMessage::system(system_prompt)); + } + + Ok(history) } fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Result<()> { @@ -275,7 +284,11 @@ fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Res } let payload = serde_json::to_string_pretty(&InteractiveSessionState::from_history(history))?; - std::fs::write(path, payload)?; + + // Write to a temporary file then rename for atomicity. + let tmp_path = path.with_extension("tmp"); + std::fs::write(&tmp_path, payload)?; + std::fs::rename(&tmp_path, path)?; Ok(()) } @@ -3566,6 +3579,66 @@ mod tests { assert_eq!(restored[1].content, "orphan"); } + #[test] + fn interactive_session_state_nonexistent_file_returns_fresh_history() { + let dir = tempdir().unwrap(); + let path = dir.path().join("does_not_exist.json"); + + let history = load_interactive_session_history(&path, "fresh prompt").unwrap(); + + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + assert_eq!(history[0].content, "fresh prompt"); + } + + #[test] + fn interactive_session_state_malformed_json_returns_error() { + let dir = tempdir().unwrap(); + let path = dir.path().join("bad.json"); + std::fs::write(&path, "NOT VALID JSON {{{").unwrap(); + + let result = load_interactive_session_history(&path, "prompt"); + assert!(result.is_err()); + } + + #[test] + fn interactive_session_state_unsupported_version_returns_error() { + let dir = tempdir().unwrap(); + let path = dir.path().join("v99.json"); + let payload = serde_json::to_string_pretty(&InteractiveSessionState { + version: 99, + history: vec![ChatMessage::user("msg")], + }) + .unwrap(); + std::fs::write(&path, payload).unwrap(); + + let result = load_interactive_session_history(&path, "prompt"); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("unsupported session state version"), + "unexpected error: {err_msg}" + ); + } + + #[test] + fn interactive_session_state_refreshes_stale_system_prompt() { + let dir = tempdir().unwrap(); + let path = dir.path().join("session.json"); + let history = vec![ + ChatMessage::system("old system prompt"), + ChatMessage::user("hello"), + ]; + + save_interactive_session_history(&path, &history).unwrap(); + let restored = load_interactive_session_history(&path, "new system prompt").unwrap(); + + assert_eq!(restored.len(), 2); + assert_eq!(restored[0].role, "system"); + assert_eq!(restored[0].content, "new system prompt"); + assert_eq!(restored[1].content, "hello"); + } + use super::*; use async_trait::async_trait; use base64::{engine::general_purpose::STANDARD, Engine as _};