Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ccd52f3394 | |||
| eb01aa451d | |||
| c785b45f2d | |||
| ffb8b81f90 | |||
| 65f856d710 | |||
| 1682620377 | |||
| aa455ae89b | |||
| a9ffd38912 | |||
| 86a0584513 | |||
| abef4c5719 | |||
| 483b2336c4 | |||
| 14cda3bc9a | |||
| 6e8f0fa43c | |||
| a965b129f8 | |||
| c135de41b7 | |||
| 2d2c2ac9e6 | |||
| 5e774bbd70 | |||
| 33015067eb | |||
| 6b10c0b891 | |||
| bf817e30d2 | |||
| 0051a0c296 | |||
| d753de91f1 | |||
| f6b2f61a01 | |||
| 70e7910cb9 |
@@ -118,3 +118,7 @@ PROVIDER=openrouter
|
||||
# Optional: Brave Search (requires API key from https://brave.com/search/api)
|
||||
# WEB_SEARCH_PROVIDER=brave
|
||||
# BRAVE_API_KEY=your-brave-search-api-key
|
||||
#
|
||||
# Optional: SearXNG (self-hosted, requires instance URL)
|
||||
# WEB_SEARCH_PROVIDER=searxng
|
||||
# SEARXNG_INSTANCE_URL=https://searx.example.com
|
||||
|
||||
Generated
+1
-1
@@ -9530,7 +9530,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.7"
|
||||
version = "0.5.8"
|
||||
dependencies = [
|
||||
"aardvark-sys",
|
||||
"anyhow",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.5.7"
|
||||
version = "0.5.8"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.5.7
|
||||
pkgver = 0.5.8
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.5.7.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.7.tar.gz
|
||||
source = zeroclaw-0.5.8.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.8.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.5.7
|
||||
pkgver=0.5.8
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.5.7",
|
||||
"version": "0.5.8",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.7/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.8/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# ADR-004: Tool Shared State Ownership Contract
|
||||
|
||||
**Status:** Accepted
|
||||
|
||||
**Date:** 2026-03-22
|
||||
|
||||
**Issue:** [#4057](https://github.com/zeroclaw/zeroclaw/issues/4057)
|
||||
|
||||
## Context
|
||||
|
||||
ZeroClaw tools execute in a multi-client environment where a single daemon
|
||||
process serves requests from multiple connected clients simultaneously. Several
|
||||
tools already maintain long-lived shared state:
|
||||
|
||||
- **`DelegateParentToolsHandle`** (`src/tools/mod.rs`):
|
||||
`Arc<RwLock<Vec<Arc<dyn Tool>>>>` — holds parent tools for delegate agents
|
||||
with no per-client isolation.
|
||||
- **`ChannelMapHandle`** (`src/tools/reaction.rs`):
|
||||
`Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>` — global channel map shared
|
||||
across all clients.
|
||||
- **`CanvasStore`** (`src/tools/canvas.rs`):
|
||||
`Arc<RwLock<HashMap<String, CanvasEntry>>>` — canvas IDs are plain strings
|
||||
with no client namespace.
|
||||
|
||||
These patterns emerged organically. As the tool surface grows and more clients
|
||||
connect concurrently, we need a clear contract governing ownership, identity,
|
||||
isolation, lifecycle, and reload behavior for tool-held shared state. Without
|
||||
this contract, new tools risk introducing data leaks between clients, stale
|
||||
state after config reloads, or inconsistent initialization timing.
|
||||
|
||||
Additional context:
|
||||
|
||||
- The tool registry is immutable after startup, built once in
|
||||
`all_tools_with_runtime()`.
|
||||
- Client identity is currently derived from IP address only
|
||||
(`src/gateway/mod.rs`), which is insufficient for reliable namespacing.
|
||||
- `SecurityPolicy` is scoped per agent, not per client.
|
||||
- `WorkspaceManager` provides some isolation but workspace switching is global.
|
||||
|
||||
## Decision
|
||||
|
||||
### 1. Ownership: May tools own long-lived shared state?
|
||||
|
||||
**Yes.** Tools MAY own long-lived shared state, provided they follow the
|
||||
established **handle pattern**: wrap the state in `Arc<RwLock<T>>` (or
|
||||
`Arc<parking_lot::RwLock<T>>`) and expose a cloneable handle type.
|
||||
|
||||
This pattern is already proven by three independent implementations:
|
||||
|
||||
| Handle | Location | Inner type |
|
||||
|--------|----------|-----------|
|
||||
| `DelegateParentToolsHandle` | `src/tools/mod.rs` | `Vec<Arc<dyn Tool>>` |
|
||||
| `ChannelMapHandle` | `src/tools/reaction.rs` | `HashMap<String, Arc<dyn Channel>>` |
|
||||
| `CanvasStore` | `src/tools/canvas.rs` | `HashMap<String, CanvasEntry>` |
|
||||
|
||||
Tools that need shared state MUST:
|
||||
|
||||
- Define a named handle type alias (e.g., `pub type FooHandle = Arc<RwLock<T>>`).
|
||||
- Accept the handle at construction time rather than creating global state.
|
||||
- Document the concurrency contract in the handle type's doc comment.
|
||||
|
||||
Tools MUST NOT use static mutable state (`lazy_static!`, `OnceCell` with
|
||||
interior mutability) for per-request or per-client data.
|
||||
|
||||
### 2. Identity assignment: Who constructs identity keys?
|
||||
|
||||
**The daemon SHOULD provide identity.** Tools MUST NOT construct their own
|
||||
client identity keys.
|
||||
|
||||
A new `ClientId` type should be introduced (opaque, `Clone + Eq + Hash + Send + Sync`)
|
||||
that the daemon assigns at connection time. This replaces the current approach
|
||||
of using raw IP addresses (`src/gateway/mod.rs:259-306`), which breaks when
|
||||
multiple clients share a NAT address or when proxied connections arrive.
|
||||
|
||||
`ClientId` is passed to tools that require per-client state namespacing as part
|
||||
of the tool execution context. Tools that do not need per-client isolation
|
||||
(e.g., the immutable tool registry) may ignore it.
|
||||
|
||||
The `ClientId` contract:
|
||||
|
||||
- Generated by the gateway layer at connection establishment.
|
||||
- Opaque to tools — tools must not parse or derive meaning from the value.
|
||||
- Stable for the lifetime of a single client session.
|
||||
- Passed through the execution context, not stored globally.
|
||||
|
||||
### 3. Lifecycle: When may tools run startup-style validation?
|
||||
|
||||
**Validation runs once at first registration, and again when config changes
|
||||
are detected.**
|
||||
|
||||
The lifecycle phases are:
|
||||
|
||||
1. **Construction** — tool is instantiated with handles and config. No I/O or
|
||||
validation occurs here.
|
||||
2. **Registration** — tool is registered in the tool registry via
|
||||
`all_tools_with_runtime()`. At this point the tool MAY perform one-time
|
||||
startup validation (e.g., checking that required credentials exist, verifying
|
||||
external service connectivity).
|
||||
3. **Execution** — tool handles individual requests. No re-validation unless
|
||||
the config-change signal fires (see Reload Semantics below).
|
||||
4. **Shutdown** — daemon is stopping. Tools with open resources SHOULD clean up
|
||||
gracefully via `Drop` or an explicit shutdown method.
|
||||
|
||||
Tools MUST NOT perform blocking validation during execution-phase calls.
|
||||
Validation results SHOULD be cached in the tool's handle state and checked
|
||||
via a fast path during execution.
|
||||
|
||||
### 4. Isolation: What must be isolated per client?
|
||||
|
||||
State falls into two categories with different isolation requirements:
|
||||
|
||||
**MUST be isolated per client:**
|
||||
|
||||
- Security-sensitive state: credentials, API keys, quotas, rate-limit counters,
|
||||
per-client authorization decisions.
|
||||
- User-specific session data: conversation context, user preferences,
|
||||
workspace-scoped file paths.
|
||||
|
||||
Isolation mechanism: tools holding per-client state MUST key their internal
|
||||
maps by `ClientId`. The handle pattern naturally supports this by using
|
||||
`HashMap<ClientId, T>` inside the `RwLock`.
|
||||
|
||||
**MAY be shared across clients (with namespace prefixing):**
|
||||
|
||||
- Broadcast/display state: canvas frames (`CanvasStore`), notification channels
|
||||
(`ChannelMapHandle`).
|
||||
- Read-only reference data: tool registry, static configuration, model
|
||||
metadata.
|
||||
|
||||
When shared state uses string keys (e.g., canvas IDs, channel names), tools
|
||||
SHOULD support optional namespace prefixing (e.g., `{client_id}:{canvas_name}`)
|
||||
to allow per-client isolation when needed without mandating it for broadcast
|
||||
use cases.
|
||||
|
||||
Tools MUST NOT store per-client secrets in shared (non-isolated) state
|
||||
structures.
|
||||
|
||||
### 5. Reload semantics: What invalidates prior shared state on config change?
|
||||
|
||||
**Config changes detected via hash comparison MUST invalidate cached
|
||||
validation state.**
|
||||
|
||||
The reload contract:
|
||||
|
||||
- The daemon computes a hash of the tool-relevant config section at startup and
|
||||
after each config reload event.
|
||||
- When the hash changes, the daemon signals affected tools to re-run their
|
||||
registration-phase validation.
|
||||
- Tools MUST treat their cached validation result as stale when signaled and
|
||||
re-validate before the next execution.
|
||||
|
||||
Specific invalidation rules:
|
||||
|
||||
| Config change | Invalidation scope |
|
||||
|--------------|-------------------|
|
||||
| Credential/secret rotation | Per-tool validation cache; per-client credential state |
|
||||
| Tool enable/disable | Full tool registry rebuild via `all_tools_with_runtime()` |
|
||||
| Security policy change | `SecurityPolicy` re-derivation; per-agent policy state |
|
||||
| Workspace directory change | `WorkspaceManager` state; file-path-dependent tool state |
|
||||
| Provider config change | Provider-dependent tools re-validate connectivity |
|
||||
|
||||
Tools MAY retain non-security shared state (e.g., canvas content, channel
|
||||
subscriptions) across config reloads unless the reload explicitly affects that
|
||||
state's validity.
|
||||
|
||||
## Consequences
|
||||
|
||||
### Positive
|
||||
|
||||
- **Consistency:** All new tools follow the same handle pattern, making shared
|
||||
state discoverable and auditable.
|
||||
- **Safety:** Per-client isolation of security-sensitive state prevents data
|
||||
leaks in multi-tenant scenarios.
|
||||
- **Clarity:** Explicit lifecycle phases eliminate ambiguity about when
|
||||
validation runs.
|
||||
- **Evolvability:** The `ClientId` abstraction decouples tools from transport
|
||||
details, supporting future identity mechanisms (tokens, certificates).
|
||||
|
||||
### Negative
|
||||
|
||||
- **Migration cost:** Existing tools (`CanvasStore`, `ReactionTool`) may need
|
||||
refactoring to accept `ClientId` and namespace their state.
|
||||
- **Complexity:** Tools that were simple singletons now need to consider
|
||||
multi-client semantics even if they currently have one client.
|
||||
- **Performance:** Per-client keying adds a hash lookup on each access, though
|
||||
this is negligible compared to I/O costs.
|
||||
|
||||
### Neutral
|
||||
|
||||
- The tool registry remains immutable after startup; this ADR does not change
|
||||
that invariant.
|
||||
- `SecurityPolicy` remains per-agent; this ADR documents that client isolation
|
||||
is orthogonal to agent-level policy.
|
||||
|
||||
## References
|
||||
|
||||
- `src/tools/mod.rs` — `DelegateParentToolsHandle`, `all_tools_with_runtime()`
|
||||
- `src/tools/reaction.rs` — `ChannelMapHandle`, `ReactionTool`
|
||||
- `src/tools/canvas.rs` — `CanvasStore`, `CanvasEntry`
|
||||
- `src/tools/traits.rs` — `Tool` trait
|
||||
- `src/gateway/mod.rs` — client IP extraction (`forwarded_client_ip`, `resolve_client_ip`)
|
||||
- `src/security/` — `SecurityPolicy`
|
||||
@@ -0,0 +1,215 @@
|
||||
# Browser Automation Setup Guide
|
||||
|
||||
This guide covers setting up browser automation capabilities in ZeroClaw, including both headless automation and GUI access via VNC.
|
||||
|
||||
## Overview
|
||||
|
||||
ZeroClaw supports multiple browser access methods:
|
||||
|
||||
| Method | Use Case | Requirements |
|
||||
|--------|----------|--------------|
|
||||
| **agent-browser CLI** | Headless automation, AI agents | npm, Chrome |
|
||||
| **VNC + noVNC** | GUI access, debugging | Xvfb, x11vnc, noVNC |
|
||||
| **Chrome Remote Desktop** | Remote GUI via Google | XFCE, Google account |
|
||||
|
||||
## Quick Start: Headless Automation
|
||||
|
||||
### 1. Install agent-browser
|
||||
|
||||
```bash
|
||||
# Install CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux (includes system deps)
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
### 2. Verify ZeroClaw Config
|
||||
|
||||
The browser tool is enabled by default. To verify or customize, edit
|
||||
`~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
### 3. Test
|
||||
|
||||
```bash
|
||||
echo "Open https://example.com and tell me what it says" | zeroclaw agent
|
||||
```
|
||||
|
||||
## VNC Setup (GUI Access)
|
||||
|
||||
For debugging or when you need visual browser access:
|
||||
|
||||
### Install Dependencies
|
||||
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
apt-get install -y xvfb x11vnc fluxbox novnc websockify
|
||||
|
||||
# Optional: Desktop environment for Chrome Remote Desktop
|
||||
apt-get install -y xfce4 xfce4-goodies
|
||||
```
|
||||
|
||||
### Start VNC Server
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access
|
||||
|
||||
DISPLAY_NUM=99
|
||||
VNC_PORT=5900
|
||||
NOVNC_PORT=6080
|
||||
RESOLUTION=1920x1080x24
|
||||
|
||||
# Start Xvfb
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
sleep 1
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
|
||||
echo "VNC available at:"
|
||||
echo " VNC Client: localhost:$VNC_PORT"
|
||||
echo " Web Browser: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
```
|
||||
|
||||
### VNC Access
|
||||
|
||||
- **VNC Client**: Connect to `localhost:5900`
|
||||
- **Web Browser**: Open `http://localhost:6080/vnc.html`
|
||||
|
||||
### Start Browser on VNC Display
|
||||
|
||||
```bash
|
||||
DISPLAY=:99 google-chrome --no-sandbox https://example.com &
|
||||
```
|
||||
|
||||
## Chrome Remote Desktop
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
# Download and install
|
||||
wget https://dl.google.com/linux/direct/chrome-remote-desktop_current_amd64.deb
|
||||
apt-get install -y ./chrome-remote-desktop_current_amd64.deb
|
||||
|
||||
# Configure session
|
||||
echo "xfce4-session" > ~/.chrome-remote-desktop-session
|
||||
chmod +x ~/.chrome-remote-desktop-session
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
1. Visit <https://remotedesktop.google.com/headless>
|
||||
2. Copy the "Debian Linux" setup command
|
||||
3. Run it on your server
|
||||
4. Start the service: `systemctl --user start chrome-remote-desktop`
|
||||
|
||||
### Remote Access
|
||||
|
||||
Go to <https://remotedesktop.google.com/access> from any device.
|
||||
|
||||
## Testing
|
||||
|
||||
### CLI Tests
|
||||
|
||||
```bash
|
||||
# Basic open and close
|
||||
agent-browser open https://example.com
|
||||
agent-browser get title
|
||||
agent-browser close
|
||||
|
||||
# Snapshot with refs
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser close
|
||||
|
||||
# Screenshot
|
||||
agent-browser open https://example.com
|
||||
agent-browser screenshot /tmp/test.png
|
||||
agent-browser close
|
||||
```
|
||||
|
||||
### ZeroClaw Integration Tests
|
||||
|
||||
```bash
|
||||
# Content extraction
|
||||
echo "Open https://example.com and summarize it" | zeroclaw agent
|
||||
|
||||
# Navigation
|
||||
echo "Go to https://github.com/trending and list the top 3 repos" | zeroclaw agent
|
||||
|
||||
# Form interaction
|
||||
echo "Go to Wikipedia, search for 'Rust programming language', and summarize" | zeroclaw agent
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Element not found"
|
||||
|
||||
The page may not be fully loaded. Add a wait:
|
||||
|
||||
```bash
|
||||
agent-browser open https://slow-site.com
|
||||
agent-browser wait --load networkidle
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Cookie dialogs blocking access
|
||||
|
||||
Handle cookie consent first:
|
||||
|
||||
```bash
|
||||
agent-browser open https://site-with-cookies.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser click @accept_cookies # Click the accept button
|
||||
agent-browser snapshot -i # Now get the actual content
|
||||
```
|
||||
|
||||
### Docker sandbox network restrictions
|
||||
|
||||
If `web_fetch` fails inside Docker sandbox, use agent-browser instead:
|
||||
|
||||
```bash
|
||||
# Instead of web_fetch, use:
|
||||
agent-browser open https://example.com
|
||||
agent-browser get text body
|
||||
```
|
||||
|
||||
## Security Notes
|
||||
|
||||
- `agent-browser` runs Chrome in headless mode with sandboxing
|
||||
- For sensitive sites, use `--session-name` to persist auth state
|
||||
- The `--allowed-domains` config restricts navigation to specific domains
|
||||
- VNC ports (5900, 6080) should be behind a firewall or Tailscale
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser Documentation](https://github.com/vercel-labs/agent-browser)
|
||||
- [ZeroClaw Configuration Reference](./config-reference.md)
|
||||
- [Skills Documentation](../skills/)
|
||||
@@ -45,6 +45,15 @@ For complete code examples of each extension trait, see [extension-examples.md](
|
||||
- Keep multilingual entry-point parity for all supported locales (`en`, `zh-CN`, `ja`, `ru`, `fr`, `vi`) when nav or key wording changes.
|
||||
- When shared docs wording changes, sync corresponding localized docs in the same PR (or explicitly document deferral and follow-up PR).
|
||||
|
||||
## Tool Shared State
|
||||
|
||||
- Follow the `Arc<RwLock<T>>` handle pattern for any tool that owns long-lived shared state.
|
||||
- Accept handles at construction; do not create global/static mutable state.
|
||||
- Use `ClientId` (provided by the daemon) to namespace per-client state — never construct identity keys inside the tool.
|
||||
- Isolate security-sensitive state (credentials, quotas) per client; broadcast/display state may be shared with optional namespace prefixing.
|
||||
- Cached validation is invalidated on config change — tools must re-validate before the next execution when signaled.
|
||||
- See [ADR-004: Tool Shared State Ownership](../architecture/adr-004-tool-shared-state-ownership.md) for the full contract.
|
||||
|
||||
## Architecture Boundary Rules
|
||||
|
||||
- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features.
|
||||
|
||||
@@ -38,3 +38,46 @@ allowed_tools = ["read", "edit", "exec"]
|
||||
max_iterations = 15
|
||||
# Optional: use longer timeout for complex coding tasks
|
||||
agentic_timeout_secs = 600
|
||||
|
||||
# ── Cron Configuration ────────────────────────────────────────
|
||||
[cron]
|
||||
# Enable the cron subsystem. Default: true
|
||||
enabled = true
|
||||
# Run all overdue jobs at scheduler startup. Default: true
|
||||
catch_up_on_startup = true
|
||||
# Maximum number of historical cron run records to retain. Default: 50
|
||||
max_run_history = 50
|
||||
|
||||
# ── Declarative Cron Jobs ─────────────────────────────────────
|
||||
# Define cron jobs directly in config. These are synced to the database
|
||||
# at scheduler startup. Each job needs a stable `id` for merge semantics.
|
||||
|
||||
# Shell job: runs a shell command on a cron schedule
|
||||
[[cron.jobs]]
|
||||
id = "daily-backup"
|
||||
name = "Daily Backup"
|
||||
job_type = "shell"
|
||||
command = "tar czf /tmp/backup.tar.gz /data"
|
||||
schedule = { kind = "cron", expr = "0 2 * * *" }
|
||||
|
||||
# Agent job: runs an agent prompt on an interval
|
||||
[[cron.jobs]]
|
||||
id = "health-check"
|
||||
name = "Health Check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health: disk space, memory, CPU load"
|
||||
model = "anthropic/claude-sonnet-4"
|
||||
allowed_tools = ["shell", "file_read"]
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
|
||||
# Cron job with timezone and delivery
|
||||
# [[cron.jobs]]
|
||||
# id = "morning-report"
|
||||
# name = "Morning Report"
|
||||
# job_type = "agent"
|
||||
# prompt = "Generate a daily summary of system metrics"
|
||||
# schedule = { kind = "cron", expr = "0 9 * * 1-5", tz = "America/New_York" }
|
||||
# [cron.jobs.delivery]
|
||||
# mode = "announce"
|
||||
# channel = "telegram"
|
||||
# to = "123456789"
|
||||
|
||||
Executable
+21
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Start a browser on a virtual display
|
||||
# Usage: ./start-browser.sh [display_num] [url]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
URL=${2:-"https://google.com"}
|
||||
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Check if display is running
|
||||
if ! xdpyinfo -display :$DISPLAY_NUM &>/dev/null; then
|
||||
echo "Error: Display :$DISPLAY_NUM not running."
|
||||
echo "Start VNC first: ./start-vnc.sh"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
google-chrome --no-sandbox --disable-gpu --disable-setuid-sandbox "$URL" &
|
||||
echo "Chrome started on display :$DISPLAY_NUM"
|
||||
echo "View via VNC or noVNC"
|
||||
Executable
+52
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Start virtual display with VNC access for browser GUI
|
||||
# Usage: ./start-vnc.sh [display_num] [vnc_port] [novnc_port] [resolution]
|
||||
|
||||
set -e
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
VNC_PORT=${2:-5900}
|
||||
NOVNC_PORT=${3:-6080}
|
||||
RESOLUTION=${4:-1920x1080x24}
|
||||
|
||||
echo "Starting virtual display :$DISPLAY_NUM at $RESOLUTION"
|
||||
|
||||
# Kill any existing sessions
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*$NOVNC_PORT" 2>/dev/null || true
|
||||
sleep 1
|
||||
|
||||
# Start Xvfb (virtual framebuffer)
|
||||
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
|
||||
XVFB_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Set DISPLAY
|
||||
export DISPLAY=:$DISPLAY_NUM
|
||||
|
||||
# Start window manager
|
||||
fluxbox -display :$DISPLAY_NUM 2>/dev/null &
|
||||
sleep 1
|
||||
|
||||
# Start x11vnc
|
||||
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg 2>/dev/null
|
||||
sleep 1
|
||||
|
||||
# Start noVNC (web-based VNC client)
|
||||
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
|
||||
NOVNC_PID=$!
|
||||
|
||||
echo ""
|
||||
echo "==================================="
|
||||
echo "VNC Server started!"
|
||||
echo "==================================="
|
||||
echo "VNC Direct: localhost:$VNC_PORT"
|
||||
echo "noVNC Web: http://localhost:$NOVNC_PORT/vnc.html"
|
||||
echo "Display: :$DISPLAY_NUM"
|
||||
echo "==================================="
|
||||
echo ""
|
||||
echo "To start a browser, run:"
|
||||
echo " DISPLAY=:$DISPLAY_NUM google-chrome &"
|
||||
echo ""
|
||||
echo "To stop, run: pkill -f 'Xvfb :$DISPLAY_NUM'"
|
||||
Executable
+11
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
# Stop virtual display and VNC server
|
||||
# Usage: ./stop-vnc.sh [display_num]
|
||||
|
||||
DISPLAY_NUM=${1:-99}
|
||||
|
||||
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
|
||||
pkill -f "websockify.*6080" 2>/dev/null || true
|
||||
|
||||
echo "VNC server stopped"
|
||||
@@ -0,0 +1,122 @@
|
||||
---
|
||||
name: browser
|
||||
description: Headless browser automation using agent-browser CLI
|
||||
metadata: {"zeroclaw":{"emoji":"🌐","requires":{"bins":["agent-browser"]}}}
|
||||
---
|
||||
|
||||
# Browser Skill
|
||||
|
||||
Control a headless browser for web automation, scraping, and testing.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `agent-browser` CLI installed globally (`npm install -g agent-browser`)
|
||||
- Chrome downloaded (`agent-browser install`)
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Install agent-browser CLI
|
||||
npm install -g agent-browser
|
||||
|
||||
# Download Chrome for Testing
|
||||
agent-browser install --with-deps # Linux
|
||||
agent-browser install # macOS/Windows
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Navigate and snapshot
|
||||
|
||||
```bash
|
||||
agent-browser open https://example.com
|
||||
agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### Interact with elements
|
||||
|
||||
```bash
|
||||
agent-browser click @e1 # Click by ref
|
||||
agent-browser fill @e2 "text" # Fill input
|
||||
agent-browser press Enter # Press key
|
||||
```
|
||||
|
||||
### Extract data
|
||||
|
||||
```bash
|
||||
agent-browser get text @e1 # Get text content
|
||||
agent-browser get url # Get current URL
|
||||
agent-browser screenshot page.png # Take screenshot
|
||||
```
|
||||
|
||||
### Session management
|
||||
|
||||
```bash
|
||||
agent-browser close # Close browser
|
||||
```
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Login flow
|
||||
|
||||
```bash
|
||||
agent-browser open https://site.com/login
|
||||
agent-browser snapshot -i
|
||||
agent-browser fill @email "user@example.com"
|
||||
agent-browser fill @password "secretpass"
|
||||
agent-browser click @submit
|
||||
agent-browser wait --text "Welcome"
|
||||
```
|
||||
|
||||
### Scrape page content
|
||||
|
||||
```bash
|
||||
agent-browser open https://news.ycombinator.com
|
||||
agent-browser snapshot -i
|
||||
agent-browser get text @e1
|
||||
```
|
||||
|
||||
### Take screenshots
|
||||
|
||||
```bash
|
||||
agent-browser open https://google.com
|
||||
agent-browser screenshot --full page.png
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
- `--json` - JSON output for parsing
|
||||
- `--headed` - Show browser window (for debugging)
|
||||
- `--session-name <name>` - Persist session cookies
|
||||
- `--profile <path>` - Use persistent browser profile
|
||||
|
||||
## Configuration
|
||||
|
||||
The browser tool is enabled by default with `allowed_domains = ["*"]` and
|
||||
`backend = "agent_browser"`. To customize, edit `~/.zeroclaw/config.toml`:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = true # default: true
|
||||
allowed_domains = ["*"] # default: ["*"] (all public hosts)
|
||||
backend = "agent_browser" # default: "agent_browser"
|
||||
native_headless = true # default: true
|
||||
```
|
||||
|
||||
To restrict domains or disable the browser tool:
|
||||
|
||||
```toml
|
||||
[browser]
|
||||
enabled = false # disable entirely
|
||||
# or restrict to specific domains:
|
||||
allowed_domains = ["example.com", "docs.example.com"]
|
||||
```
|
||||
|
||||
## Full Command Reference
|
||||
|
||||
Run `agent-browser --help` for all available commands.
|
||||
|
||||
## Related
|
||||
|
||||
- [agent-browser GitHub](https://github.com/vercel-labs/agent-browser)
|
||||
- [VNC Setup Guide](../docs/browser-setup.md)
|
||||
+517
-60
@@ -4,7 +4,7 @@ use crate::config::Config;
|
||||
use crate::cost::types::{BudgetCheck, TokenUsage as CostTokenUsage};
|
||||
use crate::cost::CostTracker;
|
||||
use crate::i18n::ToolDescriptions;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::memory::{self, decay, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
use crate::observability::{self, runtime_trace, Observer, ObserverEvent};
|
||||
use crate::providers::{
|
||||
@@ -561,6 +561,7 @@ fn save_interactive_session_history(path: &Path, history: &[ChatMessage]) -> Res
|
||||
/// Build context preamble by searching memory for relevant entries.
|
||||
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||
/// prevent unrelated memories from bleeding into the conversation.
|
||||
/// Core memories are exempt from time decay (evergreen).
|
||||
async fn build_context(
|
||||
mem: &dyn Memory,
|
||||
user_msg: &str,
|
||||
@@ -570,7 +571,10 @@ async fn build_context(
|
||||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
if let Ok(mut entries) = mem.recall(user_msg, 5, session_id, None, None).await {
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let relevant: Vec<_> = entries
|
||||
.iter()
|
||||
.filter(|e| match e.score {
|
||||
@@ -2707,16 +2711,53 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
|
||||
// ── Vision provider routing ──────────────────────────
|
||||
// When the default provider lacks vision support but a dedicated
|
||||
// vision_provider is configured, create it on demand and use it
|
||||
// for this iteration. Otherwise, preserve the original error.
|
||||
let vision_provider_box: Option<Box<dyn Provider>> = if image_marker_count > 0
|
||||
&& !provider.supports_vision()
|
||||
{
|
||||
if let Some(ref vp) = multimodal_config.vision_provider {
|
||||
let vp_instance = providers::create_provider(vp, None)
|
||||
.map_err(|e| anyhow::anyhow!("failed to create vision provider '{vp}': {e}"))?;
|
||||
if !vp_instance.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: vp.clone(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"configured vision_provider '{vp}' does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
Some(vp_instance)
|
||||
} else {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
.into());
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (active_provider, active_provider_name, active_model): (&dyn Provider, &str, &str) =
|
||||
if let Some(ref vp_box) = vision_provider_box {
|
||||
let vp_name = multimodal_config
|
||||
.vision_provider
|
||||
.as_deref()
|
||||
.unwrap_or(provider_name);
|
||||
let vm = multimodal_config.vision_model.as_deref().unwrap_or(model);
|
||||
(vp_box.as_ref(), vp_name, vm)
|
||||
} else {
|
||||
(provider, provider_name, model)
|
||||
};
|
||||
|
||||
let prepared_messages =
|
||||
multimodal::prepare_messages_for_provider(history, multimodal_config).await?;
|
||||
@@ -2732,15 +2773,15 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmRequest {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
messages_count: history.len(),
|
||||
});
|
||||
runtime_trace::record_event(
|
||||
"llm_request",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
None,
|
||||
None,
|
||||
@@ -2778,12 +2819,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||
None
|
||||
};
|
||||
|
||||
let chat_future = provider.chat(
|
||||
let chat_future = active_provider.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
active_model,
|
||||
temperature,
|
||||
);
|
||||
|
||||
@@ -2836,8 +2877,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
.unwrap_or((None, None));
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
@@ -2846,10 +2887,9 @@ pub(crate) async fn run_tool_call_loop(
|
||||
});
|
||||
|
||||
// Record cost via task-local tracker (no-op when not scoped)
|
||||
let _ = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.and_then(|usage| record_tool_loop_cost_usage(provider_name, model, usage));
|
||||
let _ = resp.usage.as_ref().and_then(|usage| {
|
||||
record_tool_loop_cost_usage(active_provider_name, active_model, usage)
|
||||
});
|
||||
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
@@ -2872,8 +2912,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"tool_call_parse_issue",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&parse_issue),
|
||||
@@ -2890,8 +2930,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
@@ -2940,8 +2980,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
Err(e) => {
|
||||
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
provider: active_provider_name.to_string(),
|
||||
model: active_model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(safe_error.clone()),
|
||||
@@ -2951,8 +2991,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(active_provider_name),
|
||||
Some(active_model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&safe_error),
|
||||
@@ -3701,6 +3741,11 @@ pub async fn run(
|
||||
|
||||
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
||||
let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config);
|
||||
|
||||
// Register skill-defined tools as callable tool specs in the tool registry
|
||||
// so the LLM can invoke them via native function calling, not just XML prompts.
|
||||
tools::register_skill_tools(&mut tools_registry, &skills, security.clone());
|
||||
|
||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||
(
|
||||
"shell",
|
||||
@@ -3865,17 +3910,45 @@ pub async fn run(
|
||||
|
||||
let mut final_output = String::new();
|
||||
|
||||
// Save the base system prompt before any thinking modifications so
|
||||
// the interactive loop can restore it between turns.
|
||||
let base_system_prompt = system_prompt.clone();
|
||||
|
||||
if let Some(msg) = message {
|
||||
// ── Parse thinking directive from user message ─────────
|
||||
let (thinking_directive, effective_msg) =
|
||||
match crate::agent::thinking::parse_thinking_directive(&msg) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed from message");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, msg.clone()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let effective_temperature = crate::agent::thinking::clamp_temperature(
|
||||
temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// Prepend thinking system prompt prefix when present.
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
}
|
||||
|
||||
// Auto-save user message to memory (skip short/trivial messages)
|
||||
if config.memory.auto_save
|
||||
&& msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&msg)
|
||||
&& effective_msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&effective_msg)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(
|
||||
&user_key,
|
||||
&msg,
|
||||
&effective_msg,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -3885,7 +3958,7 @@ pub async fn run(
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&msg,
|
||||
&effective_msg,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -3893,14 +3966,14 @@ pub async fn run(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, &msg, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, &effective_msg, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {msg}")
|
||||
format!("[{now}] {effective_msg}")
|
||||
} else {
|
||||
format!("{context}[{now}] {msg}")
|
||||
format!("{context}[{now}] {effective_msg}")
|
||||
};
|
||||
|
||||
let mut history = vec![
|
||||
@@ -3909,8 +3982,11 @@ pub async fn run(
|
||||
];
|
||||
|
||||
// Compute per-turn excluded MCP tools from tool_filter_groups.
|
||||
let excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, &msg);
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
&effective_msg,
|
||||
);
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
let mut response = String::new();
|
||||
@@ -3922,7 +3998,7 @@ pub async fn run(
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
effective_temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
@@ -4042,9 +4118,10 @@ pub async fn run(
|
||||
"/quit" | "/exit" => break,
|
||||
"/help" => {
|
||||
println!("Available commands:");
|
||||
println!(" /help Show this help message");
|
||||
println!(" /clear /new Clear conversation history");
|
||||
println!(" /quit /exit Exit interactive mode\n");
|
||||
println!(" /help Show this help message");
|
||||
println!(" /clear /new Clear conversation history");
|
||||
println!(" /quit /exit Exit interactive mode");
|
||||
println!(" /think:<level> Set reasoning depth (off|minimal|low|medium|high|max)\n");
|
||||
continue;
|
||||
}
|
||||
"/clear" | "/new" => {
|
||||
@@ -4096,16 +4173,47 @@ pub async fn run(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// ── Parse thinking directive from interactive input ───
|
||||
let (thinking_directive, effective_input) =
|
||||
match crate::agent::thinking::parse_thinking_directive(&user_input) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, user_input.clone()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let turn_temperature = crate::agent::thinking::clamp_temperature(
|
||||
temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// For non-Medium levels, temporarily patch the system prompt with prefix.
|
||||
let turn_system_prompt;
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
turn_system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
// Update the system message in history for this turn.
|
||||
if let Some(sys_msg) = history.first_mut() {
|
||||
if sys_msg.role == "system" {
|
||||
sys_msg.content = turn_system_prompt.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-save conversation turns (skip short/trivial messages)
|
||||
if config.memory.auto_save
|
||||
&& user_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&user_input)
|
||||
&& effective_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS
|
||||
&& !memory::should_skip_autosave_content(&effective_input)
|
||||
{
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(
|
||||
&user_key,
|
||||
&user_input,
|
||||
&effective_input,
|
||||
MemoryCategory::Conversation,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -4115,7 +4223,7 @@ pub async fn run(
|
||||
// Inject memory + hardware RAG context into user message
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
&user_input,
|
||||
&effective_input,
|
||||
config.memory.min_relevance_score,
|
||||
memory_session_id.as_deref(),
|
||||
)
|
||||
@@ -4123,14 +4231,14 @@ pub async fn run(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, &user_input, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, &effective_input, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {user_input}")
|
||||
format!("[{now}] {effective_input}")
|
||||
} else {
|
||||
format!("{context}[{now}] {user_input}")
|
||||
format!("{context}[{now}] {effective_input}")
|
||||
};
|
||||
|
||||
history.push(ChatMessage::user(&enriched));
|
||||
@@ -4139,7 +4247,7 @@ pub async fn run(
|
||||
let excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
&user_input,
|
||||
&effective_input,
|
||||
);
|
||||
|
||||
let response = loop {
|
||||
@@ -4150,7 +4258,7 @@ pub async fn run(
|
||||
observer.as_ref(),
|
||||
&provider_name,
|
||||
&model_name,
|
||||
temperature,
|
||||
turn_temperature,
|
||||
false,
|
||||
approval_manager.as_ref(),
|
||||
channel_name,
|
||||
@@ -4235,6 +4343,15 @@ pub async fn run(
|
||||
// Hard cap as a safety net.
|
||||
trim_history(&mut history, config.agent.max_history_messages);
|
||||
|
||||
// Restore base system prompt (remove per-turn thinking prefix).
|
||||
if thinking_params.system_prompt_prefix.is_some() {
|
||||
if let Some(sys_msg) = history.first_mut() {
|
||||
if sys_msg.role == "system" {
|
||||
sys_msg.content.clone_from(&base_system_prompt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = session_state_file.as_deref() {
|
||||
save_interactive_session_history(path, &history)?;
|
||||
}
|
||||
@@ -4415,6 +4532,10 @@ pub async fn process_message(
|
||||
let i18n_descs = crate::i18n::ToolDescriptions::load(&i18n_locale, &i18n_search_dirs);
|
||||
|
||||
let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config);
|
||||
|
||||
// Register skill-defined tools as callable tool specs (process_message path).
|
||||
tools::register_skill_tools(&mut tools_registry, &skills, security.clone());
|
||||
|
||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||
("shell", "Execute terminal commands."),
|
||||
("file_read", "Read file contents."),
|
||||
@@ -4508,9 +4629,34 @@ pub async fn process_message(
|
||||
system_prompt.push_str(&deferred_section);
|
||||
}
|
||||
|
||||
// ── Parse thinking directive from user message ─────────────
|
||||
let (thinking_directive, effective_message) =
|
||||
match crate::agent::thinking::parse_thinking_directive(message) {
|
||||
Some((level, remaining)) => {
|
||||
tracing::info!(thinking_level = ?level, "Thinking directive parsed from message");
|
||||
(Some(level), remaining)
|
||||
}
|
||||
None => (None, message.to_string()),
|
||||
};
|
||||
let thinking_level = crate::agent::thinking::resolve_thinking_level(
|
||||
thinking_directive,
|
||||
None,
|
||||
&config.agent.thinking,
|
||||
);
|
||||
let thinking_params = crate::agent::thinking::apply_thinking_level(thinking_level);
|
||||
let effective_temperature = crate::agent::thinking::clamp_temperature(
|
||||
config.default_temperature + thinking_params.temperature_adjustment,
|
||||
);
|
||||
|
||||
// Prepend thinking system prompt prefix when present.
|
||||
if let Some(ref prefix) = thinking_params.system_prompt_prefix {
|
||||
system_prompt = format!("{prefix}\n\n{system_prompt}");
|
||||
}
|
||||
|
||||
let effective_msg_ref = effective_message.as_str();
|
||||
let mem_context = build_context(
|
||||
mem.as_ref(),
|
||||
message,
|
||||
effective_msg_ref,
|
||||
config.memory.min_relevance_score,
|
||||
session_id,
|
||||
)
|
||||
@@ -4518,22 +4664,25 @@ pub async fn process_message(
|
||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||
let hw_context = hardware_rag
|
||||
.as_ref()
|
||||
.map(|r| build_hardware_context(r, message, &board_names, rag_limit))
|
||||
.map(|r| build_hardware_context(r, effective_msg_ref, &board_names, rag_limit))
|
||||
.unwrap_or_default();
|
||||
let context = format!("{mem_context}{hw_context}");
|
||||
let now = chrono::Local::now().format("%Y-%m-%d %H:%M:%S %Z");
|
||||
let enriched = if context.is_empty() {
|
||||
format!("[{now}] {message}")
|
||||
format!("[{now}] {effective_message}")
|
||||
} else {
|
||||
format!("{context}[{now}] {message}")
|
||||
format!("{context}[{now}] {effective_message}")
|
||||
};
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system(&system_prompt),
|
||||
ChatMessage::user(&enriched),
|
||||
];
|
||||
let mut excluded_tools =
|
||||
compute_excluded_mcp_tools(&tools_registry, &config.agent.tool_filter_groups, message);
|
||||
let mut excluded_tools = compute_excluded_mcp_tools(
|
||||
&tools_registry,
|
||||
&config.agent.tool_filter_groups,
|
||||
effective_msg_ref,
|
||||
);
|
||||
if config.autonomy.level != AutonomyLevel::Full {
|
||||
excluded_tools.extend(config.autonomy.non_cli_excluded_tools.iter().cloned());
|
||||
}
|
||||
@@ -4545,7 +4694,7 @@ pub async fn process_message(
|
||||
observer.as_ref(),
|
||||
provider_name,
|
||||
&model_name,
|
||||
config.default_temperature,
|
||||
effective_temperature,
|
||||
true,
|
||||
"daemon",
|
||||
None,
|
||||
@@ -5094,6 +5243,7 @@ mod tests {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
@@ -5171,6 +5321,313 @@ mod tests {
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
/// When `vision_provider` is not set and the default provider lacks vision
|
||||
/// support, the original `ProviderCapabilityError` should be returned.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_no_vision_provider_config_preserves_error() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"check [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail without vision_provider config");
|
||||
|
||||
assert!(err.to_string().contains("capability=vision"));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
/// When `vision_provider` is set but the provider factory cannot resolve
|
||||
/// the name, a descriptive error should be returned (not the generic
|
||||
/// capability error).
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_vision_provider_creation_failure() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("some-model".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail when vision provider cannot be created");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure error, got: {}",
|
||||
err
|
||||
);
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
/// Messages without image markers should use the default provider even
|
||||
/// when `vision_provider` is configured.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_no_images_uses_default_provider() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["hello world"]);
|
||||
|
||||
let mut history = vec![ChatMessage::user("just text, no images".to_string())];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("some-model".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Even though vision_provider points to a nonexistent provider, this
|
||||
// should succeed because there are no image markers to trigger routing.
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"scripted",
|
||||
"scripted-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("text-only messages should succeed with default provider");
|
||||
|
||||
assert_eq!(result, "hello world");
|
||||
}
|
||||
|
||||
/// When `vision_provider` is set but `vision_model` is not, the default
|
||||
/// model should be used as fallback for the vision provider.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_vision_provider_without_model_falls_back() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"look [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
// vision_provider set but vision_model is None — the code should
|
||||
// fall back to the default model. Since the provider name is invalid,
|
||||
// we just verify the error path references the correct provider.
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should fail due to nonexistent vision provider");
|
||||
|
||||
// Verify the routing was attempted (not the generic capability error).
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure, got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
/// Empty `[IMAGE:]` markers (which are preserved as literal text by the
|
||||
/// parser) should not trigger vision provider routing.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_empty_image_markers_use_default_provider() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["handled"]);
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"empty marker [IMAGE:] should be ignored".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"scripted",
|
||||
"scripted-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect("empty image markers should not trigger vision routing");
|
||||
|
||||
assert_eq!(result, "handled");
|
||||
}
|
||||
|
||||
/// Multiple image markers should still trigger vision routing when
|
||||
/// vision_provider is configured.
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_multiple_images_trigger_vision_routing() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"two images [IMAGE:data:image/png;base64,aQ==] and [IMAGE:data:image/png;base64,bQ==]"
|
||||
.to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
vision_provider: Some("nonexistent-provider-xyz".to_string()),
|
||||
vision_model: Some("llava:7b".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
None,
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
None,
|
||||
&crate::config::PacingConfig::default(),
|
||||
)
|
||||
.await
|
||||
.expect_err("should attempt vision provider creation for multiple images");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("failed to create vision provider"),
|
||||
"expected creation failure for multiple images, got: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_execute_tools_in_parallel_returns_false_for_single_call() {
|
||||
let calls = vec![ParsedToolCall {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::memory::{self, Memory};
|
||||
use crate::memory::{self, decay, Memory};
|
||||
use async_trait::async_trait;
|
||||
use std::fmt::Write;
|
||||
|
||||
@@ -43,13 +43,16 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||
user_message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory
|
||||
let mut entries = memory
|
||||
.recall(user_message, self.limit, session_id, None, None)
|
||||
.await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Apply time decay: older non-Core memories score lower
|
||||
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
|
||||
|
||||
let mut context = String::from("[Memory context]\n");
|
||||
for entry in entries {
|
||||
if memory::is_assistant_autosave_key(&entry.key) {
|
||||
|
||||
@@ -5,6 +5,7 @@ pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod memory_loader;
|
||||
pub mod prompt;
|
||||
pub mod thinking;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
+7
-6
@@ -473,8 +473,9 @@ mod tests {
|
||||
assert!(output.contains("<available_skills>"));
|
||||
assert!(output.contains("<name>deploy</name>"));
|
||||
assert!(output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -516,10 +517,10 @@ mod tests {
|
||||
assert!(output.contains("<location>skills/deploy/SKILL.md</location>"));
|
||||
assert!(output.contains("read_skill(name)"));
|
||||
assert!(!output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(output.contains("<tools>"));
|
||||
assert!(output.contains("<name>release_checklist</name>"));
|
||||
assert!(output.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(output.contains("<callable_tools"));
|
||||
assert!(output.contains("<name>deploy.release_checklist</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
//! Thinking/Reasoning Level Control
|
||||
//!
|
||||
//! Allows users to control how deeply the model reasons per message,
|
||||
//! trading speed for depth. Levels range from `Off` (fastest, most concise)
|
||||
//! to `Max` (deepest reasoning, slowest).
|
||||
//!
|
||||
//! Users can set the level via:
|
||||
//! - Inline directive: `/think:high` at the start of a message
|
||||
//! - Agent config: `[agent.thinking]` section with `default_level`
|
||||
//!
|
||||
//! Resolution hierarchy (highest priority first):
|
||||
//! 1. Inline directive (`/think:<level>`)
|
||||
//! 2. Session override (reserved for future use)
|
||||
//! 3. Agent config (`agent.thinking.default_level`)
|
||||
//! 4. Global default (`Medium`)
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// How deeply the model should reason for a given message.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ThinkingLevel {
|
||||
/// No chain-of-thought. Fastest, most concise responses.
|
||||
Off,
|
||||
/// Minimal reasoning. Brief, direct answers.
|
||||
Minimal,
|
||||
/// Light reasoning. Short explanations when needed.
|
||||
Low,
|
||||
/// Balanced reasoning (default). Moderate depth.
|
||||
#[default]
|
||||
Medium,
|
||||
/// Deep reasoning. Thorough analysis and step-by-step thinking.
|
||||
High,
|
||||
/// Maximum reasoning depth. Exhaustive analysis.
|
||||
Max,
|
||||
}
|
||||
|
||||
impl ThinkingLevel {
|
||||
/// Parse a thinking level from a string (case-insensitive).
|
||||
pub fn from_str_insensitive(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"off" | "none" => Some(Self::Off),
|
||||
"minimal" | "min" => Some(Self::Minimal),
|
||||
"low" => Some(Self::Low),
|
||||
"medium" | "med" | "default" => Some(Self::Medium),
|
||||
"high" => Some(Self::High),
|
||||
"max" | "maximum" => Some(Self::Max),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for thinking/reasoning level control.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct ThinkingConfig {
|
||||
/// Default thinking level when no directive is present.
|
||||
#[serde(default)]
|
||||
pub default_level: ThinkingLevel,
|
||||
}
|
||||
|
||||
impl Default for ThinkingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_level: ThinkingLevel::Medium,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters derived from a thinking level, applied to the LLM request.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ThinkingParams {
|
||||
/// Temperature adjustment (added to the base temperature, clamped to 0.0..=2.0).
|
||||
pub temperature_adjustment: f64,
|
||||
/// Maximum tokens adjustment (added to any existing max_tokens setting).
|
||||
pub max_tokens_adjustment: i64,
|
||||
/// Optional system prompt prefix injected before the existing system prompt.
|
||||
pub system_prompt_prefix: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse a `/think:<level>` directive from the start of a message.
|
||||
///
|
||||
/// Returns `Some((level, remaining_message))` if a directive is found,
|
||||
/// or `None` if no directive is present. The remaining message has
|
||||
/// leading whitespace after the directive trimmed.
|
||||
pub fn parse_thinking_directive(message: &str) -> Option<(ThinkingLevel, String)> {
|
||||
let trimmed = message.trim_start();
|
||||
if !trimmed.starts_with("/think:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract the level token (everything between `/think:` and the next whitespace or end).
|
||||
let after_prefix = &trimmed["/think:".len()..];
|
||||
let level_end = after_prefix
|
||||
.find(|c: char| c.is_whitespace())
|
||||
.unwrap_or(after_prefix.len());
|
||||
let level_str = &after_prefix[..level_end];
|
||||
|
||||
let level = ThinkingLevel::from_str_insensitive(level_str)?;
|
||||
|
||||
let remaining = after_prefix[level_end..].trim_start().to_string();
|
||||
Some((level, remaining))
|
||||
}
|
||||
|
||||
/// Convert a `ThinkingLevel` into concrete parameters for the LLM request.
|
||||
pub fn apply_thinking_level(level: ThinkingLevel) -> ThinkingParams {
|
||||
match level {
|
||||
ThinkingLevel::Off => ThinkingParams {
|
||||
temperature_adjustment: -0.2,
|
||||
max_tokens_adjustment: -1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Be extremely concise. Give direct answers without explanation \
|
||||
unless explicitly asked. No preamble."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Minimal => ThinkingParams {
|
||||
temperature_adjustment: -0.1,
|
||||
max_tokens_adjustment: -500,
|
||||
system_prompt_prefix: Some(
|
||||
"Be concise and fast. Keep explanations brief. \
|
||||
Prioritize speed over thoroughness."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Low => ThinkingParams {
|
||||
temperature_adjustment: -0.05,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: Some("Keep reasoning light. Explain only when helpful.".into()),
|
||||
},
|
||||
ThinkingLevel::Medium => ThinkingParams {
|
||||
temperature_adjustment: 0.0,
|
||||
max_tokens_adjustment: 0,
|
||||
system_prompt_prefix: None,
|
||||
},
|
||||
ThinkingLevel::High => ThinkingParams {
|
||||
temperature_adjustment: 0.05,
|
||||
max_tokens_adjustment: 1000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think step by step. Provide thorough analysis and \
|
||||
consider edge cases before answering."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
ThinkingLevel::Max => ThinkingParams {
|
||||
temperature_adjustment: 0.1,
|
||||
max_tokens_adjustment: 2000,
|
||||
system_prompt_prefix: Some(
|
||||
"Think very carefully and exhaustively. Break down the problem \
|
||||
into sub-problems, consider all angles, verify your reasoning, \
|
||||
and provide the most thorough analysis possible."
|
||||
.into(),
|
||||
),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the effective thinking level using the priority hierarchy:
|
||||
/// 1. Inline directive (if present)
|
||||
/// 2. Session override (reserved, currently always `None`)
|
||||
/// 3. Agent config default
|
||||
/// 4. Global default (`Medium`)
|
||||
pub fn resolve_thinking_level(
|
||||
inline_directive: Option<ThinkingLevel>,
|
||||
session_override: Option<ThinkingLevel>,
|
||||
config: &ThinkingConfig,
|
||||
) -> ThinkingLevel {
|
||||
inline_directive
|
||||
.or(session_override)
|
||||
.unwrap_or(config.default_level)
|
||||
}
|
||||
|
||||
/// Clamp a temperature value to the valid range `[0.0, 2.0]`.
|
||||
pub fn clamp_temperature(temp: f64) -> f64 {
|
||||
temp.clamp(0.0, 2.0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── ThinkingLevel parsing ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_canonical_names() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("off"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("minimal"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("low"),
|
||||
Some(ThinkingLevel::Low)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("medium"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("high"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_aliases() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("none"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("min"),
|
||||
Some(ThinkingLevel::Minimal)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("med"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("default"),
|
||||
Some(ThinkingLevel::Medium)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("maximum"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_case_insensitive() {
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("HIGH"),
|
||||
Some(ThinkingLevel::High)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("Max"),
|
||||
Some(ThinkingLevel::Max)
|
||||
);
|
||||
assert_eq!(
|
||||
ThinkingLevel::from_str_insensitive("OFF"),
|
||||
Some(ThinkingLevel::Off)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_from_str_invalid_returns_none() {
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("turbo"), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive(""), None);
|
||||
assert_eq!(ThinkingLevel::from_str_insensitive("super-high"), None);
|
||||
}
|
||||
|
||||
// ── Directive parsing ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_directive_extracts_level_and_remaining_message() {
|
||||
let result = parse_thinking_directive("/think:high What is Rust?");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::High);
|
||||
assert_eq!(remaining, "What is Rust?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_handles_directive_only() {
|
||||
let result = parse_thinking_directive("/think:off");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Off);
|
||||
assert_eq!(remaining, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_strips_leading_whitespace() {
|
||||
let result = parse_thinking_directive(" /think:low Tell me about Rust");
|
||||
assert!(result.is_some());
|
||||
let (level, remaining) = result.unwrap();
|
||||
assert_eq!(level, ThinkingLevel::Low);
|
||||
assert_eq!(remaining, "Tell me about Rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_no_directive() {
|
||||
assert!(parse_thinking_directive("Hello world").is_none());
|
||||
assert!(parse_thinking_directive("").is_none());
|
||||
assert!(parse_thinking_directive("/think").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_returns_none_for_invalid_level() {
|
||||
assert!(parse_thinking_directive("/think:turbo What?").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_directive_not_triggered_mid_message() {
|
||||
assert!(parse_thinking_directive("Hello /think:high world").is_none());
|
||||
}
|
||||
|
||||
// ── Level application ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_off_is_concise() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Off);
|
||||
assert!(params.temperature_adjustment < 0.0);
|
||||
assert!(params.max_tokens_adjustment < 0);
|
||||
assert!(params.system_prompt_prefix.is_some());
|
||||
assert!(params
|
||||
.system_prompt_prefix
|
||||
.unwrap()
|
||||
.to_lowercase()
|
||||
.contains("concise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_medium_is_neutral() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Medium);
|
||||
assert!((params.temperature_adjustment - 0.0).abs() < f64::EPSILON);
|
||||
assert_eq!(params.max_tokens_adjustment, 0);
|
||||
assert!(params.system_prompt_prefix.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_high_adds_step_by_step() {
|
||||
let params = apply_thinking_level(ThinkingLevel::High);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("step by step"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_thinking_level_max_is_most_thorough() {
|
||||
let params = apply_thinking_level(ThinkingLevel::Max);
|
||||
assert!(params.temperature_adjustment > 0.0);
|
||||
assert!(params.max_tokens_adjustment > 0);
|
||||
let prefix = params.system_prompt_prefix.unwrap();
|
||||
assert!(prefix.to_lowercase().contains("exhaustively"));
|
||||
}
|
||||
|
||||
// ── Resolution hierarchy ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn resolve_inline_directive_takes_priority() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result =
|
||||
resolve_thinking_level(Some(ThinkingLevel::Max), Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::Max);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_session_override_takes_priority_over_config() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Low,
|
||||
};
|
||||
let result = resolve_thinking_level(None, Some(ThinkingLevel::High), &config);
|
||||
assert_eq!(result, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_falls_back_to_config_default() {
|
||||
let config = ThinkingConfig {
|
||||
default_level: ThinkingLevel::Minimal,
|
||||
};
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Minimal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_default_config_uses_medium() {
|
||||
let config = ThinkingConfig::default();
|
||||
let result = resolve_thinking_level(None, None, &config);
|
||||
assert_eq!(result, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
// ── Temperature clamping ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_within_range() {
|
||||
assert!((clamp_temperature(0.7) - 0.7).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(0.0) - 0.0).abs() < f64::EPSILON);
|
||||
assert!((clamp_temperature(2.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_below_minimum() {
|
||||
assert!((clamp_temperature(-0.5) - 0.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_temperature_above_maximum() {
|
||||
assert!((clamp_temperature(3.0) - 2.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Serde round-trip ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn thinking_config_deserializes_from_toml() {
|
||||
let toml_str = r#"default_level = "high""#;
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::High);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_config_default_level_deserializes() {
|
||||
let toml_str = "";
|
||||
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
|
||||
assert_eq!(config.default_level, ThinkingLevel::Medium);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn thinking_level_serializes_lowercase() {
|
||||
let level = ThinkingLevel::High;
|
||||
let json = serde_json::to_string(&level).unwrap();
|
||||
assert_eq!(json, "\"high\"");
|
||||
}
|
||||
}
|
||||
@@ -562,4 +562,50 @@ mod tests {
|
||||
let parsed: ApprovalRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.tool_name, "shell");
|
||||
}
|
||||
|
||||
// ── Regression: #4247 default approved tools in channels ──
|
||||
|
||||
#[test]
|
||||
fn non_interactive_allows_default_auto_approve_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
|
||||
for tool in &config.auto_approve {
|
||||
assert!(
|
||||
!mgr.needs_approval(tool),
|
||||
"default auto_approve tool '{tool}' should not need approval in non-interactive mode"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_denies_unknown_tools() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("some_unknown_tool"),
|
||||
"unknown tool should need approval"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_interactive_weather_is_auto_approved() {
|
||||
let config = AutonomyConfig::default();
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
!mgr.needs_approval("weather"),
|
||||
"weather tool must not need approval — it is in the default auto_approve list"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn always_ask_overrides_auto_approve() {
|
||||
let mut config = AutonomyConfig::default();
|
||||
config.always_ask = vec!["weather".into()];
|
||||
let mgr = ApprovalManager::for_non_interactive(&config);
|
||||
assert!(
|
||||
mgr.needs_approval("weather"),
|
||||
"always_ask must override auto_approve"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+116
-1
@@ -20,6 +20,9 @@ pub struct DiscordChannel {
|
||||
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
@@ -38,6 +41,7 @@ impl DiscordChannel {
|
||||
mention_only,
|
||||
typing_handles: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +51,14 @@ impl DiscordChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
|
||||
}
|
||||
@@ -113,6 +125,88 @@ async fn process_attachments(
|
||||
parts.join("\n---\n")
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const DISCORD_AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check if a content type or filename indicates an audio file.
|
||||
fn is_discord_audio_attachment(content_type: &str, filename: &str) -> bool {
|
||||
if content_type.starts_with("audio/") {
|
||||
return true;
|
||||
}
|
||||
if let Some(ext) = filename.rsplit('.').next() {
|
||||
return DISCORD_AUDIO_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str());
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Download and transcribe audio attachments from a Discord message.
|
||||
///
|
||||
/// Returns transcribed text blocks for any audio attachments found.
|
||||
/// Non-audio attachments and failures are silently skipped.
|
||||
async fn transcribe_discord_audio_attachments(
|
||||
attachments: &[serde_json::Value],
|
||||
client: &reqwest::Client,
|
||||
config: &crate::config::TranscriptionConfig,
|
||||
) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
for att in attachments {
|
||||
let ct = att
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
let name = att
|
||||
.get("filename")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("file");
|
||||
|
||||
if !is_discord_audio_attachment(ct, name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(url) = att.get("url").and_then(|v| v.as_str()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let audio_data = match client.get(url).send().await {
|
||||
Ok(resp) if resp.status().is_success() => match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: failed to read audio attachment bytes");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Ok(resp) => {
|
||||
tracing::warn!(name, status = %resp.status(), "discord: audio attachment download failed");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: audio attachment fetch error");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, name, config).await {
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
tracing::info!(
|
||||
"Discord: transcribed audio attachment {} ({} chars)",
|
||||
name,
|
||||
trimmed.len()
|
||||
);
|
||||
parts.push(format!("[Voice] {trimmed}"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(name, error = %e, "discord: voice transcription failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum DiscordAttachmentKind {
|
||||
Image,
|
||||
@@ -737,7 +831,28 @@ impl Channel for DiscordChannel {
|
||||
.and_then(|a| a.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
process_attachments(&atts, &self.http_client()).await
|
||||
let client = self.http_client();
|
||||
let mut text_parts = process_attachments(&atts, &client).await;
|
||||
|
||||
// Transcribe audio attachments when transcription is configured
|
||||
if let Some(ref transcription_config) = self.transcription {
|
||||
let voice_text = transcribe_discord_audio_attachments(
|
||||
&atts,
|
||||
&client,
|
||||
transcription_config,
|
||||
)
|
||||
.await;
|
||||
if !voice_text.is_empty() {
|
||||
if text_parts.is_empty() {
|
||||
text_parts = voice_text;
|
||||
} else {
|
||||
text_parts = format!("{text_parts}
|
||||
{voice_text}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text_parts
|
||||
};
|
||||
let final_content = if attachment_text.is_empty() {
|
||||
clean_content
|
||||
|
||||
+535
-47
@@ -1,5 +1,6 @@
|
||||
use super::traits::{Channel, ChannelMessage, SendMessage};
|
||||
use async_trait::async_trait;
|
||||
use base64::Engine as _;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use std::collections::HashMap;
|
||||
@@ -221,6 +222,21 @@ const LARK_INVALID_ACCESS_TOKEN_CODE: i64 = 99_991_663;
|
||||
/// Lark card payloads have a ~30 KB limit; leave margin for JSON envelope.
|
||||
const LARK_CARD_MARKDOWN_MAX_BYTES: usize = 28_000;
|
||||
|
||||
/// Maximum image size we will download and inline (5 MiB).
|
||||
const LARK_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
|
||||
|
||||
/// Maximum file size we will download and present as text (512 KiB).
|
||||
const LARK_FILE_MAX_BYTES: usize = 512 * 1024;
|
||||
|
||||
/// Image MIME types we support for inline base64 encoding.
|
||||
const LARK_SUPPORTED_IMAGE_MIMES: &[&str] = &[
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/bmp",
|
||||
];
|
||||
|
||||
/// Returns true when the WebSocket frame indicates live traffic that should
|
||||
/// refresh the heartbeat watchdog.
|
||||
fn should_refresh_last_recv(msg: &WsMsg) -> bool {
|
||||
@@ -520,6 +536,17 @@ impl LarkChannel {
|
||||
format!("{}/im/v1/messages/{message_id}/reactions", self.api_base())
|
||||
}
|
||||
|
||||
fn image_download_url(&self, image_key: &str) -> String {
|
||||
format!("{}/im/v1/images/{image_key}", self.api_base())
|
||||
}
|
||||
|
||||
fn file_download_url(&self, message_id: &str, file_key: &str) -> String {
|
||||
format!(
|
||||
"{}/im/v1/messages/{message_id}/resources/{file_key}?type=file",
|
||||
self.api_base()
|
||||
)
|
||||
}
|
||||
|
||||
fn resolved_bot_open_id(&self) -> Option<String> {
|
||||
self.resolved_bot_open_id
|
||||
.read()
|
||||
@@ -866,6 +893,44 @@ impl LarkChannel {
|
||||
Some(details) => (details.text, details.mentioned_open_ids),
|
||||
None => continue,
|
||||
},
|
||||
"image" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let image_key = match v.get("image_key").and_then(|k| k.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => { tracing::debug!("Lark WS: image message missing image_key"); continue; }
|
||||
};
|
||||
match self.download_image_as_marker(&image_key).await {
|
||||
Some(marker) => (marker, Vec::new()),
|
||||
None => {
|
||||
tracing::warn!("Lark WS: failed to download image {image_key}");
|
||||
(format!("[IMAGE:{image_key} | download failed]"), Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
"file" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let file_key = match v.get("file_key").and_then(|k| k.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => { tracing::debug!("Lark WS: file message missing file_key"); continue; }
|
||||
};
|
||||
let file_name = v.get("file_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.unwrap_or("unknown_file")
|
||||
.to_string();
|
||||
match self.download_file_as_content(&lark_msg.message_id, &file_key, &file_name).await {
|
||||
Some(content) => (content, Vec::new()),
|
||||
None => {
|
||||
tracing::warn!("Lark WS: failed to download file {file_key}");
|
||||
(format!("[ATTACHMENT:{file_name} | download failed]"), Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
|
||||
};
|
||||
|
||||
@@ -986,6 +1051,183 @@ impl LarkChannel {
|
||||
*cached = None;
|
||||
}
|
||||
|
||||
/// Download an image from the Lark API and return an `[IMAGE:data:...]` marker string.
|
||||
async fn download_image_as_marker(&self, image_key: &str) -> Option<String> {
|
||||
let token = match self.get_tenant_access_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: failed to get token for image download: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.image_download_url(image_key);
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: image download request failed for {image_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Lark: image download failed for {image_key}: status={}",
|
||||
resp.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(cl) = resp.content_length() {
|
||||
if cl > LARK_IMAGE_MAX_BYTES as u64 {
|
||||
tracing::warn!("Lark: image too large for {image_key}: {cl} bytes exceeds limit");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::to_string);
|
||||
|
||||
let bytes = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: image body read failed for {image_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if bytes.is_empty() || bytes.len() > LARK_IMAGE_MAX_BYTES {
|
||||
tracing::warn!(
|
||||
"Lark: image body empty or too large for {image_key}: {} bytes",
|
||||
bytes.len()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let mime = lark_detect_image_mime(content_type.as_deref(), &bytes)?;
|
||||
if !LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
|
||||
tracing::warn!("Lark: unsupported image MIME for {image_key}: {mime}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
Some(format!("[IMAGE:data:{mime};base64,{encoded}]"))
|
||||
}
|
||||
|
||||
/// Download a file from the Lark API and return a text content marker.
|
||||
/// For text-like files, the content is inlined. For binary files, a summary is returned.
|
||||
async fn download_file_as_content(
|
||||
&self,
|
||||
message_id: &str,
|
||||
file_key: &str,
|
||||
file_name: &str,
|
||||
) -> Option<String> {
|
||||
let token = match self.get_tenant_access_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: failed to get token for file download: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.file_download_url(message_id, file_key);
|
||||
let resp = match self
|
||||
.http_client()
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: file download request failed for {file_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
tracing::warn!(
|
||||
"Lark: file download failed for {file_key}: status={}",
|
||||
resp.status()
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(cl) = resp.content_length() {
|
||||
if cl > LARK_FILE_MAX_BYTES as u64 {
|
||||
tracing::warn!("Lark: file too large for {file_key}: {cl} bytes exceeds limit");
|
||||
return Some(format!(
|
||||
"[ATTACHMENT:{file_name} | size={cl} bytes | too large to inline]"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let bytes = match resp.bytes().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::warn!("Lark: file body read failed for {file_key}: {e}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if bytes.is_empty() {
|
||||
tracing::warn!("Lark: file body is empty for {file_key}");
|
||||
return None;
|
||||
}
|
||||
|
||||
// If the content is image-like, return as image marker
|
||||
if content_type.starts_with("image/") && bytes.len() <= LARK_IMAGE_MAX_BYTES {
|
||||
if let Some(mime) = lark_detect_image_mime(Some(&content_type), &bytes) {
|
||||
if LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
|
||||
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
return Some(format!("[IMAGE:data:{mime};base64,{encoded}]"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the file looks like text, inline it
|
||||
if bytes.len() <= LARK_FILE_MAX_BYTES
|
||||
&& !bytes.contains(&0)
|
||||
&& (content_type.starts_with("text/")
|
||||
|| content_type.contains("json")
|
||||
|| content_type.contains("xml")
|
||||
|| content_type.contains("yaml")
|
||||
|| content_type.contains("javascript")
|
||||
|| content_type.contains("csv")
|
||||
|| lark_is_text_filename(file_name))
|
||||
{
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
let truncated = if text.len() > 50_000 {
|
||||
format!("{}...\n[truncated]", &text[..50_000])
|
||||
} else {
|
||||
text.into_owned()
|
||||
};
|
||||
let ext = file_name.rsplit('.').next().unwrap_or("text");
|
||||
return Some(format!("[FILE:{file_name}]\n```{ext}\n{truncated}\n```"));
|
||||
}
|
||||
|
||||
Some(format!(
|
||||
"[ATTACHMENT:{file_name} | mime={content_type} | size={} bytes]",
|
||||
bytes.len()
|
||||
))
|
||||
}
|
||||
|
||||
async fn fetch_bot_open_id_with_token(
|
||||
&self,
|
||||
token: &str,
|
||||
@@ -1085,8 +1327,9 @@ impl LarkChannel {
|
||||
Ok((status, parsed))
|
||||
}
|
||||
|
||||
/// Parse an event callback payload and extract text messages
|
||||
pub fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
/// Parse an event callback payload and extract messages.
|
||||
/// Supports text, post, image, and file message types.
|
||||
pub async fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
// Lark event v2 structure:
|
||||
@@ -1143,6 +1386,11 @@ impl LarkChannel {
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let evt_message_id = event
|
||||
.pointer("/message/message_id")
|
||||
.and_then(|m| m.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let (text, post_mentioned_open_ids): (String, Vec<String>) = match msg_type {
|
||||
"text" => {
|
||||
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
@@ -1162,6 +1410,62 @@ impl LarkChannel {
|
||||
Some(details) => (details.text, details.mentioned_open_ids),
|
||||
None => return messages,
|
||||
},
|
||||
"image" => {
|
||||
let image_key = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("image_key")
|
||||
.and_then(|k| k.as_str())
|
||||
.map(String::from)
|
||||
});
|
||||
match image_key {
|
||||
Some(key) => {
|
||||
let marker = match self.download_image_as_marker(&key).await {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!("Lark: failed to download image {key}");
|
||||
format!("[IMAGE:{key} | download failed]")
|
||||
}
|
||||
};
|
||||
(marker, Vec::new())
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("Lark: image message missing image_key");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
"file" => {
|
||||
let parsed = serde_json::from_str::<serde_json::Value>(content_str).ok();
|
||||
let file_key = parsed
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("file_key").and_then(|k| k.as_str()))
|
||||
.map(String::from);
|
||||
let file_name = parsed
|
||||
.as_ref()
|
||||
.and_then(|v| v.get("file_name").and_then(|n| n.as_str()))
|
||||
.unwrap_or("unknown_file")
|
||||
.to_string();
|
||||
match file_key {
|
||||
Some(key) => {
|
||||
let content = match self
|
||||
.download_file_as_content(evt_message_id, &key, &file_name)
|
||||
.await
|
||||
{
|
||||
Some(c) => c,
|
||||
None => {
|
||||
tracing::warn!("Lark: failed to download file {key}");
|
||||
format!("[ATTACHMENT:{file_name} | download failed]")
|
||||
}
|
||||
};
|
||||
(content, Vec::new())
|
||||
}
|
||||
None => {
|
||||
tracing::debug!("Lark: file message missing file_key");
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
|
||||
return messages;
|
||||
@@ -1305,7 +1609,7 @@ impl LarkChannel {
|
||||
}
|
||||
|
||||
// Parse event messages
|
||||
let messages = state.channel.parse_event_payload(&payload);
|
||||
let messages = state.channel.parse_event_payload(&payload).await;
|
||||
if !messages.is_empty() {
|
||||
if let Some(message_id) = payload
|
||||
.pointer("/event/message/message_id")
|
||||
@@ -1556,6 +1860,72 @@ fn detect_lark_ack_locale(
|
||||
detect_locale_from_text(fallback_text).unwrap_or(LarkAckLocale::En)
|
||||
}
|
||||
|
||||
/// Detect image MIME type from magic bytes, falling back to Content-Type header.
|
||||
fn lark_detect_image_mime(content_type: Option<&str>, bytes: &[u8]) -> Option<String> {
|
||||
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
|
||||
return Some("image/png".to_string());
|
||||
}
|
||||
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
|
||||
return Some("image/jpeg".to_string());
|
||||
}
|
||||
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
|
||||
return Some("image/gif".to_string());
|
||||
}
|
||||
if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
|
||||
return Some("image/webp".to_string());
|
||||
}
|
||||
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
|
||||
return Some("image/bmp".to_string());
|
||||
}
|
||||
content_type
|
||||
.and_then(|ct| ct.split(';').next())
|
||||
.map(|ct| ct.trim().to_lowercase())
|
||||
.filter(|ct| ct.starts_with("image/"))
|
||||
}
|
||||
|
||||
/// Check if a filename looks like a text file based on extension.
|
||||
fn lark_is_text_filename(name: &str) -> bool {
|
||||
let ext = name.rsplit('.').next().unwrap_or("").to_ascii_lowercase();
|
||||
matches!(
|
||||
ext.as_str(),
|
||||
"txt"
|
||||
| "md"
|
||||
| "rs"
|
||||
| "py"
|
||||
| "js"
|
||||
| "ts"
|
||||
| "tsx"
|
||||
| "jsx"
|
||||
| "java"
|
||||
| "c"
|
||||
| "h"
|
||||
| "cpp"
|
||||
| "hpp"
|
||||
| "go"
|
||||
| "rb"
|
||||
| "sh"
|
||||
| "bash"
|
||||
| "zsh"
|
||||
| "toml"
|
||||
| "yaml"
|
||||
| "yml"
|
||||
| "json"
|
||||
| "xml"
|
||||
| "html"
|
||||
| "css"
|
||||
| "sql"
|
||||
| "csv"
|
||||
| "tsv"
|
||||
| "log"
|
||||
| "cfg"
|
||||
| "ini"
|
||||
| "conf"
|
||||
| "env"
|
||||
| "dockerfile"
|
||||
| "makefile"
|
||||
)
|
||||
}
|
||||
|
||||
fn random_lark_ack_reaction(
|
||||
payload: Option<&serde_json::Value>,
|
||||
fallback_text: &str,
|
||||
@@ -1892,8 +2262,8 @@ mod tests {
|
||||
assert!(!ch.is_user_allowed("ou_anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_challenge() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_challenge() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"challenge": "abc123",
|
||||
@@ -1901,12 +2271,12 @@ mod tests {
|
||||
"type": "url_verification"
|
||||
});
|
||||
// Challenge payloads should not produce messages
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_valid_text_message() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_valid_text_message() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": {
|
||||
@@ -1927,7 +2297,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
|
||||
assert_eq!(msgs[0].sender, "oc_chat123");
|
||||
@@ -1935,8 +2305,8 @@ mod tests {
|
||||
assert_eq!(msgs[0].timestamp, 1_699_999_999);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_unauthorized_user() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unauthorized_user() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -1951,12 +2321,38 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_non_text_message_skipped() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unsupported_message_type_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_type": "sticker",
|
||||
"content": "{}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_image_missing_key_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -1977,12 +2373,38 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_empty_text_skipped() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_file_missing_key_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
None,
|
||||
vec!["*".into()],
|
||||
true,
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
"event": {
|
||||
"sender": { "sender_id": { "open_id": "ou_user" } },
|
||||
"message": {
|
||||
"message_type": "file",
|
||||
"content": "{}",
|
||||
"chat_id": "oc_chat"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lark_parse_empty_text_skipped() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2003,24 +2425,24 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_wrong_event_type() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_wrong_event_type() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.chat.disbanded_v1" },
|
||||
"event": {}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_missing_sender() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_missing_sender() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2040,12 +2462,12 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_unicode_message() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_unicode_message() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2067,24 +2489,24 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].content, "Hello world 🌍");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_missing_event() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_missing_event() {
|
||||
let ch = make_channel();
|
||||
let payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" }
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_invalid_content_json() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_invalid_content_json() {
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
"secret".into(),
|
||||
@@ -2105,7 +2527,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
@@ -2237,8 +2659,8 @@ mod tests {
|
||||
assert_eq!(ch.name(), "feishu");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_fallback_sender_to_open_id() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_fallback_sender_to_open_id() {
|
||||
// When chat_id is missing, sender should fall back to open_id
|
||||
let ch = LarkChannel::new(
|
||||
"id".into(),
|
||||
@@ -2260,13 +2682,13 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
let msgs = ch.parse_event_payload(&payload);
|
||||
let msgs = ch.parse_event_payload(&payload).await;
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(msgs[0].sender, "ou_user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_message_requires_bot_mention_when_enabled() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_message_requires_bot_mention_when_enabled() {
|
||||
let ch = with_bot_open_id(
|
||||
LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
@@ -2292,7 +2714,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert!(ch.parse_event_payload(&no_mention_payload).is_empty());
|
||||
assert!(ch.parse_event_payload(&no_mention_payload).await.is_empty());
|
||||
|
||||
let wrong_mention_payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -2307,7 +2729,10 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert!(ch.parse_event_payload(&wrong_mention_payload).is_empty());
|
||||
assert!(ch
|
||||
.parse_event_payload(&wrong_mention_payload)
|
||||
.await
|
||||
.is_empty());
|
||||
|
||||
let bot_mention_payload = serde_json::json!({
|
||||
"header": { "event_type": "im.message.receive_v1" },
|
||||
@@ -2322,11 +2747,11 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
assert_eq!(ch.parse_event_payload(&bot_mention_payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&bot_mention_payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
|
||||
let ch = with_bot_open_id(
|
||||
LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
@@ -2353,11 +2778,11 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_parse_group_message_allows_without_mention_when_disabled() {
|
||||
#[tokio::test]
|
||||
async fn lark_parse_group_message_allows_without_mention_when_disabled() {
|
||||
let ch = LarkChannel::new(
|
||||
"cli_app123".into(),
|
||||
"secret".into(),
|
||||
@@ -2381,7 +2806,7 @@ mod tests {
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
|
||||
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2409,6 +2834,69 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_image_download_url_matches_region() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(
|
||||
ch.image_download_url("img_abc123"),
|
||||
"https://open.larksuite.com/open-apis/im/v1/images/img_abc123"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_file_download_url_matches_region() {
|
||||
let ch = make_channel();
|
||||
assert_eq!(
|
||||
ch.file_download_url("om_msg123", "file_abc"),
|
||||
"https://open.larksuite.com/open-apis/im/v1/messages/om_msg123/resources/file_abc?type=file"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_detect_image_mime_from_magic_bytes() {
|
||||
let png = [0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, &png).as_deref(),
|
||||
Some("image/png")
|
||||
);
|
||||
|
||||
let jpeg = [0xff, 0xd8, 0xff, 0xe0];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, &jpeg).as_deref(),
|
||||
Some("image/jpeg")
|
||||
);
|
||||
|
||||
let gif = b"GIF89a...";
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(None, gif).as_deref(),
|
||||
Some("image/gif")
|
||||
);
|
||||
|
||||
// Unknown bytes should fall back to content-type header
|
||||
let unknown = [0x00, 0x01, 0x02];
|
||||
assert_eq!(
|
||||
lark_detect_image_mime(Some("image/webp"), &unknown).as_deref(),
|
||||
Some("image/webp")
|
||||
);
|
||||
|
||||
// Non-image content-type should be rejected
|
||||
assert_eq!(lark_detect_image_mime(Some("text/html"), &unknown), None);
|
||||
|
||||
// No info at all should return None
|
||||
assert_eq!(lark_detect_image_mime(None, &unknown), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_is_text_filename_recognizes_common_extensions() {
|
||||
assert!(lark_is_text_filename("script.py"));
|
||||
assert!(lark_is_text_filename("config.toml"));
|
||||
assert!(lark_is_text_filename("data.csv"));
|
||||
assert!(lark_is_text_filename("README.md"));
|
||||
assert!(!lark_is_text_filename("image.png"));
|
||||
assert!(!lark_is_text_filename("archive.zip"));
|
||||
assert!(!lark_is_text_filename("binary.exe"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_reaction_locale_explicit_language_tags() {
|
||||
assert_eq!(map_locale_tag("zh-CN"), Some(LarkAckLocale::ZhCn));
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
//! Link enricher: auto-detects URLs in inbound messages, fetches their content,
|
||||
//! and prepends summaries so the agent has link context without explicit tool calls.
|
||||
|
||||
use regex::Regex;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::LazyLock;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for the link enricher pipeline stage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinkEnricherConfig {
|
||||
pub enabled: bool,
|
||||
pub max_links: usize,
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for LinkEnricherConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// URL regex: matches http:// and https:// URLs, stopping at whitespace, angle
|
||||
/// brackets, or double-quotes.
|
||||
static URL_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r#"https?://[^\s<>"']+"#).expect("URL regex must compile"));
|
||||
|
||||
/// Extract URLs from message text, returning up to `max` unique URLs.
|
||||
pub fn extract_urls(text: &str, max: usize) -> Vec<String> {
|
||||
let mut seen = Vec::new();
|
||||
for m in URL_RE.find_iter(text) {
|
||||
let url = m.as_str().to_string();
|
||||
if !seen.contains(&url) {
|
||||
seen.push(url);
|
||||
if seen.len() >= max {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
seen
|
||||
}
|
||||
|
||||
/// Returns `true` if the URL points to a private/local address that should be
|
||||
/// blocked for SSRF protection.
|
||||
pub fn is_ssrf_target(url: &str) -> bool {
|
||||
let host = match extract_host(url) {
|
||||
Some(h) => h,
|
||||
None => return true, // unparseable URLs are rejected
|
||||
};
|
||||
|
||||
// Check hostname-based locals
|
||||
if host == "localhost"
|
||||
|| host.ends_with(".localhost")
|
||||
|| host.ends_with(".local")
|
||||
|| host == "local"
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check IP-based private ranges
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return is_private_ip(ip);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Extract the host portion from a URL string.
|
||||
fn extract_host(url: &str) -> Option<String> {
|
||||
let rest = url
|
||||
.strip_prefix("https://")
|
||||
.or_else(|| url.strip_prefix("http://"))?;
|
||||
let authority = rest.split(['/', '?', '#']).next()?;
|
||||
if authority.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Strip port
|
||||
let host = if authority.starts_with('[') {
|
||||
// IPv6 in brackets — reject for simplicity
|
||||
return None;
|
||||
} else {
|
||||
authority.split(':').next().unwrap_or(authority)
|
||||
};
|
||||
Some(host.to_lowercase())
|
||||
}
|
||||
|
||||
/// Check if an IP address falls within private/reserved ranges.
|
||||
fn is_private_ip(ip: IpAddr) -> bool {
|
||||
match ip {
|
||||
IpAddr::V4(v4) => {
|
||||
v4.is_loopback() // 127.0.0.0/8
|
||||
|| v4.is_private() // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
|
||||
|| v4.is_link_local() // 169.254.0.0/16
|
||||
|| v4.is_unspecified() // 0.0.0.0
|
||||
|| v4.is_broadcast() // 255.255.255.255
|
||||
|| v4.is_multicast() // 224.0.0.0/4
|
||||
}
|
||||
IpAddr::V6(v6) => {
|
||||
v6.is_loopback() // ::1
|
||||
|| v6.is_unspecified() // ::
|
||||
|| v6.is_multicast()
|
||||
// Check for IPv4-mapped IPv6 addresses
|
||||
|| v6.to_ipv4_mapped().is_some_and(|v4| {
|
||||
v4.is_loopback()
|
||||
|| v4.is_private()
|
||||
|| v4.is_link_local()
|
||||
|| v4.is_unspecified()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the `<title>` tag content from HTML.
|
||||
pub fn extract_title(html: &str) -> Option<String> {
|
||||
// Case-insensitive search for <title>...</title>
|
||||
let lower = html.to_lowercase();
|
||||
let start = lower.find("<title")? + "<title".len();
|
||||
// Skip attributes if any (e.g. <title lang="en">)
|
||||
let start = lower[start..].find('>')? + start + 1;
|
||||
let end = lower[start..].find("</title")? + start;
|
||||
let title = lower[start..end].trim().to_string();
|
||||
if title.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(html_entity_decode_basic(&title))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the first `max_chars` of visible body text from HTML.
|
||||
pub fn extract_body_text(html: &str, max_chars: usize) -> String {
|
||||
let text = nanohtml2text::html2text(html);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.len() <= max_chars {
|
||||
trimmed.to_string()
|
||||
} else {
|
||||
let mut result: String = trimmed.chars().take(max_chars).collect();
|
||||
result.push_str("...");
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Basic HTML entity decoding for title content.
|
||||
fn html_entity_decode_basic(s: &str) -> String {
|
||||
s.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
}
|
||||
|
||||
/// Summary of a fetched link.
|
||||
struct LinkSummary {
|
||||
title: String,
|
||||
snippet: String,
|
||||
}
|
||||
|
||||
/// Fetch a single URL and extract a summary. Returns `None` on any failure.
|
||||
async fn fetch_link_summary(url: &str, timeout_secs: u64) -> Option<LinkSummary> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.redirect(reqwest::redirect::Policy::limited(5))
|
||||
.user_agent("ZeroClaw/0.1 (link-enricher)")
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let response = client.get(url).send().await.ok()?;
|
||||
if !response.status().is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Only process text/html responses
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_lowercase();
|
||||
|
||||
if !content_type.contains("text/html") && !content_type.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read up to 256KB to extract title and snippet
|
||||
let max_bytes: usize = 256 * 1024;
|
||||
let bytes = response.bytes().await.ok()?;
|
||||
let body = if bytes.len() > max_bytes {
|
||||
String::from_utf8_lossy(&bytes[..max_bytes]).into_owned()
|
||||
} else {
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
};
|
||||
|
||||
let title = extract_title(&body).unwrap_or_else(|| "Untitled".to_string());
|
||||
let snippet = extract_body_text(&body, 200);
|
||||
|
||||
Some(LinkSummary { title, snippet })
|
||||
}
|
||||
|
||||
/// Enrich a message by prepending link summaries for any URLs found in the text.
|
||||
///
|
||||
/// This is the main entry point called from the channel message processing pipeline.
|
||||
/// If the enricher is disabled or no URLs are found, the original message is returned
|
||||
/// unchanged.
|
||||
pub async fn enrich_message(content: &str, config: &LinkEnricherConfig) -> String {
|
||||
if !config.enabled || config.max_links == 0 {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let urls = extract_urls(content, config.max_links);
|
||||
if urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
// Filter out SSRF targets
|
||||
let safe_urls: Vec<&str> = urls
|
||||
.iter()
|
||||
.filter(|u| !is_ssrf_target(u))
|
||||
.map(|u| u.as_str())
|
||||
.collect();
|
||||
if safe_urls.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let mut enrichments = Vec::new();
|
||||
for url in safe_urls {
|
||||
match fetch_link_summary(url, config.timeout_secs).await {
|
||||
Some(summary) => {
|
||||
enrichments.push(format!("[Link: {} — {}]", summary.title, summary.snippet));
|
||||
}
|
||||
None => {
|
||||
tracing::debug!(url, "Link enricher: failed to fetch or extract summary");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if enrichments.is_empty() {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let prefix = enrichments.join("\n");
|
||||
format!("{prefix}\n{content}")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── URL extraction ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_urls_finds_http_and_https() {
|
||||
let text = "Check https://example.com and http://test.org/page for info";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com", "http://test.org/page",]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_respects_max() {
|
||||
let text = "https://a.com https://b.com https://c.com https://d.com";
|
||||
let urls = extract_urls(text, 2);
|
||||
assert_eq!(urls.len(), 2);
|
||||
assert_eq!(urls[0], "https://a.com");
|
||||
assert_eq!(urls[1], "https://b.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_deduplicates() {
|
||||
let text = "Visit https://example.com and https://example.com again";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_handles_no_urls() {
|
||||
let text = "Just a normal message without links";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert!(urls.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_angle_brackets() {
|
||||
let text = "Link: <https://example.com/path> done";
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/path"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_urls_stops_at_quotes() {
|
||||
let text = r#"href="https://example.com/page" end"#;
|
||||
let urls = extract_urls(text, 10);
|
||||
assert_eq!(urls, vec!["https://example.com/page"]);
|
||||
}
|
||||
|
||||
// ── SSRF protection ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_localhost() {
|
||||
assert!(is_ssrf_target("http://localhost/admin"));
|
||||
assert!(is_ssrf_target("https://localhost:8080/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_loopback_ip() {
|
||||
assert!(is_ssrf_target("http://127.0.0.1/secret"));
|
||||
assert!(is_ssrf_target("http://127.0.0.2:9090"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_10_network() {
|
||||
assert!(is_ssrf_target("http://10.0.0.1/internal"));
|
||||
assert!(is_ssrf_target("http://10.255.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_172_network() {
|
||||
assert!(is_ssrf_target("http://172.16.0.1/admin"));
|
||||
assert!(is_ssrf_target("http://172.31.255.255"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_private_192_168_network() {
|
||||
assert!(is_ssrf_target("http://192.168.1.1/router"));
|
||||
assert!(is_ssrf_target("http://192.168.0.100:3000"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_link_local() {
|
||||
assert!(is_ssrf_target("http://169.254.0.1/metadata"));
|
||||
assert!(is_ssrf_target("http://169.254.169.254/latest"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_ipv6_loopback() {
|
||||
// IPv6 in brackets is rejected by extract_host
|
||||
assert!(is_ssrf_target("http://[::1]/admin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_blocks_dot_local() {
|
||||
assert!(is_ssrf_target("http://myhost.local/api"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_allows_public_urls() {
|
||||
assert!(!is_ssrf_target("https://example.com/page"));
|
||||
assert!(!is_ssrf_target("https://www.google.com"));
|
||||
assert!(!is_ssrf_target("http://93.184.216.34/resource"));
|
||||
}
|
||||
|
||||
// ── Title extraction ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_title_basic() {
|
||||
let html = "<html><head><title>My Page Title</title></head><body>Hello</body></html>";
|
||||
assert_eq!(extract_title(html), Some("my page title".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_with_entities() {
|
||||
let html = "<title>Tom & Jerry's Page</title>";
|
||||
assert_eq!(extract_title(html), Some("tom & jerry's page".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_case_insensitive() {
|
||||
let html = "<HTML><HEAD><TITLE>Upper Case</TITLE></HEAD></HTML>";
|
||||
assert_eq!(extract_title(html), Some("upper case".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_multibyte_chars_no_panic() {
|
||||
// İ (U+0130) lowercases to 2 chars, changing byte length.
|
||||
// This must not panic or produce wrong offsets.
|
||||
let html = "<title>İstanbul Guide</title>";
|
||||
let result = extract_title(html);
|
||||
assert!(result.is_some());
|
||||
let title = result.unwrap();
|
||||
assert!(title.contains("stanbul"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_missing() {
|
||||
let html = "<html><body>No title here</body></html>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_title_empty() {
|
||||
let html = "<title> </title>";
|
||||
assert_eq!(extract_title(html), None);
|
||||
}
|
||||
|
||||
// ── Body text extraction ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_strips_html() {
|
||||
let html = "<html><body><h1>Header</h1><p>Some content here</p></body></html>";
|
||||
let text = extract_body_text(html, 200);
|
||||
assert!(text.contains("Header"));
|
||||
assert!(text.contains("Some content"));
|
||||
assert!(!text.contains("<h1>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_body_text_truncates() {
|
||||
let html = "<p>A very long paragraph that should be truncated to fit within the limit.</p>";
|
||||
let text = extract_body_text(html, 20);
|
||||
assert!(text.len() <= 25); // 20 chars + "..."
|
||||
assert!(text.ends_with("..."));
|
||||
}
|
||||
|
||||
// ── Config toggle ───────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_disabled_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: false,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Check https://example.com for details";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_no_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "No links in this message";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enrich_message_ssrf_urls_returns_original() {
|
||||
let config = LinkEnricherConfig {
|
||||
enabled: true,
|
||||
max_links: 3,
|
||||
timeout_secs: 10,
|
||||
};
|
||||
let msg = "Try http://127.0.0.1/admin and http://192.168.1.1/router";
|
||||
let result = enrich_message(msg, &config).await;
|
||||
assert_eq!(result, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_is_disabled() {
|
||||
let config = LinkEnricherConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.max_links, 3);
|
||||
assert_eq!(config.timeout_secs, 10);
|
||||
}
|
||||
}
|
||||
+210
-7
@@ -8,6 +8,7 @@ use matrix_sdk::{
|
||||
events::reaction::ReactionEventContent,
|
||||
events::receipt::ReceiptThread,
|
||||
events::relation::{Annotation, Thread},
|
||||
events::room::member::StrippedRoomMemberEvent,
|
||||
events::room::message::{
|
||||
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||
},
|
||||
@@ -32,6 +33,7 @@ pub struct MatrixChannel {
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
session_owner_hint: Option<String>,
|
||||
session_device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
@@ -48,6 +50,7 @@ impl std::fmt::Debug for MatrixChannel {
|
||||
.field("homeserver", &self.homeserver)
|
||||
.field("room_id", &self.room_id)
|
||||
.field("allowed_users", &self.allowed_users)
|
||||
.field("allowed_rooms", &self.allowed_rooms)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
@@ -121,7 +124,16 @@ impl MatrixChannel {
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint(homeserver, access_token, room_id, allowed_users, None, None)
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_with_session_hint(
|
||||
@@ -132,11 +144,12 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
) -> Self {
|
||||
Self::new_with_session_hint_and_zeroclaw_dir(
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
None,
|
||||
@@ -151,6 +164,28 @@ impl MatrixChannel {
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
Self::new_full(
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
vec![],
|
||||
owner_hint,
|
||||
device_id_hint,
|
||||
zeroclaw_dir,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_full(
|
||||
homeserver: String,
|
||||
access_token: String,
|
||||
room_id: String,
|
||||
allowed_users: Vec<String>,
|
||||
allowed_rooms: Vec<String>,
|
||||
owner_hint: Option<String>,
|
||||
device_id_hint: Option<String>,
|
||||
zeroclaw_dir: Option<PathBuf>,
|
||||
) -> Self {
|
||||
let homeserver = homeserver.trim_end_matches('/').to_string();
|
||||
let access_token = access_token.trim().to_string();
|
||||
@@ -160,12 +195,18 @@ impl MatrixChannel {
|
||||
.map(|user| user.trim().to_string())
|
||||
.filter(|user| !user.is_empty())
|
||||
.collect();
|
||||
let allowed_rooms = allowed_rooms
|
||||
.into_iter()
|
||||
.map(|room| room.trim().to_string())
|
||||
.filter(|room| !room.is_empty())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
homeserver,
|
||||
access_token,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms,
|
||||
session_owner_hint: Self::normalize_optional_field(owner_hint),
|
||||
session_device_id_hint: Self::normalize_optional_field(device_id_hint),
|
||||
zeroclaw_dir,
|
||||
@@ -220,6 +261,21 @@ impl MatrixChannel {
|
||||
allowed_users.iter().any(|u| u.eq_ignore_ascii_case(sender))
|
||||
}
|
||||
|
||||
/// Check whether a room (by its canonical ID) is in the allowed_rooms list.
|
||||
/// If allowed_rooms is empty, all rooms are allowed.
|
||||
fn is_room_allowed_static(allowed_rooms: &[String], room_id: &str) -> bool {
|
||||
if allowed_rooms.is_empty() {
|
||||
return true;
|
||||
}
|
||||
allowed_rooms
|
||||
.iter()
|
||||
.any(|r| r.eq_ignore_ascii_case(room_id))
|
||||
}
|
||||
|
||||
fn is_room_allowed(&self, room_id: &str) -> bool {
|
||||
Self::is_room_allowed_static(&self.allowed_rooms, room_id)
|
||||
}
|
||||
|
||||
fn is_supported_message_type(msgtype: &str) -> bool {
|
||||
matches!(msgtype, "m.text" | "m.notice")
|
||||
}
|
||||
@@ -228,6 +284,10 @@ impl MatrixChannel {
|
||||
!body.trim().is_empty()
|
||||
}
|
||||
|
||||
fn room_matches_target(target_room_id: &str, incoming_room_id: &str) -> bool {
|
||||
target_room_id == incoming_room_id
|
||||
}
|
||||
|
||||
fn cache_event_id(
|
||||
event_id: &str,
|
||||
recent_order: &mut std::collections::VecDeque<String>,
|
||||
@@ -526,8 +586,9 @@ impl MatrixChannel {
|
||||
if client.encryption().backups().are_enabled().await {
|
||||
tracing::info!("Matrix room-key backup is enabled for this device.");
|
||||
} else {
|
||||
client.encryption().backups().disable().await;
|
||||
tracing::warn!(
|
||||
"Matrix room-key backup is not enabled for this device; `matrix_sdk_crypto::backups` warnings about missing backup keys may appear until recovery is configured."
|
||||
"Matrix room-key backup is not enabled for this device; automatic backup attempts have been disabled to suppress recurring warnings. To enable backups, configure server-side key backup and recovery for this device."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -697,6 +758,7 @@ impl Channel for MatrixChannel {
|
||||
let target_room_for_handler = target_room.clone();
|
||||
let my_user_id_for_handler = my_user_id.clone();
|
||||
let allowed_users_for_handler = self.allowed_users.clone();
|
||||
let allowed_rooms_for_handler = self.allowed_rooms.clone();
|
||||
let dedupe_for_handler = Arc::clone(&recent_event_cache);
|
||||
let homeserver_for_handler = self.homeserver.clone();
|
||||
let access_token_for_handler = self.access_token.clone();
|
||||
@@ -704,18 +766,29 @@ impl Channel for MatrixChannel {
|
||||
|
||||
client.add_event_handler(move |event: OriginalSyncRoomMessageEvent, room: Room| {
|
||||
let tx = tx_handler.clone();
|
||||
let _target_room = target_room_for_handler.clone();
|
||||
let target_room = target_room_for_handler.clone();
|
||||
let my_user_id = my_user_id_for_handler.clone();
|
||||
let allowed_users = allowed_users_for_handler.clone();
|
||||
let allowed_rooms = allowed_rooms_for_handler.clone();
|
||||
let dedupe = Arc::clone(&dedupe_for_handler);
|
||||
let homeserver = homeserver_for_handler.clone();
|
||||
let access_token = access_token_for_handler.clone();
|
||||
let voice_mode = Arc::clone(&voice_mode_for_handler);
|
||||
|
||||
async move {
|
||||
if false
|
||||
/* multi-room: room_id filter disabled */
|
||||
{
|
||||
if !MatrixChannel::room_matches_target(
|
||||
target_room.as_str(),
|
||||
room.room_id().as_str(),
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Room allowlist: skip messages from rooms not in the configured list
|
||||
if !MatrixChannel::is_room_allowed_static(&allowed_rooms, room.room_id().as_ref()) {
|
||||
tracing::debug!(
|
||||
"Matrix: ignoring message from room {} (not in allowed_rooms)",
|
||||
room.room_id()
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -907,6 +980,45 @@ impl Channel for MatrixChannel {
|
||||
}
|
||||
});
|
||||
|
||||
// Invite handler: auto-accept invites for allowed rooms, auto-reject others
|
||||
let allowed_rooms_for_invite = self.allowed_rooms.clone();
|
||||
client.add_event_handler(move |event: StrippedRoomMemberEvent, room: Room| {
|
||||
let allowed_rooms = allowed_rooms_for_invite.clone();
|
||||
async move {
|
||||
// Only process invite events targeting us
|
||||
if event.content.membership
|
||||
!= matrix_sdk::ruma::events::room::member::MembershipState::Invite
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let room_id_str = room.room_id().to_string();
|
||||
|
||||
if MatrixChannel::is_room_allowed_static(&allowed_rooms, &room_id_str) {
|
||||
// Room is allowed (or no allowlist configured): auto-accept
|
||||
tracing::info!(
|
||||
"Matrix: auto-accepting invite for allowed room {}",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.join().await {
|
||||
tracing::warn!("Matrix: failed to auto-join room {}: {error}", room_id_str);
|
||||
}
|
||||
} else {
|
||||
// Room is NOT in allowlist: auto-reject
|
||||
tracing::info!(
|
||||
"Matrix: auto-rejecting invite for room {} (not in allowed_rooms)",
|
||||
room_id_str
|
||||
);
|
||||
if let Err(error) = room.leave().await {
|
||||
tracing::warn!(
|
||||
"Matrix: failed to reject invite for room {}: {error}",
|
||||
room_id_str
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sync_settings = SyncSettings::new().timeout(std::time::Duration::from_secs(30));
|
||||
client
|
||||
.sync_with_result_callback(sync_settings, |sync_result| {
|
||||
@@ -1294,6 +1406,22 @@ mod tests {
|
||||
assert_eq!(value["room"]["timeline"]["limit"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_matches_configured_room() {
|
||||
assert!(MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!ops:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn room_scope_rejects_other_rooms() {
|
||||
assert!(!MatrixChannel::room_matches_target(
|
||||
"!ops:matrix.org",
|
||||
"!other:matrix.org"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_id_cache_deduplicates_and_evicts_old_entries() {
|
||||
let mut recent_order = std::collections::VecDeque::new();
|
||||
@@ -1549,4 +1677,79 @@ mod tests {
|
||||
let resp: SyncResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.rooms.join.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowed_rooms_permits_all() {
|
||||
let ch = make_channel();
|
||||
assert!(ch.is_room_allowed("!any:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!other:evil.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_filters_by_id() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec!["!allowed:matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!allowed:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!forbidden:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_supports_aliases() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec!["@user:m".to_string()],
|
||||
vec![
|
||||
"#ops:matrix.org".to_string(),
|
||||
"!direct:matrix.org".to_string(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!direct:matrix.org"));
|
||||
assert!(ch.is_room_allowed("#ops:matrix.org"));
|
||||
assert!(!ch.is_room_allowed("!other:matrix.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_case_insensitive() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec!["!Room:Matrix.org".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
assert!(ch.is_room_allowed("!ROOM:MATRIX.ORG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowed_rooms_trims_whitespace() {
|
||||
let ch = MatrixChannel::new_full(
|
||||
"https://m.org".to_string(),
|
||||
"tok".to_string(),
|
||||
"!r:m".to_string(),
|
||||
vec![],
|
||||
vec![" !room:matrix.org ".to_string(), " ".to_string()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(ch.allowed_rooms.len(), 1);
|
||||
assert!(ch.is_room_allowed("!room:matrix.org"));
|
||||
}
|
||||
}
|
||||
|
||||
+45
-18
@@ -26,6 +26,7 @@ pub mod imessage;
|
||||
pub mod irc;
|
||||
#[cfg(feature = "channel-lark")]
|
||||
pub mod lark;
|
||||
pub mod link_enricher;
|
||||
pub mod linq;
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
pub mod matrix;
|
||||
@@ -2066,6 +2067,25 @@ async fn process_channel_message(
|
||||
msg
|
||||
};
|
||||
|
||||
// ── Link enricher: prepend URL summaries before agent sees the message ──
|
||||
let le_config = &ctx.prompt_config.link_enricher;
|
||||
if le_config.enabled {
|
||||
let enricher_cfg = link_enricher::LinkEnricherConfig {
|
||||
enabled: le_config.enabled,
|
||||
max_links: le_config.max_links,
|
||||
timeout_secs: le_config.timeout_secs,
|
||||
};
|
||||
let enriched = link_enricher::enrich_message(&msg.content, &enricher_cfg).await;
|
||||
if enriched != msg.content {
|
||||
tracing::info!(
|
||||
channel = %msg.channel,
|
||||
sender = %msg.sender,
|
||||
"Link enricher: prepended URL summaries to message"
|
||||
);
|
||||
msg.content = enriched;
|
||||
}
|
||||
}
|
||||
|
||||
let target_channel = ctx
|
||||
.channels_by_name
|
||||
.get(&msg.channel)
|
||||
@@ -3670,13 +3690,16 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
.discord
|
||||
.as_ref()
|
||||
.context("Discord channel is not configured")?;
|
||||
Ok(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)))
|
||||
Ok(Arc::new(
|
||||
DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
"slack" => {
|
||||
let sl = config
|
||||
@@ -3692,7 +3715,8 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
))
|
||||
}
|
||||
other => anyhow::bail!("Unknown channel '{other}'. Supported: telegram, discord, slack"),
|
||||
@@ -3778,7 +3802,8 @@ fn collect_configured_channels(
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)
|
||||
.with_proxy_url(dc.proxy_url.clone()),
|
||||
.with_proxy_url(dc.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3822,7 +3847,8 @@ fn collect_configured_channels(
|
||||
.with_thread_replies(sl.thread_replies.unwrap_or(true))
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone())
|
||||
.with_proxy_url(sl.proxy_url.clone()),
|
||||
.with_proxy_url(sl.proxy_url.clone())
|
||||
.with_transcription(config.transcription.clone()),
|
||||
),
|
||||
});
|
||||
}
|
||||
@@ -3855,11 +3881,12 @@ fn collect_configured_channels(
|
||||
if let Some(ref mx) = config.channels_config.matrix {
|
||||
channels.push(ConfiguredChannel {
|
||||
display_name: "Matrix",
|
||||
channel: Arc::new(MatrixChannel::new_with_session_hint_and_zeroclaw_dir(
|
||||
channel: Arc::new(MatrixChannel::new_full(
|
||||
mx.homeserver.clone(),
|
||||
mx.access_token.clone(),
|
||||
mx.room_id.clone(),
|
||||
mx.allowed_users.clone(),
|
||||
mx.allowed_rooms.clone(),
|
||||
mx.user_id.clone(),
|
||||
mx.device_id.clone(),
|
||||
config.config_path.parent().map(|path| path.to_path_buf()),
|
||||
@@ -7699,9 +7726,9 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(prompt.contains("<instructions>"));
|
||||
assert!(prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
assert!(!prompt.contains("loaded on demand"));
|
||||
}
|
||||
|
||||
@@ -7744,10 +7771,10 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt
|
||||
.contains("<instruction>Always run cargo test before final response.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>lint</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>code-review.lint</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -12,6 +12,8 @@ use chrono::{DateTime, Utc};
|
||||
pub struct SessionMetadata {
|
||||
/// Session key (e.g. `telegram_user123`).
|
||||
pub key: String,
|
||||
/// Optional human-readable name (e.g. `eyrie-commander-briefing`).
|
||||
pub name: Option<String>,
|
||||
/// When the session was first created.
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When the last message was appended.
|
||||
@@ -54,6 +56,7 @@ pub trait SessionBackend: Send + Sync {
|
||||
let messages = self.load(&key);
|
||||
SessionMetadata {
|
||||
key,
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: messages.len(),
|
||||
@@ -81,6 +84,16 @@ pub trait SessionBackend: Send + Sync {
|
||||
fn delete_session(&self, _session_key: &str) -> std::io::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Set or update the human-readable name for a session.
|
||||
fn set_session_name(&self, _session_key: &str, _name: &str) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the human-readable name for a session (if set).
|
||||
fn get_session_name(&self, _session_key: &str) -> std::io::Result<Option<String>> {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -91,6 +104,7 @@ mod tests {
|
||||
fn session_metadata_is_constructible() {
|
||||
let meta = SessionMetadata {
|
||||
key: "test".into(),
|
||||
name: None,
|
||||
created_at: Utc::now(),
|
||||
last_activity: Utc::now(),
|
||||
message_count: 5,
|
||||
|
||||
@@ -51,7 +51,8 @@ impl SqliteSessionBackend {
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
last_activity TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL DEFAULT 0
|
||||
message_count INTEGER NOT NULL DEFAULT 0,
|
||||
name TEXT
|
||||
);
|
||||
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
|
||||
@@ -69,6 +70,18 @@ impl SqliteSessionBackend {
|
||||
)
|
||||
.context("Failed to initialize session schema")?;
|
||||
|
||||
// Migration: add name column to existing databases
|
||||
let has_name: bool = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) > 0 FROM pragma_table_info('session_metadata') WHERE name = 'name'",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap_or(false);
|
||||
if !has_name {
|
||||
let _ = conn.execute("ALTER TABLE session_metadata ADD COLUMN name TEXT", []);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
@@ -226,7 +239,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
|
||||
let conn = self.conn.lock();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_key, created_at, last_activity, message_count
|
||||
"SELECT session_key, created_at, last_activity, message_count, name
|
||||
FROM session_metadata ORDER BY last_activity DESC",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
@@ -238,6 +251,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
let created_str: String = row.get(1)?;
|
||||
let activity_str: String = row.get(2)?;
|
||||
let count: i64 = row.get(3)?;
|
||||
let name: Option<String> = row.get(4)?;
|
||||
|
||||
let created = DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
@@ -249,6 +263,7 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(SessionMetadata {
|
||||
key,
|
||||
name,
|
||||
created_at: created,
|
||||
last_activity: activity,
|
||||
message_count: count as usize,
|
||||
@@ -321,6 +336,27 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn set_session_name(&self, session_key: &str, name: &str) -> std::io::Result<()> {
|
||||
let conn = self.conn.lock();
|
||||
let name_val = if name.is_empty() { None } else { Some(name) };
|
||||
conn.execute(
|
||||
"UPDATE session_metadata SET name = ?1 WHERE session_key = ?2",
|
||||
params![name_val, session_key],
|
||||
)
|
||||
.map_err(std::io::Error::other)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_session_name(&self, session_key: &str) -> std::io::Result<Option<String>> {
|
||||
let conn = self.conn.lock();
|
||||
conn.query_row(
|
||||
"SELECT name FROM session_metadata WHERE session_key = ?1",
|
||||
params![session_key],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.map_err(std::io::Error::other)
|
||||
}
|
||||
|
||||
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
|
||||
let Some(keyword) = &query.keyword else {
|
||||
return self.list_sessions_with_metadata();
|
||||
@@ -357,14 +393,16 @@ impl SessionBackend for SqliteSessionBackend {
|
||||
keys.iter()
|
||||
.filter_map(|key| {
|
||||
conn.query_row(
|
||||
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
|
||||
"SELECT created_at, last_activity, message_count, name FROM session_metadata WHERE session_key = ?1",
|
||||
params![key],
|
||||
|row| {
|
||||
let created_str: String = row.get(0)?;
|
||||
let activity_str: String = row.get(1)?;
|
||||
let count: i64 = row.get(2)?;
|
||||
let name: Option<String> = row.get(3)?;
|
||||
Ok(SessionMetadata {
|
||||
key: key.clone(),
|
||||
name,
|
||||
created_at: DateTime::parse_from_rfc3339(&created_str)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now()),
|
||||
@@ -555,4 +593,55 @@ mod tests {
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_persists() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "My Session").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert_eq!(meta[0].name.as_deref(), Some("My Session"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_session_name_updates_existing() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "First").unwrap();
|
||||
backend.set_session_name("s1", "Second").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta[0].name.as_deref(), Some("Second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sessions_without_name_return_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert_eq!(meta.len(), 1);
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_name_clears_to_none() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
|
||||
|
||||
backend.append("s1", &ChatMessage::user("hello")).unwrap();
|
||||
backend.set_session_name("s1", "Named").unwrap();
|
||||
backend.set_session_name("s1", "").unwrap();
|
||||
|
||||
let meta = backend.list_sessions_with_metadata();
|
||||
assert!(meta[0].name.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ pub struct SlackChannel {
|
||||
active_assistant_thread: Mutex<HashMap<String, String>>,
|
||||
/// Per-channel proxy URL override.
|
||||
proxy_url: Option<String>,
|
||||
/// Voice transcription config — when set, audio file attachments are
|
||||
/// downloaded, transcribed, and their text inlined into the message.
|
||||
transcription: Option<crate::config::TranscriptionConfig>,
|
||||
}
|
||||
|
||||
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
|
||||
@@ -125,6 +128,7 @@ impl SlackChannel {
|
||||
workspace_dir: None,
|
||||
active_assistant_thread: Mutex::new(HashMap::new()),
|
||||
proxy_url: None,
|
||||
transcription: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +162,14 @@ impl SlackChannel {
|
||||
self
|
||||
}
|
||||
|
||||
/// Configure voice transcription for audio file attachments.
|
||||
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
|
||||
if config.enabled {
|
||||
self.transcription = Some(config);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn http_client(&self) -> reqwest::Client {
|
||||
crate::config::build_channel_proxy_client_with_timeouts(
|
||||
"channel.slack",
|
||||
@@ -558,6 +570,13 @@ impl SlackChannel {
|
||||
.await
|
||||
.unwrap_or_else(|| raw_file.clone());
|
||||
|
||||
// Voice / audio transcription: if transcription is configured and the
|
||||
// file looks like an audio attachment, download and transcribe it.
|
||||
if Self::is_audio_file(&file) {
|
||||
if let Some(transcribed) = self.try_transcribe_audio_file(&file).await {
|
||||
return Some(transcribed);
|
||||
}
|
||||
}
|
||||
if Self::is_image_file(&file) {
|
||||
if let Some(marker) = self.fetch_image_marker(&file).await {
|
||||
return Some(marker);
|
||||
@@ -1449,6 +1468,106 @@ impl SlackChannel {
|
||||
.is_some_and(|ext| Self::mime_from_extension(ext).is_some())
|
||||
}
|
||||
|
||||
/// Audio file extensions accepted for voice transcription.
|
||||
const AUDIO_EXTENSIONS: &[&str] = &[
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
];
|
||||
|
||||
/// Check whether a Slack file object looks like an audio attachment
|
||||
/// (voice memo, audio message, or uploaded audio file).
|
||||
fn is_audio_file(file: &serde_json::Value) -> bool {
|
||||
// Slack voice messages use subtype "slack_audio"
|
||||
if let Some(subtype) = file.get("subtype").and_then(|v| v.as_str()) {
|
||||
if subtype == "slack_audio" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if Self::slack_file_mime(file)
|
||||
.as_deref()
|
||||
.is_some_and(|mime| mime.starts_with("audio/"))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(ft) = file
|
||||
.get("filetype")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|v| v.to_ascii_lowercase())
|
||||
{
|
||||
if Self::AUDIO_EXTENSIONS.contains(&ft.as_str()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
Self::file_extension(&Self::slack_file_name(file))
|
||||
.as_deref()
|
||||
.is_some_and(|ext| Self::AUDIO_EXTENSIONS.contains(&ext))
|
||||
}
|
||||
|
||||
/// Download an audio file attachment and transcribe it using the configured
|
||||
/// transcription provider. Returns `None` if transcription is not configured
|
||||
/// or if the download/transcription fails.
|
||||
async fn try_transcribe_audio_file(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let config = self.transcription.as_ref()?;
|
||||
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let file_name = Self::slack_file_name(file);
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
let resp = self.fetch_slack_private_file(url).await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
tracing::warn!(
|
||||
"Slack voice file download failed for {} ({status})",
|
||||
redacted_url
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let audio_data = match resp.bytes().await {
|
||||
Ok(bytes) => bytes.to_vec(),
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice file read failed for {}: {e}", redacted_url);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Determine a filename with extension for the transcription API.
|
||||
let transcription_filename = if Self::file_extension(&file_name).is_some() {
|
||||
file_name.clone()
|
||||
} else {
|
||||
// Fall back to extension from mimetype or default to .ogg
|
||||
let mime_ext = Self::slack_file_mime(file)
|
||||
.and_then(|mime| mime.rsplit('/').next().map(|s| s.to_string()))
|
||||
.unwrap_or_else(|| "ogg".to_string());
|
||||
format!("voice.{mime_ext}")
|
||||
};
|
||||
|
||||
match super::transcription::transcribe_audio(audio_data, &transcription_filename, config)
|
||||
.await
|
||||
{
|
||||
Ok(text) => {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
tracing::info!("Slack voice transcription returned empty text, skipping");
|
||||
None
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Slack: transcribed voice file {} ({} chars)",
|
||||
file_name,
|
||||
trimmed.len()
|
||||
);
|
||||
Some(format!("[Voice] {trimmed}"))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Slack voice transcription failed for {}: {e}", file_name);
|
||||
Some(Self::format_attachment_summary(file))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn download_text_snippet(&self, file: &serde_json::Value) -> Option<String> {
|
||||
let url = Self::slack_file_download_url(file)?;
|
||||
let redacted_url = Self::redact_raw_slack_url(url);
|
||||
|
||||
@@ -1140,6 +1140,11 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content = format!("{quote}\n\n{content}");
|
||||
}
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
content = format!("{attr}{content}");
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1263,6 +1268,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
format!("[Voice] {text}")
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
@@ -1299,6 +1311,41 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
(username, sender_id, sender_identity)
|
||||
}
|
||||
|
||||
/// Build a forwarding attribution prefix from Telegram forward fields.
|
||||
///
|
||||
/// Returns `Some("[Forwarded from ...] ")` when the message is forwarded,
|
||||
/// `None` otherwise.
|
||||
fn format_forward_attribution(message: &serde_json::Value) -> Option<String> {
|
||||
if let Some(from_chat) = message.get("forward_from_chat") {
|
||||
// Forwarded from a channel or group
|
||||
let title = from_chat
|
||||
.get("title")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown channel");
|
||||
Some(format!("[Forwarded from channel: {title}] "))
|
||||
} else if let Some(from_user) = message.get("forward_from") {
|
||||
// Forwarded from a user (privacy allows identity)
|
||||
let label = from_user
|
||||
.get("username")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|u| format!("@{u}"))
|
||||
.or_else(|| {
|
||||
from_user
|
||||
.get("first_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(String::from)
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
Some(format!("[Forwarded from {label}] "))
|
||||
} else {
|
||||
// Forwarded from a user who hides their identity
|
||||
message
|
||||
.get("forward_sender_name")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|name| format!("[Forwarded from {name}] "))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract reply context from a Telegram `reply_to_message`, if present.
|
||||
fn extract_reply_context(&self, message: &serde_json::Value) -> Option<String> {
|
||||
let reply = message.get("reply_to_message")?;
|
||||
@@ -1420,6 +1467,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
|
||||
content
|
||||
};
|
||||
|
||||
// Prepend forwarding attribution when the message was forwarded
|
||||
let content = if let Some(attr) = Self::format_forward_attribution(message) {
|
||||
format!("{attr}{content}")
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
// Exit voice-chat mode when user switches back to typing
|
||||
if let Ok(mut vc) = self.voice_chats.lock() {
|
||||
vc.remove(&reply_target);
|
||||
@@ -4871,4 +4925,153 @@ mod tests {
|
||||
TelegramChannel::new("token".into(), vec!["*".into()], false).with_ack_reactions(true);
|
||||
assert!(ch.ack_reactions);
|
||||
}
|
||||
|
||||
// ── Forwarded message tests ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_with_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 100,
|
||||
"message": {
|
||||
"message_id": 50,
|
||||
"text": "Check this out",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"first_name": "Bob",
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from @bob] Check this out");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_channel() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 101,
|
||||
"message": {
|
||||
"message_id": 51,
|
||||
"text": "Breaking news",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from_chat": {
|
||||
"id": -1_001_234_567_890_i64,
|
||||
"title": "Daily News",
|
||||
"username": "dailynews",
|
||||
"type": "channel"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("channel-forwarded message should parse");
|
||||
assert_eq!(
|
||||
msg.content,
|
||||
"[Forwarded from channel: Daily News] Breaking news"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_hidden_sender() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 102,
|
||||
"message": {
|
||||
"message_id": 52,
|
||||
"text": "Secret tip",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_sender_name": "Hidden User",
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("hidden-sender forwarded message should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Hidden User] Secret tip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_non_forwarded_unaffected() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 103,
|
||||
"message": {
|
||||
"message_id": 53,
|
||||
"text": "Normal message",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 }
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("non-forwarded message should parse");
|
||||
assert_eq!(msg.content, "Normal message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_forwarded_from_user_no_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 104,
|
||||
"message": {
|
||||
"message_id": 54,
|
||||
"text": "Hello there",
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"forward_from": {
|
||||
"id": 77,
|
||||
"first_name": "Charlie"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("forwarded message without username should parse");
|
||||
assert_eq!(msg.content, "[Forwarded from Charlie] Hello there");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarded_photo_attachment_has_attribution() {
|
||||
// Verify that format_forward_attribution produces correct prefix
|
||||
// for a photo message (the actual download is async, so we test the
|
||||
// helper directly with a photo-bearing message structure).
|
||||
let message = serde_json::json!({
|
||||
"message_id": 60,
|
||||
"from": { "id": 1, "username": "alice" },
|
||||
"chat": { "id": 999 },
|
||||
"photo": [
|
||||
{ "file_id": "abc123", "file_unique_id": "u1", "width": 320, "height": 240 }
|
||||
],
|
||||
"forward_from": {
|
||||
"id": 42,
|
||||
"username": "bob"
|
||||
},
|
||||
"forward_date": 1_700_000_000
|
||||
});
|
||||
|
||||
let attr =
|
||||
TelegramChannel::format_forward_attribution(&message).expect("should detect forward");
|
||||
assert_eq!(attr, "[Forwarded from @bob] ");
|
||||
|
||||
// Simulate what try_parse_attachment_message does after building content
|
||||
let photo_content = "[IMAGE:/tmp/photo.jpg]".to_string();
|
||||
let content = format!("{attr}{photo_content}");
|
||||
assert_eq!(content, "[Forwarded from @bob] [IMAGE:/tmp/photo.jpg]");
|
||||
}
|
||||
}
|
||||
|
||||
+130
-1
@@ -1,6 +1,7 @@
|
||||
//! Multi-provider Text-to-Speech (TTS) subsystem.
|
||||
//!
|
||||
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, and Edge TTS (free, subprocess-based).
|
||||
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, Edge TTS (free, subprocess-based),
|
||||
//! and Piper TTS (local GPU-accelerated, OpenAI-compatible endpoint).
|
||||
//! Provider selection is driven by [`TtsConfig`] in `config.toml`.
|
||||
|
||||
use std::collections::HashMap;
|
||||
@@ -451,6 +452,80 @@ impl TtsProvider for EdgeTtsProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Piper TTS (local, OpenAI-compatible) ─────────────────────────
|
||||
|
||||
/// Piper TTS provider — local GPU-accelerated server with an OpenAI-compatible endpoint.
|
||||
pub struct PiperTtsProvider {
|
||||
client: reqwest::Client,
|
||||
api_url: String,
|
||||
}
|
||||
|
||||
impl PiperTtsProvider {
|
||||
/// Create a new Piper TTS provider pointing at the given API URL.
|
||||
pub fn new(api_url: &str) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.timeout(TTS_HTTP_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to build HTTP client for Piper TTS"),
|
||||
api_url: api_url.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TtsProvider for PiperTtsProvider {
|
||||
fn name(&self) -> &str {
|
||||
"piper"
|
||||
}
|
||||
|
||||
async fn synthesize(&self, text: &str, voice: &str) -> Result<Vec<u8>> {
|
||||
let body = serde_json::json!({
|
||||
"model": "tts-1",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&self.api_url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send Piper TTS request")?;
|
||||
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let error_body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.unwrap_or_else(|_| serde_json::json!({"error": "unknown"}));
|
||||
let msg = error_body["error"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown error");
|
||||
bail!("Piper TTS API error ({}): {}", status, msg);
|
||||
}
|
||||
|
||||
let bytes = resp
|
||||
.bytes()
|
||||
.await
|
||||
.context("Failed to read Piper TTS response body")?;
|
||||
Ok(bytes.to_vec())
|
||||
}
|
||||
|
||||
fn supported_voices(&self) -> Vec<String> {
|
||||
// Piper voices depend on installed models; return empty (dynamic).
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn supported_formats(&self) -> Vec<String> {
|
||||
["mp3", "wav", "opus"]
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── TtsManager ───────────────────────────────────────────────────
|
||||
|
||||
/// Central manager for multi-provider TTS synthesis.
|
||||
@@ -510,6 +585,11 @@ impl TtsManager {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref piper_cfg) = config.piper {
|
||||
let provider = PiperTtsProvider::new(&piper_cfg.api_url);
|
||||
providers.insert("piper".to_string(), Box::new(provider));
|
||||
}
|
||||
|
||||
let max_text_length = if config.max_text_length == 0 {
|
||||
DEFAULT_MAX_TEXT_LENGTH
|
||||
} else {
|
||||
@@ -652,6 +732,54 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn piper_provider_creation() {
|
||||
let provider = PiperTtsProvider::new("http://127.0.0.1:5000/v1/audio/speech");
|
||||
assert_eq!(provider.name(), "piper");
|
||||
assert_eq!(provider.api_url, "http://127.0.0.1:5000/v1/audio/speech");
|
||||
assert_eq!(provider.supported_formats(), vec!["mp3", "wav", "opus"]);
|
||||
// Piper voices depend on installed models; list is empty.
|
||||
assert!(provider.supported_voices().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tts_manager_with_piper_provider() {
|
||||
let mut config = default_tts_config();
|
||||
config.default_provider = "piper".to_string();
|
||||
config.piper = Some(crate::config::PiperTtsConfig {
|
||||
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
|
||||
});
|
||||
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
assert_eq!(manager.available_providers(), vec!["piper"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tts_rejects_empty_text_for_piper() {
|
||||
let mut config = default_tts_config();
|
||||
config.default_provider = "piper".to_string();
|
||||
config.piper = Some(crate::config::PiperTtsConfig {
|
||||
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
|
||||
});
|
||||
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
let err = manager
|
||||
.synthesize_with_provider("", "piper", "default")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("must not be empty"),
|
||||
"expected empty-text error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn piper_not_registered_when_config_is_none() {
|
||||
let config = default_tts_config();
|
||||
let manager = TtsManager::new(&config).unwrap();
|
||||
assert!(!manager.available_providers().contains(&"piper".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tts_config_defaults() {
|
||||
let config = TtsConfig::default();
|
||||
@@ -664,6 +792,7 @@ mod tests {
|
||||
assert!(config.elevenlabs.is_none());
|
||||
assert!(config.google.is_none());
|
||||
assert!(config.edge.is_none());
|
||||
assert!(config.piper.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
+15
-14
@@ -10,21 +10,22 @@ pub use schema::{
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
|
||||
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
|
||||
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ConversationalAiConfig, CostConfig, CronConfig, CronJobDecl, CronScheduleDecl,
|
||||
DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation,
|
||||
GoogleWorkspaceConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
|
||||
HttpRequestConfig, IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
|
||||
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig,
|
||||
LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig,
|
||||
MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PacingConfig,
|
||||
PeripheralBoardConfig, PeripheralsConfig, PluginsConfig, ProjectIntelConfig, ProxyConfig,
|
||||
ProxyScope, QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
KnowledgeConfig, LarkConfig, LinkEnricherConfig, LinkedInConfig, LinkedInContentConfig,
|
||||
LinkedInImageConfig, LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig,
|
||||
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
|
||||
ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig,
|
||||
OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PiperTtsConfig,
|
||||
PluginsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
|
||||
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
|
||||
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
|
||||
|
||||
+338
-18
@@ -265,6 +265,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub web_fetch: WebFetchConfig,
|
||||
|
||||
/// Link enricher configuration (`[link_enricher]`).
|
||||
#[serde(default)]
|
||||
pub link_enricher: LinkEnricherConfig,
|
||||
|
||||
/// Text browser tool configuration (`[text_browser]`).
|
||||
#[serde(default)]
|
||||
pub text_browser: TextBrowserConfig,
|
||||
@@ -1005,6 +1009,10 @@ fn default_edge_tts_binary_path() -> String {
|
||||
"edge-tts".into()
|
||||
}
|
||||
|
||||
fn default_piper_tts_api_url() -> String {
|
||||
"http://127.0.0.1:5000/v1/audio/speech".into()
|
||||
}
|
||||
|
||||
/// Text-to-Speech configuration (`[tts]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TtsConfig {
|
||||
@@ -1035,6 +1043,9 @@ pub struct TtsConfig {
|
||||
/// Edge TTS provider configuration (`[tts.edge]`).
|
||||
#[serde(default)]
|
||||
pub edge: Option<EdgeTtsConfig>,
|
||||
/// Piper TTS provider configuration (`[tts.piper]`).
|
||||
#[serde(default)]
|
||||
pub piper: Option<PiperTtsConfig>,
|
||||
}
|
||||
|
||||
impl Default for TtsConfig {
|
||||
@@ -1049,6 +1060,7 @@ impl Default for TtsConfig {
|
||||
elevenlabs: None,
|
||||
google: None,
|
||||
edge: None,
|
||||
piper: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1103,6 +1115,14 @@ pub struct EdgeTtsConfig {
|
||||
pub binary_path: String,
|
||||
}
|
||||
|
||||
/// Piper TTS provider configuration (local GPU-accelerated, OpenAI-compatible endpoint).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct PiperTtsConfig {
|
||||
/// Base URL for the Piper TTS HTTP server (e.g. `"http://127.0.0.1:5000/v1/audio/speech"`).
|
||||
#[serde(default = "default_piper_tts_api_url")]
|
||||
pub api_url: String,
|
||||
}
|
||||
|
||||
/// Determines when a `ToolFilterGroup` is active.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
@@ -1258,6 +1278,10 @@ pub struct AgentConfig {
|
||||
/// Useful for small-context models (e.g. glm-4.5-air ~8K tokens → set to 8000).
|
||||
#[serde(default = "default_max_system_prompt_chars")]
|
||||
pub max_system_prompt_chars: usize,
|
||||
/// Thinking/reasoning level control. Configures how deeply the model reasons
|
||||
/// per message. Users can override per-message with `/think:<level>` directives.
|
||||
#[serde(default)]
|
||||
pub thinking: crate::agent::thinking::ThinkingConfig,
|
||||
}
|
||||
|
||||
fn default_agent_max_tool_iterations() -> usize {
|
||||
@@ -1292,6 +1316,7 @@ impl Default for AgentConfig {
|
||||
tool_call_dedup_exempt: Vec::new(),
|
||||
tool_filter_groups: Vec::new(),
|
||||
max_system_prompt_chars: default_max_system_prompt_chars(),
|
||||
thinking: crate::agent::thinking::ThinkingConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1413,6 +1438,15 @@ pub struct MultimodalConfig {
|
||||
/// Allow fetching remote image URLs (http/https). Disabled by default.
|
||||
#[serde(default)]
|
||||
pub allow_remote_fetch: bool,
|
||||
/// Provider name to use for vision/image messages (e.g. `"ollama"`).
|
||||
/// When set, messages containing `[IMAGE:]` markers are routed to this
|
||||
/// provider instead of the default text provider.
|
||||
#[serde(default)]
|
||||
pub vision_provider: Option<String>,
|
||||
/// Model to use when routing to the vision provider (e.g. `"llava:7b"`).
|
||||
/// Only used when `vision_provider` is set.
|
||||
#[serde(default)]
|
||||
pub vision_model: Option<String>,
|
||||
}
|
||||
|
||||
fn default_multimodal_max_images() -> usize {
|
||||
@@ -1438,6 +1472,8 @@ impl Default for MultimodalConfig {
|
||||
max_images: default_multimodal_max_images(),
|
||||
max_image_size_mb: default_multimodal_max_image_size_mb(),
|
||||
allow_remote_fetch: false,
|
||||
vision_provider: None,
|
||||
vision_model: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2120,8 +2156,8 @@ fn default_browser_webdriver_url() -> String {
|
||||
impl Default for BrowserConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_domains: Vec::new(),
|
||||
enabled: true,
|
||||
allowed_domains: vec!["*".into()],
|
||||
session_name: None,
|
||||
backend: default_browser_backend(),
|
||||
native_headless: default_true(),
|
||||
@@ -2136,7 +2172,9 @@ impl Default for BrowserConfig {
|
||||
|
||||
/// HTTP request tool configuration (`[http_request]` section).
|
||||
///
|
||||
/// Deny-by-default: if `allowed_domains` is empty, all HTTP requests are rejected.
|
||||
/// Domain filtering: `allowed_domains` controls which hosts are reachable (use `["*"]`
|
||||
/// for all public hosts, which is the default). If `allowed_domains` is empty, all
|
||||
/// requests are rejected.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct HttpRequestConfig {
|
||||
/// Enable `http_request` tool for API interactions
|
||||
@@ -2160,8 +2198,8 @@ pub struct HttpRequestConfig {
|
||||
impl Default for HttpRequestConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_domains: vec![],
|
||||
enabled: true,
|
||||
allowed_domains: vec!["*".into()],
|
||||
max_response_size: default_http_max_response_size(),
|
||||
timeout_secs: default_http_timeout_secs(),
|
||||
allow_private_hosts: false,
|
||||
@@ -2219,7 +2257,7 @@ fn default_web_fetch_allowed_domains() -> Vec<String> {
|
||||
impl Default for WebFetchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
enabled: true,
|
||||
allowed_domains: vec!["*".into()],
|
||||
blocked_domains: vec![],
|
||||
max_response_size: default_web_fetch_max_response_size(),
|
||||
@@ -2228,6 +2266,45 @@ impl Default for WebFetchConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Link enricher ─────────────────────────────────────────────────
|
||||
|
||||
/// Automatic link understanding for inbound channel messages (`[link_enricher]`).
|
||||
///
|
||||
/// When enabled, URLs in incoming messages are automatically fetched and
|
||||
/// summarised. The summary is prepended to the message before the agent
|
||||
/// processes it, giving the LLM context about linked pages without an
|
||||
/// explicit tool call.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct LinkEnricherConfig {
|
||||
/// Enable the link enricher pipeline stage (default: false)
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Maximum number of links to fetch per message (default: 3)
|
||||
#[serde(default = "default_link_enricher_max_links")]
|
||||
pub max_links: usize,
|
||||
/// Per-link fetch timeout in seconds (default: 10)
|
||||
#[serde(default = "default_link_enricher_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
fn default_link_enricher_max_links() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_link_enricher_timeout_secs() -> u64 {
|
||||
10
|
||||
}
|
||||
|
||||
impl Default for LinkEnricherConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
max_links: default_link_enricher_max_links(),
|
||||
timeout_secs: default_link_enricher_timeout_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Text browser ─────────────────────────────────────────────────
|
||||
|
||||
/// Text browser tool configuration (`[text_browser]` section).
|
||||
@@ -2269,12 +2346,15 @@ pub struct WebSearchConfig {
|
||||
/// Enable `web_search_tool` for web searches
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Search provider: "duckduckgo" (free, no API key) or "brave" (requires API key)
|
||||
/// Search provider: "duckduckgo" (free), "brave" (requires API key), or "searxng" (self-hosted)
|
||||
#[serde(default = "default_web_search_provider")]
|
||||
pub provider: String,
|
||||
/// Brave Search API key (required if provider is "brave")
|
||||
#[serde(default)]
|
||||
pub brave_api_key: Option<String>,
|
||||
/// SearXNG instance URL (required if provider is "searxng"), e.g. "https://searx.example.com"
|
||||
#[serde(default)]
|
||||
pub searxng_instance_url: Option<String>,
|
||||
/// Maximum results per search (1-10)
|
||||
#[serde(default = "default_web_search_max_results")]
|
||||
pub max_results: usize,
|
||||
@@ -2298,9 +2378,10 @@ fn default_web_search_timeout_secs() -> u64 {
|
||||
impl Default for WebSearchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
enabled: true,
|
||||
provider: default_web_search_provider(),
|
||||
brave_api_key: None,
|
||||
searxng_instance_url: None,
|
||||
max_results: default_web_search_max_results(),
|
||||
timeout_secs: default_web_search_timeout_secs(),
|
||||
}
|
||||
@@ -4267,6 +4348,19 @@ fn default_always_ask() -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
impl AutonomyConfig {
|
||||
/// Merge the built-in default `auto_approve` entries into the current
|
||||
/// list, preserving any user-supplied additions.
|
||||
pub fn ensure_default_auto_approve(&mut self) {
|
||||
let defaults = default_auto_approve();
|
||||
for entry in defaults {
|
||||
if !self.auto_approve.iter().any(|existing| existing == &entry) {
|
||||
self.auto_approve.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_valid_env_var_name(name: &str) -> bool {
|
||||
let mut chars = name.chars();
|
||||
match chars.next() {
|
||||
@@ -4764,6 +4858,92 @@ pub struct CronConfig {
|
||||
/// Maximum number of historical cron run records to retain. Default: `50`.
|
||||
#[serde(default = "default_max_run_history")]
|
||||
pub max_run_history: u32,
|
||||
/// Declarative cron job definitions (`[[cron.jobs]]`).
|
||||
///
|
||||
/// Jobs declared here are synced into the database at scheduler startup.
|
||||
/// They use `source = "declarative"` to distinguish them from jobs
|
||||
/// created imperatively via CLI or API. Declarative config takes
|
||||
/// precedence on each sync: if the config changes, the DB is updated
|
||||
/// to match. Imperative jobs are never deleted by the sync process.
|
||||
#[serde(default)]
|
||||
pub jobs: Vec<CronJobDecl>,
|
||||
}
|
||||
|
||||
/// A declarative cron job definition for the `[[cron.jobs]]` config array.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CronJobDecl {
|
||||
/// Stable identifier used for merge semantics across syncs.
|
||||
pub id: String,
|
||||
/// Human-readable name.
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
/// Job type: `"shell"` (default) or `"agent"`.
|
||||
#[serde(default = "default_job_type_decl")]
|
||||
pub job_type: String,
|
||||
/// Schedule for the job.
|
||||
pub schedule: CronScheduleDecl,
|
||||
/// Shell command to run (required when `job_type = "shell"`).
|
||||
#[serde(default)]
|
||||
pub command: Option<String>,
|
||||
/// Agent prompt (required when `job_type = "agent"`).
|
||||
#[serde(default)]
|
||||
pub prompt: Option<String>,
|
||||
/// Whether the job is enabled. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
/// Model override for agent jobs.
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
/// Allowlist of tool names for agent jobs.
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// Session target: `"isolated"` (default) or `"main"`.
|
||||
#[serde(default)]
|
||||
pub session_target: Option<String>,
|
||||
/// Delivery configuration.
|
||||
#[serde(default)]
|
||||
pub delivery: Option<DeliveryConfigDecl>,
|
||||
}
|
||||
|
||||
/// Schedule variant for declarative cron jobs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "kind", rename_all = "lowercase")]
|
||||
pub enum CronScheduleDecl {
|
||||
/// Classic cron expression.
|
||||
Cron {
|
||||
expr: String,
|
||||
#[serde(default)]
|
||||
tz: Option<String>,
|
||||
},
|
||||
/// Interval in milliseconds.
|
||||
Every { every_ms: u64 },
|
||||
/// One-shot at an RFC 3339 timestamp.
|
||||
At { at: String },
|
||||
}
|
||||
|
||||
/// Delivery configuration for declarative cron jobs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DeliveryConfigDecl {
|
||||
/// Delivery mode: `"none"` or `"announce"`.
|
||||
#[serde(default = "default_delivery_mode")]
|
||||
pub mode: String,
|
||||
/// Channel name (e.g. `"telegram"`, `"discord"`).
|
||||
#[serde(default)]
|
||||
pub channel: Option<String>,
|
||||
/// Target/recipient identifier.
|
||||
#[serde(default)]
|
||||
pub to: Option<String>,
|
||||
/// Best-effort delivery. Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub best_effort: bool,
|
||||
}
|
||||
|
||||
fn default_job_type_decl() -> String {
|
||||
"shell".to_string()
|
||||
}
|
||||
|
||||
fn default_delivery_mode() -> String {
|
||||
"none".to_string()
|
||||
}
|
||||
|
||||
fn default_max_run_history() -> u32 {
|
||||
@@ -4776,6 +4956,7 @@ impl Default for CronConfig {
|
||||
enabled: true,
|
||||
catch_up_on_startup: true,
|
||||
max_run_history: default_max_run_history(),
|
||||
jobs: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5457,6 +5638,10 @@ pub struct MatrixConfig {
|
||||
pub room_id: String,
|
||||
/// Allowed Matrix user IDs. Empty = deny all.
|
||||
pub allowed_users: Vec<String>,
|
||||
/// Allowed Matrix room IDs or aliases. Empty = allow all rooms.
|
||||
/// Supports canonical room IDs (`!abc:server`) and aliases (`#room:server`).
|
||||
#[serde(default)]
|
||||
pub allowed_rooms: Vec<String>,
|
||||
/// Whether to interrupt an in-flight agent response when a new message arrives.
|
||||
#[serde(default)]
|
||||
pub interrupt_on_new_message: bool,
|
||||
@@ -6977,6 +7162,7 @@ impl Default for Config {
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
link_enricher: LinkEnricherConfig::default(),
|
||||
text_browser: TextBrowserConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
@@ -7544,6 +7730,19 @@ impl Config {
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).context("Failed to deserialize config file")?;
|
||||
|
||||
// Ensure the built-in default auto_approve entries are always
|
||||
// present. When a user specifies `auto_approve` in their TOML
|
||||
// (e.g. to add a custom tool), serde replaces the default list
|
||||
// instead of merging. This caused default-safe tools like
|
||||
// `weather` or `calculator` to lose their auto-approve status
|
||||
// and get silently denied in non-interactive channel runs.
|
||||
// See #4247.
|
||||
//
|
||||
// Users who want to require approval for a default tool can
|
||||
// add it to `always_ask`, which takes precedence over
|
||||
// `auto_approve` in the approval decision (see approval/mod.rs).
|
||||
config.autonomy.ensure_default_auto_approve();
|
||||
|
||||
// Detect unknown/ignored config keys for diagnostic warnings.
|
||||
// This second pass uses serde_ignored but discards the parsed
|
||||
// result — only the ignored-path list is kept.
|
||||
@@ -8883,6 +9082,16 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
// SearXNG instance URL: ZEROCLAW_SEARXNG_INSTANCE_URL or SEARXNG_INSTANCE_URL
|
||||
if let Ok(instance_url) = std::env::var("ZEROCLAW_SEARXNG_INSTANCE_URL")
|
||||
.or_else(|_| std::env::var("SEARXNG_INSTANCE_URL"))
|
||||
{
|
||||
let instance_url = instance_url.trim();
|
||||
if !instance_url.is_empty() {
|
||||
self.web_search.searxng_instance_url = Some(instance_url.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Web search max results: ZEROCLAW_WEB_SEARCH_MAX_RESULTS or WEB_SEARCH_MAX_RESULTS
|
||||
if let Ok(max_results) = std::env::var("ZEROCLAW_WEB_SEARCH_MAX_RESULTS")
|
||||
.or_else(|_| std::env::var("WEB_SEARCH_MAX_RESULTS"))
|
||||
@@ -9560,7 +9769,9 @@ mod tests {
|
||||
merged.push(']');
|
||||
}
|
||||
merged.push('\n');
|
||||
toml::from_str(&merged).unwrap()
|
||||
let mut config: Config = toml::from_str(&merged).unwrap();
|
||||
config.autonomy.ensure_default_auto_approve();
|
||||
config
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -9568,8 +9779,8 @@ mod tests {
|
||||
let cfg = HttpRequestConfig::default();
|
||||
assert_eq!(cfg.timeout_secs, 30);
|
||||
assert_eq!(cfg.max_response_size, 1_000_000);
|
||||
assert!(!cfg.enabled);
|
||||
assert!(cfg.allowed_domains.is_empty());
|
||||
assert!(cfg.enabled);
|
||||
assert_eq!(cfg.allowed_domains, vec!["*".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -9758,6 +9969,7 @@ recipient = "42"
|
||||
enabled: false,
|
||||
catch_up_on_startup: false,
|
||||
max_run_history: 100,
|
||||
jobs: Vec::new(),
|
||||
};
|
||||
let json = serde_json::to_string(&c).unwrap();
|
||||
let parsed: CronConfig = serde_json::from_str(&json).unwrap();
|
||||
@@ -9932,6 +10144,7 @@ default_temperature = 0.7
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
link_enricher: LinkEnricherConfig::default(),
|
||||
text_browser: TextBrowserConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
@@ -10046,6 +10259,109 @@ auto_approve = ["file_read", "memory_recall", "http_request"]
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test for #4247: when a user provides a custom auto_approve
|
||||
/// list, the built-in defaults must still be present.
|
||||
#[test]
|
||||
async fn auto_approve_merges_user_entries_with_defaults() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
auto_approve = ["my_custom_tool", "another_tool"]
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
// User entries are preserved
|
||||
assert!(
|
||||
parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.contains(&"my_custom_tool".to_string()),
|
||||
"user-supplied tool must remain in auto_approve"
|
||||
);
|
||||
assert!(
|
||||
parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.contains(&"another_tool".to_string()),
|
||||
"user-supplied tool must remain in auto_approve"
|
||||
);
|
||||
// Defaults are merged in
|
||||
for default_tool in &[
|
||||
"file_read",
|
||||
"memory_recall",
|
||||
"weather",
|
||||
"calculator",
|
||||
"web_fetch",
|
||||
] {
|
||||
assert!(
|
||||
parsed.autonomy.auto_approve.contains(&default_tool.to_string()),
|
||||
"default tool '{default_tool}' must be present in auto_approve even when user provides custom list"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Regression test: empty auto_approve still gets defaults merged.
|
||||
#[test]
|
||||
async fn auto_approve_empty_list_gets_defaults() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
auto_approve = []
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
let defaults = default_auto_approve();
|
||||
for tool in &defaults {
|
||||
assert!(
|
||||
parsed.autonomy.auto_approve.contains(tool),
|
||||
"default tool '{tool}' must be present even when user sets auto_approve = []"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// When no autonomy section is provided, defaults are applied normally.
|
||||
#[test]
|
||||
async fn auto_approve_defaults_when_no_autonomy_section() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
let defaults = default_auto_approve();
|
||||
for tool in &defaults {
|
||||
assert!(
|
||||
parsed.autonomy.auto_approve.contains(tool),
|
||||
"default tool '{tool}' must be present when no [autonomy] section"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Duplicates are not introduced when ensure_default_auto_approve runs
|
||||
/// on a list that already contains the defaults.
|
||||
#[test]
|
||||
async fn auto_approve_no_duplicates() {
|
||||
let raw = r#"
|
||||
default_temperature = 0.7
|
||||
|
||||
[autonomy]
|
||||
auto_approve = ["weather", "file_read"]
|
||||
"#;
|
||||
let parsed = parse_test_config(raw);
|
||||
let weather_count = parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.iter()
|
||||
.filter(|t| *t == "weather")
|
||||
.count();
|
||||
assert_eq!(weather_count, 1, "weather must not be duplicated");
|
||||
let file_read_count = parsed
|
||||
.autonomy
|
||||
.auto_approve
|
||||
.iter()
|
||||
.filter(|t| *t == "file_read")
|
||||
.count();
|
||||
assert_eq!(file_read_count, 1, "file_read must not be duplicated");
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn provider_timeout_secs_parses_from_toml() {
|
||||
let raw = r#"
|
||||
@@ -10345,6 +10661,7 @@ default_temperature = 0.7
|
||||
http_request: HttpRequestConfig::default(),
|
||||
multimodal: MultimodalConfig::default(),
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
link_enricher: LinkEnricherConfig::default(),
|
||||
text_browser: TextBrowserConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
@@ -10657,6 +10974,7 @@ default_temperature = 0.7
|
||||
device_id: Some("DEVICE123".into()),
|
||||
room_id: "!room123:matrix.org".into(),
|
||||
allowed_users: vec!["@user:matrix.org".into()],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
};
|
||||
let json = serde_json::to_string(&mc).unwrap();
|
||||
@@ -10678,6 +10996,7 @@ default_temperature = 0.7
|
||||
device_id: None,
|
||||
room_id: "!abc:synapse.local".into(),
|
||||
allowed_users: vec!["@admin:synapse.local".into(), "*".into()],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
};
|
||||
let toml_str = toml::to_string(&mc).unwrap();
|
||||
@@ -10771,6 +11090,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
device_id: None,
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec!["@u:m".into()],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
}),
|
||||
signal: None,
|
||||
@@ -11366,15 +11686,15 @@ default_temperature = 0.7
|
||||
assert!(!c.composio.enabled);
|
||||
assert!(c.composio.api_key.is_none());
|
||||
assert!(c.secrets.encrypt);
|
||||
assert!(!c.browser.enabled);
|
||||
assert!(c.browser.allowed_domains.is_empty());
|
||||
assert!(c.browser.enabled);
|
||||
assert_eq!(c.browser.allowed_domains, vec!["*".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn browser_config_default_disabled() {
|
||||
async fn browser_config_default_enabled() {
|
||||
let b = BrowserConfig::default();
|
||||
assert!(!b.enabled);
|
||||
assert!(b.allowed_domains.is_empty());
|
||||
assert!(b.enabled);
|
||||
assert_eq!(b.allowed_domains, vec!["*".to_string()]);
|
||||
assert_eq!(b.backend, "agent_browser");
|
||||
assert!(b.native_headless);
|
||||
assert_eq!(b.native_webdriver_url, "http://127.0.0.1:9515");
|
||||
@@ -11439,8 +11759,8 @@ config_path = "/tmp/config.toml"
|
||||
default_temperature = 0.7
|
||||
"#;
|
||||
let parsed = parse_test_config(minimal);
|
||||
assert!(!parsed.browser.enabled);
|
||||
assert!(parsed.browser.allowed_domains.is_empty());
|
||||
assert!(parsed.browser.enabled);
|
||||
assert_eq!(parsed.browser.allowed_domains, vec!["*".to_string()]);
|
||||
}
|
||||
|
||||
// ── Environment variable overrides (Docker support) ─────────
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ pub use schedule::{
|
||||
#[allow(unused_imports)]
|
||||
pub use store::{
|
||||
add_agent_job, all_overdue_jobs, due_jobs, get_job, list_jobs, list_runs, record_last_run,
|
||||
record_run, remove_job, reschedule_after_run, update_job,
|
||||
record_run, remove_job, reschedule_after_run, sync_declarative_jobs, update_job,
|
||||
};
|
||||
pub use types::{
|
||||
deserialize_maybe_stringified, CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType,
|
||||
|
||||
+48
-2
@@ -1,5 +1,7 @@
|
||||
#[cfg(feature = "channel-matrix")]
|
||||
use crate::channels::MatrixChannel;
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
use crate::channels::WhatsAppWebChannel;
|
||||
use crate::channels::{
|
||||
Channel, DiscordChannel, MattermostChannel, QQChannel, SendMessage, SignalChannel,
|
||||
SlackChannel, TelegramChannel,
|
||||
@@ -7,8 +9,8 @@ use crate::channels::{
|
||||
use crate::config::Config;
|
||||
use crate::cron::{
|
||||
all_overdue_jobs, due_jobs, next_run_for_schedule, record_last_run, record_run, remove_job,
|
||||
reschedule_after_run, update_job, CronJob, CronJobPatch, DeliveryConfig, JobType, Schedule,
|
||||
SessionTarget,
|
||||
reschedule_after_run, sync_declarative_jobs, update_job, CronJob, CronJobPatch, DeliveryConfig,
|
||||
JobType, Schedule, SessionTarget,
|
||||
};
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Result;
|
||||
@@ -34,6 +36,19 @@ pub async fn run(config: Config) -> Result<()> {
|
||||
|
||||
crate::health::mark_component_ok(SCHEDULER_COMPONENT);
|
||||
|
||||
// ── Declarative job sync: reconcile config-defined jobs with the DB.
|
||||
match sync_declarative_jobs(&config, &config.cron.jobs) {
|
||||
Ok(()) => {
|
||||
if !config.cron.jobs.is_empty() {
|
||||
tracing::info!(
|
||||
count = config.cron.jobs.len(),
|
||||
"Synced declarative cron jobs from config"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to sync declarative cron jobs: {e}"),
|
||||
}
|
||||
|
||||
// ── Startup catch-up: run ALL overdue jobs before entering the
|
||||
// normal polling loop. The regular loop is capped by `max_tasks`,
|
||||
// which could leave some overdue jobs waiting across many cycles
|
||||
@@ -483,6 +498,36 @@ pub(crate) async fn deliver_announcement(
|
||||
anyhow::bail!("matrix delivery channel requires `channel-matrix` feature");
|
||||
}
|
||||
}
|
||||
"whatsapp" | "whatsapp-web" | "whatsapp_web" => {
|
||||
#[cfg(feature = "whatsapp-web")]
|
||||
{
|
||||
let wa = config
|
||||
.channels_config
|
||||
.whatsapp
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("whatsapp channel not configured"))?;
|
||||
if !wa.is_web_config() {
|
||||
anyhow::bail!(
|
||||
"whatsapp cron delivery requires Web mode (session_path must be set)"
|
||||
);
|
||||
}
|
||||
let channel = WhatsAppWebChannel::new(
|
||||
wa.session_path.clone().unwrap_or_default(),
|
||||
wa.pair_phone.clone(),
|
||||
wa.pair_code.clone(),
|
||||
wa.allowed_numbers.clone(),
|
||||
wa.mode.clone(),
|
||||
wa.dm_policy.clone(),
|
||||
wa.group_policy.clone(),
|
||||
wa.self_chat_mode,
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
}
|
||||
#[cfg(not(feature = "whatsapp-web"))]
|
||||
{
|
||||
anyhow::bail!("whatsapp delivery channel requires `whatsapp-web` feature");
|
||||
}
|
||||
}
|
||||
"qq" => {
|
||||
let qq = config
|
||||
.channels_config
|
||||
@@ -657,6 +702,7 @@ mod tests {
|
||||
delivery: DeliveryConfig::default(),
|
||||
delete_after_run: false,
|
||||
allowed_tools: None,
|
||||
source: "imperative".into(),
|
||||
created_at: Utc::now(),
|
||||
next_run: Utc::now(),
|
||||
last_run: None,
|
||||
|
||||
+521
-4
@@ -124,7 +124,7 @@ pub fn list_jobs(config: &Config) -> Result<Vec<CronJob>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs ORDER BY next_run ASC",
|
||||
)?;
|
||||
|
||||
@@ -143,7 +143,7 @@ pub fn get_job(config: &Config, job_id: &str) -> Result<CronJob> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs WHERE id = ?1",
|
||||
)?;
|
||||
|
||||
@@ -177,7 +177,7 @@ pub fn due_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJob>> {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools
|
||||
allowed_tools, source
|
||||
FROM cron_jobs
|
||||
WHERE enabled = 1 AND next_run <= ?1
|
||||
ORDER BY next_run ASC
|
||||
@@ -206,7 +206,8 @@ pub fn all_overdue_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJ
|
||||
with_connection(config, |conn| {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output, allowed_tools
|
||||
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
|
||||
allowed_tools, source
|
||||
FROM cron_jobs
|
||||
WHERE enabled = 1 AND next_run <= ?1
|
||||
ORDER BY next_run ASC",
|
||||
@@ -488,6 +489,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
let last_run_raw: Option<String> = row.get(14)?;
|
||||
let created_at_raw: String = row.get(12)?;
|
||||
let allowed_tools_raw: Option<String> = row.get(17)?;
|
||||
let source: Option<String> = row.get(18)?;
|
||||
|
||||
Ok(CronJob {
|
||||
id: row.get(0)?,
|
||||
@@ -502,6 +504,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
|
||||
enabled: row.get::<_, i64>(9)? != 0,
|
||||
delivery,
|
||||
delete_after_run: row.get::<_, i64>(11)? != 0,
|
||||
source: source.unwrap_or_else(|| "imperative".to_string()),
|
||||
created_at: parse_rfc3339(&created_at_raw).map_err(sql_conversion_error)?,
|
||||
next_run: parse_rfc3339(&next_run_raw).map_err(sql_conversion_error)?,
|
||||
last_run: match last_run_raw {
|
||||
@@ -564,6 +567,277 @@ fn decode_allowed_tools(raw: Option<&str>) -> Result<Option<Vec<String>>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Synchronize declarative cron job definitions from config into the database.
|
||||
///
|
||||
/// For each declarative job (identified by `id`):
|
||||
/// - If the job exists in DB: update it to match the config definition.
|
||||
/// - If the job does not exist: insert it.
|
||||
///
|
||||
/// Jobs created imperatively (via CLI/API) are never modified or deleted.
|
||||
/// Declarative jobs that are no longer present in config are removed.
|
||||
pub fn sync_declarative_jobs(
|
||||
config: &Config,
|
||||
decls: &[crate::config::schema::CronJobDecl],
|
||||
) -> Result<()> {
|
||||
use crate::config::schema::CronScheduleDecl;
|
||||
|
||||
if decls.is_empty() {
|
||||
// If no declarative jobs are defined, clean up any previously
|
||||
// synced declarative jobs that are no longer in config.
|
||||
with_connection(config, |conn| {
|
||||
let deleted = conn
|
||||
.execute("DELETE FROM cron_jobs WHERE source = 'declarative'", [])
|
||||
.context("Failed to remove stale declarative cron jobs")?;
|
||||
if deleted > 0 {
|
||||
tracing::info!(
|
||||
count = deleted,
|
||||
"Removed declarative cron jobs no longer in config"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Validate declarations before touching the DB.
|
||||
for decl in decls {
|
||||
validate_decl(decl)?;
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
with_connection(config, |conn| {
|
||||
// Collect IDs of all declarative jobs currently defined in config.
|
||||
let config_ids: std::collections::HashSet<&str> =
|
||||
decls.iter().map(|d| d.id.as_str()).collect();
|
||||
|
||||
// Remove declarative jobs no longer in config.
|
||||
{
|
||||
let mut stmt = conn.prepare("SELECT id FROM cron_jobs WHERE source = 'declarative'")?;
|
||||
let db_ids: Vec<String> = stmt
|
||||
.query_map([], |row| row.get(0))?
|
||||
.filter_map(|r| r.ok())
|
||||
.collect();
|
||||
|
||||
for db_id in &db_ids {
|
||||
if !config_ids.contains(db_id.as_str()) {
|
||||
conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![db_id])
|
||||
.with_context(|| {
|
||||
format!("Failed to remove stale declarative cron job '{db_id}'")
|
||||
})?;
|
||||
tracing::info!(
|
||||
job_id = %db_id,
|
||||
"Removed declarative cron job no longer in config"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for decl in decls {
|
||||
let schedule = convert_schedule_decl(&decl.schedule)?;
|
||||
let expression = schedule_cron_expression(&schedule).unwrap_or_default();
|
||||
let schedule_json = serde_json::to_string(&schedule)?;
|
||||
let job_type = &decl.job_type;
|
||||
let session_target = decl.session_target.as_deref().unwrap_or("isolated");
|
||||
let delivery = match &decl.delivery {
|
||||
Some(d) => convert_delivery_decl(d),
|
||||
None => DeliveryConfig::default(),
|
||||
};
|
||||
let delivery_json = serde_json::to_string(&delivery)?;
|
||||
let allowed_tools_json = encode_allowed_tools(decl.allowed_tools.as_ref())?;
|
||||
let command = decl.command.as_deref().unwrap_or("");
|
||||
let delete_after_run = matches!(decl.schedule, CronScheduleDecl::At { .. });
|
||||
|
||||
// Check if job already exists.
|
||||
let exists: bool = conn
|
||||
.prepare("SELECT COUNT(*) FROM cron_jobs WHERE id = ?1")?
|
||||
.query_row(params![decl.id], |row| row.get::<_, i64>(0))
|
||||
.map(|c| c > 0)
|
||||
.unwrap_or(false);
|
||||
|
||||
if exists {
|
||||
// Update existing declarative job — preserve runtime state
|
||||
// (next_run, last_run, last_status, last_output, created_at).
|
||||
// Only update the schedule's next_run if the schedule itself changed.
|
||||
let current_schedule_raw: Option<String> = conn
|
||||
.prepare("SELECT schedule FROM cron_jobs WHERE id = ?1")?
|
||||
.query_row(params![decl.id], |row| row.get(0))
|
||||
.ok();
|
||||
|
||||
let schedule_changed = current_schedule_raw.as_deref() != Some(&schedule_json);
|
||||
|
||||
if schedule_changed {
|
||||
let next_run = next_run_for_schedule(&schedule, now)?;
|
||||
conn.execute(
|
||||
"UPDATE cron_jobs
|
||||
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
|
||||
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
|
||||
enabled = ?9, delivery = ?10, delete_after_run = ?11,
|
||||
allowed_tools = ?12, source = 'declarative', next_run = ?13
|
||||
WHERE id = ?14",
|
||||
params![
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
next_run.to_rfc3339(),
|
||||
decl.id,
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("Failed to update declarative cron job '{}'", decl.id)
|
||||
})?;
|
||||
} else {
|
||||
conn.execute(
|
||||
"UPDATE cron_jobs
|
||||
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
|
||||
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
|
||||
enabled = ?9, delivery = ?10, delete_after_run = ?11,
|
||||
allowed_tools = ?12, source = 'declarative'
|
||||
WHERE id = ?13",
|
||||
params![
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
decl.id,
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!("Failed to update declarative cron job '{}'", decl.id)
|
||||
})?;
|
||||
}
|
||||
|
||||
tracing::debug!(job_id = %decl.id, "Updated declarative cron job");
|
||||
} else {
|
||||
// Insert new declarative job.
|
||||
let next_run = next_run_for_schedule(&schedule, now)?;
|
||||
conn.execute(
|
||||
"INSERT INTO cron_jobs (
|
||||
id, expression, command, schedule, job_type, prompt, name,
|
||||
session_target, model, enabled, delivery, delete_after_run,
|
||||
allowed_tools, source, created_at, next_run
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, 'declarative', ?14, ?15)",
|
||||
params![
|
||||
decl.id,
|
||||
expression,
|
||||
command,
|
||||
schedule_json,
|
||||
job_type,
|
||||
decl.prompt,
|
||||
decl.name,
|
||||
session_target,
|
||||
decl.model,
|
||||
if decl.enabled { 1 } else { 0 },
|
||||
delivery_json,
|
||||
if delete_after_run { 1 } else { 0 },
|
||||
allowed_tools_json,
|
||||
now.to_rfc3339(),
|
||||
next_run.to_rfc3339(),
|
||||
],
|
||||
)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to insert declarative cron job '{}'",
|
||||
decl.id
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!(job_id = %decl.id, "Inserted declarative cron job from config");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a declarative cron job definition.
|
||||
fn validate_decl(decl: &crate::config::schema::CronJobDecl) -> Result<()> {
|
||||
if decl.id.trim().is_empty() {
|
||||
anyhow::bail!("Declarative cron job has empty id");
|
||||
}
|
||||
|
||||
match decl.job_type.to_lowercase().as_str() {
|
||||
"shell" => {
|
||||
if decl
|
||||
.command
|
||||
.as_deref()
|
||||
.map_or(true, |c| c.trim().is_empty())
|
||||
{
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': shell job requires a non-empty 'command'",
|
||||
decl.id
|
||||
);
|
||||
}
|
||||
}
|
||||
"agent" => {
|
||||
if decl.prompt.as_deref().map_or(true, |p| p.trim().is_empty()) {
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': agent job requires a non-empty 'prompt'",
|
||||
decl.id
|
||||
);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"Declarative cron job '{}': invalid job_type '{}', expected 'shell' or 'agent'",
|
||||
decl.id,
|
||||
other
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert a `CronScheduleDecl` to the runtime `Schedule` type.
|
||||
fn convert_schedule_decl(decl: &crate::config::schema::CronScheduleDecl) -> Result<Schedule> {
|
||||
use crate::config::schema::CronScheduleDecl;
|
||||
match decl {
|
||||
CronScheduleDecl::Cron { expr, tz } => Ok(Schedule::Cron {
|
||||
expr: expr.clone(),
|
||||
tz: tz.clone(),
|
||||
}),
|
||||
CronScheduleDecl::Every { every_ms } => Ok(Schedule::Every {
|
||||
every_ms: *every_ms,
|
||||
}),
|
||||
CronScheduleDecl::At { at } => {
|
||||
let parsed = DateTime::parse_from_rfc3339(at)
|
||||
.with_context(|| {
|
||||
format!("Invalid RFC3339 timestamp in declarative cron 'at': {at}")
|
||||
})?
|
||||
.with_timezone(&Utc);
|
||||
Ok(Schedule::At { at: parsed })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a `DeliveryConfigDecl` to the runtime `DeliveryConfig`.
|
||||
fn convert_delivery_decl(decl: &crate::config::schema::DeliveryConfigDecl) -> DeliveryConfig {
|
||||
DeliveryConfig {
|
||||
mode: decl.mode.clone(),
|
||||
channel: decl.channel.clone(),
|
||||
to: decl.to.clone(),
|
||||
best_effort: decl.best_effort,
|
||||
}
|
||||
}
|
||||
|
||||
fn add_column_if_missing(conn: &Connection, name: &str, sql_type: &str) -> Result<()> {
|
||||
let mut stmt = conn.prepare("PRAGMA table_info(cron_jobs)")?;
|
||||
let mut rows = stmt.query([])?;
|
||||
@@ -654,6 +928,7 @@ fn with_connection<T>(config: &Config, f: impl FnOnce(&Connection) -> Result<T>)
|
||||
add_column_if_missing(&conn, "delivery", "TEXT")?;
|
||||
add_column_if_missing(&conn, "delete_after_run", "INTEGER NOT NULL DEFAULT 0")?;
|
||||
add_column_if_missing(&conn, "allowed_tools", "TEXT")?;
|
||||
add_column_if_missing(&conn, "source", "TEXT DEFAULT 'imperative'")?;
|
||||
|
||||
f(&conn)
|
||||
}
|
||||
@@ -1170,4 +1445,246 @@ mod tests {
|
||||
assert!(last_output.ends_with(TRUNCATED_OUTPUT_MARKER));
|
||||
assert!(last_output.len() <= MAX_CRON_OUTPUT_BYTES);
|
||||
}
|
||||
|
||||
// ── Declarative cron job sync tests ──────────────────────────
|
||||
|
||||
fn make_shell_decl(id: &str, expr: &str, cmd: &str) -> crate::config::schema::CronJobDecl {
|
||||
crate::config::schema::CronJobDecl {
|
||||
id: id.to_string(),
|
||||
name: Some(format!("decl-{id}")),
|
||||
job_type: "shell".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Cron {
|
||||
expr: expr.to_string(),
|
||||
tz: None,
|
||||
},
|
||||
command: Some(cmd.to_string()),
|
||||
prompt: None,
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_agent_decl(id: &str, expr: &str, prompt: &str) -> crate::config::schema::CronJobDecl {
|
||||
crate::config::schema::CronJobDecl {
|
||||
id: id.to_string(),
|
||||
name: Some(format!("decl-{id}")),
|
||||
job_type: "agent".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Cron {
|
||||
expr: expr.to_string(),
|
||||
tz: None,
|
||||
},
|
||||
command: None,
|
||||
prompt: Some(prompt.to_string()),
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_inserts_new_declarative_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("daily-backup", "0 2 * * *", "echo backup")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job = get_job(&config, "daily-backup").unwrap();
|
||||
assert_eq!(job.command, "echo backup");
|
||||
assert_eq!(job.source, "declarative");
|
||||
assert_eq!(job.name.as_deref(), Some("decl-daily-backup"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_updates_existing_declarative_job() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("updatable", "0 2 * * *", "echo v1")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job_v1 = get_job(&config, "updatable").unwrap();
|
||||
assert_eq!(job_v1.command, "echo v1");
|
||||
|
||||
let decls_v2 = vec![make_shell_decl("updatable", "0 3 * * *", "echo v2")];
|
||||
sync_declarative_jobs(&config, &decls_v2).unwrap();
|
||||
|
||||
let job_v2 = get_job(&config, "updatable").unwrap();
|
||||
assert_eq!(job_v2.command, "echo v2");
|
||||
assert_eq!(job_v2.expression, "0 3 * * *");
|
||||
assert_eq!(job_v2.source, "declarative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_does_not_delete_imperative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
// Create an imperative job via the normal API.
|
||||
let imperative = add_job(&config, "*/10 * * * *", "echo imperative").unwrap();
|
||||
|
||||
// Sync declarative jobs (none of which match the imperative job).
|
||||
let decls = vec![make_shell_decl("my-decl", "0 2 * * *", "echo decl")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
// Imperative job should still exist.
|
||||
let still_there = get_job(&config, &imperative.id).unwrap();
|
||||
assert_eq!(still_there.command, "echo imperative");
|
||||
assert_eq!(still_there.source, "imperative");
|
||||
|
||||
// Declarative job should also exist.
|
||||
let decl_job = get_job(&config, "my-decl").unwrap();
|
||||
assert_eq!(decl_job.command, "echo decl");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_removes_stale_declarative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
// Insert two declarative jobs.
|
||||
let decls = vec![
|
||||
make_shell_decl("keeper", "0 2 * * *", "echo keep"),
|
||||
make_shell_decl("stale", "0 3 * * *", "echo stale"),
|
||||
];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
// Now sync with only "keeper" — "stale" should be removed.
|
||||
let decls_v2 = vec![make_shell_decl("keeper", "0 2 * * *", "echo keep")];
|
||||
sync_declarative_jobs(&config, &decls_v2).unwrap();
|
||||
|
||||
assert!(get_job(&config, "stale").is_err());
|
||||
assert!(get_job(&config, "keeper").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_empty_removes_all_declarative_jobs() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_shell_decl("to-remove", "0 2 * * *", "echo bye")];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
assert!(get_job(&config, "to-remove").is_ok());
|
||||
|
||||
// Sync with empty list.
|
||||
sync_declarative_jobs(&config, &[]).unwrap();
|
||||
assert!(get_job(&config, "to-remove").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_validates_shell_job_requires_command() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let mut decl = make_shell_decl("bad", "0 2 * * *", "echo ok");
|
||||
decl.command = None;
|
||||
|
||||
let result = sync_declarative_jobs(&config, &[decl]);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("command"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_validates_agent_job_requires_prompt() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let mut decl = make_agent_decl("bad-agent", "0 2 * * *", "do stuff");
|
||||
decl.prompt = None;
|
||||
|
||||
let result = sync_declarative_jobs(&config, &[decl]);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_agent_job_inserts_correctly() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decls = vec![make_agent_decl(
|
||||
"agent-check",
|
||||
"*/15 * * * *",
|
||||
"check health",
|
||||
)];
|
||||
sync_declarative_jobs(&config, &decls).unwrap();
|
||||
|
||||
let job = get_job(&config, "agent-check").unwrap();
|
||||
assert_eq!(job.job_type, JobType::Agent);
|
||||
assert_eq!(job.prompt.as_deref(), Some("check health"));
|
||||
assert_eq!(job.source, "declarative");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sync_every_schedule_works() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let config = test_config(&tmp);
|
||||
|
||||
let decl = crate::config::schema::CronJobDecl {
|
||||
id: "interval-job".to_string(),
|
||||
name: None,
|
||||
job_type: "shell".to_string(),
|
||||
schedule: crate::config::schema::CronScheduleDecl::Every { every_ms: 60000 },
|
||||
command: Some("echo interval".to_string()),
|
||||
prompt: None,
|
||||
enabled: true,
|
||||
model: None,
|
||||
allowed_tools: None,
|
||||
session_target: None,
|
||||
delivery: None,
|
||||
};
|
||||
|
||||
sync_declarative_jobs(&config, &[decl]).unwrap();
|
||||
|
||||
let job = get_job(&config, "interval-job").unwrap();
|
||||
assert!(matches!(job.schedule, Schedule::Every { every_ms: 60000 }));
|
||||
assert_eq!(job.command, "echo interval");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn declarative_config_parses_from_toml() {
|
||||
let toml_str = r#"
|
||||
enabled = true
|
||||
|
||||
[[jobs]]
|
||||
id = "daily-report"
|
||||
name = "Daily Report"
|
||||
job_type = "shell"
|
||||
command = "echo report"
|
||||
schedule = { kind = "cron", expr = "0 9 * * *" }
|
||||
|
||||
[[jobs]]
|
||||
id = "health-check"
|
||||
job_type = "agent"
|
||||
prompt = "Check server health"
|
||||
schedule = { kind = "every", every_ms = 300000 }
|
||||
"#;
|
||||
|
||||
let parsed: crate::config::schema::CronConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(parsed.enabled);
|
||||
assert_eq!(parsed.jobs.len(), 2);
|
||||
|
||||
assert_eq!(parsed.jobs[0].id, "daily-report");
|
||||
assert_eq!(parsed.jobs[0].command.as_deref(), Some("echo report"));
|
||||
assert!(matches!(
|
||||
parsed.jobs[0].schedule,
|
||||
crate::config::schema::CronScheduleDecl::Cron { ref expr, .. } if expr == "0 9 * * *"
|
||||
));
|
||||
|
||||
assert_eq!(parsed.jobs[1].id, "health-check");
|
||||
assert_eq!(parsed.jobs[1].job_type, "agent");
|
||||
assert_eq!(
|
||||
parsed.jobs[1].prompt.as_deref(),
|
||||
Some("Check server health")
|
||||
);
|
||||
assert!(matches!(
|
||||
parsed.jobs[1].schedule,
|
||||
crate::config::schema::CronScheduleDecl::Every { every_ms: 300_000 }
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,6 +127,10 @@ fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_source() -> String {
|
||||
"imperative".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronJob {
|
||||
pub id: String,
|
||||
@@ -146,6 +150,9 @@ pub struct CronJob {
|
||||
/// When `None`, all tools are available (backward compatible default).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
/// How the job was created: `"imperative"` (CLI/API) or `"declarative"` (config).
|
||||
#[serde(default = "default_source")]
|
||||
pub source: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub next_run: DateTime<Utc>,
|
||||
pub last_run: Option<DateTime<Utc>>,
|
||||
|
||||
+56
-2
@@ -1280,12 +1280,16 @@ pub async fn handle_api_sessions_list(
|
||||
.into_iter()
|
||||
.filter_map(|meta| {
|
||||
let session_id = meta.key.strip_prefix("gw_")?;
|
||||
Some(serde_json::json!({
|
||||
let mut entry = serde_json::json!({
|
||||
"session_id": session_id,
|
||||
"created_at": meta.created_at.to_rfc3339(),
|
||||
"last_activity": meta.last_activity.to_rfc3339(),
|
||||
"message_count": meta.message_count,
|
||||
}))
|
||||
});
|
||||
if let Some(name) = meta.name {
|
||||
entry["name"] = serde_json::Value::String(name);
|
||||
}
|
||||
Some(entry)
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1326,6 +1330,56 @@ pub async fn handle_api_session_delete(
|
||||
}
|
||||
}
|
||||
|
||||
/// PUT /api/sessions/{id} — rename a gateway session
|
||||
pub async fn handle_api_session_rename(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<serde_json::Value>,
|
||||
) -> impl IntoResponse {
|
||||
if let Err(e) = require_auth(&state, &headers) {
|
||||
return e.into_response();
|
||||
}
|
||||
|
||||
let Some(ref backend) = state.session_backend else {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Session persistence is disabled"})),
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
let name = body["name"].as_str().unwrap_or("").trim();
|
||||
if name.is_empty() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({"error": "name is required"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let session_key = format!("gw_{id}");
|
||||
|
||||
// Verify the session exists before renaming
|
||||
let sessions = backend.list_sessions();
|
||||
if !sessions.contains(&session_key) {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({"error": "Session not found"})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
match backend.set_session_name(&session_key, name) {
|
||||
Ok(()) => Json(serde_json::json!({"session_id": id, "name": name})).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({"error": format!("Failed to rename session: {e}")})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
+1
-1
@@ -886,7 +886,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||
.route("/api/cli-tools", get(api::handle_api_cli_tools))
|
||||
.route("/api/health", get(api::handle_api_health))
|
||||
.route("/api/sessions", get(api::handle_api_sessions_list))
|
||||
.route("/api/sessions/{id}", delete(api::handle_api_session_delete))
|
||||
.route("/api/sessions/{id}", delete(api::handle_api_session_delete).put(api::handle_api_session_rename))
|
||||
// ── Pairing + Device management API ──
|
||||
.route("/api/pairing/initiate", post(api_pairing::initiate_pairing))
|
||||
.route("/api/pair", post(api_pairing::submit_pairing_enhanced))
|
||||
|
||||
+34
-3
@@ -1,13 +1,21 @@
|
||||
//! WebSocket agent chat handler.
|
||||
//!
|
||||
//! Connect: `ws://host:port/ws/chat?session_id=ID&name=My+Session`
|
||||
//!
|
||||
//! Protocol:
|
||||
//! ```text
|
||||
//! Server -> Client: {"type":"session_start","session_id":"...","name":"...","resumed":true,"message_count":42}
|
||||
//! Client -> Server: {"type":"message","content":"Hello"}
|
||||
//! Server -> Client: {"type":"chunk","content":"Hi! "}
|
||||
//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
|
||||
//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
|
||||
//! Server -> Client: {"type":"done","full_response":"..."}
|
||||
//! ```
|
||||
//!
|
||||
//! Query params:
|
||||
//! - `session_id` — resume or create a session (default: new UUID)
|
||||
//! - `name` — optional human-readable label for the session
|
||||
//! - `token` — bearer auth token (alternative to Authorization header)
|
||||
|
||||
use super::AppState;
|
||||
use axum::{
|
||||
@@ -53,6 +61,8 @@ const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
|
||||
pub struct WsQuery {
|
||||
pub token: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
/// Optional human-readable name for the session.
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a bearer token from WebSocket-compatible sources.
|
||||
@@ -134,14 +144,20 @@ pub async fn handle_ws_chat(
|
||||
};
|
||||
|
||||
let session_id = params.session_id;
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
|
||||
let session_name = params.name;
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Gateway session key prefix to avoid collisions with channel sessions.
|
||||
const GW_SESSION_PREFIX: &str = "gw_";
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<String>) {
|
||||
async fn handle_socket(
|
||||
socket: WebSocket,
|
||||
state: AppState,
|
||||
session_id: Option<String>,
|
||||
session_name: Option<String>,
|
||||
) {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
// Resolve session ID: use provided or generate a new UUID
|
||||
@@ -163,6 +179,7 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
|
||||
// Hydrate agent from persisted session (if available)
|
||||
let mut resumed = false;
|
||||
let mut message_count: usize = 0;
|
||||
let mut effective_name: Option<String> = None;
|
||||
if let Some(ref backend) = state.session_backend {
|
||||
let messages = backend.load(&session_key);
|
||||
if !messages.is_empty() {
|
||||
@@ -170,15 +187,29 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
|
||||
agent.seed_history(&messages);
|
||||
resumed = true;
|
||||
}
|
||||
// Set session name if provided (non-empty) on connect
|
||||
if let Some(ref name) = session_name {
|
||||
if !name.is_empty() {
|
||||
let _ = backend.set_session_name(&session_key, name);
|
||||
effective_name = Some(name.clone());
|
||||
}
|
||||
}
|
||||
// If no name was provided via query param, load the stored name
|
||||
if effective_name.is_none() {
|
||||
effective_name = backend.get_session_name(&session_key).unwrap_or(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Send session_start message to client
|
||||
let session_start = serde_json::json!({
|
||||
let mut session_start = serde_json::json!({
|
||||
"type": "session_start",
|
||||
"session_id": session_id,
|
||||
"resumed": resumed,
|
||||
"message_count": message_count,
|
||||
});
|
||||
if let Some(ref name) = effective_name {
|
||||
session_start["name"] = serde_json::Value::String(name.clone());
|
||||
}
|
||||
let _ = sender
|
||||
.send(Message::Text(session_start.to_string().into()))
|
||||
.await;
|
||||
|
||||
@@ -891,6 +891,7 @@ mod tests {
|
||||
device_id: None,
|
||||
room_id: "!r:m".into(),
|
||||
allowed_users: vec![],
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
});
|
||||
let entries = all_integrations();
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
use super::traits::{MemoryCategory, MemoryEntry};
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Default half-life in days for time-decay scoring.
|
||||
/// After this many days, a non-Core memory's score drops to 50%.
|
||||
pub const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
|
||||
|
||||
/// Apply exponential time decay to memory entry scores.
|
||||
///
|
||||
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
|
||||
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
|
||||
/// - Entries without a score (`None`) are left unchanged.
|
||||
///
|
||||
/// Decay formula: `score * 2^(-age_days / half_life_days)`
|
||||
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
|
||||
let half_life = if half_life_days <= 0.0 {
|
||||
DEFAULT_HALF_LIFE_DAYS
|
||||
} else {
|
||||
half_life_days
|
||||
};
|
||||
|
||||
let now = Utc::now();
|
||||
|
||||
for entry in entries.iter_mut() {
|
||||
// Core memories are evergreen — never decay
|
||||
if entry.category == MemoryCategory::Core {
|
||||
continue;
|
||||
}
|
||||
|
||||
let score = match entry.score {
|
||||
Some(s) => s,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
|
||||
Ok(dt) => dt.with_timezone(&Utc),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let age_days = now.signed_duration_since(ts).num_seconds().max(0) as f64 / 86_400.0;
|
||||
|
||||
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
|
||||
entry.score = Some(score * decay_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: "1".into(),
|
||||
key: "test".into(),
|
||||
content: "value".into(),
|
||||
category,
|
||||
timestamp: timestamp.into(),
|
||||
session_id: None,
|
||||
score,
|
||||
namespace: "default".into(),
|
||||
importance: None,
|
||||
superseded_by: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn recent_rfc3339() -> String {
|
||||
Utc::now().to_rfc3339()
|
||||
}
|
||||
|
||||
fn days_ago_rfc3339(days: i64) -> String {
|
||||
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn core_memories_are_never_decayed() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Core,
|
||||
Some(0.9),
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recent_entry_score_barely_changes() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.8),
|
||||
&recent_rfc3339(),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.8).abs() < 0.01,
|
||||
"recent entry should barely decay, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_half_life_halves_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(7),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.5).abs() < 0.05,
|
||||
"score after one half-life should be ~0.5, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_half_lives_quarters_score() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(1.0),
|
||||
&days_ago_rfc3339(14),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
let decayed = entries[0].score.unwrap();
|
||||
assert!(
|
||||
(decayed - 0.25).abs() < 0.05,
|
||||
"score after two half-lives should be ~0.25, got {decayed}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_score_entry_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
&days_ago_rfc3339(30),
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unparseable_timestamp_is_unchanged() {
|
||||
let mut entries = vec![make_entry(
|
||||
MemoryCategory::Conversation,
|
||||
Some(0.9),
|
||||
"not-a-date",
|
||||
)];
|
||||
apply_time_decay(&mut entries, 7.0);
|
||||
assert_eq!(entries[0].score, Some(0.9));
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod conflict;
|
||||
pub mod consolidation;
|
||||
pub mod decay;
|
||||
pub mod embeddings;
|
||||
pub mod hygiene;
|
||||
pub mod importance;
|
||||
|
||||
@@ -507,6 +507,7 @@ mod tests {
|
||||
max_images: 1,
|
||||
max_image_size_mb: 5,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
@@ -549,6 +550,7 @@ mod tests {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
|
||||
@@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
link_enricher: crate::config::LinkEnricherConfig::default(),
|
||||
text_browser: crate::config::TextBrowserConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
@@ -605,6 +606,7 @@ async fn run_quick_setup_with_home(
|
||||
http_request: crate::config::HttpRequestConfig::default(),
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
link_enricher: crate::config::LinkEnricherConfig::default(),
|
||||
text_browser: crate::config::TextBrowserConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
@@ -4193,6 +4195,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
device_id: detected_device_id,
|
||||
room_id,
|
||||
allowed_users,
|
||||
allowed_rooms: vec![],
|
||||
interrupt_on_new_message: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -106,8 +106,14 @@ impl SkillCreator {
|
||||
// Trim leading/trailing hyphens, then truncate.
|
||||
let trimmed = collapsed.trim_matches('-');
|
||||
if trimmed.len() > 64 {
|
||||
// Truncate at a hyphen boundary if possible.
|
||||
let truncated = &trimmed[..64];
|
||||
// Find the nearest valid character boundary at or before 64 bytes.
|
||||
let safe_index = trimmed
|
||||
.char_indices()
|
||||
.map(|(i, _)| i)
|
||||
.take_while(|&i| i <= 64)
|
||||
.last()
|
||||
.unwrap_or(0);
|
||||
let truncated = &trimmed[..safe_index];
|
||||
truncated.trim_end_matches('-').to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
|
||||
+89
-14
@@ -738,15 +738,47 @@ pub fn skills_to_prompt_with_mode(
|
||||
}
|
||||
|
||||
if !skill.tools.is_empty() {
|
||||
let _ = writeln!(prompt, " <tools>");
|
||||
for tool in &skill.tools {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
// Tools with known kinds (shell, script, http) are registered as
|
||||
// callable tool specs and can be invoked directly via function calling.
|
||||
// We note them here for context but mark them as callable.
|
||||
let registered: Vec<_> = skill
|
||||
.tools
|
||||
.iter()
|
||||
.filter(|t| matches!(t.kind.as_str(), "shell" | "script" | "http"))
|
||||
.collect();
|
||||
let unregistered: Vec<_> = skill
|
||||
.tools
|
||||
.iter()
|
||||
.filter(|t| !matches!(t.kind.as_str(), "shell" | "script" | "http"))
|
||||
.collect();
|
||||
|
||||
if !registered.is_empty() {
|
||||
let _ = writeln!(prompt, " <callable_tools hint=\"These are registered as callable tool specs. Invoke them directly by name ({{}}.{{}}) instead of using shell.\">");
|
||||
for tool in ®istered {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(
|
||||
&mut prompt,
|
||||
8,
|
||||
"name",
|
||||
&format!("{}.{}", skill.name, tool.name),
|
||||
);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </callable_tools>");
|
||||
}
|
||||
|
||||
if !unregistered.is_empty() {
|
||||
let _ = writeln!(prompt, " <tools>");
|
||||
for tool in &unregistered {
|
||||
let _ = writeln!(prompt, " <tool>");
|
||||
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
|
||||
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
|
||||
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
|
||||
let _ = writeln!(prompt, " </tool>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </tools>");
|
||||
}
|
||||
let _ = writeln!(prompt, " </tools>");
|
||||
}
|
||||
|
||||
let _ = writeln!(prompt, " </skill>");
|
||||
@@ -756,6 +788,47 @@ pub fn skills_to_prompt_with_mode(
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Convert skill tools into callable `Tool` trait objects.
|
||||
///
|
||||
/// Each skill's `[[tools]]` entries are converted to either `SkillShellTool`
|
||||
/// (for `shell`/`script` kinds) or `SkillHttpTool` (for `http` kind),
|
||||
/// enabling them to appear as first-class callable tool specs rather than
|
||||
/// only as XML in the system prompt.
|
||||
pub fn skills_to_tools(
|
||||
skills: &[Skill],
|
||||
security: std::sync::Arc<crate::security::SecurityPolicy>,
|
||||
) -> Vec<Box<dyn crate::tools::traits::Tool>> {
|
||||
let mut tools: Vec<Box<dyn crate::tools::traits::Tool>> = Vec::new();
|
||||
for skill in skills {
|
||||
for tool in &skill.tools {
|
||||
match tool.kind.as_str() {
|
||||
"shell" | "script" => {
|
||||
tools.push(Box::new(crate::tools::skill_tool::SkillShellTool::new(
|
||||
&skill.name,
|
||||
tool,
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
"http" => {
|
||||
tools.push(Box::new(crate::tools::skill_http::SkillHttpTool::new(
|
||||
&skill.name,
|
||||
tool,
|
||||
)));
|
||||
}
|
||||
other => {
|
||||
tracing::warn!(
|
||||
"Unknown skill tool kind '{}' for {}.{}, skipping",
|
||||
other,
|
||||
skill.name,
|
||||
tool.name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tools
|
||||
}
|
||||
|
||||
/// Get the skills directory path
|
||||
pub fn skills_dir(workspace_dir: &Path) -> PathBuf {
|
||||
workspace_dir.join("skills")
|
||||
@@ -1517,10 +1590,10 @@ command = "echo hello"
|
||||
assert!(prompt.contains("read_skill(name)"));
|
||||
assert!(!prompt.contains("<instructions>"));
|
||||
assert!(!prompt.contains("<instruction>Do the thing.</instruction>"));
|
||||
// Compact mode should still include tools so the LLM knows about them
|
||||
assert!(prompt.contains("<tools>"));
|
||||
assert!(prompt.contains("<name>run</name>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
// Compact mode should still include tools so the LLM knows about them.
|
||||
// Registered tools (shell/script/http) appear under <callable_tools>.
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>test.run</name>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1710,9 +1783,11 @@ description = "Bare minimum"
|
||||
}];
|
||||
let prompt = skills_to_prompt(&skills, Path::new("/tmp"));
|
||||
assert!(prompt.contains("weather"));
|
||||
assert!(prompt.contains("<name>get_weather</name>"));
|
||||
// Registered tools (shell kind) now appear under <callable_tools> with
|
||||
// prefixed names (skill_name.tool_name).
|
||||
assert!(prompt.contains("<callable_tools"));
|
||||
assert!(prompt.contains("<name>weather.get_weather</name>"));
|
||||
assert!(prompt.contains("<description>Fetch forecast</description>"));
|
||||
assert!(prompt.contains("<kind>shell</kind>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -81,6 +81,8 @@ pub mod screenshot;
|
||||
pub mod security_ops;
|
||||
pub mod sessions;
|
||||
pub mod shell;
|
||||
pub mod skill_http;
|
||||
pub mod skill_tool;
|
||||
pub mod swarm;
|
||||
pub mod text_browser;
|
||||
pub mod tool_search;
|
||||
@@ -156,6 +158,10 @@ pub use screenshot::ScreenshotTool;
|
||||
pub use security_ops::SecurityOpsTool;
|
||||
pub use sessions::{SessionsHistoryTool, SessionsListTool, SessionsSendTool};
|
||||
pub use shell::ShellTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use skill_http::SkillHttpTool;
|
||||
#[allow(unused_imports)]
|
||||
pub use skill_tool::SkillShellTool;
|
||||
pub use swarm::SwarmTool;
|
||||
pub use text_browser::TextBrowserTool;
|
||||
pub use tool_search::ToolSearchTool;
|
||||
@@ -257,6 +263,33 @@ pub fn default_tools_with_runtime(
|
||||
]
|
||||
}
|
||||
|
||||
/// Register skill-defined tools into an existing tool registry.
|
||||
///
|
||||
/// Converts each skill's `[[tools]]` entries into callable `Tool` implementations
|
||||
/// and appends them to the registry. Skill tools that would shadow a built-in tool
|
||||
/// name are skipped with a warning.
|
||||
pub fn register_skill_tools(
|
||||
tools_registry: &mut Vec<Box<dyn Tool>>,
|
||||
skills: &[crate::skills::Skill],
|
||||
security: Arc<SecurityPolicy>,
|
||||
) {
|
||||
let skill_tools = crate::skills::skills_to_tools(skills, security);
|
||||
let existing_names: std::collections::HashSet<String> = tools_registry
|
||||
.iter()
|
||||
.map(|t| t.name().to_string())
|
||||
.collect();
|
||||
for tool in skill_tools {
|
||||
if existing_names.contains(tool.name()) {
|
||||
tracing::warn!(
|
||||
"Skill tool '{}' shadows built-in tool, skipping",
|
||||
tool.name()
|
||||
);
|
||||
} else {
|
||||
tools_registry.push(tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create full tool registry including memory tools and optional Composio
|
||||
#[allow(clippy::implicit_hasher, clippy::too_many_arguments)]
|
||||
pub fn all_tools(
|
||||
@@ -458,6 +491,7 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(WebSearchTool::new_with_config(
|
||||
root_config.web_search.provider.clone(),
|
||||
root_config.web_search.brave_api_key.clone(),
|
||||
root_config.web_search.searxng_instance_url.clone(),
|
||||
root_config.web_search.max_results,
|
||||
root_config.web_search.timeout_secs,
|
||||
root_config.config_path.clone(),
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
//! HTTP-based tool derived from a skill's `[[tools]]` section.
|
||||
//!
|
||||
//! Each `SkillTool` with `kind = "http"` is converted into a `SkillHttpTool`
|
||||
//! that implements the `Tool` trait. The command field is used as the URL
|
||||
//! template and args are substituted as query parameters or path segments.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Maximum response body size (1 MB).
|
||||
const MAX_RESPONSE_BYTES: usize = 1_048_576;
|
||||
/// HTTP request timeout (seconds).
|
||||
const HTTP_TIMEOUT_SECS: u64 = 30;
|
||||
|
||||
/// A tool derived from a skill's `[[tools]]` section that makes HTTP requests.
|
||||
pub struct SkillHttpTool {
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
url_template: String,
|
||||
args: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl SkillHttpTool {
|
||||
/// Create a new skill HTTP tool.
|
||||
///
|
||||
/// The tool name is prefixed with the skill name (`skill_name.tool_name`)
|
||||
/// to prevent collisions with built-in tools.
|
||||
pub fn new(skill_name: &str, tool: &crate::skills::SkillTool) -> Self {
|
||||
Self {
|
||||
tool_name: format!("{}.{}", skill_name, tool.name),
|
||||
tool_description: tool.description.clone(),
|
||||
url_template: tool.command.clone(),
|
||||
args: tool.args.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_parameters_schema(&self) -> serde_json::Value {
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut required = Vec::new();
|
||||
|
||||
for (name, description) in &self.args {
|
||||
properties.insert(
|
||||
name.clone(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": description
|
||||
}),
|
||||
);
|
||||
required.push(serde_json::Value::String(name.clone()));
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
})
|
||||
}
|
||||
|
||||
/// Substitute `{{arg_name}}` placeholders in the URL template with
|
||||
/// the provided argument values.
|
||||
fn substitute_args(&self, args: &serde_json::Value) -> String {
|
||||
let mut url = self.url_template.clone();
|
||||
if let Some(obj) = args.as_object() {
|
||||
for (key, value) in obj {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
let replacement = value.as_str().unwrap_or_default();
|
||||
url = url.replace(&placeholder, replacement);
|
||||
}
|
||||
}
|
||||
url
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillHttpTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.tool_description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.build_parameters_schema()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let url = self.substitute_args(&args);
|
||||
|
||||
// Validate URL scheme
|
||||
if !url.starts_with("http://") && !url.starts_with("https://") {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Only http:// and https:// URLs are allowed, got: {url}"
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
|
||||
.build()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build HTTP client: {e}"))?;
|
||||
|
||||
let response = match client.get(&url).send().await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("HTTP request failed: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let body = match response.bytes().await {
|
||||
Ok(bytes) => {
|
||||
let mut text = String::from_utf8_lossy(&bytes).to_string();
|
||||
if text.len() > MAX_RESPONSE_BYTES {
|
||||
let mut b = MAX_RESPONSE_BYTES.min(text.len());
|
||||
while b > 0 && !text.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
text.truncate(b);
|
||||
text.push_str("\n... [response truncated at 1MB]");
|
||||
}
|
||||
text
|
||||
}
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to read response body: {e}")),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
success: status.is_success(),
|
||||
output: body,
|
||||
error: if status.is_success() {
|
||||
None
|
||||
} else {
|
||||
Some(format!("HTTP {}", status))
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::skills::SkillTool;
|
||||
|
||||
fn sample_http_tool() -> SkillTool {
|
||||
let mut args = HashMap::new();
|
||||
args.insert("city".to_string(), "City name to look up".to_string());
|
||||
|
||||
SkillTool {
|
||||
name: "get_weather".to_string(),
|
||||
description: "Fetch weather for a city".to_string(),
|
||||
kind: "http".to_string(),
|
||||
command: "https://api.example.com/weather?city={{city}}".to_string(),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_name_is_prefixed() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
assert_eq!(tool.name(), "weather_skill.get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_description() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
assert_eq!(tool.description(), "Fetch weather for a city");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_parameters_schema() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["city"].is_object());
|
||||
assert_eq!(schema["properties"]["city"]["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_substitute_args() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let result = tool.substitute_args(&serde_json::json!({"city": "London"}));
|
||||
assert_eq!(result, "https://api.example.com/weather?city=London");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_spec_roundtrip() {
|
||||
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "weather_skill.get_weather");
|
||||
assert_eq!(spec.description, "Fetch weather for a city");
|
||||
assert_eq!(spec.parameters["type"], "object");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_http_tool_empty_args() {
|
||||
let st = SkillTool {
|
||||
name: "ping".to_string(),
|
||||
description: "Ping endpoint".to_string(),
|
||||
kind: "http".to_string(),
|
||||
command: "https://api.example.com/ping".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillHttpTool::new("s", &st);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].as_object().unwrap().is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
//! Shell-based tool derived from a skill's `[[tools]]` section.
|
||||
//!
|
||||
//! Each `SkillTool` with `kind = "shell"` or `kind = "script"` is converted
|
||||
//! into a `SkillShellTool` that implements the `Tool` trait. The tool name is
|
||||
//! prefixed with the skill name (e.g. `my_skill.run_lint`) to avoid collisions
|
||||
//! with built-in tools.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Maximum execution time for a skill shell command (seconds).
|
||||
const SKILL_SHELL_TIMEOUT_SECS: u64 = 60;
|
||||
/// Maximum output size in bytes (1 MB).
|
||||
const MAX_OUTPUT_BYTES: usize = 1_048_576;
|
||||
|
||||
/// A tool derived from a skill's `[[tools]]` section that executes shell commands.
|
||||
pub struct SkillShellTool {
|
||||
tool_name: String,
|
||||
tool_description: String,
|
||||
command_template: String,
|
||||
args: HashMap<String, String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl SkillShellTool {
|
||||
/// Create a new skill shell tool.
|
||||
///
|
||||
/// The tool name is prefixed with the skill name (`skill_name.tool_name`)
|
||||
/// to prevent collisions with built-in tools.
|
||||
pub fn new(
|
||||
skill_name: &str,
|
||||
tool: &crate::skills::SkillTool,
|
||||
security: Arc<SecurityPolicy>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_name: format!("{}.{}", skill_name, tool.name),
|
||||
tool_description: tool.description.clone(),
|
||||
command_template: tool.command.clone(),
|
||||
args: tool.args.clone(),
|
||||
security,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_parameters_schema(&self) -> serde_json::Value {
|
||||
let mut properties = serde_json::Map::new();
|
||||
let mut required = Vec::new();
|
||||
|
||||
for (name, description) in &self.args {
|
||||
properties.insert(
|
||||
name.clone(),
|
||||
serde_json::json!({
|
||||
"type": "string",
|
||||
"description": description
|
||||
}),
|
||||
);
|
||||
required.push(serde_json::Value::String(name.clone()));
|
||||
}
|
||||
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
})
|
||||
}
|
||||
|
||||
/// Substitute `{{arg_name}}` placeholders in the command template with
|
||||
/// the provided argument values. Unknown placeholders are left as-is.
|
||||
fn substitute_args(&self, args: &serde_json::Value) -> String {
|
||||
let mut command = self.command_template.clone();
|
||||
if let Some(obj) = args.as_object() {
|
||||
for (key, value) in obj {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
let replacement = value.as_str().unwrap_or_default();
|
||||
command = command.replace(&placeholder, replacement);
|
||||
}
|
||||
}
|
||||
command
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SkillShellTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.tool_name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.tool_description
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
self.build_parameters_schema()
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let command = self.substitute_args(&args);
|
||||
|
||||
// Rate limit check
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Security validation — always requires explicit approval (approved=true)
|
||||
// since skill tools are user-defined and should be treated as medium-risk.
|
||||
match self.security.validate_command_execution(&command, true) {
|
||||
Ok(_) => {}
|
||||
Err(reason) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(reason),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(path) = self.security.forbidden_path_argument(&command) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Path blocked by security policy: {path}")),
|
||||
});
|
||||
}
|
||||
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Build and execute the command
|
||||
let mut cmd = tokio::process::Command::new("sh");
|
||||
cmd.arg("-c").arg(&command);
|
||||
cmd.current_dir(&self.security.workspace_dir);
|
||||
cmd.env_clear();
|
||||
|
||||
// Only pass safe environment variables
|
||||
for var in &[
|
||||
"PATH", "HOME", "TERM", "LANG", "LC_ALL", "USER", "SHELL", "TMPDIR",
|
||||
] {
|
||||
if let Ok(val) = std::env::var(var) {
|
||||
cmd.env(var, val);
|
||||
}
|
||||
}
|
||||
|
||||
let result =
|
||||
tokio::time::timeout(Duration::from_secs(SKILL_SHELL_TIMEOUT_SECS), cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let mut stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
if stdout.len() > MAX_OUTPUT_BYTES {
|
||||
let mut b = MAX_OUTPUT_BYTES.min(stdout.len());
|
||||
while b > 0 && !stdout.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
stdout.truncate(b);
|
||||
stdout.push_str("\n... [output truncated at 1MB]");
|
||||
}
|
||||
if stderr.len() > MAX_OUTPUT_BYTES {
|
||||
let mut b = MAX_OUTPUT_BYTES.min(stderr.len());
|
||||
while b > 0 && !stderr.is_char_boundary(b) {
|
||||
b -= 1;
|
||||
}
|
||||
stderr.truncate(b);
|
||||
stderr.push_str("\n... [stderr truncated at 1MB]");
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!("Failed to execute command: {e}")),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Command timed out after {SKILL_SHELL_TIMEOUT_SECS}s and was killed"
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use crate::skills::SkillTool;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn sample_skill_tool() -> SkillTool {
|
||||
let mut args = HashMap::new();
|
||||
args.insert("file".to_string(), "The file to lint".to_string());
|
||||
args.insert(
|
||||
"format".to_string(),
|
||||
"Output format (json|text)".to_string(),
|
||||
);
|
||||
|
||||
SkillTool {
|
||||
name: "run_lint".to_string(),
|
||||
description: "Run the linter on a file".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "lint --file {{file}} --format {{format}}".to_string(),
|
||||
args,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_name_is_prefixed() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
assert_eq!(tool.name(), "my_skill.run_lint");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_description() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
assert_eq!(tool.description(), "Run the linter on a file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_parameters_schema() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"]["file"].is_object());
|
||||
assert_eq!(schema["properties"]["file"]["type"], "string");
|
||||
assert!(schema["properties"]["format"].is_object());
|
||||
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required should be array");
|
||||
assert_eq!(required.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_substitute_args() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let result = tool.substitute_args(&serde_json::json!({
|
||||
"file": "src/main.rs",
|
||||
"format": "json"
|
||||
}));
|
||||
assert_eq!(result, "lint --file src/main.rs --format json");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_substitute_missing_arg() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let result = tool.substitute_args(&serde_json::json!({"file": "test.rs"}));
|
||||
// Missing {{format}} placeholder stays in the command
|
||||
assert!(result.contains("{{format}}"));
|
||||
assert!(result.contains("test.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_empty_args_schema() {
|
||||
let st = SkillTool {
|
||||
name: "simple".to_string(),
|
||||
description: "Simple tool".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "echo hello".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillShellTool::new("s", &st, test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"].as_object().unwrap().is_empty());
|
||||
assert!(schema["required"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn skill_shell_tool_executes_echo() {
|
||||
let st = SkillTool {
|
||||
name: "hello".to_string(),
|
||||
description: "Say hello".to_string(),
|
||||
kind: "shell".to_string(),
|
||||
command: "echo hello-skill".to_string(),
|
||||
args: HashMap::new(),
|
||||
};
|
||||
let tool = SkillShellTool::new("test", &st, test_security());
|
||||
let result = tool.execute(serde_json::json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("hello-skill"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_shell_tool_spec_roundtrip() {
|
||||
let tool = SkillShellTool::new("my_skill", &sample_skill_tool(), test_security());
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "my_skill.run_lint");
|
||||
assert_eq!(spec.description, "Run the linter on a file");
|
||||
assert_eq!(spec.parameters["type"], "object");
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
pub enum WebSearchProviderRoute {
|
||||
DuckDuckGo,
|
||||
Brave,
|
||||
SearXNG,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@@ -13,6 +14,7 @@ pub struct WebSearchProviderResolution {
|
||||
|
||||
pub const DEFAULT_WEB_SEARCH_PROVIDER: &str = "duckduckgo";
|
||||
const BRAVE_PROVIDER: &str = "brave";
|
||||
const SEARXNG_PROVIDER: &str = "searxng";
|
||||
|
||||
pub fn resolve_web_search_provider(raw_provider: &str) -> WebSearchProviderResolution {
|
||||
let normalized = raw_provider.trim().to_ascii_lowercase();
|
||||
@@ -29,6 +31,11 @@ pub fn resolve_web_search_provider(raw_provider: &str) -> WebSearchProviderResol
|
||||
canonical_provider: BRAVE_PROVIDER,
|
||||
used_fallback: false,
|
||||
},
|
||||
"searxng" | "searx" | "searx-ng" | "searx_ng" => WebSearchProviderResolution {
|
||||
route: WebSearchProviderRoute::SearXNG,
|
||||
canonical_provider: SEARXNG_PROVIDER,
|
||||
used_fallback: false,
|
||||
},
|
||||
_ => WebSearchProviderResolution {
|
||||
route: WebSearchProviderRoute::DuckDuckGo,
|
||||
canonical_provider: DEFAULT_WEB_SEARCH_PROVIDER,
|
||||
@@ -63,6 +70,17 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_aliases_to_searxng() {
|
||||
let searxng_aliases = ["searxng", "searx", "searx-ng", "searx_ng"];
|
||||
for alias in searxng_aliases {
|
||||
let resolved = resolve_web_search_provider(alias);
|
||||
assert_eq!(resolved.route, WebSearchProviderRoute::SearXNG);
|
||||
assert_eq!(resolved.canonical_provider, SEARXNG_PROVIDER);
|
||||
assert!(!resolved.used_fallback);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_unknown_provider_falls_back_to_default() {
|
||||
let resolved = resolve_web_search_provider("bing");
|
||||
|
||||
@@ -7,7 +7,8 @@ use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Web search tool for searching the internet.
|
||||
/// Supports multiple providers: DuckDuckGo (free), Brave (requires API key).
|
||||
/// Supports multiple providers: DuckDuckGo (free), Brave (requires API key),
|
||||
/// SearXNG (self-hosted, requires instance URL).
|
||||
///
|
||||
/// The Brave API key is resolved lazily at execution time: if the boot-time key
|
||||
/// is missing or still encrypted, the tool re-reads `config.toml`, decrypts the
|
||||
@@ -18,6 +19,8 @@ pub struct WebSearchTool {
|
||||
provider: String,
|
||||
/// Boot-time key snapshot (may be `None` if not yet configured at startup).
|
||||
boot_brave_api_key: Option<String>,
|
||||
/// SearXNG instance base URL (e.g. "https://searx.example.com").
|
||||
searxng_instance_url: Option<String>,
|
||||
max_results: usize,
|
||||
timeout_secs: u64,
|
||||
/// Path to `config.toml` for lazy re-read of keys at execution time.
|
||||
@@ -36,6 +39,7 @@ impl WebSearchTool {
|
||||
Self {
|
||||
provider: provider.trim().to_lowercase(),
|
||||
boot_brave_api_key: brave_api_key,
|
||||
searxng_instance_url: None,
|
||||
max_results: max_results.clamp(1, 10),
|
||||
timeout_secs: timeout_secs.max(1),
|
||||
config_path: PathBuf::new(),
|
||||
@@ -51,6 +55,7 @@ impl WebSearchTool {
|
||||
pub fn new_with_config(
|
||||
provider: String,
|
||||
brave_api_key: Option<String>,
|
||||
searxng_instance_url: Option<String>,
|
||||
max_results: usize,
|
||||
timeout_secs: u64,
|
||||
config_path: PathBuf,
|
||||
@@ -59,6 +64,7 @@ impl WebSearchTool {
|
||||
Self {
|
||||
provider: provider.trim().to_lowercase(),
|
||||
boot_brave_api_key: brave_api_key,
|
||||
searxng_instance_url,
|
||||
max_results: max_results.clamp(1, 10),
|
||||
timeout_secs: timeout_secs.max(1),
|
||||
config_path,
|
||||
@@ -248,6 +254,105 @@ impl WebSearchTool {
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
/// Resolve the SearXNG instance URL from the boot-time config or by
|
||||
/// re-reading `config.toml` at runtime.
|
||||
fn resolve_searxng_instance_url(&self) -> anyhow::Result<String> {
|
||||
if let Some(ref url) = self.searxng_instance_url {
|
||||
if !url.is_empty() {
|
||||
return Ok(url.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: re-read config.toml to pick up values set after boot.
|
||||
let contents = std::fs::read_to_string(&self.config_path).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to read config file {} for SearXNG instance URL: {e}",
|
||||
self.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let config: crate::config::Config = toml::from_str(&contents).map_err(|e| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to parse config file {} for SearXNG instance URL: {e}",
|
||||
self.config_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
config
|
||||
.web_search
|
||||
.searxng_instance_url
|
||||
.filter(|u| !u.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"SearXNG instance URL not configured. Set [web_search] searxng_instance_url \
|
||||
in config.toml or the SEARXNG_INSTANCE_URL environment variable."
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
async fn search_searxng(&self, query: &str) -> anyhow::Result<String> {
|
||||
let instance_url = self.resolve_searxng_instance_url()?;
|
||||
let base_url = instance_url.trim_end_matches('/');
|
||||
|
||||
let encoded_query = urlencoding::encode(query);
|
||||
let search_url = format!(
|
||||
"{}/search?q={}&format=json&pageno=1",
|
||||
base_url, encoded_query
|
||||
);
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.timeout_secs))
|
||||
.user_agent("ZeroClaw/1.0")
|
||||
.build()?;
|
||||
|
||||
let response = client
|
||||
.get(&search_url)
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("SearXNG search failed with status: {}", response.status());
|
||||
}
|
||||
|
||||
let json: serde_json::Value = response.json().await?;
|
||||
self.parse_searxng_results(&json, query)
|
||||
}
|
||||
|
||||
fn parse_searxng_results(
|
||||
&self,
|
||||
json: &serde_json::Value,
|
||||
query: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let results = json
|
||||
.get("results")
|
||||
.and_then(|r| r.as_array())
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid SearXNG API response"))?;
|
||||
|
||||
if results.is_empty() {
|
||||
return Ok(format!("No results found for: {}", query));
|
||||
}
|
||||
|
||||
let mut lines = vec![format!("Search results for: {} (via SearXNG)", query)];
|
||||
|
||||
for (i, result) in results.iter().take(self.max_results).enumerate() {
|
||||
let title = result
|
||||
.get("title")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("No title");
|
||||
let url = result.get("url").and_then(|u| u.as_str()).unwrap_or("");
|
||||
let content = result.get("content").and_then(|c| c.as_str()).unwrap_or("");
|
||||
|
||||
lines.push(format!("{}. {}", i + 1, title));
|
||||
lines.push(format!(" {}", url));
|
||||
if !content.is_empty() {
|
||||
lines.push(format!(" {}", content));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_ddg_redirect_url(raw_url: &str) -> String {
|
||||
@@ -314,6 +419,7 @@ impl Tool for WebSearchTool {
|
||||
let result = match resolution.route {
|
||||
WebSearchProviderRoute::DuckDuckGo => self.search_duckduckgo(query).await?,
|
||||
WebSearchProviderRoute::Brave => self.search_brave(query).await?,
|
||||
WebSearchProviderRoute::SearXNG => self.search_searxng(query).await?,
|
||||
};
|
||||
|
||||
Ok(ToolResult {
|
||||
@@ -443,8 +549,15 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// No boot key -- forces reload from config
|
||||
let tool =
|
||||
WebSearchTool::new_with_config("brave".to_string(), None, 5, 15, config_path, false);
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
false,
|
||||
);
|
||||
let key = tool.resolve_brave_api_key().unwrap();
|
||||
assert_eq!(key, "fresh-key-from-disk");
|
||||
}
|
||||
@@ -466,6 +579,7 @@ mod tests {
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
Some(encrypted),
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
@@ -475,6 +589,111 @@ mod tests {
|
||||
assert_eq!(key, "brave-secret-key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_searxng_without_instance_url() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
std::fs::write(&config_path, "[web_search]\n").unwrap();
|
||||
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"searxng".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
false,
|
||||
);
|
||||
let result = tool.execute(json!({"query": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("SearXNG instance URL not configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_searxng_results_empty() {
|
||||
let tool = WebSearchTool::new("searxng".to_string(), None, 5, 15);
|
||||
let json = serde_json::json!({"results": []});
|
||||
let result = tool.parse_searxng_results(&json, "test").unwrap();
|
||||
assert!(result.contains("No results found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_searxng_results_with_data() {
|
||||
let tool = WebSearchTool::new("searxng".to_string(), None, 5, 15);
|
||||
let json = serde_json::json!({
|
||||
"results": [
|
||||
{
|
||||
"title": "SearXNG Example",
|
||||
"url": "https://example.com",
|
||||
"content": "A privacy-respecting metasearch engine"
|
||||
},
|
||||
{
|
||||
"title": "Another Result",
|
||||
"url": "https://example.org",
|
||||
"content": "More information here"
|
||||
}
|
||||
]
|
||||
});
|
||||
let result = tool.parse_searxng_results(&json, "test").unwrap();
|
||||
assert!(result.contains("SearXNG Example"));
|
||||
assert!(result.contains("https://example.com"));
|
||||
assert!(result.contains("A privacy-respecting metasearch engine"));
|
||||
assert!(result.contains("via SearXNG"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_searxng_results_invalid_response() {
|
||||
let tool = WebSearchTool::new("searxng".to_string(), None, 5, 15);
|
||||
let json = serde_json::json!({"error": "bad request"});
|
||||
let result = tool.parse_searxng_results(&json, "test");
|
||||
assert!(result.is_err());
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Invalid SearXNG API response"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_searxng_instance_url_from_boot() {
|
||||
let tool = WebSearchTool {
|
||||
provider: "searxng".to_string(),
|
||||
boot_brave_api_key: None,
|
||||
searxng_instance_url: Some("https://searx.example.com".to_string()),
|
||||
max_results: 5,
|
||||
timeout_secs: 15,
|
||||
config_path: PathBuf::new(),
|
||||
secrets_encrypt: false,
|
||||
};
|
||||
let url = tool.resolve_searxng_instance_url().unwrap();
|
||||
assert_eq!(url, "https://searx.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_searxng_instance_url_reloads_from_config() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
std::fs::write(
|
||||
&config_path,
|
||||
"[web_search]\nsearxng_instance_url = \"https://search.local\"\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"searxng".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path,
|
||||
false,
|
||||
);
|
||||
let url = tool.resolve_searxng_instance_url().unwrap();
|
||||
assert_eq!(url, "https://search.local");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_brave_api_key_picks_up_runtime_update() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
@@ -486,6 +705,7 @@ mod tests {
|
||||
let tool = WebSearchTool::new_with_config(
|
||||
"brave".to_string(),
|
||||
None,
|
||||
None,
|
||||
5,
|
||||
15,
|
||||
config_path.clone(),
|
||||
|
||||
@@ -401,7 +401,7 @@ fn config_nested_optional_sections_default_when_absent() {
|
||||
assert!(parsed.channels_config.telegram.is_none());
|
||||
assert!(!parsed.composio.enabled);
|
||||
assert!(parsed.composio.api_key.is_none());
|
||||
assert!(!parsed.browser.enabled);
|
||||
assert!(parsed.browser.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user