Compare commits
2 Commits
master
...
feature/in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bcc8d5d26 | ||
|
|
95c9080a37 |
@ -12,9 +12,11 @@ use crate::tools::{self, Tool};
|
|||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use regex::{Regex, RegexSet};
|
use regex::{Regex, RegexSet};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::io::Write as _;
|
use std::io::Write as _;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::{Arc, LazyLock};
|
use std::sync::{Arc, LazyLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
@ -236,6 +238,60 @@ async fn auto_compact_history(
|
|||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
struct InteractiveSessionState {
|
||||||
|
version: u32,
|
||||||
|
history: Vec<ChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InteractiveSessionState {
|
||||||
|
fn from_history(history: &[ChatMessage]) -> Self {
|
||||||
|
Self {
|
||||||
|
version: 1,
|
||||||
|
history: history.to_vec(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_interactive_session_history(path: &Path, system_prompt: &str) -> Result<Vec<ChatMessage>> {
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(vec![ChatMessage::system(system_prompt)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw = std::fs::read_to_string(path)?;
|
||||||
|
let state: InteractiveSessionState = serde_json::from_str(&raw)?;
|
||||||
|
|
||||||
|
if state.version != 1 {
|
||||||
|
anyhow::bail!("unsupported session state version: {}", state.version);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut history = state.history;
|
||||||
|
if history.is_empty() {
|
||||||
|
history.push(ChatMessage::system(system_prompt));
|
||||||
|
} else if history.first().map(|msg| msg.role.as_str()) == Some("system") {
|
||||||
|
// Always refresh with the current system prompt so config changes take effect.
|
||||||
|
history[0] = ChatMessage::system(system_prompt);
|
||||||
|
} else {
|
||||||
|
history.insert(0, ChatMessage::system(system_prompt));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(history)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Result<()> {
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
std::fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload = serde_json::to_string_pretty(&InteractiveSessionState::from_history(history))?;
|
||||||
|
|
||||||
|
// Write to a temporary file then rename for atomicity.
|
||||||
|
let tmp_path = path.with_extension("tmp");
|
||||||
|
std::fs::write(&tmp_path, payload)?;
|
||||||
|
std::fs::rename(&tmp_path, path)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Build context preamble by searching memory for relevant entries.
|
/// Build context preamble by searching memory for relevant entries.
|
||||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||||
/// prevent unrelated memories from bleeding into the conversation.
|
/// prevent unrelated memories from bleeding into the conversation.
|
||||||
@ -2802,6 +2858,7 @@ pub async fn run(
|
|||||||
temperature: f64,
|
temperature: f64,
|
||||||
peripheral_overrides: Vec<String>,
|
peripheral_overrides: Vec<String>,
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
|
session_state_file: Option<PathBuf>,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
// ── Wire up agnostic subsystems ──────────────────────────────
|
// ── Wire up agnostic subsystems ──────────────────────────────
|
||||||
let base_observer = observability::create_observer(&config.observability);
|
let base_observer = observability::create_observer(&config.observability);
|
||||||
@ -3119,7 +3176,11 @@ pub async fn run(
|
|||||||
let cli = crate::channels::CliChannel::new();
|
let cli = crate::channels::CliChannel::new();
|
||||||
|
|
||||||
// Persistent conversation history across turns
|
// Persistent conversation history across turns
|
||||||
let mut history = vec![ChatMessage::system(&system_prompt)];
|
let mut history = if let Some(path) = session_state_file.as_deref() {
|
||||||
|
load_interactive_session_history(path, &system_prompt)?
|
||||||
|
} else {
|
||||||
|
vec![ChatMessage::system(&system_prompt)]
|
||||||
|
};
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
print!("> ");
|
print!("> ");
|
||||||
@ -3182,6 +3243,9 @@ pub async fn run(
|
|||||||
} else {
|
} else {
|
||||||
println!("Conversation cleared.\n");
|
println!("Conversation cleared.\n");
|
||||||
}
|
}
|
||||||
|
if let Some(path) = session_state_file.as_deref() {
|
||||||
|
save_interactive_session_history(path, &history)?;
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
@ -3266,6 +3330,10 @@ pub async fn run(
|
|||||||
|
|
||||||
// Hard cap as a safety net.
|
// Hard cap as a safety net.
|
||||||
trim_history(&mut history, config.agent.max_history_messages);
|
trim_history(&mut history, config.agent.max_history_messages);
|
||||||
|
|
||||||
|
if let Some(path) = session_state_file.as_deref() {
|
||||||
|
save_interactive_session_history(path, &history)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3467,6 +3535,110 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
apply_compaction_summary, build_compaction_transcript, load_interactive_session_history,
|
||||||
|
save_interactive_session_history, InteractiveSessionState,
|
||||||
|
};
|
||||||
|
use crate::providers::ChatMessage;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interactive_session_state_round_trips_history() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let path = dir.path().join("session.json");
|
||||||
|
let history = vec![
|
||||||
|
ChatMessage::system("system"),
|
||||||
|
ChatMessage::user("hello"),
|
||||||
|
ChatMessage::assistant("hi"),
|
||||||
|
];
|
||||||
|
|
||||||
|
save_interactive_session_history(&path, &history).unwrap();
|
||||||
|
let restored = load_interactive_session_history(&path, "fallback").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(restored.len(), 3);
|
||||||
|
assert_eq!(restored[0].role, "system");
|
||||||
|
assert_eq!(restored[1].content, "hello");
|
||||||
|
assert_eq!(restored[2].content, "hi");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interactive_session_state_adds_missing_system_prompt() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let path = dir.path().join("session.json");
|
||||||
|
let payload = serde_json::to_string_pretty(&InteractiveSessionState {
|
||||||
|
version: 1,
|
||||||
|
history: vec![ChatMessage::user("orphan")],
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
std::fs::write(&path, payload).unwrap();
|
||||||
|
|
||||||
|
let restored = load_interactive_session_history(&path, "fallback system").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(restored[0].role, "system");
|
||||||
|
assert_eq!(restored[0].content, "fallback system");
|
||||||
|
assert_eq!(restored[1].content, "orphan");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interactive_session_state_nonexistent_file_returns_fresh_history() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let path = dir.path().join("does_not_exist.json");
|
||||||
|
|
||||||
|
let history = load_interactive_session_history(&path, "fresh prompt").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
assert_eq!(history[0].content, "fresh prompt");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interactive_session_state_malformed_json_returns_error() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let path = dir.path().join("bad.json");
|
||||||
|
std::fs::write(&path, "NOT VALID JSON {{{").unwrap();
|
||||||
|
|
||||||
|
let result = load_interactive_session_history(&path, "prompt");
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interactive_session_state_unsupported_version_returns_error() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let path = dir.path().join("v99.json");
|
||||||
|
let payload = serde_json::to_string_pretty(&InteractiveSessionState {
|
||||||
|
version: 99,
|
||||||
|
history: vec![ChatMessage::user("msg")],
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
std::fs::write(&path, payload).unwrap();
|
||||||
|
|
||||||
|
let result = load_interactive_session_history(&path, "prompt");
|
||||||
|
assert!(result.is_err());
|
||||||
|
let err_msg = result.unwrap_err().to_string();
|
||||||
|
assert!(
|
||||||
|
err_msg.contains("unsupported session state version"),
|
||||||
|
"unexpected error: {err_msg}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn interactive_session_state_refreshes_stale_system_prompt() {
|
||||||
|
let dir = tempdir().unwrap();
|
||||||
|
let path = dir.path().join("session.json");
|
||||||
|
let history = vec![
|
||||||
|
ChatMessage::system("old system prompt"),
|
||||||
|
ChatMessage::user("hello"),
|
||||||
|
];
|
||||||
|
|
||||||
|
save_interactive_session_history(&path, &history).unwrap();
|
||||||
|
let restored = load_interactive_session_history(&path, "new system prompt").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(restored.len(), 2);
|
||||||
|
assert_eq!(restored[0].role, "system");
|
||||||
|
assert_eq!(restored[0].content, "new system prompt");
|
||||||
|
assert_eq!(restored[1].content, "hello");
|
||||||
|
}
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||||
|
|||||||
@ -175,6 +175,7 @@ async fn run_agent_job(
|
|||||||
config.default_temperature,
|
config.default_temperature,
|
||||||
vec![],
|
vec![],
|
||||||
false,
|
false,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|||||||
@ -235,6 +235,7 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
|||||||
temp,
|
temp,
|
||||||
vec![],
|
vec![],
|
||||||
false,
|
false,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|||||||
23
src/main.rs
23
src/main.rs
@ -37,6 +37,7 @@ use clap::{CommandFactory, Parser, Subcommand, ValueEnum};
|
|||||||
use dialoguer::{Input, Password};
|
use dialoguer::{Input, Password};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use tracing_subscriber::{fmt, EnvFilter};
|
use tracing_subscriber::{fmt, EnvFilter};
|
||||||
|
|
||||||
@ -180,6 +181,10 @@ Examples:
|
|||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
message: Option<String>,
|
message: Option<String>,
|
||||||
|
|
||||||
|
/// Load and save interactive session state in this JSON file
|
||||||
|
#[arg(long)]
|
||||||
|
session_state_file: Option<PathBuf>,
|
||||||
|
|
||||||
/// Provider to use (openrouter, anthropic, openai, openai-codex)
|
/// Provider to use (openrouter, anthropic, openai, openai-codex)
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
provider: Option<String>,
|
provider: Option<String>,
|
||||||
@ -814,6 +819,7 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
Commands::Agent {
|
Commands::Agent {
|
||||||
message,
|
message,
|
||||||
|
session_state_file,
|
||||||
provider,
|
provider,
|
||||||
model,
|
model,
|
||||||
temperature,
|
temperature,
|
||||||
@ -829,6 +835,7 @@ async fn main() -> Result<()> {
|
|||||||
final_temperature,
|
final_temperature,
|
||||||
peripheral,
|
peripheral,
|
||||||
true,
|
true,
|
||||||
|
session_state_file,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
@ -2218,6 +2225,22 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_command_parses_session_state_file() {
|
||||||
|
let cli =
|
||||||
|
Cli::try_parse_from(["zeroclaw", "agent", "--session-state-file", "session.json"])
|
||||||
|
.expect("agent command with session state file should parse");
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Commands::Agent {
|
||||||
|
session_state_file, ..
|
||||||
|
} => {
|
||||||
|
assert_eq!(session_state_file, Some(PathBuf::from("session.json")));
|
||||||
|
}
|
||||||
|
other => panic!("expected agent command, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn agent_fallback_uses_config_default_temperature() {
|
fn agent_fallback_uses_config_default_temperature() {
|
||||||
// Test that when user doesn't provide --temperature,
|
// Test that when user doesn't provide --temperature,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user