Compare commits
2 Commits
master
...
feature/in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bcc8d5d26 | ||
|
|
95c9080a37 |
@ -12,9 +12,11 @@ use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use regex::{Regex, RegexSet};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Write;
|
||||
use std::io::Write as _;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@ -236,6 +238,60 @@ async fn auto_compact_history(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct InteractiveSessionState {
|
||||
version: u32,
|
||||
history: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
impl InteractiveSessionState {
|
||||
fn from_history(history: &[ChatMessage]) -> Self {
|
||||
Self {
|
||||
version: 1,
|
||||
history: history.to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_interactive_session_history(path: &Path, system_prompt: &str) -> Result<Vec<ChatMessage>> {
|
||||
if !path.exists() {
|
||||
return Ok(vec![ChatMessage::system(system_prompt)]);
|
||||
}
|
||||
|
||||
let raw = std::fs::read_to_string(path)?;
|
||||
let state: InteractiveSessionState = serde_json::from_str(&raw)?;
|
||||
|
||||
if state.version != 1 {
|
||||
anyhow::bail!("unsupported session state version: {}", state.version);
|
||||
}
|
||||
|
||||
let mut history = state.history;
|
||||
if history.is_empty() {
|
||||
history.push(ChatMessage::system(system_prompt));
|
||||
} else if history.first().map(|msg| msg.role.as_str()) == Some("system") {
|
||||
// Always refresh with the current system prompt so config changes take effect.
|
||||
history[0] = ChatMessage::system(system_prompt);
|
||||
} else {
|
||||
history.insert(0, ChatMessage::system(system_prompt));
|
||||
}
|
||||
|
||||
Ok(history)
|
||||
}
|
||||
|
||||
fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let payload = serde_json::to_string_pretty(&InteractiveSessionState::from_history(history))?;
|
||||
|
||||
// Write to a temporary file then rename for atomicity.
|
||||
let tmp_path = path.with_extension("tmp");
|
||||
std::fs::write(&tmp_path, payload)?;
|
||||
std::fs::rename(&tmp_path, path)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
@ -2802,6 +2858,7 @@ pub async fn run(
|
||||
temperature: f64,
|
||||
peripheral_overrides: Vec<String>,
|
||||
interactive: bool,
|
||||
session_state_file: Option<PathBuf>,
|
||||
) -> Result<String> {
|
||||
// ── Wire up agnostic subsystems ──────────────────────────────
|
||||
let base_observer = observability::create_observer(&config.observability);
|
||||
@ -3119,7 +3176,11 @@ pub async fn run(
|
||||
let cli = crate::channels::CliChannel::new();
|
||||
|
||||
// Persistent conversation history across turns
|
||||
let mut history = vec![ChatMessage::system(&system_prompt)];
|
||||
let mut history = if let Some(path) = session_state_file.as_deref() {
|
||||
load_interactive_session_history(path, &system_prompt)?
|
||||
} else {
|
||||
vec![ChatMessage::system(&system_prompt)]
|
||||
};
|
||||
|
||||
loop {
|
||||
print!("> ");
|
||||
@ -3182,6 +3243,9 @@ pub async fn run(
|
||||
} else {
|
||||
println!("Conversation cleared.\n");
|
||||
}
|
||||
if let Some(path) = session_state_file.as_deref() {
|
||||
save_interactive_session_history(path, &history)?;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
@ -3266,6 +3330,10 @@ pub async fn run(
|
||||
|
||||
// Hard cap as a safety net.
|
||||
trim_history(&mut history, config.agent.max_history_messages);
|
||||
|
||||
if let Some(path) = session_state_file.as_deref() {
|
||||
save_interactive_session_history(path, &history)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -3467,6 +3535,110 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
apply_compaction_summary, build_compaction_transcript, load_interactive_session_history,
|
||||
save_interactive_session_history, InteractiveSessionState,
|
||||
};
|
||||
use crate::providers::ChatMessage;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn interactive_session_state_round_trips_history() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("session.json");
|
||||
let history = vec![
|
||||
ChatMessage::system("system"),
|
||||
ChatMessage::user("hello"),
|
||||
ChatMessage::assistant("hi"),
|
||||
];
|
||||
|
||||
save_interactive_session_history(&path, &history).unwrap();
|
||||
let restored = load_interactive_session_history(&path, "fallback").unwrap();
|
||||
|
||||
assert_eq!(restored.len(), 3);
|
||||
assert_eq!(restored[0].role, "system");
|
||||
assert_eq!(restored[1].content, "hello");
|
||||
assert_eq!(restored[2].content, "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_session_state_adds_missing_system_prompt() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("session.json");
|
||||
let payload = serde_json::to_string_pretty(&InteractiveSessionState {
|
||||
version: 1,
|
||||
history: vec![ChatMessage::user("orphan")],
|
||||
})
|
||||
.unwrap();
|
||||
std::fs::write(&path, payload).unwrap();
|
||||
|
||||
let restored = load_interactive_session_history(&path, "fallback system").unwrap();
|
||||
|
||||
assert_eq!(restored[0].role, "system");
|
||||
assert_eq!(restored[0].content, "fallback system");
|
||||
assert_eq!(restored[1].content, "orphan");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_session_state_nonexistent_file_returns_fresh_history() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("does_not_exist.json");
|
||||
|
||||
let history = load_interactive_session_history(&path, "fresh prompt").unwrap();
|
||||
|
||||
assert_eq!(history.len(), 1);
|
||||
assert_eq!(history[0].role, "system");
|
||||
assert_eq!(history[0].content, "fresh prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_session_state_malformed_json_returns_error() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("bad.json");
|
||||
std::fs::write(&path, "NOT VALID JSON {{{").unwrap();
|
||||
|
||||
let result = load_interactive_session_history(&path, "prompt");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_session_state_unsupported_version_returns_error() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("v99.json");
|
||||
let payload = serde_json::to_string_pretty(&InteractiveSessionState {
|
||||
version: 99,
|
||||
history: vec![ChatMessage::user("msg")],
|
||||
})
|
||||
.unwrap();
|
||||
std::fs::write(&path, payload).unwrap();
|
||||
|
||||
let result = load_interactive_session_history(&path, "prompt");
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err_msg.contains("unsupported session state version"),
|
||||
"unexpected error: {err_msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interactive_session_state_refreshes_stale_system_prompt() {
|
||||
let dir = tempdir().unwrap();
|
||||
let path = dir.path().join("session.json");
|
||||
let history = vec![
|
||||
ChatMessage::system("old system prompt"),
|
||||
ChatMessage::user("hello"),
|
||||
];
|
||||
|
||||
save_interactive_session_history(&path, &history).unwrap();
|
||||
let restored = load_interactive_session_history(&path, "new system prompt").unwrap();
|
||||
|
||||
assert_eq!(restored.len(), 2);
|
||||
assert_eq!(restored[0].role, "system");
|
||||
assert_eq!(restored[0].content, "new system prompt");
|
||||
assert_eq!(restored[1].content, "hello");
|
||||
}
|
||||
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
|
||||
@ -175,6 +175,7 @@ async fn run_agent_job(
|
||||
config.default_temperature,
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@ -235,6 +235,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||
temp,
|
||||
vec![],
|
||||
false,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
23
src/main.rs
23
src/main.rs
@ -37,6 +37,7 @@ use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use dialoguer::{Input, Password};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{info, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
@ -180,6 +181,10 @@ Examples:
|
||||
#[arg(short, long)]
|
||||
message: Option<String>,
|
||||
|
||||
/// Load and save interactive session state in this JSON file
|
||||
#[arg(long)]
|
||||
session_state_file: Option<PathBuf>,
|
||||
|
||||
/// Provider to use (openrouter, anthropic, openai, openai-codex)
|
||||
#[arg(short, long)]
|
||||
provider: Option<String>,
|
||||
@ -814,6 +819,7 @@ async fn main() -> Result<()> {
|
||||
|
||||
Commands::Agent {
|
||||
message,
|
||||
session_state_file,
|
||||
provider,
|
||||
model,
|
||||
temperature,
|
||||
@ -829,6 +835,7 @@ async fn main() -> Result<()> {
|
||||
final_temperature,
|
||||
peripheral,
|
||||
true,
|
||||
session_state_file,
|
||||
)
|
||||
.await
|
||||
.map(|_| ())
|
||||
@ -2218,6 +2225,22 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_command_parses_session_state_file() {
|
||||
let cli =
|
||||
Cli::try_parse_from(["zeroclaw", "agent", "--session-state-file", "session.json"])
|
||||
.expect("agent command with session state file should parse");
|
||||
|
||||
match cli.command {
|
||||
Commands::Agent {
|
||||
session_state_file, ..
|
||||
} => {
|
||||
assert_eq!(session_state_file, Some(PathBuf::from("session.json")));
|
||||
}
|
||||
other => panic!("expected agent command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_fallback_uses_config_default_temperature() {
|
||||
// Test that when user doesn't provide --temperature,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user