merge(main): sync upstream main into feature branch
This commit is contained in:
commit
f547e4d966
2
.github/workflows/ci-queue-hygiene.yml
vendored
2
.github/workflows/ci-queue-hygiene.yml
vendored
@ -51,6 +51,8 @@ jobs:
|
||||
- name: Run queue hygiene policy
|
||||
id: hygiene
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
mkdir -p artifacts
|
||||
|
||||
9
.github/workflows/pr-auto-response.yml
vendored
9
.github/workflows/pr-auto-response.yml
vendored
@ -8,7 +8,9 @@ on:
|
||||
types: [opened, labeled, unlabeled]
|
||||
|
||||
concurrency:
|
||||
group: pr-auto-response-${{ github.event.pull_request.number || github.event.issue.number || github.run_id }}
|
||||
# Keep cancellation within the same lifecycle action to avoid `labeled`
|
||||
# events canceling an in-flight `opened` run for the same issue/PR.
|
||||
group: pr-auto-response-${{ github.event.pull_request.number || github.event.issue.number || github.run_id }}-${{ github.event.action || 'unknown' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions: {}
|
||||
@ -21,11 +23,10 @@ env:
|
||||
|
||||
jobs:
|
||||
contributor-tier-issues:
|
||||
# Only run for opened/reopened events to avoid duplicate runs with labeled-routes job
|
||||
if: >-
|
||||
(github.event_name == 'issues' &&
|
||||
(github.event.action == 'opened' || github.event.action == 'reopened' || github.event.action == 'labeled' || github.event.action == 'unlabeled')) ||
|
||||
(github.event_name == 'pull_request_target' &&
|
||||
(github.event.action == 'labeled' || github.event.action == 'unlabeled'))
|
||||
(github.event.action == 'opened' || github.event.action == 'reopened'))
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
16
.github/workflows/test-self-hosted.yml
vendored
16
.github/workflows/test-self-hosted.yml
vendored
@ -11,6 +11,18 @@ jobs:
|
||||
run: |
|
||||
echo "Runner: $(hostname)"
|
||||
echo "OS: $(uname -a)"
|
||||
echo "Docker: $(docker --version)"
|
||||
if command -v docker >/dev/null 2>&1; then
|
||||
echo "Docker: $(docker --version)"
|
||||
else
|
||||
echo "Docker: <not installed>"
|
||||
fi
|
||||
- name: Test Docker
|
||||
run: docker run --rm hello-world
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if ! command -v docker >/dev/null 2>&1; then
|
||||
echo "::notice::Docker is not installed on this self-hosted runner. Skipping docker smoke test."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
docker run --rm hello-world
|
||||
|
||||
@ -6,6 +6,7 @@ resolver = "2"
|
||||
name = "zeroclaw"
|
||||
version = "0.1.7"
|
||||
edition = "2021"
|
||||
build = "build.rs"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
|
||||
80
build.rs
Normal file
80
build.rs
Normal file
@ -0,0 +1,80 @@
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
|
||||
fn git_short_sha(manifest_dir: &str) -> Option<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "--short", "HEAD"])
|
||||
.current_dir(manifest_dir)
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let short_sha = String::from_utf8(output.stdout).ok()?;
|
||||
let trimmed = short_sha.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_git_rerun_hints(manifest_dir: &str) {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "--git-dir"])
|
||||
.current_dir(manifest_dir)
|
||||
.output();
|
||||
|
||||
let Ok(output) = output else {
|
||||
return;
|
||||
};
|
||||
if !output.status.success() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(git_dir_raw) = String::from_utf8(output.stdout) else {
|
||||
return;
|
||||
};
|
||||
let git_dir_raw = git_dir_raw.trim();
|
||||
if git_dir_raw.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let git_dir = if PathBuf::from(git_dir_raw).is_absolute() {
|
||||
PathBuf::from(git_dir_raw)
|
||||
} else {
|
||||
PathBuf::from(manifest_dir).join(git_dir_raw)
|
||||
};
|
||||
|
||||
println!("cargo:rerun-if-changed={}", git_dir.join("HEAD").display());
|
||||
println!("cargo:rerun-if-changed={}", git_dir.join("refs").display());
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=build.rs");
|
||||
println!("cargo:rerun-if-env-changed=ZEROCLAW_GIT_SHORT_SHA");
|
||||
|
||||
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
|
||||
emit_git_rerun_hints(&manifest_dir);
|
||||
|
||||
let package_version = env::var("CARGO_PKG_VERSION").unwrap_or_else(|_| "0.0.0".to_string());
|
||||
let short_sha = env::var("ZEROCLAW_GIT_SHORT_SHA")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
.or_else(|| git_short_sha(&manifest_dir));
|
||||
|
||||
let build_version = if let Some(sha) = short_sha.as_deref() {
|
||||
format!("{package_version} ({sha})")
|
||||
} else {
|
||||
package_version
|
||||
};
|
||||
|
||||
println!("cargo:rustc-env=ZEROCLAW_BUILD_VERSION={build_version}");
|
||||
println!(
|
||||
"cargo:rustc-env=ZEROCLAW_GIT_SHORT_SHA={}",
|
||||
short_sha.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
@ -56,6 +56,8 @@ Telegram/Discord sender-scoped model routing:
|
||||
Supervised tool approvals (all non-CLI channels):
|
||||
- `/approve-request <tool-name>` — create a pending approval request
|
||||
- `/approve-confirm <request-id>` — confirm pending request (same sender + same chat/channel only)
|
||||
- `/approve-allow <request-id>` — approve the current pending runtime execution request once (no policy persistence)
|
||||
- `/approve-deny <request-id>` — deny the current pending runtime execution request
|
||||
- `/approve-pending` — list pending requests for your current sender+chat/channel scope
|
||||
- `/approve <tool-name>` — direct one-step approve + persist (`autonomy.auto_approve`, compatibility path)
|
||||
- `/unapprove <tool-name>` — revoke and remove persisted approval
|
||||
@ -76,6 +78,7 @@ Notes:
|
||||
- You can restrict who can use approval-management commands via `[autonomy].non_cli_approval_approvers`.
|
||||
- Configure natural-language approval mode via `[autonomy].non_cli_natural_language_approval_mode`.
|
||||
- `autonomy.non_cli_excluded_tools` is reloaded from `config.toml` at runtime; `/approvals` shows the currently effective list.
|
||||
- Default non-CLI exclusions include both `shell` and `process`; remove `process` from `[autonomy].non_cli_excluded_tools` only when you explicitly want background command execution in chat channels.
|
||||
- Each incoming message injects a runtime tool-availability snapshot into the system prompt, derived from the same exclusion policy used by execution.
|
||||
|
||||
## Inbound Image Marker Protocol
|
||||
|
||||
@ -191,6 +191,8 @@ Runtime in-chat commands while channel server is running:
|
||||
- Supervised tool approvals (all non-CLI channels):
|
||||
- `/approve-request <tool-name>` (create pending approval request)
|
||||
- `/approve-confirm <request-id>` (confirm pending request; same sender + same chat/channel only)
|
||||
- `/approve-allow <request-id>` (approve current pending runtime execution request once; no policy persistence)
|
||||
- `/approve-deny <request-id>` (deny current pending runtime execution request)
|
||||
- `/approve-pending` (list pending requests in current sender+chat/channel scope)
|
||||
- `/approve <tool-name>` (direct one-step grant + persist to `autonomy.auto_approve`, compatibility path)
|
||||
- `/unapprove <tool-name>` (revoke + remove from `autonomy.auto_approve`)
|
||||
|
||||
@ -712,8 +712,8 @@ When using `credential_profile`, do not also set the same header key in `args.he
|
||||
| Key | Default | Purpose |
|
||||
|---|---|---|
|
||||
| `enabled` | `false` | Enable `web_fetch` for page-to-text extraction |
|
||||
| `provider` | `fast_html2md` | Fetch/render backend: `fast_html2md`, `nanohtml2text`, `firecrawl` |
|
||||
| `api_key` | unset | API key for provider backends that require it (e.g. `firecrawl`) |
|
||||
| `provider` | `fast_html2md` | Fetch/render backend: `fast_html2md`, `nanohtml2text`, `firecrawl`, `tavily` |
|
||||
| `api_key` | unset | API key for provider backends that require it (e.g. `firecrawl`, `tavily`) |
|
||||
| `api_url` | unset | Optional API URL override (self-hosted/alternate endpoint) |
|
||||
| `allowed_domains` | `["*"]` | Domain allowlist (`"*"` allows all public domains) |
|
||||
| `blocked_domains` | `[]` | Denylist applied before allowlist |
|
||||
@ -857,6 +857,7 @@ Environment overrides:
|
||||
| `level` | `supervised` | `read_only`, `supervised`, or `full` |
|
||||
| `workspace_only` | `true` | reject absolute path inputs unless explicitly disabled |
|
||||
| `allowed_commands` | _required for shell execution_ | allowlist of executable names, explicit executable paths, or `"*"` |
|
||||
| `command_context_rules` | `[]` | per-command context-aware allow/deny rules (domain/path constraints, optional high-risk override) |
|
||||
| `forbidden_paths` | built-in protected list | explicit path denylist (system paths + sensitive dotdirs by default) |
|
||||
| `allowed_roots` | `[]` | additional roots allowed outside workspace after canonicalization |
|
||||
| `max_actions_per_hour` | `20` | per-policy action budget |
|
||||
@ -867,7 +868,7 @@ Environment overrides:
|
||||
| `allow_sensitive_file_writes` | `false` | allow `file_write`/`file_edit` on sensitive files/dirs (for example `.env`, `.aws/credentials`, private keys) |
|
||||
| `auto_approve` | `[]` | tool operations always auto-approved |
|
||||
| `always_ask` | `[]` | tool operations that always require approval |
|
||||
| `non_cli_excluded_tools` | `[]` | tools hidden from non-CLI channel tool specs |
|
||||
| `non_cli_excluded_tools` | built-in denylist (includes `shell`, `process`, `file_write`, ...) | tools hidden from non-CLI channel tool specs |
|
||||
| `non_cli_approval_approvers` | `[]` | optional allowlist for who can run non-CLI approval-management commands |
|
||||
| `non_cli_natural_language_approval_mode` | `direct` | natural-language behavior for approval-management commands (`direct`, `request_confirm`, `disabled`) |
|
||||
| `non_cli_natural_language_approval_mode_by_channel` | `{}` | per-channel override map for natural-language approval mode |
|
||||
@ -878,6 +879,10 @@ Notes:
|
||||
- Access outside the workspace requires `allowed_roots`, even when `workspace_only = false`.
|
||||
- `allowed_roots` supports absolute paths, `~/...`, and workspace-relative paths.
|
||||
- `allowed_commands` entries can be command names (for example, `"git"`), explicit executable paths (for example, `"/usr/bin/antigravity"`), or `"*"` to allow any command name/path (risk gates still apply).
|
||||
- `command_context_rules` can narrow or override `allowed_commands` for matching commands:
|
||||
- `action = "allow"` rules are restrictive when present for a command: at least one allow rule must match.
|
||||
- `action = "deny"` rules explicitly block matching contexts.
|
||||
- `allow_high_risk = true` allows a matching high-risk command to pass the hard block, but supervised mode still requires `approved=true`.
|
||||
- `file_read` blocks sensitive secret-bearing files/directories by default. Set `allow_sensitive_file_reads = true` only for controlled debugging sessions.
|
||||
- `file_write` and `file_edit` block sensitive secret-bearing files/directories by default. Set `allow_sensitive_file_writes = true` only for controlled break-glass sessions.
|
||||
- `file_read`, `file_write`, and `file_edit` refuse multiply-linked files (hard-link guard) to reduce workspace path bypass risk via hard-link escapes.
|
||||
@ -887,6 +892,10 @@ Notes:
|
||||
- One-step flow: `/approve <tool>`.
|
||||
- Two-step flow: `/approve-request <tool>` then `/approve-confirm <request-id>` (same sender + same chat/channel).
|
||||
Both paths write to `autonomy.auto_approve` and remove the tool from `autonomy.always_ask`.
|
||||
- For pending runtime execution prompts (including Telegram inline approval buttons), use:
|
||||
- `/approve-allow <request-id>` to approve only the current pending request.
|
||||
- `/approve-deny <request-id>` to reject the current pending request.
|
||||
This path does not modify `autonomy.auto_approve` or `autonomy.always_ask`.
|
||||
- `non_cli_natural_language_approval_mode` controls how strict natural-language approval intents are:
|
||||
- `direct` (default): natural-language approval grants immediately (private-chat friendly).
|
||||
- `request_confirm`: natural-language approval creates a pending request that needs explicit confirm.
|
||||
@ -899,6 +908,7 @@ Notes:
|
||||
- `telegram:alice` allows only that channel+sender pair.
|
||||
- `telegram:*` allows any sender on Telegram.
|
||||
- `*:alice` allows `alice` on any channel.
|
||||
- By default, `process` is excluded on non-CLI channels alongside `shell`. To opt in intentionally, remove `"process"` from `[autonomy].non_cli_excluded_tools` in `config.toml`.
|
||||
- Use `/unapprove <tool>` to remove persisted approval from `autonomy.auto_approve`.
|
||||
- `/approve-pending` lists pending requests for the current sender+chat/channel scope.
|
||||
- If a tool remains unavailable after approval, check `autonomy.non_cli_excluded_tools` (runtime `/approvals` shows this list). Channel runtime reloads this list from `config.toml` automatically.
|
||||
@ -908,6 +918,18 @@ Notes:
|
||||
workspace_only = false
|
||||
forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"]
|
||||
allowed_roots = ["~/Desktop/projects", "/opt/shared-repo"]
|
||||
|
||||
[[autonomy.command_context_rules]]
|
||||
command = "curl"
|
||||
action = "allow"
|
||||
allowed_domains = ["api.github.com", "*.example.internal"]
|
||||
allow_high_risk = true
|
||||
|
||||
[[autonomy.command_context_rules]]
|
||||
command = "rm"
|
||||
action = "allow"
|
||||
allowed_path_prefixes = ["/tmp"]
|
||||
allow_high_risk = true
|
||||
```
|
||||
|
||||
## `[memory]`
|
||||
|
||||
@ -321,6 +321,13 @@ def main() -> int:
|
||||
|
||||
owner, repo = split_repo(args.repo)
|
||||
token = resolve_token(args.token)
|
||||
if args.apply and not token:
|
||||
print(
|
||||
"queue_hygiene: apply mode requires authentication token "
|
||||
"(set GH_TOKEN/GITHUB_TOKEN, pass --token, or configure gh auth).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
api = GitHubApi(args.api_url, token)
|
||||
|
||||
if args.runs_json:
|
||||
|
||||
@ -3946,6 +3946,64 @@ class CiScriptsBehaviorTest(unittest.TestCase):
|
||||
self.assertEqual(report["planned_actions"], [])
|
||||
self.assertEqual(report["policies"]["non_pr_key"], "sha")
|
||||
|
||||
def test_queue_hygiene_apply_requires_authentication_token(self) -> None:
|
||||
runs_json = self.tmp / "runs-apply-auth.json"
|
||||
output_json = self.tmp / "queue-hygiene-apply-auth.json"
|
||||
runs_json.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"workflow_runs": [
|
||||
{
|
||||
"id": 401,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "main",
|
||||
"head_sha": "sha-401",
|
||||
"created_at": "2026-02-27T20:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": 402,
|
||||
"name": "CI Run",
|
||||
"event": "push",
|
||||
"head_branch": "main",
|
||||
"head_sha": "sha-402",
|
||||
"created_at": "2026-02-27T20:01:00Z",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
isolated_home = self.tmp / "isolated-home"
|
||||
isolated_home.mkdir(parents=True, exist_ok=True)
|
||||
isolated_xdg = self.tmp / "isolated-xdg"
|
||||
isolated_xdg.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
env = dict(os.environ)
|
||||
env["GH_TOKEN"] = ""
|
||||
env["GITHUB_TOKEN"] = ""
|
||||
env["HOME"] = str(isolated_home)
|
||||
env["XDG_CONFIG_HOME"] = str(isolated_xdg)
|
||||
|
||||
proc = run_cmd(
|
||||
[
|
||||
"python3",
|
||||
self._script("queue_hygiene.py"),
|
||||
"--runs-json",
|
||||
str(runs_json),
|
||||
"--dedupe-workflow",
|
||||
"CI Run",
|
||||
"--apply",
|
||||
"--output-json",
|
||||
str(output_json),
|
||||
],
|
||||
env=env,
|
||||
)
|
||||
self.assertEqual(proc.returncode, 2)
|
||||
self.assertIn("requires authentication token", proc.stderr.lower())
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@ -243,6 +243,10 @@ impl Agent {
|
||||
AgentBuilder::new()
|
||||
}
|
||||
|
||||
pub fn tool_specs(&self) -> &[ToolSpec] {
|
||||
&self.tool_specs
|
||||
}
|
||||
|
||||
pub fn history(&self) -> &[ConversationMessage] {
|
||||
&self.history
|
||||
}
|
||||
|
||||
@ -1043,7 +1043,7 @@ pub(crate) async fn run_tool_call_loop_with_non_cli_approval_context(
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn run_tool_call_loop(
|
||||
pub async fn run_tool_call_loop(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
@ -3017,6 +3017,7 @@ pub async fn run(
|
||||
&model_name,
|
||||
config.agent.max_history_messages,
|
||||
effective_hooks,
|
||||
Some(mem.as_ref()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@ -1,9 +1,29 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory, MemoryCategory};
|
||||
use std::fmt::Write;
|
||||
|
||||
/// Default half-life (days) for time decay in context building.
|
||||
const CONTEXT_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Score boost applied to `Core` category memories so durable facts and
|
||||
/// preferences surface even when keyword/semantic similarity is moderate.
|
||||
const CORE_CATEGORY_SCORE_BOOST: f64 = 0.3;
|
||||
|
||||
/// Maximum number of memory entries included in the context preamble.
|
||||
const CONTEXT_ENTRY_LIMIT: usize = 5;
|
||||
|
||||
/// Over-fetch factor: retrieve more candidates than the output limit so
|
||||
/// that Core boost and re-ranking can select the best subset.
|
||||
const RECALL_OVER_FETCH_FACTOR: usize = 2;
|
||||
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
///
|
||||
/// Core memories are exempt from time decay (evergreen).
|
||||
///
|
||||
/// `Core` category memories receive a score boost so that durable facts,
|
||||
/// preferences, and project rules are more likely to appear in context
|
||||
/// even when semantic similarity to the current message is moderate.
|
||||
pub(super) async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
@ -12,29 +32,41 @@ pub(super) async fn build_context(
|
||||
) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id).await {
|
||||
let relevant: Vec<_> = entries
|
||||
// Over-fetch so Core-boosted entries can compete fairly after re-ranking.
|
||||
let fetch_limit = CONTEXT_ENTRY_LIMIT * RECALL_OVER_FETCH_FACTOR;
|
||||
if let Ok(mut entries) = mem.recall(user_msg, fetch_limit, session_id).await {
|
||||
// Apply time decay: older non-Core memories score lower.
|
||||
decay::apply_time_decay(&mut entries, CONTEXT_DECAY_HALF_LIFE_DAYS);
|
||||
|
||||
// Apply Core category boost and filter by minimum relevance.
|
||||
let mut scored: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
Some(score) => score >= min_relevance_score,
|
||||
None => true,
|
||||
.filter(|e| !memory::is_assistant_autosave_key(&e.key))
|
||||
.filter_map(|e| {
|
||||
let base = e.score.unwrap_or(min_relevance_score);
|
||||
let boosted = if e.category == MemoryCategory::Core {
|
||||
(base + CORE_CATEGORY_SCORE_BOOST).min(1.0)
|
||||
} else {
|
||||
base
|
||||
};
|
||||
if boosted >= min_relevance_score {
|
||||
Some((e, boosted))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !relevant.is_empty() {
|
||||
// Sort by boosted score descending, then truncate to output limit.
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
scored.truncate(CONTEXT_ENTRY_LIMIT);
|
||||
|
||||
if !scored.is_empty() {
|
||||
context.push_str("[Memory context]\n");
|
||||
for entry in &relevant {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
continue;
|
||||
}
|
||||
for (entry, _) in &scored {
|
||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||
}
|
||||
if context == "[Memory context]\n" {
|
||||
context.clear();
|
||||
} else {
|
||||
context.push('\n');
|
||||
}
|
||||
context.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,3 +112,135 @@ pub(super) fn build_hardware_context(
|
||||
context.push('\n');
|
||||
context
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct MockMemory {
|
||||
entries: Arc<Vec<MemoryEntry>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for MockMemory {
|
||||
async fn store(
|
||||
&self,
|
||||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(self.entries.as_ref().clone())
|
||||
}
|
||||
|
||||
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(self.entries.len())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"mock-memory"
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_context_promotes_core_entries_with_score_boost() {
|
||||
let memory = MockMemory {
|
||||
entries: Arc::new(vec![
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "conv_note".into(),
|
||||
content: "small talk".into(),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.6),
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
key: "core_rule".into(),
|
||||
content: "always provide tests".into(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.2),
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "3".into(),
|
||||
key: "conv_low".into(),
|
||||
content: "irrelevant".into(),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.1),
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
let context = build_context(&memory, "test query", 0.4, None).await;
|
||||
assert!(
|
||||
context.contains("core_rule"),
|
||||
"expected core boost to include core_rule"
|
||||
);
|
||||
assert!(
|
||||
!context.contains("conv_low"),
|
||||
"low-score non-core should be filtered"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_context_keeps_output_limit_at_five_entries() {
|
||||
let entries = (0..8)
|
||||
.map(|idx| MemoryEntry {
|
||||
id: idx.to_string(),
|
||||
key: format!("k{idx}"),
|
||||
content: format!("v{idx}"),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.9 - (idx as f64 * 0.01)),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let memory = MockMemory {
|
||||
entries: Arc::new(entries),
|
||||
};
|
||||
|
||||
let context = build_context(&memory, "limit", 0.0, None).await;
|
||||
let listed = context
|
||||
.lines()
|
||||
.filter(|line| line.starts_with("- "))
|
||||
.count();
|
||||
assert_eq!(listed, 5, "context output limit should remain 5 entries");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use crate::memory::{Memory, MemoryCategory};
|
||||
use crate::providers::{ChatMessage, Provider};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
@ -12,6 +13,9 @@ const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000;
|
||||
/// Max characters retained in stored compaction summary.
|
||||
const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000;
|
||||
|
||||
/// Safety cap for durable facts extracted during pre-compaction flush.
|
||||
const COMPACTION_MAX_FLUSH_FACTS: usize = 8;
|
||||
|
||||
/// Trim conversation history to prevent unbounded growth.
|
||||
/// Preserves the system prompt (first message if role=system) and the most recent messages.
|
||||
pub(super) fn trim_history(history: &mut Vec<ChatMessage>, max_history: usize) {
|
||||
@ -67,6 +71,7 @@ pub(super) async fn auto_compact_history(
|
||||
model: &str,
|
||||
max_history: usize,
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
memory: Option<&dyn Memory>,
|
||||
) -> Result<bool> {
|
||||
let has_system = history.first().map_or(false, |m| m.role == "system");
|
||||
let non_system_count = if has_system {
|
||||
@ -105,6 +110,13 @@ pub(super) async fn auto_compact_history(
|
||||
};
|
||||
let transcript = build_compaction_transcript(&to_compact);
|
||||
|
||||
// ── Pre-compaction memory flush ──────────────────────────────────
|
||||
// Before discarding old messages, ask the LLM to extract durable
|
||||
// facts and store them as Core memories so they survive compaction.
|
||||
if let Some(mem) = memory {
|
||||
flush_durable_facts(provider, model, &transcript, mem).await;
|
||||
}
|
||||
|
||||
let summarizer_system = "You are a conversation compaction engine. Summarize older chat history into concise context for future turns. Preserve: user preferences, commitments, decisions, unresolved tasks, key facts. Omit: filler, repeated chit-chat, verbose tool logs. Output plain text bullet points only.";
|
||||
|
||||
let summarizer_user = format!(
|
||||
@ -137,6 +149,86 @@ pub(super) async fn auto_compact_history(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Extract durable facts from a conversation transcript and store them as
|
||||
/// `Core` memories. Called before compaction discards old messages.
|
||||
///
|
||||
/// Best-effort: failures are logged but never block compaction.
|
||||
async fn flush_durable_facts(
|
||||
provider: &dyn Provider,
|
||||
model: &str,
|
||||
transcript: &str,
|
||||
memory: &dyn Memory,
|
||||
) {
|
||||
const FLUSH_SYSTEM: &str = "\
|
||||
You extract durable facts from a conversation that is about to be compacted. \
|
||||
Output ONLY facts worth remembering long-term — user preferences, project decisions, \
|
||||
technical constraints, commitments, or important discoveries. \
|
||||
Output one fact per line, prefixed with a short key in brackets. \
|
||||
Example:\n\
|
||||
[preferred_language] User prefers Rust over Go\n\
|
||||
[db_choice] Project uses PostgreSQL 16\n\
|
||||
If there are no durable facts, output exactly: NONE";
|
||||
|
||||
let flush_user = format!(
|
||||
"Extract durable facts from this conversation (max 8 facts):\n\n{}",
|
||||
transcript
|
||||
);
|
||||
|
||||
let response = match provider
|
||||
.chat_with_system(Some(FLUSH_SYSTEM), &flush_user, model, 0.2)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Pre-compaction memory flush failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if response.trim().eq_ignore_ascii_case("NONE") || response.trim().is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut stored = 0usize;
|
||||
for line in response.lines() {
|
||||
if stored >= COMPACTION_MAX_FLUSH_FACTS {
|
||||
break;
|
||||
}
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
// Parse "[key] content" format
|
||||
if let Some((key, content)) = parse_fact_line(line) {
|
||||
let prefixed_key = format!("compaction_fact_{key}");
|
||||
if let Err(e) = memory
|
||||
.store(&prefixed_key, content, MemoryCategory::Core, None)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to store compaction fact '{prefixed_key}': {e}");
|
||||
} else {
|
||||
stored += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if stored > 0 {
|
||||
tracing::info!("Pre-compaction flush: stored {stored} durable fact(s) to Core memory");
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a `[key] content` line from the fact extraction output.
|
||||
fn parse_fact_line(line: &str) -> Option<(&str, &str)> {
|
||||
let line = line.trim_start_matches(|c: char| c == '-' || c.is_whitespace());
|
||||
let rest = line.strip_prefix('[')?;
|
||||
let close = rest.find(']')?;
|
||||
let key = rest[..close].trim();
|
||||
let content = rest[close + 1..].trim();
|
||||
if key.is_empty() || content.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((key, content))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -215,10 +307,16 @@ mod tests {
|
||||
// previously cut right before the tool result (index 2).
|
||||
assert_eq!(history.len(), 22);
|
||||
|
||||
let compacted =
|
||||
auto_compact_history(&mut history, &StaticSummaryProvider, "test-model", 21, None)
|
||||
.await
|
||||
.expect("compaction should succeed");
|
||||
let compacted = auto_compact_history(
|
||||
&mut history,
|
||||
&StaticSummaryProvider,
|
||||
"test-model",
|
||||
21,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("compaction should succeed");
|
||||
|
||||
assert!(compacted);
|
||||
assert_eq!(history[0].role, "assistant");
|
||||
@ -231,4 +329,301 @@ mod tests {
|
||||
"first retained message must not be an orphan tool result"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fact_line_extracts_key_and_content() {
|
||||
assert_eq!(
|
||||
parse_fact_line("[preferred_language] User prefers Rust over Go"),
|
||||
Some(("preferred_language", "User prefers Rust over Go"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fact_line_handles_leading_dash() {
|
||||
assert_eq!(
|
||||
parse_fact_line("- [db_choice] Project uses PostgreSQL 16"),
|
||||
Some(("db_choice", "Project uses PostgreSQL 16"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fact_line_rejects_empty_key_or_content() {
|
||||
assert_eq!(parse_fact_line("[] some content"), None);
|
||||
assert_eq!(parse_fact_line("[key]"), None);
|
||||
assert_eq!(parse_fact_line("[key] "), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_fact_line_rejects_malformed_input() {
|
||||
assert_eq!(parse_fact_line("no brackets here"), None);
|
||||
assert_eq!(parse_fact_line(""), None);
|
||||
assert_eq!(parse_fact_line("[unclosed bracket"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auto_compact_with_memory_stores_durable_facts() {
|
||||
use crate::memory::{MemoryCategory, MemoryEntry};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
struct FactCapture {
|
||||
stored: Mutex<Vec<(String, String)>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for FactCapture {
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.stored
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((key.to_string(), content.to_string()));
|
||||
Ok(())
|
||||
}
|
||||
async fn recall(
|
||||
&self,
|
||||
_q: &str,
|
||||
_l: usize,
|
||||
_s: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
async fn get(&self, _k: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
async fn list(
|
||||
&self,
|
||||
_c: Option<&MemoryCategory>,
|
||||
_s: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
async fn forget(&self, _k: &str) -> anyhow::Result<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn name(&self) -> &str {
|
||||
"fact-capture"
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider that returns facts for the first call (flush) and summary for the second (compaction).
|
||||
struct FlushThenSummaryProvider {
|
||||
call_count: Mutex<usize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for FlushThenSummaryProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut count = self.call_count.lock().unwrap();
|
||||
*count += 1;
|
||||
if *count == 1 {
|
||||
// flush_durable_facts call
|
||||
Ok("[lang] User prefers Rust\n[db] PostgreSQL 16".to_string())
|
||||
} else {
|
||||
// summarizer call
|
||||
Ok("- summarized context".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
Ok(ChatResponse {
|
||||
text: Some("- summarized context".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
stop_reason: None,
|
||||
raw_stop_reason: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let mem = Arc::new(FactCapture {
|
||||
stored: Mutex::new(Vec::new()),
|
||||
});
|
||||
let provider = FlushThenSummaryProvider {
|
||||
call_count: Mutex::new(0),
|
||||
};
|
||||
|
||||
let mut history: Vec<ChatMessage> = Vec::new();
|
||||
for i in 0..25 {
|
||||
history.push(ChatMessage::user(format!("msg-{i}")));
|
||||
}
|
||||
|
||||
let compacted = auto_compact_history(
|
||||
&mut history,
|
||||
&provider,
|
||||
"test-model",
|
||||
21,
|
||||
None,
|
||||
Some(mem.as_ref()),
|
||||
)
|
||||
.await
|
||||
.expect("compaction should succeed");
|
||||
|
||||
assert!(compacted);
|
||||
|
||||
let stored = mem.stored.lock().unwrap();
|
||||
assert_eq!(stored.len(), 2, "should store 2 durable facts");
|
||||
assert_eq!(stored[0].0, "compaction_fact_lang");
|
||||
assert_eq!(stored[0].1, "User prefers Rust");
|
||||
assert_eq!(stored[1].0, "compaction_fact_db");
|
||||
assert_eq!(stored[1].1, "PostgreSQL 16");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auto_compact_with_memory_caps_fact_flush_at_eight_entries() {
|
||||
use crate::memory::{MemoryCategory, MemoryEntry};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
struct FactCapture {
|
||||
stored: Mutex<Vec<(String, String)>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for FactCapture {
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.stored
|
||||
.lock()
|
||||
.expect("fact capture lock")
|
||||
.push((key.to_string(), content.to_string()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
_q: &str,
|
||||
_l: usize,
|
||||
_s: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn get(&self, _k: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_c: Option<&MemoryCategory>,
|
||||
_s: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn forget(&self, _k: &str) -> anyhow::Result<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"fact-capture-cap"
|
||||
}
|
||||
}
|
||||
|
||||
struct FlushManyFactsProvider {
|
||||
call_count: Mutex<usize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for FlushManyFactsProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut count = self.call_count.lock().expect("provider lock");
|
||||
*count += 1;
|
||||
if *count == 1 {
|
||||
let lines = (0..12)
|
||||
.map(|idx| format!("[k{idx}] fact-{idx}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
Ok(lines)
|
||||
} else {
|
||||
Ok("- summarized context".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
_request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
Ok(ChatResponse {
|
||||
text: Some("- summarized context".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
quota_metadata: None,
|
||||
stop_reason: None,
|
||||
raw_stop_reason: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let mem = Arc::new(FactCapture {
|
||||
stored: Mutex::new(Vec::new()),
|
||||
});
|
||||
let provider = FlushManyFactsProvider {
|
||||
call_count: Mutex::new(0),
|
||||
};
|
||||
let mut history = (0..30)
|
||||
.map(|idx| ChatMessage::user(format!("msg-{idx}")))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let compacted = auto_compact_history(
|
||||
&mut history,
|
||||
&provider,
|
||||
"test-model",
|
||||
21,
|
||||
None,
|
||||
Some(mem.as_ref()),
|
||||
)
|
||||
.await
|
||||
.expect("compaction should succeed");
|
||||
assert!(compacted);
|
||||
|
||||
let stored = mem.stored.lock().expect("fact capture lock");
|
||||
assert_eq!(stored.len(), COMPACTION_MAX_FLUSH_FACTS);
|
||||
assert_eq!(stored[0].0, "compaction_fact_k0");
|
||||
assert_eq!(stored[7].0, "compaction_fact_k7");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,18 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory, MemoryCategory};
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
/// Default half-life (days) for time decay in memory loading.
|
||||
const LOADER_DECAY_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Score boost applied to `Core` category memories so durable facts and
|
||||
/// preferences surface even when keyword/semantic similarity is moderate.
|
||||
const CORE_CATEGORY_SCORE_BOOST: f64 = 0.3;
|
||||
|
||||
/// Over-fetch factor: retrieve more candidates than the output limit so
|
||||
/// that Core boost and re-ranking can select the best subset.
|
||||
const RECALL_OVER_FETCH_FACTOR: usize = 2;
|
||||
|
||||
#[async_trait]
|
||||
pub trait MemoryLoader: Send + Sync {
|
||||
async fn load_context(&self, memory: &dyn Memory, user_message: &str)
|
||||
@ -38,29 +49,47 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||
// Over-fetch so Core-boosted entries can compete fairly after re-ranking.
|
||||
let fetch_limit = self.limit * RECALL_OVER_FETCH_FACTOR;
|
||||
let mut entries = memory.recall(user_message, fetch_limit, None).await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
continue;
|
||||
}
|
||||
if let Some(score) = entry.score {
|
||||
if score < self.min_relevance_score {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||
}
|
||||
// Apply time decay: older non-Core memories score lower.
|
||||
decay::apply_time_decay(&mut entries, LOADER_DECAY_HALF_LIFE_DAYS);
|
||||
|
||||
// If all entries were below threshold, return empty
|
||||
if context == "[Memory context]\n" {
|
||||
// Apply Core category boost and filter by minimum relevance.
|
||||
let mut scored: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| !memory::is_assistant_autosave_key(&e.key))
|
||||
.filter_map(|e| {
|
||||
let base = e.score.unwrap_or(self.min_relevance_score);
|
||||
let boosted = if e.category == MemoryCategory::Core {
|
||||
(base + CORE_CATEGORY_SCORE_BOOST).min(1.0)
|
||||
} else {
|
||||
base
|
||||
};
|
||||
if boosted >= self.min_relevance_score {
|
||||
Some((e, boosted))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort by boosted score descending, then truncate to output limit.
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
scored.truncate(self.limit);
|
||||
|
||||
if scored.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for (entry, _) in &scored {
|
||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||
}
|
||||
context.push('\n');
|
||||
Ok(context)
|
||||
}
|
||||
@ -227,4 +256,93 @@ mod tests {
|
||||
assert!(!context.contains("assistant_resp_legacy"));
|
||||
assert!(!context.contains("fabricated detail"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn core_category_boost_promotes_low_score_core_entry() {
|
||||
let loader = DefaultMemoryLoader::new(2, 0.4);
|
||||
let memory = MockMemoryWithEntries {
|
||||
entries: Arc::new(vec![
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "chat_detail".into(),
|
||||
content: "talked about weather".into(),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.6),
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
key: "project_rule".into(),
|
||||
content: "always use async/await".into(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
// Below threshold without boost (0.25 < 0.4),
|
||||
// but above with +0.3 boost (0.55 >= 0.4).
|
||||
score: Some(0.25),
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "3".into(),
|
||||
key: "low_conv".into(),
|
||||
content: "irrelevant chatter".into(),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.2),
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
let context = loader.load_context(&memory, "code style").await.unwrap();
|
||||
// Core entry should survive thanks to boost
|
||||
assert!(
|
||||
context.contains("project_rule"),
|
||||
"Core entry should be promoted by boost: {context}"
|
||||
);
|
||||
// Low-score Conversation entry should be filtered out
|
||||
assert!(
|
||||
!context.contains("low_conv"),
|
||||
"Low-score non-Core entry should be filtered: {context}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn core_boost_reranks_above_conversation() {
|
||||
let loader = DefaultMemoryLoader::new(1, 0.0);
|
||||
let memory = MockMemoryWithEntries {
|
||||
entries: Arc::new(vec![
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "conv_high".into(),
|
||||
content: "recent conversation".into(),
|
||||
category: MemoryCategory::Conversation,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
score: Some(0.6),
|
||||
},
|
||||
MemoryEntry {
|
||||
id: "2".into(),
|
||||
key: "core_pref".into(),
|
||||
content: "user prefers Rust".into(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "now".into(),
|
||||
session_id: None,
|
||||
// 0.5 + 0.3 boost = 0.8 > 0.6
|
||||
score: Some(0.5),
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
let context = loader.load_context(&memory, "language").await.unwrap();
|
||||
// With limit=1 and Core boost, Core entry (0.8) should win over Conversation (0.6)
|
||||
assert!(
|
||||
context.contains("core_pref"),
|
||||
"Boosted Core should rank above Conversation: {context}"
|
||||
);
|
||||
assert!(
|
||||
!context.contains("conv_high"),
|
||||
"Conversation should be truncated when limit=1: {context}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,4 +16,4 @@ mod tests;
|
||||
#[allow(unused_imports)]
|
||||
pub use agent::{Agent, AgentBuilder};
|
||||
#[allow(unused_imports)]
|
||||
pub use loop_::{process_message, process_message_with_session, run};
|
||||
pub use loop_::{process_message, process_message_with_session, run, run_tool_call_loop};
|
||||
|
||||
@ -744,6 +744,20 @@ async fn native_dispatcher_sends_tool_specs() {
|
||||
assert!(dispatcher.should_send_tool_specs());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_tool_specs_accessor_exposes_registered_tools() {
|
||||
let provider = Box::new(ScriptedProvider::new(vec![text_response("ok")]));
|
||||
let agent = build_agent_with(
|
||||
provider,
|
||||
vec![Box::new(EchoTool)],
|
||||
Box::new(NativeToolDispatcher),
|
||||
);
|
||||
|
||||
let specs = agent.tool_specs();
|
||||
assert_eq!(specs.len(), 1);
|
||||
assert_eq!(specs[0].name, "echo");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn xml_dispatcher_does_not_send_tool_specs() {
|
||||
let dispatcher = XmlToolDispatcher;
|
||||
|
||||
@ -251,6 +251,14 @@ struct ChannelRuntimeDefaults {
|
||||
api_url: Option<String>,
|
||||
reliability: crate::config::ReliabilityConfig,
|
||||
cost: crate::config::CostConfig,
|
||||
auto_save_memory: bool,
|
||||
max_tool_iterations: usize,
|
||||
min_relevance_score: f64,
|
||||
message_timeout_secs: u64,
|
||||
interrupt_on_new_message: bool,
|
||||
multimodal: crate::config::MultimodalConfig,
|
||||
query_classification: crate::config::QueryClassificationConfig,
|
||||
model_routes: Vec<crate::config::ModelRouteConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -1048,6 +1056,14 @@ fn resolved_default_model(config: &Config) -> String {
|
||||
}
|
||||
|
||||
fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
|
||||
let message_timeout_secs =
|
||||
effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs);
|
||||
let interrupt_on_new_message = config
|
||||
.channels_config
|
||||
.telegram
|
||||
.as_ref()
|
||||
.is_some_and(|tg| tg.interrupt_on_new_message);
|
||||
|
||||
ChannelRuntimeDefaults {
|
||||
default_provider: resolved_default_provider(config),
|
||||
model: resolved_default_model(config),
|
||||
@ -1056,6 +1072,14 @@ fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults {
|
||||
api_url: config.api_url.clone(),
|
||||
reliability: config.reliability.clone(),
|
||||
cost: config.cost.clone(),
|
||||
auto_save_memory: config.memory.auto_save,
|
||||
max_tool_iterations: config.agent.max_tool_iterations,
|
||||
min_relevance_score: config.memory.min_relevance_score,
|
||||
message_timeout_secs,
|
||||
interrupt_on_new_message,
|
||||
multimodal: config.multimodal.clone(),
|
||||
query_classification: config.query_classification.clone(),
|
||||
model_routes: config.model_routes.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1102,6 +1126,14 @@ fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefau
|
||||
api_url: ctx.api_url.clone(),
|
||||
reliability: (*ctx.reliability).clone(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
auto_save_memory: ctx.auto_save_memory,
|
||||
max_tool_iterations: ctx.max_tool_iterations,
|
||||
min_relevance_score: ctx.min_relevance_score,
|
||||
message_timeout_secs: ctx.message_timeout_secs,
|
||||
interrupt_on_new_message: ctx.interrupt_on_new_message,
|
||||
multimodal: ctx.multimodal.clone(),
|
||||
query_classification: ctx.query_classification.clone(),
|
||||
model_routes: ctx.model_routes.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1722,14 +1754,14 @@ fn get_route_selection(ctx: &ChannelRuntimeContext, sender_key: &str) -> Channel
|
||||
/// Classify a user message and return the appropriate route selection with logging.
|
||||
/// Returns None if classification is disabled or no rules match.
|
||||
fn classify_message_route(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
query_classification: &crate::config::QueryClassificationConfig,
|
||||
model_routes: &[crate::config::ModelRouteConfig],
|
||||
message: &str,
|
||||
) -> Option<ChannelRouteSelection> {
|
||||
let decision =
|
||||
crate::agent::classifier::classify_with_decision(&ctx.query_classification, message)?;
|
||||
let decision = crate::agent::classifier::classify_with_decision(query_classification, message)?;
|
||||
|
||||
// Find the matching model route
|
||||
let route = ctx.model_routes.iter().find(|r| r.hint == decision.hint)?;
|
||||
let route = model_routes.iter().find(|r| r.hint == decision.hint)?;
|
||||
|
||||
tracing::info!(
|
||||
target: "query_classification",
|
||||
@ -1956,9 +1988,9 @@ async fn get_or_create_provider(
|
||||
|
||||
let provider = create_resilient_provider_nonblocking(
|
||||
provider_name,
|
||||
ctx.api_key.clone(),
|
||||
defaults.api_key.clone(),
|
||||
api_url.map(ToString::to_string),
|
||||
ctx.reliability.as_ref().clone(),
|
||||
defaults.reliability.clone(),
|
||||
ctx.provider_runtime_options.clone(),
|
||||
)
|
||||
.await?;
|
||||
@ -2245,6 +2277,27 @@ async fn handle_runtime_command_if_needed(
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle `/approve-allow <request-id>` for pending runtime execution prompts.
|
||||
///
|
||||
/// This path confirms only the current pending request and intentionally does
|
||||
/// not persist approval policy changes for normal tools.
|
||||
async fn handle_pending_runtime_approval_side_effects(
|
||||
ctx: &ChannelRuntimeContext,
|
||||
request_id: &str,
|
||||
tool_name: &str,
|
||||
) -> String {
|
||||
if tool_name == APPROVAL_ALL_TOOLS_ONCE_TOKEN {
|
||||
let remaining = ctx.approval_manager.grant_non_cli_allow_all_once();
|
||||
format!(
|
||||
"Approved one-time all-tools bypass from request `{request_id}`.\nApplies to the next non-CLI agent tool-execution turn only.\nThis bypass is runtime-only and does not persist to config.\nChannel exclusions from `autonomy.non_cli_excluded_tools` still apply.\nQueued one-time all-tools bypass tokens: `{remaining}`."
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Approved pending execution request `{request_id}` for `{tool_name}`.\nThis approval applies only to the current pending request and does not change persisted approval policy.\nTo persist approval for future requests, use `/approve {tool_name}` or the `/approve-request` + `/approve-confirm` flow."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let response = match command {
|
||||
ChannelRuntimeCommand::ShowProviders => build_providers_help_response(¤t),
|
||||
ChannelRuntimeCommand::SetProvider(raw_provider) => {
|
||||
@ -2708,11 +2761,10 @@ async fn handle_runtime_command_if_needed(
|
||||
Ok(req) => {
|
||||
ctx.approval_manager
|
||||
.record_non_cli_pending_resolution(&request_id, ApprovalResponse::Yes);
|
||||
let approval_message = handle_confirm_tool_approval_side_effects(
|
||||
let approval_message = handle_pending_runtime_approval_side_effects(
|
||||
ctx,
|
||||
&request_id,
|
||||
&req.tool_name,
|
||||
source_channel,
|
||||
)
|
||||
.await;
|
||||
runtime_trace::record_event(
|
||||
@ -3426,10 +3478,14 @@ or tune thresholds in config.",
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try classification first, fall back to sender/default route
|
||||
let route = classify_message_route(ctx.as_ref(), &msg.content)
|
||||
.unwrap_or_else(|| get_route_selection(ctx.as_ref(), &history_key));
|
||||
let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref());
|
||||
// Try classification first, fall back to sender/default route.
|
||||
let route = classify_message_route(
|
||||
&runtime_defaults.query_classification,
|
||||
&runtime_defaults.model_routes,
|
||||
&msg.content,
|
||||
)
|
||||
.unwrap_or_else(|| get_route_selection(ctx.as_ref(), &history_key));
|
||||
let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await {
|
||||
Ok(provider) => provider,
|
||||
Err(err) => {
|
||||
@ -3449,7 +3505,9 @@ or tune thresholds in config.",
|
||||
return;
|
||||
}
|
||||
};
|
||||
if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS {
|
||||
if runtime_defaults.auto_save_memory
|
||||
&& msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
{
|
||||
let autosave_key = conversation_memory_key(&msg);
|
||||
let _ = ctx
|
||||
.memory
|
||||
@ -3512,7 +3570,7 @@ or tune thresholds in config.",
|
||||
let memory_context = build_memory_context(
|
||||
ctx.memory.as_ref(),
|
||||
&msg.content,
|
||||
ctx.min_relevance_score,
|
||||
runtime_defaults.min_relevance_score,
|
||||
Some(&history_key),
|
||||
)
|
||||
.await;
|
||||
@ -3666,8 +3724,10 @@ or tune thresholds in config.",
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
let timeout_budget_secs =
|
||||
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
|
||||
let timeout_budget_secs = channel_message_timeout_budget_secs(
|
||||
runtime_defaults.message_timeout_secs,
|
||||
runtime_defaults.max_tool_iterations,
|
||||
);
|
||||
let cost_enforcement_context = crate::agent::loop_::create_cost_enforcement_context(
|
||||
&runtime_defaults.cost,
|
||||
ctx.workspace_dir.as_path(),
|
||||
@ -3731,8 +3791,8 @@ or tune thresholds in config.",
|
||||
Some(ctx.approval_manager.as_ref()),
|
||||
msg.channel.as_str(),
|
||||
non_cli_approval_context,
|
||||
&ctx.multimodal,
|
||||
ctx.max_tool_iterations,
|
||||
&runtime_defaults.multimodal,
|
||||
runtime_defaults.max_tool_iterations,
|
||||
Some(cancellation_token.clone()),
|
||||
delta_tx,
|
||||
ctx.hooks.as_deref(),
|
||||
@ -3911,7 +3971,7 @@ or tune thresholds in config.",
|
||||
&history_key,
|
||||
ChatMessage::assistant(&history_response),
|
||||
);
|
||||
if ctx.auto_save_memory
|
||||
if runtime_defaults.auto_save_memory
|
||||
&& delivered_response.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
{
|
||||
let assistant_key = assistant_memory_key(&msg);
|
||||
@ -4024,7 +4084,7 @@ or tune thresholds in config.",
|
||||
}
|
||||
}
|
||||
} else if is_tool_iteration_limit_error(&e) {
|
||||
let limit = ctx.max_tool_iterations.max(1);
|
||||
let limit = runtime_defaults.max_tool_iterations.max(1);
|
||||
let pause_text = format!(
|
||||
"⚠️ Reached tool-iteration limit ({limit}) for this turn. Context and progress were preserved. Reply \"continue\" to resume, or increase `agent.max_tool_iterations`."
|
||||
);
|
||||
@ -4120,7 +4180,9 @@ or tune thresholds in config.",
|
||||
LlmExecutionResult::Completed(Err(_)) => {
|
||||
let timeout_msg = format!(
|
||||
"LLM response timed out after {}s (base={}s, max_tool_iterations={})",
|
||||
timeout_budget_secs, ctx.message_timeout_secs, ctx.max_tool_iterations
|
||||
timeout_budget_secs,
|
||||
runtime_defaults.message_timeout_secs,
|
||||
runtime_defaults.max_tool_iterations
|
||||
);
|
||||
runtime_trace::record_event(
|
||||
"channel_message_timeout",
|
||||
@ -4201,8 +4263,9 @@ async fn run_message_dispatch_loop(
|
||||
let task_sequence = Arc::clone(&task_sequence);
|
||||
workers.spawn(async move {
|
||||
let _permit = permit;
|
||||
let runtime_defaults = runtime_defaults_snapshot(worker_ctx.as_ref());
|
||||
let interrupt_enabled =
|
||||
worker_ctx.interrupt_on_new_message && msg.channel == "telegram";
|
||||
runtime_defaults.interrupt_on_new_message && msg.channel == "telegram";
|
||||
let sender_scope_key = interruption_scope_key(&msg);
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let completion = Arc::new(InFlightTaskCompletion::new());
|
||||
@ -5117,8 +5180,14 @@ fn collect_configured_channels(
|
||||
|
||||
#[cfg(not(feature = "channel-lark"))]
|
||||
if config.channels_config.lark.is_some() || config.channels_config.feishu.is_some() {
|
||||
let executable = std::env::current_exe()
|
||||
.map(|path| path.display().to_string())
|
||||
.unwrap_or_else(|_| "<unknown>".to_string());
|
||||
tracing::warn!(
|
||||
"Lark/Feishu channel is configured but this build was compiled without `channel-lark`; skipping Lark/Feishu health check."
|
||||
"Lark/Feishu channel is configured but this binary was compiled without `channel-lark`; skipping Lark/Feishu startup. \
|
||||
binary={executable}. \
|
||||
If you built from source, run the built artifact directly (for example `./target/release/zeroclaw daemon`) \
|
||||
or run `cargo run --features channel-lark -- daemon`."
|
||||
);
|
||||
}
|
||||
|
||||
@ -6721,6 +6790,36 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
}
|
||||
}
|
||||
|
||||
struct MockProcessTool;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Tool for MockProcessTool {
|
||||
fn name(&self) -> &str {
|
||||
"process"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Mock process tool for runtime visibility tests"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": { "type": "string" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: String::new(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_runtime_tool_visibility_prompt_respects_excluded_snapshot() {
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockPriceTool), Box::new(MockEchoTool)];
|
||||
@ -6739,6 +6838,23 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(!native.contains("## Tool Use Protocol"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_runtime_tool_visibility_prompt_excludes_process_with_default_policy() {
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockProcessTool), Box::new(MockEchoTool)];
|
||||
let excluded = crate::config::AutonomyConfig::default().non_cli_excluded_tools;
|
||||
|
||||
assert!(
|
||||
excluded.contains(&"process".to_string()),
|
||||
"default non-CLI exclusion list must include process"
|
||||
);
|
||||
|
||||
let prompt = build_runtime_tool_visibility_prompt(&tools, &excluded, false);
|
||||
assert!(prompt.contains("Excluded by runtime policy:"));
|
||||
assert!(prompt.contains("process"));
|
||||
assert!(!prompt.contains("**process**:"));
|
||||
assert!(prompt.contains("`mock_echo`"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_injects_runtime_tool_visibility_prompt() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
@ -7608,7 +7724,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
|
||||
let sent = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(sent[0].contains("Approved supervised execution for `mock_price`"));
|
||||
assert!(sent[0].contains("Approved pending execution request"));
|
||||
assert!(sent[0].contains("mock_price"));
|
||||
drop(sent);
|
||||
|
||||
@ -7820,7 +7936,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
);
|
||||
assert!(
|
||||
sent.iter()
|
||||
.any(|entry| entry.contains("Approved supervised execution for `mock_price`")),
|
||||
.any(|entry| entry.contains("Approved pending execution request")),
|
||||
"channel should acknowledge explicit approval command"
|
||||
);
|
||||
assert!(
|
||||
@ -7832,6 +7948,17 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
sent.iter().all(|entry| !entry.contains("Denied by user.")),
|
||||
"always_ask tool should not be silently denied once non-cli approval prompt path is wired"
|
||||
);
|
||||
assert!(
|
||||
runtime_ctx.approval_manager.needs_approval("mock_price"),
|
||||
"/approve-allow should not downgrade always_ask policy for future requests"
|
||||
);
|
||||
assert!(
|
||||
runtime_ctx
|
||||
.approval_manager
|
||||
.always_ask_tools()
|
||||
.contains("mock_price"),
|
||||
"always_ask runtime policy should remain intact after one-shot approval"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -8245,7 +8372,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
|
||||
let sent = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(sent[0].contains("Approved supervised execution for `mock_price`"));
|
||||
assert!(sent[0].contains("Approved pending execution request"));
|
||||
assert!(sent[0].contains("mock_price"));
|
||||
drop(sent);
|
||||
|
||||
@ -8256,6 +8383,14 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager.take_non_cli_pending_resolution(&request_id),
|
||||
Some(ApprovalResponse::Yes)
|
||||
);
|
||||
assert!(
|
||||
approval_manager.needs_approval("mock_price"),
|
||||
"/approve-allow should not persistently auto-approve tools"
|
||||
);
|
||||
assert!(
|
||||
approval_manager.always_ask_tools().contains("mock_price"),
|
||||
"always_ask tool should remain in always_ask after one-shot approval"
|
||||
);
|
||||
assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
@ -9411,6 +9546,14 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
api_url: None,
|
||||
reliability: crate::config::ReliabilityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
model_routes: Vec::new(),
|
||||
},
|
||||
perplexity_filter: crate::config::PerplexityFilterConfig::default(),
|
||||
outbound_leak_guard: crate::config::OutboundLeakGuardConfig::default(),
|
||||
@ -9593,6 +9736,13 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
cfg.default_provider = Some("ollama".to_string());
|
||||
cfg.default_model = Some("llama3.2".to_string());
|
||||
cfg.api_key = Some("http://127.0.0.1:11434".to_string());
|
||||
cfg.memory.auto_save = false;
|
||||
cfg.memory.min_relevance_score = 0.15;
|
||||
cfg.agent.max_tool_iterations = 5;
|
||||
cfg.channels_config.message_timeout_secs = 45;
|
||||
cfg.multimodal.allow_remote_fetch = false;
|
||||
cfg.query_classification.enabled = false;
|
||||
cfg.model_routes = vec![];
|
||||
cfg.autonomy.non_cli_natural_language_approval_mode =
|
||||
crate::config::NonCliNaturalLanguageApprovalMode::Direct;
|
||||
cfg.autonomy.non_cli_excluded_tools = vec!["shell".to_string()];
|
||||
@ -9659,6 +9809,14 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
runtime_outbound_leak_guard_snapshot(runtime_ctx.as_ref()).action,
|
||||
crate::config::OutboundLeakGuardAction::Redact
|
||||
);
|
||||
let defaults = runtime_defaults_snapshot(runtime_ctx.as_ref());
|
||||
assert!(!defaults.auto_save_memory);
|
||||
assert_eq!(defaults.min_relevance_score, 0.15);
|
||||
assert_eq!(defaults.max_tool_iterations, 5);
|
||||
assert_eq!(defaults.message_timeout_secs, 45);
|
||||
assert!(!defaults.multimodal.allow_remote_fetch);
|
||||
assert!(!defaults.query_classification.enabled);
|
||||
assert!(defaults.model_routes.is_empty());
|
||||
|
||||
cfg.autonomy.non_cli_natural_language_approval_mode =
|
||||
crate::config::NonCliNaturalLanguageApprovalMode::Disabled;
|
||||
@ -9674,6 +9832,28 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
cfg.security.perplexity_filter.perplexity_threshold = 12.5;
|
||||
cfg.security.outbound_leak_guard.action = crate::config::OutboundLeakGuardAction::Block;
|
||||
cfg.security.outbound_leak_guard.sensitivity = 0.92;
|
||||
cfg.memory.auto_save = true;
|
||||
cfg.memory.min_relevance_score = 0.65;
|
||||
cfg.agent.max_tool_iterations = 11;
|
||||
cfg.channels_config.message_timeout_secs = 120;
|
||||
cfg.multimodal.allow_remote_fetch = true;
|
||||
cfg.query_classification.enabled = true;
|
||||
cfg.query_classification.rules = vec![crate::config::ClassificationRule {
|
||||
hint: "reasoning".to_string(),
|
||||
keywords: vec!["analyze".to_string()],
|
||||
patterns: vec!["deep".to_string()],
|
||||
min_length: None,
|
||||
max_length: None,
|
||||
priority: 10,
|
||||
}];
|
||||
cfg.model_routes = vec![crate::config::ModelRouteConfig {
|
||||
hint: "reasoning".to_string(),
|
||||
provider: "openrouter".to_string(),
|
||||
model: "openai/gpt-5.2".to_string(),
|
||||
max_tokens: Some(512),
|
||||
api_key: None,
|
||||
transport: None,
|
||||
}];
|
||||
cfg.save().await.expect("save updated config");
|
||||
|
||||
maybe_apply_runtime_config_update(runtime_ctx.as_ref())
|
||||
@ -9705,6 +9885,15 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
crate::config::OutboundLeakGuardAction::Block
|
||||
);
|
||||
assert_eq!(leak_guard_cfg.sensitivity, 0.92);
|
||||
let defaults = runtime_defaults_snapshot(runtime_ctx.as_ref());
|
||||
assert!(defaults.auto_save_memory);
|
||||
assert_eq!(defaults.min_relevance_score, 0.65);
|
||||
assert_eq!(defaults.max_tool_iterations, 11);
|
||||
assert_eq!(defaults.message_timeout_secs, 120);
|
||||
assert!(defaults.multimodal.allow_remote_fetch);
|
||||
assert!(defaults.query_classification.enabled);
|
||||
assert_eq!(defaults.query_classification.rules.len(), 1);
|
||||
assert_eq!(defaults.model_routes.len(), 1);
|
||||
|
||||
let mut store = runtime_config_store()
|
||||
.lock()
|
||||
|
||||
@ -589,6 +589,40 @@ impl TelegramChannel {
|
||||
body
|
||||
}
|
||||
|
||||
fn build_approval_prompt_body(
|
||||
chat_id: &str,
|
||||
thread_id: Option<&str>,
|
||||
request_id: &str,
|
||||
tool_name: &str,
|
||||
args_preview: &str,
|
||||
) -> serde_json::Value {
|
||||
let mut body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": format!(
|
||||
"Approval required for tool `{tool_name}`.\nRequest ID: `{request_id}`\nArgs: `{args_preview}`",
|
||||
),
|
||||
"parse_mode": "Markdown",
|
||||
"reply_markup": {
|
||||
"inline_keyboard": [[
|
||||
{
|
||||
"text": "Approve",
|
||||
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_APPROVE_PREFIX}{request_id}")
|
||||
},
|
||||
{
|
||||
"text": "Deny",
|
||||
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_DENY_PREFIX}{request_id}")
|
||||
}
|
||||
]]
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(thread_id) = thread_id {
|
||||
body["message_thread_id"] = serde_json::Value::String(thread_id.to_string());
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
fn extract_update_message_ack_target(
|
||||
update: &serde_json::Value,
|
||||
) -> Option<(String, i64, AckReactionContextChatType, Option<String>)> {
|
||||
@ -3153,28 +3187,13 @@ impl Channel for TelegramChannel {
|
||||
raw_args
|
||||
};
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": format!(
|
||||
"Approval required for tool `{tool_name}`.\nRequest ID: `{request_id}`\nArgs: `{args_preview}`",
|
||||
),
|
||||
"reply_markup": {
|
||||
"inline_keyboard": [[
|
||||
{
|
||||
"text": "Approve",
|
||||
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_APPROVE_PREFIX}{request_id}")
|
||||
},
|
||||
{
|
||||
"text": "Deny",
|
||||
"callback_data": format!("{TELEGRAM_APPROVAL_CALLBACK_DENY_PREFIX}{request_id}")
|
||||
}
|
||||
]]
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(thread_id) = thread_id {
|
||||
body["message_thread_id"] = serde_json::Value::String(thread_id);
|
||||
}
|
||||
let body = Self::build_approval_prompt_body(
|
||||
&chat_id,
|
||||
thread_id.as_deref(),
|
||||
request_id,
|
||||
tool_name,
|
||||
&args_preview,
|
||||
);
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
@ -3654,6 +3673,24 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn approval_prompt_includes_markdown_parse_mode() {
|
||||
let body = TelegramChannel::build_approval_prompt_body(
|
||||
"12345",
|
||||
Some("67890"),
|
||||
"apr-1234",
|
||||
"shell",
|
||||
"{\"command\":\"echo hello\"}",
|
||||
);
|
||||
|
||||
assert_eq!(body["parse_mode"], "Markdown");
|
||||
assert_eq!(body["chat_id"], "12345");
|
||||
assert_eq!(body["message_thread_id"], "67890");
|
||||
assert!(body["text"]
|
||||
.as_str()
|
||||
.is_some_and(|text| text.contains("`shell`")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_telegram_error_redacts_bot_token_in_url() {
|
||||
let input =
|
||||
|
||||
@ -10,12 +10,13 @@ pub use schema::{
|
||||
AckReactionRuleConfig, AckReactionStrategy, AgentConfig, AgentSessionBackend,
|
||||
AgentSessionConfig, AgentSessionStrategy, AgentsIpcConfig, AuditConfig, AutonomyConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ComposioConfig, Config, CoordinationConfig, CostConfig, CronConfig,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EconomicConfig, EconomicTokenPricing,
|
||||
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GroupReplyConfig,
|
||||
GroupReplyMode, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, HttpRequestCredentialProfile, IMessageConfig, IdentityConfig, LarkConfig,
|
||||
MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
|
||||
ClassificationRule, CommandContextRuleAction, CommandContextRuleConfig, ComposioConfig, Config,
|
||||
CoordinationConfig, CostConfig, CronConfig, DelegateAgentConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EconomicConfig, EconomicTokenPricing, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GroupReplyConfig, GroupReplyMode, HardwareConfig,
|
||||
HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
HttpRequestCredentialProfile, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig,
|
||||
MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
|
||||
NonCliNaturalLanguageApprovalMode, ObservabilityConfig, OtpChallengeDelivery, OtpConfig,
|
||||
OtpMethod, OutboundLeakGuardAction, OutboundLeakGuardConfig, PeripheralBoardConfig,
|
||||
PeripheralsConfig, PerplexityFilterConfig, PluginEntryConfig, PluginsConfig, ProgressMode,
|
||||
|
||||
@ -3135,6 +3135,67 @@ pub enum NonCliNaturalLanguageApprovalMode {
|
||||
Direct,
|
||||
}
|
||||
|
||||
/// Action to apply when a command-context rule matches.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CommandContextRuleAction {
|
||||
/// Matching context is explicitly allowed.
|
||||
#[default]
|
||||
Allow,
|
||||
/// Matching context is explicitly denied.
|
||||
Deny,
|
||||
}
|
||||
|
||||
/// Context-aware allow/deny rule for shell commands.
|
||||
///
|
||||
/// Rules are evaluated per command segment. Command matching accepts command
|
||||
/// names (`curl`), explicit paths (`/usr/bin/curl`), and wildcard (`*`).
|
||||
///
|
||||
/// Matching semantics:
|
||||
/// - `action = "deny"`: if all constraints match, the segment is rejected.
|
||||
/// - `action = "allow"`: if at least one allow rule exists for a command,
|
||||
/// segments must match at least one of those allow rules.
|
||||
///
|
||||
/// Constraints are optional:
|
||||
/// - `allowed_domains`: require URL arguments to match these hosts/patterns.
|
||||
/// - `allowed_path_prefixes`: require path-like arguments to stay under these prefixes.
|
||||
/// - `denied_path_prefixes`: for deny rules, match when any path-like argument
|
||||
/// is under these prefixes; for allow rules, require path arguments not to hit
|
||||
/// these prefixes.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct CommandContextRuleConfig {
|
||||
/// Command name/path pattern (`git`, `/usr/bin/curl`, or `*`).
|
||||
pub command: String,
|
||||
|
||||
/// Rule action (`allow` | `deny`). Defaults to `allow`.
|
||||
#[serde(default)]
|
||||
pub action: CommandContextRuleAction,
|
||||
|
||||
/// Allowed host patterns for URL arguments.
|
||||
///
|
||||
/// Supports exact hosts (`api.example.com`) and wildcard suffixes (`*.example.com`).
|
||||
#[serde(default)]
|
||||
pub allowed_domains: Vec<String>,
|
||||
|
||||
/// Allowed path prefixes for path-like arguments.
|
||||
///
|
||||
/// Prefixes may be absolute, `~/...`, or workspace-relative.
|
||||
#[serde(default)]
|
||||
pub allowed_path_prefixes: Vec<String>,
|
||||
|
||||
/// Denied path prefixes for path-like arguments.
|
||||
///
|
||||
/// Prefixes may be absolute, `~/...`, or workspace-relative.
|
||||
#[serde(default)]
|
||||
pub denied_path_prefixes: Vec<String>,
|
||||
|
||||
/// Permit high-risk commands when this allow rule matches.
|
||||
///
|
||||
/// The command still requires explicit `approved=true` in supervised mode.
|
||||
#[serde(default)]
|
||||
pub allow_high_risk: bool,
|
||||
}
|
||||
|
||||
/// Autonomy and security policy configuration (`[autonomy]` section).
|
||||
///
|
||||
/// Controls what the agent is allowed to do: shell commands, filesystem access,
|
||||
@ -3148,6 +3209,13 @@ pub struct AutonomyConfig {
|
||||
pub workspace_only: bool,
|
||||
/// Allowlist of executable names permitted for shell execution.
|
||||
pub allowed_commands: Vec<String>,
|
||||
|
||||
/// Context-aware shell command allow/deny rules.
|
||||
///
|
||||
/// These rules are evaluated per command segment and can narrow or override
|
||||
/// global `allowed_commands` behavior for matching commands.
|
||||
#[serde(default)]
|
||||
pub command_context_rules: Vec<CommandContextRuleConfig>,
|
||||
/// Explicit path denylist. Default includes system-critical paths and sensitive dotdirs.
|
||||
pub forbidden_paths: Vec<String>,
|
||||
/// Maximum actions allowed per hour per policy. Default: `100`.
|
||||
@ -3252,6 +3320,7 @@ fn default_always_ask() -> Vec<String> {
|
||||
fn default_non_cli_excluded_tools() -> Vec<String> {
|
||||
[
|
||||
"shell",
|
||||
"process",
|
||||
"file_write",
|
||||
"file_edit",
|
||||
"git_operations",
|
||||
@ -3310,6 +3379,7 @@ impl Default for AutonomyConfig {
|
||||
"tail".into(),
|
||||
"date".into(),
|
||||
],
|
||||
command_context_rules: Vec::new(),
|
||||
forbidden_paths: vec![
|
||||
"/etc".into(),
|
||||
"/root".into(),
|
||||
@ -7032,6 +7102,75 @@ fn validate_mcp_config(config: &McpConfig) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn legacy_feishu_table(raw_toml: &toml::Value) -> Option<&toml::map::Map<String, toml::Value>> {
|
||||
raw_toml
|
||||
.get("channels_config")?
|
||||
.as_table()?
|
||||
.get("feishu")?
|
||||
.as_table()
|
||||
}
|
||||
|
||||
fn extract_legacy_feishu_mention_only(raw_toml: &toml::Value) -> Option<bool> {
|
||||
legacy_feishu_table(raw_toml)?
|
||||
.get("mention_only")
|
||||
.and_then(toml::Value::as_bool)
|
||||
}
|
||||
|
||||
fn has_legacy_feishu_mention_only(raw_toml: &toml::Value) -> bool {
|
||||
legacy_feishu_table(raw_toml)
|
||||
.and_then(|table| table.get("mention_only"))
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn has_legacy_feishu_use_feishu(raw_toml: &toml::Value) -> bool {
|
||||
legacy_feishu_table(raw_toml)
|
||||
.and_then(|table| table.get("use_feishu"))
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn apply_feishu_legacy_compat(
|
||||
config: &mut Config,
|
||||
legacy_feishu_mention_only: Option<bool>,
|
||||
legacy_feishu_use_feishu_present: bool,
|
||||
saw_legacy_feishu_mention_only_path: bool,
|
||||
saw_legacy_feishu_use_feishu_path: bool,
|
||||
) {
|
||||
// Backward compatibility: users sometimes migrate config snippets from
|
||||
// [channels_config.lark] to [channels_config.feishu] and keep old keys.
|
||||
if let Some(feishu_cfg) = config.channels_config.feishu.as_mut() {
|
||||
if let Some(legacy_mention_only) = legacy_feishu_mention_only {
|
||||
if feishu_cfg.group_reply.is_none() {
|
||||
let mapped_mode = if legacy_mention_only {
|
||||
GroupReplyMode::MentionOnly
|
||||
} else {
|
||||
GroupReplyMode::AllMessages
|
||||
};
|
||||
feishu_cfg.group_reply = Some(GroupReplyConfig {
|
||||
mode: Some(mapped_mode),
|
||||
allowed_sender_ids: Vec::new(),
|
||||
});
|
||||
tracing::warn!(
|
||||
"Legacy key [channels_config.feishu].mention_only is deprecated; mapped to [channels_config.feishu.group_reply].mode."
|
||||
);
|
||||
} else if saw_legacy_feishu_mention_only_path {
|
||||
tracing::warn!(
|
||||
"Legacy key [channels_config.feishu].mention_only is ignored because [channels_config.feishu.group_reply] is already set."
|
||||
);
|
||||
}
|
||||
} else if saw_legacy_feishu_mention_only_path {
|
||||
tracing::warn!(
|
||||
"Legacy key [channels_config.feishu].mention_only is invalid; expected boolean."
|
||||
);
|
||||
}
|
||||
|
||||
if legacy_feishu_use_feishu_present || saw_legacy_feishu_use_feishu_path {
|
||||
tracing::warn!(
|
||||
"Legacy key [channels_config.feishu].use_feishu is redundant and ignored; [channels_config.feishu] always uses Feishu endpoints."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub async fn load_or_init() -> Result<Self> {
|
||||
let (default_zeroclaw_dir, default_workspace_dir) = default_config_and_workspace_dirs()?;
|
||||
@ -7070,8 +7209,23 @@ impl Config {
|
||||
.await
|
||||
.context("Failed to read config file")?;
|
||||
|
||||
// Parse raw TOML first so legacy compatibility rewrites can be applied after
|
||||
// deserialization.
|
||||
let raw_toml: toml::Value =
|
||||
toml::from_str(&contents).context("Failed to parse config file")?;
|
||||
let legacy_feishu_mention_only = extract_legacy_feishu_mention_only(&raw_toml);
|
||||
let legacy_feishu_mention_only_present = has_legacy_feishu_mention_only(&raw_toml);
|
||||
let legacy_feishu_use_feishu_present = has_legacy_feishu_use_feishu(&raw_toml);
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).context("Failed to deserialize config file")?;
|
||||
|
||||
apply_feishu_legacy_compat(
|
||||
&mut config,
|
||||
legacy_feishu_mention_only,
|
||||
legacy_feishu_use_feishu_present,
|
||||
legacy_feishu_mention_only_present,
|
||||
legacy_feishu_use_feishu_present,
|
||||
);
|
||||
// Set computed paths that are skipped during serialization
|
||||
config.config_path = config_path.clone();
|
||||
config.workspace_dir = workspace_dir;
|
||||
@ -7431,6 +7585,61 @@ impl Config {
|
||||
);
|
||||
}
|
||||
}
|
||||
for (i, rule) in self.autonomy.command_context_rules.iter().enumerate() {
|
||||
let command = rule.command.trim();
|
||||
if command.is_empty() {
|
||||
anyhow::bail!("autonomy.command_context_rules[{i}].command must not be empty");
|
||||
}
|
||||
if !command
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '-' | '/' | '.' | '*'))
|
||||
{
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].command contains invalid characters: {command}"
|
||||
);
|
||||
}
|
||||
|
||||
for (j, domain) in rule.allowed_domains.iter().enumerate() {
|
||||
let normalized = domain.trim();
|
||||
if normalized.is_empty() {
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].allowed_domains[{j}] must not be empty"
|
||||
);
|
||||
}
|
||||
if normalized.chars().any(char::is_whitespace) {
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].allowed_domains[{j}] must not contain whitespace"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (j, prefix) in rule.allowed_path_prefixes.iter().enumerate() {
|
||||
let normalized = prefix.trim();
|
||||
if normalized.is_empty() {
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].allowed_path_prefixes[{j}] must not be empty"
|
||||
);
|
||||
}
|
||||
if normalized.contains('\0') {
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].allowed_path_prefixes[{j}] must not contain null bytes"
|
||||
);
|
||||
}
|
||||
}
|
||||
for (j, prefix) in rule.denied_path_prefixes.iter().enumerate() {
|
||||
let normalized = prefix.trim();
|
||||
if normalized.is_empty() {
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].denied_path_prefixes[{j}] must not be empty"
|
||||
);
|
||||
}
|
||||
if normalized.contains('\0') {
|
||||
anyhow::bail!(
|
||||
"autonomy.command_context_rules[{i}].denied_path_prefixes[{j}] must not contain null bytes"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut seen_non_cli_excluded = std::collections::HashSet::new();
|
||||
for (i, tool_name) in self.autonomy.non_cli_excluded_tools.iter().enumerate() {
|
||||
let normalized = tool_name.trim();
|
||||
@ -9215,9 +9424,11 @@ mod tests {
|
||||
assert!(a.require_approval_for_medium_risk);
|
||||
assert!(a.block_high_risk_commands);
|
||||
assert!(a.shell_env_passthrough.is_empty());
|
||||
assert!(a.command_context_rules.is_empty());
|
||||
assert!(!a.allow_sensitive_file_reads);
|
||||
assert!(!a.allow_sensitive_file_writes);
|
||||
assert!(a.non_cli_excluded_tools.contains(&"shell".to_string()));
|
||||
assert!(a.non_cli_excluded_tools.contains(&"process".to_string()));
|
||||
assert!(a.non_cli_excluded_tools.contains(&"delegate".to_string()));
|
||||
}
|
||||
|
||||
@ -9246,12 +9457,53 @@ allowed_roots = []
|
||||
!parsed.allow_sensitive_file_writes,
|
||||
"Missing allow_sensitive_file_writes must default to false"
|
||||
);
|
||||
assert!(
|
||||
parsed.command_context_rules.is_empty(),
|
||||
"Missing command_context_rules must default to empty"
|
||||
);
|
||||
assert!(parsed.non_cli_excluded_tools.contains(&"shell".to_string()));
|
||||
assert!(parsed
|
||||
.non_cli_excluded_tools
|
||||
.contains(&"process".to_string()));
|
||||
assert!(parsed
|
||||
.non_cli_excluded_tools
|
||||
.contains(&"browser".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn config_validate_rejects_invalid_command_context_rule_command() {
|
||||
let mut cfg = Config::default();
|
||||
cfg.autonomy.command_context_rules = vec![CommandContextRuleConfig {
|
||||
command: "curl;rm".into(),
|
||||
action: CommandContextRuleAction::Allow,
|
||||
allowed_domains: vec![],
|
||||
allowed_path_prefixes: vec![],
|
||||
denied_path_prefixes: vec![],
|
||||
allow_high_risk: false,
|
||||
}];
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("autonomy.command_context_rules[0].command"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn config_validate_rejects_empty_command_context_rule_domain() {
|
||||
let mut cfg = Config::default();
|
||||
cfg.autonomy.command_context_rules = vec![CommandContextRuleConfig {
|
||||
command: "curl".into(),
|
||||
action: CommandContextRuleAction::Allow,
|
||||
allowed_domains: vec![" ".into()],
|
||||
allowed_path_prefixes: vec![],
|
||||
denied_path_prefixes: vec![],
|
||||
allow_high_risk: true,
|
||||
}];
|
||||
let err = cfg.validate().unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("autonomy.command_context_rules[0].allowed_domains[0]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn config_validate_rejects_duplicate_non_cli_excluded_tools() {
|
||||
let mut cfg = Config::default();
|
||||
@ -9447,6 +9699,7 @@ ws_url = "ws://127.0.0.1:3002"
|
||||
level: AutonomyLevel::Full,
|
||||
workspace_only: false,
|
||||
allowed_commands: vec!["docker".into()],
|
||||
command_context_rules: vec![],
|
||||
forbidden_paths: vec!["/secret".into()],
|
||||
max_actions_per_hour: 50,
|
||||
max_cost_per_day_cents: 1000,
|
||||
@ -13103,6 +13356,83 @@ default_model = "legacy-model"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn feishu_legacy_key_extractors_detect_compat_fields() {
|
||||
let raw: toml::Value = toml::from_str(
|
||||
r#"
|
||||
[channels_config.feishu]
|
||||
app_id = "cli_123"
|
||||
app_secret = "secret"
|
||||
mention_only = true
|
||||
use_feishu = true
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_legacy_feishu_mention_only(&raw), Some(true));
|
||||
assert!(has_legacy_feishu_mention_only(&raw));
|
||||
assert!(has_legacy_feishu_use_feishu(&raw));
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn feishu_legacy_mention_only_maps_to_group_reply_mode() {
|
||||
let mut parsed = Config::default();
|
||||
parsed.channels_config.feishu = Some(FeishuConfig {
|
||||
app_id: "cli_123".into(),
|
||||
app_secret: "secret".into(),
|
||||
encrypt_key: None,
|
||||
verification_token: None,
|
||||
allowed_users: vec![],
|
||||
group_reply: None,
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
draft_update_interval_ms: default_lark_draft_update_interval_ms(),
|
||||
max_draft_edits: default_lark_max_draft_edits(),
|
||||
});
|
||||
|
||||
apply_feishu_legacy_compat(&mut parsed, Some(true), true, true, true);
|
||||
|
||||
let feishu = parsed
|
||||
.channels_config
|
||||
.feishu
|
||||
.expect("feishu config should exist");
|
||||
assert_eq!(
|
||||
feishu.effective_group_reply_mode(),
|
||||
GroupReplyMode::MentionOnly
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn feishu_legacy_mention_only_does_not_override_group_reply() {
|
||||
let mut parsed = Config::default();
|
||||
parsed.channels_config.feishu = Some(FeishuConfig {
|
||||
app_id: "cli_123".into(),
|
||||
app_secret: "secret".into(),
|
||||
encrypt_key: None,
|
||||
verification_token: None,
|
||||
allowed_users: vec![],
|
||||
group_reply: Some(GroupReplyConfig {
|
||||
mode: Some(GroupReplyMode::AllMessages),
|
||||
allowed_sender_ids: vec![],
|
||||
}),
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
draft_update_interval_ms: default_lark_draft_update_interval_ms(),
|
||||
max_draft_edits: default_lark_max_draft_edits(),
|
||||
});
|
||||
|
||||
apply_feishu_legacy_compat(&mut parsed, Some(true), false, true, false);
|
||||
|
||||
let feishu = parsed
|
||||
.channels_config
|
||||
.feishu
|
||||
.expect("feishu config should exist");
|
||||
assert_eq!(
|
||||
feishu.effective_group_reply_mode(),
|
||||
GroupReplyMode::AllMessages
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn qq_config_defaults_to_webhook_receive_mode() {
|
||||
let json = r#"{"app_id":"123","app_secret":"secret"}"#;
|
||||
|
||||
@ -42,6 +42,7 @@ use tracing::{info, warn};
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
const PROFILE_MISMATCH_PREFIX: &str = "Pending login profile mismatch:";
|
||||
const ZEROCLAW_BUILD_VERSION: &str = env!("ZEROCLAW_BUILD_VERSION");
|
||||
|
||||
#[derive(Debug, Clone, ValueEnum)]
|
||||
enum QuotaFormat {
|
||||
@ -132,7 +133,7 @@ enum EstopLevelArg {
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeroclaw")]
|
||||
#[command(author = "theonlyhennygod")]
|
||||
#[command(version)]
|
||||
#[command(version = ZEROCLAW_BUILD_VERSION)]
|
||||
#[command(about = "The fastest, smallest AI assistant.", long_about = None)]
|
||||
struct Cli {
|
||||
#[arg(long, global = true)]
|
||||
@ -1021,7 +1022,7 @@ async fn main() -> Result<()> {
|
||||
Commands::Status => {
|
||||
println!("🦀 ZeroClaw Status");
|
||||
println!();
|
||||
println!("Version: {}", env!("CARGO_PKG_VERSION"));
|
||||
println!("Version: {}", ZEROCLAW_BUILD_VERSION);
|
||||
println!("Workspace: {}", config.workspace_dir.display());
|
||||
println!("Config: {}", config.config_path.display());
|
||||
println!();
|
||||
|
||||
152
src/memory/decay.rs
Normal file
152
src/memory/decay.rs
Normal file
@ -0,0 +1,152 @@
|
||||
use super::traits::{MemoryCategory, MemoryEntry};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Default half-life in days for time-decay scoring.
|
||||
/// After this many days, a non-Core memory's score drops to 50%.
|
||||
const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Apply exponential time decay to memory entry scores.
|
||||
///
|
||||
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
|
||||
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
|
||||
/// - Entries without a score (`None`) are left unchanged.
|
||||
///
|
||||
/// Decay formula: `score * 2^(-age_days / half_life_days)`
|
||||
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
|
||||
let half_life = if half_life_days <= 0.0 {
|
||||
DEFAULT_HALF_LIFE_DAYS
|
||||
} else {
|
||||
half_life_days
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
// Core memories are evergreen — never decay
|
||||
if entry.category == MemoryCategory::Core {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score = match entry.score {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
Ok(dt) => dt.with_timezone(&Utc),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let age_days = now
|
||||
.signed_duration_since(ts)
|
||||
.num_seconds()
|
||||
.max(0) as f64
|
||||
/ 86_400.0;
|
||||
|
||||
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
|
||||
entry.score = Some(score * decay_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "test".into(),
|
||||
content: "value".into(),
|
||||
category,
|
||||
timestamp: timestamp.into(),
|
||||
session_id: None,
|
||||
score,
|
||||
}
|
||||
}
|
||||
|
||||
fn recent_rfc3339() -> String {
|
||||
Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
fn days_ago_rfc3339(days: i64) -> String {
|
||||
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn core_memories_are_never_decayed() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Core,
|
||||
Some(0.9),
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recent_entry_score_barely_changes() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.8),
|
||||
&recent_rfc3339(),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.8).abs() < 0.01,
|
||||
"recent entry should barely decay, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_half_life_halves_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(7),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.5).abs() < 0.05,
|
||||
"score after one half-life should be ~0.5, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_half_lives_quarters_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(14),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.25).abs() < 0.05,
|
||||
"score after two half-lives should be ~0.25, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_score_entry_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_timestamp_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.9),
|
||||
"not-a-date",
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
}
|
||||
@ -2,6 +2,7 @@ pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod cortex;
|
||||
pub mod decay;
|
||||
pub mod embeddings;
|
||||
pub mod hybrid;
|
||||
pub mod hygiene;
|
||||
|
||||
@ -410,8 +410,9 @@ impl AnthropicProvider {
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find(|c| c.kind == "text")
|
||||
.and_then(|c| c.text)
|
||||
.filter(|c| c.kind == "text")
|
||||
.filter_map(|c| c.text.map(|text| text.trim().to_string()))
|
||||
.find(|text| !text.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
|
||||
}
|
||||
|
||||
@ -1421,6 +1422,36 @@ mod tests {
|
||||
assert!(result.usage.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_text_response_ignores_empty_and_whitespace_text_blocks() {
|
||||
let json = r#"{
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
{"type": "text", "text": " \n "},
|
||||
{"type": "text", "text": " final answer "}
|
||||
]
|
||||
}"#;
|
||||
let response: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
|
||||
let parsed = AnthropicProvider::parse_text_response(response).unwrap();
|
||||
assert_eq!(parsed, "final answer");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_text_response_rejects_empty_or_whitespace_only_text_blocks() {
|
||||
let json = r#"{
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
{"type": "text", "text": " \n "},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "shell"}
|
||||
]
|
||||
}"#;
|
||||
let response: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
|
||||
let err = AnthropicProvider::parse_text_response(response).unwrap_err();
|
||||
assert!(err.to_string().contains("No response from Anthropic"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_reports_vision_and_native_tool_calling() {
|
||||
let provider = AnthropicProvider::new(Some("test-key"));
|
||||
|
||||
@ -16,6 +16,7 @@ use reqwest::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::{
|
||||
@ -29,6 +30,7 @@ use tokio_tungstenite::{
|
||||
/// A provider that speaks the OpenAI-compatible chat completions API.
|
||||
/// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot,
|
||||
/// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc.
|
||||
#[derive(Clone)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct OpenAiCompatibleProvider {
|
||||
pub(crate) name: String,
|
||||
@ -1151,6 +1153,90 @@ impl OpenAiCompatibleProvider {
|
||||
self.api_mode == CompatibleApiMode::OpenAiResponses
|
||||
}
|
||||
|
||||
fn chat_completions_fallback_provider(&self) -> Self {
|
||||
let mut provider = self.clone();
|
||||
provider.api_mode = CompatibleApiMode::OpenAiChatCompletions;
|
||||
provider.supports_responses_fallback = false;
|
||||
provider
|
||||
}
|
||||
|
||||
fn error_status_code(error: &anyhow::Error) -> Option<reqwest::StatusCode> {
|
||||
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
|
||||
if let Some(status) = reqwest_error.status() {
|
||||
return Some(status);
|
||||
}
|
||||
}
|
||||
|
||||
let message = error.to_string();
|
||||
for token in message.split(|c: char| !c.is_ascii_digit()) {
|
||||
let Ok(code) = token.parse::<u16>() else {
|
||||
continue;
|
||||
};
|
||||
if let Ok(status) = reqwest::StatusCode::from_u16(code) {
|
||||
if status.is_client_error() || status.is_server_error() {
|
||||
return Some(status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn is_authentication_error(error: &anyhow::Error) -> bool {
|
||||
if let Some(status) = Self::error_status_code(error) {
|
||||
if status == reqwest::StatusCode::UNAUTHORIZED
|
||||
|| status == reqwest::StatusCode::FORBIDDEN
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let lower = error.to_string().to_ascii_lowercase();
|
||||
let auth_hints = [
|
||||
"invalid api key",
|
||||
"incorrect api key",
|
||||
"missing api key",
|
||||
"api key not set",
|
||||
"authentication failed",
|
||||
"auth failed",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"permission denied",
|
||||
"access denied",
|
||||
"invalid token",
|
||||
];
|
||||
|
||||
auth_hints.iter().any(|hint| lower.contains(hint))
|
||||
}
|
||||
|
||||
fn should_fallback_to_chat_completions(error: &anyhow::Error) -> bool {
|
||||
if Self::is_authentication_error(error) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(status) = Self::error_status_code(error) {
|
||||
return status == reqwest::StatusCode::NOT_FOUND
|
||||
|| status == reqwest::StatusCode::REQUEST_TIMEOUT
|
||||
|| status == reqwest::StatusCode::TOO_MANY_REQUESTS
|
||||
|| status.is_server_error();
|
||||
}
|
||||
|
||||
if let Some(reqwest_error) = error.downcast_ref::<reqwest::Error>() {
|
||||
if reqwest_error.is_connect()
|
||||
|| reqwest_error.is_timeout()
|
||||
|| reqwest_error.is_request()
|
||||
|| reqwest_error.is_body()
|
||||
|| reqwest_error.is_decode()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let lower = error.to_string().to_ascii_lowercase();
|
||||
lower.contains("responses api returned an unexpected payload")
|
||||
|| lower.contains("no response from")
|
||||
}
|
||||
|
||||
fn effective_max_tokens(&self) -> Option<u32> {
|
||||
self.max_tokens_override.filter(|value| *value > 0)
|
||||
}
|
||||
@ -1335,8 +1421,10 @@ impl OpenAiCompatibleProvider {
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("{} Responses API error: {error}", self.name);
|
||||
let sanitized = super::sanitize_api_error(&error);
|
||||
anyhow::bail!("{} Responses API error ({status}): {sanitized}", self.name);
|
||||
}
|
||||
|
||||
let body = response.text().await?;
|
||||
@ -1387,10 +1475,37 @@ impl OpenAiCompatibleProvider {
|
||||
credential: &str,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let responses = self
|
||||
let responses = match self
|
||||
.send_responses_request(credential, messages, model, None)
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(responses_err) => {
|
||||
if self.should_use_responses_mode()
|
||||
&& Self::should_fallback_to_chat_completions(&responses_err)
|
||||
{
|
||||
tracing::warn!(
|
||||
provider = %self.name,
|
||||
error = %responses_err,
|
||||
"Responses API request failed in responses mode; retrying via chat completions"
|
||||
);
|
||||
let fallback_provider = self.chat_completions_fallback_provider();
|
||||
let sanitized = super::sanitize_api_error(&responses_err.to_string());
|
||||
return fallback_provider
|
||||
.chat_with_history(messages, model, temperature)
|
||||
.await
|
||||
.map_err(|chat_err| {
|
||||
anyhow::anyhow!(
|
||||
"{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})",
|
||||
self.name
|
||||
)
|
||||
});
|
||||
}
|
||||
return Err(responses_err);
|
||||
}
|
||||
};
|
||||
extract_responses_text(&responses)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
|
||||
}
|
||||
@ -1401,10 +1516,51 @@ impl OpenAiCompatibleProvider {
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
tools: Option<Vec<Value>>,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let responses = self
|
||||
.send_responses_request(credential, messages, model, tools)
|
||||
.await?;
|
||||
let responses = match self
|
||||
.send_responses_request(credential, messages, model, tools.clone())
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(responses_err) => {
|
||||
if self.should_use_responses_mode()
|
||||
&& Self::should_fallback_to_chat_completions(&responses_err)
|
||||
{
|
||||
tracing::warn!(
|
||||
provider = %self.name,
|
||||
error = %responses_err,
|
||||
"Responses API request failed in responses mode; retrying via chat completions"
|
||||
);
|
||||
let fallback_provider = self.chat_completions_fallback_provider();
|
||||
let fallback_tool_specs = tools
|
||||
.as_deref()
|
||||
.map(Self::openai_tools_to_tool_specs)
|
||||
.unwrap_or_default();
|
||||
let fallback_tools =
|
||||
(!fallback_tool_specs.is_empty()).then_some(fallback_tool_specs.as_slice());
|
||||
let sanitized = super::sanitize_api_error(&responses_err.to_string());
|
||||
|
||||
return fallback_provider
|
||||
.chat(
|
||||
ProviderChatRequest {
|
||||
messages,
|
||||
tools: fallback_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
.map_err(|chat_err| {
|
||||
anyhow::anyhow!(
|
||||
"{} Responses API failed: {sanitized} (chat-completions fallback failed: {chat_err})",
|
||||
self.name
|
||||
)
|
||||
});
|
||||
}
|
||||
return Err(responses_err);
|
||||
}
|
||||
};
|
||||
let parsed = parse_responses_chat_response(responses);
|
||||
if parsed.text.is_none() && parsed.tool_calls.is_empty() {
|
||||
anyhow::bail!("No response from {} Responses API", self.name);
|
||||
@ -1467,90 +1623,173 @@ impl OpenAiCompatibleProvider {
|
||||
messages: &[ChatMessage],
|
||||
allow_user_image_parts: bool,
|
||||
) -> Vec<NativeMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
|
||||
{
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||
tool_calls_value.clone(),
|
||||
)
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| ToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some(tc.name),
|
||||
arguments: Some(tc.arguments),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut native_messages = Vec::with_capacity(messages.len());
|
||||
let mut assistant_tool_call_ids = HashSet::new();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|value| MessageContent::Text(value.to_string()));
|
||||
|
||||
let reasoning_content = value
|
||||
.get("reasoning_content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
reasoning_content,
|
||||
};
|
||||
for message in messages {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
if let Some(tool_calls) = Self::parse_history_tool_calls(&value) {
|
||||
for call in &tool_calls {
|
||||
if let Some(id) = call.id.as_ref() {
|
||||
assistant_tool_call_ids.insert(id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
// Some OpenAI-compatible providers (including NVIDIA NIM models)
|
||||
// reject assistant tool-call messages if `content` is omitted.
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|value| MessageContent::Text(value.to_string()))
|
||||
.or_else(|| Some(MessageContent::Text(message.content.clone())));
|
||||
.map(ToString::to_string)
|
||||
.unwrap_or_default();
|
||||
|
||||
return NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
};
|
||||
let reasoning_content = value
|
||||
.get("reasoning_content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
native_messages.push(NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(MessageContent::Text(content)),
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
reasoning_content,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NativeMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(Self::to_message_content(
|
||||
&message.role,
|
||||
&message.content,
|
||||
allow_user_image_parts,
|
||||
)),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.or_else(|| value.get("tool_use_id"))
|
||||
.or_else(|| value.get("toolUseId"))
|
||||
.or_else(|| value.get("id"))
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
let content_text = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string)
|
||||
.unwrap_or_else(|| message.content.clone());
|
||||
|
||||
if let Some(id) = tool_call_id {
|
||||
if assistant_tool_call_ids.contains(&id) {
|
||||
native_messages.push(NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content: Some(MessageContent::Text(content_text)),
|
||||
tool_call_id: Some(id),
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
tool_call_id = %id,
|
||||
"Dropping orphan tool-role message; no matching assistant tool_call in history"
|
||||
);
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Dropping tool-role message missing tool_call_id; preserving as user text fallback"
|
||||
);
|
||||
}
|
||||
|
||||
native_messages.push(NativeMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(MessageContent::Text(format!(
|
||||
"[Tool result]\n{}",
|
||||
content_text
|
||||
))),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
native_messages.push(NativeMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(Self::to_message_content(
|
||||
&message.role,
|
||||
&message.content,
|
||||
allow_user_image_parts,
|
||||
)),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
});
|
||||
}
|
||||
|
||||
native_messages
|
||||
}
|
||||
|
||||
fn parse_history_tool_calls(value: &serde_json::Value) -> Option<Vec<ToolCall>> {
|
||||
let tool_calls_value = value.get("tool_calls")?;
|
||||
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| ToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some(tc.name),
|
||||
arguments: Some(Self::normalize_tool_arguments(tc.arguments)),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if !tool_calls.is_empty() {
|
||||
return Some(tool_calls);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(parsed_calls) = serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let mut normalized_calls = Vec::with_capacity(parsed_calls.len());
|
||||
for call in parsed_calls {
|
||||
let Some(name) = call.function_name() else {
|
||||
continue;
|
||||
};
|
||||
let arguments = call
|
||||
.function_arguments()
|
||||
.unwrap_or_else(|| "{}".to_string());
|
||||
normalized_calls.push(ToolCall {
|
||||
id: Some(call.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some(name),
|
||||
arguments: Some(Self::normalize_tool_arguments(arguments)),
|
||||
}),
|
||||
name: None,
|
||||
arguments: None,
|
||||
parameters: None,
|
||||
});
|
||||
}
|
||||
if !normalized_calls.is_empty() {
|
||||
return Some(normalized_calls);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn normalize_tool_arguments(arguments: String) -> String {
|
||||
if serde_json::from_str::<serde_json::Value>(&arguments).is_ok() {
|
||||
arguments
|
||||
} else {
|
||||
"{}".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn with_prompt_guided_tool_instructions(
|
||||
@ -1595,17 +1834,14 @@ impl OpenAiCompatibleProvider {
|
||||
.filter_map(|tc| {
|
||||
let name = tc.function_name()?;
|
||||
let arguments = tc.function_arguments().unwrap_or_else(|| "{}".to_string());
|
||||
let normalized_arguments =
|
||||
if serde_json::from_str::<serde_json::Value>(&arguments).is_ok() {
|
||||
arguments
|
||||
} else {
|
||||
tracing::warn!(
|
||||
function = %name,
|
||||
arguments = %arguments,
|
||||
"Invalid JSON in native tool-call arguments, using empty object"
|
||||
);
|
||||
"{}".to_string()
|
||||
};
|
||||
let normalized_arguments = Self::normalize_tool_arguments(arguments.clone());
|
||||
if normalized_arguments == "{}" && arguments != "{}" {
|
||||
tracing::warn!(
|
||||
function = %name,
|
||||
arguments = %arguments,
|
||||
"Invalid JSON in native tool-call arguments, using empty object"
|
||||
);
|
||||
}
|
||||
Some(ProviderToolCall {
|
||||
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name,
|
||||
@ -1727,7 +1963,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
|
||||
if self.should_use_responses_mode() {
|
||||
return self
|
||||
.chat_via_responses(credential, &fallback_messages, model)
|
||||
.chat_via_responses(credential, &fallback_messages, model, temperature)
|
||||
.await;
|
||||
}
|
||||
|
||||
@ -1741,7 +1977,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
if self.supports_responses_fallback {
|
||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||
return self
|
||||
.chat_via_responses(credential, &fallback_messages, model)
|
||||
.chat_via_responses(credential, &fallback_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1762,7 +1998,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(credential, &fallback_messages, model)
|
||||
.chat_via_responses(credential, &fallback_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1843,7 +2079,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
|
||||
if self.should_use_responses_mode() {
|
||||
return self
|
||||
.chat_via_responses(credential, &effective_messages, model)
|
||||
.chat_via_responses(credential, &effective_messages, model, temperature)
|
||||
.await;
|
||||
}
|
||||
|
||||
@ -1858,7 +2094,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
if self.supports_responses_fallback {
|
||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||
return self
|
||||
.chat_via_responses(credential, &effective_messages, model)
|
||||
.chat_via_responses(credential, &effective_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1878,7 +2114,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(credential, &effective_messages, model)
|
||||
.chat_via_responses(credential, &effective_messages, model, temperature)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
@ -1973,6 +2209,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
(!tools.is_empty()).then(|| tools.to_vec()),
|
||||
temperature,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -2026,6 +2263,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
(!tools.is_empty()).then(|| tools.to_vec()),
|
||||
temperature,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -2120,6 +2358,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
response_tools.clone(),
|
||||
temperature,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@ -2143,6 +2382,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
response_tools.clone(),
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
@ -2180,6 +2420,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||
&effective_messages,
|
||||
model,
|
||||
response_tools.clone(),
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
@ -2695,7 +2936,12 @@ mod tests {
|
||||
async fn chat_via_responses_requires_non_system_message() {
|
||||
let provider = make_provider("custom", "https://api.example.com", Some("test-key"));
|
||||
let err = provider
|
||||
.chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test")
|
||||
.chat_via_responses(
|
||||
"test-key",
|
||||
&[ChatMessage::system("policy")],
|
||||
"gpt-test",
|
||||
0.7,
|
||||
)
|
||||
.await
|
||||
.expect_err("system-only fallback payload should fail");
|
||||
|
||||
@ -2704,6 +2950,278 @@ mod tests {
|
||||
.contains("requires at least one non-system message"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_mode_falls_back_to_chat_completions_on_responses_404() {
|
||||
#[derive(Clone, Default)]
|
||||
struct FallbackState {
|
||||
hits: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
async fn responses_endpoint(
|
||||
State(state): State<FallbackState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("responses".to_string());
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "responses endpoint unavailable" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_endpoint(
|
||||
State(state): State<FallbackState>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("chat".to_string());
|
||||
assert_eq!(
|
||||
payload.get("model").and_then(Value::as_str),
|
||||
Some("test-model")
|
||||
);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "chat fallback ok"
|
||||
}
|
||||
}]
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
let state = FallbackState::default();
|
||||
let app = Router::new()
|
||||
.route("/v1/responses", post(responses_endpoint))
|
||||
.route("/chat/completions", post(chat_endpoint))
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind test server");
|
||||
let addr = listener.local_addr().expect("server local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.expect("serve test app");
|
||||
});
|
||||
|
||||
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
|
||||
"custom",
|
||||
&format!("http://{}", addr),
|
||||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
false,
|
||||
CompatibleApiMode::OpenAiResponses,
|
||||
None,
|
||||
);
|
||||
let text = provider
|
||||
.chat_with_system(Some("system"), "hello", "test-model", 0.2)
|
||||
.await
|
||||
.expect("responses 404 should retry chat completions in responses mode");
|
||||
assert_eq!(text, "chat fallback ok");
|
||||
|
||||
let hits = state.hits.lock().await.clone();
|
||||
assert_eq!(
|
||||
hits,
|
||||
vec!["responses".to_string(), "chat".to_string()],
|
||||
"must attempt responses first, then chat-completions fallback"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_mode_does_not_fallback_to_chat_completions_on_auth_error() {
|
||||
#[derive(Clone, Default)]
|
||||
struct AuthFailureState {
|
||||
hits: Arc<Mutex<Vec<String>>>,
|
||||
}
|
||||
|
||||
async fn responses_endpoint(
|
||||
State(state): State<AuthFailureState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("responses".to_string());
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "invalid api key" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_endpoint(
|
||||
State(state): State<AuthFailureState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("chat".to_string());
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": "should not be reached"
|
||||
}
|
||||
}]
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
let state = AuthFailureState::default();
|
||||
let app = Router::new()
|
||||
.route("/v1/responses", post(responses_endpoint))
|
||||
.route("/chat/completions", post(chat_endpoint))
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind test server");
|
||||
let addr = listener.local_addr().expect("server local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.expect("serve test app");
|
||||
});
|
||||
|
||||
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
|
||||
"custom",
|
||||
&format!("http://{}", addr),
|
||||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
false,
|
||||
CompatibleApiMode::OpenAiResponses,
|
||||
None,
|
||||
);
|
||||
let err = provider
|
||||
.chat_with_system(None, "hello", "test-model", 0.2)
|
||||
.await
|
||||
.expect_err("auth errors should not trigger chat-completions fallback");
|
||||
assert!(err.to_string().contains("401"));
|
||||
|
||||
let hits = state.hits.lock().await.clone();
|
||||
assert_eq!(
|
||||
hits,
|
||||
vec!["responses".to_string()],
|
||||
"auth failures must not trigger fallback chat attempt"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn responses_mode_native_chat_falls_back_and_preserves_tool_call_id() {
|
||||
#[derive(Clone, Default)]
|
||||
struct NativeFallbackState {
|
||||
hits: Arc<Mutex<Vec<String>>>,
|
||||
chat_payloads: Arc<Mutex<Vec<Value>>>,
|
||||
}
|
||||
|
||||
async fn responses_endpoint(
|
||||
State(state): State<NativeFallbackState>,
|
||||
Json(_payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("responses".to_string());
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": { "message": "responses backend unavailable" }
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_endpoint(
|
||||
State(state): State<NativeFallbackState>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> (StatusCode, Json<Value>) {
|
||||
state.hits.lock().await.push("chat".to_string());
|
||||
state.chat_payloads.lock().await.push(payload);
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": "{\"command\":\"pwd\"}"
|
||||
}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
let state = NativeFallbackState::default();
|
||||
let app = Router::new()
|
||||
.route("/v1/responses", post(responses_endpoint))
|
||||
.route("/chat/completions", post(chat_endpoint))
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("bind test server");
|
||||
let addr = listener.local_addr().expect("server local addr");
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.expect("serve test app");
|
||||
});
|
||||
|
||||
let provider = OpenAiCompatibleProvider::new_custom_with_mode(
|
||||
"custom",
|
||||
&format!("http://{}", addr),
|
||||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
false,
|
||||
CompatibleApiMode::OpenAiResponses,
|
||||
None,
|
||||
);
|
||||
let messages = vec![ChatMessage::user("run a command")];
|
||||
let tools = vec![crate::tools::ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run a command".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"command": {"type": "string"}},
|
||||
"required": ["command"]
|
||||
}),
|
||||
}];
|
||||
let result = provider
|
||||
.chat(
|
||||
ProviderChatRequest {
|
||||
messages: &messages,
|
||||
tools: Some(&tools),
|
||||
},
|
||||
"test-model",
|
||||
0.2,
|
||||
)
|
||||
.await
|
||||
.expect("responses server errors should retry via native chat-completions");
|
||||
|
||||
assert_eq!(result.tool_calls.len(), 1);
|
||||
assert_eq!(result.tool_calls[0].id, "call_abc");
|
||||
assert_eq!(result.tool_calls[0].name, "shell");
|
||||
|
||||
let hits = state.hits.lock().await.clone();
|
||||
assert_eq!(
|
||||
hits,
|
||||
vec!["responses".to_string(), "chat".to_string()],
|
||||
"responses mode should retry via chat for retryable errors"
|
||||
);
|
||||
|
||||
let chat_payloads = state.chat_payloads.lock().await;
|
||||
assert_eq!(chat_payloads.len(), 1);
|
||||
assert!(
|
||||
chat_payloads[0].get("tools").is_some(),
|
||||
"fallback native chat request should preserve tool schema"
|
||||
);
|
||||
|
||||
server.abort();
|
||||
let _ = server.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_function_name_falls_back_to_top_level_name() {
|
||||
let call: ToolCall = serde_json::from_value(serde_json::json!({
|
||||
@ -2970,18 +3488,80 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn convert_messages_for_native_maps_tool_result_payload() {
|
||||
let input = vec![ChatMessage::tool(
|
||||
r#"{"tool_call_id":"call_abc","content":"done"}"#,
|
||||
let input = vec![
|
||||
ChatMessage::assistant(
|
||||
r#"{"content":"","tool_calls":[{"id":"call_abc","name":"shell","arguments":"{}"}]}"#,
|
||||
),
|
||||
ChatMessage::tool(r#"{"tool_call_id":"call_abc","content":"done"}"#),
|
||||
];
|
||||
|
||||
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input, true);
|
||||
assert_eq!(converted.len(), 2);
|
||||
assert_eq!(converted[1].role, "tool");
|
||||
assert_eq!(converted[1].tool_call_id.as_deref(), Some("call_abc"));
|
||||
assert!(matches!(
|
||||
converted[1].content.as_ref(),
|
||||
Some(MessageContent::Text(value)) if value == "done"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_messages_for_native_parses_openai_style_assistant_tool_calls() {
|
||||
let input = vec![ChatMessage::assistant(
|
||||
r#"{
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": "call_openai_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": "{\"command\":\"pwd\"}"
|
||||
}
|
||||
}]
|
||||
}"#,
|
||||
)];
|
||||
|
||||
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input, true);
|
||||
assert_eq!(converted.len(), 1);
|
||||
assert_eq!(converted[0].role, "tool");
|
||||
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
|
||||
assert_eq!(converted[0].role, "assistant");
|
||||
assert!(matches!(
|
||||
converted[0].content.as_ref(),
|
||||
Some(MessageContent::Text(value)) if value == "done"
|
||||
Some(MessageContent::Text(value)) if value.is_empty()
|
||||
));
|
||||
|
||||
let calls = converted[0]
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.expect("assistant message should include tool_calls");
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].id.as_deref(), Some("call_openai_1"));
|
||||
assert!(matches!(
|
||||
calls[0].function.as_ref().and_then(|f| f.name.as_deref()),
|
||||
Some("shell")
|
||||
));
|
||||
assert!(matches!(
|
||||
calls[0]
|
||||
.function
|
||||
.as_ref()
|
||||
.and_then(|f| f.arguments.as_deref()),
|
||||
Some("{\"command\":\"pwd\"}")
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_messages_for_native_rewrites_orphan_tool_message_as_user() {
|
||||
let input = vec![ChatMessage::tool(
|
||||
r#"{"tool_call_id":"call_missing","content":"done"}"#,
|
||||
)];
|
||||
|
||||
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input, true);
|
||||
assert_eq!(converted.len(), 1);
|
||||
assert_eq!(converted[0].role, "user");
|
||||
assert!(matches!(
|
||||
converted[0].content.as_ref(),
|
||||
Some(MessageContent::Text(value)) if value.contains("[Tool result]") && value.contains("done")
|
||||
));
|
||||
assert!(converted[0].tool_call_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
|
||||
|
||||
use crate::auth::AuthService;
|
||||
use crate::multimodal;
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatResponse, NormalizedStopReason, Provider, TokenUsage,
|
||||
};
|
||||
@ -137,8 +138,22 @@ struct Content {
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
struct Part {
|
||||
text: String,
|
||||
#[serde(untagged)]
|
||||
enum Part {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
InlineData {
|
||||
#[serde(rename = "inlineData")]
|
||||
inline_data: InlineDataPart,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
struct InlineDataPart {
|
||||
#[serde(rename = "mimeType")]
|
||||
mime_type: String,
|
||||
data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
@ -934,6 +949,57 @@ impl GeminiProvider {
|
||||
|| status.is_server_error()
|
||||
|| error_text.contains("RESOURCE_EXHAUSTED")
|
||||
}
|
||||
|
||||
fn parse_inline_image_marker(image_ref: &str) -> Option<InlineDataPart> {
|
||||
let rest = image_ref.strip_prefix("data:")?;
|
||||
let semi_index = rest.find(';')?;
|
||||
let mime_type = rest[..semi_index].trim();
|
||||
if mime_type.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let payload = rest[semi_index + 1..].strip_prefix("base64,")?.trim();
|
||||
if payload.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(InlineDataPart {
|
||||
mime_type: mime_type.to_string(),
|
||||
data: payload.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn build_user_parts(content: &str) -> Vec<Part> {
|
||||
let (cleaned_text, image_refs) = multimodal::parse_image_markers(content);
|
||||
if image_refs.is_empty() {
|
||||
return vec![Part::Text {
|
||||
text: content.to_string(),
|
||||
}];
|
||||
}
|
||||
|
||||
let mut parts: Vec<Part> = Vec::with_capacity(image_refs.len() + 1);
|
||||
if !cleaned_text.is_empty() {
|
||||
parts.push(Part::Text { text: cleaned_text });
|
||||
}
|
||||
|
||||
for image_ref in image_refs {
|
||||
if let Some(inline_data) = Self::parse_inline_image_marker(&image_ref) {
|
||||
parts.push(Part::InlineData { inline_data });
|
||||
} else {
|
||||
parts.push(Part::Text {
|
||||
text: format!("[IMAGE:{image_ref}]"),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
vec![Part::Text {
|
||||
text: String::new(),
|
||||
}]
|
||||
} else {
|
||||
parts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GeminiProvider {
|
||||
@ -1167,16 +1233,14 @@ impl Provider for GeminiProvider {
|
||||
) -> anyhow::Result<String> {
|
||||
let system_instruction = system_prompt.map(|sys| Content {
|
||||
role: None,
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: sys.to_string(),
|
||||
}],
|
||||
});
|
||||
|
||||
let contents = vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: message.to_string(),
|
||||
}],
|
||||
parts: Self::build_user_parts(message),
|
||||
}];
|
||||
|
||||
let (text_opt, _usage, _stop_reason, _raw_stop_reason) = self
|
||||
@ -1203,16 +1267,14 @@ impl Provider for GeminiProvider {
|
||||
"user" => {
|
||||
contents.push(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
parts: Self::build_user_parts(&msg.content),
|
||||
});
|
||||
}
|
||||
"assistant" => {
|
||||
// Gemini API uses "model" role instead of "assistant"
|
||||
contents.push(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
});
|
||||
@ -1226,7 +1288,7 @@ impl Provider for GeminiProvider {
|
||||
} else {
|
||||
Some(Content {
|
||||
role: None,
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: system_parts.join("\n\n"),
|
||||
}],
|
||||
})
|
||||
@ -1253,13 +1315,11 @@ impl Provider for GeminiProvider {
|
||||
"system" => system_parts.push(&msg.content),
|
||||
"user" => contents.push(Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
parts: Self::build_user_parts(&msg.content),
|
||||
}),
|
||||
"assistant" => contents.push(Content {
|
||||
role: Some("model".to_string()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
}),
|
||||
@ -1272,7 +1332,7 @@ impl Provider for GeminiProvider {
|
||||
} else {
|
||||
Some(Content {
|
||||
role: None,
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: system_parts.join("\n\n"),
|
||||
}],
|
||||
})
|
||||
@ -1562,7 +1622,7 @@ mod tests {
|
||||
let body = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".into()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
@ -1603,7 +1663,7 @@ mod tests {
|
||||
let body = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".into()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
@ -1647,7 +1707,7 @@ mod tests {
|
||||
let body = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".into()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
@ -1679,13 +1739,13 @@ mod tests {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "Hello".to_string(),
|
||||
}],
|
||||
}],
|
||||
system_instruction: Some(Content {
|
||||
role: None,
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "You are helpful".to_string(),
|
||||
}],
|
||||
}),
|
||||
@ -1704,6 +1764,74 @@ mod tests {
|
||||
assert!(json.contains("\"maxOutputTokens\":8192"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_user_parts_text_only_is_backward_compatible() {
|
||||
let content = "Plain text message without image markers.";
|
||||
let parts = GeminiProvider::build_user_parts(content);
|
||||
assert_eq!(parts.len(), 1);
|
||||
match &parts[0] {
|
||||
Part::Text { text } => assert_eq!(text, content),
|
||||
Part::InlineData { .. } => panic!("text-only message must stay text-only"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_user_parts_single_image() {
|
||||
let parts = GeminiProvider::build_user_parts(
|
||||
"Describe this image [IMAGE:data:image/png;base64,aGVsbG8=]",
|
||||
);
|
||||
assert_eq!(parts.len(), 2);
|
||||
match &parts[0] {
|
||||
Part::Text { text } => assert_eq!(text, "Describe this image"),
|
||||
Part::InlineData { .. } => panic!("first part should be text"),
|
||||
}
|
||||
match &parts[1] {
|
||||
Part::InlineData { inline_data } => {
|
||||
assert_eq!(inline_data.mime_type, "image/png");
|
||||
assert_eq!(inline_data.data, "aGVsbG8=");
|
||||
}
|
||||
Part::Text { .. } => panic!("second part should be inline image data"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_user_parts_multiple_images() {
|
||||
let parts = GeminiProvider::build_user_parts(
|
||||
"Compare [IMAGE:data:image/png;base64,aQ==] and [IMAGE:data:image/jpeg;base64,ag==]",
|
||||
);
|
||||
assert_eq!(parts.len(), 3);
|
||||
assert!(matches!(parts[0], Part::Text { .. }));
|
||||
assert!(matches!(parts[1], Part::InlineData { .. }));
|
||||
assert!(matches!(parts[2], Part::InlineData { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_user_parts_image_only() {
|
||||
let parts = GeminiProvider::build_user_parts("[IMAGE:data:image/webp;base64,YWJjZA==]");
|
||||
assert_eq!(parts.len(), 1);
|
||||
match &parts[0] {
|
||||
Part::InlineData { inline_data } => {
|
||||
assert_eq!(inline_data.mime_type, "image/webp");
|
||||
assert_eq!(inline_data.data, "YWJjZA==");
|
||||
}
|
||||
Part::Text { .. } => panic!("image-only message should create inline image part"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_user_parts_fallback_for_non_data_uri_markers() {
|
||||
let parts = GeminiProvider::build_user_parts("Inspect [IMAGE:https://example.com/img.png]");
|
||||
assert_eq!(parts.len(), 2);
|
||||
match &parts[0] {
|
||||
Part::Text { text } => assert_eq!(text, "Inspect"),
|
||||
Part::InlineData { .. } => panic!("first part should be text"),
|
||||
}
|
||||
match &parts[1] {
|
||||
Part::Text { text } => assert_eq!(text, "[IMAGE:https://example.com/img.png]"),
|
||||
Part::InlineData { .. } => panic!("invalid markers should fall back to text"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn internal_request_includes_model() {
|
||||
let request = InternalGenerateContentEnvelope {
|
||||
@ -1713,7 +1841,7 @@ mod tests {
|
||||
request: InternalGenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "Hello".to_string(),
|
||||
}],
|
||||
}],
|
||||
@ -1745,7 +1873,7 @@ mod tests {
|
||||
request: InternalGenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "Hello".to_string(),
|
||||
}],
|
||||
}],
|
||||
@ -1768,7 +1896,7 @@ mod tests {
|
||||
request: InternalGenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
parts: vec![Part::Text {
|
||||
text: "Hello".to_string(),
|
||||
}],
|
||||
}],
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
use parking_lot::Mutex;
|
||||
use reqwest::Url;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
@ -47,6 +48,24 @@ pub enum ToolOperation {
|
||||
Act,
|
||||
}
|
||||
|
||||
/// Action applied when a command context rule matches.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CommandContextRuleAction {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
/// Context-aware allow/deny rule for shell commands.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CommandContextRule {
|
||||
pub command: String,
|
||||
pub action: CommandContextRuleAction,
|
||||
pub allowed_domains: Vec<String>,
|
||||
pub allowed_path_prefixes: Vec<String>,
|
||||
pub denied_path_prefixes: Vec<String>,
|
||||
pub allow_high_risk: bool,
|
||||
}
|
||||
|
||||
/// Sliding-window action tracker for rate limiting.
|
||||
#[derive(Debug)]
|
||||
pub struct ActionTracker {
|
||||
@ -99,6 +118,7 @@ pub struct SecurityPolicy {
|
||||
pub workspace_dir: PathBuf,
|
||||
pub workspace_only: bool,
|
||||
pub allowed_commands: Vec<String>,
|
||||
pub command_context_rules: Vec<CommandContextRule>,
|
||||
pub forbidden_paths: Vec<String>,
|
||||
pub allowed_roots: Vec<PathBuf>,
|
||||
pub max_actions_per_hour: u32,
|
||||
@ -132,6 +152,7 @@ impl Default for SecurityPolicy {
|
||||
"tail".into(),
|
||||
"date".into(),
|
||||
],
|
||||
command_context_rules: Vec::new(),
|
||||
forbidden_paths: vec![
|
||||
// System directories (blocked even when workspace_only=false)
|
||||
"/etc".into(),
|
||||
@ -565,7 +586,366 @@ fn is_allowlist_entry_match(allowed: &str, executable: &str, executable_base: &s
|
||||
allowed == executable_base
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum SegmentRuleDecision {
|
||||
NoMatch,
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct SegmentRuleOutcome {
|
||||
decision: SegmentRuleDecision,
|
||||
allow_high_risk: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
struct CommandAllowlistEvaluation {
|
||||
high_risk_overridden: bool,
|
||||
}
|
||||
|
||||
fn is_high_risk_base_command(base: &str) -> bool {
|
||||
matches!(
|
||||
base,
|
||||
"rm" | "mkfs"
|
||||
| "dd"
|
||||
| "shutdown"
|
||||
| "reboot"
|
||||
| "halt"
|
||||
| "poweroff"
|
||||
| "sudo"
|
||||
| "su"
|
||||
| "chown"
|
||||
| "chmod"
|
||||
| "useradd"
|
||||
| "userdel"
|
||||
| "usermod"
|
||||
| "passwd"
|
||||
| "mount"
|
||||
| "umount"
|
||||
| "iptables"
|
||||
| "ufw"
|
||||
| "firewall-cmd"
|
||||
| "curl"
|
||||
| "wget"
|
||||
| "nc"
|
||||
| "ncat"
|
||||
| "netcat"
|
||||
| "scp"
|
||||
| "ssh"
|
||||
| "ftp"
|
||||
| "telnet"
|
||||
)
|
||||
}
|
||||
|
||||
impl SecurityPolicy {
|
||||
fn path_matches_rule_prefix(&self, candidate: &str, prefix: &str) -> bool {
|
||||
let candidate_path = expand_user_path(candidate);
|
||||
let prefix_path = expand_user_path(prefix);
|
||||
|
||||
let normalized_candidate = if candidate_path.is_absolute() {
|
||||
candidate_path
|
||||
} else {
|
||||
self.workspace_dir.join(candidate_path)
|
||||
};
|
||||
let normalized_prefix = if prefix_path.is_absolute() {
|
||||
prefix_path
|
||||
} else {
|
||||
self.workspace_dir.join(prefix_path)
|
||||
};
|
||||
|
||||
normalized_candidate.starts_with(&normalized_prefix)
|
||||
}
|
||||
|
||||
fn host_matches_pattern(host: &str, pattern: &str) -> bool {
|
||||
let host = host.trim().to_ascii_lowercase();
|
||||
let pattern = pattern.trim().to_ascii_lowercase();
|
||||
if host.is_empty() || pattern.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(suffix) = pattern.strip_prefix("*.") {
|
||||
host == suffix || host.ends_with(&format!(".{suffix}"))
|
||||
} else {
|
||||
host == pattern
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_segment_url_hosts(args: &[&str]) -> Vec<String> {
|
||||
args.iter()
|
||||
.filter_map(|token| {
|
||||
let candidate = strip_wrapping_quotes(token)
|
||||
.trim()
|
||||
.trim_matches(|c: char| matches!(c, ',' | ';'));
|
||||
if candidate.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Url::parse(candidate)
|
||||
.ok()
|
||||
.and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn extract_segment_path_args(args: &[&str]) -> Vec<String> {
|
||||
let mut paths = Vec::new();
|
||||
|
||||
for token in args {
|
||||
let candidate = strip_wrapping_quotes(token).trim();
|
||||
if candidate.is_empty() || candidate.contains("://") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(target) = redirection_target(candidate) {
|
||||
let normalized = strip_wrapping_quotes(target).trim();
|
||||
if !normalized.is_empty() && looks_like_path(normalized) {
|
||||
paths.push(normalized.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if candidate.starts_with('-') {
|
||||
if let Some((_, value)) = candidate.split_once('=') {
|
||||
let normalized = strip_wrapping_quotes(value).trim();
|
||||
if !normalized.is_empty()
|
||||
&& !normalized.contains("://")
|
||||
&& looks_like_path(normalized)
|
||||
{
|
||||
paths.push(normalized.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(value) = attached_short_option_value(candidate) {
|
||||
let normalized = strip_wrapping_quotes(value).trim();
|
||||
if !normalized.is_empty()
|
||||
&& !normalized.contains("://")
|
||||
&& looks_like_path(normalized)
|
||||
{
|
||||
paths.push(normalized.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if looks_like_path(candidate) {
|
||||
paths.push(candidate.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
paths
|
||||
}
|
||||
|
||||
fn rule_conditions_match(&self, rule: &CommandContextRule, args: &[&str]) -> bool {
|
||||
if !rule.allowed_domains.is_empty() {
|
||||
let hosts = Self::extract_segment_url_hosts(args);
|
||||
if hosts.is_empty() {
|
||||
return false;
|
||||
}
|
||||
if !hosts.iter().all(|host| {
|
||||
rule.allowed_domains
|
||||
.iter()
|
||||
.any(|pattern| Self::host_matches_pattern(host, pattern))
|
||||
}) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
let path_args =
|
||||
if rule.allowed_path_prefixes.is_empty() && rule.denied_path_prefixes.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
Self::extract_segment_path_args(args)
|
||||
};
|
||||
|
||||
if !rule.allowed_path_prefixes.is_empty() {
|
||||
if path_args.is_empty() {
|
||||
return false;
|
||||
}
|
||||
if !path_args.iter().all(|path| {
|
||||
rule.allowed_path_prefixes
|
||||
.iter()
|
||||
.any(|prefix| self.path_matches_rule_prefix(path, prefix))
|
||||
}) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if !rule.denied_path_prefixes.is_empty() {
|
||||
if path_args.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let has_denied_path = path_args.iter().any(|path| {
|
||||
rule.denied_path_prefixes
|
||||
.iter()
|
||||
.any(|prefix| self.path_matches_rule_prefix(path, prefix))
|
||||
});
|
||||
match rule.action {
|
||||
CommandContextRuleAction::Allow => {
|
||||
if has_denied_path {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
CommandContextRuleAction::Deny => {
|
||||
if !has_denied_path {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn evaluate_segment_context_rules(
|
||||
&self,
|
||||
executable: &str,
|
||||
base_cmd: &str,
|
||||
args: &[&str],
|
||||
) -> SegmentRuleOutcome {
|
||||
let mut has_allow_rules = false;
|
||||
let mut allow_match = false;
|
||||
let mut allow_high_risk = false;
|
||||
|
||||
for rule in &self.command_context_rules {
|
||||
if !is_allowlist_entry_match(&rule.command, executable, base_cmd) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(rule.action, CommandContextRuleAction::Allow) {
|
||||
has_allow_rules = true;
|
||||
}
|
||||
|
||||
if !self.rule_conditions_match(rule, args) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match rule.action {
|
||||
CommandContextRuleAction::Deny => {
|
||||
return SegmentRuleOutcome {
|
||||
decision: SegmentRuleDecision::Deny,
|
||||
allow_high_risk: false,
|
||||
};
|
||||
}
|
||||
CommandContextRuleAction::Allow => {
|
||||
allow_match = true;
|
||||
allow_high_risk |= rule.allow_high_risk;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_allow_rules {
|
||||
if allow_match {
|
||||
SegmentRuleOutcome {
|
||||
decision: SegmentRuleDecision::Allow,
|
||||
allow_high_risk,
|
||||
}
|
||||
} else {
|
||||
SegmentRuleOutcome {
|
||||
decision: SegmentRuleDecision::Deny,
|
||||
allow_high_risk: false,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SegmentRuleOutcome {
|
||||
decision: SegmentRuleDecision::NoMatch,
|
||||
allow_high_risk: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_command_allowlist(
|
||||
&self,
|
||||
command: &str,
|
||||
) -> Result<CommandAllowlistEvaluation, String> {
|
||||
if self.autonomy == AutonomyLevel::ReadOnly {
|
||||
return Err("readonly autonomy level blocks shell command execution".into());
|
||||
}
|
||||
|
||||
if command.contains('`')
|
||||
|| contains_unquoted_shell_variable_expansion(command)
|
||||
|| command.contains("<(")
|
||||
|| command.contains(">(")
|
||||
{
|
||||
return Err("command contains disallowed shell expansion syntax".into());
|
||||
}
|
||||
|
||||
if contains_unquoted_char(command, '>') || contains_unquoted_char(command, '<') {
|
||||
return Err("command contains disallowed redirection syntax".into());
|
||||
}
|
||||
|
||||
if command
|
||||
.split_whitespace()
|
||||
.any(|w| w == "tee" || w.ends_with("/tee"))
|
||||
{
|
||||
return Err("command contains disallowed tee usage".into());
|
||||
}
|
||||
|
||||
if contains_unquoted_single_ampersand(command) {
|
||||
return Err("command contains disallowed background chaining operator '&'".into());
|
||||
}
|
||||
|
||||
let segments = split_unquoted_segments(command);
|
||||
let mut has_cmd = false;
|
||||
let mut saw_high_risk_segment = false;
|
||||
let mut all_high_risk_segments_overridden = true;
|
||||
|
||||
for segment in &segments {
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
|
||||
let base_cmd = executable.rsplit('/').next().unwrap_or("").trim();
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
}
|
||||
has_cmd = true;
|
||||
|
||||
let args_raw: Vec<&str> = words.collect();
|
||||
let args_lower: Vec<String> = args_raw.iter().map(|w| w.to_ascii_lowercase()).collect();
|
||||
|
||||
let context_outcome =
|
||||
self.evaluate_segment_context_rules(executable, base_cmd, &args_raw);
|
||||
if context_outcome.decision == SegmentRuleDecision::Deny {
|
||||
return Err(format!("context rule denied command segment `{base_cmd}`"));
|
||||
}
|
||||
|
||||
if context_outcome.decision != SegmentRuleDecision::Allow
|
||||
&& !self
|
||||
.allowed_commands
|
||||
.iter()
|
||||
.any(|allowed| is_allowlist_entry_match(allowed, executable, base_cmd))
|
||||
{
|
||||
return Err(format!(
|
||||
"command segment `{base_cmd}` is not present in allowed_commands"
|
||||
));
|
||||
}
|
||||
|
||||
if !self.is_args_safe(base_cmd, &args_lower) {
|
||||
return Err(format!(
|
||||
"command segment `{base_cmd}` contains unsafe arguments"
|
||||
));
|
||||
}
|
||||
|
||||
let base_lower = base_cmd.to_ascii_lowercase();
|
||||
if is_high_risk_base_command(&base_lower) {
|
||||
saw_high_risk_segment = true;
|
||||
if !(context_outcome.decision == SegmentRuleDecision::Allow
|
||||
&& context_outcome.allow_high_risk)
|
||||
{
|
||||
all_high_risk_segments_overridden = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_cmd {
|
||||
return Err("command is empty after parsing".into());
|
||||
}
|
||||
|
||||
Ok(CommandAllowlistEvaluation {
|
||||
high_risk_overridden: saw_high_risk_segment && all_high_risk_segments_overridden,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Risk Classification ──────────────────────────────────────────────
|
||||
// Risk is assessed per-segment (split on shell operators), and the
|
||||
// highest risk across all segments wins. This prevents bypasses like
|
||||
@ -592,37 +972,7 @@ impl SecurityPolicy {
|
||||
let joined_segment = cmd_part.to_ascii_lowercase();
|
||||
|
||||
// High-risk commands
|
||||
if matches!(
|
||||
base.as_str(),
|
||||
"rm" | "mkfs"
|
||||
| "dd"
|
||||
| "shutdown"
|
||||
| "reboot"
|
||||
| "halt"
|
||||
| "poweroff"
|
||||
| "sudo"
|
||||
| "su"
|
||||
| "chown"
|
||||
| "chmod"
|
||||
| "useradd"
|
||||
| "userdel"
|
||||
| "usermod"
|
||||
| "passwd"
|
||||
| "mount"
|
||||
| "umount"
|
||||
| "iptables"
|
||||
| "ufw"
|
||||
| "firewall-cmd"
|
||||
| "curl"
|
||||
| "wget"
|
||||
| "nc"
|
||||
| "ncat"
|
||||
| "netcat"
|
||||
| "scp"
|
||||
| "ssh"
|
||||
| "ftp"
|
||||
| "telnet"
|
||||
) {
|
||||
if is_high_risk_base_command(base.as_str()) {
|
||||
return CommandRiskLevel::High;
|
||||
}
|
||||
|
||||
@ -693,9 +1043,9 @@ impl SecurityPolicy {
|
||||
command: &str,
|
||||
approved: bool,
|
||||
) -> Result<CommandRiskLevel, String> {
|
||||
if !self.is_command_allowed(command) {
|
||||
return Err(format!("Command not allowed by security policy: {command}"));
|
||||
}
|
||||
let allowlist_eval = self
|
||||
.evaluate_command_allowlist(command)
|
||||
.map_err(|reason| format!("Command not allowed by security policy: {reason}"))?;
|
||||
|
||||
if let Some(path) = self.forbidden_path_argument(command) {
|
||||
return Err(format!("Path blocked by security policy: {path}"));
|
||||
@ -704,7 +1054,7 @@ impl SecurityPolicy {
|
||||
let risk = self.command_risk_level(command);
|
||||
|
||||
if risk == CommandRiskLevel::High {
|
||||
if self.block_high_risk_commands {
|
||||
if self.block_high_risk_commands && !allowlist_eval.high_risk_overridden {
|
||||
let lower = command.to_ascii_lowercase();
|
||||
if lower.contains("curl") || lower.contains("wget") {
|
||||
return Err(
|
||||
@ -750,81 +1100,7 @@ impl SecurityPolicy {
|
||||
/// - Blocks shell redirections (`<`, `>`, `>>`) that can bypass path policy
|
||||
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
|
||||
pub fn is_command_allowed(&self, command: &str) -> bool {
|
||||
if self.autonomy == AutonomyLevel::ReadOnly {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block subshell/expansion operators — these allow hiding arbitrary
|
||||
// commands inside an allowed command (e.g. `echo $(rm -rf /)`) and
|
||||
// bypassing path checks through variable indirection. The helper below
|
||||
// ignores escapes and literals inside single quotes, so `$(` or `${`
|
||||
// literals are permitted there.
|
||||
if command.contains('`')
|
||||
|| contains_unquoted_shell_variable_expansion(command)
|
||||
|| command.contains("<(")
|
||||
|| command.contains(">(")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block shell redirections (`<`, `>`, `>>`) — they can read/write
|
||||
// arbitrary paths and bypass path checks.
|
||||
// Ignore quoted literals, e.g. `echo "a>b"` and `echo "a<b"`.
|
||||
if contains_unquoted_char(command, '>') || contains_unquoted_char(command, '<') {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block `tee` — it can write to arbitrary files, bypassing the
|
||||
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
|
||||
if command
|
||||
.split_whitespace()
|
||||
.any(|w| w == "tee" || w.ends_with("/tee"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block background command chaining (`&`), which can hide extra
|
||||
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
|
||||
if contains_unquoted_single_ampersand(command) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split on unquoted command separators and validate each sub-command.
|
||||
let segments = split_unquoted_segments(command);
|
||||
for segment in &segments {
|
||||
// Strip leading env var assignments (e.g. FOO=bar cmd)
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let executable = strip_wrapping_quotes(words.next().unwrap_or("")).trim();
|
||||
let base_cmd = executable.rsplit('/').next().unwrap_or("");
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !self
|
||||
.allowed_commands
|
||||
.iter()
|
||||
.any(|allowed| is_allowlist_entry_match(allowed, executable, base_cmd))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate arguments for the command
|
||||
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
|
||||
if !self.is_args_safe(base_cmd, &args) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At least one command must be present
|
||||
let has_cmd = segments.iter().any(|s| {
|
||||
let s = skip_env_assignments(s.trim());
|
||||
s.split_whitespace().next().is_some_and(|w| !w.is_empty())
|
||||
});
|
||||
|
||||
has_cmd
|
||||
self.evaluate_command_allowlist(command).is_ok()
|
||||
}
|
||||
|
||||
/// Check for dangerous arguments that allow sub-command execution.
|
||||
@ -1214,6 +1490,11 @@ impl SecurityPolicy {
|
||||
format!("{} (others rejected)", shown.join(", "))
|
||||
}
|
||||
};
|
||||
let context_rules = if self.command_context_rules.is_empty() {
|
||||
"none".to_string()
|
||||
} else {
|
||||
format!("{} configured", self.command_context_rules.len())
|
||||
};
|
||||
|
||||
let high_risk = if self.block_high_risk_commands {
|
||||
"blocked"
|
||||
@ -1226,6 +1507,7 @@ impl SecurityPolicy {
|
||||
- Workspace: {workspace} (workspace_only: {ws_only})\n\
|
||||
- Forbidden paths: {forbidden_preview}\n\
|
||||
- Allowed commands: {commands_preview}\n\
|
||||
- Command context rules: {context_rules}\n\
|
||||
- High-risk commands: {high_risk}\n\
|
||||
- Do not exfiltrate data, bypass approval, or run destructive commands without asking."
|
||||
)
|
||||
@ -1240,6 +1522,25 @@ impl SecurityPolicy {
|
||||
workspace_dir: workspace_dir.to_path_buf(),
|
||||
workspace_only: autonomy_config.workspace_only,
|
||||
allowed_commands: autonomy_config.allowed_commands.clone(),
|
||||
command_context_rules: autonomy_config
|
||||
.command_context_rules
|
||||
.iter()
|
||||
.map(|rule| CommandContextRule {
|
||||
command: rule.command.clone(),
|
||||
action: match rule.action {
|
||||
crate::config::CommandContextRuleAction::Allow => {
|
||||
CommandContextRuleAction::Allow
|
||||
}
|
||||
crate::config::CommandContextRuleAction::Deny => {
|
||||
CommandContextRuleAction::Deny
|
||||
}
|
||||
},
|
||||
allowed_domains: rule.allowed_domains.clone(),
|
||||
allowed_path_prefixes: rule.allowed_path_prefixes.clone(),
|
||||
denied_path_prefixes: rule.denied_path_prefixes.clone(),
|
||||
allow_high_risk: rule.allow_high_risk,
|
||||
})
|
||||
.collect(),
|
||||
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
||||
allowed_roots: autonomy_config
|
||||
.allowed_roots
|
||||
@ -1461,6 +1762,102 @@ mod tests {
|
||||
assert!(!p.is_command_allowed("echo hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_allow_rule_overrides_global_allowlist_for_curl_domain() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec![],
|
||||
command_context_rules: vec![CommandContextRule {
|
||||
command: "curl".into(),
|
||||
action: CommandContextRuleAction::Allow,
|
||||
allowed_domains: vec!["api.example.com".into()],
|
||||
allowed_path_prefixes: vec![],
|
||||
denied_path_prefixes: vec![],
|
||||
allow_high_risk: true,
|
||||
}],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
assert!(p.is_command_allowed("curl https://api.example.com/v1/health"));
|
||||
assert!(p
|
||||
.validate_command_execution("curl https://api.example.com/v1/health", true)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_allow_rule_restricts_curl_to_matching_domains() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
command_context_rules: vec![CommandContextRule {
|
||||
command: "curl".into(),
|
||||
action: CommandContextRuleAction::Allow,
|
||||
allowed_domains: vec!["api.example.com".into()],
|
||||
allowed_path_prefixes: vec![],
|
||||
denied_path_prefixes: vec![],
|
||||
allow_high_risk: true,
|
||||
}],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
assert!(!p.is_command_allowed("curl https://evil.example.com/steal"));
|
||||
let err = p
|
||||
.validate_command_execution("curl https://evil.example.com/steal", true)
|
||||
.expect_err("non-matching domains should be denied by context rules");
|
||||
assert!(err.contains("context rule denied"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_allow_rule_restricts_rm_to_allowed_path_prefix() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_only: false,
|
||||
allowed_commands: vec!["rm".into()],
|
||||
forbidden_paths: vec![],
|
||||
command_context_rules: vec![CommandContextRule {
|
||||
command: "rm".into(),
|
||||
action: CommandContextRuleAction::Allow,
|
||||
allowed_domains: vec![],
|
||||
allowed_path_prefixes: vec!["/tmp".into()],
|
||||
denied_path_prefixes: vec![],
|
||||
allow_high_risk: true,
|
||||
}],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
assert!(p.is_command_allowed("rm -rf /tmp/cleanup"));
|
||||
assert!(p
|
||||
.validate_command_execution("rm -rf /tmp/cleanup", true)
|
||||
.is_ok());
|
||||
|
||||
assert!(!p.is_command_allowed("rm -rf /var/log"));
|
||||
let err = p
|
||||
.validate_command_execution("rm -rf /var/log", true)
|
||||
.expect_err("paths outside /tmp should be denied");
|
||||
assert!(err.contains("context rule denied"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn context_deny_rule_can_block_specific_domain_even_when_allowlisted() {
|
||||
let p = SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
block_high_risk_commands: false,
|
||||
allowed_commands: vec!["curl".into()],
|
||||
command_context_rules: vec![CommandContextRule {
|
||||
command: "curl".into(),
|
||||
action: CommandContextRuleAction::Deny,
|
||||
allowed_domains: vec!["evil.example.com".into()],
|
||||
allowed_path_prefixes: vec![],
|
||||
denied_path_prefixes: vec![],
|
||||
allow_high_risk: false,
|
||||
}],
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
|
||||
assert!(p.is_command_allowed("curl https://api.example.com/v1/health"));
|
||||
assert!(!p.is_command_allowed("curl https://evil.example.com/steal"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_risk_low_for_read_commands() {
|
||||
let p = default_policy();
|
||||
|
||||
@ -9,10 +9,12 @@ use std::sync::Arc;
|
||||
|
||||
/// Edit a file by replacing an exact string match with new content.
|
||||
///
|
||||
/// Uses `old_string` → `new_string` precise replacement within the workspace.
|
||||
/// The `old_string` must appear exactly once in the file (zero matches = not
|
||||
/// found, multiple matches = ambiguous). `new_string` may be empty to delete
|
||||
/// the matched text. Security checks mirror [`super::file_write::FileWriteTool`].
|
||||
/// Uses `old_string` → `new_string` replacement within the workspace.
|
||||
/// Exact matching is preferred and unchanged. When exact matching finds zero
|
||||
/// matches, the tool falls back to whitespace-flexible line matching.
|
||||
/// The final match must still be unique (zero matches = not found, multiple
|
||||
/// matches = ambiguous). `new_string` may be empty to delete the matched text.
|
||||
/// Security checks mirror [`super::file_write::FileWriteTool`].
|
||||
pub struct FileEditTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
@ -38,6 +40,169 @@ fn hard_link_edit_block_message(path: &Path) -> String {
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct LineSpan {
|
||||
start: usize,
|
||||
content_end: usize,
|
||||
end: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct MatchOutcome {
|
||||
start: usize,
|
||||
end: usize,
|
||||
used_whitespace_flex: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum FlexibleLineMatch {
|
||||
NoMatch,
|
||||
Unique { start: usize, end: usize },
|
||||
Ambiguous { count: usize },
|
||||
}
|
||||
|
||||
fn normalize_line(line: &str) -> String {
|
||||
let trimmed = line.trim_end_matches([' ', '\t']);
|
||||
let mut normalized = String::with_capacity(trimmed.len());
|
||||
let mut in_whitespace_run = false;
|
||||
|
||||
for ch in trimmed.chars() {
|
||||
if ch == ' ' || ch == '\t' {
|
||||
if !in_whitespace_run {
|
||||
normalized.push(' ');
|
||||
in_whitespace_run = true;
|
||||
}
|
||||
} else {
|
||||
normalized.push(ch);
|
||||
in_whitespace_run = false;
|
||||
}
|
||||
}
|
||||
|
||||
normalized
|
||||
}
|
||||
|
||||
fn compute_line_spans(content: &str) -> Vec<LineSpan> {
|
||||
let mut spans = Vec::new();
|
||||
let bytes = content.as_bytes();
|
||||
let mut line_start = 0usize;
|
||||
|
||||
for (idx, byte) in bytes.iter().enumerate() {
|
||||
if *byte == b'\n' {
|
||||
let mut content_end = idx;
|
||||
if content_end > line_start && bytes[content_end - 1] == b'\r' {
|
||||
content_end -= 1;
|
||||
}
|
||||
spans.push(LineSpan {
|
||||
start: line_start,
|
||||
content_end,
|
||||
end: idx + 1,
|
||||
});
|
||||
line_start = idx + 1;
|
||||
}
|
||||
}
|
||||
|
||||
if line_start < content.len() {
|
||||
spans.push(LineSpan {
|
||||
start: line_start,
|
||||
content_end: content.len(),
|
||||
end: content.len(),
|
||||
});
|
||||
}
|
||||
|
||||
spans
|
||||
}
|
||||
|
||||
fn try_flexible_line_match(content: &str, old_string: &str) -> FlexibleLineMatch {
|
||||
let content_spans = compute_line_spans(content);
|
||||
let old_spans = compute_line_spans(old_string);
|
||||
|
||||
if old_spans.is_empty() || content_spans.len() < old_spans.len() {
|
||||
return FlexibleLineMatch::NoMatch;
|
||||
}
|
||||
|
||||
let normalized_old_lines: Vec<String> = old_spans
|
||||
.iter()
|
||||
.map(|span| normalize_line(&old_string[span.start..span.content_end]))
|
||||
.collect();
|
||||
let normalized_content_lines: Vec<String> = content_spans
|
||||
.iter()
|
||||
.map(|span| normalize_line(&content[span.start..span.content_end]))
|
||||
.collect();
|
||||
|
||||
let mut match_count = 0usize;
|
||||
let mut matched_start_line = 0usize;
|
||||
let window_size = old_spans.len();
|
||||
|
||||
for start_line in 0..=(content_spans.len() - window_size) {
|
||||
let mut window_matches = true;
|
||||
for line_offset in 0..window_size {
|
||||
if normalized_content_lines[start_line + line_offset]
|
||||
!= normalized_old_lines[line_offset]
|
||||
{
|
||||
window_matches = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if window_matches {
|
||||
match_count += 1;
|
||||
if match_count == 1 {
|
||||
matched_start_line = start_line;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if match_count == 0 {
|
||||
return FlexibleLineMatch::NoMatch;
|
||||
}
|
||||
|
||||
if match_count > 1 {
|
||||
return FlexibleLineMatch::Ambiguous { count: match_count };
|
||||
}
|
||||
|
||||
let first_span = content_spans[matched_start_line];
|
||||
let last_span = content_spans[matched_start_line + window_size - 1];
|
||||
let end = if old_string.ends_with('\n') {
|
||||
last_span.end
|
||||
} else {
|
||||
last_span.content_end
|
||||
};
|
||||
|
||||
FlexibleLineMatch::Unique {
|
||||
start: first_span.start,
|
||||
end,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_match(content: &str, old_string: &str) -> Result<MatchOutcome, String> {
|
||||
let mut exact_matches = content.match_indices(old_string);
|
||||
if let Some((start, _)) = exact_matches.next() {
|
||||
if exact_matches.next().is_some() {
|
||||
let match_count = 2 + exact_matches.count();
|
||||
return Err(format!(
|
||||
"old_string matches {match_count} times; must match exactly once"
|
||||
));
|
||||
}
|
||||
return Ok(MatchOutcome {
|
||||
start,
|
||||
end: start + old_string.len(),
|
||||
used_whitespace_flex: false,
|
||||
});
|
||||
}
|
||||
|
||||
match try_flexible_line_match(content, old_string) {
|
||||
FlexibleLineMatch::NoMatch => Err("old_string not found in file".into()),
|
||||
FlexibleLineMatch::Ambiguous { count } => Err(format!(
|
||||
"old_string matches {count} times with whitespace flexibility; must match exactly once"
|
||||
)),
|
||||
FlexibleLineMatch::Unique { start, end } => Ok(MatchOutcome {
|
||||
start,
|
||||
end,
|
||||
used_whitespace_flex: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FileEditTool {
|
||||
fn name(&self) -> &str {
|
||||
@ -45,7 +210,7 @@ impl Tool for FileEditTool {
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Edit a file by replacing an exact string match with new content. Sensitive files (for example .env and key material) are blocked by default."
|
||||
"Edit a file by replacing text in a file. Exact matching is preferred; if exact matching fails, whitespace-flexible line matching is used. Sensitive files (for example .env and key material) are blocked by default."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
@ -58,7 +223,7 @@ impl Tool for FileEditTool {
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "The exact text to find and replace (must appear exactly once in the file)"
|
||||
"description": "The text to find and replace. Exact matching is attempted first; if no exact match is found, whitespace-flexible line matching is attempted."
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
@ -226,34 +391,43 @@ impl Tool for FileEditTool {
|
||||
}
|
||||
};
|
||||
|
||||
let match_count = content.matches(old_string).count();
|
||||
let match_outcome = match resolve_match(&content, old_string) {
|
||||
Ok(outcome) => outcome,
|
||||
Err(error) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if match_count == 0 {
|
||||
if match_outcome.end < match_outcome.start || match_outcome.end > content.len() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("old_string not found in file".into()),
|
||||
error: Some("Internal matching error: invalid replacement range".into()),
|
||||
});
|
||||
}
|
||||
|
||||
if match_count > 1 {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"old_string matches {match_count} times; must match exactly once"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let new_content = content.replacen(old_string, new_string, 1);
|
||||
let mut new_content = String::with_capacity(
|
||||
content.len() - (match_outcome.end - match_outcome.start) + new_string.len(),
|
||||
);
|
||||
new_content.push_str(&content[..match_outcome.start]);
|
||||
new_content.push_str(new_string);
|
||||
new_content.push_str(&content[match_outcome.end..]);
|
||||
|
||||
match tokio::fs::write(&resolved_target, &new_content).await {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Edited {path}: replaced 1 occurrence ({} bytes)",
|
||||
new_content.len()
|
||||
"Edited {path}: replaced 1 occurrence ({} bytes){}",
|
||||
new_content.len(),
|
||||
if match_outcome.used_whitespace_flex {
|
||||
" (matched with whitespace flexibility)"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
),
|
||||
error: None,
|
||||
}),
|
||||
@ -384,6 +558,272 @@ mod tests {
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_indentation_difference() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_indent");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(
|
||||
dir.join("test.txt"),
|
||||
"fn main() {\n println!(\"hi\");\n}\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "fn main() {\n println!(\"hi\");\n}\n",
|
||||
"new_string": "fn main() {\n println!(\"hello\");\n}\n"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.success,
|
||||
"flexible indentation match should succeed: {:?}",
|
||||
result.error
|
||||
);
|
||||
assert!(result.output.contains("whitespace flexibility"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "fn main() {\n println!(\"hello\");\n}\n");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_tab_space_difference() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_tabs");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("test.txt"), "alpha\n\tbeta\ngamma\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "alpha\n beta\ngamma\n",
|
||||
"new_string": "alpha\n\tdelta\ngamma\n"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "tab/space flex match should succeed");
|
||||
assert!(result.output.contains("whitespace flexibility"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "alpha\n\tdelta\ngamma\n");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_trailing_whitespace_difference() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_trailing");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("test.txt"), "line one \nline two\t\t\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "line one\nline two\n",
|
||||
"new_string": "line one\nline 2\n"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.success,
|
||||
"trailing whitespace flex match should succeed"
|
||||
);
|
||||
assert!(result.output.contains("whitespace flexibility"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "line one\nline 2\n");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_collapsed_spaces() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_spaces");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("test.txt"), "let value = 42;\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "let value = 42;\n",
|
||||
"new_string": "let value = 7;\n"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "collapsed-space flex match should succeed");
|
||||
assert!(result.output.contains("whitespace flexibility"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "let value = 7;\n");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_ambiguous_errors() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_ambiguous");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(
|
||||
dir.join("test.txt"),
|
||||
"if cond {\n work();\n}\n\nif cond {\n\twork();\n}\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "if cond {\n work();\n}\n",
|
||||
"new_string": "if cond {\n done();\n}\n"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success, "ambiguous flex match must fail");
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("whitespace flexibility"));
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("matches 2 times"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
content,
|
||||
"if cond {\n work();\n}\n\nif cond {\n\twork();\n}\n"
|
||||
);
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_not_found_when_no_line_match() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_not_found");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("test.txt"), "alpha\nbeta\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "gamma\n",
|
||||
"new_string": "delta\n"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success, "non-matching flex case should fail");
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("not found"));
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_prefers_exact_match_over_flexible() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_exact_preference");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(
|
||||
dir.join("test.txt"),
|
||||
"let value = 1;\nlet value = 1;\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "let value = 1;",
|
||||
"new_string": "let value = 2;"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.success, "exact match should succeed");
|
||||
assert!(!result.output.contains("whitespace flexibility"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "let value = 2;\nlet value = 1;\n");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_flexible_match_preserves_trailing_newline_when_old_string_has_none() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_flex_no_trailing_nl");
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
tokio::fs::create_dir_all(&dir).await.unwrap();
|
||||
tokio::fs::write(dir.join("test.txt"), "line one\n line two\nline three\n")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = FileEditTool::new(test_security(dir.clone()));
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"path": "test.txt",
|
||||
"old_string": "line one\n line two\nline three",
|
||||
"new_string": "updated block"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.success,
|
||||
"flex match without trailing newline should succeed"
|
||||
);
|
||||
assert!(result.output.contains("whitespace flexibility"));
|
||||
|
||||
let content = tokio::fs::read_to_string(dir.join("test.txt"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(content, "updated block\n");
|
||||
|
||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_edit_multiple_matches() {
|
||||
let dir = std::env::temp_dir().join("zeroclaw_test_file_edit_multi");
|
||||
|
||||
@ -18,6 +18,12 @@ const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
|
||||
/// Timeout for init/list operations.
|
||||
const RECV_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// Streamable HTTP Accept header required by MCP HTTP transport.
|
||||
const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream";
|
||||
|
||||
/// Default media type for MCP JSON-RPC request bodies.
|
||||
const MCP_JSON_CONTENT_TYPE: &str = "application/json";
|
||||
|
||||
// ── Transport Trait ──────────────────────────────────────────────────────
|
||||
|
||||
/// Abstract transport for MCP communication.
|
||||
@ -171,10 +177,25 @@ impl McpTransportConn for HttpTransport {
|
||||
async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
|
||||
let body = serde_json::to_string(request)?;
|
||||
|
||||
let has_accept = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"));
|
||||
let has_content_type = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
|
||||
|
||||
let mut req = self.client.post(&self.url).body(body);
|
||||
if !has_content_type {
|
||||
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
|
||||
}
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !has_accept {
|
||||
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
@ -194,11 +215,24 @@ impl McpTransportConn for HttpTransport {
|
||||
});
|
||||
}
|
||||
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(&resp_text)
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
let is_sse = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
|
||||
if is_sse {
|
||||
let maybe_resp = timeout(
|
||||
Duration::from_secs(RECV_TIMEOUT_SECS),
|
||||
read_first_jsonrpc_from_sse_response(resp),
|
||||
)
|
||||
.await
|
||||
.context("timeout waiting for MCP response from streamable HTTP SSE stream")??;
|
||||
return maybe_resp
|
||||
.ok_or_else(|| anyhow!("MCP server returned no response in SSE stream"));
|
||||
}
|
||||
|
||||
Ok(mcp_resp)
|
||||
let resp_text = resp.text().await.context("failed to read HTTP response")?;
|
||||
parse_jsonrpc_response_text(&resp_text)
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<()> {
|
||||
@ -264,14 +298,21 @@ impl SseTransport {
|
||||
}
|
||||
}
|
||||
|
||||
let has_accept = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"));
|
||||
|
||||
let mut req = self
|
||||
.client
|
||||
.get(&self.sse_url)
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache");
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !has_accept {
|
||||
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE GET to MCP server failed")?;
|
||||
if resp.status() == reqwest::StatusCode::NOT_FOUND
|
||||
@ -556,6 +597,30 @@ fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
|
||||
Cow::Owned(joined.trim().to_string())
|
||||
}
|
||||
|
||||
fn parse_jsonrpc_response_text(resp_text: &str) -> Result<JsonRpcResponse> {
|
||||
let trimmed = resp_text.trim();
|
||||
if trimmed.is_empty() {
|
||||
bail!("MCP server returned no response");
|
||||
}
|
||||
|
||||
let json_text = if looks_like_sse_text(trimmed) {
|
||||
extract_json_from_sse_text(trimmed)
|
||||
} else {
|
||||
Cow::Borrowed(trimmed)
|
||||
};
|
||||
|
||||
let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref())
|
||||
.with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
|
||||
Ok(mcp_resp)
|
||||
}
|
||||
|
||||
fn looks_like_sse_text(text: &str) -> bool {
|
||||
text.starts_with("data:")
|
||||
|| text.starts_with("event:")
|
||||
|| text.contains("\ndata:")
|
||||
|| text.contains("\nevent:")
|
||||
}
|
||||
|
||||
async fn read_first_jsonrpc_from_sse_response(
|
||||
resp: reqwest::Response,
|
||||
) -> Result<Option<JsonRpcResponse>> {
|
||||
@ -673,21 +738,27 @@ impl McpTransportConn for SseTransport {
|
||||
.chain(secondary_url.into_iter())
|
||||
.enumerate()
|
||||
{
|
||||
let has_accept = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"));
|
||||
let has_content_type = self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Content-Type"));
|
||||
let mut req = self
|
||||
.client
|
||||
.post(&url)
|
||||
.timeout(Duration::from_secs(120))
|
||||
.body(body.clone())
|
||||
.header("Content-Type", "application/json");
|
||||
.body(body.clone());
|
||||
if !has_content_type {
|
||||
req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
|
||||
}
|
||||
for (key, value) in &self.headers {
|
||||
req = req.header(key, value);
|
||||
}
|
||||
if !self
|
||||
.headers
|
||||
.keys()
|
||||
.any(|k| k.eq_ignore_ascii_case("Accept"))
|
||||
{
|
||||
req = req.header("Accept", "application/json, text/event-stream");
|
||||
if !has_accept {
|
||||
req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
|
||||
}
|
||||
|
||||
let resp = req.send().await.context("SSE POST to MCP server failed")?;
|
||||
@ -887,4 +958,34 @@ mod tests {
|
||||
let extracted = extract_json_from_sse_text(input);
|
||||
let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_jsonrpc_response_text_handles_plain_json() {
|
||||
let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}")
|
||||
.expect("plain JSON response should parse");
|
||||
assert_eq!(parsed.id, Some(serde_json::json!(1)));
|
||||
assert!(parsed.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_jsonrpc_response_text_handles_sse_framed_json() {
|
||||
let sse =
|
||||
"event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n";
|
||||
let parsed =
|
||||
parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse");
|
||||
assert_eq!(parsed.id, Some(serde_json::json!(2)));
|
||||
assert_eq!(
|
||||
parsed
|
||||
.result
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("ok"))
|
||||
.and_then(|v| v.as_bool()),
|
||||
Some(true)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_jsonrpc_response_text_rejects_empty_payload() {
|
||||
assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
1
web/dist/assets/index-BCWngppm.css
vendored
Normal file
1
web/dist/assets/index-BCWngppm.css
vendored
Normal file
File diff suppressed because one or more lines are too long
1
web/dist/assets/index-C70eaW2F.css
vendored
1
web/dist/assets/index-C70eaW2F.css
vendored
File diff suppressed because one or more lines are too long
320
web/dist/assets/index-CJ6bGkAt.js
vendored
320
web/dist/assets/index-CJ6bGkAt.js
vendored
File diff suppressed because one or more lines are too long
706
web/dist/assets/index-DOPK6_Za.js
vendored
Normal file
706
web/dist/assets/index-DOPK6_Za.js
vendored
Normal file
File diff suppressed because one or more lines are too long
4
web/dist/index.html
vendored
4
web/dist/index.html
vendored
@ -9,8 +9,8 @@
|
||||
/>
|
||||
<meta name="color-scheme" content="dark" />
|
||||
<title>ZeroClaw</title>
|
||||
<script type="module" crossorigin src="/_app/assets/index-CJ6bGkAt.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/_app/assets/index-C70eaW2F.css">
|
||||
<script type="module" crossorigin src="/_app/assets/index-DOPK6_Za.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/_app/assets/index-BCWngppm.css">
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
||||
232
web/package-lock.json
generated
232
web/package-lock.json
generated
@ -9,6 +9,11 @@
|
||||
"version": "0.1.0",
|
||||
"license": "(MIT OR Apache-2.0)",
|
||||
"dependencies": {
|
||||
"@codemirror/language": "^6.12.2",
|
||||
"@codemirror/legacy-modes": "^6.5.2",
|
||||
"@codemirror/theme-one-dark": "^6.1.3",
|
||||
"@codemirror/view": "^6.39.15",
|
||||
"@uiw/react-codemirror": "^4.25.5",
|
||||
"lucide-react": "^0.468.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
@ -260,6 +265,15 @@
|
||||
"@babel/core": "^7.0.0-0"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/runtime": {
|
||||
"version": "7.28.6",
|
||||
"resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.28.6.tgz",
|
||||
"integrity": "sha512-05WQkdpL9COIMz4LjTxGpPNCdlpyimKppYNoJ5Di5EUObifl8t4tuLuUBBZEpoLYOmfvIWrsp9fCl0HoPRVTdA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/template": {
|
||||
"version": "7.28.6",
|
||||
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.28.6.tgz",
|
||||
@ -308,6 +322,108 @@
|
||||
"node": ">=6.9.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/autocomplete": {
|
||||
"version": "6.20.0",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.20.0.tgz",
|
||||
"integrity": "sha512-bOwvTOIJcG5FVo5gUUupiwYh8MioPLQ4UcqbcRf7UQ98X90tCa9E1kZ3Z7tqwpZxYyOvh1YTYbmZE9RTfTp5hg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/language": "^6.0.0",
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.17.0",
|
||||
"@lezer/common": "^1.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/commands": {
|
||||
"version": "6.10.2",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.10.2.tgz",
|
||||
"integrity": "sha512-vvX1fsih9HledO1c9zdotZYUZnE4xV0m6i3m25s5DIfXofuprk6cRcLUZvSk3CASUbwjQX21tOGbkY2BH8TpnQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/language": "^6.0.0",
|
||||
"@codemirror/state": "^6.4.0",
|
||||
"@codemirror/view": "^6.27.0",
|
||||
"@lezer/common": "^1.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/language": {
|
||||
"version": "6.12.2",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.12.2.tgz",
|
||||
"integrity": "sha512-jEPmz2nGGDxhRTg3lTpzmIyGKxz3Gp3SJES4b0nAuE5SWQoKdT5GoQ69cwMmFd+wvFUhYirtDTr0/DRHpQAyWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.23.0",
|
||||
"@lezer/common": "^1.5.0",
|
||||
"@lezer/highlight": "^1.0.0",
|
||||
"@lezer/lr": "^1.0.0",
|
||||
"style-mod": "^4.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/legacy-modes": {
|
||||
"version": "6.5.2",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/legacy-modes/-/legacy-modes-6.5.2.tgz",
|
||||
"integrity": "sha512-/jJbwSTazlQEDOQw2FJ8LEEKVS72pU0lx6oM54kGpL8t/NJ2Jda3CZ4pcltiKTdqYSRk3ug1B3pil1gsjA6+8Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/language": "^6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/lint": {
|
||||
"version": "6.9.4",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/lint/-/lint-6.9.4.tgz",
|
||||
"integrity": "sha512-ABc9vJ8DEmvOWuH26P3i8FpMWPQkduD9Rvba5iwb6O3hxASgclm3T3krGo8NASXkHCidz6b++LWlzWIUfEPSWw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.35.0",
|
||||
"crelt": "^1.0.5"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/search": {
|
||||
"version": "6.6.0",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/search/-/search-6.6.0.tgz",
|
||||
"integrity": "sha512-koFuNXcDvyyotWcgOnZGmY7LZqEOXZaaxD/j6n18TCLx2/9HieZJ5H6hs1g8FiRxBD0DNfs0nXn17g872RmYdw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.37.0",
|
||||
"crelt": "^1.0.5"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/state": {
|
||||
"version": "6.5.4",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.5.4.tgz",
|
||||
"integrity": "sha512-8y7xqG/hpB53l25CIoit9/ngxdfoG+fx+V3SHBrinnhOtLvKHRyAJJuHzkWrR4YXXLX8eXBsejgAAxHUOdW1yw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@marijn/find-cluster-break": "^1.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/theme-one-dark": {
|
||||
"version": "6.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/theme-one-dark/-/theme-one-dark-6.1.3.tgz",
|
||||
"integrity": "sha512-NzBdIvEJmx6fjeremiGp3t/okrLPYT0d9orIc7AFun8oZcRk58aejkqhv6spnz4MLAevrKNPMQYXEWMg4s+sKA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/language": "^6.0.0",
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.0.0",
|
||||
"@lezer/highlight": "^1.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@codemirror/view": {
|
||||
"version": "6.39.15",
|
||||
"resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.39.15.tgz",
|
||||
"integrity": "sha512-aCWjgweIIXLBHh7bY6cACvXuyrZ0xGafjQ2VInjp4RM4gMfscK5uESiNdrH0pE+e1lZr2B4ONGsjchl2KsKZzg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/state": "^6.5.0",
|
||||
"crelt": "^1.0.6",
|
||||
"style-mod": "^4.1.0",
|
||||
"w3c-keyname": "^2.2.4"
|
||||
}
|
||||
},
|
||||
"node_modules/@esbuild/aix-ppc64": {
|
||||
"version": "0.25.12",
|
||||
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.12.tgz",
|
||||
@ -800,6 +916,36 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.4.14"
|
||||
}
|
||||
},
|
||||
"node_modules/@lezer/common": {
|
||||
"version": "1.5.1",
|
||||
"resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.5.1.tgz",
|
||||
"integrity": "sha512-6YRVG9vBkaY7p1IVxL4s44n5nUnaNnGM2/AckNgYOnxTG2kWh1vR8BMxPseWPjRNpb5VtXnMpeYAEAADoRV1Iw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@lezer/highlight": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@lezer/highlight/-/highlight-1.2.3.tgz",
|
||||
"integrity": "sha512-qXdH7UqTvGfdVBINrgKhDsVTJTxactNNxLk7+UMwZhU13lMHaOBlJe9Vqp907ya56Y3+ed2tlqzys7jDkTmW0g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@lezer/common": "^1.3.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@lezer/lr": {
|
||||
"version": "1.4.8",
|
||||
"resolved": "https://registry.npmjs.org/@lezer/lr/-/lr-1.4.8.tgz",
|
||||
"integrity": "sha512-bPWa0Pgx69ylNlMlPvBPryqeLYQjyJjqPx+Aupm5zydLIF3NE+6MMLT8Yi23Bd9cif9VS00aUebn+6fDIGBcDA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@lezer/common": "^1.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@marijn/find-cluster-break": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@marijn/find-cluster-break/-/find-cluster-break-1.0.2.tgz",
|
||||
"integrity": "sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@rolldown/pluginutils": {
|
||||
"version": "1.0.0-beta.27",
|
||||
"resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.27.tgz",
|
||||
@ -1511,6 +1657,59 @@
|
||||
"@types/react": "^19.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@uiw/codemirror-extensions-basic-setup": {
|
||||
"version": "4.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@uiw/codemirror-extensions-basic-setup/-/codemirror-extensions-basic-setup-4.25.5.tgz",
|
||||
"integrity": "sha512-2KWS4NqrS9SQzlPs/3sxFhuArvjB3JF6WpsrZqBtGHM5/smCNTULX3lUGeRH+f3mkfMt0k6DR+q0xCW9k+Up5w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/autocomplete": "^6.0.0",
|
||||
"@codemirror/commands": "^6.0.0",
|
||||
"@codemirror/language": "^6.0.0",
|
||||
"@codemirror/lint": "^6.0.0",
|
||||
"@codemirror/search": "^6.0.0",
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://jaywcjlove.github.io/#/sponsor"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@codemirror/autocomplete": ">=6.0.0",
|
||||
"@codemirror/commands": ">=6.0.0",
|
||||
"@codemirror/language": ">=6.0.0",
|
||||
"@codemirror/lint": ">=6.0.0",
|
||||
"@codemirror/search": ">=6.0.0",
|
||||
"@codemirror/state": ">=6.0.0",
|
||||
"@codemirror/view": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@uiw/react-codemirror": {
|
||||
"version": "4.25.5",
|
||||
"resolved": "https://registry.npmjs.org/@uiw/react-codemirror/-/react-codemirror-4.25.5.tgz",
|
||||
"integrity": "sha512-WUMBGwfstufdbnaiMzQzmOf+6Mzf0IbiOoleexC9ItWcDTJybidLtEi20aP2N58Wn/AQxsd5Otebydaimh7Opw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.18.6",
|
||||
"@codemirror/commands": "^6.1.0",
|
||||
"@codemirror/state": "^6.1.1",
|
||||
"@codemirror/theme-one-dark": "^6.0.0",
|
||||
"@uiw/codemirror-extensions-basic-setup": "4.25.5",
|
||||
"codemirror": "^6.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://jaywcjlove.github.io/#/sponsor"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@babel/runtime": ">=7.11.0",
|
||||
"@codemirror/state": ">=6.0.0",
|
||||
"@codemirror/theme-one-dark": ">=6.0.0",
|
||||
"@codemirror/view": ">=6.0.0",
|
||||
"codemirror": ">=6.0.0",
|
||||
"react": ">=17.0.0",
|
||||
"react-dom": ">=17.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@vitejs/plugin-react": {
|
||||
"version": "4.7.0",
|
||||
"resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz",
|
||||
@ -1600,6 +1799,21 @@
|
||||
],
|
||||
"license": "CC-BY-4.0"
|
||||
},
|
||||
"node_modules/codemirror": {
|
||||
"version": "6.0.2",
|
||||
"resolved": "https://registry.npmjs.org/codemirror/-/codemirror-6.0.2.tgz",
|
||||
"integrity": "sha512-VhydHotNW5w1UGK0Qj96BwSk/Zqbp9WbnyK2W/eVMv4QyF41INRGpjUhFJY7/uDNuudSc33a/PKr4iDqRduvHw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@codemirror/autocomplete": "^6.0.0",
|
||||
"@codemirror/commands": "^6.0.0",
|
||||
"@codemirror/language": "^6.0.0",
|
||||
"@codemirror/lint": "^6.0.0",
|
||||
"@codemirror/search": "^6.0.0",
|
||||
"@codemirror/state": "^6.0.0",
|
||||
"@codemirror/view": "^6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/convert-source-map": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz",
|
||||
@ -1620,6 +1834,12 @@
|
||||
"url": "https://opencollective.com/express"
|
||||
}
|
||||
},
|
||||
"node_modules/crelt": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz",
|
||||
"integrity": "sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/csstype": {
|
||||
"version": "3.2.3",
|
||||
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz",
|
||||
@ -2351,6 +2571,12 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/style-mod": {
|
||||
"version": "4.1.3",
|
||||
"resolved": "https://registry.npmjs.org/style-mod/-/style-mod-4.1.3.tgz",
|
||||
"integrity": "sha512-i/n8VsZydrugj3Iuzll8+x/00GH2vnYsk1eomD8QiRrSAeW6ItbCQDtfXCeJHd0iwiNagqjQkvpvREEPtW3IoQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/tailwindcss": {
|
||||
"version": "4.2.0",
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.2.0.tgz",
|
||||
@ -2516,6 +2742,12 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/w3c-keyname": {
|
||||
"version": "2.2.8",
|
||||
"resolved": "https://registry.npmjs.org/w3c-keyname/-/w3c-keyname-2.2.8.tgz",
|
||||
"integrity": "sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/yallist": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz",
|
||||
|
||||
@ -10,6 +10,11 @@
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@codemirror/language": "^6.12.2",
|
||||
"@codemirror/legacy-modes": "^6.5.2",
|
||||
"@codemirror/theme-one-dark": "^6.1.3",
|
||||
"@codemirror/view": "^6.39.15",
|
||||
"@uiw/react-codemirror": "^4.25.5",
|
||||
"lucide-react": "^0.468.0",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
import { StreamLanguage } from '@codemirror/language';
|
||||
import { oneDark } from '@codemirror/theme-one-dark';
|
||||
import { EditorView } from '@codemirror/view';
|
||||
import { toml } from '@codemirror/legacy-modes/mode/toml';
|
||||
import CodeMirror from '@uiw/react-codemirror';
|
||||
|
||||
interface Props {
|
||||
rawToml: string;
|
||||
onChange: (raw: string) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const tomlLanguage = StreamLanguage.define(toml);
|
||||
|
||||
export default function ConfigRawEditor({ rawToml, onChange, disabled }: Props) {
|
||||
return (
|
||||
<div className="bg-gray-900 rounded-xl border border-gray-800 overflow-hidden">
|
||||
@ -15,14 +23,22 @@ export default function ConfigRawEditor({ rawToml, onChange, disabled }: Props)
|
||||
{rawToml.split('\n').length} lines
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
<CodeMirror
|
||||
value={rawToml}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
disabled={disabled}
|
||||
spellCheck={false}
|
||||
aria-label="Raw TOML configuration editor"
|
||||
className="w-full min-h-[500px] bg-gray-950 text-gray-200 font-mono text-sm p-4 resize-y focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-inset disabled:opacity-50"
|
||||
style={{ tabSize: 4 }}
|
||||
onChange={onChange}
|
||||
theme={oneDark}
|
||||
readOnly={Boolean(disabled)}
|
||||
editable={!disabled}
|
||||
height="500px"
|
||||
basicSetup={{
|
||||
lineNumbers: true,
|
||||
foldGutter: false,
|
||||
highlightActiveLineGutter: false,
|
||||
highlightActiveLine: false,
|
||||
}}
|
||||
extensions={[tomlLanguage, EditorView.lineWrapping]}
|
||||
className="text-sm [&_.cm-scroller]:font-mono [&_.cm-scroller]:leading-6 [&_.cm-content]:py-4 [&_.cm-content]:px-0 [&_.cm-gutters]:border-r [&_.cm-gutters]:border-gray-800 [&_.cm-gutters]:bg-gray-950 [&_.cm-editor]:bg-gray-950 [&_.cm-editor]:focus:outline-none [&_.cm-focused]:ring-2 [&_.cm-focused]:ring-blue-500/70 [&_.cm-focused]:ring-inset"
|
||||
aria-label="Raw TOML configuration editor with syntax highlighting"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -55,9 +55,9 @@ export default function AgentChat() {
|
||||
ws.onMessage = (msg: WsMessage) => {
|
||||
switch (msg.type) {
|
||||
case 'history': {
|
||||
const restored = (msg.messages ?? [])
|
||||
const restored: ChatMessage[] = (msg.messages ?? [])
|
||||
.filter((entry) => entry.content?.trim())
|
||||
.map((entry) => ({
|
||||
.map((entry): ChatMessage => ({
|
||||
id: makeMessageId(),
|
||||
role: entry.role === 'user' ? 'user' : 'agent',
|
||||
content: entry.content.trim(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user