Compare commits

...

89 Commits

Author SHA1 Message Date
Argenis 41dd23175f chore(ci): unify release pipeline for full auto-sync (#4283)
- Add tag-push trigger to Release Stable workflow so `git push origin v0.5.9`
  auto-triggers the full pipeline (builds, Docker, crates.io, website,
  Scoop, AUR, Homebrew, tweet) in one shot
- Add Homebrew Core as downstream job (was manual-only, never auto-triggered)
- Add `workflow_call` to pub-homebrew-core.yml so it can be called from
  the stable release workflow
- Skip beta releases on version bump commits (prevents beta/stable race)
- Skip auto crates.io publish when stable tag exists (prevents double-publish)
- Auto-create git tag on manual dispatch so tag always exists for downstream
- Fix cut_release_tag.sh to reference correct workflow name
2026-03-22 20:56:47 -04:00
Argenis 864d754b56 chore: bump version to 0.5.9 (#4282)
Release 0.5.9 includes:
- feat: enable internet access by default (#4270)
- feat: add browser automation skill and VNC setup scripts (#4281)
- feat: add voice message transcription support (#4278)
- feat: add image and file support for Feishu/Lark channel (#4280)
- feat: declarative cron job configuration (#4045)
- feat: add SearXNG search provider support (#4272)
- feat: register skill tools as callable tool specs (#4040)
- feat: named sessions with reconnect and validation (#4275)
- feat: restore time-decay scoring for memory (#4274)
- fix: prevent thinking level prefix leak across turns (#4277)
- fix: link enricher title extraction byte offset bug (#4271)
- fix: WhatsApp Web delivery channel with backend validation (#4273)
2026-03-22 20:23:55 -04:00
Argenis ccd52f3394 feat: add browser automation skill and VNC setup scripts (#4281)
* feat: add browser automation skill and VNC setup scripts

- Add browser skill template for agent-browser CLI integration
- Add VNC setup scripts for GUI browser access (Xvfb, x11vnc, noVNC)
- Add comprehensive browser setup documentation
- Enables headless browser automation for AI agents

Tested on: Ubuntu 24.04, ZeroClaw 0.5.7, agent-browser 0.21.4

Co-authored-by: OpenClaw Assistant

* fix(docs): fix markdown lint errors and update browser config docs

- SKILL.md: add blank lines around headings (MD022)
- browser-setup.md: wrap bare URLs in angle brackets (MD034)
- browser-setup.md: rename duplicate "Access" heading (MD024)
- Update config examples to reflect browser enabled by default
- Add examples for restricting/disabling browser via config

---------

Co-authored-by: Argenis <theonlyhennygod@users.noreply.github.com>
2026-03-22 19:35:20 -04:00
Argenis eb01aa451d feat: add voice message transcription support (#4278)
Closes #4231. Adds voice memo detection and transcription for Slack and Discord channels. Audio files are downloaded, transcribed via the existing transcription module, and passed as text to the LLM.
2026-03-22 19:18:07 -04:00
Argenis c785b45f2d feat: add image and file support for Feishu/Lark channel (#4280)
Closes #4235. Adds image and file message type handling in the Lark channel - downloads images/files via Lark API, detects MIME types, and passes content to the model for analysis.
2026-03-22 19:14:43 -04:00
Argenis ffb8b81f90 fix(agent): prevent thinking level prefix from leaking across turns (#4277)
* feat(agent): add thinking/reasoning level control per message

Users can set reasoning depth via /think:high etc. with resolution
hierarchy (inline > session > config > default). 6 levels from Off
to Max. Adjusts temperature and system prompt.

* fix(agent): prevent thinking level prefix from leaking across interactive turns

system_prompt was mutated in place for the first message's thinking
directive, then used as the "baseline" for restoration after each
interactive turn. This caused the first turn's thinking prefix to
persist across all subsequent turns.

Fix: save the original system_prompt before any thinking modifications
and restore from that saved copy between turns.
2026-03-22 19:09:12 -04:00
Argenis 65f856d710 fix(channels): link enricher title extraction byte offset bug (#4271)
* feat(channels): add automatic link understanding for inbound messages

Auto-detects URLs in inbound messages, fetches content, extracts
title+summary, and enriches the message before the agent sees it.
Includes SSRF protection (rejects private IPs).

* fix(channels): use lowercased string for title extraction to prevent byte offset mismatch

extract_title used byte offsets from the lowercased HTML to index into
the original HTML. Multi-byte characters whose lowercase form has a
different byte length (e.g. İ → i̇) would produce wrong slices or panics.
Fix: extract from the lowercased string directly. Add multibyte test.
2026-03-22 19:09:09 -04:00
Argenis 1682620377 feat(tools): enable internet access by default (#4270)
* feat(tools): enable internet access by default

Enable web_fetch, web_search, http_request, and browser tools by
default so ZeroClaw has internet access out of the box. Security
remains fully toggleable via config (set enabled = false to disable).

- web_fetch: enabled with allowed_domains = ["*"]
- web_search: enabled with DuckDuckGo (free, no API key)
- http_request: enabled with allowed_domains = ["*"]
- browser: enabled with allowed_domains = ["*"], agent_browser backend
- text_browser: remains opt-in (requires external binary)

* fix(tests): update component test for browser enabled by default

Update config_nested_optional_sections_default_when_absent to expect
browser.enabled = true, matching the new default.
2026-03-22 19:07:12 -04:00
Argenis aa455ae89b feat: declarative cron job configuration (#4045)
Add support for defining cron jobs directly in the TOML config file via
`[[cron.jobs]]` array entries. Declarative jobs are synced to the SQLite
database at scheduler startup with upsert semantics:

- New declarative jobs are inserted
- Existing declarative jobs are updated to match config
- Stale declarative jobs (removed from config) are deleted
- Imperative jobs (created via CLI/API) are never modified

Each declarative job requires a stable `id` for merge tracking. A new
`source` column (`"imperative"` or `"declarative"`) distinguishes the
two creation paths. Shell jobs require `command`, agent jobs require
`prompt`, validated before any DB writes.
2026-03-22 19:03:00 -04:00
Argenis a9ffd38912 feat(memory): restore time-decay scoring lost in main→master migration (#4274)
Apply exponential time decay (2^(-age/half_life), 7-day half-life) to
memory entry scores post-recall. Core memories are exempt (evergreen).
Consolidate duplicate half-life constants into a single public constant
in the decay module.

Based on PR #4266 by 5queezer with constant consolidation fix.
2026-03-22 19:01:40 -04:00
Argenis 86a0584513 feat: add SearXNG search provider support (#4272)
Closes #4152. Adds SearXNG as a search provider option with JSON output support, configurable instance URL, env override support, and 7 new tests.
2026-03-22 19:01:35 -04:00
Argenis abef4c5719 fix(cron): add WhatsApp Web delivery channel with backend validation (#4273)
Apply PR #4258 changes to add whatsapp/whatsapp-web/whatsapp_web match
arm in deliver_announcement, feature-gated behind whatsapp-web.

Added is_web_config() guard to bail early when the WhatsApp config is
for Cloud API mode (no session_path), preventing a confusing runtime
failure with an empty session path.
2026-03-22 18:58:26 -04:00
Argenis 483b2336c4 feat(gateway): add named sessions with reconnect and validation fixes (#4275)
* fix(cron): add WhatsApp Web delivery channel with backend validation

Apply PR #4258 changes to add whatsapp/whatsapp-web/whatsapp_web match
arm in deliver_announcement, feature-gated behind whatsapp-web.

Added is_web_config() guard to bail early when the WhatsApp config is
for Cloud API mode (no session_path), preventing a confusing runtime
failure with an empty session path.

* feat(gateway): add named sessions with human-readable labels

Apply PR #4267 changes with bug fixes:
- Add get_session_name trait method so WS session_start includes the
  stored name on reconnect (not just when ?name= query param is present)
- Rename API now returns 404 for non-existent sessions instead of
  silently succeeding
- Empty ?name= query param on WS connect no longer clears existing name
2026-03-22 18:58:15 -04:00
Argenis 14cda3bc9a feat: register skill tools as callable tool specs (#4040)
Skill tools defined in [[tools]] sections are now registered as first-class
callable tool specs via the Tool trait, rather than only appearing as XML
in the system prompt. This enables the LLM to invoke skill tools through
native function calling.

- Add SkillShellTool for shell/script kind skill tools
- Add SkillHttpTool for http kind skill tools
- Add skills_to_tools() conversion and register_skill_tools() wiring
- Wire registration into both CLI and process_message agent paths
- Update prompt rendering to mark registered tools as callable
- Update affected tests across skills, agent/prompt, and channels
2026-03-22 18:51:24 -04:00
Argenis 6e8f0fa43c docs: add ADR for tool shared state ownership contract (#4057)
Define the contract for long-lived shared state in multi-client tool
execution, covering ownership (handle pattern), identity assignment
(daemon-provided ClientId), lifecycle (validation at registration),
isolation (per-client for security state), and reload semantics
(config hash invalidation).
2026-03-22 18:40:34 -04:00
argenis de la rosa a965b129f8 chore: bump version to 0.5.8
Release trigger bump after recent fixes.
2026-03-22 16:29:45 -04:00
Argenis c135de41b7 feat(matrix): add allowed_rooms config for room-level gating (#4230) (#4260)
Add an `allowed_rooms` field to MatrixConfig that controls which rooms
the bot will accept messages from and join invites for. When the list
is non-empty, messages from unlisted rooms are silently dropped and
room invites are auto-rejected. When empty (default), all rooms are
allowed, preserving backward compatibility.

- Config: add `allowed_rooms: Vec<String>` with `#[serde(default)]`
- Message handler: replace disabled room_id filter with allowlist check
- Invite handler: auto-accept allowed rooms, auto-reject others
- Support both canonical room IDs and aliases, case-insensitive
2026-03-22 14:41:43 -04:00
Argenis 2d2c2ac9e6 feat(telegram): support forwarded messages with attribution (#4265)
Parse forward_from, forward_from_chat, and forward_sender_name fields
from Telegram message updates. Prepend forwarding attribution to message
content so the LLM has context about the original sender.

Closes #4118
2026-03-22 14:36:31 -04:00
Argenis 5e774bbd70 feat(multimodal): route image messages to dedicated vision provider (#4264)
When vision_provider is configured in [multimodal] config, messages
containing [IMAGE:] markers are automatically routed to the specified
vision-capable provider instead of failing on the default text provider.

Closes #4119
2026-03-22 14:36:29 -04:00
Argenis 33015067eb feat(tts): add local Piper TTS provider (#4263)
Add a piper TTS provider that communicates with a local Piper/Coqui TTS
server via an OpenAI-compatible HTTP endpoint. This enables fully offline
voice pipelines: Whisper (STT) → LLM → Piper (TTS).

Closes #4116
2026-03-22 14:36:26 -04:00
Argenis 6b10c0b891 fix(approval): merge default auto_approve entries with user config (#4262)
When a user provides a custom `auto_approve` list in their TOML
config (e.g. to add an MCP tool), serde replaces the built-in
defaults instead of merging. This causes default safe tools like
`weather`, `calculator`, and `file_read` to lose auto-approve
status and get silently denied in non-interactive channel runs.

Add `ensure_default_auto_approve()` which merges built-in entries
into the user's list after deserialization, preserving user
additions while guaranteeing defaults are always present. Users
who want to require approval for a default tool can use
`always_ask`, which takes precedence.

Closes #4247
2026-03-22 14:28:09 -04:00
Argenis bf817e30d2 fix(provider): prevent async runtime panic during model refresh (#4261)
Wrap `fetch_live_models_for_provider` calls in
`tokio::task::spawn_blocking` so the `reqwest::blocking::Client`
is created and dropped on a dedicated thread pool instead of
inside the async Tokio context. This prevents the
"Cannot drop a runtime in a context where blocking is not allowed"
panic when running `models refresh --provider openai`.

Closes #4253
2026-03-22 14:22:47 -04:00
Alix-007 0051a0c296 fix(matrix): enforce configured room scope on inbound events (#4251)
Co-authored-by: Alix-007 <267018309+Alix-007@users.noreply.github.com>
2026-03-22 14:08:13 -04:00
Canberk Özkan d753de91f1 fix(skills): prevent panic by ensuring UTF-8 char boundary during truncation (#4252)
Fixed issue #4139.

Previously, slicing a string at exactly 64 bytes could land in the middle of a multi-byte UTF-8 character (e.g., Chinese characters), causing a runtime panic.

Changes:
- Replaced direct byte slicing with a safe boundary lookup using .char_indices().
- Ensures truncation always occurs at a valid character boundary at or before the 64-byte limit.
- Maintained existing hyphen-trimming logic.

Co-authored-by: loriscience <loriscience@gmail.com>
2026-03-22 14:08:01 -04:00
Argenis f6b2f61a01 fix(matrix): disable automatic key backup when no backup key is configured (#4259)
Explicitly call `client.encryption().backups().disable()` when backups
are not enabled, preventing the matrix_sdk_crypto crate from attempting
room key backups on every sync cycle and spamming the logs with
"Trying to backup room keys but no backup key was found" warnings.

Closes #4227
2026-03-22 13:55:45 -04:00
Argenis 70e7910cb9 fix(web): remove unused import blocking release pipeline (#4234)
fix(web): remove unused import blocking release pipeline
2026-03-22 01:35:26 -04:00
argenis de la rosa a8868768e8 fix(web): remove unused ChevronsUpDown import blocking release pipeline
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 01:20:51 -04:00
Argenis 67293c50df chore: bump version to 0.5.7 (#4232)
chore: bump version to 0.5.7
2026-03-22 01:14:08 -04:00
argenis de la rosa 1646079d25 chore: bump version to 0.5.7
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 00:49:41 -04:00
Argenis 25b639435f fix: merge voice-wake feature (PR #4162) with conflict resolution (#4225)
* feat(channels): add voice wake word detection channel

Add VoiceWakeChannel behind the `voice-wake` feature flag that:
- Captures audio from the default microphone via cpal
- Uses energy-based VAD to detect speech activity
- Transcribes speech via the existing transcription API (Whisper)
- Checks for a configurable wake word in the transcription
- On detection, captures the following utterance and dispatches it
  as a ChannelMessage

State machine: Listening -> Triggered -> Capturing -> Processing -> Listening

Config keys (under [channels_config.voice_wake]):
- wake_word (default: "hey zeroclaw")
- silence_timeout_ms (default: 2000)
- energy_threshold (default: 0.01)
- max_capture_secs (default: 30)

Includes tests for config parsing, state machine, RMS energy
computation, and WAV encoding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(config): fix pre-existing test compilation errors in schema.rs

- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): exclude voice-wake from all-features CI check

Add a `ci-all` meta-feature in Cargo.toml that includes every feature
except `voice-wake`, which requires `libasound2-dev` (ALSA) not present
on CI runners. Update the check-all-features CI job to use
`--features ci-all` instead of `--all-features`.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Giulio V <vannini.gv@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 00:49:12 -04:00
Argenis 77779844e5 feat(memory): layered architecture upgrade + remove mem0 backend (#4226)
feat(memory): layered architecture upgrade + remove mem0 backend
2026-03-22 00:47:42 -04:00
Argenis f658d5806a fix: honor [autonomy] config section in daemon/channel mode
Fixes #4171
2026-03-22 00:47:32 -04:00
Argenis 7134fe0824 Merge pull request #4223 from zeroclaw-labs/fix/4214-heartbeat-utf8-safety
fix(heartbeat): prevent UTF-8 panic, add memory bounds and path validation
2026-03-22 00:41:47 -04:00
Argenis 263802b3df Merge pull request #4224 from zeroclaw-labs/fix/4215-thai-i18n-cleanup
fix(i18n): remove extra keys and translate notion in th.toml
2026-03-22 00:41:21 -04:00
Argenis 3c25fddb2a fix: merge Gmail Pub/Sub push PR #4164 (already integrated via #4200) (#4222)
* feat(channels): add Gmail Pub/Sub push notifications for real-time email

Add GmailPushChannel that replaces IMAP polling with Google's Pub/Sub
push notification system for real-time email-driven automation.

- New channel at src/channels/gmail_push.rs implementing the Channel trait
- Registers Gmail watch subscription (POST /gmail/v1/users/me/watch)
  with automatic renewal before the 7-day expiry
- Handles incoming Pub/Sub notifications at POST /webhook/gmail
- Fetches new messages via Gmail History API (startHistoryId-based)
- Dispatches email messages to the agent with full metadata
- Sends replies via Gmail messages.send API
- Config: gmail_push.enabled, topic, label_filter, oauth_token,
  allowed_senders, webhook_url
- OAuth token encrypted at rest via existing secret store
- Webhook endpoint added to gateway router
- 30+ unit tests covering notification parsing, header extraction,
  body decoding, sender allowlist, and config serialization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(config): fix pre-existing test compilation errors in schema.rs

- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): fix extract_body_text_plain test

Gmail API sends base64url without padding. The decode_body function
converted URL-safe chars back to standard base64 but did not restore
the padding, causing STANDARD decoder to fail and falling back to
snippet. Add padding restoration before decoding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Giulio V <vannini.gv@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 00:40:42 -04:00
Argenis a6a46bdd25 fix: add weather tool to default auto_approve list
Fixes #4170
2026-03-22 00:21:33 -04:00
Argenis 235d4d2f1c fix: replace ILIKE substring matching with full-text search in postgres memory recall()
Fixes #4204
2026-03-22 00:20:11 -04:00
argenis de la rosa bd1e8c8e1a merge: resolve conflicts with master + remove memory-mem0 from ci-all
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 00:18:09 -04:00
Argenis f81807bff6 fix: serialize env-dependent codex tests to prevent race (#4210) (#4218)
Add a process-scoped Mutex that all env-var-mutating tests in
openai_codex::tests must hold.  This prevents std::env::set_var /
remove_var calls from racing when Rust's test harness runs them on
parallel threads.

Affected tests:
- resolve_responses_url_prefers_explicit_endpoint_env
- resolve_responses_url_uses_provider_api_url_override
- resolve_reasoning_effort_prefers_configured_override
- resolve_reasoning_effort_uses_legacy_env_when_unconfigured
2026-03-22 00:14:01 -04:00
argenis de la rosa bb7006313c feat(memory): layered architecture upgrade + remove mem0 backend
Implement 6-phase memory system improvement:
- Multi-stage retrieval pipeline (cache → FTS → vector)
- Namespace isolation with strict filtering
- Importance scoring (category + keyword heuristics)
- Conflict resolution via Jaccard similarity + superseded_by
- Audit trail decorator (AuditedMemory<M>)
- Policy engine (quotas, read-only namespaces, retention rules)
- Deterministic sort tiebreaker on equal scores

Remove mem0 (OpenMemory) backend — all capabilities now covered
natively with better performance (local SQLite vs external REST API).

46 battle tests + 262 existing tests pass. Backward-compatible:
existing databases auto-migrate, existing configs work unchanged.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 00:09:43 -04:00
Argenis 9a49626376 fix: use POSIX-compatible sh -c instead of dash-specific -lc (#4209) (#4217)
* fix: build web dashboard during install.sh (#4207)

* fix: use POSIX-compatible sh -c instead of dash-specific -lc in cron scheduler (#4209)
2026-03-22 00:07:37 -04:00
Argenis 8b978a721f fix: build web dashboard during install.sh (#4207) (#4216) 2026-03-22 00:02:54 -04:00
argenis de la rosa 75b4c1d4a4 fix(heartbeat): prevent UTF-8 panic, add memory bounds and path validation in session context
- Use char_indices for safe UTF-8 truncation instead of byte slicing
- Replace unbounded Vec with VecDeque rolling window in load_jsonl_messages
- Add path separator validation for channel/to to prevent directory traversal
2026-03-22 00:01:44 -04:00
argenis de la rosa b2087e6065 fix(i18n): remove extra keys and translate untranslated notion entry in th.toml 2026-03-21 23:59:46 -04:00
Nisit Sirimarnkit ad8f81ad76 Merge branch 'master' into i18n/thai-tool-descriptions 2026-03-22 10:28:47 +07:00
ninenox c58e1c1fb3 i18n: add Thai tool descriptions 2026-03-22 10:09:03 +07:00
Martin Minkus cb0779d761 feat(heartbeat): add load_session_context to inject conversation history
When `load_session_context = true` in `[heartbeat]`, the daemon loads the
last 20 messages from the target user's JSONL session file and prepends them
to the heartbeat task prompt before calling the LLM.

This gives the companion context — who the user is, what was last discussed —
so outreach messages feel like a natural continuation rather than a blank-slate
ping. Defaults to `false` (opt-in, no change to existing behaviour).

Key behaviours:
- Session context is re-read on every heartbeat tick (not cached at startup)
- Skips context injection if only assistant messages are present (prevents
  heartbeat outputs feeding back in a loop)
- Scans sessions directory for matching JSONL files using flexible filename
  matching: {channel}_{to}.jsonl, {channel}_*_{to}.jsonl, or
  {channel}_{to}_*.jsonl — handles varying session key formats
- Injects file mtime as "last message ~Xh ago" so the LLM knows how long
  the user has been silent

Config example:
  [heartbeat]
  enabled = true
  interval_minutes = 120
  load_session_context = true
  target = "telegram"
  to = "your_username"

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 02:44:19 +00:00
Chris Hengge daca2d9354 fix(web/tools): make section headings collapsible (#4180)
Agent Tools and CLI Tools section headings were static divs with no
way to collapse sections the user is not interested in, making the
page unwieldy with a large tool set.

- Convert both section heading divs to button elements toggling
  agentSectionOpen / cliSectionOpen state (both default open)
- Section content renders conditionally on those booleans
- ChevronsUpDown icon added (already in lucide-react bundle) that
  fades in on hover and indicates collapsed/expanded state
- No change to individual tool card parameter schema expand/collapse

Risk: Low — UI state only, no API or logic change.
Does not change: search/filter behaviour, tool card expand/collapse,
CLI tools table structure.
2026-03-21 22:25:18 -04:00
Chris Hengge 3c1e710c38 fix(web/logs): layout, footer status indicator, and empty-state note (#4203)
Three issues addressed:

Empty page: the logs page shows nothing at idle because the SSE stream
only carries ObserverEvent variants (llm_request, tool_call, error,
agent_start, agent_end). Daemon stdout and RUST_LOG tracing output go
to the terminal/log file and are never forwarded to the broadcast
channel — this is correct behaviour, not a misconfiguration. Added a
dismissible informational banner explaining what appears on the stream
and how to access tracing output (RUST_LOG + terminal).

Layout: flex-1 log entries div was missing min-h-0, which can cause
flex children to refuse to shrink below content size in some browsers.

Connection indicator: moved from the toolbar (where it cluttered the
title and controls) to a compact footer strip below the scroll area,
matching the /agent page pattern exactly.

Also added colour rules for llm_request, agent_start, agent_end,
tool_call_start event types which previously fell through to default
grey.

Risk: Low — UI layout and informational copy only, no backend change.
Does not change: SSE connection logic, event parsing, pause/resume,
type filters, or the underlying broadcast observer.
2026-03-21 22:11:20 -04:00
Chris Hengge 0aefde95f2 fix(web/config): fill viewport and add TOML syntax highlighting (#4201)
Two issues with the config editor:

Layout: the page root had no height constraint and the textarea used
min-h-[500px] resize-y, causing independent scrollbars on both the
page and the editor. Fixed by adopting the Memory/Cron flex column
pattern so the editor fills the remaining viewport height with a single
scroll surface.

Highlighting: plain textarea with no visual structure for TOML.
Added a zero-dependency layered pre-overlay technique — no new npm
packages (per CLAUDE.md anti-pattern rules). A pre element sits
absolute behind a transparent textarea; highlightToml() produces HTML
colour-coding sections, keys, strings, booleans, numbers, datetimes,
and comments via per-line regex. onScroll syncs the overlay. Tab key
inserts two spaces instead of leaving focus.

dangerouslySetInnerHTML used on the pre — content is the user's own
local config, not from the network, risk equivalent to any local editor.

Risk: Low-Medium — no API or backend change. New rendering logic
in editor only.
Does not change: save/load API calls, validation, sensitive field
masking behaviour.
2026-03-21 22:11:18 -04:00
Chris Hengge a84aa60554 fix(web/cron): contain table scroll within viewport (#4186)
The cron page used a block-flow root with no height constraint, causing
the jobs table to grow taller than the viewport and the page itself to
scroll. This was inconsistent with the Memory page pattern.

- Change page root to flex flex-col h-full matching Memory's layout
- Table wrapper gains flex-1 min-h-0 overflow-auto so it fills
  remaining height and scrolls both axes internally
- Table header already has position:sticky so it pins correctly
  inside the scrolling container with no CSS change needed

Risk: Low — layout only, no logic or API change.
Does not change: job CRUD, modal, catch-up toggle, run history panel.
2026-03-21 22:11:15 -04:00
Chris Hengge edd4b37325 fix(web/dashboard): rename channels card heading and add internal scroll (#4178)
The card heading used the key dashboard.active_channels ("Active Channels")
even though the card has a toggle between Active and All views, making the
static heading misleading. The channel list div had no height cap, causing
tall channel lists to stretch the card and break 3-column grid alignment.

- Change heading to t("dashboard.channels") — key already present in all
  three locales (zh/en/tr), no i18n changes needed
- Add overflow-y-auto max-h-48 pr-1 to the channel list wrapper so it
  scrolls internally instead of stretching the card
2026-03-21 22:09:00 -04:00
Argenis c5f0155061 Merge pull request #4193 from zeroclaw-labs/fix/reaction-tool
fix(tools): pass platform channel_id to reaction trait
2026-03-21 21:38:32 -04:00
argenis de la rosa 9ee06ed6fc merge: resolve conflicts with master (image_gen + sessions) 2026-03-21 21:18:46 -04:00
Argenis ac6b43e9f4 fix: remove unused channel_names field from DiscordHistoryChannel (#4199)
* feat: add discord history logging and search tool with persistent channel cache

* fix: remove unused channel_names field from DiscordHistoryChannel

The channel_names HashMap was declared and initialized but never used.
Channel name caching is handled via discord_memory.get()/store() with
the cache:channel_name: prefix. Remove the dead field.

* style: run cargo fmt on discord_history.rs

---------

Co-authored-by: ninenox <nisit15@hotmail.com>
2026-03-21 21:15:23 -04:00
Argenis 6c5573ad96 Merge pull request #4194 from zeroclaw-labs/fix/session-messaging-tools
fix(security): add enforcement and validation to session tools
2026-03-21 21:15:17 -04:00
Argenis 1d57a0d1e5 fix(web/tools): improve a11y in collapsible section headings (#4197)
* fix(web/tools): make section headings collapsible

Agent Tools and CLI Tools section headings were static divs with no
way to collapse sections the user is not interested in, making the
page unwieldy with a large tool set.

- Convert both section heading divs to button elements toggling
  agentSectionOpen / cliSectionOpen state (both default open)
- Section content renders conditionally on those booleans
- ChevronsUpDown icon added (already in lucide-react bundle) that
  fades in on hover and indicates collapsed/expanded state
- No change to individual tool card parameter schema expand/collapse

Risk: Low — UI state only, no API or logic change.
Does not change: search/filter behaviour, tool card expand/collapse,
CLI tools table structure.

* fix(web/tools): improve a11y and fix invalid HTML in collapsible sections

- Replace <h2> inside <button> with <span role="heading" aria-level={2}>
  to fix invalid HTML (heading elements not permitted in interactive content)
- Add aria-expanded attribute to section toggle buttons for screen readers
- Add aria-controls + id linking buttons to their controlled sections
- Replace ChevronsUpDown with ChevronDown icon — ChevronsUpDown is
  visually symmetric so rotating 180deg has no visible effect; ChevronDown
  rotating to -90deg gives a clear directional cue
- Remove unused ChevronsUpDown import

---------

Co-authored-by: WareWolf-MoonWall <chris.hengge@gmail.com>
2026-03-21 21:02:10 -04:00
Argenis 9780c7d797 fix: restrict free command to Linux-only in security policy (#4198)
* fix: resolve claude-code test flakiness and update security policy

* fix: restrict `free` command to Linux-only in security policy

`free` is not available on macOS or other BSDs. Move it behind
a #[cfg(target_os = "linux")] gate so it is only included in the
default allowed commands on Linux systems.

---------

Co-authored-by: ninenox <nisit15@hotmail.com>
2026-03-21 21:02:05 -04:00
Argenis 35a5451a17 fix(channels): address critical security bugs in Gmail Pub/Sub push (#4200)
* feat(channels): add Gmail Pub/Sub push notifications for real-time email

Add GmailPushChannel that replaces IMAP polling with Google's Pub/Sub
push notification system for real-time email-driven automation.

- New channel at src/channels/gmail_push.rs implementing the Channel trait
- Registers Gmail watch subscription (POST /gmail/v1/users/me/watch)
  with automatic renewal before the 7-day expiry
- Handles incoming Pub/Sub notifications at POST /webhook/gmail
- Fetches new messages via Gmail History API (startHistoryId-based)
- Dispatches email messages to the agent with full metadata
- Sends replies via Gmail messages.send API
- Config: gmail_push.enabled, topic, label_filter, oauth_token,
  allowed_senders, webhook_url
- OAuth token encrypted at rest via existing secret store
- Webhook endpoint added to gateway router
- 30+ unit tests covering notification parsing, header extraction,
  body decoding, sender allowlist, and config serialization

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(config): fix pre-existing test compilation errors in schema.rs

- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): fix extract_body_text_plain test

Gmail API sends base64url without padding. The decode_body function
converted URL-safe chars back to standard base64 but did not restore
the padding, causing STANDARD decoder to fail and falling back to
snippet. Add padding restoration before decoding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): address critical security bugs in Gmail Pub/Sub push

- Add webhook authentication via shared secret (webhook_secret config
  field or GMAIL_PUSH_WEBHOOK_SECRET env var), preventing unauthorized
  message injection through the unauthenticated webhook endpoint
- Add 1MB body size limit on webhook endpoint to prevent memory exhaustion
- Fix race condition in handle_notification: hold history_id lock across
  the read-fetch-update cycle to prevent duplicate message processing
  when concurrent webhook notifications arrive
- Sanitize RFC 2822 headers (To/Subject) to prevent CRLF injection
  attacks that could add arbitrary headers to outgoing emails
- Fix extract_email_from_header panic on malformed angle brackets by
  using rfind('>') and validating bracket ordering
- Add 30s default HTTP client timeout for all Gmail API calls,
  preventing indefinite hangs
- Clone tx sender before message processing loop to avoid holding
  the mutex lock across network calls

---------

Co-authored-by: Giulio V <vannini.gv@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 20:59:56 -04:00
Argenis 8e81d44d54 fix(gateway): address critical security and reliability bugs in Live Canvas (#4196)
* feat(gateway): add Live Canvas (A2UI) tool and real-time web viewer

Add a Live Canvas system that enables the agent to push rendered content
(HTML, SVG, Markdown, text) to a web-visible canvas in real time.

Backend:
- src/tools/canvas.rs: CanvasTool with render/snapshot/clear/eval actions,
  backed by a shared CanvasStore (Arc<RwLock<HashMap>>) with per-canvas
  broadcast channels for real-time updates
- src/gateway/canvas.rs: REST endpoints (GET/POST/DELETE /api/canvas/:id,
  GET /api/canvas/:id/history, GET /api/canvas) and WebSocket endpoint
  (WS /ws/canvas/:id) for real-time frame delivery

Frontend:
- web/src/pages/Canvas.tsx: Canvas viewer page with WebSocket connection,
  iframe sandbox rendering, canvas switcher, frame history panel

Registration:
- CanvasTool registered in all_tools_with_runtime (always available)
- Canvas routes wired into gateway router
- CanvasStore added to AppState
- Canvas page added to App.tsx router and Sidebar navigation
- i18n keys added for en/zh/tr locales

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(config): fix pre-existing test compilation errors in schema.rs

- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(gateway): share CanvasStore between tool and REST API

The CanvasTool and gateway AppState each created their own CanvasStore,
so content rendered via the tool never appeared in the REST API.

Create the CanvasStore once in the gateway, pass it to
all_tools_with_runtime via a new optional parameter, and reuse the
same instance in AppState.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(gateway): address critical security and reliability bugs in Live Canvas

- Validate content_type in REST POST endpoint against allowed set,
  preventing injection of "eval" frames via the REST API
- Enforce MAX_CONTENT_SIZE (256KB) limit on REST POST endpoint,
  matching tool-side validation to prevent memory exhaustion
- Add MAX_CANVAS_COUNT (100) limit to prevent unbounded canvas creation
  and memory exhaustion from CanvasStore
- Handle broadcast RecvError::Lagged in WebSocket handler gracefully
  instead of disconnecting the client
- Make MAX_CONTENT_SIZE and ALLOWED_CONTENT_TYPES pub for gateway reuse
- Update CanvasStore::render and subscribe to return Option for
  canvas count enforcement

---------

Co-authored-by: Giulio V <vannini.gv@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: rareba <rareba@users.noreply.github.com>
2026-03-21 20:59:18 -04:00
Argenis 86ad0c6a2b fix(channels): address critical bugs in voice wake word detection (#4191)
* feat(channels): add voice wake word detection channel

Add VoiceWakeChannel behind the `voice-wake` feature flag that:
- Captures audio from the default microphone via cpal
- Uses energy-based VAD to detect speech activity
- Transcribes speech via the existing transcription API (Whisper)
- Checks for a configurable wake word in the transcription
- On detection, captures the following utterance and dispatches it
  as a ChannelMessage

State machine: Listening -> Triggered -> Capturing -> Processing -> Listening

Config keys (under [channels_config.voice_wake]):
- wake_word (default: "hey zeroclaw")
- silence_timeout_ms (default: 2000)
- energy_threshold (default: 0.01)
- max_capture_secs (default: 30)

Includes tests for config parsing, state machine, RMS energy
computation, and WAV encoding.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(config): fix pre-existing test compilation errors in schema.rs

- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): exclude voice-wake from all-features CI check

Add a `ci-all` meta-feature in Cargo.toml that includes every feature
except `voice-wake`, which requires `libasound2-dev` (ALSA) not present
on CI runners. Update the check-all-features CI job to use
`--features ci-all` instead of `--all-features`.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(channels): address critical bugs in voice wake word detection

- Replace std::mem::forget(stream) with dedicated thread that holds the
  cpal stream and shuts down cleanly via oneshot channel, preventing
  microphone resource leaks on task cancellation
- Add config validation: energy_threshold must be positive+finite,
  silence_timeout_ms >= 100ms, max_capture_secs clamped to 300
- Guard WAV encoding against u32 overflow for large audio buffers
- Add hard cap on capture_buf size to prevent unbounded memory growth
- Increase audio channel buffer from 4 to 64 slots to reduce chunk
  drops during transcription API calls
- Remove dead WakeState::Processing variant that was never entered

---------

Co-authored-by: Giulio V <vannini.gv@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 20:43:19 -04:00
Argenis 6ecf89d6a9 fix(ci): skip release and publish workflows on forks (#4190)
When a fork syncs with upstream, GitHub attributes the push to the fork
owner, causing release-beta-on-push and publish-crates-auto to run
under the wrong identity — leading to confusing notifications and
guaranteed failures (missing secrets).

Add repository guards to root jobs so the entire pipeline is skipped
on forks.
2026-03-21 20:42:55 -04:00
argenis de la rosa 691efa4d8c style: fix cargo fmt formatting in reaction tool 2026-03-21 20:38:24 -04:00
argenis de la rosa d1e3f435b4 style: fix cargo fmt formatting in session tools 2026-03-21 20:38:08 -04:00
Argenis 44c3e264ad Merge pull request #4192 from zeroclaw-labs/fix/image-gen-tool
fix(tools): harden image_gen security and model validation
2026-03-21 20:37:27 -04:00
argenis de la rosa f2b6013329 fix(tools): harden image_gen security enforcement and model validation
- Replace manual can_act()/record_action() with enforce_tool_operation()
  to match the codebase convention used by all other tools (notion,
  memory_forget, claude_code, delegate, etc.), producing consistent
  error messages and avoiding logic duplication.

- Add model parameter validation to prevent URL path traversal attacks
  via crafted model identifiers (e.g. "../../evil-endpoint").

- Add tests for model traversal rejection and filename sanitization.
2026-03-21 20:08:51 -04:00
argenis de la rosa 05d3c51a30 fix(security): add security policy enforcement and input validation to session tools
SessionsSendTool was missing security gate enforcement entirely - any agent
could send messages to any session without security policy checks. Similarly,
SessionsHistoryTool had no security enforcement for reading session data.

Changes:
- Add SecurityPolicy field to SessionsHistoryTool (enforces ToolOperation::Read)
- Add SecurityPolicy field to SessionsSendTool (enforces ToolOperation::Act)
- Add session_id validation to reject empty or non-alphanumeric-only IDs
- Pass security policy from all_tools_with_runtime registration
- Add tests for empty session_id, non-alphanumeric session_id validation
2026-03-21 20:04:44 -04:00
argenis de la rosa 2ceda31ce2 fix(tools): pass platform channel_id to reaction trait instead of channel name
The reaction tool was passing the channel adapter name (e.g. "discord",
"slack") as the first argument to Channel::add_reaction() and
Channel::remove_reaction(), but the trait signature expects a
platform-specific channel_id (e.g. Discord channel snowflake, Slack
channel ID like "C0123ABCD"). This would cause all reaction API calls
to fail at the platform level.

Fixes:
- Add required "channel_id" parameter to the tool schema
- Extract and pass channel_id (not channel_name) to trait methods
- Update tool description to mention the new parameter
- Add MockChannel channel_id capture for test verification
- Add test asserting channel_id (not name) reaches the trait
- Update all existing tests to supply channel_id
2026-03-21 20:01:22 -04:00
Argenis 9069bc3c1f fix(agent): add system prompt budgeting for small-context models (#4185)
For models with small context windows (e.g. glm-4.5-air ~8K tokens),
the system prompt alone can exceed the limit. This adds:

- max_system_prompt_chars config option (default 0 = unlimited)
- compact_context now also compacts the system prompt: skips the
  Channel Capabilities section and shows only tool names
- Truncation with marker when prompt exceeds the budget

Users can set `max_system_prompt_chars = 8000` in [agent] config
to cap the system prompt for small-context models.

Closes #4124
2026-03-21 19:40:21 -04:00
Argenis 9319fe18da fix(approval): support wildcard * in auto_approve and always_ask (#4184)
auto_approve = ["*"] was doing exact string matching, so only the
literal tool name "*" was matched. Users expecting wildcard semantics
had every tool blocked in supervised mode.

Also adds "prompt exceeds max length" to the context-window error
detection hints (fixes GLM/ZAI error 1261 detection).

Closes #4127
2026-03-21 19:38:11 -04:00
Argenis cc454a86c8 fix(install): remove pairing code display from installer (#4176)
The gateway pairing code is now shown in the dashboard, so displaying
it in the installer output is redundant and cluttered (showed 3 codes).
2026-03-21 19:06:37 -04:00
Argenis 256e8ccebf chore: bump version to v0.5.6 (#4174)
Update version across all distribution manifests:
- Cargo.toml / Cargo.lock
- dist/aur/PKGBUILD + .SRCINFO
- dist/scoop/zeroclaw.json
2026-03-21 18:03:38 -04:00
Argenis 72c9e6b6ca fix(publish): publish aardvark-sys dep before main crate (#4172)
* fix(publish): add aardvark-sys version and publish it before main crate

- Add version = "0.1.0" to aardvark-sys path dependency in Cargo.toml
- Update all three publish workflows to publish aardvark-sys first
- Add aardvark-sys COPY to Dockerfile for workspace builds
- Fixes cargo publish failure: "dependency aardvark-sys does not
  specify a version"

* ci: publish aardvark-sys before main crate in all publish workflows

All three crates.io publish workflows now publish aardvark-sys first,
wait for indexing, then publish the main zeroclawlabs crate.
2026-03-21 16:20:50 -04:00
Argenis 755a129ca2 fix(install): use /dev/tty for sudo in curl|bash Xcode license accept (#4169)
When run via `curl | bash`, stdin is the curl pipe, so sudo cannot
prompt for a password. Redirect sudo's stdin from /dev/tty to reach
the real terminal, allowing the password prompt to work in piped
invocations.
2026-03-21 14:15:21 -04:00
Argenis 8b0d3684c5 fix(install): auto-accept Xcode license instead of bailing out (#4165)
Instead of exiting with a manual remediation step, the installer now
attempts to accept the Xcode/CLT license automatically via
`sudo xcodebuild -license accept`. Falls back to a clear error message
only if sudo fails (e.g. no terminal or password).
2026-03-21 13:57:38 -04:00
Giulio V cdb5ac1471 fix(tools): fix remove_reaction_success test
The output format used "{action}ed" which produced "removeed" for the
remove action. Use explicit past-tense mapping instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 18:49:35 +01:00
Giulio V 67acb1a0bb fix(config): fix pre-existing test compilation errors in schema.rs
- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 18:10:05 +01:00
Giulio V 9eac6bafef fix(config): fix pre-existing test compilation errors in schema.rs
- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 18:09:48 +01:00
Giulio V a12f2ff439 fix(config): fix pre-existing test compilation errors in schema.rs
- Remove #[cfg(unix)] gate on `use tempfile::TempDir` import since
  TempDir is used unconditionally in bootstrap file tests
- Add explicit type annotations on tokio::fs::* calls to resolve
  type inference failures (create_dir_all, write, read_to_string)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 18:09:36 +01:00
Argenis a38a4d132e fix(hardware): drain stdin in subprocess test to prevent broken pipe flake (#4161)
* fix(hardware): drain stdin in subprocess test to prevent broken pipe flake

The test script did not consume stdin, so SubprocessTool's stdin write
raced against the process exit, causing intermittent EPIPE failures.
Add `cat > /dev/null` to drain stdin before producing output.

* style: format subprocess test
2026-03-21 12:19:53 -04:00
Argenis 48aba73d3a fix(install): always check Xcode license on macOS, not just with --install-system-deps (#4153)
The Xcode license test-compile was inside install_system_deps(), which
only runs when --install-system-deps is passed. On macOS the default
path skipped this entirely, so users hit `cc` exit code 69 deep in
cargo build. Move the check into the unconditional main flow so it
always fires on Darwin.
2026-03-21 11:29:36 -04:00
Argenis a1ab1e1a11 fix(install): use test-compile instead of xcrun for Xcode license detection (#4151)
xcrun --show-sdk-path can succeed even when the Xcode/CLT license has
not been accepted, so the previous check was ineffective. Replace it
with an actual test-compilation of a trivial C file, which reliably
triggers the exit-code-69 failure when the license is pending.
2026-03-21 11:03:07 -04:00
Giulio V f394abf35c feat(tools): add standalone image generation tool via fal.ai
Add ImageGenTool that exposes fal.ai Flux model image generation as a
standalone tool, decoupled from the LinkedIn client. The tool accepts a
text prompt, optional filename/size/model parameters, calls the fal.ai
synchronous API, downloads the result, and saves to workspace/images/.

- New src/tools/image_gen.rs with full Tool trait implementation
- New ImageGenConfig in schema.rs (enabled, default_model, api_key_env)
- Config-gated registration in all_tools_with_runtime
- Security: checks can_act() and record_action() before execution
- Comprehensive unit tests (prompt validation, API key, size enum,
  autonomy blocking, tool spec)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 15:17:28 +01:00
Giulio V 52e0271bd5 feat(tools): add emoji reaction tool for cross-channel reactions
Add ReactionTool that exposes Channel::add_reaction and
Channel::remove_reaction as an agent-callable tool. Uses a
late-binding ChannelMapHandle (Arc<RwLock<HashMap>>) pattern
so the tool can be constructed during tool registry init and
populated once channels are available in start_channels.

Parameters: channel, message_id, emoji, action (add/remove).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 15:15:25 +01:00
Giulio V 6c0a48efff feat(tools): add session list, history, and send tools for inter-agent messaging
Add three new tools in src/tools/sessions.rs:
- sessions_list: lists active sessions with channel, message count, last activity
- sessions_history: reads last N messages from a session by ID
- sessions_send: appends a message to a session for inter-agent communication

All tools operate on the SessionBackend trait, using the JSONL SessionStore
by default. Registered unconditionally in all_tools_with_runtime when the
sessions directory is accessible.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-21 15:07:18 +01:00
SimianAstronaut7 87b5bca449 feat(config): add configurable pacing controls for slow/local LLM workloads (#3343)
* feat(config): add configurable pacing controls for slow/local LLM workloads (#2963)

Add a new `[pacing]` config section with four opt-in parameters that
let users tune timeout and loop-detection behavior for local LLMs
(Ollama, llama.cpp, vLLM) without disabling safety features entirely:

- `step_timeout_secs`: per-step LLM inference timeout independent of
  the overall message budget, catching hung model responses early.
- `loop_detection_min_elapsed_secs`: time-gated loop detection that
  only activates after a configurable grace period, avoiding false
  positives on long-running browser/research workflows.
- `loop_ignore_tools`: per-tool loop-detection exclusions so tools
  like `browser_screenshot` that structurally resemble loops are not
  counted toward identical-output detection.
- `message_timeout_scale_max`: overrides the hardcoded 4x ceiling in
  the channel message timeout scaling formula.

All parameters are strictly optional with no effect when absent,
preserving full backwards compatibility.

Closes #2963

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix(config): add missing pacing fields in tests and call sites

* fix(config): add pacing arg to remaining cost-tracking test call sites

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
2026-03-21 08:54:08 -04:00
Argenis be40c0c5a5 Merge pull request #4145 from zeroclaw-labs/feat/gateway-path-prefix
feat(gateway): add path_prefix for reverse-proxy deployments
2026-03-21 08:48:56 -04:00
argenis de la rosa 6527871928 fix: add path_prefix to test AppState in gateway/api.rs 2026-03-21 08:14:28 -04:00
argenis de la rosa 0bda80de9c feat(gateway): add path_prefix for reverse-proxy deployments
Adopted from #3709 by @slayer with minor cleanup.
Supersedes #3709
2026-03-21 08:14:28 -04:00
120 changed files with 15977 additions and 1909 deletions
-97
View File
@@ -1,97 +0,0 @@
# Mem0 Integration: Dual-Scope Recall + Per-Turn Memory
## Context
Mem0 auto-save works but the integration is missing key features from mem0 best practices: per-turn recall, multi-level scoping, and proper context injection. This causes the bot to "forget" on follow-up turns and not differentiate users.
## What's Missing (vs mem0 docs)
1. **Per-turn recall** — only first turn gets memory context, follow-ups get nothing
2. **Dual-scope** — no sender vs group distinction. All memories use single hardcoded `user_id`
3. **System prompt injection** — memory prepended to user message (pollutes session history)
4. **`agent_id` scoping** — mem0 supports agent-level patterns, not used
## Changes
### 1. `src/memory/mem0.rs` — Use session_id for multi-level scoping
Map zeroclaw's `session_id` param to mem0's `user_id`. This enables per-user and per-group memory namespaces without changing the `Memory` trait.
```rust
// Add helper:
fn effective_user_id(&self, session_id: Option<&str>) -> &str {
session_id.filter(|s| !s.is_empty()).unwrap_or(&self.user_id)
}
// In store(): use effective_user_id(session_id) as mem0 user_id
// In recall(): use effective_user_id(session_id) as mem0 user_id
// In list(): use effective_user_id(session_id) as mem0 user_id
```
### 2. `src/channels/mod.rs` ~line 2229 — Per-turn dual-scope recall
Remove `if !had_prior_history` gate. Always recall from both sender scope and group scope (for group chats).
```rust
// Detect group chat
let is_group = msg.reply_target.contains("@g.us")
|| msg.reply_target.starts_with("group:");
// Sender-scope recall (always)
let sender_context = build_memory_context(
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
Some(&msg.sender),
).await;
// Group-scope recall (groups only)
let group_context = if is_group {
build_memory_context(
ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score,
Some(&history_key),
).await
} else {
String::new()
};
// Merge (deduplicate by checking substring overlap)
let memory_context = merge_memory_contexts(&sender_context, &group_context);
```
### 3. `src/channels/mod.rs` ~line 2244 — Inject into system prompt
Move memory context from user message to system prompt. Re-fetched each turn, doesn't pollute session.
```rust
let mut system_prompt = build_channel_system_prompt(...);
if !memory_context.is_empty() {
system_prompt.push_str(&format!("\n\n{memory_context}"));
}
let mut history = vec![ChatMessage::system(system_prompt)];
```
### 4. `src/channels/mod.rs` — Dual-scope auto-save
Find existing auto-save call. For group messages, store twice:
- `store(key, content, category, Some(&msg.sender))` — personal facts
- `store(key, content, category, Some(&history_key))` — group context
Both async, non-blocking. DMs only store to sender scope.
### 5. `src/memory/mem0.rs` — Add `agent_id` support (optional)
Pass `self.app_name` as `agent_id` param to mem0 API for agent behavior tracking.
## Files to Modify
1. `src/memory/mem0.rs` — session_id → user_id mapping
2. `src/channels/mod.rs` — per-turn recall, dual-scope, system prompt injection, dual-scope save
## Verification
1. `cargo check --features whatsapp-web,memory-mem0`
2. `cargo test --features whatsapp-web,memory-mem0`
3. Deploy to Synology
4. Test DM: "我鍾意食壽司" → next turn "我鍾意食咩" → should recall
5. Test group: Joe says "我鍾意食壽司" → someone else asks "Joe 鍾意食咩" → should recall from group scope
6. Check mem0 server logs: GET with `user_id=sender` AND `user_id=group_key`
7. Check mem0 server logs: POST with both user_ids for group messages
+4
View File
@@ -118,3 +118,7 @@ PROVIDER=openrouter
# Optional: Brave Search (requires API key from https://brave.com/search/api)
# WEB_SEARCH_PROVIDER=brave
# BRAVE_API_KEY=your-brave-search-api-key
#
# Optional: SearXNG (self-hosted, requires instance URL)
# WEB_SEARCH_PROVIDER=searxng
# SEARXNG_INSTANCE_URL=https://searx.example.com
+1 -1
View File
@@ -154,7 +154,7 @@ jobs:
run: mkdir -p web/dist && touch web/dist/.gitkeep
- name: Check all features
run: cargo check --all-features --locked
run: cargo check --features ci-all --locked
docs-quality:
name: Docs Quality
+16
View File
@@ -1,6 +1,22 @@
name: Pub Homebrew Core
on:
workflow_call:
inputs:
release_tag:
description: "Existing release tag to publish (vX.Y.Z)"
required: true
type: string
dry_run:
description: "Patch formula only (no push/PR)"
required: false
default: false
type: boolean
secrets:
HOMEBREW_UPSTREAM_PR_TOKEN:
required: false
HOMEBREW_CORE_BOT_TOKEN:
required: false
workflow_dispatch:
inputs:
release_tag:
+25
View File
@@ -19,6 +19,7 @@ env:
jobs:
detect-version-change:
name: Detect Version Bump
if: github.repository == 'zeroclaw-labs/zeroclaw'
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
@@ -40,6 +41,14 @@ jobs:
echo "Current version: ${current}"
echo "Previous version: ${previous}"
# Skip if stable release workflow will handle this version
# (indicated by an existing or imminent stable tag)
if git ls-remote --exit-code --tags origin "refs/tags/v${current}" >/dev/null 2>&1; then
echo "Stable tag v${current} exists — stable release workflow handles crates.io"
echo "changed=false" >> "$GITHUB_OUTPUT"
exit 0
fi
if [[ "$current" != "$previous" && -n "$current" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
echo "version=${current}" >> "$GITHUB_OUTPUT"
@@ -102,6 +111,22 @@ jobs:
- name: Clean web build artifacts
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
- name: Publish aardvark-sys to crates.io
shell: bash
env:
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
run: |
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
echo "$OUTPUT"
if echo "$OUTPUT" | grep -q 'already exists'; then
echo "::notice::aardvark-sys already on crates.io — skipping"
exit 0
fi
exit 1
- name: Wait for aardvark-sys to index
run: sleep 15
- name: Publish to crates.io
shell: bash
env:
+18
View File
@@ -67,6 +67,24 @@ jobs:
- name: Clean web build artifacts
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
- name: Publish aardvark-sys to crates.io
if: "!inputs.dry_run"
shell: bash
env:
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
run: |
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
echo "$OUTPUT"
if echo "$OUTPUT" | grep -q 'already exists'; then
echo "::notice::aardvark-sys already on crates.io — skipping"
exit 0
fi
exit 1
- name: Wait for aardvark-sys to index
if: "!inputs.dry_run"
run: sleep 15
- name: Publish (dry run)
if: inputs.dry_run
run: cargo publish --dry-run --locked --allow-dirty --no-verify
@@ -21,25 +21,48 @@ env:
jobs:
version:
name: Resolve Version
if: github.repository == 'zeroclaw-labs/zeroclaw'
runs-on: ubuntu-latest
outputs:
version: ${{ steps.ver.outputs.version }}
tag: ${{ steps.ver.outputs.tag }}
skip: ${{ steps.ver.outputs.skip }}
steps:
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
with:
fetch-depth: 2
- name: Compute beta version
id: ver
shell: bash
run: |
set -euo pipefail
base_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
# Skip beta if this is a version bump commit (stable release handles it)
commit_msg=$(git log -1 --pretty=format:"%s")
if [[ "$commit_msg" =~ ^chore:\ bump\ version ]]; then
echo "Version bump commit detected — skipping beta release"
echo "skip=true" >> "$GITHUB_OUTPUT"
exit 0
fi
# Skip beta if a stable tag already exists for this version
if git ls-remote --exit-code --tags origin "refs/tags/v${base_version}" >/dev/null 2>&1; then
echo "Stable tag v${base_version} exists — skipping beta release"
echo "skip=true" >> "$GITHUB_OUTPUT"
exit 0
fi
beta_tag="v${base_version}-beta.${GITHUB_RUN_NUMBER}"
echo "version=${base_version}" >> "$GITHUB_OUTPUT"
echo "tag=${beta_tag}" >> "$GITHUB_OUTPUT"
echo "skip=false" >> "$GITHUB_OUTPUT"
echo "Beta release: ${beta_tag}"
release-notes:
name: Generate Release Notes
needs: [version]
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
runs-on: ubuntu-latest
outputs:
notes: ${{ steps.notes.outputs.body }}
@@ -130,6 +153,8 @@ jobs:
web:
name: Build Web Dashboard
needs: [version]
if: github.repository == 'zeroclaw-labs/zeroclaw' && needs.version.outputs.skip != 'true'
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
+55 -4
View File
@@ -1,6 +1,9 @@
name: Release Stable
on:
push:
tags:
- "v[0-9]+.[0-9]+.[0-9]+" # stable tags only (no -beta suffix)
workflow_dispatch:
inputs:
version:
@@ -33,11 +36,22 @@ jobs:
- name: Validate semver and Cargo.toml match
id: check
shell: bash
env:
INPUT_VERSION: ${{ inputs.version || '' }}
REF_NAME: ${{ github.ref_name }}
EVENT_NAME: ${{ github.event_name }}
run: |
set -euo pipefail
input_version="${{ inputs.version }}"
cargo_version=$(sed -n 's/^version = "\([^"]*\)"/\1/p' Cargo.toml | head -1)
# Resolve version from tag push or manual input
if [[ "$EVENT_NAME" == "push" ]]; then
# Tag push: extract version from tag name (v0.5.9 -> 0.5.9)
input_version="${REF_NAME#v}"
else
input_version="$INPUT_VERSION"
fi
if [[ ! "$input_version" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "::error::Version must be semver (X.Y.Z). Got: ${input_version}"
exit 1
@@ -49,9 +63,13 @@ jobs:
fi
tag="v${input_version}"
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
echo "::error::Tag ${tag} already exists."
exit 1
# Only check tag existence for manual dispatch (tag push means it already exists)
if [[ "$EVENT_NAME" != "push" ]]; then
if git ls-remote --exit-code --tags origin "refs/tags/${tag}" >/dev/null 2>&1; then
echo "::error::Tag ${tag} already exists."
exit 1
fi
fi
echo "tag=${tag}" >> "$GITHUB_OUTPUT"
@@ -286,6 +304,14 @@ jobs:
NOTES: ${{ needs.release-notes.outputs.notes }}
run: printf '%s\n' "$NOTES" > release-notes.md
- name: Create tag if manual dispatch
if: github.event_name == 'workflow_dispatch'
env:
TAG: ${{ needs.validate.outputs.tag }}
run: |
git tag -a "$TAG" -m "zeroclaw $TAG"
git push origin "$TAG"
- name: Create GitHub Release
env:
GH_TOKEN: ${{ secrets.RELEASE_TOKEN }}
@@ -323,6 +349,21 @@ jobs:
- name: Clean web build artifacts
run: rm -rf web/node_modules web/src web/package.json web/package-lock.json web/tsconfig*.json web/vite.config.ts web/index.html
- name: Publish aardvark-sys to crates.io
env:
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
run: |
OUTPUT=$(cargo publish --locked --allow-dirty --no-verify -p aardvark-sys 2>&1) && exit 0
echo "$OUTPUT"
if echo "$OUTPUT" | grep -q 'already exists'; then
echo "::notice::aardvark-sys already on crates.io — skipping"
exit 0
fi
exit 1
- name: Wait for aardvark-sys to index
run: sleep 15
- name: Publish to crates.io
env:
CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }}
@@ -446,6 +487,16 @@ jobs:
dry_run: false
secrets: inherit
homebrew:
name: Update Homebrew Core
needs: [validate, publish]
if: ${{ !cancelled() && needs.publish.result == 'success' }}
uses: ./.github/workflows/pub-homebrew-core.yml
with:
release_tag: ${{ needs.validate.outputs.tag }}
dry_run: false
secrets: inherit
# ── Post-publish: tweet after release + website are live ──────────────
# Docker push can be slow; don't let it block the tweet.
tweet:
Generated
+382 -26
View File
@@ -117,6 +117,28 @@ version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]]
name = "alsa"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed7572b7ba83a31e20d1b48970ee402d2e3e0537dcfe0a3ff4d6eb7508617d43"
dependencies = [
"alsa-sys",
"bitflags 2.11.0",
"cfg-if",
"libc",
]
[[package]]
name = "alsa-sys"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db8fee663d06c4e303404ef5f40488a53e062f89ba8bfed81f42325aafad1527"
dependencies = [
"libc",
"pkg-config",
]
[[package]]
name = "ambient-authority"
version = "0.0.2"
@@ -583,6 +605,24 @@ dependencies = [
"virtue",
]
[[package]]
name = "bindgen"
version = "0.72.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
dependencies = [
"bitflags 2.11.0",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 2.0.117",
]
[[package]]
name = "bip39"
version = "2.2.2"
@@ -878,6 +918,21 @@ dependencies = [
"shlex",
]
[[package]]
name = "cesu8"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom 7.1.3",
]
[[package]]
name = "cff-parser"
version = "0.1.0"
@@ -1003,6 +1058,17 @@ dependencies = [
"zeroize",
]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "4.6.0"
@@ -1086,6 +1152,16 @@ version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"memchr",
]
[[package]]
name = "compression-codecs"
version = "0.4.37"
@@ -1209,6 +1285,49 @@ dependencies = [
"libm",
]
[[package]]
name = "coreaudio-rs"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "321077172d79c662f64f5071a03120748d5bb652f5231570141be24cfcd2bace"
dependencies = [
"bitflags 1.3.2",
"core-foundation-sys",
"coreaudio-sys",
]
[[package]]
name = "coreaudio-sys"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ceec7a6067e62d6f931a2baf6f3a751f4a892595bcec1461a3c94ef9949864b6"
dependencies = [
"bindgen",
]
[[package]]
name = "cpal"
version = "0.15.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "873dab07c8f743075e57f524c583985fbaf745602acbe916a01539364369a779"
dependencies = [
"alsa",
"core-foundation-sys",
"coreaudio-rs",
"dasp_sample",
"jni",
"js-sys",
"libc",
"mach2 0.4.3",
"ndk",
"ndk-context",
"oboe",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"windows",
]
[[package]]
name = "cpp_demangle"
version = "0.4.5"
@@ -1588,6 +1707,12 @@ dependencies = [
"parking_lot_core",
]
[[package]]
name = "dasp_sample"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f"
[[package]]
name = "data-encoding"
version = "2.10.0"
@@ -2944,7 +3069,7 @@ dependencies = [
"js-sys",
"log",
"wasm-bindgen",
"windows-core",
"windows-core 0.62.2",
]
[[package]]
@@ -3319,9 +3444,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2"
[[package]]
name = "iri-string"
version = "0.7.10"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
dependencies = [
"memchr",
"serde",
@@ -3362,9 +3487,9 @@ dependencies = [
[[package]]
name = "itoa"
version = "1.0.17"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682"
[[package]]
name = "ittapi"
@@ -3395,6 +3520,50 @@ dependencies = [
"serde",
]
[[package]]
name = "jni"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys 0.3.1",
"log",
"thiserror 1.0.69",
"walkdir",
"windows-sys 0.45.0",
]
[[package]]
name = "jni-sys"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
dependencies = [
"jni-sys 0.4.1",
]
[[package]]
name = "jni-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
dependencies = [
"jni-sys-macros",
]
[[package]]
name = "jni-sys-macros"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
dependencies = [
"quote",
"syn 2.0.117",
]
[[package]]
name = "jobserver"
version = "0.1.34"
@@ -4202,9 +4371,9 @@ dependencies = [
[[package]]
name = "moka"
version = "0.12.14"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
dependencies = [
"async-lock",
"crossbeam-channel",
@@ -4242,6 +4411,35 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11ec1bc47d34ae756616f387c11fd0595f86f2cc7e6473bde9e3ded30cb902a1"
[[package]]
name = "ndk"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2076a31b7010b17a38c01907c45b945e8f11495ee4dd588309718901b1f7a5b7"
dependencies = [
"bitflags 2.11.0",
"jni-sys 0.3.1",
"log",
"ndk-sys",
"num_enum",
"thiserror 1.0.69",
]
[[package]]
name = "ndk-context"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b"
[[package]]
name = "ndk-sys"
version = "0.5.0+25.2.9519653"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691"
dependencies = [
"jni-sys 0.3.1",
]
[[package]]
name = "negentropy"
version = "0.5.0"
@@ -4421,6 +4619,17 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050"
[[package]]
name = "num-derive"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "num-traits"
version = "0.2.19"
@@ -4440,6 +4649,28 @@ dependencies = [
"libc",
]
[[package]]
name = "num_enum"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26"
dependencies = [
"num_enum_derive",
"rustversion",
]
[[package]]
name = "num_enum_derive"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "nusb"
version = "0.2.3"
@@ -4519,6 +4750,29 @@ dependencies = [
"ruzstd",
]
[[package]]
name = "oboe"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8b61bebd49e5d43f5f8cc7ee2891c16e0f41ec7954d36bcb6c14c5e0de867fb"
dependencies = [
"jni",
"ndk",
"ndk-context",
"num-derive",
"num-traits",
"oboe-sys",
]
[[package]]
name = "oboe-sys"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8bb09a4a2b1d668170cfe0a7d5bc103f8999fb316c98099b6a9939c9f2e79d"
dependencies = [
"cc",
]
[[package]]
name = "once_cell"
version = "1.21.4"
@@ -5281,9 +5535,9 @@ dependencies = [
[[package]]
name = "pulldown-cmark"
version = "0.13.1"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6"
checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad"
dependencies = [
"bitflags 2.11.0",
"memchr",
@@ -5401,9 +5655,9 @@ dependencies = [
[[package]]
name = "quoted_printable"
version = "0.5.1"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "640c9bd8497b02465aeef5375144c26062e0dcd5939dfcbb0f5db76cb8c17c73"
checksum = "478e0585659a122aa407eb7e3c0e1fa51b1d8a870038bd29f0cf4a8551eea972"
[[package]]
name = "r-efi"
@@ -7577,9 +7831,9 @@ checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae"
[[package]]
name = "ureq"
version = "3.2.0"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc"
checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0"
dependencies = [
"base64",
"cookie_store",
@@ -7591,15 +7845,15 @@ dependencies = [
"serde",
"serde_json",
"ureq-proto",
"utf-8",
"utf8-zero",
"webpki-roots 1.0.6",
]
[[package]]
name = "ureq-proto"
version = "0.5.3"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c"
dependencies = [
"base64",
"http 1.4.0",
@@ -7632,6 +7886,12 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf8-zero"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e"
[[package]]
name = "utf8_iter"
version = "1.0.4"
@@ -8691,6 +8951,26 @@ dependencies = [
"wasmtime-internal-math",
]
[[package]]
name = "windows"
version = "0.54.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9252e5725dbed82865af151df558e754e4a3c2c30818359eb17465f1346a1b49"
dependencies = [
"windows-core 0.54.0",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-core"
version = "0.54.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12661b9c89351d684a50a8a643ce5f608e20243b9fb84687800163429f161d65"
dependencies = [
"windows-result 0.1.2",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-core"
version = "0.62.2"
@@ -8700,7 +8980,7 @@ dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-result 0.4.1",
"windows-strings",
]
@@ -8732,6 +9012,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-result"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.4.1"
@@ -8750,6 +9039,15 @@ dependencies = [
"windows-link",
]
[[package]]
name = "windows-sys"
version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [
"windows-targets 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
@@ -8786,6 +9084,21 @@ dependencies = [
"windows-link",
]
[[package]]
name = "windows-targets"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
@@ -8819,6 +9132,12 @@ dependencies = [
"windows_x86_64_msvc 0.53.1",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
@@ -8831,6 +9150,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
@@ -8843,6 +9168,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
@@ -8867,6 +9198,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
@@ -8879,6 +9216,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
@@ -8891,6 +9234,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
@@ -8903,6 +9252,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
@@ -9203,7 +9558,7 @@ dependencies = [
[[package]]
name = "zeroclawlabs"
version = "0.5.5"
version = "0.5.9"
dependencies = [
"aardvark-sys",
"anyhow",
@@ -9217,6 +9572,7 @@ dependencies = [
"clap",
"clap_complete",
"console",
"cpal",
"criterion",
"cron",
"dialoguer",
@@ -9298,18 +9654,18 @@ dependencies = [
[[package]]
name = "zerocopy"
version = "0.8.42"
version = "0.8.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3"
checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.42"
version = "0.8.47"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f"
checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89"
dependencies = [
"proc-macro2",
"quote",
@@ -9414,9 +9770,9 @@ dependencies = [
[[package]]
name = "zip"
version = "8.3.0"
version = "8.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a243cfad17427fc077f529da5a95abe4e94fd2bfdb601611870a6557cc67657"
checksum = "5c546feb4481b0fbafb4ef0d79b6204fc41c6f9884b1b73b1d73f82442fc0845"
dependencies = [
"crc32fast",
"flate2",
@@ -9486,9 +9842,9 @@ checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9"
[[package]]
name = "zune-jpeg"
version = "0.5.13"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec5f41c76397b7da451efd19915684f727d7e1d516384ca6bd0ec43ec94de23c"
checksum = "0b7a1c0af6e5d8d1363f4994b7a091ccf963d8b694f7da5b0b9cceb82da2c0a6"
dependencies = [
"zune-core",
]
+27 -4
View File
@@ -4,7 +4,7 @@ resolver = "2"
[package]
name = "zeroclawlabs"
version = "0.5.5"
version = "0.5.9"
edition = "2021"
authors = ["theonlyhennygod"]
license = "MIT OR Apache-2.0"
@@ -97,7 +97,7 @@ anyhow = "1.0"
thiserror = "2.0"
# Aardvark I2C/SPI/GPIO USB adapter (Total Phase) — stub when SDK absent
aardvark-sys = { path = "crates/aardvark-sys" }
aardvark-sys = { path = "crates/aardvark-sys", version = "0.1.0" }
# UUID generation
uuid = { version = "1.22", default-features = false, features = ["v4", "std"] }
@@ -199,6 +199,9 @@ pdf-extract = { version = "0.10", optional = true }
# WASM plugin runtime (extism)
extism = { version = "1.20", optional = true }
# Cross-platform audio capture for voice wake word detection (optional, enable with --features voice-wake)
cpal = { version = "0.15", optional = true }
# Terminal QR rendering for WhatsApp Web pairing flow.
qrcode = { version = "0.14", optional = true }
@@ -228,8 +231,6 @@ channel-matrix = ["dep:matrix-sdk"]
channel-lark = ["dep:prost"]
channel-feishu = ["channel-lark"] # Alias for Feishu users (Lark and Feishu are the same platform)
memory-postgres = ["dep:postgres"]
# memory-mem0 = Mem0 (OpenMemory) memory backend via REST API
memory-mem0 = []
observability-prometheus = ["dep:prometheus"]
observability-otel = ["dep:opentelemetry", "dep:opentelemetry_sdk", "dep:opentelemetry-otlp"]
peripheral-rpi = ["rppal"]
@@ -252,8 +253,30 @@ rag-pdf = ["dep:pdf-extract"]
skill-creation = []
# whatsapp-web = Native WhatsApp Web client with custom rusqlite storage backend
whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-proto", "dep:wa-rs-ureq-http", "dep:wa-rs-tokio-transport", "dep:serde-big-array", "dep:prost", "dep:qrcode"]
# voice-wake = Voice wake word detection via microphone (cpal)
voice-wake = ["dep:cpal"]
# WASM plugin system (extism-based)
plugins-wasm = ["dep:extism"]
# Meta-feature for CI: all features except those requiring system C libraries
# not available on standard CI runners (e.g., voice-wake needs libasound2-dev).
ci-all = [
"channel-nostr",
"hardware",
"channel-matrix",
"channel-lark",
"memory-postgres",
"observability-prometheus",
"observability-otel",
"peripheral-rpi",
"browser-native",
"sandbox-landlock",
"sandbox-bubblewrap",
"probe",
"rag-pdf",
"skill-creation",
"whatsapp-web",
"plugins-wasm",
]
[profile.release]
opt-level = "z" # Optimize for size
+1
View File
@@ -27,6 +27,7 @@ COPY Cargo.toml Cargo.lock ./
# Previously we used sed to drop `crates/robot-kit`, which made the manifest disagree
# with the lockfile and caused `cargo --locked` to fail (Cargo refused to rewrite the lock).
COPY crates/robot-kit/ crates/robot-kit/
COPY crates/aardvark-sys/ crates/aardvark-sys/
# Create dummy targets declared in Cargo.toml so manifest parsing succeeds.
RUN mkdir -p src benches \
&& echo "fn main() {}" > src/main.rs \
-80
View File
@@ -1,80 +0,0 @@
#!/bin/bash
# Start mem0 + reranker GPU container for ZeroClaw memory backend.
#
# Required env vars:
# MEM0_LLM_API_KEY or ZAI_API_KEY — API key for the LLM used in fact extraction
#
# Optional env vars (with defaults):
# MEM0_LLM_PROVIDER — mem0 LLM provider (default: "openai" i.e. OpenAI-compatible)
# MEM0_LLM_MODEL — LLM model for fact extraction (default: "glm-5-turbo")
# MEM0_LLM_BASE_URL — LLM API base URL (default: "https://api.z.ai/api/coding/paas/v4")
# MEM0_EMBEDDER_MODEL — embedding model (default: "BAAI/bge-m3")
# MEM0_EMBEDDER_DIMS — embedding dimensions (default: "1024")
# MEM0_EMBEDDER_DEVICE — "cuda", "cpu", or "auto" (default: "cuda")
# MEM0_VECTOR_COLLECTION — Qdrant collection name (default: "zeroclaw_mem0")
# RERANKER_MODEL — reranker model (default: "BAAI/bge-reranker-v2-m3")
# RERANKER_DEVICE — "cuda" or "cpu" (default: "cuda")
# MEM0_PORT — mem0 server port (default: 8765)
# RERANKER_PORT — reranker server port (default: 8678)
# CONTAINER_IMAGE — base container image (default: docker.io/kyuz0/amd-strix-halo-comfyui:latest)
# CONTAINER_NAME — container name (default: mem0-gpu)
# DATA_DIR — host path for Qdrant data (default: ~/mem0-data)
# SCRIPT_DIR — host path for server scripts (default: directory of this script)
set -e
# Resolve script directory for mounting server scripts
SCRIPT_DIR="${SCRIPT_DIR:-$(cd "$(dirname "$0")" && pwd)}"
# API key — accept either name
export MEM0_LLM_API_KEY="${MEM0_LLM_API_KEY:-${ZAI_API_KEY:?MEM0_LLM_API_KEY or ZAI_API_KEY must be set}}"
# Defaults
MEM0_LLM_MODEL="${MEM0_LLM_MODEL:-glm-5-turbo}"
MEM0_LLM_BASE_URL="${MEM0_LLM_BASE_URL:-https://api.z.ai/api/coding/paas/v4}"
MEM0_PORT="${MEM0_PORT:-8765}"
RERANKER_PORT="${RERANKER_PORT:-8678}"
CONTAINER_IMAGE="${CONTAINER_IMAGE:-docker.io/kyuz0/amd-strix-halo-comfyui:latest}"
CONTAINER_NAME="${CONTAINER_NAME:-mem0-gpu}"
DATA_DIR="${DATA_DIR:-$HOME/mem0-data}"
# Stop existing CPU services (if any)
kill -9 $(pgrep -f "mem0-server.py") 2>/dev/null || true
kill -9 $(pgrep -f "reranker-server.py") 2>/dev/null || true
# Stop existing container
podman stop "$CONTAINER_NAME" 2>/dev/null || true
podman rm "$CONTAINER_NAME" 2>/dev/null || true
podman run -d --name "$CONTAINER_NAME" \
--device /dev/dri --device /dev/kfd \
--group-add video --group-add render \
--restart unless-stopped \
-p "$MEM0_PORT:$MEM0_PORT" -p "$RERANKER_PORT:$RERANKER_PORT" \
-v "$DATA_DIR":/root/mem0-data:Z \
-v "$SCRIPT_DIR/mem0-server.py":/app/mem0-server.py:ro,Z \
-v "$SCRIPT_DIR/reranker-server.py":/app/reranker-server.py:ro,Z \
-v "$HOME/.cache/huggingface":/root/.cache/huggingface:Z \
-e MEM0_LLM_API_KEY="$MEM0_LLM_API_KEY" \
-e ZAI_API_KEY="$MEM0_LLM_API_KEY" \
-e MEM0_LLM_MODEL="$MEM0_LLM_MODEL" \
-e MEM0_LLM_BASE_URL="$MEM0_LLM_BASE_URL" \
${MEM0_LLM_PROVIDER:+-e MEM0_LLM_PROVIDER="$MEM0_LLM_PROVIDER"} \
${MEM0_EMBEDDER_MODEL:+-e MEM0_EMBEDDER_MODEL="$MEM0_EMBEDDER_MODEL"} \
${MEM0_EMBEDDER_DIMS:+-e MEM0_EMBEDDER_DIMS="$MEM0_EMBEDDER_DIMS"} \
${MEM0_EMBEDDER_DEVICE:+-e MEM0_EMBEDDER_DEVICE="$MEM0_EMBEDDER_DEVICE"} \
${MEM0_VECTOR_COLLECTION:+-e MEM0_VECTOR_COLLECTION="$MEM0_VECTOR_COLLECTION"} \
${RERANKER_MODEL:+-e RERANKER_MODEL="$RERANKER_MODEL"} \
${RERANKER_DEVICE:+-e RERANKER_DEVICE="$RERANKER_DEVICE"} \
-e RERANKER_PORT="$RERANKER_PORT" \
-e RERANKER_URL="http://127.0.0.1:$RERANKER_PORT/rerank" \
-e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
-e HOME=/root \
"$CONTAINER_IMAGE" \
bash -c "pip install -q FlagEmbedding mem0ai flask httpx qdrant-client 2>&1 | tail -3; echo '=== Starting reranker (GPU) on :$RERANKER_PORT ==='; python3 /app/reranker-server.py & sleep 3; echo '=== Starting mem0 (GPU) on :$MEM0_PORT ==='; exec python3 /app/mem0-server.py"
echo "Container started, waiting for init..."
sleep 15
echo "=== Container logs ==="
podman logs "$CONTAINER_NAME" 2>&1 | tail -25
echo "=== Port check ==="
ss -tlnp | grep "$MEM0_PORT\|$RERANKER_PORT" || echo "Ports not yet ready, check: podman logs $CONTAINER_NAME"
-288
View File
@@ -1,288 +0,0 @@
"""Minimal OpenMemory-compatible REST server wrapping mem0 Python SDK."""
import asyncio
import json, os, uuid, httpx
from datetime import datetime, timezone
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import Optional
from mem0 import Memory
app = FastAPI()
RERANKER_URL = os.environ.get("RERANKER_URL", "http://127.0.0.1:8678/rerank")
CUSTOM_EXTRACTION_PROMPT = """You are a memory extraction specialist for a Cantonese/Chinese chat assistant.
Extract ONLY important, persistent facts from the conversation. Rules:
1. Extract personal preferences, habits, relationships, names, locations
2. Extract decisions, plans, and commitments people make
3. SKIP small talk, greetings, reactions ("ok", "哈哈", "係呀")
4. SKIP temporary states ("我依家食緊飯") unless they reveal a habit
5. Keep facts in the ORIGINAL language (Cantonese/Chinese/English)
6. For each fact, note WHO it's about (use their name or identifier if known)
7. Merge/update existing facts rather than creating duplicates
Return a list of facts in JSON format: {"facts": ["fact1", "fact2", ...]}
"""
PROCEDURAL_EXTRACTION_PROMPT = """You are a procedural memory specialist for an AI assistant.
Extract HOW-TO patterns and reusable procedures from the conversation trace. Rules:
1. Identify step-by-step procedures the assistant followed to accomplish a task
2. Extract tool usage patterns: which tools were called, in what order, with what arguments
3. Capture decision points: why the assistant chose one approach over another
4. Note error-recovery patterns: what failed, how it was fixed
5. Keep the procedure generic enough to apply to similar future tasks
6. Preserve technical details (commands, file paths, API calls) that are reusable
7. SKIP greetings, small talk, and conversational filler
8. Format each procedure as: "To [goal]: [step1] -> [step2] -> ... -> [result]"
Return a list of procedures in JSON format: {"facts": ["procedure1", "procedure2", ...]}
"""
# ── Configurable via environment variables ─────────────────────────
# LLM (for fact extraction when infer=true)
MEM0_LLM_PROVIDER = os.environ.get("MEM0_LLM_PROVIDER", "openai") # "openai" (compatible), "anthropic", etc.
MEM0_LLM_MODEL = os.environ.get("MEM0_LLM_MODEL", "glm-5-turbo")
MEM0_LLM_API_KEY = os.environ.get("MEM0_LLM_API_KEY") or os.environ.get("ZAI_API_KEY", "")
MEM0_LLM_BASE_URL = os.environ.get("MEM0_LLM_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
# Embedder
MEM0_EMBEDDER_PROVIDER = os.environ.get("MEM0_EMBEDDER_PROVIDER", "huggingface") # "huggingface", "openai", etc.
MEM0_EMBEDDER_MODEL = os.environ.get("MEM0_EMBEDDER_MODEL", "BAAI/bge-m3")
MEM0_EMBEDDER_DIMS = int(os.environ.get("MEM0_EMBEDDER_DIMS", "1024"))
MEM0_EMBEDDER_DEVICE = os.environ.get("MEM0_EMBEDDER_DEVICE", "cuda") # "cuda", "cpu", "auto"
# Vector store
MEM0_VECTOR_PROVIDER = os.environ.get("MEM0_VECTOR_PROVIDER", "qdrant") # "qdrant", "chroma", etc.
MEM0_VECTOR_COLLECTION = os.environ.get("MEM0_VECTOR_COLLECTION", "zeroclaw_mem0")
MEM0_VECTOR_PATH = os.environ.get("MEM0_VECTOR_PATH", os.path.expanduser("~/mem0-data/qdrant"))
config = {
"llm": {
"provider": MEM0_LLM_PROVIDER,
"config": {
"model": MEM0_LLM_MODEL,
"api_key": MEM0_LLM_API_KEY,
"openai_base_url": MEM0_LLM_BASE_URL,
},
},
"embedder": {
"provider": MEM0_EMBEDDER_PROVIDER,
"config": {
"model": MEM0_EMBEDDER_MODEL,
"embedding_dims": MEM0_EMBEDDER_DIMS,
"model_kwargs": {"device": MEM0_EMBEDDER_DEVICE},
},
},
"vector_store": {
"provider": MEM0_VECTOR_PROVIDER,
"config": {
"collection_name": MEM0_VECTOR_COLLECTION,
"embedding_model_dims": MEM0_EMBEDDER_DIMS,
"path": MEM0_VECTOR_PATH,
},
},
"custom_fact_extraction_prompt": CUSTOM_EXTRACTION_PROMPT,
}
m = Memory.from_config(config)
def rerank_results(query: str, items: list, top_k: int = 10) -> list:
"""Rerank search results using bge-reranker-v2-m3."""
if not items:
return items
documents = [item.get("memory", "") for item in items]
try:
resp = httpx.post(
RERANKER_URL,
json={"query": query, "documents": documents, "top_k": top_k},
timeout=10.0,
)
resp.raise_for_status()
ranked = resp.json().get("results", [])
return [items[r["index"]] for r in ranked]
except Exception as e:
print(f"Reranker failed, using original order: {e}")
return items
class AddMemoryRequest(BaseModel):
user_id: str
text: str
metadata: Optional[dict] = None
infer: bool = True
app: Optional[str] = None
custom_instructions: Optional[str] = None
@app.post("/api/v1/memories/")
async def add_memory(req: AddMemoryRequest):
# Use client-supplied prompt, fall back to server default, then mem0 SDK default
prompt = req.custom_instructions or CUSTOM_EXTRACTION_PROMPT
result = await asyncio.to_thread(m.add, req.text, user_id=req.user_id, metadata=req.metadata or {}, prompt=prompt)
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
class ProceduralMemoryRequest(BaseModel):
user_id: str
messages: list[dict]
metadata: Optional[dict] = None
@app.post("/api/v1/memories/procedural")
async def add_procedural_memory(req: ProceduralMemoryRequest):
"""Store a conversation trace as procedural memory.
Accepts a list of messages (role/content dicts) representing a full
conversation turn including tool calls, then uses mem0's native
procedural memory extraction to learn reusable "how to" patterns.
"""
# Build metadata with procedural type marker
meta = {"type": "procedural"}
if req.metadata:
meta.update(req.metadata)
# Use mem0's native message list support + procedural prompt
result = await asyncio.to_thread(m.add,
req.messages,
user_id=req.user_id,
metadata=meta,
prompt=PROCEDURAL_EXTRACTION_PROMPT,
)
return {"id": str(uuid.uuid4()), "status": "ok", "result": result}
def _parse_mem0_results(raw_results) -> list:
raw = raw_results.get("results", raw_results) if isinstance(raw_results, dict) else raw_results
items = []
for r in raw:
item = r if isinstance(r, dict) else {"memory": str(r)}
items.append({
"id": item.get("id", str(uuid.uuid4())),
"memory": item.get("memory", item.get("text", "")),
"created_at": item.get("created_at", datetime.now(timezone.utc).isoformat()),
"metadata_": item.get("metadata", {}),
})
return items
def _parse_iso_timestamp(value: str) -> Optional[datetime]:
"""Parse an ISO 8601 timestamp string, returning None on failure."""
try:
dt = datetime.fromisoformat(value)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
except (ValueError, TypeError):
return None
def _item_created_at(item: dict) -> Optional[datetime]:
"""Extract created_at from an item as a timezone-aware datetime."""
raw = item.get("created_at")
if raw is None:
return None
if isinstance(raw, datetime):
if raw.tzinfo is None:
raw = raw.replace(tzinfo=timezone.utc)
return raw
return _parse_iso_timestamp(str(raw))
def _apply_post_filters(
items: list,
created_after: Optional[str],
created_before: Optional[str],
) -> list:
"""Filter items by created_after / created_before timestamps (post-query)."""
after_dt = _parse_iso_timestamp(created_after) if created_after else None
before_dt = _parse_iso_timestamp(created_before) if created_before else None
if after_dt is None and before_dt is None:
return items
filtered = []
for item in items:
ts = _item_created_at(item)
if ts is None:
# Keep items without a parseable timestamp
filtered.append(item)
continue
if after_dt and ts < after_dt:
continue
if before_dt and ts > before_dt:
continue
filtered.append(item)
return filtered
@app.get("/api/v1/memories/")
async def list_or_search_memories(
user_id: str = Query(...),
search_query: Optional[str] = Query(None),
size: int = Query(10),
rerank: bool = Query(True),
created_after: Optional[str] = Query(None),
created_before: Optional[str] = Query(None),
metadata_filter: Optional[str] = Query(None),
):
# Build mem0 SDK filters dict from metadata_filter JSON param
sdk_filters = None
if metadata_filter:
try:
sdk_filters = json.loads(metadata_filter)
except json.JSONDecodeError:
sdk_filters = None
if search_query:
# Fetch more results than needed so reranker has candidates to work with
fetch_size = min(size * 3, 50)
results = await asyncio.to_thread(m.search,
search_query,
user_id=user_id,
limit=fetch_size,
filters=sdk_filters,
)
items = _parse_mem0_results(results)
items = _apply_post_filters(items, created_after, created_before)
if rerank and items:
items = rerank_results(search_query, items, top_k=size)
else:
items = items[:size]
return {"items": items, "total": len(items)}
else:
results = await asyncio.to_thread(m.get_all,user_id=user_id, filters=sdk_filters)
items = _parse_mem0_results(results)
items = _apply_post_filters(items, created_after, created_before)
return {"items": items, "total": len(items)}
@app.delete("/api/v1/memories/{memory_id}")
async def delete_memory(memory_id: str):
try:
await asyncio.to_thread(m.delete, memory_id)
except Exception:
pass
return {"status": "ok"}
@app.get("/api/v1/memories/{memory_id}/history")
async def get_memory_history(memory_id: str):
"""Return the edit history of a specific memory."""
try:
history = await asyncio.to_thread(m.history, memory_id)
# Normalize to list of dicts
entries = []
raw = history if isinstance(history, list) else history.get("results", history) if isinstance(history, dict) else [history]
for h in raw:
entry = h if isinstance(h, dict) else {"event": str(h)}
entries.append(entry)
return {"memory_id": memory_id, "history": entries}
except Exception as e:
return {"memory_id": memory_id, "history": [], "error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8765)
-50
View File
@@ -1,50 +0,0 @@
from flask import Flask, request, jsonify
from FlagEmbedding import FlagReranker
import os, torch
app = Flask(__name__)
reranker = None
# ── Configurable via environment variables ─────────────────────────
RERANKER_MODEL = os.environ.get("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
RERANKER_DEVICE = os.environ.get("RERANKER_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
RERANKER_PORT = int(os.environ.get("RERANKER_PORT", "8678"))
def get_reranker():
global reranker
if reranker is None:
reranker = FlagReranker(RERANKER_MODEL, use_fp16=True, device=RERANKER_DEVICE)
return reranker
@app.route('/rerank', methods=['POST'])
def rerank():
data = request.json
query = data.get('query', '')
documents = data.get('documents', [])
top_k = data.get('top_k', len(documents))
if not query or not documents:
return jsonify({'error': 'query and documents required'}), 400
pairs = [[query, doc] for doc in documents]
scores = get_reranker().compute_score(pairs)
if isinstance(scores, float):
scores = [scores]
results = sorted(
[{'index': i, 'document': doc, 'score': score}
for i, (doc, score) in enumerate(zip(documents, scores))],
key=lambda x: x['score'], reverse=True
)[:top_k]
return jsonify({'results': results})
@app.route('/health', methods=['GET'])
def health():
return jsonify({'status': 'ok', 'model': RERANKER_MODEL, 'device': RERANKER_DEVICE})
if __name__ == '__main__':
print(f'Loading reranker model ({RERANKER_MODEL}) on {RERANKER_DEVICE}...')
get_reranker()
print(f'Reranker server ready on :{RERANKER_PORT}')
app.run(host='0.0.0.0', port=RERANKER_PORT)
+2 -2
View File
@@ -1,6 +1,6 @@
pkgbase = zeroclaw
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
pkgver = 0.5.5
pkgver = 0.5.9
pkgrel = 1
url = https://github.com/zeroclaw-labs/zeroclaw
arch = x86_64
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
makedepends = git
depends = gcc-libs
depends = openssl
source = zeroclaw-0.5.5.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.5.tar.gz
source = zeroclaw-0.5.9.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.5.9.tar.gz
sha256sums = SKIP
pkgname = zeroclaw
+1 -1
View File
@@ -1,6 +1,6 @@
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
pkgname=zeroclaw
pkgver=0.5.5
pkgver=0.5.9
pkgrel=1
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
arch=('x86_64')
+2 -2
View File
@@ -1,11 +1,11 @@
{
"version": "0.5.5",
"version": "0.5.9",
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
"license": "MIT|Apache-2.0",
"architecture": {
"64bit": {
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.5/zeroclaw-x86_64-pc-windows-msvc.zip",
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.5.9/zeroclaw-x86_64-pc-windows-msvc.zip",
"hash": "",
"bin": "zeroclaw.exe"
}
@@ -0,0 +1,202 @@
# ADR-004: Tool Shared State Ownership Contract
**Status:** Accepted
**Date:** 2026-03-22
**Issue:** [#4057](https://github.com/zeroclaw/zeroclaw/issues/4057)
## Context
ZeroClaw tools execute in a multi-client environment where a single daemon
process serves requests from multiple connected clients simultaneously. Several
tools already maintain long-lived shared state:
- **`DelegateParentToolsHandle`** (`src/tools/mod.rs`):
`Arc<RwLock<Vec<Arc<dyn Tool>>>>` — holds parent tools for delegate agents
with no per-client isolation.
- **`ChannelMapHandle`** (`src/tools/reaction.rs`):
`Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>` — global channel map shared
across all clients.
- **`CanvasStore`** (`src/tools/canvas.rs`):
`Arc<RwLock<HashMap<String, CanvasEntry>>>` — canvas IDs are plain strings
with no client namespace.
These patterns emerged organically. As the tool surface grows and more clients
connect concurrently, we need a clear contract governing ownership, identity,
isolation, lifecycle, and reload behavior for tool-held shared state. Without
this contract, new tools risk introducing data leaks between clients, stale
state after config reloads, or inconsistent initialization timing.
Additional context:
- The tool registry is immutable after startup, built once in
`all_tools_with_runtime()`.
- Client identity is currently derived from IP address only
(`src/gateway/mod.rs`), which is insufficient for reliable namespacing.
- `SecurityPolicy` is scoped per agent, not per client.
- `WorkspaceManager` provides some isolation but workspace switching is global.
## Decision
### 1. Ownership: May tools own long-lived shared state?
**Yes.** Tools MAY own long-lived shared state, provided they follow the
established **handle pattern**: wrap the state in `Arc<RwLock<T>>` (or
`Arc<parking_lot::RwLock<T>>`) and expose a cloneable handle type.
This pattern is already proven by three independent implementations:
| Handle | Location | Inner type |
|--------|----------|-----------|
| `DelegateParentToolsHandle` | `src/tools/mod.rs` | `Vec<Arc<dyn Tool>>` |
| `ChannelMapHandle` | `src/tools/reaction.rs` | `HashMap<String, Arc<dyn Channel>>` |
| `CanvasStore` | `src/tools/canvas.rs` | `HashMap<String, CanvasEntry>` |
Tools that need shared state MUST:
- Define a named handle type alias (e.g., `pub type FooHandle = Arc<RwLock<T>>`).
- Accept the handle at construction time rather than creating global state.
- Document the concurrency contract in the handle type's doc comment.
Tools MUST NOT use static mutable state (`lazy_static!`, `OnceCell` with
interior mutability) for per-request or per-client data.
### 2. Identity assignment: Who constructs identity keys?
**The daemon SHOULD provide identity.** Tools MUST NOT construct their own
client identity keys.
A new `ClientId` type should be introduced (opaque, `Clone + Eq + Hash + Send + Sync`)
that the daemon assigns at connection time. This replaces the current approach
of using raw IP addresses (`src/gateway/mod.rs:259-306`), which breaks when
multiple clients share a NAT address or when proxied connections arrive.
`ClientId` is passed to tools that require per-client state namespacing as part
of the tool execution context. Tools that do not need per-client isolation
(e.g., the immutable tool registry) may ignore it.
The `ClientId` contract:
- Generated by the gateway layer at connection establishment.
- Opaque to tools — tools must not parse or derive meaning from the value.
- Stable for the lifetime of a single client session.
- Passed through the execution context, not stored globally.
### 3. Lifecycle: When may tools run startup-style validation?
**Validation runs once at first registration, and again when config changes
are detected.**
The lifecycle phases are:
1. **Construction** — tool is instantiated with handles and config. No I/O or
validation occurs here.
2. **Registration** — tool is registered in the tool registry via
`all_tools_with_runtime()`. At this point the tool MAY perform one-time
startup validation (e.g., checking that required credentials exist, verifying
external service connectivity).
3. **Execution** — tool handles individual requests. No re-validation unless
the config-change signal fires (see Reload Semantics below).
4. **Shutdown** — daemon is stopping. Tools with open resources SHOULD clean up
gracefully via `Drop` or an explicit shutdown method.
Tools MUST NOT perform blocking validation during execution-phase calls.
Validation results SHOULD be cached in the tool's handle state and checked
via a fast path during execution.
### 4. Isolation: What must be isolated per client?
State falls into two categories with different isolation requirements:
**MUST be isolated per client:**
- Security-sensitive state: credentials, API keys, quotas, rate-limit counters,
per-client authorization decisions.
- User-specific session data: conversation context, user preferences,
workspace-scoped file paths.
Isolation mechanism: tools holding per-client state MUST key their internal
maps by `ClientId`. The handle pattern naturally supports this by using
`HashMap<ClientId, T>` inside the `RwLock`.
**MAY be shared across clients (with namespace prefixing):**
- Broadcast/display state: canvas frames (`CanvasStore`), notification channels
(`ChannelMapHandle`).
- Read-only reference data: tool registry, static configuration, model
metadata.
When shared state uses string keys (e.g., canvas IDs, channel names), tools
SHOULD support optional namespace prefixing (e.g., `{client_id}:{canvas_name}`)
to allow per-client isolation when needed without mandating it for broadcast
use cases.
Tools MUST NOT store per-client secrets in shared (non-isolated) state
structures.
### 5. Reload semantics: What invalidates prior shared state on config change?
**Config changes detected via hash comparison MUST invalidate cached
validation state.**
The reload contract:
- The daemon computes a hash of the tool-relevant config section at startup and
after each config reload event.
- When the hash changes, the daemon signals affected tools to re-run their
registration-phase validation.
- Tools MUST treat their cached validation result as stale when signaled and
re-validate before the next execution.
Specific invalidation rules:
| Config change | Invalidation scope |
|--------------|-------------------|
| Credential/secret rotation | Per-tool validation cache; per-client credential state |
| Tool enable/disable | Full tool registry rebuild via `all_tools_with_runtime()` |
| Security policy change | `SecurityPolicy` re-derivation; per-agent policy state |
| Workspace directory change | `WorkspaceManager` state; file-path-dependent tool state |
| Provider config change | Provider-dependent tools re-validate connectivity |
Tools MAY retain non-security shared state (e.g., canvas content, channel
subscriptions) across config reloads unless the reload explicitly affects that
state's validity.
## Consequences
### Positive
- **Consistency:** All new tools follow the same handle pattern, making shared
state discoverable and auditable.
- **Safety:** Per-client isolation of security-sensitive state prevents data
leaks in multi-tenant scenarios.
- **Clarity:** Explicit lifecycle phases eliminate ambiguity about when
validation runs.
- **Evolvability:** The `ClientId` abstraction decouples tools from transport
details, supporting future identity mechanisms (tokens, certificates).
### Negative
- **Migration cost:** Existing tools (`CanvasStore`, `ReactionTool`) may need
refactoring to accept `ClientId` and namespace their state.
- **Complexity:** Tools that were simple singletons now need to consider
multi-client semantics even if they currently have one client.
- **Performance:** Per-client keying adds a hash lookup on each access, though
this is negligible compared to I/O costs.
### Neutral
- The tool registry remains immutable after startup; this ADR does not change
that invariant.
- `SecurityPolicy` remains per-agent; this ADR documents that client isolation
is orthogonal to agent-level policy.
## References
- `src/tools/mod.rs``DelegateParentToolsHandle`, `all_tools_with_runtime()`
- `src/tools/reaction.rs``ChannelMapHandle`, `ReactionTool`
- `src/tools/canvas.rs``CanvasStore`, `CanvasEntry`
- `src/tools/traits.rs``Tool` trait
- `src/gateway/mod.rs` — client IP extraction (`forwarded_client_ip`, `resolve_client_ip`)
- `src/security/``SecurityPolicy`
+215
View File
@@ -0,0 +1,215 @@
# Browser Automation Setup Guide
This guide covers setting up browser automation capabilities in ZeroClaw, including both headless automation and GUI access via VNC.
## Overview
ZeroClaw supports multiple browser access methods:
| Method | Use Case | Requirements |
|--------|----------|--------------|
| **agent-browser CLI** | Headless automation, AI agents | npm, Chrome |
| **VNC + noVNC** | GUI access, debugging | Xvfb, x11vnc, noVNC |
| **Chrome Remote Desktop** | Remote GUI via Google | XFCE, Google account |
## Quick Start: Headless Automation
### 1. Install agent-browser
```bash
# Install CLI
npm install -g agent-browser
# Download Chrome for Testing
agent-browser install --with-deps # Linux (includes system deps)
agent-browser install # macOS/Windows
```
### 2. Verify ZeroClaw Config
The browser tool is enabled by default. To verify or customize, edit
`~/.zeroclaw/config.toml`:
```toml
[browser]
enabled = true # default: true
allowed_domains = ["*"] # default: ["*"] (all public hosts)
backend = "agent_browser" # default: "agent_browser"
native_headless = true # default: true
```
To restrict domains or disable the browser tool:
```toml
[browser]
enabled = false # disable entirely
# or restrict to specific domains:
allowed_domains = ["example.com", "docs.example.com"]
```
### 3. Test
```bash
echo "Open https://example.com and tell me what it says" | zeroclaw agent
```
## VNC Setup (GUI Access)
For debugging or when you need visual browser access:
### Install Dependencies
```bash
# Ubuntu/Debian
apt-get install -y xvfb x11vnc fluxbox novnc websockify
# Optional: Desktop environment for Chrome Remote Desktop
apt-get install -y xfce4 xfce4-goodies
```
### Start VNC Server
```bash
#!/bin/bash
# Start virtual display with VNC access
DISPLAY_NUM=99
VNC_PORT=5900
NOVNC_PORT=6080
RESOLUTION=1920x1080x24
# Start Xvfb
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
sleep 1
# Start window manager
fluxbox -display :$DISPLAY_NUM &
sleep 1
# Start x11vnc
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg
sleep 1
# Start noVNC (web-based VNC)
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
echo "VNC available at:"
echo " VNC Client: localhost:$VNC_PORT"
echo " Web Browser: http://localhost:$NOVNC_PORT/vnc.html"
```
### VNC Access
- **VNC Client**: Connect to `localhost:5900`
- **Web Browser**: Open `http://localhost:6080/vnc.html`
### Start Browser on VNC Display
```bash
DISPLAY=:99 google-chrome --no-sandbox https://example.com &
```
## Chrome Remote Desktop
### Install
```bash
# Download and install
wget https://dl.google.com/linux/direct/chrome-remote-desktop_current_amd64.deb
apt-get install -y ./chrome-remote-desktop_current_amd64.deb
# Configure session
echo "xfce4-session" > ~/.chrome-remote-desktop-session
chmod +x ~/.chrome-remote-desktop-session
```
### Setup
1. Visit <https://remotedesktop.google.com/headless>
2. Copy the "Debian Linux" setup command
3. Run it on your server
4. Start the service: `systemctl --user start chrome-remote-desktop`
### Remote Access
Go to <https://remotedesktop.google.com/access> from any device.
## Testing
### CLI Tests
```bash
# Basic open and close
agent-browser open https://example.com
agent-browser get title
agent-browser close
# Snapshot with refs
agent-browser open https://example.com
agent-browser snapshot -i
agent-browser close
# Screenshot
agent-browser open https://example.com
agent-browser screenshot /tmp/test.png
agent-browser close
```
### ZeroClaw Integration Tests
```bash
# Content extraction
echo "Open https://example.com and summarize it" | zeroclaw agent
# Navigation
echo "Go to https://github.com/trending and list the top 3 repos" | zeroclaw agent
# Form interaction
echo "Go to Wikipedia, search for 'Rust programming language', and summarize" | zeroclaw agent
```
## Troubleshooting
### "Element not found"
The page may not be fully loaded. Add a wait:
```bash
agent-browser open https://slow-site.com
agent-browser wait --load networkidle
agent-browser snapshot -i
```
### Cookie dialogs blocking access
Handle cookie consent first:
```bash
agent-browser open https://site-with-cookies.com
agent-browser snapshot -i
agent-browser click @accept_cookies # Click the accept button
agent-browser snapshot -i # Now get the actual content
```
### Docker sandbox network restrictions
If `web_fetch` fails inside Docker sandbox, use agent-browser instead:
```bash
# Instead of web_fetch, use:
agent-browser open https://example.com
agent-browser get text body
```
## Security Notes
- `agent-browser` runs Chrome in headless mode with sandboxing
- For sensitive sites, use `--session-name` to persist auth state
- The `--allowed-domains` config restricts navigation to specific domains
- VNC ports (5900, 6080) should be behind a firewall or Tailscale
## Related
- [agent-browser Documentation](https://github.com/vercel-labs/agent-browser)
- [ZeroClaw Configuration Reference](./config-reference.md)
- [Skills Documentation](../skills/)
+9
View File
@@ -45,6 +45,15 @@ For complete code examples of each extension trait, see [extension-examples.md](
- Keep multilingual entry-point parity for all supported locales (`en`, `zh-CN`, `ja`, `ru`, `fr`, `vi`) when nav or key wording changes.
- When shared docs wording changes, sync corresponding localized docs in the same PR (or explicitly document deferral and follow-up PR).
## Tool Shared State
- Follow the `Arc<RwLock<T>>` handle pattern for any tool that owns long-lived shared state.
- Accept handles at construction; do not create global/static mutable state.
- Use `ClientId` (provided by the daemon) to namespace per-client state — never construct identity keys inside the tool.
- Isolate security-sensitive state (credentials, quotas) per client; broadcast/display state may be shared with optional namespace prefixing.
- Cached validation is invalidated on config change — tools must re-validate before the next execution when signaled.
- See [ADR-004: Tool Shared State Ownership](../architecture/adr-004-tool-shared-state-ownership.md) for the full contract.
## Architecture Boundary Rules
- Extend capabilities by adding trait implementations + factory wiring first; avoid cross-module rewrites for isolated features.
@@ -411,30 +411,6 @@ allowed_roots = [\"~/Desktop/projects\", \"/opt/shared-repo\"]
- 内存上下文注入忽略旧的 `assistant_resp*` 自动保存键,以防止旧模型生成的摘要被视为事实。
### `[memory.mem0]`
Mem0 (OpenMemory) 后端 — 连接自托管 mem0 服务器,提供基于向量的记忆存储和 LLM 事实提取。构建时需要 `memory-mem0` feature flag,配置需设置 `backend = "mem0"`
| 键 | 默认值 | 环境变量 | 用途 |
|---|---|---|---|
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory 服务器地址 |
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | 记忆作用域的用户 ID |
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | 在 mem0 中注册的应用名称 |
| `infer` | `true` | — | 使用 LLM 从存储文本中提取事实 (`true`) 或原样存储 (`false`) |
| `extraction_prompt` | 未设置 | `MEM0_EXTRACTION_PROMPT` | 自定义 LLM 事实提取提示词(如适用于非英文内容) |
```toml
[memory]
backend = "mem0"
[memory.mem0]
url = "http://192.168.0.171:8765"
user_id = "zeroclaw-bot"
extraction_prompt = "用原始语言提取事实..."
```
服务器部署脚本位于 `deploy/mem0/`
## `[[model_routes]]` 和 `[[embedding_routes]]`
使用路由提示,以便集成可以在模型 ID 演变时保持稳定的名称。
+36 -26
View File
@@ -122,6 +122,34 @@ tools = ["mcp_browser_*"]
keywords = ["browse", "navigate", "open url", "screenshot"]
```
## `[pacing]`
Pacing controls for slow/local LLM workloads (Ollama, llama.cpp, vLLM). All keys are optional; when absent, existing behavior is preserved.
| Key | Default | Purpose |
|---|---|---|
| `step_timeout_secs` | _none_ | Per-step timeout: maximum seconds for a single LLM inference turn. Catches a truly hung model without terminating the overall task loop |
| `loop_detection_min_elapsed_secs` | _none_ | Minimum elapsed seconds before loop detection activates. Tasks completing under this threshold get aggressive loop protection; longer-running tasks receive a grace period |
| `loop_ignore_tools` | `[]` | Tool names excluded from identical-output loop detection. Useful for browser workflows where `browser_screenshot` structurally resembles a loop |
| `message_timeout_scale_max` | `4` | Override for the hardcoded timeout scaling cap. The channel message timeout budget is `message_timeout_secs * min(max_tool_iterations, message_timeout_scale_max)` |
Notes:
- These settings are intended for local/slow LLM deployments. Cloud-provider users typically do not need them.
- `step_timeout_secs` operates independently of the total channel message timeout budget. A step timeout abort does not consume the overall budget; the loop simply stops.
- `loop_detection_min_elapsed_secs` delays loop-detection counting, not the task itself. Loop protection remains fully active for short tasks (the default).
- `loop_ignore_tools` only suppresses tool-output-based loop detection for the listed tools. Other safety features (max iterations, overall timeout) remain active.
- `message_timeout_scale_max` must be >= 1. Setting it higher than `max_tool_iterations` has no additional effect (the formula uses `min()`).
- Example configuration for a slow local Ollama deployment:
```toml
[pacing]
step_timeout_secs = 120
loop_detection_min_elapsed_secs = 60
loop_ignore_tools = ["browser_screenshot", "browser_navigate"]
message_timeout_scale_max = 8
```
## `[security.otp]`
| Key | Default | Purpose |
@@ -425,6 +453,12 @@ Notes:
| `port` | `42617` | gateway listen port |
| `require_pairing` | `true` | require pairing before bearer auth |
| `allow_public_bind` | `false` | block accidental public exposure |
| `path_prefix` | _(none)_ | URL path prefix for reverse-proxy deployments (e.g. `"/zeroclaw"`) |
When deploying behind a reverse proxy that maps ZeroClaw to a sub-path,
set `path_prefix` to that sub-path (e.g. `"/zeroclaw"`). All gateway
routes will be served under this prefix. The value must start with `/`
and must not end with `/`.
## `[autonomy]`
@@ -474,30 +508,6 @@ Notes:
- Memory context injection ignores legacy `assistant_resp*` auto-save keys to prevent old model-authored summaries from being treated as facts.
### `[memory.mem0]`
Mem0 (OpenMemory) backend — connects to a self-hosted mem0 server for vector-based memory with LLM-powered fact extraction. Requires feature flag `memory-mem0` at build time and `backend = "mem0"` in config.
| Key | Default | Env var | Purpose |
|---|---|---|---|
| `url` | `http://localhost:8765` | `MEM0_URL` | OpenMemory server URL |
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID for scoping memories |
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Application name registered in mem0 |
| `infer` | `true` | — | Use LLM to extract facts from stored text (`true`) or store raw (`false`) |
| `extraction_prompt` | unset | `MEM0_EXTRACTION_PROMPT` | Custom prompt for LLM fact extraction (e.g. for non-English content) |
```toml
[memory]
backend = "mem0"
[memory.mem0]
url = "http://192.168.0.171:8765"
user_id = "zeroclaw-bot"
extraction_prompt = "Extract facts in the original language..."
```
Server deployment scripts are in `deploy/mem0/`.
## `[[model_routes]]` and `[[embedding_routes]]`
Use route hints so integrations can keep stable names while model IDs evolve.
@@ -597,7 +607,7 @@ Top-level channel options are configured under `channels_config`.
| Key | Default | Purpose |
|---|---|---|
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x) |
| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x, overridable via `[pacing].message_timeout_scale_max`) |
Examples:
@@ -612,7 +622,7 @@ Examples:
Notes:
- Default `300s` is optimized for on-device LLMs (Ollama) which are slower than cloud APIs.
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, 4)` and a minimum of `1`.
- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, cap)` and a minimum of `1`. The default cap is `4`; override with `[pacing].message_timeout_scale_max`.
- This scaling avoids false timeouts when the first LLM turn is slow/retried but later tool-loop turns still need to complete.
- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower.
- Values below `30` are clamped to `30` to avoid immediate timeout churn.
-24
View File
@@ -337,30 +337,6 @@ Lưu ý:
- Chèn ngữ cảnh memory bỏ qua khóa auto-save `assistant_resp*` kiểu cũ để tránh tóm tắt do model tạo bị coi là sự thật.
### `[memory.mem0]`
Backend Mem0 (OpenMemory) — kết nối đến server mem0 tự host, cung cấp bộ nhớ vector với trích xuất sự kiện bằng LLM. Cần feature flag `memory-mem0` khi build và `backend = "mem0"` trong config.
| Khóa | Mặc định | Biến môi trường | Mục đích |
|---|---|---|---|
| `url` | `http://localhost:8765` | `MEM0_URL` | URL server OpenMemory |
| `user_id` | `zeroclaw` | `MEM0_USER_ID` | User ID để phân vùng memory |
| `app_name` | `zeroclaw` | `MEM0_APP_NAME` | Tên ứng dụng đăng ký trong mem0 |
| `infer` | `true` | — | Dùng LLM trích xuất sự kiện từ text (`true`) hoặc lưu nguyên (`false`) |
| `extraction_prompt` | chưa đặt | `MEM0_EXTRACTION_PROMPT` | Prompt tùy chỉnh cho trích xuất sự kiện LLM (vd: cho nội dung không phải tiếng Anh) |
```toml
[memory]
backend = "mem0"
[memory.mem0]
url = "http://192.168.0.171:8765"
user_id = "zeroclaw-bot"
extraction_prompt = "Trích xuất sự kiện bằng ngôn ngữ gốc..."
```
Script triển khai server nằm trong `deploy/mem0/`.
## `[[model_routes]]` và `[[embedding_routes]]`
Route hint giúp tên tích hợp ổn định khi model ID thay đổi.
+43
View File
@@ -38,3 +38,46 @@ allowed_tools = ["read", "edit", "exec"]
max_iterations = 15
# Optional: use longer timeout for complex coding tasks
agentic_timeout_secs = 600
# ── Cron Configuration ────────────────────────────────────────
[cron]
# Enable the cron subsystem. Default: true
enabled = true
# Run all overdue jobs at scheduler startup. Default: true
catch_up_on_startup = true
# Maximum number of historical cron run records to retain. Default: 50
max_run_history = 50
# ── Declarative Cron Jobs ─────────────────────────────────────
# Define cron jobs directly in config. These are synced to the database
# at scheduler startup. Each job needs a stable `id` for merge semantics.
# Shell job: runs a shell command on a cron schedule
[[cron.jobs]]
id = "daily-backup"
name = "Daily Backup"
job_type = "shell"
command = "tar czf /tmp/backup.tar.gz /data"
schedule = { kind = "cron", expr = "0 2 * * *" }
# Agent job: runs an agent prompt on an interval
[[cron.jobs]]
id = "health-check"
name = "Health Check"
job_type = "agent"
prompt = "Check server health: disk space, memory, CPU load"
model = "anthropic/claude-sonnet-4"
allowed_tools = ["shell", "file_read"]
schedule = { kind = "every", every_ms = 300000 }
# Cron job with timezone and delivery
# [[cron.jobs]]
# id = "morning-report"
# name = "Morning Report"
# job_type = "agent"
# prompt = "Generate a daily summary of system metrics"
# schedule = { kind = "cron", expr = "0 9 * * 1-5", tz = "America/New_York" }
# [cron.jobs.delivery]
# mode = "announce"
# channel = "telegram"
# to = "123456789"
+79 -25
View File
@@ -569,11 +569,29 @@ MSG
exit 0
fi
# Detect un-accepted Xcode/CLT license (causes `cc` to exit 69).
if ! /usr/bin/xcrun --show-sdk-path >/dev/null 2>&1; then
warn "Xcode license has not been accepted. Run:"
warn " sudo xcodebuild -license accept"
warn "then re-run this installer."
exit 1
# xcrun --show-sdk-path can succeed even without an accepted license,
# so we test-compile a trivial C file which reliably triggers the error.
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
rm -f "$_xcode_test_file"
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
_xcode_accept_ok=false
if [[ "$(id -u)" -eq 0 ]]; then
xcodebuild -license accept && _xcode_accept_ok=true
elif [[ -c /dev/tty ]] && have_cmd sudo; then
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
fi
if [[ "$_xcode_accept_ok" == true ]]; then
step_ok "Xcode license accepted"
else
error "Could not accept Xcode license. Run manually:"
error " sudo xcodebuild -license accept"
error "then re-run this installer."
exit 1
fi
else
rm -f "$_xcode_test_file"
fi
if ! have_cmd git; then
warn "git is not available. Install git (e.g., Homebrew) and re-run bootstrap."
@@ -1175,6 +1193,43 @@ else
install_system_deps
fi
# Always check Xcode/CLT license on macOS, regardless of --install-system-deps.
# An un-accepted license causes `cc` to exit 69, breaking all Rust builds.
if [[ "$OS_NAME" == "Darwin" ]]; then
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
rm -f "$_xcode_test_file"
warn "Xcode/CLT license has not been accepted. Attempting to accept it now..."
# Use /dev/tty so sudo can prompt for a password even in a curl|bash pipe.
_xcode_accept_ok=false
if [[ "$(id -u)" -eq 0 ]]; then
xcodebuild -license accept && _xcode_accept_ok=true
elif [[ -c /dev/tty ]] && have_cmd sudo; then
sudo xcodebuild -license accept < /dev/tty && _xcode_accept_ok=true
fi
if [[ "$_xcode_accept_ok" == true ]]; then
step_ok "Xcode license accepted"
# Re-test compilation to confirm it's fixed.
_xcode_test_file="$(mktemp /tmp/zeroclaw-xcode-check.XXXXXX.c)"
printf 'int main(){return 0;}\n' > "$_xcode_test_file"
if ! cc -x c "$_xcode_test_file" -o /dev/null 2>/dev/null; then
rm -f "$_xcode_test_file"
error "C compiler still failing after license accept. Check your Xcode/CLT installation."
exit 1
fi
rm -f "$_xcode_test_file"
else
error "Could not accept Xcode license. Run manually:"
error " sudo xcodebuild -license accept"
error "then re-run this installer."
exit 1
fi
else
rm -f "$_xcode_test_file"
fi
fi
if [[ "$INSTALL_RUST" == true ]]; then
install_rust_toolchain
fi
@@ -1393,6 +1448,25 @@ else
step_dot "Skipping install"
fi
# --- Build web dashboard ---
if [[ "$SKIP_BUILD" == false && -d "$WORK_DIR/web" ]]; then
if have_cmd node && have_cmd npm; then
step_dot "Building web dashboard"
if (cd "$WORK_DIR/web" && npm ci --ignore-scripts 2>/dev/null && npm run build 2>/dev/null); then
step_ok "Web dashboard built"
else
warn "Web dashboard build failed — dashboard will not be available"
fi
else
warn "node/npm not found — skipping web dashboard build"
warn "Install Node.js (>=18) and re-run, or build manually: cd web && npm ci && npm run build"
fi
else
if [[ "$SKIP_BUILD" == true ]]; then
step_dot "Skipping web dashboard build"
fi
fi
ZEROCLAW_BIN=""
if [[ -x "$HOME/.cargo/bin/zeroclaw" ]]; then
ZEROCLAW_BIN="$HOME/.cargo/bin/zeroclaw"
@@ -1467,25 +1541,6 @@ if [[ -n "$ZEROCLAW_BIN" ]]; then
if "$ZEROCLAW_BIN" service restart 2>/dev/null; then
step_ok "Gateway service restarted"
# Fetch and display pairing code from running gateway
PAIR_CODE=""
for i in 1 2 3 4 5; do
sleep 2
if PAIR_CODE=$("$ZEROCLAW_BIN" gateway get-paircode 2>/dev/null | grep -oE '[0-9]{6}'); then
break
fi
done
if [[ -n "$PAIR_CODE" ]]; then
echo
echo -e " ${BOLD_BLUE}🔐 Gateway Pairing Code${RESET}"
echo
echo -e " ${BOLD_BLUE}┌──────────────┐${RESET}"
echo -e " ${BOLD_BLUE}${RESET} ${BOLD}${PAIR_CODE}${RESET} ${BOLD_BLUE}${RESET}"
echo -e " ${BOLD_BLUE}└──────────────┘${RESET}"
echo
echo -e " ${DIM}Enter this code in the dashboard to pair your device.${RESET}"
echo -e " ${DIM}Run 'zeroclaw gateway get-paircode --new' anytime to generate a fresh code.${RESET}"
fi
else
step_fail "Gateway service restart failed — re-run with zeroclaw service start"
fi
@@ -1532,7 +1587,6 @@ GATEWAY_PORT=42617
DASHBOARD_URL="http://127.0.0.1:${GATEWAY_PORT}"
echo
echo -e "${BOLD}Dashboard URL:${RESET} ${BLUE}${DASHBOARD_URL}${RESET}"
echo -e "${DIM} Run 'zeroclaw gateway get-paircode' to get your pairing code.${RESET}"
# --- Copy to clipboard ---
COPIED_TO_CLIPBOARD=false
+21
View File
@@ -0,0 +1,21 @@
#!/bin/bash
# Start a browser on a virtual display
# Usage: ./start-browser.sh [display_num] [url]
set -e
DISPLAY_NUM=${1:-99}
URL=${2:-"https://google.com"}
export DISPLAY=:$DISPLAY_NUM
# Check if display is running
if ! xdpyinfo -display :$DISPLAY_NUM &>/dev/null; then
echo "Error: Display :$DISPLAY_NUM not running."
echo "Start VNC first: ./start-vnc.sh"
exit 1
fi
google-chrome --no-sandbox --disable-gpu --disable-setuid-sandbox "$URL" &
echo "Chrome started on display :$DISPLAY_NUM"
echo "View via VNC or noVNC"
+52
View File
@@ -0,0 +1,52 @@
#!/bin/bash
# Start virtual display with VNC access for browser GUI
# Usage: ./start-vnc.sh [display_num] [vnc_port] [novnc_port] [resolution]
set -e
DISPLAY_NUM=${1:-99}
VNC_PORT=${2:-5900}
NOVNC_PORT=${3:-6080}
RESOLUTION=${4:-1920x1080x24}
echo "Starting virtual display :$DISPLAY_NUM at $RESOLUTION"
# Kill any existing sessions
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
pkill -f "websockify.*$NOVNC_PORT" 2>/dev/null || true
sleep 1
# Start Xvfb (virtual framebuffer)
Xvfb :$DISPLAY_NUM -screen 0 $RESOLUTION -ac &
XVFB_PID=$!
sleep 1
# Set DISPLAY
export DISPLAY=:$DISPLAY_NUM
# Start window manager
fluxbox -display :$DISPLAY_NUM 2>/dev/null &
sleep 1
# Start x11vnc
x11vnc -display :$DISPLAY_NUM -rfbport $VNC_PORT -forever -shared -nopw -bg 2>/dev/null
sleep 1
# Start noVNC (web-based VNC client)
websockify --web=/usr/share/novnc $NOVNC_PORT localhost:$VNC_PORT &
NOVNC_PID=$!
echo ""
echo "==================================="
echo "VNC Server started!"
echo "==================================="
echo "VNC Direct: localhost:$VNC_PORT"
echo "noVNC Web: http://localhost:$NOVNC_PORT/vnc.html"
echo "Display: :$DISPLAY_NUM"
echo "==================================="
echo ""
echo "To start a browser, run:"
echo " DISPLAY=:$DISPLAY_NUM google-chrome &"
echo ""
echo "To stop, run: pkill -f 'Xvfb :$DISPLAY_NUM'"
+11
View File
@@ -0,0 +1,11 @@
#!/bin/bash
# Stop virtual display and VNC server
# Usage: ./stop-vnc.sh [display_num]
DISPLAY_NUM=${1:-99}
pkill -f "Xvfb :$DISPLAY_NUM" 2>/dev/null || true
pkill -f "x11vnc.*:$DISPLAY_NUM" 2>/dev/null || true
pkill -f "websockify.*6080" 2>/dev/null || true
echo "VNC server stopped"
+3 -1
View File
@@ -77,7 +77,9 @@ echo "Created annotated tag: $TAG"
if [[ "$PUSH_TAG" == "true" ]]; then
git push origin "$TAG"
echo "Pushed tag to origin: $TAG"
echo "GitHub release pipeline will run via .github/workflows/pub-release.yml"
echo "Release Stable workflow will auto-trigger via tag push."
echo "Monitor: gh workflow view 'Release Stable' --web"
else
echo "Next step: git push origin $TAG"
echo "This will auto-trigger the Release Stable workflow (builds, Docker, crates.io, website, Scoop, AUR, Homebrew, tweet)."
fi
+122
View File
@@ -0,0 +1,122 @@
---
name: browser
description: Headless browser automation using agent-browser CLI
metadata: {"zeroclaw":{"emoji":"🌐","requires":{"bins":["agent-browser"]}}}
---
# Browser Skill
Control a headless browser for web automation, scraping, and testing.
## Prerequisites
- `agent-browser` CLI installed globally (`npm install -g agent-browser`)
- Chrome downloaded (`agent-browser install`)
## Installation
```bash
# Install agent-browser CLI
npm install -g agent-browser
# Download Chrome for Testing
agent-browser install --with-deps # Linux
agent-browser install # macOS/Windows
```
## Usage
### Navigate and snapshot
```bash
agent-browser open https://example.com
agent-browser snapshot -i
```
### Interact with elements
```bash
agent-browser click @e1 # Click by ref
agent-browser fill @e2 "text" # Fill input
agent-browser press Enter # Press key
```
### Extract data
```bash
agent-browser get text @e1 # Get text content
agent-browser get url # Get current URL
agent-browser screenshot page.png # Take screenshot
```
### Session management
```bash
agent-browser close # Close browser
```
## Common Workflows
### Login flow
```bash
agent-browser open https://site.com/login
agent-browser snapshot -i
agent-browser fill @email "user@example.com"
agent-browser fill @password "secretpass"
agent-browser click @submit
agent-browser wait --text "Welcome"
```
### Scrape page content
```bash
agent-browser open https://news.ycombinator.com
agent-browser snapshot -i
agent-browser get text @e1
```
### Take screenshots
```bash
agent-browser open https://google.com
agent-browser screenshot --full page.png
```
## Options
- `--json` - JSON output for parsing
- `--headed` - Show browser window (for debugging)
- `--session-name <name>` - Persist session cookies
- `--profile <path>` - Use persistent browser profile
## Configuration
The browser tool is enabled by default with `allowed_domains = ["*"]` and
`backend = "agent_browser"`. To customize, edit `~/.zeroclaw/config.toml`:
```toml
[browser]
enabled = true # default: true
allowed_domains = ["*"] # default: ["*"] (all public hosts)
backend = "agent_browser" # default: "agent_browser"
native_headless = true # default: true
```
To restrict domains or disable the browser tool:
```toml
[browser]
enabled = false # disable entirely
# or restrict to specific domains:
allowed_domains = ["example.com", "docs.example.com"]
```
## Full Command Reference
Run `agent-browser --help` for all available commands.
## Related
- [agent-browser GitHub](https://github.com/vercel-labs/agent-browser)
- [VNC Setup Guide](../docs/browser-setup.md)
+2 -1
View File
@@ -359,7 +359,7 @@ impl Agent {
None
};
let (mut tools, delegate_handle) = tools::all_tools_with_runtime(
let (mut tools, delegate_handle, _reaction_handle) = tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
@@ -373,6 +373,7 @@ impl Agent {
&config.agents,
config.api_key.as_deref(),
config,
None,
);
// ── Wire MCP tools (non-fatal) ─────────────────────────────
+657 -82
View File
File diff suppressed because it is too large Load Diff
+14 -2
View File
@@ -1,4 +1,4 @@
use crate::memory::{self, Memory};
use crate::memory::{self, decay, Memory};
use async_trait::async_trait;
use std::fmt::Write;
@@ -43,13 +43,16 @@ impl MemoryLoader for DefaultMemoryLoader {
user_message: &str,
session_id: Option<&str>,
) -> anyhow::Result<String> {
let entries = memory
let mut entries = memory
.recall(user_message, self.limit, session_id, None, None)
.await?;
if entries.is_empty() {
return Ok(String::new());
}
// Apply time decay: older non-Core memories score lower
decay::apply_time_decay(&mut entries, decay::DEFAULT_HALF_LIFE_DAYS);
let mut context = String::from("[Memory context]\n");
for entry in entries {
if memory::is_assistant_autosave_key(&entry.key) {
@@ -118,6 +121,9 @@ mod tests {
timestamp: "now".into(),
session_id: None,
score: None,
namespace: "default".into(),
importance: None,
superseded_by: None,
}])
}
@@ -226,6 +232,9 @@ mod tests {
timestamp: "now".into(),
session_id: None,
score: Some(0.95),
namespace: "default".into(),
importance: None,
superseded_by: None,
},
MemoryEntry {
id: "2".into(),
@@ -235,6 +244,9 @@ mod tests {
timestamp: "now".into(),
session_id: None,
score: Some(0.9),
namespace: "default".into(),
importance: None,
superseded_by: None,
},
]),
};
+1
View File
@@ -5,6 +5,7 @@ pub mod dispatcher;
pub mod loop_;
pub mod memory_loader;
pub mod prompt;
pub mod thinking;
#[cfg(test)]
mod tests;
+7 -6
View File
@@ -473,8 +473,9 @@ mod tests {
assert!(output.contains("<available_skills>"));
assert!(output.contains("<name>deploy</name>"));
assert!(output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
assert!(output.contains("<name>release_checklist</name>"));
assert!(output.contains("<kind>shell</kind>"));
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
assert!(output.contains("<callable_tools"));
assert!(output.contains("<name>deploy.release_checklist</name>"));
}
#[test]
@@ -516,10 +517,10 @@ mod tests {
assert!(output.contains("<location>skills/deploy/SKILL.md</location>"));
assert!(output.contains("read_skill(name)"));
assert!(!output.contains("<instruction>Run smoke tests before deploy.</instruction>"));
// Compact mode should still include tools so the LLM knows about them
assert!(output.contains("<tools>"));
assert!(output.contains("<name>release_checklist</name>"));
assert!(output.contains("<kind>shell</kind>"));
// Compact mode should still include tools so the LLM knows about them.
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
assert!(output.contains("<callable_tools"));
assert!(output.contains("<name>deploy.release_checklist</name>"));
}
#[test]
+424
View File
@@ -0,0 +1,424 @@
//! Thinking/Reasoning Level Control
//!
//! Allows users to control how deeply the model reasons per message,
//! trading speed for depth. Levels range from `Off` (fastest, most concise)
//! to `Max` (deepest reasoning, slowest).
//!
//! Users can set the level via:
//! - Inline directive: `/think:high` at the start of a message
//! - Agent config: `[agent.thinking]` section with `default_level`
//!
//! Resolution hierarchy (highest priority first):
//! 1. Inline directive (`/think:<level>`)
//! 2. Session override (reserved for future use)
//! 3. Agent config (`agent.thinking.default_level`)
//! 4. Global default (`Medium`)
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
/// How deeply the model should reason for a given message.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum ThinkingLevel {
/// No chain-of-thought. Fastest, most concise responses.
Off,
/// Minimal reasoning. Brief, direct answers.
Minimal,
/// Light reasoning. Short explanations when needed.
Low,
/// Balanced reasoning (default). Moderate depth.
#[default]
Medium,
/// Deep reasoning. Thorough analysis and step-by-step thinking.
High,
/// Maximum reasoning depth. Exhaustive analysis.
Max,
}
impl ThinkingLevel {
/// Parse a thinking level from a string (case-insensitive).
pub fn from_str_insensitive(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"off" | "none" => Some(Self::Off),
"minimal" | "min" => Some(Self::Minimal),
"low" => Some(Self::Low),
"medium" | "med" | "default" => Some(Self::Medium),
"high" => Some(Self::High),
"max" | "maximum" => Some(Self::Max),
_ => None,
}
}
}
/// Configuration for thinking/reasoning level control.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ThinkingConfig {
/// Default thinking level when no directive is present.
#[serde(default)]
pub default_level: ThinkingLevel,
}
impl Default for ThinkingConfig {
fn default() -> Self {
Self {
default_level: ThinkingLevel::Medium,
}
}
}
/// Parameters derived from a thinking level, applied to the LLM request.
#[derive(Debug, Clone, PartialEq)]
pub struct ThinkingParams {
/// Temperature adjustment (added to the base temperature, clamped to 0.0..=2.0).
pub temperature_adjustment: f64,
/// Maximum tokens adjustment (added to any existing max_tokens setting).
pub max_tokens_adjustment: i64,
/// Optional system prompt prefix injected before the existing system prompt.
pub system_prompt_prefix: Option<String>,
}
/// Parse a `/think:<level>` directive from the start of a message.
///
/// Returns `Some((level, remaining_message))` if a directive is found,
/// or `None` if no directive is present. The remaining message has
/// leading whitespace after the directive trimmed.
pub fn parse_thinking_directive(message: &str) -> Option<(ThinkingLevel, String)> {
let trimmed = message.trim_start();
if !trimmed.starts_with("/think:") {
return None;
}
// Extract the level token (everything between `/think:` and the next whitespace or end).
let after_prefix = &trimmed["/think:".len()..];
let level_end = after_prefix
.find(|c: char| c.is_whitespace())
.unwrap_or(after_prefix.len());
let level_str = &after_prefix[..level_end];
let level = ThinkingLevel::from_str_insensitive(level_str)?;
let remaining = after_prefix[level_end..].trim_start().to_string();
Some((level, remaining))
}
/// Convert a `ThinkingLevel` into concrete parameters for the LLM request.
pub fn apply_thinking_level(level: ThinkingLevel) -> ThinkingParams {
match level {
ThinkingLevel::Off => ThinkingParams {
temperature_adjustment: -0.2,
max_tokens_adjustment: -1000,
system_prompt_prefix: Some(
"Be extremely concise. Give direct answers without explanation \
unless explicitly asked. No preamble."
.into(),
),
},
ThinkingLevel::Minimal => ThinkingParams {
temperature_adjustment: -0.1,
max_tokens_adjustment: -500,
system_prompt_prefix: Some(
"Be concise and fast. Keep explanations brief. \
Prioritize speed over thoroughness."
.into(),
),
},
ThinkingLevel::Low => ThinkingParams {
temperature_adjustment: -0.05,
max_tokens_adjustment: 0,
system_prompt_prefix: Some("Keep reasoning light. Explain only when helpful.".into()),
},
ThinkingLevel::Medium => ThinkingParams {
temperature_adjustment: 0.0,
max_tokens_adjustment: 0,
system_prompt_prefix: None,
},
ThinkingLevel::High => ThinkingParams {
temperature_adjustment: 0.05,
max_tokens_adjustment: 1000,
system_prompt_prefix: Some(
"Think step by step. Provide thorough analysis and \
consider edge cases before answering."
.into(),
),
},
ThinkingLevel::Max => ThinkingParams {
temperature_adjustment: 0.1,
max_tokens_adjustment: 2000,
system_prompt_prefix: Some(
"Think very carefully and exhaustively. Break down the problem \
into sub-problems, consider all angles, verify your reasoning, \
and provide the most thorough analysis possible."
.into(),
),
},
}
}
/// Resolve the effective thinking level using the priority hierarchy:
/// 1. Inline directive (if present)
/// 2. Session override (reserved, currently always `None`)
/// 3. Agent config default
/// 4. Global default (`Medium`)
pub fn resolve_thinking_level(
inline_directive: Option<ThinkingLevel>,
session_override: Option<ThinkingLevel>,
config: &ThinkingConfig,
) -> ThinkingLevel {
inline_directive
.or(session_override)
.unwrap_or(config.default_level)
}
/// Clamp a temperature value to the valid range `[0.0, 2.0]`.
pub fn clamp_temperature(temp: f64) -> f64 {
temp.clamp(0.0, 2.0)
}
#[cfg(test)]
mod tests {
use super::*;
// ── ThinkingLevel parsing ────────────────────────────────────
#[test]
fn thinking_level_from_str_canonical_names() {
assert_eq!(
ThinkingLevel::from_str_insensitive("off"),
Some(ThinkingLevel::Off)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("minimal"),
Some(ThinkingLevel::Minimal)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("low"),
Some(ThinkingLevel::Low)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("medium"),
Some(ThinkingLevel::Medium)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("high"),
Some(ThinkingLevel::High)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("max"),
Some(ThinkingLevel::Max)
);
}
#[test]
fn thinking_level_from_str_aliases() {
assert_eq!(
ThinkingLevel::from_str_insensitive("none"),
Some(ThinkingLevel::Off)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("min"),
Some(ThinkingLevel::Minimal)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("med"),
Some(ThinkingLevel::Medium)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("default"),
Some(ThinkingLevel::Medium)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("maximum"),
Some(ThinkingLevel::Max)
);
}
#[test]
fn thinking_level_from_str_case_insensitive() {
assert_eq!(
ThinkingLevel::from_str_insensitive("HIGH"),
Some(ThinkingLevel::High)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("Max"),
Some(ThinkingLevel::Max)
);
assert_eq!(
ThinkingLevel::from_str_insensitive("OFF"),
Some(ThinkingLevel::Off)
);
}
#[test]
fn thinking_level_from_str_invalid_returns_none() {
assert_eq!(ThinkingLevel::from_str_insensitive("turbo"), None);
assert_eq!(ThinkingLevel::from_str_insensitive(""), None);
assert_eq!(ThinkingLevel::from_str_insensitive("super-high"), None);
}
// ── Directive parsing ────────────────────────────────────────
#[test]
fn parse_directive_extracts_level_and_remaining_message() {
let result = parse_thinking_directive("/think:high What is Rust?");
assert!(result.is_some());
let (level, remaining) = result.unwrap();
assert_eq!(level, ThinkingLevel::High);
assert_eq!(remaining, "What is Rust?");
}
#[test]
fn parse_directive_handles_directive_only() {
let result = parse_thinking_directive("/think:off");
assert!(result.is_some());
let (level, remaining) = result.unwrap();
assert_eq!(level, ThinkingLevel::Off);
assert_eq!(remaining, "");
}
#[test]
fn parse_directive_strips_leading_whitespace() {
let result = parse_thinking_directive(" /think:low Tell me about Rust");
assert!(result.is_some());
let (level, remaining) = result.unwrap();
assert_eq!(level, ThinkingLevel::Low);
assert_eq!(remaining, "Tell me about Rust");
}
#[test]
fn parse_directive_returns_none_for_no_directive() {
assert!(parse_thinking_directive("Hello world").is_none());
assert!(parse_thinking_directive("").is_none());
assert!(parse_thinking_directive("/think").is_none());
}
#[test]
fn parse_directive_returns_none_for_invalid_level() {
assert!(parse_thinking_directive("/think:turbo What?").is_none());
}
#[test]
fn parse_directive_not_triggered_mid_message() {
assert!(parse_thinking_directive("Hello /think:high world").is_none());
}
// ── Level application ────────────────────────────────────────
#[test]
fn apply_thinking_level_off_is_concise() {
let params = apply_thinking_level(ThinkingLevel::Off);
assert!(params.temperature_adjustment < 0.0);
assert!(params.max_tokens_adjustment < 0);
assert!(params.system_prompt_prefix.is_some());
assert!(params
.system_prompt_prefix
.unwrap()
.to_lowercase()
.contains("concise"));
}
#[test]
fn apply_thinking_level_medium_is_neutral() {
let params = apply_thinking_level(ThinkingLevel::Medium);
assert!((params.temperature_adjustment - 0.0).abs() < f64::EPSILON);
assert_eq!(params.max_tokens_adjustment, 0);
assert!(params.system_prompt_prefix.is_none());
}
#[test]
fn apply_thinking_level_high_adds_step_by_step() {
let params = apply_thinking_level(ThinkingLevel::High);
assert!(params.temperature_adjustment > 0.0);
assert!(params.max_tokens_adjustment > 0);
let prefix = params.system_prompt_prefix.unwrap();
assert!(prefix.to_lowercase().contains("step by step"));
}
#[test]
fn apply_thinking_level_max_is_most_thorough() {
let params = apply_thinking_level(ThinkingLevel::Max);
assert!(params.temperature_adjustment > 0.0);
assert!(params.max_tokens_adjustment > 0);
let prefix = params.system_prompt_prefix.unwrap();
assert!(prefix.to_lowercase().contains("exhaustively"));
}
// ── Resolution hierarchy ─────────────────────────────────────
#[test]
fn resolve_inline_directive_takes_priority() {
let config = ThinkingConfig {
default_level: ThinkingLevel::Low,
};
let result =
resolve_thinking_level(Some(ThinkingLevel::Max), Some(ThinkingLevel::High), &config);
assert_eq!(result, ThinkingLevel::Max);
}
#[test]
fn resolve_session_override_takes_priority_over_config() {
let config = ThinkingConfig {
default_level: ThinkingLevel::Low,
};
let result = resolve_thinking_level(None, Some(ThinkingLevel::High), &config);
assert_eq!(result, ThinkingLevel::High);
}
#[test]
fn resolve_falls_back_to_config_default() {
let config = ThinkingConfig {
default_level: ThinkingLevel::Minimal,
};
let result = resolve_thinking_level(None, None, &config);
assert_eq!(result, ThinkingLevel::Minimal);
}
#[test]
fn resolve_default_config_uses_medium() {
let config = ThinkingConfig::default();
let result = resolve_thinking_level(None, None, &config);
assert_eq!(result, ThinkingLevel::Medium);
}
// ── Temperature clamping ─────────────────────────────────────
#[test]
fn clamp_temperature_within_range() {
assert!((clamp_temperature(0.7) - 0.7).abs() < f64::EPSILON);
assert!((clamp_temperature(0.0) - 0.0).abs() < f64::EPSILON);
assert!((clamp_temperature(2.0) - 2.0).abs() < f64::EPSILON);
}
#[test]
fn clamp_temperature_below_minimum() {
assert!((clamp_temperature(-0.5) - 0.0).abs() < f64::EPSILON);
}
#[test]
fn clamp_temperature_above_maximum() {
assert!((clamp_temperature(3.0) - 2.0).abs() < f64::EPSILON);
}
// ── Serde round-trip ─────────────────────────────────────────
#[test]
fn thinking_config_deserializes_from_toml() {
let toml_str = r#"default_level = "high""#;
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.default_level, ThinkingLevel::High);
}
#[test]
fn thinking_config_default_level_deserializes() {
let toml_str = "";
let config: ThinkingConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.default_level, ThinkingLevel::Medium);
}
#[test]
fn thinking_level_serializes_lowercase() {
let level = ThinkingLevel::High;
let json = serde_json::to_string(&level).unwrap();
assert_eq!(json, "\"high\"");
}
}
+48 -2
View File
@@ -122,7 +122,7 @@ impl ApprovalManager {
}
// always_ask overrides everything.
if self.always_ask.contains(tool_name) {
if self.always_ask.contains("*") || self.always_ask.contains(tool_name) {
return true;
}
@@ -136,7 +136,7 @@ impl ApprovalManager {
}
// auto_approve skips the prompt.
if self.auto_approve.contains(tool_name) {
if self.auto_approve.contains("*") || self.auto_approve.contains(tool_name) {
return false;
}
@@ -562,4 +562,50 @@ mod tests {
let parsed: ApprovalRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.tool_name, "shell");
}
// ── Regression: #4247 default approved tools in channels ──
#[test]
fn non_interactive_allows_default_auto_approve_tools() {
let config = AutonomyConfig::default();
let mgr = ApprovalManager::for_non_interactive(&config);
for tool in &config.auto_approve {
assert!(
!mgr.needs_approval(tool),
"default auto_approve tool '{tool}' should not need approval in non-interactive mode"
);
}
}
#[test]
fn non_interactive_denies_unknown_tools() {
let config = AutonomyConfig::default();
let mgr = ApprovalManager::for_non_interactive(&config);
assert!(
mgr.needs_approval("some_unknown_tool"),
"unknown tool should need approval"
);
}
#[test]
fn non_interactive_weather_is_auto_approved() {
let config = AutonomyConfig::default();
let mgr = ApprovalManager::for_non_interactive(&config);
assert!(
!mgr.needs_approval("weather"),
"weather tool must not need approval — it is in the default auto_approve list"
);
}
#[test]
fn always_ask_overrides_auto_approve() {
let mut config = AutonomyConfig::default();
config.always_ask = vec!["weather".into()];
let mgr = ApprovalManager::for_non_interactive(&config);
assert!(
mgr.needs_approval("weather"),
"always_ask must override auto_approve"
);
}
}
+116 -1
View File
@@ -20,6 +20,9 @@ pub struct DiscordChannel {
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
/// Voice transcription config — when set, audio attachments are
/// downloaded, transcribed, and their text inlined into the message.
transcription: Option<crate::config::TranscriptionConfig>,
}
impl DiscordChannel {
@@ -38,6 +41,7 @@ impl DiscordChannel {
mention_only,
typing_handles: Mutex::new(HashMap::new()),
proxy_url: None,
transcription: None,
}
}
@@ -47,6 +51,14 @@ impl DiscordChannel {
self
}
/// Configure voice transcription for audio attachments.
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
if config.enabled {
self.transcription = Some(config);
}
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_channel_proxy_client("channel.discord", self.proxy_url.as_deref())
}
@@ -113,6 +125,88 @@ async fn process_attachments(
parts.join("\n---\n")
}
/// Audio file extensions accepted for voice transcription.
const DISCORD_AUDIO_EXTENSIONS: &[&str] = &[
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
];
/// Check if a content type or filename indicates an audio file.
fn is_discord_audio_attachment(content_type: &str, filename: &str) -> bool {
if content_type.starts_with("audio/") {
return true;
}
if let Some(ext) = filename.rsplit('.').next() {
return DISCORD_AUDIO_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str());
}
false
}
/// Download and transcribe audio attachments from a Discord message.
///
/// Returns transcribed text blocks for any audio attachments found.
/// Non-audio attachments and failures are silently skipped.
async fn transcribe_discord_audio_attachments(
attachments: &[serde_json::Value],
client: &reqwest::Client,
config: &crate::config::TranscriptionConfig,
) -> String {
let mut parts: Vec<String> = Vec::new();
for att in attachments {
let ct = att
.get("content_type")
.and_then(|v| v.as_str())
.unwrap_or("");
let name = att
.get("filename")
.and_then(|v| v.as_str())
.unwrap_or("file");
if !is_discord_audio_attachment(ct, name) {
continue;
}
let Some(url) = att.get("url").and_then(|v| v.as_str()) else {
continue;
};
let audio_data = match client.get(url).send().await {
Ok(resp) if resp.status().is_success() => match resp.bytes().await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
tracing::warn!(name, error = %e, "discord: failed to read audio attachment bytes");
continue;
}
},
Ok(resp) => {
tracing::warn!(name, status = %resp.status(), "discord: audio attachment download failed");
continue;
}
Err(e) => {
tracing::warn!(name, error = %e, "discord: audio attachment fetch error");
continue;
}
};
match super::transcription::transcribe_audio(audio_data, name, config).await {
Ok(text) => {
let trimmed = text.trim();
if !trimmed.is_empty() {
tracing::info!(
"Discord: transcribed audio attachment {} ({} chars)",
name,
trimmed.len()
);
parts.push(format!("[Voice] {trimmed}"));
}
}
Err(e) => {
tracing::warn!(name, error = %e, "discord: voice transcription failed");
}
}
}
parts.join("\n")
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum DiscordAttachmentKind {
Image,
@@ -737,7 +831,28 @@ impl Channel for DiscordChannel {
.and_then(|a| a.as_array())
.cloned()
.unwrap_or_default();
process_attachments(&atts, &self.http_client()).await
let client = self.http_client();
let mut text_parts = process_attachments(&atts, &client).await;
// Transcribe audio attachments when transcription is configured
if let Some(ref transcription_config) = self.transcription {
let voice_text = transcribe_discord_audio_attachments(
&atts,
&client,
transcription_config,
)
.await;
if !voice_text.is_empty() {
if text_parts.is_empty() {
text_parts = voice_text;
} else {
text_parts = format!("{text_parts}
{voice_text}");
}
}
}
text_parts
};
let final_content = if attachment_text.is_empty() {
clean_content
+549
View File
@@ -0,0 +1,549 @@
use super::traits::{Channel, ChannelMessage, SendMessage};
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use parking_lot::Mutex;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;
use crate::memory::{Memory, MemoryCategory};
/// Discord History channel — connects via Gateway WebSocket, stores ALL non-bot messages
/// to a dedicated discord.db, and forwards @mention messages to the agent.
pub struct DiscordHistoryChannel {
bot_token: String,
guild_id: Option<String>,
allowed_users: Vec<String>,
/// Channel IDs to watch. Empty = watch all channels.
channel_ids: Vec<String>,
/// Dedicated discord.db memory backend.
discord_memory: Arc<dyn Memory>,
typing_handles: Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
proxy_url: Option<String>,
/// When false, DM messages are not stored in discord.db.
store_dms: bool,
/// When false, @mentions in DMs are not forwarded to the agent.
respond_to_dms: bool,
}
impl DiscordHistoryChannel {
pub fn new(
bot_token: String,
guild_id: Option<String>,
allowed_users: Vec<String>,
channel_ids: Vec<String>,
discord_memory: Arc<dyn Memory>,
store_dms: bool,
respond_to_dms: bool,
) -> Self {
Self {
bot_token,
guild_id,
allowed_users,
channel_ids,
discord_memory,
typing_handles: Mutex::new(HashMap::new()),
proxy_url: None,
store_dms,
respond_to_dms,
}
}
pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
self.proxy_url = proxy_url;
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_channel_proxy_client(
"channel.discord_history",
self.proxy_url.as_deref(),
)
}
fn is_user_allowed(&self, user_id: &str) -> bool {
if self.allowed_users.is_empty() {
return true; // default open for logging channel
}
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
}
fn is_channel_watched(&self, channel_id: &str) -> bool {
self.channel_ids.is_empty() || self.channel_ids.iter().any(|c| c == channel_id)
}
fn bot_user_id_from_token(token: &str) -> Option<String> {
let part = token.split('.').next()?;
base64_decode(part)
}
async fn resolve_channel_name(&self, channel_id: &str) -> String {
// 1. Check persistent database (via discord_memory)
let cache_key = format!("cache:channel_name:{}", channel_id);
if let Ok(Some(cached_mem)) = self.discord_memory.get(&cache_key).await {
// Check if it's still fresh (e.g., less than 24 hours old)
// Note: cached_mem.timestamp is an RFC3339 string
let is_fresh =
if let Ok(ts) = chrono::DateTime::parse_from_rfc3339(&cached_mem.timestamp) {
chrono::Utc::now().signed_duration_since(ts.with_timezone(&chrono::Utc))
< chrono::Duration::hours(24)
} else {
false
};
if is_fresh {
return cached_mem.content.clone();
}
}
// 2. Fetch from API (either not in DB or stale)
let url = format!("https://discord.com/api/v10/channels/{channel_id}");
let resp = self
.http_client()
.get(&url)
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await;
let name = if let Ok(r) = resp {
if let Ok(json) = r.json::<serde_json::Value>().await {
json.get("name")
.and_then(|n| n.as_str())
.map(|s| s.to_string())
.or_else(|| {
// For DMs, there might not be a 'name', use the recipient's username if available
json.get("recipients")
.and_then(|r| r.as_array())
.and_then(|a| a.first())
.and_then(|u| u.get("username"))
.and_then(|un| un.as_str())
.map(|s| format!("dm-{}", s))
})
} else {
None
}
} else {
None
};
let resolved = name.unwrap_or_else(|| channel_id.to_string());
// 3. Store in persistent database
let _ = self
.discord_memory
.store(
&cache_key,
&resolved,
crate::memory::MemoryCategory::Custom("channel_cache".to_string()),
Some(channel_id),
)
.await;
resolved
}
}
const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
#[allow(clippy::cast_possible_truncation)]
fn base64_decode(input: &str) -> Option<String> {
let padded = match input.len() % 4 {
2 => format!("{input}=="),
3 => format!("{input}="),
_ => input.to_string(),
};
let mut bytes = Vec::new();
let chars: Vec<u8> = padded.bytes().collect();
for chunk in chars.chunks(4) {
if chunk.len() < 4 {
break;
}
let mut v = [0usize; 4];
for (i, &b) in chunk.iter().enumerate() {
if b == b'=' {
v[i] = 0;
} else {
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
}
}
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
if chunk[2] != b'=' {
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
}
if chunk[3] != b'=' {
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
}
}
String::from_utf8(bytes).ok()
}
fn contains_bot_mention(content: &str, bot_user_id: &str) -> bool {
if bot_user_id.is_empty() {
return false;
}
content.contains(&format!("<@{bot_user_id}>"))
|| content.contains(&format!("<@!{bot_user_id}>"))
}
fn strip_bot_mention(content: &str, bot_user_id: &str) -> String {
let mut result = content.to_string();
for tag in [format!("<@{bot_user_id}>"), format!("<@!{bot_user_id}>")] {
result = result.replace(&tag, " ");
}
result.trim().to_string()
}
#[async_trait]
impl Channel for DiscordHistoryChannel {
fn name(&self) -> &str {
"discord_history"
}
/// Send a reply back to Discord (used when agent responds to @mention).
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
let content = super::strip_tool_call_tags(&message.content);
let url = format!(
"https://discord.com/api/v10/channels/{}/messages",
message.recipient
);
self.http_client()
.post(&url)
.header("Authorization", format!("Bot {}", self.bot_token))
.json(&json!({"content": content}))
.send()
.await?;
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
// Get Gateway URL
let gw_resp: serde_json::Value = self
.http_client()
.get("https://discord.com/api/v10/gateway/bot")
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await?
.json()
.await?;
let gw_url = gw_resp
.get("url")
.and_then(|u| u.as_str())
.unwrap_or("wss://gateway.discord.gg");
let ws_url = format!("{gw_url}/?v=10&encoding=json");
tracing::info!("DiscordHistory: connecting to gateway...");
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
// Read Hello (opcode 10)
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
let heartbeat_interval = hello_data
.get("d")
.and_then(|d| d.get("heartbeat_interval"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(41250);
// Identify with intents for guild + DM messages + message content
let identify = json!({
"op": 2,
"d": {
"token": self.bot_token,
"intents": 37377,
"properties": {
"os": "linux",
"browser": "zeroclaw",
"device": "zeroclaw"
}
}
});
write
.send(Message::Text(identify.to_string().into()))
.await?;
tracing::info!("DiscordHistory: connected and identified");
let mut sequence: i64 = -1;
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_millis(heartbeat_interval));
loop {
interval.tick().await;
if hb_tx.send(()).await.is_err() {
break;
}
}
});
let guild_filter = self.guild_id.clone();
let discord_memory = Arc::clone(&self.discord_memory);
let store_dms = self.store_dms;
let respond_to_dms = self.respond_to_dms;
loop {
tokio::select! {
_ = hb_rx.recv() => {
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
let hb = json!({"op": 1, "d": d});
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
break;
}
}
msg = read.next() => {
let msg = match msg {
Some(Ok(Message::Text(t))) => t,
Some(Ok(Message::Ping(payload))) => {
if write.send(Message::Pong(payload)).await.is_err() {
break;
}
continue;
}
Some(Ok(Message::Close(_))) | None => break,
Some(Err(e)) => {
tracing::warn!("DiscordHistory: websocket error: {e}");
break;
}
_ => continue,
};
let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) {
Ok(e) => e,
Err(_) => continue,
};
if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) {
sequence = s;
}
let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0);
match op {
1 => {
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
let hb = json!({"op": 1, "d": d});
if write.send(Message::Text(hb.to_string().into())).await.is_err() {
break;
}
continue;
}
7 => { tracing::warn!("DiscordHistory: Reconnect (op 7)"); break; }
9 => { tracing::warn!("DiscordHistory: Invalid Session (op 9)"); break; }
_ => {}
}
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
if event_type != "MESSAGE_CREATE" {
continue;
}
let Some(d) = event.get("d") else { continue };
// Skip messages from the bot itself
let author_id = d
.get("author")
.and_then(|a| a.get("id"))
.and_then(|i| i.as_str())
.unwrap_or("");
let username = d
.get("author")
.and_then(|a| a.get("username"))
.and_then(|i| i.as_str())
.unwrap_or(author_id);
if author_id == bot_user_id {
continue;
}
// Skip other bots
if d.get("author")
.and_then(|a| a.get("bot"))
.and_then(serde_json::Value::as_bool)
.unwrap_or(false)
{
continue;
}
let channel_id = d
.get("channel_id")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string();
// DM detection: DMs have no guild_id
let is_dm_event = d.get("guild_id").and_then(serde_json::Value::as_str).is_none();
// Resolve channel name (with cache)
let channel_display = if is_dm_event {
"dm".to_string()
} else {
self.resolve_channel_name(&channel_id).await
};
if is_dm_event && !store_dms && !respond_to_dms {
continue;
}
// Guild filter
if let Some(ref gid) = guild_filter {
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str);
if let Some(g) = msg_guild {
if g != gid {
continue;
}
}
}
// Channel filter
if !self.is_channel_watched(&channel_id) {
continue;
}
if !self.is_user_allowed(author_id) {
continue;
}
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
let is_mention = contains_bot_mention(content, &bot_user_id);
// Collect attachment URLs
let attachments: Vec<String> = d
.get("attachments")
.and_then(|a| a.as_array())
.map(|arr| {
arr.iter()
.filter_map(|a| a.get("url").and_then(|u| u.as_str()))
.map(|u| u.to_string())
.collect()
})
.unwrap_or_default();
// Store messages to discord.db (skip DMs if store_dms=false)
if (!is_dm_event || store_dms) && (!content.is_empty() || !attachments.is_empty()) {
let ts = chrono::Utc::now().to_rfc3339();
let mut mem_content = format!(
"@{username} in #{channel_display} at {ts}: {content}"
);
if !attachments.is_empty() {
mem_content.push_str(" [attachments: ");
mem_content.push_str(&attachments.join(", "));
mem_content.push(']');
}
let mem_key = format!(
"discord_{}",
if message_id.is_empty() {
Uuid::new_v4().to_string()
} else {
message_id.to_string()
}
);
let channel_id_for_session = if channel_id.is_empty() {
None
} else {
Some(channel_id.as_str())
};
if let Err(err) = discord_memory
.store(
&mem_key,
&mem_content,
MemoryCategory::Custom("discord".to_string()),
channel_id_for_session,
)
.await
{
tracing::warn!("discord_history: failed to store message: {err}");
} else {
tracing::debug!(
"discord_history: stored message from @{username} in #{channel_display}"
);
}
}
// Forward @mention to agent (skip DMs if respond_to_dms=false)
if is_mention && (!is_dm_event || respond_to_dms) {
let clean_content = strip_bot_mention(content, &bot_user_id);
if clean_content.is_empty() {
continue;
}
let channel_msg = ChannelMessage {
id: if message_id.is_empty() {
Uuid::new_v4().to_string()
} else {
format!("discord_{message_id}")
},
sender: author_id.to_string(),
reply_target: if channel_id.is_empty() {
author_id.to_string()
} else {
channel_id.clone()
},
content: clean_content,
channel: "discord_history".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
thread_ts: None,
interruption_scope_id: None,
};
if tx.send(channel_msg).await.is_err() {
break;
}
}
}
}
}
Ok(())
}
async fn health_check(&self) -> bool {
self.http_client()
.get("https://discord.com/api/v10/users/@me")
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> {
let mut guard = self.typing_handles.lock();
if let Some(h) = guard.remove(recipient) {
h.abort();
}
let client = self.http_client();
let token = self.bot_token.clone();
let channel_id = recipient.to_string();
let handle = tokio::spawn(async move {
let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing");
loop {
let _ = client
.post(&url)
.header("Authorization", format!("Bot {token}"))
.send()
.await;
tokio::time::sleep(std::time::Duration::from_secs(8)).await;
}
});
guard.insert(recipient.to_string(), handle);
Ok(())
}
async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> {
let mut guard = self.typing_handles.lock();
if let Some(handle) = guard.remove(recipient) {
handle.abort();
}
Ok(())
}
}
File diff suppressed because it is too large Load Diff
+535 -47
View File
@@ -1,5 +1,6 @@
use super::traits::{Channel, ChannelMessage, SendMessage};
use async_trait::async_trait;
use base64::Engine as _;
use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage;
use std::collections::HashMap;
@@ -221,6 +222,21 @@ const LARK_INVALID_ACCESS_TOKEN_CODE: i64 = 99_991_663;
/// Lark card payloads have a ~30 KB limit; leave margin for JSON envelope.
const LARK_CARD_MARKDOWN_MAX_BYTES: usize = 28_000;
/// Maximum image size we will download and inline (5 MiB).
const LARK_IMAGE_MAX_BYTES: usize = 5 * 1024 * 1024;
/// Maximum file size we will download and present as text (512 KiB).
const LARK_FILE_MAX_BYTES: usize = 512 * 1024;
/// Image MIME types we support for inline base64 encoding.
const LARK_SUPPORTED_IMAGE_MIMES: &[&str] = &[
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/bmp",
];
/// Returns true when the WebSocket frame indicates live traffic that should
/// refresh the heartbeat watchdog.
fn should_refresh_last_recv(msg: &WsMsg) -> bool {
@@ -520,6 +536,17 @@ impl LarkChannel {
format!("{}/im/v1/messages/{message_id}/reactions", self.api_base())
}
fn image_download_url(&self, image_key: &str) -> String {
format!("{}/im/v1/images/{image_key}", self.api_base())
}
fn file_download_url(&self, message_id: &str, file_key: &str) -> String {
format!(
"{}/im/v1/messages/{message_id}/resources/{file_key}?type=file",
self.api_base()
)
}
fn resolved_bot_open_id(&self) -> Option<String> {
self.resolved_bot_open_id
.read()
@@ -866,6 +893,44 @@ impl LarkChannel {
Some(details) => (details.text, details.mentioned_open_ids),
None => continue,
},
"image" => {
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
Ok(v) => v,
Err(_) => continue,
};
let image_key = match v.get("image_key").and_then(|k| k.as_str()) {
Some(k) => k.to_string(),
None => { tracing::debug!("Lark WS: image message missing image_key"); continue; }
};
match self.download_image_as_marker(&image_key).await {
Some(marker) => (marker, Vec::new()),
None => {
tracing::warn!("Lark WS: failed to download image {image_key}");
(format!("[IMAGE:{image_key} | download failed]"), Vec::new())
}
}
}
"file" => {
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
Ok(v) => v,
Err(_) => continue,
};
let file_key = match v.get("file_key").and_then(|k| k.as_str()) {
Some(k) => k.to_string(),
None => { tracing::debug!("Lark WS: file message missing file_key"); continue; }
};
let file_name = v.get("file_name")
.and_then(|n| n.as_str())
.unwrap_or("unknown_file")
.to_string();
match self.download_file_as_content(&lark_msg.message_id, &file_key, &file_name).await {
Some(content) => (content, Vec::new()),
None => {
tracing::warn!("Lark WS: failed to download file {file_key}");
(format!("[ATTACHMENT:{file_name} | download failed]"), Vec::new())
}
}
}
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
};
@@ -986,6 +1051,183 @@ impl LarkChannel {
*cached = None;
}
/// Download an image from the Lark API and return an `[IMAGE:data:...]` marker string.
async fn download_image_as_marker(&self, image_key: &str) -> Option<String> {
let token = match self.get_tenant_access_token().await {
Ok(t) => t,
Err(e) => {
tracing::warn!("Lark: failed to get token for image download: {e}");
return None;
}
};
let url = self.image_download_url(image_key);
let resp = match self
.http_client()
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::warn!("Lark: image download request failed for {image_key}: {e}");
return None;
}
};
if !resp.status().is_success() {
tracing::warn!(
"Lark: image download failed for {image_key}: status={}",
resp.status()
);
return None;
}
if let Some(cl) = resp.content_length() {
if cl > LARK_IMAGE_MAX_BYTES as u64 {
tracing::warn!("Lark: image too large for {image_key}: {cl} bytes exceeds limit");
return None;
}
}
let content_type = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let bytes = match resp.bytes().await {
Ok(b) => b,
Err(e) => {
tracing::warn!("Lark: image body read failed for {image_key}: {e}");
return None;
}
};
if bytes.is_empty() || bytes.len() > LARK_IMAGE_MAX_BYTES {
tracing::warn!(
"Lark: image body empty or too large for {image_key}: {} bytes",
bytes.len()
);
return None;
}
let mime = lark_detect_image_mime(content_type.as_deref(), &bytes)?;
if !LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
tracing::warn!("Lark: unsupported image MIME for {image_key}: {mime}");
return None;
}
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
Some(format!("[IMAGE:data:{mime};base64,{encoded}]"))
}
/// Download a file from the Lark API and return a text content marker.
/// For text-like files, the content is inlined. For binary files, a summary is returned.
async fn download_file_as_content(
&self,
message_id: &str,
file_key: &str,
file_name: &str,
) -> Option<String> {
let token = match self.get_tenant_access_token().await {
Ok(t) => t,
Err(e) => {
tracing::warn!("Lark: failed to get token for file download: {e}");
return None;
}
};
let url = self.file_download_url(message_id, file_key);
let resp = match self
.http_client()
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::warn!("Lark: file download request failed for {file_key}: {e}");
return None;
}
};
if !resp.status().is_success() {
tracing::warn!(
"Lark: file download failed for {file_key}: status={}",
resp.status()
);
return None;
}
if let Some(cl) = resp.content_length() {
if cl > LARK_FILE_MAX_BYTES as u64 {
tracing::warn!("Lark: file too large for {file_key}: {cl} bytes exceeds limit");
return Some(format!(
"[ATTACHMENT:{file_name} | size={cl} bytes | too large to inline]"
));
}
}
let content_type = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let bytes = match resp.bytes().await {
Ok(b) => b,
Err(e) => {
tracing::warn!("Lark: file body read failed for {file_key}: {e}");
return None;
}
};
if bytes.is_empty() {
tracing::warn!("Lark: file body is empty for {file_key}");
return None;
}
// If the content is image-like, return as image marker
if content_type.starts_with("image/") && bytes.len() <= LARK_IMAGE_MAX_BYTES {
if let Some(mime) = lark_detect_image_mime(Some(&content_type), &bytes) {
if LARK_SUPPORTED_IMAGE_MIMES.contains(&mime.as_str()) {
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
return Some(format!("[IMAGE:data:{mime};base64,{encoded}]"));
}
}
}
// If the file looks like text, inline it
if bytes.len() <= LARK_FILE_MAX_BYTES
&& !bytes.contains(&0)
&& (content_type.starts_with("text/")
|| content_type.contains("json")
|| content_type.contains("xml")
|| content_type.contains("yaml")
|| content_type.contains("javascript")
|| content_type.contains("csv")
|| lark_is_text_filename(file_name))
{
let text = String::from_utf8_lossy(&bytes);
let truncated = if text.len() > 50_000 {
format!("{}...\n[truncated]", &text[..50_000])
} else {
text.into_owned()
};
let ext = file_name.rsplit('.').next().unwrap_or("text");
return Some(format!("[FILE:{file_name}]\n```{ext}\n{truncated}\n```"));
}
Some(format!(
"[ATTACHMENT:{file_name} | mime={content_type} | size={} bytes]",
bytes.len()
))
}
async fn fetch_bot_open_id_with_token(
&self,
token: &str,
@@ -1085,8 +1327,9 @@ impl LarkChannel {
Ok((status, parsed))
}
/// Parse an event callback payload and extract text messages
pub fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
/// Parse an event callback payload and extract messages.
/// Supports text, post, image, and file message types.
pub async fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
let mut messages = Vec::new();
// Lark event v2 structure:
@@ -1143,6 +1386,11 @@ impl LarkChannel {
.and_then(|c| c.as_str())
.unwrap_or("");
let evt_message_id = event
.pointer("/message/message_id")
.and_then(|m| m.as_str())
.unwrap_or("");
let (text, post_mentioned_open_ids): (String, Vec<String>) = match msg_type {
"text" => {
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
@@ -1162,6 +1410,62 @@ impl LarkChannel {
Some(details) => (details.text, details.mentioned_open_ids),
None => return messages,
},
"image" => {
let image_key = serde_json::from_str::<serde_json::Value>(content_str)
.ok()
.and_then(|v| {
v.get("image_key")
.and_then(|k| k.as_str())
.map(String::from)
});
match image_key {
Some(key) => {
let marker = match self.download_image_as_marker(&key).await {
Some(m) => m,
None => {
tracing::warn!("Lark: failed to download image {key}");
format!("[IMAGE:{key} | download failed]")
}
};
(marker, Vec::new())
}
None => {
tracing::debug!("Lark: image message missing image_key");
return messages;
}
}
}
"file" => {
let parsed = serde_json::from_str::<serde_json::Value>(content_str).ok();
let file_key = parsed
.as_ref()
.and_then(|v| v.get("file_key").and_then(|k| k.as_str()))
.map(String::from);
let file_name = parsed
.as_ref()
.and_then(|v| v.get("file_name").and_then(|n| n.as_str()))
.unwrap_or("unknown_file")
.to_string();
match file_key {
Some(key) => {
let content = match self
.download_file_as_content(evt_message_id, &key, &file_name)
.await
{
Some(c) => c,
None => {
tracing::warn!("Lark: failed to download file {key}");
format!("[ATTACHMENT:{file_name} | download failed]")
}
};
(content, Vec::new())
}
None => {
tracing::debug!("Lark: file message missing file_key");
return messages;
}
}
}
_ => {
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
return messages;
@@ -1305,7 +1609,7 @@ impl LarkChannel {
}
// Parse event messages
let messages = state.channel.parse_event_payload(&payload);
let messages = state.channel.parse_event_payload(&payload).await;
if !messages.is_empty() {
if let Some(message_id) = payload
.pointer("/event/message/message_id")
@@ -1556,6 +1860,72 @@ fn detect_lark_ack_locale(
detect_locale_from_text(fallback_text).unwrap_or(LarkAckLocale::En)
}
/// Detect image MIME type from magic bytes, falling back to Content-Type header.
fn lark_detect_image_mime(content_type: Option<&str>, bytes: &[u8]) -> Option<String> {
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
return Some("image/png".to_string());
}
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
return Some("image/jpeg".to_string());
}
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
return Some("image/gif".to_string());
}
if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
return Some("image/webp".to_string());
}
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
return Some("image/bmp".to_string());
}
content_type
.and_then(|ct| ct.split(';').next())
.map(|ct| ct.trim().to_lowercase())
.filter(|ct| ct.starts_with("image/"))
}
/// Check if a filename looks like a text file based on extension.
fn lark_is_text_filename(name: &str) -> bool {
let ext = name.rsplit('.').next().unwrap_or("").to_ascii_lowercase();
matches!(
ext.as_str(),
"txt"
| "md"
| "rs"
| "py"
| "js"
| "ts"
| "tsx"
| "jsx"
| "java"
| "c"
| "h"
| "cpp"
| "hpp"
| "go"
| "rb"
| "sh"
| "bash"
| "zsh"
| "toml"
| "yaml"
| "yml"
| "json"
| "xml"
| "html"
| "css"
| "sql"
| "csv"
| "tsv"
| "log"
| "cfg"
| "ini"
| "conf"
| "env"
| "dockerfile"
| "makefile"
)
}
fn random_lark_ack_reaction(
payload: Option<&serde_json::Value>,
fallback_text: &str,
@@ -1892,8 +2262,8 @@ mod tests {
assert!(!ch.is_user_allowed("ou_anyone"));
}
#[test]
fn lark_parse_challenge() {
#[tokio::test]
async fn lark_parse_challenge() {
let ch = make_channel();
let payload = serde_json::json!({
"challenge": "abc123",
@@ -1901,12 +2271,12 @@ mod tests {
"type": "url_verification"
});
// Challenge payloads should not produce messages
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_valid_text_message() {
#[tokio::test]
async fn lark_parse_valid_text_message() {
let ch = make_channel();
let payload = serde_json::json!({
"header": {
@@ -1927,7 +2297,7 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
assert_eq!(msgs[0].sender, "oc_chat123");
@@ -1935,8 +2305,8 @@ mod tests {
assert_eq!(msgs[0].timestamp, 1_699_999_999);
}
#[test]
fn lark_parse_unauthorized_user() {
#[tokio::test]
async fn lark_parse_unauthorized_user() {
let ch = make_channel();
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
@@ -1951,12 +2321,38 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_non_text_message_skipped() {
#[tokio::test]
async fn lark_parse_unsupported_message_type_skipped() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
None,
vec!["*".into()],
true,
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "sticker",
"content": "{}",
"chat_id": "oc_chat"
}
}
});
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[tokio::test]
async fn lark_parse_image_missing_key_skipped() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
@@ -1977,12 +2373,38 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_empty_text_skipped() {
#[tokio::test]
async fn lark_parse_file_missing_key_skipped() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
None,
vec!["*".into()],
true,
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "file",
"content": "{}",
"chat_id": "oc_chat"
}
}
});
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[tokio::test]
async fn lark_parse_empty_text_skipped() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
@@ -2003,24 +2425,24 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_wrong_event_type() {
#[tokio::test]
async fn lark_parse_wrong_event_type() {
let ch = make_channel();
let payload = serde_json::json!({
"header": { "event_type": "im.chat.disbanded_v1" },
"event": {}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_missing_sender() {
#[tokio::test]
async fn lark_parse_missing_sender() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
@@ -2040,12 +2462,12 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_unicode_message() {
#[tokio::test]
async fn lark_parse_unicode_message() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
@@ -2067,24 +2489,24 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "Hello world 🌍");
}
#[test]
fn lark_parse_missing_event() {
#[tokio::test]
async fn lark_parse_missing_event() {
let ch = make_channel();
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" }
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_invalid_content_json() {
#[tokio::test]
async fn lark_parse_invalid_content_json() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
@@ -2105,7 +2527,7 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert!(msgs.is_empty());
}
@@ -2237,8 +2659,8 @@ mod tests {
assert_eq!(ch.name(), "feishu");
}
#[test]
fn lark_parse_fallback_sender_to_open_id() {
#[tokio::test]
async fn lark_parse_fallback_sender_to_open_id() {
// When chat_id is missing, sender should fall back to open_id
let ch = LarkChannel::new(
"id".into(),
@@ -2260,13 +2682,13 @@ mod tests {
}
});
let msgs = ch.parse_event_payload(&payload);
let msgs = ch.parse_event_payload(&payload).await;
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].sender, "ou_user");
}
#[test]
fn lark_parse_group_message_requires_bot_mention_when_enabled() {
#[tokio::test]
async fn lark_parse_group_message_requires_bot_mention_when_enabled() {
let ch = with_bot_open_id(
LarkChannel::new(
"cli_app123".into(),
@@ -2292,7 +2714,7 @@ mod tests {
}
}
});
assert!(ch.parse_event_payload(&no_mention_payload).is_empty());
assert!(ch.parse_event_payload(&no_mention_payload).await.is_empty());
let wrong_mention_payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
@@ -2307,7 +2729,10 @@ mod tests {
}
}
});
assert!(ch.parse_event_payload(&wrong_mention_payload).is_empty());
assert!(ch
.parse_event_payload(&wrong_mention_payload)
.await
.is_empty());
let bot_mention_payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
@@ -2322,11 +2747,11 @@ mod tests {
}
}
});
assert_eq!(ch.parse_event_payload(&bot_mention_payload).len(), 1);
assert_eq!(ch.parse_event_payload(&bot_mention_payload).await.len(), 1);
}
#[test]
fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
#[tokio::test]
async fn lark_parse_group_post_message_accepts_at_when_top_level_mentions_empty() {
let ch = with_bot_open_id(
LarkChannel::new(
"cli_app123".into(),
@@ -2353,11 +2778,11 @@ mod tests {
}
});
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
}
#[test]
fn lark_parse_group_message_allows_without_mention_when_disabled() {
#[tokio::test]
async fn lark_parse_group_message_allows_without_mention_when_disabled() {
let ch = LarkChannel::new(
"cli_app123".into(),
"secret".into(),
@@ -2381,7 +2806,7 @@ mod tests {
}
});
assert_eq!(ch.parse_event_payload(&payload).len(), 1);
assert_eq!(ch.parse_event_payload(&payload).await.len(), 1);
}
#[test]
@@ -2409,6 +2834,69 @@ mod tests {
);
}
#[test]
fn lark_image_download_url_matches_region() {
let ch = make_channel();
assert_eq!(
ch.image_download_url("img_abc123"),
"https://open.larksuite.com/open-apis/im/v1/images/img_abc123"
);
}
#[test]
fn lark_file_download_url_matches_region() {
let ch = make_channel();
assert_eq!(
ch.file_download_url("om_msg123", "file_abc"),
"https://open.larksuite.com/open-apis/im/v1/messages/om_msg123/resources/file_abc?type=file"
);
}
#[test]
fn lark_detect_image_mime_from_magic_bytes() {
let png = [0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'];
assert_eq!(
lark_detect_image_mime(None, &png).as_deref(),
Some("image/png")
);
let jpeg = [0xff, 0xd8, 0xff, 0xe0];
assert_eq!(
lark_detect_image_mime(None, &jpeg).as_deref(),
Some("image/jpeg")
);
let gif = b"GIF89a...";
assert_eq!(
lark_detect_image_mime(None, gif).as_deref(),
Some("image/gif")
);
// Unknown bytes should fall back to content-type header
let unknown = [0x00, 0x01, 0x02];
assert_eq!(
lark_detect_image_mime(Some("image/webp"), &unknown).as_deref(),
Some("image/webp")
);
// Non-image content-type should be rejected
assert_eq!(lark_detect_image_mime(Some("text/html"), &unknown), None);
// No info at all should return None
assert_eq!(lark_detect_image_mime(None, &unknown), None);
}
#[test]
fn lark_is_text_filename_recognizes_common_extensions() {
assert!(lark_is_text_filename("script.py"));
assert!(lark_is_text_filename("config.toml"));
assert!(lark_is_text_filename("data.csv"));
assert!(lark_is_text_filename("README.md"));
assert!(!lark_is_text_filename("image.png"));
assert!(!lark_is_text_filename("archive.zip"));
assert!(!lark_is_text_filename("binary.exe"));
}
#[test]
fn lark_reaction_locale_explicit_language_tags() {
assert_eq!(map_locale_tag("zh-CN"), Some(LarkAckLocale::ZhCn));
+462
View File
@@ -0,0 +1,462 @@
//! Link enricher: auto-detects URLs in inbound messages, fetches their content,
//! and prepends summaries so the agent has link context without explicit tool calls.
use regex::Regex;
use std::net::IpAddr;
use std::sync::LazyLock;
use std::time::Duration;
/// Configuration for the link enricher pipeline stage.
#[derive(Debug, Clone)]
pub struct LinkEnricherConfig {
pub enabled: bool,
pub max_links: usize,
pub timeout_secs: u64,
}
impl Default for LinkEnricherConfig {
fn default() -> Self {
Self {
enabled: false,
max_links: 3,
timeout_secs: 10,
}
}
}
/// URL regex: matches http:// and https:// URLs, stopping at whitespace, angle
/// brackets, or double-quotes.
static URL_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r#"https?://[^\s<>"']+"#).expect("URL regex must compile"));
/// Extract URLs from message text, returning up to `max` unique URLs.
pub fn extract_urls(text: &str, max: usize) -> Vec<String> {
let mut seen = Vec::new();
for m in URL_RE.find_iter(text) {
let url = m.as_str().to_string();
if !seen.contains(&url) {
seen.push(url);
if seen.len() >= max {
break;
}
}
}
seen
}
/// Returns `true` if the URL points to a private/local address that should be
/// blocked for SSRF protection.
pub fn is_ssrf_target(url: &str) -> bool {
let host = match extract_host(url) {
Some(h) => h,
None => return true, // unparseable URLs are rejected
};
// Check hostname-based locals
if host == "localhost"
|| host.ends_with(".localhost")
|| host.ends_with(".local")
|| host == "local"
{
return true;
}
// Check IP-based private ranges
if let Ok(ip) = host.parse::<IpAddr>() {
return is_private_ip(ip);
}
false
}
/// Extract the host portion from a URL string.
fn extract_host(url: &str) -> Option<String> {
let rest = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))?;
let authority = rest.split(['/', '?', '#']).next()?;
if authority.is_empty() {
return None;
}
// Strip port
let host = if authority.starts_with('[') {
// IPv6 in brackets — reject for simplicity
return None;
} else {
authority.split(':').next().unwrap_or(authority)
};
Some(host.to_lowercase())
}
/// Check if an IP address falls within private/reserved ranges.
fn is_private_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
v4.is_loopback() // 127.0.0.0/8
|| v4.is_private() // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
|| v4.is_link_local() // 169.254.0.0/16
|| v4.is_unspecified() // 0.0.0.0
|| v4.is_broadcast() // 255.255.255.255
|| v4.is_multicast() // 224.0.0.0/4
}
IpAddr::V6(v6) => {
v6.is_loopback() // ::1
|| v6.is_unspecified() // ::
|| v6.is_multicast()
// Check for IPv4-mapped IPv6 addresses
|| v6.to_ipv4_mapped().is_some_and(|v4| {
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
})
}
}
}
/// Extract the `<title>` tag content from HTML.
pub fn extract_title(html: &str) -> Option<String> {
// Case-insensitive search for <title>...</title>
let lower = html.to_lowercase();
let start = lower.find("<title")? + "<title".len();
// Skip attributes if any (e.g. <title lang="en">)
let start = lower[start..].find('>')? + start + 1;
let end = lower[start..].find("</title")? + start;
let title = lower[start..end].trim().to_string();
if title.is_empty() {
None
} else {
Some(html_entity_decode_basic(&title))
}
}
/// Extract the first `max_chars` of visible body text from HTML.
pub fn extract_body_text(html: &str, max_chars: usize) -> String {
let text = nanohtml2text::html2text(html);
let trimmed = text.trim();
if trimmed.len() <= max_chars {
trimmed.to_string()
} else {
let mut result: String = trimmed.chars().take(max_chars).collect();
result.push_str("...");
result
}
}
/// Basic HTML entity decoding for title content.
fn html_entity_decode_basic(s: &str) -> String {
s.replace("&amp;", "&")
.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&quot;", "\"")
.replace("&#39;", "'")
.replace("&apos;", "'")
}
/// Summary of a fetched link.
struct LinkSummary {
title: String,
snippet: String,
}
/// Fetch a single URL and extract a summary. Returns `None` on any failure.
async fn fetch_link_summary(url: &str, timeout_secs: u64) -> Option<LinkSummary> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.connect_timeout(Duration::from_secs(5))
.redirect(reqwest::redirect::Policy::limited(5))
.user_agent("ZeroClaw/0.1 (link-enricher)")
.build()
.ok()?;
let response = client.get(url).send().await.ok()?;
if !response.status().is_success() {
return None;
}
// Only process text/html responses
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_lowercase();
if !content_type.contains("text/html") && !content_type.is_empty() {
return None;
}
// Read up to 256KB to extract title and snippet
let max_bytes: usize = 256 * 1024;
let bytes = response.bytes().await.ok()?;
let body = if bytes.len() > max_bytes {
String::from_utf8_lossy(&bytes[..max_bytes]).into_owned()
} else {
String::from_utf8_lossy(&bytes).into_owned()
};
let title = extract_title(&body).unwrap_or_else(|| "Untitled".to_string());
let snippet = extract_body_text(&body, 200);
Some(LinkSummary { title, snippet })
}
/// Enrich a message by prepending link summaries for any URLs found in the text.
///
/// This is the main entry point called from the channel message processing pipeline.
/// If the enricher is disabled or no URLs are found, the original message is returned
/// unchanged.
pub async fn enrich_message(content: &str, config: &LinkEnricherConfig) -> String {
if !config.enabled || config.max_links == 0 {
return content.to_string();
}
let urls = extract_urls(content, config.max_links);
if urls.is_empty() {
return content.to_string();
}
// Filter out SSRF targets
let safe_urls: Vec<&str> = urls
.iter()
.filter(|u| !is_ssrf_target(u))
.map(|u| u.as_str())
.collect();
if safe_urls.is_empty() {
return content.to_string();
}
let mut enrichments = Vec::new();
for url in safe_urls {
match fetch_link_summary(url, config.timeout_secs).await {
Some(summary) => {
enrichments.push(format!("[Link: {}{}]", summary.title, summary.snippet));
}
None => {
tracing::debug!(url, "Link enricher: failed to fetch or extract summary");
}
}
}
if enrichments.is_empty() {
return content.to_string();
}
let prefix = enrichments.join("\n");
format!("{prefix}\n{content}")
}
#[cfg(test)]
mod tests {
use super::*;
// ── URL extraction ──────────────────────────────────────────────
#[test]
fn extract_urls_finds_http_and_https() {
let text = "Check https://example.com and http://test.org/page for info";
let urls = extract_urls(text, 10);
assert_eq!(urls, vec!["https://example.com", "http://test.org/page",]);
}
#[test]
fn extract_urls_respects_max() {
let text = "https://a.com https://b.com https://c.com https://d.com";
let urls = extract_urls(text, 2);
assert_eq!(urls.len(), 2);
assert_eq!(urls[0], "https://a.com");
assert_eq!(urls[1], "https://b.com");
}
#[test]
fn extract_urls_deduplicates() {
let text = "Visit https://example.com and https://example.com again";
let urls = extract_urls(text, 10);
assert_eq!(urls.len(), 1);
}
#[test]
fn extract_urls_handles_no_urls() {
let text = "Just a normal message without links";
let urls = extract_urls(text, 10);
assert!(urls.is_empty());
}
#[test]
fn extract_urls_stops_at_angle_brackets() {
let text = "Link: <https://example.com/path> done";
let urls = extract_urls(text, 10);
assert_eq!(urls, vec!["https://example.com/path"]);
}
#[test]
fn extract_urls_stops_at_quotes() {
let text = r#"href="https://example.com/page" end"#;
let urls = extract_urls(text, 10);
assert_eq!(urls, vec!["https://example.com/page"]);
}
// ── SSRF protection ─────────────────────────────────────────────
#[test]
fn ssrf_blocks_localhost() {
assert!(is_ssrf_target("http://localhost/admin"));
assert!(is_ssrf_target("https://localhost:8080/api"));
}
#[test]
fn ssrf_blocks_loopback_ip() {
assert!(is_ssrf_target("http://127.0.0.1/secret"));
assert!(is_ssrf_target("http://127.0.0.2:9090"));
}
#[test]
fn ssrf_blocks_private_10_network() {
assert!(is_ssrf_target("http://10.0.0.1/internal"));
assert!(is_ssrf_target("http://10.255.255.255"));
}
#[test]
fn ssrf_blocks_private_172_network() {
assert!(is_ssrf_target("http://172.16.0.1/admin"));
assert!(is_ssrf_target("http://172.31.255.255"));
}
#[test]
fn ssrf_blocks_private_192_168_network() {
assert!(is_ssrf_target("http://192.168.1.1/router"));
assert!(is_ssrf_target("http://192.168.0.100:3000"));
}
#[test]
fn ssrf_blocks_link_local() {
assert!(is_ssrf_target("http://169.254.0.1/metadata"));
assert!(is_ssrf_target("http://169.254.169.254/latest"));
}
#[test]
fn ssrf_blocks_ipv6_loopback() {
// IPv6 in brackets is rejected by extract_host
assert!(is_ssrf_target("http://[::1]/admin"));
}
#[test]
fn ssrf_blocks_dot_local() {
assert!(is_ssrf_target("http://myhost.local/api"));
}
#[test]
fn ssrf_allows_public_urls() {
assert!(!is_ssrf_target("https://example.com/page"));
assert!(!is_ssrf_target("https://www.google.com"));
assert!(!is_ssrf_target("http://93.184.216.34/resource"));
}
// ── Title extraction ────────────────────────────────────────────
#[test]
fn extract_title_basic() {
let html = "<html><head><title>My Page Title</title></head><body>Hello</body></html>";
assert_eq!(extract_title(html), Some("my page title".to_string()));
}
#[test]
fn extract_title_with_entities() {
let html = "<title>Tom &amp; Jerry&#39;s Page</title>";
assert_eq!(extract_title(html), Some("tom & jerry's page".to_string()));
}
#[test]
fn extract_title_case_insensitive() {
let html = "<HTML><HEAD><TITLE>Upper Case</TITLE></HEAD></HTML>";
assert_eq!(extract_title(html), Some("upper case".to_string()));
}
#[test]
fn extract_title_multibyte_chars_no_panic() {
// İ (U+0130) lowercases to 2 chars, changing byte length.
// This must not panic or produce wrong offsets.
let html = "<title>İstanbul Guide</title>";
let result = extract_title(html);
assert!(result.is_some());
let title = result.unwrap();
assert!(title.contains("stanbul"));
}
#[test]
fn extract_title_missing() {
let html = "<html><body>No title here</body></html>";
assert_eq!(extract_title(html), None);
}
#[test]
fn extract_title_empty() {
let html = "<title> </title>";
assert_eq!(extract_title(html), None);
}
// ── Body text extraction ────────────────────────────────────────
#[test]
fn extract_body_text_strips_html() {
let html = "<html><body><h1>Header</h1><p>Some content here</p></body></html>";
let text = extract_body_text(html, 200);
assert!(text.contains("Header"));
assert!(text.contains("Some content"));
assert!(!text.contains("<h1>"));
}
#[test]
fn extract_body_text_truncates() {
let html = "<p>A very long paragraph that should be truncated to fit within the limit.</p>";
let text = extract_body_text(html, 20);
assert!(text.len() <= 25); // 20 chars + "..."
assert!(text.ends_with("..."));
}
// ── Config toggle ───────────────────────────────────────────────
#[tokio::test]
async fn enrich_message_disabled_returns_original() {
let config = LinkEnricherConfig {
enabled: false,
max_links: 3,
timeout_secs: 10,
};
let msg = "Check https://example.com for details";
let result = enrich_message(msg, &config).await;
assert_eq!(result, msg);
}
#[tokio::test]
async fn enrich_message_no_urls_returns_original() {
let config = LinkEnricherConfig {
enabled: true,
max_links: 3,
timeout_secs: 10,
};
let msg = "No links in this message";
let result = enrich_message(msg, &config).await;
assert_eq!(result, msg);
}
#[tokio::test]
async fn enrich_message_ssrf_urls_returns_original() {
let config = LinkEnricherConfig {
enabled: true,
max_links: 3,
timeout_secs: 10,
};
let msg = "Try http://127.0.0.1/admin and http://192.168.1.1/router";
let result = enrich_message(msg, &config).await;
assert_eq!(result, msg);
}
#[test]
fn default_config_is_disabled() {
let config = LinkEnricherConfig::default();
assert!(!config.enabled);
assert_eq!(config.max_links, 3);
assert_eq!(config.timeout_secs, 10);
}
}
+210 -7
View File
@@ -8,6 +8,7 @@ use matrix_sdk::{
events::reaction::ReactionEventContent,
events::receipt::ReceiptThread,
events::relation::{Annotation, Thread},
events::room::member::StrippedRoomMemberEvent,
events::room::message::{
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
},
@@ -32,6 +33,7 @@ pub struct MatrixChannel {
access_token: String,
room_id: String,
allowed_users: Vec<String>,
allowed_rooms: Vec<String>,
session_owner_hint: Option<String>,
session_device_id_hint: Option<String>,
zeroclaw_dir: Option<PathBuf>,
@@ -48,6 +50,7 @@ impl std::fmt::Debug for MatrixChannel {
.field("homeserver", &self.homeserver)
.field("room_id", &self.room_id)
.field("allowed_users", &self.allowed_users)
.field("allowed_rooms", &self.allowed_rooms)
.finish_non_exhaustive()
}
}
@@ -121,7 +124,16 @@ impl MatrixChannel {
room_id: String,
allowed_users: Vec<String>,
) -> Self {
Self::new_with_session_hint(homeserver, access_token, room_id, allowed_users, None, None)
Self::new_full(
homeserver,
access_token,
room_id,
allowed_users,
vec![],
None,
None,
None,
)
}
pub fn new_with_session_hint(
@@ -132,11 +144,12 @@ impl MatrixChannel {
owner_hint: Option<String>,
device_id_hint: Option<String>,
) -> Self {
Self::new_with_session_hint_and_zeroclaw_dir(
Self::new_full(
homeserver,
access_token,
room_id,
allowed_users,
vec![],
owner_hint,
device_id_hint,
None,
@@ -151,6 +164,28 @@ impl MatrixChannel {
owner_hint: Option<String>,
device_id_hint: Option<String>,
zeroclaw_dir: Option<PathBuf>,
) -> Self {
Self::new_full(
homeserver,
access_token,
room_id,
allowed_users,
vec![],
owner_hint,
device_id_hint,
zeroclaw_dir,
)
}
pub fn new_full(
homeserver: String,
access_token: String,
room_id: String,
allowed_users: Vec<String>,
allowed_rooms: Vec<String>,
owner_hint: Option<String>,
device_id_hint: Option<String>,
zeroclaw_dir: Option<PathBuf>,
) -> Self {
let homeserver = homeserver.trim_end_matches('/').to_string();
let access_token = access_token.trim().to_string();
@@ -160,12 +195,18 @@ impl MatrixChannel {
.map(|user| user.trim().to_string())
.filter(|user| !user.is_empty())
.collect();
let allowed_rooms = allowed_rooms
.into_iter()
.map(|room| room.trim().to_string())
.filter(|room| !room.is_empty())
.collect();
Self {
homeserver,
access_token,
room_id,
allowed_users,
allowed_rooms,
session_owner_hint: Self::normalize_optional_field(owner_hint),
session_device_id_hint: Self::normalize_optional_field(device_id_hint),
zeroclaw_dir,
@@ -220,6 +261,21 @@ impl MatrixChannel {
allowed_users.iter().any(|u| u.eq_ignore_ascii_case(sender))
}
/// Check whether a room (by its canonical ID) is in the allowed_rooms list.
/// If allowed_rooms is empty, all rooms are allowed.
fn is_room_allowed_static(allowed_rooms: &[String], room_id: &str) -> bool {
if allowed_rooms.is_empty() {
return true;
}
allowed_rooms
.iter()
.any(|r| r.eq_ignore_ascii_case(room_id))
}
fn is_room_allowed(&self, room_id: &str) -> bool {
Self::is_room_allowed_static(&self.allowed_rooms, room_id)
}
fn is_supported_message_type(msgtype: &str) -> bool {
matches!(msgtype, "m.text" | "m.notice")
}
@@ -228,6 +284,10 @@ impl MatrixChannel {
!body.trim().is_empty()
}
fn room_matches_target(target_room_id: &str, incoming_room_id: &str) -> bool {
target_room_id == incoming_room_id
}
fn cache_event_id(
event_id: &str,
recent_order: &mut std::collections::VecDeque<String>,
@@ -526,8 +586,9 @@ impl MatrixChannel {
if client.encryption().backups().are_enabled().await {
tracing::info!("Matrix room-key backup is enabled for this device.");
} else {
client.encryption().backups().disable().await;
tracing::warn!(
"Matrix room-key backup is not enabled for this device; `matrix_sdk_crypto::backups` warnings about missing backup keys may appear until recovery is configured."
"Matrix room-key backup is not enabled for this device; automatic backup attempts have been disabled to suppress recurring warnings. To enable backups, configure server-side key backup and recovery for this device."
);
}
}
@@ -697,6 +758,7 @@ impl Channel for MatrixChannel {
let target_room_for_handler = target_room.clone();
let my_user_id_for_handler = my_user_id.clone();
let allowed_users_for_handler = self.allowed_users.clone();
let allowed_rooms_for_handler = self.allowed_rooms.clone();
let dedupe_for_handler = Arc::clone(&recent_event_cache);
let homeserver_for_handler = self.homeserver.clone();
let access_token_for_handler = self.access_token.clone();
@@ -704,18 +766,29 @@ impl Channel for MatrixChannel {
client.add_event_handler(move |event: OriginalSyncRoomMessageEvent, room: Room| {
let tx = tx_handler.clone();
let _target_room = target_room_for_handler.clone();
let target_room = target_room_for_handler.clone();
let my_user_id = my_user_id_for_handler.clone();
let allowed_users = allowed_users_for_handler.clone();
let allowed_rooms = allowed_rooms_for_handler.clone();
let dedupe = Arc::clone(&dedupe_for_handler);
let homeserver = homeserver_for_handler.clone();
let access_token = access_token_for_handler.clone();
let voice_mode = Arc::clone(&voice_mode_for_handler);
async move {
if false
/* multi-room: room_id filter disabled */
{
if !MatrixChannel::room_matches_target(
target_room.as_str(),
room.room_id().as_str(),
) {
return;
}
// Room allowlist: skip messages from rooms not in the configured list
if !MatrixChannel::is_room_allowed_static(&allowed_rooms, room.room_id().as_ref()) {
tracing::debug!(
"Matrix: ignoring message from room {} (not in allowed_rooms)",
room.room_id()
);
return;
}
@@ -907,6 +980,45 @@ impl Channel for MatrixChannel {
}
});
// Invite handler: auto-accept invites for allowed rooms, auto-reject others
let allowed_rooms_for_invite = self.allowed_rooms.clone();
client.add_event_handler(move |event: StrippedRoomMemberEvent, room: Room| {
let allowed_rooms = allowed_rooms_for_invite.clone();
async move {
// Only process invite events targeting us
if event.content.membership
!= matrix_sdk::ruma::events::room::member::MembershipState::Invite
{
return;
}
let room_id_str = room.room_id().to_string();
if MatrixChannel::is_room_allowed_static(&allowed_rooms, &room_id_str) {
// Room is allowed (or no allowlist configured): auto-accept
tracing::info!(
"Matrix: auto-accepting invite for allowed room {}",
room_id_str
);
if let Err(error) = room.join().await {
tracing::warn!("Matrix: failed to auto-join room {}: {error}", room_id_str);
}
} else {
// Room is NOT in allowlist: auto-reject
tracing::info!(
"Matrix: auto-rejecting invite for room {} (not in allowed_rooms)",
room_id_str
);
if let Err(error) = room.leave().await {
tracing::warn!(
"Matrix: failed to reject invite for room {}: {error}",
room_id_str
);
}
}
}
});
let sync_settings = SyncSettings::new().timeout(std::time::Duration::from_secs(30));
client
.sync_with_result_callback(sync_settings, |sync_result| {
@@ -1294,6 +1406,22 @@ mod tests {
assert_eq!(value["room"]["timeline"]["limit"], 1);
}
#[test]
fn room_scope_matches_configured_room() {
assert!(MatrixChannel::room_matches_target(
"!ops:matrix.org",
"!ops:matrix.org"
));
}
#[test]
fn room_scope_rejects_other_rooms() {
assert!(!MatrixChannel::room_matches_target(
"!ops:matrix.org",
"!other:matrix.org"
));
}
#[test]
fn event_id_cache_deduplicates_and_evicts_old_entries() {
let mut recent_order = std::collections::VecDeque::new();
@@ -1549,4 +1677,79 @@ mod tests {
let resp: SyncResponse = serde_json::from_str(json).unwrap();
assert!(resp.rooms.join.is_empty());
}
#[test]
fn empty_allowed_rooms_permits_all() {
let ch = make_channel();
assert!(ch.is_room_allowed("!any:matrix.org"));
assert!(ch.is_room_allowed("!other:evil.org"));
}
#[test]
fn allowed_rooms_filters_by_id() {
let ch = MatrixChannel::new_full(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec!["@user:m".to_string()],
vec!["!allowed:matrix.org".to_string()],
None,
None,
None,
);
assert!(ch.is_room_allowed("!allowed:matrix.org"));
assert!(!ch.is_room_allowed("!forbidden:matrix.org"));
}
#[test]
fn allowed_rooms_supports_aliases() {
let ch = MatrixChannel::new_full(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec!["@user:m".to_string()],
vec![
"#ops:matrix.org".to_string(),
"!direct:matrix.org".to_string(),
],
None,
None,
None,
);
assert!(ch.is_room_allowed("!direct:matrix.org"));
assert!(ch.is_room_allowed("#ops:matrix.org"));
assert!(!ch.is_room_allowed("!other:matrix.org"));
}
#[test]
fn allowed_rooms_case_insensitive() {
let ch = MatrixChannel::new_full(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec![],
vec!["!Room:Matrix.org".to_string()],
None,
None,
None,
);
assert!(ch.is_room_allowed("!room:matrix.org"));
assert!(ch.is_room_allowed("!ROOM:MATRIX.ORG"));
}
#[test]
fn allowed_rooms_trims_whitespace() {
let ch = MatrixChannel::new_full(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec![],
vec![" !room:matrix.org ".to_string(), " ".to_string()],
None,
None,
None,
);
assert_eq!(ch.allowed_rooms.len(), 1);
assert!(ch.is_room_allowed("!room:matrix.org"));
}
}
+272 -50
View File
@@ -19,11 +19,14 @@ pub mod clawdtalk;
pub mod cli;
pub mod dingtalk;
pub mod discord;
pub mod discord_history;
pub mod email_channel;
pub mod gmail_push;
pub mod imessage;
pub mod irc;
#[cfg(feature = "channel-lark")]
pub mod lark;
pub mod link_enricher;
pub mod linq;
#[cfg(feature = "channel-matrix")]
pub mod matrix;
@@ -45,6 +48,8 @@ pub mod traits;
pub mod transcription;
pub mod tts;
pub mod twitter;
#[cfg(feature = "voice-wake")]
pub mod voice_wake;
pub mod wati;
pub mod webhook;
pub mod wecom;
@@ -59,7 +64,9 @@ pub use clawdtalk::{ClawdTalkChannel, ClawdTalkConfig};
pub use cli::CliChannel;
pub use dingtalk::DingTalkChannel;
pub use discord::DiscordChannel;
pub use discord_history::DiscordHistoryChannel;
pub use email_channel::EmailChannel;
pub use gmail_push::GmailPushChannel;
pub use imessage::IMessageChannel;
pub use irc::IrcChannel;
#[cfg(feature = "channel-lark")]
@@ -82,6 +89,8 @@ pub use traits::{Channel, SendMessage};
#[allow(unused_imports)]
pub use tts::{TtsManager, TtsProvider};
pub use twitter::TwitterChannel;
#[cfg(feature = "voice-wake")]
pub use voice_wake::VoiceWakeChannel;
pub use wati::WatiChannel;
pub use webhook::WebhookChannel;
pub use wecom::WeComChannel;
@@ -222,9 +231,21 @@ fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
fn channel_message_timeout_budget_secs(
message_timeout_secs: u64,
max_tool_iterations: usize,
) -> u64 {
channel_message_timeout_budget_secs_with_cap(
message_timeout_secs,
max_tool_iterations,
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP,
)
}
fn channel_message_timeout_budget_secs_with_cap(
message_timeout_secs: u64,
max_tool_iterations: usize,
scale_cap: u64,
) -> u64 {
let iterations = max_tool_iterations.max(1) as u64;
let scale = iterations.min(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
let scale = iterations.min(scale_cap);
message_timeout_secs.saturating_mul(scale)
}
@@ -362,6 +383,7 @@ struct ChannelRuntimeContext {
approval_manager: Arc<ApprovalManager>,
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
cost_tracking: Option<ChannelCostTrackingState>,
pacing: crate::config::PacingConfig,
}
#[derive(Clone)]
@@ -2045,6 +2067,25 @@ async fn process_channel_message(
msg
};
// ── Link enricher: prepend URL summaries before agent sees the message ──
let le_config = &ctx.prompt_config.link_enricher;
if le_config.enabled {
let enricher_cfg = link_enricher::LinkEnricherConfig {
enabled: le_config.enabled,
max_links: le_config.max_links,
timeout_secs: le_config.timeout_secs,
};
let enriched = link_enricher::enrich_message(&msg.content, &enricher_cfg).await;
if enriched != msg.content {
tracing::info!(
channel = %msg.channel,
sender = %msg.sender,
"Link enricher: prepended URL summaries to message"
);
msg.content = enriched;
}
}
let target_channel = ctx
.channels_by_name
.get(&msg.channel)
@@ -2402,8 +2443,15 @@ async fn process_channel_message(
}
let model_switch_callback = get_model_switch_state();
let timeout_budget_secs =
channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations);
let scale_cap = ctx
.pacing
.message_timeout_scale_max
.unwrap_or(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP);
let timeout_budget_secs = channel_message_timeout_budget_secs_with_cap(
ctx.message_timeout_secs,
ctx.max_tool_iterations,
scale_cap,
);
let cost_tracking_context = ctx.cost_tracking.clone().map(|state| {
crate::agent::loop_::ToolLoopCostTrackingContext::new(state.tracker, state.prices)
});
@@ -2445,6 +2493,7 @@ async fn process_channel_message(
ctx.tool_call_dedup_exempt.as_ref(),
ctx.activated_tools.as_ref(),
Some(model_switch_callback.clone()),
&ctx.pacing,
),
),
) => LlmExecutionResult::Completed(result),
@@ -3107,9 +3156,12 @@ pub fn build_system_prompt_with_mode(
Some(&autonomy_cfg),
native_tools,
skills_prompt_mode,
false,
0,
)
}
#[allow(clippy::too_many_arguments)]
pub fn build_system_prompt_with_mode_and_autonomy(
workspace_dir: &std::path::Path,
model_name: &str,
@@ -3120,6 +3172,8 @@ pub fn build_system_prompt_with_mode_and_autonomy(
autonomy_config: Option<&crate::config::AutonomyConfig>,
native_tools: bool,
skills_prompt_mode: crate::config::SkillsPromptInjectionMode,
compact_context: bool,
max_system_prompt_chars: usize,
) -> String {
use std::fmt::Write;
let mut prompt = String::with_capacity(8192);
@@ -3146,11 +3200,19 @@ pub fn build_system_prompt_with_mode_and_autonomy(
// ── 1. Tooling ──────────────────────────────────────────────
if !tools.is_empty() {
prompt.push_str("## Tools\n\n");
prompt.push_str("You have access to the following tools:\n\n");
for (name, desc) in tools {
let _ = writeln!(prompt, "- **{name}**: {desc}");
if compact_context {
// Compact mode: tool names only, no descriptions/schemas
prompt.push_str("Available tools: ");
let names: Vec<&str> = tools.iter().map(|(name, _)| *name).collect();
prompt.push_str(&names.join(", "));
prompt.push_str("\n\n");
} else {
prompt.push_str("You have access to the following tools:\n\n");
for (name, desc) in tools {
let _ = writeln!(prompt, "- **{name}**: {desc}");
}
prompt.push('\n');
}
prompt.push('\n');
}
// ── 1b. Hardware (when gpio/arduino tools present) ───────────
@@ -3294,11 +3356,13 @@ pub fn build_system_prompt_with_mode_and_autonomy(
std::env::consts::OS,
);
// ── 8. Channel Capabilities ─────────────────────────────────────
prompt.push_str("## Channel Capabilities\n\n");
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
// ── 8. Channel Capabilities (skipped in compact_context mode) ──
if !compact_context {
prompt.push_str("## Channel Capabilities\n\n");
prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n");
prompt
.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
prompt.push_str(match autonomy_config.map(|cfg| cfg.level) {
Some(crate::security::AutonomyLevel::Full) => {
"- If the runtime policy already allows a tool, use it directly; do not ask the user for extra approval.\n\
- Never pretend you are waiting for a human approval click or confirmation when the runtime policy already permits the action.\n\
@@ -3312,10 +3376,23 @@ pub fn build_system_prompt_with_mode_and_autonomy(
- If there is no approval path for this channel or the runtime blocks an action, explain that restriction directly instead of simulating an approval flow.\n"
}
});
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n");
prompt.push_str("- When a user sends a voice note, it is automatically transcribed to text. Your text reply is automatically converted to a voice note and sent back. Do NOT attempt to generate audio yourself — TTS is handled by the channel.\n");
prompt.push_str("- NEVER narrate or describe your tool usage. Do NOT say 'Let me fetch...', 'I will use...', 'Searching...', or similar. Give the FINAL ANSWER only — no intermediate steps, no tool mentions, no progress updates.\n\n");
} // end if !compact_context (Channel Capabilities)
// ── 9. Truncation (max_system_prompt_chars budget) ──────────
if max_system_prompt_chars > 0 && prompt.len() > max_system_prompt_chars {
// Truncate on a char boundary, keeping the top portion (identity + safety).
let mut end = max_system_prompt_chars;
// Ensure we don't split a multi-byte UTF-8 character.
while !prompt.is_char_boundary(end) && end > 0 {
end -= 1;
}
prompt.truncate(end);
prompt.push_str("\n\n[System prompt truncated to fit context budget]\n");
}
if prompt.is_empty() {
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct."
@@ -3613,13 +3690,16 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
.discord
.as_ref()
.context("Discord channel is not configured")?;
Ok(Arc::new(DiscordChannel::new(
dc.bot_token.clone(),
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
)))
Ok(Arc::new(
DiscordChannel::new(
dc.bot_token.clone(),
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
)
.with_transcription(config.transcription.clone()),
))
}
"slack" => {
let sl = config
@@ -3635,7 +3715,8 @@ fn build_channel_by_id(config: &Config, channel_id: &str) -> Result<Arc<dyn Chan
Vec::new(),
sl.allowed_users.clone(),
)
.with_workspace_dir(config.workspace_dir.clone()),
.with_workspace_dir(config.workspace_dir.clone())
.with_transcription(config.transcription.clone()),
))
}
other => anyhow::bail!("Unknown channel '{other}'. Supported: telegram, discord, slack"),
@@ -3721,11 +3802,37 @@ fn collect_configured_channels(
dc.listen_to_bots,
dc.mention_only,
)
.with_proxy_url(dc.proxy_url.clone()),
.with_proxy_url(dc.proxy_url.clone())
.with_transcription(config.transcription.clone()),
),
});
}
if let Some(ref dh) = config.channels_config.discord_history {
match crate::memory::SqliteMemory::new_named(&config.workspace_dir, "discord") {
Ok(discord_mem) => {
channels.push(ConfiguredChannel {
display_name: "Discord History",
channel: Arc::new(
DiscordHistoryChannel::new(
dh.bot_token.clone(),
dh.guild_id.clone(),
dh.allowed_users.clone(),
dh.channel_ids.clone(),
Arc::new(discord_mem),
dh.store_dms,
dh.respond_to_dms,
)
.with_proxy_url(dh.proxy_url.clone()),
),
});
}
Err(e) => {
tracing::error!("discord_history: failed to open discord.db: {e}");
}
}
}
if let Some(ref sl) = config.channels_config.slack {
channels.push(ConfiguredChannel {
display_name: "Slack",
@@ -3740,7 +3847,8 @@ fn collect_configured_channels(
.with_thread_replies(sl.thread_replies.unwrap_or(true))
.with_group_reply_policy(sl.mention_only, Vec::new())
.with_workspace_dir(config.workspace_dir.clone())
.with_proxy_url(sl.proxy_url.clone()),
.with_proxy_url(sl.proxy_url.clone())
.with_transcription(config.transcription.clone()),
),
});
}
@@ -3773,11 +3881,12 @@ fn collect_configured_channels(
if let Some(ref mx) = config.channels_config.matrix {
channels.push(ConfiguredChannel {
display_name: "Matrix",
channel: Arc::new(MatrixChannel::new_with_session_hint_and_zeroclaw_dir(
channel: Arc::new(MatrixChannel::new_full(
mx.homeserver.clone(),
mx.access_token.clone(),
mx.room_id.clone(),
mx.allowed_users.clone(),
mx.allowed_rooms.clone(),
mx.user_id.clone(),
mx.device_id.clone(),
config.config_path.parent().map(|path| path.to_path_buf()),
@@ -3917,6 +4026,15 @@ fn collect_configured_channels(
});
}
if let Some(ref gp_cfg) = config.channels_config.gmail_push {
if gp_cfg.enabled {
channels.push(ConfiguredChannel {
display_name: "Gmail Push",
channel: Arc::new(GmailPushChannel::new(gp_cfg.clone())),
});
}
}
if let Some(ref irc) = config.channels_config.irc {
channels.push(ConfiguredChannel {
display_name: "IRC",
@@ -4093,6 +4211,17 @@ fn collect_configured_channels(
});
}
#[cfg(feature = "voice-wake")]
if let Some(ref vw) = config.channels_config.voice_wake {
channels.push(ConfiguredChannel {
display_name: "VoiceWake",
channel: Arc::new(VoiceWakeChannel::new(
vw.clone(),
config.transcription.clone(),
)),
});
}
if let Some(ref wh) = config.channels_config.webhook {
channels.push(ConfiguredChannel {
display_name: "Webhook",
@@ -4243,22 +4372,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
};
// Build system prompt from workspace identity files + skills
let workspace = config.workspace_dir.clone();
let (mut built_tools, delegate_handle_ch): (Vec<Box<dyn Tool>>, _) =
tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
Arc::clone(&mem),
composio_key,
composio_entity_id,
&config.browser,
&config.http_request,
&config.web_fetch,
&workspace,
&config.agents,
config.api_key.as_deref(),
&config,
);
let (mut built_tools, delegate_handle_ch, reaction_handle_ch) = tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
Arc::clone(&mem),
composio_key,
composio_entity_id,
&config.browser,
&config.http_request,
&config.web_fetch,
&workspace,
&config.agents,
config.api_key.as_deref(),
&config,
None,
);
// Wire MCP tools into the registry before freezing — non-fatal.
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
@@ -4431,6 +4560,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
Some(&config.autonomy),
native_tools,
config.skills.prompt_injection_mode,
config.agent.compact_context,
config.agent.max_system_prompt_chars,
);
if !native_tools {
system_prompt.push_str(&build_tool_instructions(
@@ -4530,6 +4661,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
.collect::<HashMap<_, _>>(),
);
// Populate the reaction tool's channel map now that channels are initialized.
if let Some(ref handle) = reaction_handle_ch {
let mut map = handle.write();
for (name, ch) in channels_by_name.as_ref() {
map.insert(name.clone(), Arc::clone(ch));
}
}
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
@@ -4641,6 +4781,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
tracker,
prices: Arc::new(config.cost.prices.clone()),
}),
pacing: config.pacing.clone(),
});
// Hydrate in-memory conversation histories from persisted JSONL session files.
@@ -4737,6 +4878,49 @@ mod tests {
);
}
#[test]
fn channel_message_timeout_budget_with_custom_scale_cap() {
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 8, 8),
300 * 8
);
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 20, 8),
300 * 8
);
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 10, 1),
300
);
}
#[test]
fn pacing_config_defaults_preserve_existing_behavior() {
let pacing = crate::config::PacingConfig::default();
assert!(pacing.step_timeout_secs.is_none());
assert!(pacing.loop_detection_min_elapsed_secs.is_none());
assert!(pacing.loop_ignore_tools.is_empty());
assert!(pacing.message_timeout_scale_max.is_none());
}
#[test]
fn pacing_message_timeout_scale_max_overrides_default_cap() {
// Custom cap of 8 scales budget proportionally
assert_eq!(
channel_message_timeout_budget_secs_with_cap(300, 10, 8),
300 * 8
);
// Default cap produces the standard behavior
assert_eq!(
channel_message_timeout_budget_secs_with_cap(
300,
10,
CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
),
300 * CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP
);
}
#[test]
fn context_window_overflow_error_detector_matches_known_messages() {
let overflow_err = anyhow::anyhow!(
@@ -4941,6 +5125,7 @@ mod tests {
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
assert!(compact_sender_history(&ctx, &sender));
@@ -5057,6 +5242,7 @@ mod tests {
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
@@ -5129,6 +5315,7 @@ mod tests {
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
@@ -5220,6 +5407,7 @@ mod tests {
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
};
assert!(rollback_orphan_user_turn(
@@ -5761,6 +5949,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -5842,6 +6031,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -5937,6 +6127,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6017,6 +6208,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6107,6 +6299,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6218,6 +6411,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6310,6 +6504,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6417,6 +6612,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6509,6 +6705,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6591,6 +6788,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -6704,6 +6902,9 @@ BTC is currently around $65,000 based on latest tool output."#
timestamp: "2026-02-20T00:00:00Z".to_string(),
session_id: None,
score: Some(0.9),
namespace: "default".into(),
importance: None,
superseded_by: None,
}])
}
@@ -6788,6 +6989,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
@@ -6890,6 +7092,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -7007,6 +7210,7 @@ BTC is currently around $65,000 based on latest tool output."#
activated_tools: None,
cost_tracking: None,
query_classification: crate::config::QueryClassificationConfig::default(),
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -7121,6 +7325,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -7217,6 +7422,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -7297,6 +7503,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -7519,9 +7726,9 @@ BTC is currently around $65,000 based on latest tool output."#
assert!(prompt.contains("<instructions>"));
assert!(prompt
.contains("<instruction>Always run cargo test before final response.</instruction>"));
assert!(prompt.contains("<tools>"));
assert!(prompt.contains("<name>lint</name>"));
assert!(prompt.contains("<kind>shell</kind>"));
// Registered tools (shell kind) appear under <callable_tools> with prefixed names
assert!(prompt.contains("<callable_tools"));
assert!(prompt.contains("<name>code-review.lint</name>"));
assert!(!prompt.contains("loaded on demand"));
}
@@ -7564,10 +7771,10 @@ BTC is currently around $65,000 based on latest tool output."#
assert!(!prompt.contains("<instructions>"));
assert!(!prompt
.contains("<instruction>Always run cargo test before final response.</instruction>"));
// Compact mode should still include tools so the LLM knows about them
assert!(prompt.contains("<tools>"));
assert!(prompt.contains("<name>lint</name>"));
assert!(prompt.contains("<kind>shell</kind>"));
// Compact mode should still include tools so the LLM knows about them.
// Registered tools (shell kind) appear under <callable_tools> with prefixed names.
assert!(prompt.contains("<callable_tools"));
assert!(prompt.contains("<name>code-review.lint</name>"));
}
#[test]
@@ -7689,6 +7896,8 @@ BTC is currently around $65,000 based on latest tool output."#
Some(&config),
false,
crate::config::SkillsPromptInjectionMode::Full,
false,
0,
);
assert!(
@@ -7718,6 +7927,8 @@ BTC is currently around $65,000 based on latest tool output."#
Some(&config),
false,
crate::config::SkillsPromptInjectionMode::Full,
false,
0,
);
assert!(
@@ -8063,6 +8274,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8194,6 +8406,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8365,6 +8578,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -8473,6 +8687,7 @@ BTC is currently around $65,000 based on latest tool output."#
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9045,6 +9260,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
// Simulate a photo attachment message with [IMAGE:] marker.
@@ -9132,6 +9348,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9294,6 +9511,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9405,6 +9623,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9508,6 +9727,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9631,6 +9851,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
process_channel_message(
@@ -9892,6 +10113,7 @@ This is an example JSON object for profile settings."#;
)),
activated_tools: None,
cost_tracking: None,
pacing: crate::config::PacingConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
+14
View File
@@ -12,6 +12,8 @@ use chrono::{DateTime, Utc};
pub struct SessionMetadata {
/// Session key (e.g. `telegram_user123`).
pub key: String,
/// Optional human-readable name (e.g. `eyrie-commander-briefing`).
pub name: Option<String>,
/// When the session was first created.
pub created_at: DateTime<Utc>,
/// When the last message was appended.
@@ -54,6 +56,7 @@ pub trait SessionBackend: Send + Sync {
let messages = self.load(&key);
SessionMetadata {
key,
name: None,
created_at: Utc::now(),
last_activity: Utc::now(),
message_count: messages.len(),
@@ -81,6 +84,16 @@ pub trait SessionBackend: Send + Sync {
fn delete_session(&self, _session_key: &str) -> std::io::Result<bool> {
Ok(false)
}
/// Set or update the human-readable name for a session.
fn set_session_name(&self, _session_key: &str, _name: &str) -> std::io::Result<()> {
Ok(())
}
/// Get the human-readable name for a session (if set).
fn get_session_name(&self, _session_key: &str) -> std::io::Result<Option<String>> {
Ok(None)
}
}
#[cfg(test)]
@@ -91,6 +104,7 @@ mod tests {
fn session_metadata_is_constructible() {
let meta = SessionMetadata {
key: "test".into(),
name: None,
created_at: Utc::now(),
last_activity: Utc::now(),
message_count: 5,
+92 -3
View File
@@ -51,7 +51,8 @@ impl SqliteSessionBackend {
session_key TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
last_activity TEXT NOT NULL,
message_count INTEGER NOT NULL DEFAULT 0
message_count INTEGER NOT NULL DEFAULT 0,
name TEXT
);
CREATE VIRTUAL TABLE IF NOT EXISTS sessions_fts USING fts5(
@@ -69,6 +70,18 @@ impl SqliteSessionBackend {
)
.context("Failed to initialize session schema")?;
// Migration: add name column to existing databases
let has_name: bool = conn
.query_row(
"SELECT COUNT(*) > 0 FROM pragma_table_info('session_metadata') WHERE name = 'name'",
[],
|row| row.get(0),
)
.unwrap_or(false);
if !has_name {
let _ = conn.execute("ALTER TABLE session_metadata ADD COLUMN name TEXT", []);
}
Ok(Self {
conn: Mutex::new(conn),
db_path,
@@ -226,7 +239,7 @@ impl SessionBackend for SqliteSessionBackend {
fn list_sessions_with_metadata(&self) -> Vec<SessionMetadata> {
let conn = self.conn.lock();
let mut stmt = match conn.prepare(
"SELECT session_key, created_at, last_activity, message_count
"SELECT session_key, created_at, last_activity, message_count, name
FROM session_metadata ORDER BY last_activity DESC",
) {
Ok(s) => s,
@@ -238,6 +251,7 @@ impl SessionBackend for SqliteSessionBackend {
let created_str: String = row.get(1)?;
let activity_str: String = row.get(2)?;
let count: i64 = row.get(3)?;
let name: Option<String> = row.get(4)?;
let created = DateTime::parse_from_rfc3339(&created_str)
.map(|dt| dt.with_timezone(&Utc))
@@ -249,6 +263,7 @@ impl SessionBackend for SqliteSessionBackend {
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(SessionMetadata {
key,
name,
created_at: created,
last_activity: activity,
message_count: count as usize,
@@ -321,6 +336,27 @@ impl SessionBackend for SqliteSessionBackend {
Ok(true)
}
fn set_session_name(&self, session_key: &str, name: &str) -> std::io::Result<()> {
let conn = self.conn.lock();
let name_val = if name.is_empty() { None } else { Some(name) };
conn.execute(
"UPDATE session_metadata SET name = ?1 WHERE session_key = ?2",
params![name_val, session_key],
)
.map_err(std::io::Error::other)?;
Ok(())
}
fn get_session_name(&self, session_key: &str) -> std::io::Result<Option<String>> {
let conn = self.conn.lock();
conn.query_row(
"SELECT name FROM session_metadata WHERE session_key = ?1",
params![session_key],
|row| row.get(0),
)
.map_err(std::io::Error::other)
}
fn search(&self, query: &SessionQuery) -> Vec<SessionMetadata> {
let Some(keyword) = &query.keyword else {
return self.list_sessions_with_metadata();
@@ -357,14 +393,16 @@ impl SessionBackend for SqliteSessionBackend {
keys.iter()
.filter_map(|key| {
conn.query_row(
"SELECT created_at, last_activity, message_count FROM session_metadata WHERE session_key = ?1",
"SELECT created_at, last_activity, message_count, name FROM session_metadata WHERE session_key = ?1",
params![key],
|row| {
let created_str: String = row.get(0)?;
let activity_str: String = row.get(1)?;
let count: i64 = row.get(2)?;
let name: Option<String> = row.get(3)?;
Ok(SessionMetadata {
key: key.clone(),
name,
created_at: DateTime::parse_from_rfc3339(&created_str)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or_else(|_| Utc::now()),
@@ -555,4 +593,55 @@ mod tests {
assert_eq!(msgs.len(), 2);
assert_eq!(msgs[0].content, "hello");
}
#[test]
fn set_session_name_persists() {
let tmp = TempDir::new().unwrap();
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
backend.append("s1", &ChatMessage::user("hello")).unwrap();
backend.set_session_name("s1", "My Session").unwrap();
let meta = backend.list_sessions_with_metadata();
assert_eq!(meta.len(), 1);
assert_eq!(meta[0].name.as_deref(), Some("My Session"));
}
#[test]
fn set_session_name_updates_existing() {
let tmp = TempDir::new().unwrap();
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
backend.append("s1", &ChatMessage::user("hello")).unwrap();
backend.set_session_name("s1", "First").unwrap();
backend.set_session_name("s1", "Second").unwrap();
let meta = backend.list_sessions_with_metadata();
assert_eq!(meta[0].name.as_deref(), Some("Second"));
}
#[test]
fn sessions_without_name_return_none() {
let tmp = TempDir::new().unwrap();
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
backend.append("s1", &ChatMessage::user("hello")).unwrap();
let meta = backend.list_sessions_with_metadata();
assert_eq!(meta.len(), 1);
assert!(meta[0].name.is_none());
}
#[test]
fn empty_name_clears_to_none() {
let tmp = TempDir::new().unwrap();
let backend = SqliteSessionBackend::new(tmp.path()).unwrap();
backend.append("s1", &ChatMessage::user("hello")).unwrap();
backend.set_session_name("s1", "Named").unwrap();
backend.set_session_name("s1", "").unwrap();
let meta = backend.list_sessions_with_metadata();
assert!(meta[0].name.is_none());
}
}
+119
View File
@@ -34,6 +34,9 @@ pub struct SlackChannel {
active_assistant_thread: Mutex<HashMap<String, String>>,
/// Per-channel proxy URL override.
proxy_url: Option<String>,
/// Voice transcription config — when set, audio file attachments are
/// downloaded, transcribed, and their text inlined into the message.
transcription: Option<crate::config::TranscriptionConfig>,
}
const SLACK_HISTORY_MAX_RETRIES: u32 = 3;
@@ -125,6 +128,7 @@ impl SlackChannel {
workspace_dir: None,
active_assistant_thread: Mutex::new(HashMap::new()),
proxy_url: None,
transcription: None,
}
}
@@ -158,6 +162,14 @@ impl SlackChannel {
self
}
/// Configure voice transcription for audio file attachments.
pub fn with_transcription(mut self, config: crate::config::TranscriptionConfig) -> Self {
if config.enabled {
self.transcription = Some(config);
}
self
}
fn http_client(&self) -> reqwest::Client {
crate::config::build_channel_proxy_client_with_timeouts(
"channel.slack",
@@ -558,6 +570,13 @@ impl SlackChannel {
.await
.unwrap_or_else(|| raw_file.clone());
// Voice / audio transcription: if transcription is configured and the
// file looks like an audio attachment, download and transcribe it.
if Self::is_audio_file(&file) {
if let Some(transcribed) = self.try_transcribe_audio_file(&file).await {
return Some(transcribed);
}
}
if Self::is_image_file(&file) {
if let Some(marker) = self.fetch_image_marker(&file).await {
return Some(marker);
@@ -1449,6 +1468,106 @@ impl SlackChannel {
.is_some_and(|ext| Self::mime_from_extension(ext).is_some())
}
/// Audio file extensions accepted for voice transcription.
const AUDIO_EXTENSIONS: &[&str] = &[
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
];
/// Check whether a Slack file object looks like an audio attachment
/// (voice memo, audio message, or uploaded audio file).
fn is_audio_file(file: &serde_json::Value) -> bool {
// Slack voice messages use subtype "slack_audio"
if let Some(subtype) = file.get("subtype").and_then(|v| v.as_str()) {
if subtype == "slack_audio" {
return true;
}
}
if Self::slack_file_mime(file)
.as_deref()
.is_some_and(|mime| mime.starts_with("audio/"))
{
return true;
}
if let Some(ft) = file
.get("filetype")
.and_then(|v| v.as_str())
.map(|v| v.to_ascii_lowercase())
{
if Self::AUDIO_EXTENSIONS.contains(&ft.as_str()) {
return true;
}
}
Self::file_extension(&Self::slack_file_name(file))
.as_deref()
.is_some_and(|ext| Self::AUDIO_EXTENSIONS.contains(&ext))
}
/// Download an audio file attachment and transcribe it using the configured
/// transcription provider. Returns `None` if transcription is not configured
/// or if the download/transcription fails.
async fn try_transcribe_audio_file(&self, file: &serde_json::Value) -> Option<String> {
let config = self.transcription.as_ref()?;
let url = Self::slack_file_download_url(file)?;
let file_name = Self::slack_file_name(file);
let redacted_url = Self::redact_raw_slack_url(url);
let resp = self.fetch_slack_private_file(url).await?;
let status = resp.status();
if !status.is_success() {
tracing::warn!(
"Slack voice file download failed for {} ({status})",
redacted_url
);
return None;
}
let audio_data = match resp.bytes().await {
Ok(bytes) => bytes.to_vec(),
Err(e) => {
tracing::warn!("Slack voice file read failed for {}: {e}", redacted_url);
return None;
}
};
// Determine a filename with extension for the transcription API.
let transcription_filename = if Self::file_extension(&file_name).is_some() {
file_name.clone()
} else {
// Fall back to extension from mimetype or default to .ogg
let mime_ext = Self::slack_file_mime(file)
.and_then(|mime| mime.rsplit('/').next().map(|s| s.to_string()))
.unwrap_or_else(|| "ogg".to_string());
format!("voice.{mime_ext}")
};
match super::transcription::transcribe_audio(audio_data, &transcription_filename, config)
.await
{
Ok(text) => {
let trimmed = text.trim();
if trimmed.is_empty() {
tracing::info!("Slack voice transcription returned empty text, skipping");
None
} else {
tracing::info!(
"Slack: transcribed voice file {} ({} chars)",
file_name,
trimmed.len()
);
Some(format!("[Voice] {trimmed}"))
}
}
Err(e) => {
tracing::warn!("Slack voice transcription failed for {}: {e}", file_name);
Some(Self::format_attachment_summary(file))
}
}
}
async fn download_text_snippet(&self, file: &serde_json::Value) -> Option<String> {
let url = Self::slack_file_download_url(file)?;
let redacted_url = Self::redact_raw_slack_url(url);
+203
View File
@@ -1140,6 +1140,11 @@ Allowlist Telegram username (without '@') or numeric user ID.",
content = format!("{quote}\n\n{content}");
}
// Prepend forwarding attribution when the message was forwarded
if let Some(attr) = Self::format_forward_attribution(message) {
content = format!("{attr}{content}");
}
Some(ChannelMessage {
id: format!("telegram_{chat_id}_{message_id}"),
sender: sender_identity,
@@ -1263,6 +1268,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
format!("[Voice] {text}")
};
// Prepend forwarding attribution when the message was forwarded
let content = if let Some(attr) = Self::format_forward_attribution(message) {
format!("{attr}{content}")
} else {
content
};
Some(ChannelMessage {
id: format!("telegram_{chat_id}_{message_id}"),
sender: sender_identity,
@@ -1299,6 +1311,41 @@ Allowlist Telegram username (without '@') or numeric user ID.",
(username, sender_id, sender_identity)
}
/// Build a forwarding attribution prefix from Telegram forward fields.
///
/// Returns `Some("[Forwarded from ...] ")` when the message is forwarded,
/// `None` otherwise.
fn format_forward_attribution(message: &serde_json::Value) -> Option<String> {
if let Some(from_chat) = message.get("forward_from_chat") {
// Forwarded from a channel or group
let title = from_chat
.get("title")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown channel");
Some(format!("[Forwarded from channel: {title}] "))
} else if let Some(from_user) = message.get("forward_from") {
// Forwarded from a user (privacy allows identity)
let label = from_user
.get("username")
.and_then(serde_json::Value::as_str)
.map(|u| format!("@{u}"))
.or_else(|| {
from_user
.get("first_name")
.and_then(serde_json::Value::as_str)
.map(String::from)
})
.unwrap_or_else(|| "unknown".to_string());
Some(format!("[Forwarded from {label}] "))
} else {
// Forwarded from a user who hides their identity
message
.get("forward_sender_name")
.and_then(serde_json::Value::as_str)
.map(|name| format!("[Forwarded from {name}] "))
}
}
/// Extract reply context from a Telegram `reply_to_message`, if present.
fn extract_reply_context(&self, message: &serde_json::Value) -> Option<String> {
let reply = message.get("reply_to_message")?;
@@ -1420,6 +1467,13 @@ Allowlist Telegram username (without '@') or numeric user ID.",
content
};
// Prepend forwarding attribution when the message was forwarded
let content = if let Some(attr) = Self::format_forward_attribution(message) {
format!("{attr}{content}")
} else {
content
};
// Exit voice-chat mode when user switches back to typing
if let Ok(mut vc) = self.voice_chats.lock() {
vc.remove(&reply_target);
@@ -4871,4 +4925,153 @@ mod tests {
TelegramChannel::new("token".into(), vec!["*".into()], false).with_ack_reactions(true);
assert!(ch.ack_reactions);
}
// ── Forwarded message tests ─────────────────────────────────────
#[test]
fn parse_update_message_forwarded_from_user_with_username() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
let update = serde_json::json!({
"update_id": 100,
"message": {
"message_id": 50,
"text": "Check this out",
"from": { "id": 1, "username": "alice" },
"chat": { "id": 999 },
"forward_from": {
"id": 42,
"first_name": "Bob",
"username": "bob"
},
"forward_date": 1_700_000_000
}
});
let msg = ch
.parse_update_message(&update)
.expect("forwarded message should parse");
assert_eq!(msg.content, "[Forwarded from @bob] Check this out");
}
#[test]
fn parse_update_message_forwarded_from_channel() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
let update = serde_json::json!({
"update_id": 101,
"message": {
"message_id": 51,
"text": "Breaking news",
"from": { "id": 1, "username": "alice" },
"chat": { "id": 999 },
"forward_from_chat": {
"id": -1_001_234_567_890_i64,
"title": "Daily News",
"username": "dailynews",
"type": "channel"
},
"forward_date": 1_700_000_000
}
});
let msg = ch
.parse_update_message(&update)
.expect("channel-forwarded message should parse");
assert_eq!(
msg.content,
"[Forwarded from channel: Daily News] Breaking news"
);
}
#[test]
fn parse_update_message_forwarded_hidden_sender() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
let update = serde_json::json!({
"update_id": 102,
"message": {
"message_id": 52,
"text": "Secret tip",
"from": { "id": 1, "username": "alice" },
"chat": { "id": 999 },
"forward_sender_name": "Hidden User",
"forward_date": 1_700_000_000
}
});
let msg = ch
.parse_update_message(&update)
.expect("hidden-sender forwarded message should parse");
assert_eq!(msg.content, "[Forwarded from Hidden User] Secret tip");
}
#[test]
fn parse_update_message_non_forwarded_unaffected() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
let update = serde_json::json!({
"update_id": 103,
"message": {
"message_id": 53,
"text": "Normal message",
"from": { "id": 1, "username": "alice" },
"chat": { "id": 999 }
}
});
let msg = ch
.parse_update_message(&update)
.expect("non-forwarded message should parse");
assert_eq!(msg.content, "Normal message");
}
#[test]
fn parse_update_message_forwarded_from_user_no_username() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()], false);
let update = serde_json::json!({
"update_id": 104,
"message": {
"message_id": 54,
"text": "Hello there",
"from": { "id": 1, "username": "alice" },
"chat": { "id": 999 },
"forward_from": {
"id": 77,
"first_name": "Charlie"
},
"forward_date": 1_700_000_000
}
});
let msg = ch
.parse_update_message(&update)
.expect("forwarded message without username should parse");
assert_eq!(msg.content, "[Forwarded from Charlie] Hello there");
}
#[test]
fn forwarded_photo_attachment_has_attribution() {
// Verify that format_forward_attribution produces correct prefix
// for a photo message (the actual download is async, so we test the
// helper directly with a photo-bearing message structure).
let message = serde_json::json!({
"message_id": 60,
"from": { "id": 1, "username": "alice" },
"chat": { "id": 999 },
"photo": [
{ "file_id": "abc123", "file_unique_id": "u1", "width": 320, "height": 240 }
],
"forward_from": {
"id": 42,
"username": "bob"
},
"forward_date": 1_700_000_000
});
let attr =
TelegramChannel::format_forward_attribution(&message).expect("should detect forward");
assert_eq!(attr, "[Forwarded from @bob] ");
// Simulate what try_parse_attachment_message does after building content
let photo_content = "[IMAGE:/tmp/photo.jpg]".to_string();
let content = format!("{attr}{photo_content}");
assert_eq!(content, "[Forwarded from @bob] [IMAGE:/tmp/photo.jpg]");
}
}
+130 -1
View File
@@ -1,6 +1,7 @@
//! Multi-provider Text-to-Speech (TTS) subsystem.
//!
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, and Edge TTS (free, subprocess-based).
//! Supports OpenAI, ElevenLabs, Google Cloud TTS, Edge TTS (free, subprocess-based),
//! and Piper TTS (local GPU-accelerated, OpenAI-compatible endpoint).
//! Provider selection is driven by [`TtsConfig`] in `config.toml`.
use std::collections::HashMap;
@@ -451,6 +452,80 @@ impl TtsProvider for EdgeTtsProvider {
}
}
// ── Piper TTS (local, OpenAI-compatible) ─────────────────────────
/// Piper TTS provider — local GPU-accelerated server with an OpenAI-compatible endpoint.
pub struct PiperTtsProvider {
client: reqwest::Client,
api_url: String,
}
impl PiperTtsProvider {
/// Create a new Piper TTS provider pointing at the given API URL.
pub fn new(api_url: &str) -> Self {
Self {
client: reqwest::Client::builder()
.timeout(TTS_HTTP_TIMEOUT)
.build()
.expect("Failed to build HTTP client for Piper TTS"),
api_url: api_url.to_string(),
}
}
}
#[async_trait::async_trait]
impl TtsProvider for PiperTtsProvider {
fn name(&self) -> &str {
"piper"
}
async fn synthesize(&self, text: &str, voice: &str) -> Result<Vec<u8>> {
let body = serde_json::json!({
"model": "tts-1",
"input": text,
"voice": voice,
});
let resp = self
.client
.post(&self.api_url)
.json(&body)
.send()
.await
.context("Failed to send Piper TTS request")?;
let status = resp.status();
if !status.is_success() {
let error_body: serde_json::Value = resp
.json()
.await
.unwrap_or_else(|_| serde_json::json!({"error": "unknown"}));
let msg = error_body["error"]["message"]
.as_str()
.unwrap_or("unknown error");
bail!("Piper TTS API error ({}): {}", status, msg);
}
let bytes = resp
.bytes()
.await
.context("Failed to read Piper TTS response body")?;
Ok(bytes.to_vec())
}
fn supported_voices(&self) -> Vec<String> {
// Piper voices depend on installed models; return empty (dynamic).
Vec::new()
}
fn supported_formats(&self) -> Vec<String> {
["mp3", "wav", "opus"]
.iter()
.map(|s| (*s).to_string())
.collect()
}
}
// ── TtsManager ───────────────────────────────────────────────────
/// Central manager for multi-provider TTS synthesis.
@@ -510,6 +585,11 @@ impl TtsManager {
}
}
if let Some(ref piper_cfg) = config.piper {
let provider = PiperTtsProvider::new(&piper_cfg.api_url);
providers.insert("piper".to_string(), Box::new(provider));
}
let max_text_length = if config.max_text_length == 0 {
DEFAULT_MAX_TEXT_LENGTH
} else {
@@ -652,6 +732,54 @@ mod tests {
);
}
#[test]
fn piper_provider_creation() {
let provider = PiperTtsProvider::new("http://127.0.0.1:5000/v1/audio/speech");
assert_eq!(provider.name(), "piper");
assert_eq!(provider.api_url, "http://127.0.0.1:5000/v1/audio/speech");
assert_eq!(provider.supported_formats(), vec!["mp3", "wav", "opus"]);
// Piper voices depend on installed models; list is empty.
assert!(provider.supported_voices().is_empty());
}
#[test]
fn tts_manager_with_piper_provider() {
let mut config = default_tts_config();
config.default_provider = "piper".to_string();
config.piper = Some(crate::config::PiperTtsConfig {
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
});
let manager = TtsManager::new(&config).unwrap();
assert_eq!(manager.available_providers(), vec!["piper"]);
}
#[tokio::test]
async fn tts_rejects_empty_text_for_piper() {
let mut config = default_tts_config();
config.default_provider = "piper".to_string();
config.piper = Some(crate::config::PiperTtsConfig {
api_url: "http://127.0.0.1:5000/v1/audio/speech".into(),
});
let manager = TtsManager::new(&config).unwrap();
let err = manager
.synthesize_with_provider("", "piper", "default")
.await
.unwrap_err();
assert!(
err.to_string().contains("must not be empty"),
"expected empty-text error, got: {err}"
);
}
#[test]
fn piper_not_registered_when_config_is_none() {
let config = default_tts_config();
let manager = TtsManager::new(&config).unwrap();
assert!(!manager.available_providers().contains(&"piper".to_string()));
}
#[test]
fn tts_config_defaults() {
let config = TtsConfig::default();
@@ -664,6 +792,7 @@ mod tests {
assert!(config.elevenlabs.is_none());
assert!(config.google.is_none());
assert!(config.edge.is_none());
assert!(config.piper.is_none());
}
#[test]
+531
View File
@@ -0,0 +1,531 @@
//! Voice Wake Word detection channel.
//!
//! Listens on the default microphone via `cpal`, detects a configurable wake
//! word using energy-based VAD followed by transcription-based keyword matching,
//! then captures the subsequent utterance and dispatches it as a channel message.
//!
//! Gated behind the `voice-wake` Cargo feature.
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use anyhow::{bail, Result};
use async_trait::async_trait;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
use crate::channels::transcription::transcribe_audio;
use crate::config::schema::VoiceWakeConfig;
use crate::config::TranscriptionConfig;
use super::traits::{Channel, ChannelMessage, SendMessage};
// ── State machine ──────────────────────────────────────────────
/// Internal states for the wake-word detector.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WakeState {
/// Passively monitoring microphone energy levels.
Listening,
/// Energy spike detected — capturing a short window to check for wake word.
Triggered,
/// Wake word confirmed — capturing the full utterance that follows.
Capturing,
/// Captured audio is being transcribed.
Processing,
}
impl std::fmt::Display for WakeState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Listening => write!(f, "Listening"),
Self::Triggered => write!(f, "Triggered"),
Self::Capturing => write!(f, "Capturing"),
Self::Processing => write!(f, "Processing"),
}
}
}
// ── Channel implementation ─────────────────────────────────────
/// Voice wake-word channel that activates on a spoken keyword.
pub struct VoiceWakeChannel {
config: VoiceWakeConfig,
transcription_config: TranscriptionConfig,
}
impl VoiceWakeChannel {
/// Create a new `VoiceWakeChannel` from its config sections.
pub fn new(config: VoiceWakeConfig, transcription_config: TranscriptionConfig) -> Self {
Self {
config,
transcription_config,
}
}
}
#[async_trait]
impl Channel for VoiceWakeChannel {
fn name(&self) -> &str {
"voice_wake"
}
async fn send(&self, _message: &SendMessage) -> Result<()> {
// Voice wake is input-only; outbound messages are not supported.
Ok(())
}
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
let config = self.config.clone();
let transcription_config = self.transcription_config.clone();
// Run the blocking audio capture loop on a dedicated thread.
let (audio_tx, mut audio_rx) = mpsc::channel::<Vec<f32>>(4);
let energy_threshold = config.energy_threshold;
let silence_timeout = Duration::from_millis(u64::from(config.silence_timeout_ms));
let max_capture = Duration::from_secs(u64::from(config.max_capture_secs));
let sample_rate: u32;
let channels_count: u16;
// ── Initialise cpal stream ────────────────────────────
{
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
let host = cpal::default_host();
let device = host
.default_input_device()
.ok_or_else(|| anyhow::anyhow!("No default audio input device available"))?;
let supported = device.default_input_config()?;
sample_rate = supported.sample_rate().0;
channels_count = supported.channels();
info!(
device = ?device.name().unwrap_or_default(),
sample_rate,
channels = channels_count,
"VoiceWake: opening audio input"
);
let stream_config: cpal::StreamConfig = supported.into();
let audio_tx_clone = audio_tx.clone();
let stream = device.build_input_stream(
&stream_config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
// Non-blocking: try_send and drop if full.
let _ = audio_tx_clone.try_send(data.to_vec());
},
move |err| {
warn!("VoiceWake: audio stream error: {err}");
},
None,
)?;
stream.play()?;
// Keep the stream alive for the lifetime of the channel.
// We leak it intentionally — the channel runs until the daemon shuts down.
std::mem::forget(stream);
}
// Drop the extra sender so the channel closes when the stream sender drops.
drop(audio_tx);
// ── Main detection loop ───────────────────────────────
let wake_word = config.wake_word.to_lowercase();
let mut state = WakeState::Listening;
let mut capture_buf: Vec<f32> = Vec::new();
let mut last_voice_at = Instant::now();
let mut capture_start = Instant::now();
let mut msg_counter: u64 = 0;
info!(wake_word = %wake_word, "VoiceWake: entering listen loop");
while let Some(chunk) = audio_rx.recv().await {
let energy = compute_rms_energy(&chunk);
match state {
WakeState::Listening => {
if energy >= energy_threshold {
debug!(
energy,
"VoiceWake: energy spike — transitioning to Triggered"
);
state = WakeState::Triggered;
capture_buf.clear();
capture_buf.extend_from_slice(&chunk);
last_voice_at = Instant::now();
capture_start = Instant::now();
}
}
WakeState::Triggered => {
capture_buf.extend_from_slice(&chunk);
if energy >= energy_threshold {
last_voice_at = Instant::now();
}
let since_voice = last_voice_at.elapsed();
let since_start = capture_start.elapsed();
// After enough silence or max time, transcribe to check for wake word.
if since_voice >= silence_timeout || since_start >= max_capture {
debug!("VoiceWake: Triggered window closed — transcribing for wake word");
let wav_bytes =
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
match transcribe_audio(wav_bytes, "wake_check.wav", &transcription_config)
.await
{
Ok(text) => {
let lower = text.to_lowercase();
if lower.contains(&wake_word) {
info!(text = %text, "VoiceWake: wake word detected — capturing utterance");
state = WakeState::Capturing;
capture_buf.clear();
last_voice_at = Instant::now();
capture_start = Instant::now();
} else {
debug!(text = %text, "VoiceWake: no wake word — back to Listening");
state = WakeState::Listening;
capture_buf.clear();
}
}
Err(e) => {
warn!("VoiceWake: transcription error during wake check: {e}");
state = WakeState::Listening;
capture_buf.clear();
}
}
}
}
WakeState::Capturing => {
capture_buf.extend_from_slice(&chunk);
if energy >= energy_threshold {
last_voice_at = Instant::now();
}
let since_voice = last_voice_at.elapsed();
let since_start = capture_start.elapsed();
if since_voice >= silence_timeout || since_start >= max_capture {
debug!("VoiceWake: utterance capture complete — transcribing");
let wav_bytes =
encode_wav_from_f32(&capture_buf, sample_rate, channels_count);
match transcribe_audio(wav_bytes, "utterance.wav", &transcription_config)
.await
{
Ok(text) => {
let trimmed = text.trim().to_string();
if !trimmed.is_empty() {
msg_counter += 1;
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let msg = ChannelMessage {
id: format!("voice_wake_{msg_counter}"),
sender: "voice_user".into(),
reply_target: "voice_user".into(),
content: trimmed,
channel: "voice_wake".into(),
timestamp: ts,
thread_ts: None,
interruption_scope_id: None,
};
if let Err(e) = tx.send(msg).await {
warn!("VoiceWake: failed to dispatch message: {e}");
}
}
}
Err(e) => {
warn!("VoiceWake: transcription error for utterance: {e}");
}
}
state = WakeState::Listening;
capture_buf.clear();
}
}
WakeState::Processing => {
// Should not receive chunks while processing, but just buffer them.
// State transitions happen above synchronously after transcription.
}
}
}
bail!("VoiceWake: audio stream ended unexpectedly");
}
}
// ── Audio utilities ────────────────────────────────────────────
/// Compute RMS (root-mean-square) energy of an audio chunk.
pub fn compute_rms_energy(samples: &[f32]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let sum_sq: f32 = samples.iter().map(|s| s * s).sum();
(sum_sq / samples.len() as f32).sqrt()
}
/// Encode raw f32 PCM samples as a WAV byte buffer (16-bit PCM).
///
/// This produces a minimal valid WAV file that Whisper-compatible APIs accept.
pub fn encode_wav_from_f32(samples: &[f32], sample_rate: u32, channels: u16) -> Vec<u8> {
let bits_per_sample: u16 = 16;
let byte_rate = u32::from(channels) * sample_rate * u32::from(bits_per_sample) / 8;
let block_align = channels * bits_per_sample / 8;
#[allow(clippy::cast_possible_truncation)]
let data_len = (samples.len() * 2) as u32; // 16-bit = 2 bytes per sample; max ~25 MB
let file_len = 36 + data_len;
let mut buf = Vec::with_capacity(file_len as usize + 8);
// RIFF header
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_len.to_le_bytes());
buf.extend_from_slice(b"WAVE");
// fmt chunk
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
buf.extend_from_slice(&channels.to_le_bytes());
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&block_align.to_le_bytes());
buf.extend_from_slice(&bits_per_sample.to_le_bytes());
// data chunk
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_len.to_le_bytes());
for &sample in samples {
let clamped = sample.clamp(-1.0, 1.0);
#[allow(clippy::cast_possible_truncation)]
let pcm16 = (clamped * 32767.0) as i16; // clamped to [-1,1] so fits i16
buf.extend_from_slice(&pcm16.to_le_bytes());
}
buf
}
// ── Tests ──────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::config::traits::ChannelConfig;
// ── State machine tests ────────────────────────────────
#[test]
fn wake_state_display() {
assert_eq!(WakeState::Listening.to_string(), "Listening");
assert_eq!(WakeState::Triggered.to_string(), "Triggered");
assert_eq!(WakeState::Capturing.to_string(), "Capturing");
assert_eq!(WakeState::Processing.to_string(), "Processing");
}
#[test]
fn wake_state_equality() {
assert_eq!(WakeState::Listening, WakeState::Listening);
assert_ne!(WakeState::Listening, WakeState::Triggered);
}
// ── Energy computation tests ───────────────────────────
#[test]
fn rms_energy_of_silence_is_zero() {
let silence = vec![0.0f32; 1024];
assert_eq!(compute_rms_energy(&silence), 0.0);
}
#[test]
fn rms_energy_of_empty_is_zero() {
assert_eq!(compute_rms_energy(&[]), 0.0);
}
#[test]
fn rms_energy_of_constant_signal() {
// Constant signal at 0.5 → RMS should be 0.5
let signal = vec![0.5f32; 100];
let energy = compute_rms_energy(&signal);
assert!((energy - 0.5).abs() < 1e-5);
}
#[test]
fn rms_energy_above_threshold() {
let loud = vec![0.8f32; 256];
let energy = compute_rms_energy(&loud);
assert!(energy > 0.01, "Loud signal should exceed default threshold");
}
#[test]
fn rms_energy_below_threshold_for_quiet() {
let quiet = vec![0.001f32; 256];
let energy = compute_rms_energy(&quiet);
assert!(
energy < 0.01,
"Very quiet signal should be below default threshold"
);
}
// ── WAV encoding tests ─────────────────────────────────
#[test]
fn wav_header_is_valid() {
let samples = vec![0.0f32; 100];
let wav = encode_wav_from_f32(&samples, 16000, 1);
// RIFF header
assert_eq!(&wav[0..4], b"RIFF");
assert_eq!(&wav[8..12], b"WAVE");
// fmt chunk
assert_eq!(&wav[12..16], b"fmt ");
let fmt_size = u32::from_le_bytes(wav[16..20].try_into().unwrap());
assert_eq!(fmt_size, 16);
// PCM format
let format = u16::from_le_bytes(wav[20..22].try_into().unwrap());
assert_eq!(format, 1);
// Channels
let channels = u16::from_le_bytes(wav[22..24].try_into().unwrap());
assert_eq!(channels, 1);
// Sample rate
let sr = u32::from_le_bytes(wav[24..28].try_into().unwrap());
assert_eq!(sr, 16000);
// data chunk
assert_eq!(&wav[36..40], b"data");
let data_size = u32::from_le_bytes(wav[40..44].try_into().unwrap());
assert_eq!(data_size, 200); // 100 samples * 2 bytes each
}
#[test]
fn wav_total_size_correct() {
let samples = vec![0.0f32; 50];
let wav = encode_wav_from_f32(&samples, 44100, 2);
// header (44 bytes) + data (50 * 2 = 100 bytes)
assert_eq!(wav.len(), 144);
}
#[test]
fn wav_encodes_clipped_samples() {
// Samples outside [-1, 1] should be clamped
let samples = vec![-2.0f32, 2.0, 0.0];
let wav = encode_wav_from_f32(&samples, 16000, 1);
let s0 = i16::from_le_bytes(wav[44..46].try_into().unwrap());
let s1 = i16::from_le_bytes(wav[46..48].try_into().unwrap());
let s2 = i16::from_le_bytes(wav[48..50].try_into().unwrap());
assert_eq!(s0, -32767); // clamped to -1.0
assert_eq!(s1, 32767); // clamped to 1.0
assert_eq!(s2, 0);
}
// ── Config parsing tests ───────────────────────────────
#[test]
fn voice_wake_config_defaults() {
let config = VoiceWakeConfig::default();
assert_eq!(config.wake_word, "hey zeroclaw");
assert_eq!(config.silence_timeout_ms, 2000);
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
assert_eq!(config.max_capture_secs, 30);
}
#[test]
fn voice_wake_config_deserialize_partial() {
let toml_str = r#"
wake_word = "okay agent"
max_capture_secs = 60
"#;
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.wake_word, "okay agent");
assert_eq!(config.max_capture_secs, 60);
// Defaults preserved for unset fields
assert_eq!(config.silence_timeout_ms, 2000);
assert!((config.energy_threshold - 0.01).abs() < f32::EPSILON);
}
#[test]
fn voice_wake_config_deserialize_all_fields() {
let toml_str = r#"
wake_word = "hello bot"
silence_timeout_ms = 3000
energy_threshold = 0.05
max_capture_secs = 15
"#;
let config: VoiceWakeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.wake_word, "hello bot");
assert_eq!(config.silence_timeout_ms, 3000);
assert!((config.energy_threshold - 0.05).abs() < f32::EPSILON);
assert_eq!(config.max_capture_secs, 15);
}
#[test]
fn voice_wake_config_channel_config_trait() {
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
assert_eq!(VoiceWakeConfig::desc(), "voice wake word detection");
}
// ── State transition logic tests ───────────────────────
#[test]
fn energy_threshold_determines_trigger() {
let threshold = 0.01f32;
let quiet_energy = compute_rms_energy(&vec![0.005f32; 256]);
let loud_energy = compute_rms_energy(&vec![0.5f32; 256]);
assert!(quiet_energy < threshold, "Quiet should not trigger");
assert!(loud_energy >= threshold, "Loud should trigger");
}
#[test]
fn state_transitions_are_deterministic() {
// Verify that the state enum values are distinct and copyable
let states = [
WakeState::Listening,
WakeState::Triggered,
WakeState::Capturing,
WakeState::Processing,
];
for (i, a) in states.iter().enumerate() {
for (j, b) in states.iter().enumerate() {
if i == j {
assert_eq!(a, b);
} else {
assert_ne!(a, b);
}
}
}
}
#[test]
fn channel_config_impl() {
// VoiceWakeConfig implements ChannelConfig
assert_eq!(VoiceWakeConfig::name(), "VoiceWake");
assert!(!VoiceWakeConfig::desc().is_empty());
}
#[test]
fn voice_wake_channel_name() {
let config = VoiceWakeConfig::default();
let transcription_config = TranscriptionConfig::default();
let channel = VoiceWakeChannel::new(config, transcription_config);
assert_eq!(channel.name(), "voice_wake");
}
}
+22 -20
View File
@@ -10,26 +10,28 @@ pub use schema::{
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
ClassificationRule, ClaudeCodeConfig, CloudOpsConfig, ComposioConfig, Config,
ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig, DeepgramSttConfig,
DelegateAgentConfig, DelegateToolConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig,
ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation, GoogleWorkspaceConfig,
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
IMessageConfig, IdentityConfig, ImageProviderDalleConfig, ImageProviderFluxConfig,
ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig, KnowledgeConfig,
LarkConfig, LinkedInConfig, LinkedInContentConfig, LinkedInImageConfig, LocalWhisperConfig,
MatrixConfig, McpConfig, McpServerConfig, McpTransport, MemoryConfig, Microsoft365Config,
ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig,
NotionConfig, ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig,
OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PluginsConfig,
ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
SchedulerConfig, SecretsConfig, SecurityConfig, SecurityOpsConfig, SkillCreationConfig,
SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy, TelegramConfig,
TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
TunnelConfig, VerifiableIntentConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig, DEFAULT_GWS_SERVICES,
ConversationalAiConfig, CostConfig, CronConfig, CronJobDecl, CronScheduleDecl,
DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DelegateToolConfig, DiscordConfig,
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceAllowedOperation,
GoogleWorkspaceConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig,
HttpRequestConfig, IMessageConfig, IdentityConfig, ImageGenConfig, ImageProviderDalleConfig,
ImageProviderFluxConfig, ImageProviderImagenConfig, ImageProviderStabilityConfig, JiraConfig,
KnowledgeConfig, LarkConfig, LinkEnricherConfig, LinkedInConfig, LinkedInContentConfig,
LinkedInImageConfig, LocalWhisperConfig, MatrixConfig, McpConfig, McpServerConfig,
McpTransport, MemoryConfig, MemoryPolicyConfig, Microsoft365Config, ModelRouteConfig,
MultimodalConfig, NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig,
ObservabilityConfig, OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig,
OtpMethod, PacingConfig, PeripheralBoardConfig, PeripheralsConfig, PiperTtsConfig,
PluginsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SecurityOpsConfig, SkillCreationConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig,
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig,
SwarmStrategy, TelegramConfig, TextBrowserConfig, ToolFilterGroup, ToolFilterGroupMode,
TranscriptionConfig, TtsConfig, TunnelConfig, VerifiableIntentConfig, WebFetchConfig,
WebSearchConfig, WebhookConfig, WhatsAppChatPolicy, WhatsAppWebMode, WorkspaceConfig,
DEFAULT_GWS_SERVICES,
};
pub fn name_and_presence<T: traits::ChannelConfig>(channel: Option<&T>) -> (&'static str, bool) {
+781 -112
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -15,7 +15,7 @@ pub use schedule::{
#[allow(unused_imports)]
pub use store::{
add_agent_job, all_overdue_jobs, due_jobs, get_job, list_jobs, list_runs, record_last_run,
record_run, remove_job, reschedule_after_run, update_job,
record_run, remove_job, reschedule_after_run, sync_declarative_jobs, update_job,
};
pub use types::{
deserialize_maybe_stringified, CronJob, CronJobPatch, CronRun, DeliveryConfig, JobType,
+48 -2
View File
@@ -1,5 +1,7 @@
#[cfg(feature = "channel-matrix")]
use crate::channels::MatrixChannel;
#[cfg(feature = "whatsapp-web")]
use crate::channels::WhatsAppWebChannel;
use crate::channels::{
Channel, DiscordChannel, MattermostChannel, QQChannel, SendMessage, SignalChannel,
SlackChannel, TelegramChannel,
@@ -7,8 +9,8 @@ use crate::channels::{
use crate::config::Config;
use crate::cron::{
all_overdue_jobs, due_jobs, next_run_for_schedule, record_last_run, record_run, remove_job,
reschedule_after_run, update_job, CronJob, CronJobPatch, DeliveryConfig, JobType, Schedule,
SessionTarget,
reschedule_after_run, sync_declarative_jobs, update_job, CronJob, CronJobPatch, DeliveryConfig,
JobType, Schedule, SessionTarget,
};
use crate::security::SecurityPolicy;
use anyhow::Result;
@@ -34,6 +36,19 @@ pub async fn run(config: Config) -> Result<()> {
crate::health::mark_component_ok(SCHEDULER_COMPONENT);
// ── Declarative job sync: reconcile config-defined jobs with the DB.
match sync_declarative_jobs(&config, &config.cron.jobs) {
Ok(()) => {
if !config.cron.jobs.is_empty() {
tracing::info!(
count = config.cron.jobs.len(),
"Synced declarative cron jobs from config"
);
}
}
Err(e) => tracing::warn!("Failed to sync declarative cron jobs: {e}"),
}
// ── Startup catch-up: run ALL overdue jobs before entering the
// normal polling loop. The regular loop is capped by `max_tasks`,
// which could leave some overdue jobs waiting across many cycles
@@ -483,6 +498,36 @@ pub(crate) async fn deliver_announcement(
anyhow::bail!("matrix delivery channel requires `channel-matrix` feature");
}
}
"whatsapp" | "whatsapp-web" | "whatsapp_web" => {
#[cfg(feature = "whatsapp-web")]
{
let wa = config
.channels_config
.whatsapp
.as_ref()
.ok_or_else(|| anyhow::anyhow!("whatsapp channel not configured"))?;
if !wa.is_web_config() {
anyhow::bail!(
"whatsapp cron delivery requires Web mode (session_path must be set)"
);
}
let channel = WhatsAppWebChannel::new(
wa.session_path.clone().unwrap_or_default(),
wa.pair_phone.clone(),
wa.pair_code.clone(),
wa.allowed_numbers.clone(),
wa.mode.clone(),
wa.dm_policy.clone(),
wa.group_policy.clone(),
wa.self_chat_mode,
);
channel.send(&SendMessage::new(output, target)).await?;
}
#[cfg(not(feature = "whatsapp-web"))]
{
anyhow::bail!("whatsapp delivery channel requires `whatsapp-web` feature");
}
}
"qq" => {
let qq = config
.channels_config
@@ -657,6 +702,7 @@ mod tests {
delivery: DeliveryConfig::default(),
delete_after_run: false,
allowed_tools: None,
source: "imperative".into(),
created_at: Utc::now(),
next_run: Utc::now(),
last_run: None,
+521 -4
View File
@@ -124,7 +124,7 @@ pub fn list_jobs(config: &Config) -> Result<Vec<CronJob>> {
let mut stmt = conn.prepare(
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
allowed_tools
allowed_tools, source
FROM cron_jobs ORDER BY next_run ASC",
)?;
@@ -143,7 +143,7 @@ pub fn get_job(config: &Config, job_id: &str) -> Result<CronJob> {
let mut stmt = conn.prepare(
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
allowed_tools
allowed_tools, source
FROM cron_jobs WHERE id = ?1",
)?;
@@ -177,7 +177,7 @@ pub fn due_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJob>> {
let mut stmt = conn.prepare(
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
allowed_tools
allowed_tools, source
FROM cron_jobs
WHERE enabled = 1 AND next_run <= ?1
ORDER BY next_run ASC
@@ -206,7 +206,8 @@ pub fn all_overdue_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJ
with_connection(config, |conn| {
let mut stmt = conn.prepare(
"SELECT id, expression, command, schedule, job_type, prompt, name, session_target, model,
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output, allowed_tools
enabled, delivery, delete_after_run, created_at, next_run, last_run, last_status, last_output,
allowed_tools, source
FROM cron_jobs
WHERE enabled = 1 AND next_run <= ?1
ORDER BY next_run ASC",
@@ -488,6 +489,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
let last_run_raw: Option<String> = row.get(14)?;
let created_at_raw: String = row.get(12)?;
let allowed_tools_raw: Option<String> = row.get(17)?;
let source: Option<String> = row.get(18)?;
Ok(CronJob {
id: row.get(0)?,
@@ -502,6 +504,7 @@ fn map_cron_job_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<CronJob> {
enabled: row.get::<_, i64>(9)? != 0,
delivery,
delete_after_run: row.get::<_, i64>(11)? != 0,
source: source.unwrap_or_else(|| "imperative".to_string()),
created_at: parse_rfc3339(&created_at_raw).map_err(sql_conversion_error)?,
next_run: parse_rfc3339(&next_run_raw).map_err(sql_conversion_error)?,
last_run: match last_run_raw {
@@ -564,6 +567,277 @@ fn decode_allowed_tools(raw: Option<&str>) -> Result<Option<Vec<String>>> {
Ok(None)
}
/// Synchronize declarative cron job definitions from config into the database.
///
/// For each declarative job (identified by `id`):
/// - If the job exists in DB: update it to match the config definition.
/// - If the job does not exist: insert it.
///
/// Jobs created imperatively (via CLI/API) are never modified or deleted.
/// Declarative jobs that are no longer present in config are removed.
pub fn sync_declarative_jobs(
config: &Config,
decls: &[crate::config::schema::CronJobDecl],
) -> Result<()> {
use crate::config::schema::CronScheduleDecl;
if decls.is_empty() {
// If no declarative jobs are defined, clean up any previously
// synced declarative jobs that are no longer in config.
with_connection(config, |conn| {
let deleted = conn
.execute("DELETE FROM cron_jobs WHERE source = 'declarative'", [])
.context("Failed to remove stale declarative cron jobs")?;
if deleted > 0 {
tracing::info!(
count = deleted,
"Removed declarative cron jobs no longer in config"
);
}
Ok(())
})?;
return Ok(());
}
// Validate declarations before touching the DB.
for decl in decls {
validate_decl(decl)?;
}
let now = Utc::now();
with_connection(config, |conn| {
// Collect IDs of all declarative jobs currently defined in config.
let config_ids: std::collections::HashSet<&str> =
decls.iter().map(|d| d.id.as_str()).collect();
// Remove declarative jobs no longer in config.
{
let mut stmt = conn.prepare("SELECT id FROM cron_jobs WHERE source = 'declarative'")?;
let db_ids: Vec<String> = stmt
.query_map([], |row| row.get(0))?
.filter_map(|r| r.ok())
.collect();
for db_id in &db_ids {
if !config_ids.contains(db_id.as_str()) {
conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![db_id])
.with_context(|| {
format!("Failed to remove stale declarative cron job '{db_id}'")
})?;
tracing::info!(
job_id = %db_id,
"Removed declarative cron job no longer in config"
);
}
}
}
for decl in decls {
let schedule = convert_schedule_decl(&decl.schedule)?;
let expression = schedule_cron_expression(&schedule).unwrap_or_default();
let schedule_json = serde_json::to_string(&schedule)?;
let job_type = &decl.job_type;
let session_target = decl.session_target.as_deref().unwrap_or("isolated");
let delivery = match &decl.delivery {
Some(d) => convert_delivery_decl(d),
None => DeliveryConfig::default(),
};
let delivery_json = serde_json::to_string(&delivery)?;
let allowed_tools_json = encode_allowed_tools(decl.allowed_tools.as_ref())?;
let command = decl.command.as_deref().unwrap_or("");
let delete_after_run = matches!(decl.schedule, CronScheduleDecl::At { .. });
// Check if job already exists.
let exists: bool = conn
.prepare("SELECT COUNT(*) FROM cron_jobs WHERE id = ?1")?
.query_row(params![decl.id], |row| row.get::<_, i64>(0))
.map(|c| c > 0)
.unwrap_or(false);
if exists {
// Update existing declarative job — preserve runtime state
// (next_run, last_run, last_status, last_output, created_at).
// Only update the schedule's next_run if the schedule itself changed.
let current_schedule_raw: Option<String> = conn
.prepare("SELECT schedule FROM cron_jobs WHERE id = ?1")?
.query_row(params![decl.id], |row| row.get(0))
.ok();
let schedule_changed = current_schedule_raw.as_deref() != Some(&schedule_json);
if schedule_changed {
let next_run = next_run_for_schedule(&schedule, now)?;
conn.execute(
"UPDATE cron_jobs
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
enabled = ?9, delivery = ?10, delete_after_run = ?11,
allowed_tools = ?12, source = 'declarative', next_run = ?13
WHERE id = ?14",
params![
expression,
command,
schedule_json,
job_type,
decl.prompt,
decl.name,
session_target,
decl.model,
if decl.enabled { 1 } else { 0 },
delivery_json,
if delete_after_run { 1 } else { 0 },
allowed_tools_json,
next_run.to_rfc3339(),
decl.id,
],
)
.with_context(|| {
format!("Failed to update declarative cron job '{}'", decl.id)
})?;
} else {
conn.execute(
"UPDATE cron_jobs
SET expression = ?1, command = ?2, schedule = ?3, job_type = ?4,
prompt = ?5, name = ?6, session_target = ?7, model = ?8,
enabled = ?9, delivery = ?10, delete_after_run = ?11,
allowed_tools = ?12, source = 'declarative'
WHERE id = ?13",
params![
expression,
command,
schedule_json,
job_type,
decl.prompt,
decl.name,
session_target,
decl.model,
if decl.enabled { 1 } else { 0 },
delivery_json,
if delete_after_run { 1 } else { 0 },
allowed_tools_json,
decl.id,
],
)
.with_context(|| {
format!("Failed to update declarative cron job '{}'", decl.id)
})?;
}
tracing::debug!(job_id = %decl.id, "Updated declarative cron job");
} else {
// Insert new declarative job.
let next_run = next_run_for_schedule(&schedule, now)?;
conn.execute(
"INSERT INTO cron_jobs (
id, expression, command, schedule, job_type, prompt, name,
session_target, model, enabled, delivery, delete_after_run,
allowed_tools, source, created_at, next_run
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, 'declarative', ?14, ?15)",
params![
decl.id,
expression,
command,
schedule_json,
job_type,
decl.prompt,
decl.name,
session_target,
decl.model,
if decl.enabled { 1 } else { 0 },
delivery_json,
if delete_after_run { 1 } else { 0 },
allowed_tools_json,
now.to_rfc3339(),
next_run.to_rfc3339(),
],
)
.with_context(|| {
format!(
"Failed to insert declarative cron job '{}'",
decl.id
)
})?;
tracing::info!(job_id = %decl.id, "Inserted declarative cron job from config");
}
}
Ok(())
})
}
/// Validate a declarative cron job definition.
fn validate_decl(decl: &crate::config::schema::CronJobDecl) -> Result<()> {
if decl.id.trim().is_empty() {
anyhow::bail!("Declarative cron job has empty id");
}
match decl.job_type.to_lowercase().as_str() {
"shell" => {
if decl
.command
.as_deref()
.map_or(true, |c| c.trim().is_empty())
{
anyhow::bail!(
"Declarative cron job '{}': shell job requires a non-empty 'command'",
decl.id
);
}
}
"agent" => {
if decl.prompt.as_deref().map_or(true, |p| p.trim().is_empty()) {
anyhow::bail!(
"Declarative cron job '{}': agent job requires a non-empty 'prompt'",
decl.id
);
}
}
other => {
anyhow::bail!(
"Declarative cron job '{}': invalid job_type '{}', expected 'shell' or 'agent'",
decl.id,
other
);
}
}
Ok(())
}
/// Convert a `CronScheduleDecl` to the runtime `Schedule` type.
fn convert_schedule_decl(decl: &crate::config::schema::CronScheduleDecl) -> Result<Schedule> {
use crate::config::schema::CronScheduleDecl;
match decl {
CronScheduleDecl::Cron { expr, tz } => Ok(Schedule::Cron {
expr: expr.clone(),
tz: tz.clone(),
}),
CronScheduleDecl::Every { every_ms } => Ok(Schedule::Every {
every_ms: *every_ms,
}),
CronScheduleDecl::At { at } => {
let parsed = DateTime::parse_from_rfc3339(at)
.with_context(|| {
format!("Invalid RFC3339 timestamp in declarative cron 'at': {at}")
})?
.with_timezone(&Utc);
Ok(Schedule::At { at: parsed })
}
}
}
/// Convert a `DeliveryConfigDecl` to the runtime `DeliveryConfig`.
fn convert_delivery_decl(decl: &crate::config::schema::DeliveryConfigDecl) -> DeliveryConfig {
DeliveryConfig {
mode: decl.mode.clone(),
channel: decl.channel.clone(),
to: decl.to.clone(),
best_effort: decl.best_effort,
}
}
fn add_column_if_missing(conn: &Connection, name: &str, sql_type: &str) -> Result<()> {
let mut stmt = conn.prepare("PRAGMA table_info(cron_jobs)")?;
let mut rows = stmt.query([])?;
@@ -654,6 +928,7 @@ fn with_connection<T>(config: &Config, f: impl FnOnce(&Connection) -> Result<T>)
add_column_if_missing(&conn, "delivery", "TEXT")?;
add_column_if_missing(&conn, "delete_after_run", "INTEGER NOT NULL DEFAULT 0")?;
add_column_if_missing(&conn, "allowed_tools", "TEXT")?;
add_column_if_missing(&conn, "source", "TEXT DEFAULT 'imperative'")?;
f(&conn)
}
@@ -1170,4 +1445,246 @@ mod tests {
assert!(last_output.ends_with(TRUNCATED_OUTPUT_MARKER));
assert!(last_output.len() <= MAX_CRON_OUTPUT_BYTES);
}
// ── Declarative cron job sync tests ──────────────────────────
fn make_shell_decl(id: &str, expr: &str, cmd: &str) -> crate::config::schema::CronJobDecl {
crate::config::schema::CronJobDecl {
id: id.to_string(),
name: Some(format!("decl-{id}")),
job_type: "shell".to_string(),
schedule: crate::config::schema::CronScheduleDecl::Cron {
expr: expr.to_string(),
tz: None,
},
command: Some(cmd.to_string()),
prompt: None,
enabled: true,
model: None,
allowed_tools: None,
session_target: None,
delivery: None,
}
}
fn make_agent_decl(id: &str, expr: &str, prompt: &str) -> crate::config::schema::CronJobDecl {
crate::config::schema::CronJobDecl {
id: id.to_string(),
name: Some(format!("decl-{id}")),
job_type: "agent".to_string(),
schedule: crate::config::schema::CronScheduleDecl::Cron {
expr: expr.to_string(),
tz: None,
},
command: None,
prompt: Some(prompt.to_string()),
enabled: true,
model: None,
allowed_tools: None,
session_target: None,
delivery: None,
}
}
#[test]
fn sync_inserts_new_declarative_job() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let decls = vec![make_shell_decl("daily-backup", "0 2 * * *", "echo backup")];
sync_declarative_jobs(&config, &decls).unwrap();
let job = get_job(&config, "daily-backup").unwrap();
assert_eq!(job.command, "echo backup");
assert_eq!(job.source, "declarative");
assert_eq!(job.name.as_deref(), Some("decl-daily-backup"));
}
#[test]
fn sync_updates_existing_declarative_job() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let decls = vec![make_shell_decl("updatable", "0 2 * * *", "echo v1")];
sync_declarative_jobs(&config, &decls).unwrap();
let job_v1 = get_job(&config, "updatable").unwrap();
assert_eq!(job_v1.command, "echo v1");
let decls_v2 = vec![make_shell_decl("updatable", "0 3 * * *", "echo v2")];
sync_declarative_jobs(&config, &decls_v2).unwrap();
let job_v2 = get_job(&config, "updatable").unwrap();
assert_eq!(job_v2.command, "echo v2");
assert_eq!(job_v2.expression, "0 3 * * *");
assert_eq!(job_v2.source, "declarative");
}
#[test]
fn sync_does_not_delete_imperative_jobs() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
// Create an imperative job via the normal API.
let imperative = add_job(&config, "*/10 * * * *", "echo imperative").unwrap();
// Sync declarative jobs (none of which match the imperative job).
let decls = vec![make_shell_decl("my-decl", "0 2 * * *", "echo decl")];
sync_declarative_jobs(&config, &decls).unwrap();
// Imperative job should still exist.
let still_there = get_job(&config, &imperative.id).unwrap();
assert_eq!(still_there.command, "echo imperative");
assert_eq!(still_there.source, "imperative");
// Declarative job should also exist.
let decl_job = get_job(&config, "my-decl").unwrap();
assert_eq!(decl_job.command, "echo decl");
}
#[test]
fn sync_removes_stale_declarative_jobs() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
// Insert two declarative jobs.
let decls = vec![
make_shell_decl("keeper", "0 2 * * *", "echo keep"),
make_shell_decl("stale", "0 3 * * *", "echo stale"),
];
sync_declarative_jobs(&config, &decls).unwrap();
// Now sync with only "keeper" — "stale" should be removed.
let decls_v2 = vec![make_shell_decl("keeper", "0 2 * * *", "echo keep")];
sync_declarative_jobs(&config, &decls_v2).unwrap();
assert!(get_job(&config, "stale").is_err());
assert!(get_job(&config, "keeper").is_ok());
}
#[test]
fn sync_empty_removes_all_declarative_jobs() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let decls = vec![make_shell_decl("to-remove", "0 2 * * *", "echo bye")];
sync_declarative_jobs(&config, &decls).unwrap();
assert!(get_job(&config, "to-remove").is_ok());
// Sync with empty list.
sync_declarative_jobs(&config, &[]).unwrap();
assert!(get_job(&config, "to-remove").is_err());
}
#[test]
fn sync_validates_shell_job_requires_command() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let mut decl = make_shell_decl("bad", "0 2 * * *", "echo ok");
decl.command = None;
let result = sync_declarative_jobs(&config, &[decl]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("command"));
}
#[test]
fn sync_validates_agent_job_requires_prompt() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let mut decl = make_agent_decl("bad-agent", "0 2 * * *", "do stuff");
decl.prompt = None;
let result = sync_declarative_jobs(&config, &[decl]);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("prompt"));
}
#[test]
fn sync_agent_job_inserts_correctly() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let decls = vec![make_agent_decl(
"agent-check",
"*/15 * * * *",
"check health",
)];
sync_declarative_jobs(&config, &decls).unwrap();
let job = get_job(&config, "agent-check").unwrap();
assert_eq!(job.job_type, JobType::Agent);
assert_eq!(job.prompt.as_deref(), Some("check health"));
assert_eq!(job.source, "declarative");
}
#[test]
fn sync_every_schedule_works() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let decl = crate::config::schema::CronJobDecl {
id: "interval-job".to_string(),
name: None,
job_type: "shell".to_string(),
schedule: crate::config::schema::CronScheduleDecl::Every { every_ms: 60000 },
command: Some("echo interval".to_string()),
prompt: None,
enabled: true,
model: None,
allowed_tools: None,
session_target: None,
delivery: None,
};
sync_declarative_jobs(&config, &[decl]).unwrap();
let job = get_job(&config, "interval-job").unwrap();
assert!(matches!(job.schedule, Schedule::Every { every_ms: 60000 }));
assert_eq!(job.command, "echo interval");
}
#[test]
fn declarative_config_parses_from_toml() {
let toml_str = r#"
enabled = true
[[jobs]]
id = "daily-report"
name = "Daily Report"
job_type = "shell"
command = "echo report"
schedule = { kind = "cron", expr = "0 9 * * *" }
[[jobs]]
id = "health-check"
job_type = "agent"
prompt = "Check server health"
schedule = { kind = "every", every_ms = 300000 }
"#;
let parsed: crate::config::schema::CronConfig = toml::from_str(toml_str).unwrap();
assert!(parsed.enabled);
assert_eq!(parsed.jobs.len(), 2);
assert_eq!(parsed.jobs[0].id, "daily-report");
assert_eq!(parsed.jobs[0].command.as_deref(), Some("echo report"));
assert!(matches!(
parsed.jobs[0].schedule,
crate::config::schema::CronScheduleDecl::Cron { ref expr, .. } if expr == "0 9 * * *"
));
assert_eq!(parsed.jobs[1].id, "health-check");
assert_eq!(parsed.jobs[1].job_type, "agent");
assert_eq!(
parsed.jobs[1].prompt.as_deref(),
Some("Check server health")
);
assert!(matches!(
parsed.jobs[1].schedule,
crate::config::schema::CronScheduleDecl::Every { every_ms: 300_000 }
));
}
}
+7
View File
@@ -127,6 +127,10 @@ fn default_true() -> bool {
true
}
fn default_source() -> String {
"imperative".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CronJob {
pub id: String,
@@ -146,6 +150,9 @@ pub struct CronJob {
/// When `None`, all tools are available (backward compatible default).
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_tools: Option<Vec<String>>,
/// How the job was created: `"imperative"` (CLI/API) or `"declarative"` (config).
#[serde(default = "default_source")]
pub source: String,
pub created_at: DateTime<Utc>,
pub next_run: DateTime<Utc>,
pub last_run: Option<DateTime<Utc>>,
+193 -1
View File
@@ -362,10 +362,22 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> {
};
// ── Phase 2: Execute selected tasks ─────────────────────
// Re-read session context on every tick so we pick up messages
// that arrived since the daemon started.
let session_context = if config.heartbeat.load_session_context {
load_heartbeat_session_context(&config)
} else {
None
};
let mut tick_had_error = false;
for task in &tasks_to_run {
let task_start = std::time::Instant::now();
let prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
let task_prompt = format!("[Heartbeat Task | {}] {}", task.priority, task.text);
let prompt = match &session_context {
Some(ctx) => format!("{ctx}\n\n{task_prompt}"),
None => task_prompt,
};
let temp = config.default_temperature;
match Box::pin(crate::agent::run(
config.clone(),
@@ -497,6 +509,186 @@ fn resolve_heartbeat_delivery(config: &Config) -> Result<Option<(String, String)
}
}
/// Load recent conversation history for the heartbeat's delivery target and
/// format it as a text preamble to inject into the task prompt.
///
/// Scans `{workspace}/sessions/` for JSONL files whose name starts with
/// `{channel}_` and ends with `_{to}.jsonl` (or exactly `{channel}_{to}.jsonl`),
/// then picks the most recently modified match. This handles session key
/// formats such as `telegram_diskiller.jsonl` and
/// `telegram_5673725398_diskiller.jsonl`.
/// Returns `None` when `target`/`to` are not configured or no session exists.
const HEARTBEAT_SESSION_CONTEXT_MESSAGES: usize = 20;
fn load_heartbeat_session_context(config: &Config) -> Option<String> {
use crate::providers::traits::ChatMessage;
let channel = config
.heartbeat
.target
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())?;
let to = config
.heartbeat
.to
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())?;
if channel.contains('/') || channel.contains('\\') || to.contains('/') || to.contains('\\') {
tracing::warn!("heartbeat session context: channel/to contains path separators, skipping");
return None;
}
let sessions_dir = config.workspace_dir.join("sessions");
// Find the most recently modified JSONL file that belongs to this target.
// Matches both `{channel}_{to}.jsonl` and `{channel}_{anything}_{to}.jsonl`.
let prefix = format!("{channel}_");
let suffix = format!("_{to}.jsonl");
let exact = format!("{channel}_{to}.jsonl");
let mid_prefix = format!("{channel}_{to}_");
let path = std::fs::read_dir(&sessions_dir)
.ok()?
.filter_map(|e| e.ok())
.filter(|e| {
let name = e.file_name();
let name = name.to_string_lossy();
name.ends_with(".jsonl")
&& (name == exact
|| (name.starts_with(&prefix) && name.ends_with(&suffix))
|| name.starts_with(&mid_prefix))
})
.max_by_key(|e| {
e.metadata()
.and_then(|m| m.modified())
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
})
.map(|e| e.path())?;
if !path.exists() {
tracing::debug!("💓 Heartbeat session context: no session file found for {channel}/{to}");
return None;
}
let messages = load_jsonl_messages(&path);
if messages.is_empty() {
return None;
}
let recent: Vec<&ChatMessage> = messages
.iter()
.filter(|m| m.role == "user" || m.role == "assistant")
.rev()
.take(HEARTBEAT_SESSION_CONTEXT_MESSAGES)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
// Only inject context if there is at least one real user message in the
// window. If the JSONL contains only assistant messages (e.g. previous
// heartbeat outputs with no reply yet), skip context to avoid feeding
// Monika's own messages back to her in a loop.
let has_user_message = recent.iter().any(|m| m.role == "user");
if !has_user_message {
tracing::debug!(
"💓 Heartbeat session context: no user messages in recent history — skipping"
);
return None;
}
// Use the session file's mtime as a proxy for when the last message arrived.
let last_message_age = std::fs::metadata(&path)
.ok()
.and_then(|m| m.modified().ok())
.and_then(|mtime| mtime.elapsed().ok());
let silence_note = match last_message_age {
Some(age) => {
let mins = age.as_secs() / 60;
if mins < 60 {
format!("(last message ~{mins} minutes ago)\n")
} else {
let hours = mins / 60;
let rem = mins % 60;
if rem == 0 {
format!("(last message ~{hours}h ago)\n")
} else {
format!("(last message ~{hours}h {rem}m ago)\n")
}
}
}
None => String::new(),
};
tracing::debug!(
"💓 Heartbeat session context: {} messages from {}, silence: {}",
recent.len(),
path.display(),
silence_note.trim(),
);
let mut ctx = format!(
"[Recent conversation history — use this for context when composing your message] {silence_note}",
);
for msg in &recent {
let label = if msg.role == "user" { "User" } else { "You" };
// Truncate very long messages to avoid bloating the prompt.
// Use char_indices to avoid panicking on multi-byte UTF-8 characters.
let content = if msg.content.len() > 500 {
let truncate_at = msg
.content
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= 500)
.last()
.unwrap_or(0);
format!("{}", &msg.content[..truncate_at])
} else {
msg.content.clone()
};
ctx.push_str(label);
ctx.push_str(": ");
ctx.push_str(&content);
ctx.push('\n');
}
Some(ctx)
}
/// Read the last `HEARTBEAT_SESSION_CONTEXT_MESSAGES` `ChatMessage` lines from
/// a JSONL session file using a bounded rolling window so we never hold the
/// entire file in memory.
fn load_jsonl_messages(path: &std::path::Path) -> Vec<crate::providers::traits::ChatMessage> {
use std::collections::VecDeque;
use std::io::BufRead;
let file = match std::fs::File::open(path) {
Ok(f) => f,
Err(_) => return Vec::new(),
};
let reader = std::io::BufReader::new(file);
let mut window: VecDeque<crate::providers::traits::ChatMessage> =
VecDeque::with_capacity(HEARTBEAT_SESSION_CONTEXT_MESSAGES + 1);
for line in reader.lines() {
let Ok(line) = line else { continue };
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Ok(msg) = serde_json::from_str::<crate::providers::traits::ChatMessage>(trimmed) {
window.push_back(msg);
if window.len() > HEARTBEAT_SESSION_CONTEXT_MESSAGES {
window.pop_front();
}
}
}
window.into_iter().collect()
}
/// Auto-detect the best channel for heartbeat delivery by checking which
/// channels are configured. Returns the first match in priority order.
fn auto_detect_heartbeat_channel(config: &Config) -> Option<(String, String)> {
+60 -3
View File
@@ -23,7 +23,7 @@ fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
}
/// Verify bearer token against PairingGuard. Returns error response if unauthorized.
fn require_auth(
pub(super) fn require_auth(
state: &AppState,
headers: &HeaderMap,
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
@@ -1280,12 +1280,16 @@ pub async fn handle_api_sessions_list(
.into_iter()
.filter_map(|meta| {
let session_id = meta.key.strip_prefix("gw_")?;
Some(serde_json::json!({
let mut entry = serde_json::json!({
"session_id": session_id,
"created_at": meta.created_at.to_rfc3339(),
"last_activity": meta.last_activity.to_rfc3339(),
"message_count": meta.message_count,
}))
});
if let Some(name) = meta.name {
entry["name"] = serde_json::Value::String(name);
}
Some(entry)
})
.collect();
@@ -1326,6 +1330,56 @@ pub async fn handle_api_session_delete(
}
}
/// PUT /api/sessions/{id} — rename a gateway session
pub async fn handle_api_session_rename(
State(state): State<AppState>,
headers: HeaderMap,
Path(id): Path<String>,
Json(body): Json<serde_json::Value>,
) -> impl IntoResponse {
if let Err(e) = require_auth(&state, &headers) {
return e.into_response();
}
let Some(ref backend) = state.session_backend else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "Session persistence is disabled"})),
)
.into_response();
};
let name = body["name"].as_str().unwrap_or("").trim();
if name.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "name is required"})),
)
.into_response();
}
let session_key = format!("gw_{id}");
// Verify the session exists before renaming
let sessions = backend.list_sessions();
if !sessions.contains(&session_key) {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "Session not found"})),
)
.into_response();
}
match backend.set_session_name(&session_key, name) {
Ok(()) => Json(serde_json::json!({"session_id": id, "name": name})).into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"error": format!("Failed to rename session: {e}")})),
)
.into_response(),
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -1429,6 +1483,7 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
@@ -1438,6 +1493,8 @@ mod tests {
session_backend: None,
device_registry: None,
pending_pairings: None,
path_prefix: String::new(),
canvas_store: crate::tools::canvas::CanvasStore::new(),
}
}
+278
View File
@@ -0,0 +1,278 @@
//! Live Canvas gateway routes — REST + WebSocket for real-time canvas updates.
//!
//! - `GET /api/canvas/:id` — get current canvas content (JSON)
//! - `POST /api/canvas/:id` — push content programmatically
//! - `GET /api/canvas` — list all active canvases
//! - `WS /ws/canvas/:id` — real-time canvas updates via WebSocket
use super::api::require_auth;
use super::AppState;
use axum::{
extract::{
ws::{Message, WebSocket},
Path, State, WebSocketUpgrade,
},
http::{header, HeaderMap, StatusCode},
response::{IntoResponse, Json},
};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
/// POST /api/canvas/:id request body.
#[derive(Deserialize)]
pub struct CanvasPostBody {
pub content_type: Option<String>,
pub content: String,
}
/// GET /api/canvas — list all active canvases.
pub async fn handle_canvas_list(
State(state): State<AppState>,
headers: HeaderMap,
) -> impl IntoResponse {
if let Err(e) = require_auth(&state, &headers) {
return e.into_response();
}
let ids = state.canvas_store.list();
Json(serde_json::json!({ "canvases": ids })).into_response()
}
/// GET /api/canvas/:id — get current canvas content.
pub async fn handle_canvas_get(
State(state): State<AppState>,
headers: HeaderMap,
Path(id): Path<String>,
) -> impl IntoResponse {
if let Err(e) = require_auth(&state, &headers) {
return e.into_response();
}
match state.canvas_store.snapshot(&id) {
Some(frame) => Json(serde_json::json!({
"canvas_id": id,
"frame": frame,
}))
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": format!("Canvas '{}' not found", id) })),
)
.into_response(),
}
}
/// GET /api/canvas/:id/history — get canvas frame history.
pub async fn handle_canvas_history(
State(state): State<AppState>,
headers: HeaderMap,
Path(id): Path<String>,
) -> impl IntoResponse {
if let Err(e) = require_auth(&state, &headers) {
return e.into_response();
}
let history = state.canvas_store.history(&id);
Json(serde_json::json!({
"canvas_id": id,
"frames": history,
}))
.into_response()
}
/// POST /api/canvas/:id — push content to a canvas.
pub async fn handle_canvas_post(
State(state): State<AppState>,
headers: HeaderMap,
Path(id): Path<String>,
Json(body): Json<CanvasPostBody>,
) -> impl IntoResponse {
if let Err(e) = require_auth(&state, &headers) {
return e.into_response();
}
let content_type = body.content_type.as_deref().unwrap_or("html");
// Validate content_type against allowed set (prevent injecting "eval" frames via REST).
if !crate::tools::canvas::ALLOWED_CONTENT_TYPES.contains(&content_type) {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"error": format!(
"Invalid content_type '{}'. Allowed: {:?}",
content_type,
crate::tools::canvas::ALLOWED_CONTENT_TYPES
)
})),
)
.into_response();
}
// Enforce content size limit (same as tool-side validation).
if body.content.len() > crate::tools::canvas::MAX_CONTENT_SIZE {
return (
StatusCode::PAYLOAD_TOO_LARGE,
Json(serde_json::json!({
"error": format!(
"Content exceeds maximum size of {} bytes",
crate::tools::canvas::MAX_CONTENT_SIZE
)
})),
)
.into_response();
}
match state.canvas_store.render(&id, content_type, &body.content) {
Some(frame) => (
StatusCode::CREATED,
Json(serde_json::json!({
"canvas_id": id,
"frame": frame,
})),
)
.into_response(),
None => (
StatusCode::TOO_MANY_REQUESTS,
Json(serde_json::json!({
"error": "Maximum canvas count reached. Clear unused canvases first."
})),
)
.into_response(),
}
}
/// DELETE /api/canvas/:id — clear a canvas.
pub async fn handle_canvas_clear(
State(state): State<AppState>,
headers: HeaderMap,
Path(id): Path<String>,
) -> impl IntoResponse {
if let Err(e) = require_auth(&state, &headers) {
return e.into_response();
}
state.canvas_store.clear(&id);
Json(serde_json::json!({
"canvas_id": id,
"status": "cleared",
}))
.into_response()
}
/// WS /ws/canvas/:id — real-time canvas updates.
pub async fn handle_ws_canvas(
State(state): State<AppState>,
Path(id): Path<String>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
// Auth check (same pattern as ws::handle_ws_chat)
if state.pairing.require_pairing() {
let token = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer "))
.or_else(|| {
// Fallback: check query params in the upgrade request URI
headers
.get("sec-websocket-protocol")
.and_then(|v| v.to_str().ok())
.and_then(|protos| {
protos
.split(',')
.map(|p| p.trim())
.find_map(|p| p.strip_prefix("bearer."))
})
})
.unwrap_or("");
if !state.pairing.is_authenticated(token) {
return (
StatusCode::UNAUTHORIZED,
"Unauthorized — provide Authorization header or Sec-WebSocket-Protocol bearer",
)
.into_response();
}
}
ws.on_upgrade(move |socket| handle_canvas_socket(socket, state, id))
.into_response()
}
async fn handle_canvas_socket(socket: WebSocket, state: AppState, canvas_id: String) {
let (mut sender, mut receiver) = socket.split();
// Subscribe to canvas updates
let mut rx = match state.canvas_store.subscribe(&canvas_id) {
Some(rx) => rx,
None => {
let msg = serde_json::json!({
"type": "error",
"error": "Maximum canvas count reached",
});
let _ = sender.send(Message::Text(msg.to_string().into())).await;
return;
}
};
// Send current state immediately if available
if let Some(frame) = state.canvas_store.snapshot(&canvas_id) {
let msg = serde_json::json!({
"type": "frame",
"canvas_id": canvas_id,
"frame": frame,
});
let _ = sender.send(Message::Text(msg.to_string().into())).await;
}
// Send a connected acknowledgement
let ack = serde_json::json!({
"type": "connected",
"canvas_id": canvas_id,
});
let _ = sender.send(Message::Text(ack.to_string().into())).await;
// Spawn a task that forwards broadcast updates to the WebSocket
let canvas_id_clone = canvas_id.clone();
let send_task = tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(frame) => {
let msg = serde_json::json!({
"type": "frame",
"canvas_id": canvas_id_clone,
"frame": frame,
});
if sender
.send(Message::Text(msg.to_string().into()))
.await
.is_err()
{
break;
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
// Client fell behind — notify and continue rather than disconnecting.
let msg = serde_json::json!({
"type": "lagged",
"canvas_id": canvas_id_clone,
"missed_frames": n,
});
let _ = sender.send(Message::Text(msg.to_string().into())).await;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
});
// Read loop: we mostly ignore incoming messages but handle close/ping
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Close(_)) | Err(_) => break,
_ => {} // Ignore all other messages (pings are handled by axum)
}
}
// Abort the send task when the connection is closed
send_task.abort();
}
+194 -40
View File
@@ -11,14 +11,15 @@ pub mod api;
pub mod api_pairing;
#[cfg(feature = "plugins-wasm")]
pub mod api_plugins;
pub mod canvas;
pub mod nodes;
pub mod sse;
pub mod static_files;
pub mod ws;
use crate::channels::{
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel, LinqChannel,
NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
session_backend::SessionBackend, session_sqlite::SqliteSessionBackend, Channel,
GmailPushChannel, LinqChannel, NextcloudTalkChannel, SendMessage, WatiChannel, WhatsAppChannel,
};
use crate::config::Config;
use crate::cost::CostTracker;
@@ -28,6 +29,7 @@ use crate::runtime;
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
use crate::security::SecurityPolicy;
use crate::tools;
use crate::tools::canvas::CanvasStore;
use crate::tools::traits::ToolSpec;
use crate::util::truncate_with_ellipsis;
use anyhow::{Context, Result};
@@ -336,6 +338,8 @@ pub struct AppState {
/// Nextcloud Talk webhook secret for signature verification
pub nextcloud_talk_webhook_secret: Option<Arc<str>>,
pub wati: Option<Arc<WatiChannel>>,
/// Gmail Pub/Sub push notification channel
pub gmail_push: Option<Arc<GmailPushChannel>>,
/// Observability backend for metrics scraping
pub observer: Arc<dyn crate::observability::Observer>,
/// Registered tool specs (for web dashboard tools page)
@@ -348,12 +352,16 @@ pub struct AppState {
pub shutdown_tx: tokio::sync::watch::Sender<bool>,
/// Registry of dynamically connected nodes
pub node_registry: Arc<nodes::NodeRegistry>,
/// Path prefix for reverse-proxy deployments (empty string = no prefix)
pub path_prefix: String,
/// Session backend for persisting gateway WS chat sessions
pub session_backend: Option<Arc<dyn SessionBackend>>,
/// Device registry for paired device management
pub device_registry: Option<Arc<api_pairing::DeviceRegistry>>,
/// Pending pairing request store
pub pending_pairings: Option<Arc<api_pairing::PairingStore>>,
/// Shared canvas store for Live Canvas (A2UI) system
pub canvas_store: CanvasStore,
}
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
@@ -430,21 +438,25 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
(None, None)
};
let (mut tools_registry_raw, delegate_handle_gw) = tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
Arc::clone(&mem),
composio_key,
composio_entity_id,
&config.browser,
&config.http_request,
&config.web_fetch,
&config.workspace_dir,
&config.agents,
config.api_key.as_deref(),
&config,
);
let canvas_store = tools::CanvasStore::new();
let (mut tools_registry_raw, delegate_handle_gw, _reaction_handle_gw) =
tools::all_tools_with_runtime(
Arc::new(config.clone()),
&security,
runtime,
Arc::clone(&mem),
composio_key,
composio_entity_id,
&config.browser,
&config.http_request,
&config.web_fetch,
&config.workspace_dir,
&config.agents,
config.api_key.as_deref(),
&config,
Some(canvas_store.clone()),
);
// ── Wire MCP tools into the gateway tool registry (non-fatal) ───
// Without this, the `/api/tools` endpoint misses MCP tools.
@@ -627,6 +639,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
})
.map(Arc::from);
// Gmail Push channel (if configured and enabled)
let gmail_push_channel: Option<Arc<GmailPushChannel>> = config
.channels_config
.gmail_push
.as_ref()
.filter(|gp| gp.enabled)
.map(|gp| Arc::new(GmailPushChannel::new(gp.clone())));
// ── Session persistence for WS chat ─────────────────────
let session_backend: Option<Arc<dyn SessionBackend>> = if config.gateway.session_persistence {
match SqliteSessionBackend::new(&config.workspace_dir) {
@@ -673,6 +693,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
idempotency_max_keys,
));
// Resolve optional path prefix for reverse-proxy deployments.
let path_prefix: Option<&str> = config
.gateway
.path_prefix
.as_deref()
.filter(|p| !p.is_empty());
// ── Tunnel ────────────────────────────────────────────────
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
let mut tunnel_url: Option<String> = None;
@@ -691,18 +718,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
}
}
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
let pfx = path_prefix.unwrap_or("");
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}{pfx}");
if let Some(ref url) = tunnel_url {
println!(" 🌐 Public URL: {url}");
}
println!(" 🌐 Web Dashboard: http://{display_addr}/");
println!(" 🌐 Web Dashboard: http://{display_addr}{pfx}/");
if let Some(code) = pairing.pairing_code() {
println!();
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
println!(" ┌──────────────┐");
println!("{code}");
println!(" └──────────────┘");
println!();
println!(" Send: POST {pfx}/pair with header X-Pairing-Code: {code}");
} else if pairing.require_pairing() {
println!(" 🔒 Pairing: ACTIVE (bearer token required)");
println!(" To pair a new device: zeroclaw gateway get-paircode --new");
@@ -711,29 +739,29 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
println!();
}
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
println!(" POST {pfx}/pair — pair a new client (X-Pairing-Code header)");
println!(" POST {pfx}/webhook — {{\"message\": \"your prompt\"}}");
if whatsapp_channel.is_some() {
println!(" GET /whatsapp — Meta webhook verification");
println!(" POST /whatsapp — WhatsApp message webhook");
println!(" GET {pfx}/whatsapp — Meta webhook verification");
println!(" POST {pfx}/whatsapp — WhatsApp message webhook");
}
if linq_channel.is_some() {
println!(" POST /linq — Linq message webhook (iMessage/RCS/SMS)");
println!(" POST {pfx}/linq — Linq message webhook (iMessage/RCS/SMS)");
}
if wati_channel.is_some() {
println!(" GET /wati — WATI webhook verification");
println!(" POST /wati — WATI message webhook");
println!(" GET {pfx}/wati — WATI webhook verification");
println!(" POST {pfx}/wati — WATI message webhook");
}
if nextcloud_talk_channel.is_some() {
println!(" POST /nextcloud-talk — Nextcloud Talk bot webhook");
println!(" POST {pfx}/nextcloud-talk — Nextcloud Talk bot webhook");
}
println!(" GET /api/* — REST API (bearer token required)");
println!(" GET /ws/chat — WebSocket agent chat");
println!(" GET {pfx}/api/* — REST API (bearer token required)");
println!(" GET {pfx}/ws/chat — WebSocket agent chat");
if config.nodes.enabled {
println!(" GET /ws/nodes — WebSocket node discovery");
println!(" GET {pfx}/ws/nodes — WebSocket node discovery");
}
println!(" GET /health — health check");
println!(" GET /metrics — Prometheus metrics");
println!(" GET {pfx}/health — health check");
println!(" GET {pfx}/metrics — Prometheus metrics");
println!(" Press Ctrl+C to stop.\n");
crate::health::mark_component_ok("gateway");
@@ -790,6 +818,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
nextcloud_talk: nextcloud_talk_channel,
nextcloud_talk_webhook_secret,
wati: wati_channel,
gmail_push: gmail_push_channel,
observer: broadcast_observer,
tools_registry,
cost_tracker,
@@ -799,6 +828,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
session_backend,
device_registry,
pending_pairings,
path_prefix: path_prefix.unwrap_or("").to_string(),
canvas_store,
};
// Config PUT needs larger body limit (1MB)
@@ -807,7 +838,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.layer(RequestBodyLimitLayer::new(1_048_576));
// Build router with middleware
let app = Router::new()
let inner = Router::new()
// ── Admin routes (for CLI management) ──
.route("/admin/shutdown", post(handle_admin_shutdown))
.route("/admin/paircode", get(handle_admin_paircode))
@@ -823,6 +854,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.route("/wati", get(handle_wati_verify))
.route("/wati", post(handle_wati_webhook))
.route("/nextcloud-talk", post(handle_nextcloud_talk_webhook))
.route("/webhook/gmail", post(handle_gmail_push_webhook))
// ── Web Dashboard API routes ──
.route("/api/status", get(api::handle_api_status))
.route("/api/config", get(api::handle_api_config_get))
@@ -854,7 +886,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.route("/api/cli-tools", get(api::handle_api_cli_tools))
.route("/api/health", get(api::handle_api_health))
.route("/api/sessions", get(api::handle_api_sessions_list))
.route("/api/sessions/{id}", delete(api::handle_api_session_delete))
.route("/api/sessions/{id}", delete(api::handle_api_session_delete).put(api::handle_api_session_rename))
// ── Pairing + Device management API ──
.route("/api/pairing/initiate", post(api_pairing::initiate_pairing))
.route("/api/pair", post(api_pairing::submit_pairing_enhanced))
@@ -863,34 +895,61 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.route(
"/api/devices/{id}/token/rotate",
post(api_pairing::rotate_token),
)
// ── Live Canvas (A2UI) routes ──
.route("/api/canvas", get(canvas::handle_canvas_list))
.route(
"/api/canvas/{id}",
get(canvas::handle_canvas_get)
.post(canvas::handle_canvas_post)
.delete(canvas::handle_canvas_clear),
)
.route(
"/api/canvas/{id}/history",
get(canvas::handle_canvas_history),
);
// ── Plugin management API (requires plugins-wasm feature) ──
#[cfg(feature = "plugins-wasm")]
let app = app.route(
let inner = inner.route(
"/api/plugins",
get(api_plugins::plugin_routes::list_plugins),
);
let app = app
let inner = inner
// ── SSE event stream ──
.route("/api/events", get(sse::handle_sse_events))
// ── WebSocket agent chat ──
.route("/ws/chat", get(ws::handle_ws_chat))
// ── WebSocket canvas updates ──
.route("/ws/canvas/{id}", get(canvas::handle_ws_canvas))
// ── WebSocket node discovery ──
.route("/ws/nodes", get(nodes::handle_ws_nodes))
// ── Static assets (web dashboard) ──
.route("/_app/{*path}", get(static_files::handle_static))
// ── Config PUT with larger body limit ──
.merge(config_put_router)
// ── SPA fallback: non-API GET requests serve index.html ──
.fallback(get(static_files::handle_spa_fallback))
.with_state(state)
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE))
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(gateway_request_timeout_secs()),
))
// ── SPA fallback: non-API GET requests serve index.html ──
.fallback(get(static_files::handle_spa_fallback));
));
// Nest under path prefix when configured (axum strips prefix before routing).
// nest() at "/prefix" handles both "/prefix" and "/prefix/*" but not "/prefix/"
// with a trailing slash, so we add a fallback redirect for that case.
let app = if let Some(prefix) = path_prefix {
let redirect_target = prefix.to_string();
Router::new().nest(prefix, inner).route(
&format!("{prefix}/"),
get(|| async move { axum::response::Redirect::permanent(&redirect_target) }),
)
} else {
inner
};
// Run the server with graceful shutdown
axum::serve(
@@ -1788,6 +1847,74 @@ async fn handle_nextcloud_talk_webhook(
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
}
/// Maximum request body size for the Gmail webhook endpoint (1 MB).
/// Google Pub/Sub messages are typically under 10 KB.
const GMAIL_WEBHOOK_MAX_BODY: usize = 1024 * 1024;
/// POST /webhook/gmail — incoming Gmail Pub/Sub push notification
async fn handle_gmail_push_webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: Bytes,
) -> impl IntoResponse {
let Some(ref gmail_push) = state.gmail_push else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "Gmail push not configured"})),
);
};
// Enforce body size limit.
if body.len() > GMAIL_WEBHOOK_MAX_BODY {
return (
StatusCode::PAYLOAD_TOO_LARGE,
Json(serde_json::json!({"error": "Request body too large"})),
);
}
// Authenticate the webhook request using a shared secret.
let secret = gmail_push.resolve_webhook_secret();
if !secret.is_empty() {
let provided = headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|auth| auth.strip_prefix("Bearer "))
.unwrap_or("");
if provided != secret {
tracing::warn!("Gmail push webhook: unauthorized request");
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({"error": "Unauthorized"})),
);
}
}
let body_str = String::from_utf8_lossy(&body);
let envelope: crate::channels::gmail_push::PubSubEnvelope =
match serde_json::from_str(&body_str) {
Ok(e) => e,
Err(e) => {
tracing::warn!("Gmail push webhook: invalid payload: {e}");
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid Pub/Sub envelope"})),
);
}
};
// Process the notification asynchronously (non-blocking for the webhook response)
let channel = Arc::clone(gmail_push);
tokio::spawn(async move {
if let Err(e) = channel.handle_notification(&envelope).await {
tracing::error!("Gmail push notification processing failed: {e:#}");
}
});
// Acknowledge immediately — Google Pub/Sub requires a 2xx within ~10s
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
}
// ══════════════════════════════════════════════════════════════════════════════
// ADMIN HANDLERS (for CLI management)
// ══════════════════════════════════════════════════════════════════════════════
@@ -1976,15 +2103,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let response = handle_metrics(State(state)).await.into_response();
@@ -2031,15 +2161,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer,
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let response = handle_metrics(State(state)).await.into_response();
@@ -2415,15 +2548,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let mut headers = HeaderMap::new();
@@ -2484,15 +2620,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let headers = HeaderMap::new();
@@ -2565,15 +2704,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let response = handle_webhook(
@@ -2618,15 +2760,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let mut headers = HeaderMap::new();
@@ -2676,15 +2821,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let mut headers = HeaderMap::new();
@@ -2739,15 +2887,18 @@ mod tests {
nextcloud_talk: None,
nextcloud_talk_webhook_secret: None,
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let response = Box::pin(handle_nextcloud_talk_webhook(
@@ -2798,15 +2949,18 @@ mod tests {
nextcloud_talk: Some(channel),
nextcloud_talk_webhook_secret: Some(Arc::from(secret)),
wati: None,
gmail_push: None,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
cost_tracker: None,
event_tx: tokio::sync::broadcast::channel(16).0,
shutdown_tx: tokio::sync::watch::channel(false).0,
node_registry: Arc::new(nodes::NodeRegistry::new(16)),
path_prefix: String::new(),
session_backend: None,
device_registry: None,
pending_pairings: None,
canvas_store: CanvasStore::new(),
};
let mut headers = HeaderMap::new();
+33 -5
View File
@@ -3,11 +3,14 @@
//! Uses `rust-embed` to bundle the `web/dist/` directory into the binary at compile time.
use axum::{
extract::State,
http::{header, StatusCode, Uri},
response::{IntoResponse, Response},
};
use rust_embed::Embed;
use super::AppState;
#[derive(Embed)]
#[folder = "web/dist/"]
struct WebAssets;
@@ -23,16 +26,41 @@ pub async fn handle_static(uri: Uri) -> Response {
serve_embedded_file(path)
}
/// SPA fallback: serve index.html for any non-API, non-static GET request
pub async fn handle_spa_fallback() -> Response {
if WebAssets::get("index.html").is_none() {
/// SPA fallback: serve index.html for any non-API, non-static GET request.
/// Injects `window.__ZEROCLAW_BASE__` so the frontend knows the path prefix.
pub async fn handle_spa_fallback(State(state): State<AppState>) -> Response {
let Some(content) = WebAssets::get("index.html") else {
return (
StatusCode::SERVICE_UNAVAILABLE,
"Web dashboard not available. Build it with: cd web && npm ci && npm run build",
)
.into_response();
}
serve_embedded_file("index.html")
};
let html = String::from_utf8_lossy(&content.data);
// Inject path prefix for the SPA and rewrite asset paths in the HTML
let html = if state.path_prefix.is_empty() {
html.into_owned()
} else {
let pfx = &state.path_prefix;
// JSON-encode the prefix to safely embed in a <script> block
let json_pfx = serde_json::to_string(pfx).unwrap_or_else(|_| "\"\"".to_string());
let script = format!("<script>window.__ZEROCLAW_BASE__={json_pfx};</script>");
// Rewrite absolute /_app/ references so the browser requests {prefix}/_app/...
html.replace("/_app/", &format!("{pfx}/_app/"))
.replace("<head>", &format!("<head>{script}"))
};
(
StatusCode::OK,
[
(header::CONTENT_TYPE, "text/html; charset=utf-8".to_string()),
(header::CACHE_CONTROL, "no-cache".to_string()),
],
html,
)
.into_response()
}
fn serve_embedded_file(path: &str) -> Response {
+34 -3
View File
@@ -1,13 +1,21 @@
//! WebSocket agent chat handler.
//!
//! Connect: `ws://host:port/ws/chat?session_id=ID&name=My+Session`
//!
//! Protocol:
//! ```text
//! Server -> Client: {"type":"session_start","session_id":"...","name":"...","resumed":true,"message_count":42}
//! Client -> Server: {"type":"message","content":"Hello"}
//! Server -> Client: {"type":"chunk","content":"Hi! "}
//! Server -> Client: {"type":"tool_call","name":"shell","args":{...}}
//! Server -> Client: {"type":"tool_result","name":"shell","output":"..."}
//! Server -> Client: {"type":"done","full_response":"..."}
//! ```
//!
//! Query params:
//! - `session_id` — resume or create a session (default: new UUID)
//! - `name` — optional human-readable label for the session
//! - `token` — bearer auth token (alternative to Authorization header)
use super::AppState;
use axum::{
@@ -53,6 +61,8 @@ const BEARER_SUBPROTO_PREFIX: &str = "bearer.";
pub struct WsQuery {
pub token: Option<String>,
pub session_id: Option<String>,
/// Optional human-readable name for the session.
pub name: Option<String>,
}
/// Extract a bearer token from WebSocket-compatible sources.
@@ -134,14 +144,20 @@ pub async fn handle_ws_chat(
};
let session_id = params.session_id;
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id))
let session_name = params.name;
ws.on_upgrade(move |socket| handle_socket(socket, state, session_id, session_name))
.into_response()
}
/// Gateway session key prefix to avoid collisions with channel sessions.
const GW_SESSION_PREFIX: &str = "gw_";
async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<String>) {
async fn handle_socket(
socket: WebSocket,
state: AppState,
session_id: Option<String>,
session_name: Option<String>,
) {
let (mut sender, mut receiver) = socket.split();
// Resolve session ID: use provided or generate a new UUID
@@ -163,6 +179,7 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
// Hydrate agent from persisted session (if available)
let mut resumed = false;
let mut message_count: usize = 0;
let mut effective_name: Option<String> = None;
if let Some(ref backend) = state.session_backend {
let messages = backend.load(&session_key);
if !messages.is_empty() {
@@ -170,15 +187,29 @@ async fn handle_socket(socket: WebSocket, state: AppState, session_id: Option<St
agent.seed_history(&messages);
resumed = true;
}
// Set session name if provided (non-empty) on connect
if let Some(ref name) = session_name {
if !name.is_empty() {
let _ = backend.set_session_name(&session_key, name);
effective_name = Some(name.clone());
}
}
// If no name was provided via query param, load the stored name
if effective_name.is_none() {
effective_name = backend.get_session_name(&session_key).unwrap_or(None);
}
}
// Send session_start message to client
let session_start = serde_json::json!({
let mut session_start = serde_json::json!({
"type": "session_start",
"session_id": session_id,
"resumed": resumed,
"message_count": message_count,
});
if let Some(ref name) = effective_name {
session_start["name"] = serde_json::Value::String(name.clone());
}
let _ = sender
.send(Message::Text(session_start.to_string().into()))
.await;
+5 -1
View File
@@ -407,7 +407,11 @@ mod tests {
// Simpler: write a temp script.
let dir = tempfile::tempdir().unwrap();
let script_path = dir.path().join("tool.sh");
std::fs::write(&script_path, format!("#!/bin/sh\necho '{}'\n", result_json)).unwrap();
std::fs::write(
&script_path,
format!("#!/bin/sh\ncat > /dev/null\necho '{}'\n", result_json),
)
.unwrap();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
+1
View File
@@ -891,6 +891,7 @@ mod tests {
device_id: None,
room_id: "!r:m".into(),
allowed_users: vec![],
allowed_rooms: vec![],
interrupt_on_new_message: false,
});
let entries = all_integrations();
+293
View File
@@ -0,0 +1,293 @@
//! Audit trail for memory operations.
//!
//! Provides a decorator `AuditedMemory<M>` that wraps any `Memory` backend
//! and logs all operations to a `memory_audit` table. Opt-in via
//! `[memory] audit_enabled = true`.
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
use async_trait::async_trait;
use chrono::Local;
use parking_lot::Mutex;
use rusqlite::{params, Connection};
use std::path::{Path, PathBuf};
use std::sync::Arc;
/// Audit log entry operations.
#[derive(Debug, Clone, Copy)]
pub enum AuditOp {
Store,
Recall,
Get,
List,
Forget,
StoreProcedural,
}
impl std::fmt::Display for AuditOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Store => write!(f, "store"),
Self::Recall => write!(f, "recall"),
Self::Get => write!(f, "get"),
Self::List => write!(f, "list"),
Self::Forget => write!(f, "forget"),
Self::StoreProcedural => write!(f, "store_procedural"),
}
}
}
/// Decorator that wraps a `Memory` backend with audit logging.
pub struct AuditedMemory<M: Memory> {
inner: M,
audit_conn: Arc<Mutex<Connection>>,
#[allow(dead_code)]
db_path: PathBuf,
}
impl<M: Memory> AuditedMemory<M> {
pub fn new(inner: M, workspace_dir: &Path) -> anyhow::Result<Self> {
let db_path = workspace_dir.join("memory").join("audit.db");
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(&db_path)?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
CREATE TABLE IF NOT EXISTS memory_audit (
id INTEGER PRIMARY KEY AUTOINCREMENT,
operation TEXT NOT NULL,
key TEXT,
namespace TEXT,
session_id TEXT,
timestamp TEXT NOT NULL,
metadata TEXT
);
CREATE INDEX IF NOT EXISTS idx_audit_timestamp ON memory_audit(timestamp);
CREATE INDEX IF NOT EXISTS idx_audit_operation ON memory_audit(operation);",
)?;
Ok(Self {
inner,
audit_conn: Arc::new(Mutex::new(conn)),
db_path,
})
}
fn log_audit(
&self,
op: AuditOp,
key: Option<&str>,
namespace: Option<&str>,
session_id: Option<&str>,
metadata: Option<&str>,
) {
let conn = self.audit_conn.lock();
let now = Local::now().to_rfc3339();
let op_str = op.to_string();
let _ = conn.execute(
"INSERT INTO memory_audit (operation, key, namespace, session_id, timestamp, metadata)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![op_str, key, namespace, session_id, now, metadata],
);
}
/// Prune audit entries older than the given number of days.
pub fn prune_older_than(&self, retention_days: u32) -> anyhow::Result<u64> {
let conn = self.audit_conn.lock();
let cutoff =
(Local::now() - chrono::Duration::days(i64::from(retention_days))).to_rfc3339();
let affected = conn.execute(
"DELETE FROM memory_audit WHERE timestamp < ?1",
params![cutoff],
)?;
Ok(u64::try_from(affected).unwrap_or(0))
}
/// Count total audit entries.
pub fn audit_count(&self) -> anyhow::Result<usize> {
let conn = self.audit_conn.lock();
let count: i64 =
conn.query_row("SELECT COUNT(*) FROM memory_audit", [], |row| row.get(0))?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(count as usize)
}
}
#[async_trait]
impl<M: Memory> Memory for AuditedMemory<M> {
fn name(&self) -> &str {
self.inner.name()
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
self.log_audit(AuditOp::Store, Some(key), None, session_id, None);
self.inner.store(key, content, category, session_id).await
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.log_audit(
AuditOp::Recall,
None,
None,
session_id,
Some(&format!("query={query}")),
);
self.inner
.recall(query, limit, session_id, since, until)
.await
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
self.log_audit(AuditOp::Get, Some(key), None, None, None);
self.inner.get(key).await
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.log_audit(AuditOp::List, None, None, session_id, None);
self.inner.list(category, session_id).await
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
self.log_audit(AuditOp::Forget, Some(key), None, None, None);
self.inner.forget(key).await
}
async fn count(&self) -> anyhow::Result<usize> {
self.inner.count().await
}
async fn health_check(&self) -> bool {
self.inner.health_check().await
}
async fn store_procedural(
&self,
messages: &[ProceduralMessage],
session_id: Option<&str>,
) -> anyhow::Result<()> {
self.log_audit(
AuditOp::StoreProcedural,
None,
None,
session_id,
Some(&format!("messages={}", messages.len())),
);
self.inner.store_procedural(messages, session_id).await
}
async fn recall_namespaced(
&self,
namespace: &str,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.log_audit(
AuditOp::Recall,
None,
Some(namespace),
session_id,
Some(&format!("query={query}")),
);
self.inner
.recall_namespaced(namespace, query, limit, session_id, since, until)
.await
}
async fn store_with_metadata(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
namespace: Option<&str>,
importance: Option<f64>,
) -> anyhow::Result<()> {
self.log_audit(AuditOp::Store, Some(key), namespace, session_id, None);
self.inner
.store_with_metadata(key, content, category, session_id, namespace, importance)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::NoneMemory;
use tempfile::TempDir;
#[tokio::test]
async fn audited_memory_logs_store_operation() {
let tmp = TempDir::new().unwrap();
let inner = NoneMemory::new();
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
audited
.store("test_key", "test_value", MemoryCategory::Core, None)
.await
.unwrap();
assert_eq!(audited.audit_count().unwrap(), 1);
}
#[tokio::test]
async fn audited_memory_logs_recall_operation() {
let tmp = TempDir::new().unwrap();
let inner = NoneMemory::new();
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
let _ = audited.recall("query", 10, None, None, None).await;
assert_eq!(audited.audit_count().unwrap(), 1);
}
#[tokio::test]
async fn audited_memory_prune_works() {
let tmp = TempDir::new().unwrap();
let inner = NoneMemory::new();
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
audited
.store("k1", "v1", MemoryCategory::Core, None)
.await
.unwrap();
// Pruning with 0 days should remove entries
let pruned = audited.prune_older_than(0).unwrap();
// Entry was just created, so 0-day retention should remove it
// Pruning should succeed (pruned is usize, always >= 0)
let _ = pruned;
}
#[tokio::test]
async fn audited_memory_delegates_correctly() {
let tmp = TempDir::new().unwrap();
let inner = NoneMemory::new();
let audited = AuditedMemory::new(inner, tmp.path()).unwrap();
assert_eq!(audited.name(), "none");
assert!(audited.health_check().await);
assert_eq!(audited.count().await.unwrap(), 0);
}
}
-12
View File
@@ -4,7 +4,6 @@ pub enum MemoryBackendKind {
Lucid,
Postgres,
Qdrant,
Mem0,
Markdown,
None,
Unknown,
@@ -66,15 +65,6 @@ const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: false,
};
const MEM0_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "mem0",
label: "Mem0 (OpenMemory) — semantic memory with LLM fact extraction via [memory.mem0]",
auto_save_default: true,
uses_sqlite_hygiene: false,
sqlite_based: false,
optional_dependency: true,
};
const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "none",
label: "None — disable persistent memory",
@@ -114,7 +104,6 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
"lucid" => MemoryBackendKind::Lucid,
"postgres" => MemoryBackendKind::Postgres,
"qdrant" => MemoryBackendKind::Qdrant,
"mem0" | "openmemory" => MemoryBackendKind::Mem0,
"markdown" => MemoryBackendKind::Markdown,
"none" => MemoryBackendKind::None,
_ => MemoryBackendKind::Unknown,
@@ -127,7 +116,6 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
MemoryBackendKind::Lucid => LUCID_PROFILE,
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
MemoryBackendKind::Mem0 => MEM0_PROFILE,
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
MemoryBackendKind::None => NONE_PROFILE,
MemoryBackendKind::Unknown => CUSTOM_PROFILE,
File diff suppressed because it is too large Load Diff
+173
View File
@@ -0,0 +1,173 @@
//! Conflict resolution for memory entries.
//!
//! Before storing Core memories, performs a semantic similarity check against
//! existing entries. If cosine similarity exceeds a threshold but content
//! differs, the old entry is marked as superseded.
use super::traits::{Memory, MemoryCategory, MemoryEntry};
/// Check for conflicting memories and mark old ones as superseded.
///
/// Returns the list of entry IDs that were superseded.
pub async fn check_and_resolve_conflicts(
memory: &dyn Memory,
key: &str,
content: &str,
category: &MemoryCategory,
threshold: f64,
) -> anyhow::Result<Vec<String>> {
// Only check conflicts for Core memories
if !matches!(category, MemoryCategory::Core) {
return Ok(Vec::new());
}
// Search for similar existing entries
let candidates = memory.recall(content, 10, None, None, None).await?;
let mut superseded = Vec::new();
for candidate in &candidates {
if candidate.key == key {
continue; // Same key = update, not conflict
}
if !matches!(candidate.category, MemoryCategory::Core) {
continue;
}
if let Some(score) = candidate.score {
if score > threshold && candidate.content != content {
superseded.push(candidate.id.clone());
}
}
}
Ok(superseded)
}
/// Mark entries as superseded in SQLite by setting their `superseded_by` column.
pub fn mark_superseded(
conn: &rusqlite::Connection,
superseded_ids: &[String],
new_id: &str,
) -> anyhow::Result<()> {
if superseded_ids.is_empty() {
return Ok(());
}
for id in superseded_ids {
conn.execute(
"UPDATE memories SET superseded_by = ?1 WHERE id = ?2",
rusqlite::params![new_id, id],
)?;
}
Ok(())
}
/// Simple text-based conflict detection without embeddings.
///
/// Uses token overlap (Jaccard similarity) as a fast approximation
/// when vector embeddings are unavailable.
pub fn jaccard_similarity(a: &str, b: &str) -> f64 {
let words_a: std::collections::HashSet<&str> = a.split_whitespace().collect();
let words_b: std::collections::HashSet<&str> = b.split_whitespace().collect();
if words_a.is_empty() && words_b.is_empty() {
return 1.0;
}
if words_a.is_empty() || words_b.is_empty() {
return 0.0;
}
let intersection = words_a.intersection(&words_b).count();
let union = words_a.union(&words_b).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
/// Find potentially conflicting entries using text similarity when embeddings
/// are not available. Returns entries above the threshold.
pub fn find_text_conflicts(
entries: &[MemoryEntry],
new_content: &str,
threshold: f64,
) -> Vec<String> {
entries
.iter()
.filter(|e| {
matches!(e.category, MemoryCategory::Core)
&& e.superseded_by.is_none()
&& jaccard_similarity(&e.content, new_content) > threshold
&& e.content != new_content
})
.map(|e| e.id.clone())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jaccard_identical_strings() {
let sim = jaccard_similarity("hello world", "hello world");
assert!((sim - 1.0).abs() < f64::EPSILON);
}
#[test]
fn jaccard_disjoint_strings() {
let sim = jaccard_similarity("hello world", "foo bar");
assert!(sim.abs() < f64::EPSILON);
}
#[test]
fn jaccard_partial_overlap() {
let sim = jaccard_similarity("the quick brown fox", "the slow brown dog");
// overlap: "the", "brown" = 2; union: "the", "quick", "brown", "fox", "slow", "dog" = 6
assert!((sim - 2.0 / 6.0).abs() < 0.01);
}
#[test]
fn jaccard_empty_strings() {
assert!((jaccard_similarity("", "") - 1.0).abs() < f64::EPSILON);
assert!(jaccard_similarity("hello", "").abs() < f64::EPSILON);
assert!(jaccard_similarity("", "hello").abs() < f64::EPSILON);
}
#[test]
fn find_text_conflicts_filters_correctly() {
let entries = vec![
MemoryEntry {
id: "1".into(),
key: "pref".into(),
content: "User prefers Rust for systems work".into(),
category: MemoryCategory::Core,
timestamp: "now".into(),
session_id: None,
score: None,
namespace: "default".into(),
importance: Some(0.7),
superseded_by: None,
},
MemoryEntry {
id: "2".into(),
key: "daily1".into(),
content: "User prefers Rust for systems work".into(),
category: MemoryCategory::Daily,
timestamp: "now".into(),
session_id: None,
score: None,
namespace: "default".into(),
importance: Some(0.3),
superseded_by: None,
},
];
// Only Core entries should be flagged
let conflicts = find_text_conflicts(&entries, "User now prefers Go for systems work", 0.3);
assert_eq!(conflicts.len(), 1);
assert_eq!(conflicts[0], "1");
}
}
+28 -1
View File
@@ -8,6 +8,8 @@
//! This two-phase approach replaces the naive raw-message auto-save with
//! semantic extraction, similar to Nanobot's `save_memory` tool call pattern.
use crate::memory::conflict;
use crate::memory::importance;
use crate::memory::traits::{Memory, MemoryCategory};
use crate::providers::traits::Provider;
@@ -78,8 +80,33 @@ pub async fn consolidate_turn(
if let Some(ref update) = result.memory_update {
if !update.trim().is_empty() {
let mem_key = format!("core_{}", uuid::Uuid::new_v4());
// Compute importance score heuristically.
let imp = importance::compute_importance(update, &MemoryCategory::Core);
// Check for conflicts with existing Core memories.
if let Err(e) = conflict::check_and_resolve_conflicts(
memory,
&mem_key,
update,
&MemoryCategory::Core,
0.85,
)
.await
{
tracing::debug!("conflict check skipped: {e}");
}
// Store with importance metadata.
memory
.store(&mem_key, update, MemoryCategory::Core, None)
.store_with_metadata(
&mem_key,
update,
MemoryCategory::Core,
None,
None,
Some(imp),
)
.await?;
}
}
+151
View File
@@ -0,0 +1,151 @@
use super::traits::{MemoryCategory, MemoryEntry};
use chrono::{DateTime, Utc};
/// Default half-life in days for time-decay scoring.
/// After this many days, a non-Core memory's score drops to 50%.
pub const DEFAULT_HALF_LIFE_DAYS: f64 = 7.0;
/// Apply exponential time decay to memory entry scores.
///
/// - `Core` memories are exempt ("evergreen") — their scores are never decayed.
/// - Entries without a parseable RFC3339 timestamp are left unchanged.
/// - Entries without a score (`None`) are left unchanged.
///
/// Decay formula: `score * 2^(-age_days / half_life_days)`
pub fn apply_time_decay(entries: &mut [MemoryEntry], half_life_days: f64) {
let half_life = if half_life_days <= 0.0 {
DEFAULT_HALF_LIFE_DAYS
} else {
half_life_days
};
let now = Utc::now();
for entry in entries.iter_mut() {
// Core memories are evergreen — never decay
if entry.category == MemoryCategory::Core {
continue;
}
let score = match entry.score {
Some(s) => s,
None => continue,
};
let ts = match DateTime::parse_from_rfc3339(&entry.timestamp) {
Ok(dt) => dt.with_timezone(&Utc),
Err(_) => continue,
};
let age_days = now.signed_duration_since(ts).num_seconds().max(0) as f64 / 86_400.0;
let decay_factor = (-age_days / half_life * std::f64::consts::LN_2).exp();
entry.score = Some(score * decay_factor);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_entry(category: MemoryCategory, score: Option<f64>, timestamp: &str) -> MemoryEntry {
MemoryEntry {
id: "1".into(),
key: "test".into(),
content: "value".into(),
category,
timestamp: timestamp.into(),
session_id: None,
score,
namespace: "default".into(),
importance: None,
superseded_by: None,
}
}
fn recent_rfc3339() -> String {
Utc::now().to_rfc3339()
}
fn days_ago_rfc3339(days: i64) -> String {
(Utc::now() - chrono::Duration::days(days)).to_rfc3339()
}
#[test]
fn core_memories_are_never_decayed() {
let mut entries = vec![make_entry(
MemoryCategory::Core,
Some(0.9),
&days_ago_rfc3339(30),
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, Some(0.9));
}
#[test]
fn recent_entry_score_barely_changes() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(0.8),
&recent_rfc3339(),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.8).abs() < 0.01,
"recent entry should barely decay, got {decayed}"
);
}
#[test]
fn one_half_life_halves_score() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(1.0),
&days_ago_rfc3339(7),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.5).abs() < 0.05,
"score after one half-life should be ~0.5, got {decayed}"
);
}
#[test]
fn two_half_lives_quarters_score() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(1.0),
&days_ago_rfc3339(14),
)];
apply_time_decay(&mut entries, 7.0);
let decayed = entries[0].score.unwrap();
assert!(
(decayed - 0.25).abs() < 0.05,
"score after two half-lives should be ~0.25, got {decayed}"
);
}
#[test]
fn no_score_entry_is_unchanged() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
None,
&days_ago_rfc3339(30),
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, None);
}
#[test]
fn unparseable_timestamp_is_unchanged() {
let mut entries = vec![make_entry(
MemoryCategory::Conversation,
Some(0.9),
"not-a-date",
)];
apply_time_decay(&mut entries, 7.0);
assert_eq!(entries[0].score, Some(0.9));
}
}
+42 -4
View File
@@ -1,4 +1,5 @@
use crate::config::MemoryConfig;
use crate::memory::policy::PolicyEnforcer;
use anyhow::Result;
use chrono::{DateTime, Duration, Local, NaiveDate, Utc};
use rusqlite::{params, Connection};
@@ -47,6 +48,13 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
return Ok(());
}
// Use policy engine for per-category retention overrides.
let enforcer = PolicyEnforcer::new(&config.policy);
let conversation_retention = enforcer.retention_days_for_category(
&crate::memory::traits::MemoryCategory::Conversation,
config.conversation_retention_days,
);
let report = HygieneReport {
archived_memory_files: archive_daily_memory_files(
workspace_dir,
@@ -55,12 +63,16 @@ pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?,
purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?,
purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?,
pruned_conversation_rows: prune_conversation_rows(
workspace_dir,
config.conversation_retention_days,
)?,
pruned_conversation_rows: prune_conversation_rows(workspace_dir, conversation_retention)?,
};
// Prune audit entries if audit is enabled.
if config.audit_enabled {
if let Err(e) = prune_audit_entries(workspace_dir, config.audit_retention_days) {
tracing::debug!("audit pruning skipped: {e}");
}
}
write_state(workspace_dir, &report)?;
if report.total_actions() > 0 {
@@ -318,6 +330,32 @@ fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result<
Ok(u64::try_from(affected).unwrap_or(0))
}
fn prune_audit_entries(workspace_dir: &Path, retention_days: u32) -> Result<()> {
if retention_days == 0 {
return Ok(());
}
let db_path = workspace_dir.join("memory").join("audit.db");
if !db_path.exists() {
return Ok(());
}
let conn = Connection::open(db_path)?;
conn.execute_batch("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?;
let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339();
let affected = conn.execute(
"DELETE FROM memory_audit WHERE timestamp < ?1",
params![cutoff],
)?;
if affected > 0 {
tracing::debug!("pruned {affected} audit entries older than {retention_days} days");
}
Ok(())
}
fn memory_date_from_filename(filename: &str) -> Option<NaiveDate> {
let stem = filename.strip_suffix(".md")?;
let date_part = stem.split('_').next().unwrap_or(stem);
+107
View File
@@ -0,0 +1,107 @@
//! Heuristic importance scorer for non-LLM paths.
//!
//! Assigns importance scores (0.01.0) based on memory category and keyword
//! signals. Used when LLM-based consolidation is unavailable or as a fast
//! first-pass scorer.
use super::traits::MemoryCategory;
/// Base importance by category.
fn category_base_score(category: &MemoryCategory) -> f64 {
match category {
MemoryCategory::Core => 0.7,
MemoryCategory::Daily => 0.3,
MemoryCategory::Conversation => 0.2,
MemoryCategory::Custom(_) => 0.4,
}
}
/// Keyword boost: if the content contains high-signal keywords, bump importance.
fn keyword_boost(content: &str) -> f64 {
const HIGH_SIGNAL_KEYWORDS: &[&str] = &[
"decision",
"always",
"never",
"important",
"critical",
"must",
"requirement",
"policy",
"rule",
"principle",
];
let lowered = content.to_ascii_lowercase();
let matches = HIGH_SIGNAL_KEYWORDS
.iter()
.filter(|kw| lowered.contains(**kw))
.count();
// Cap at +0.2
(matches as f64 * 0.1).min(0.2)
}
/// Compute heuristic importance score for a memory entry.
pub fn compute_importance(content: &str, category: &MemoryCategory) -> f64 {
let base = category_base_score(category);
let boost = keyword_boost(content);
(base + boost).min(1.0)
}
/// Compute final retrieval score incorporating importance and recency.
///
/// `hybrid_score`: raw retrieval score from FTS/vector (0.01.0)
/// `importance`: importance score (0.01.0)
/// `recency_decay`: recency factor (0.01.0, 1.0 = very recent)
pub fn weighted_final_score(hybrid_score: f64, importance: f64, recency_decay: f64) -> f64 {
hybrid_score * 0.7 + importance * 0.2 + recency_decay * 0.1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn core_category_has_high_base_score() {
let score = compute_importance("some fact", &MemoryCategory::Core);
assert!((score - 0.7).abs() < f64::EPSILON);
}
#[test]
fn conversation_category_has_low_base_score() {
let score = compute_importance("chat message", &MemoryCategory::Conversation);
assert!((score - 0.2).abs() < f64::EPSILON);
}
#[test]
fn keywords_boost_importance() {
let score = compute_importance(
"This is a critical decision that must always be followed",
&MemoryCategory::Core,
);
// base 0.7 + boost for "critical", "decision", "must", "always" = 0.7 + 0.2 (capped) = 0.9
assert!(score > 0.85);
}
#[test]
fn boost_capped_at_point_two() {
let score = compute_importance(
"important critical decision rule policy must always never requirement principle",
&MemoryCategory::Conversation,
);
// base 0.2 + max boost 0.2 = 0.4
assert!((score - 0.4).abs() < f64::EPSILON);
}
#[test]
fn weighted_final_score_formula() {
let score = weighted_final_score(1.0, 1.0, 1.0);
assert!((score - 1.0).abs() < f64::EPSILON);
let score = weighted_final_score(0.0, 0.0, 0.0);
assert!(score.abs() < f64::EPSILON);
let score = weighted_final_score(0.5, 0.5, 0.5);
assert!((score - 0.5).abs() < f64::EPSILON);
}
}
+3
View File
@@ -226,6 +226,9 @@ impl LucidMemory {
timestamp: now.clone(),
session_id: None,
score: Some((1.0 - rank as f64 * 0.05).max(0.1)),
namespace: "default".into(),
importance: None,
superseded_by: None,
});
}
+3
View File
@@ -91,6 +91,9 @@ impl MarkdownMemory {
timestamp: filename.to_string(),
session_id: None,
score: None,
namespace: "default".into(),
importance: None,
superseded_by: None,
}
})
.collect()
-635
View File
@@ -1,635 +0,0 @@
//! Mem0 (OpenMemory) memory backend.
//!
//! Connects to a self-hosted OpenMemory server via its REST API
//! and implements the [`Memory`] trait for seamless integration with
//! ZeroClaw's auto-save, auto-recall, and hygiene lifecycle.
//!
//! Deploy OpenMemory: `docker compose up` from the mem0 repo.
//! Default endpoint: `http://localhost:8765`.
use super::traits::{Memory, MemoryCategory, MemoryEntry, ProceduralMessage};
use crate::config::schema::Mem0Config;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// Memory backend backed by a mem0 (OpenMemory) REST API.
pub struct Mem0Memory {
client: Client,
base_url: String,
user_id: String,
app_name: String,
infer: bool,
extraction_prompt: Option<String>,
}
// ── mem0 API request/response types ────────────────────────────────
#[derive(Serialize)]
struct AddMemoryRequest<'a> {
user_id: &'a str,
text: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<Mem0Metadata<'a>>,
infer: bool,
#[serde(skip_serializing_if = "Option::is_none")]
app: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
custom_instructions: Option<&'a str>,
}
#[derive(Serialize)]
struct Mem0Metadata<'a> {
key: &'a str,
category: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<&'a str>,
}
#[derive(Serialize)]
struct AddProceduralRequest<'a> {
user_id: &'a str,
messages: &'a [ProceduralMessage],
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<serde_json::Value>,
}
#[derive(Serialize)]
struct DeleteMemoriesRequest<'a> {
memory_ids: Vec<&'a str>,
user_id: &'a str,
}
#[derive(Deserialize)]
struct Mem0MemoryItem {
id: String,
#[serde(alias = "content", alias = "text", default)]
memory: String,
#[serde(default)]
created_at: Option<serde_json::Value>,
#[serde(default, rename = "metadata_")]
metadata: Option<Mem0ResponseMetadata>,
#[serde(alias = "relevance_score", default)]
score: Option<f64>,
}
#[derive(Deserialize, Default)]
struct Mem0ResponseMetadata {
#[serde(default)]
key: Option<String>,
#[serde(default)]
category: Option<String>,
#[serde(default)]
session_id: Option<String>,
}
#[derive(Deserialize)]
struct Mem0ListResponse {
#[serde(default)]
items: Vec<Mem0MemoryItem>,
#[serde(default)]
total: usize,
}
// ── Implementation ─────────────────────────────────────────────────
impl Mem0Memory {
/// Create a new mem0 memory backend from config.
pub fn new(config: &Mem0Config) -> anyhow::Result<Self> {
let base_url = config.url.trim_end_matches('/').to_string();
if base_url.is_empty() {
anyhow::bail!("mem0 URL is empty; set [memory.mem0] url or MEM0_URL env var");
}
let client = Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
base_url,
user_id: config.user_id.clone(),
app_name: config.app_name.clone(),
infer: config.infer,
extraction_prompt: config.extraction_prompt.clone(),
})
}
fn api_url(&self, path: &str) -> String {
format!("{}/api/v1{}", self.base_url, path)
}
/// Use `session_id` as the effective mem0 `user_id` when provided,
/// falling back to the configured default. This enables per-user
/// and per-group memory scoping via the existing `Memory` trait.
fn effective_user_id<'a>(&'a self, session_id: Option<&'a str>) -> &'a str {
session_id
.filter(|s| !s.trim().is_empty())
.unwrap_or(&self.user_id)
}
/// Recall memories with optional search filters.
///
/// - `created_after` / `created_before`: ISO 8601 timestamps for time-range filtering.
/// - `metadata_filter`: arbitrary JSON object passed to the mem0 SDK `filters` param.
pub async fn recall_filtered(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
created_after: Option<&str>,
created_before: Option<&str>,
metadata_filter: Option<&serde_json::Value>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let effective_user = self.effective_user_id(session_id);
let limit_str = limit.to_string();
let mut params: Vec<(&str, &str)> = vec![
("user_id", effective_user),
("search_query", query),
("size", &limit_str),
];
if let Some(after) = created_after {
params.push(("created_after", after));
}
if let Some(before) = created_before {
params.push(("created_before", before));
}
let meta_json;
if let Some(mf) = metadata_filter {
meta_json = serde_json::to_string(mf)?;
params.push(("metadata_filter", &meta_json));
}
let resp = self
.client
.get(self.api_url("/memories/"))
.query(&params)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("mem0 recall failed ({status}): {text}");
}
let list: Mem0ListResponse = resp.json().await?;
Ok(list.items.into_iter().map(|i| self.to_entry(i)).collect())
}
fn to_entry(&self, item: Mem0MemoryItem) -> MemoryEntry {
let meta = item.metadata.unwrap_or_default();
let timestamp = match item.created_at {
Some(serde_json::Value::Number(n)) => {
// Unix timestamp → ISO 8601
if let Some(ts) = n.as_i64() {
chrono::DateTime::from_timestamp(ts, 0)
.map(|dt| dt.to_rfc3339())
.unwrap_or_default()
} else {
String::new()
}
}
Some(serde_json::Value::String(s)) => s,
_ => String::new(),
};
let category = match meta.category.as_deref() {
Some("daily") => MemoryCategory::Daily,
Some("conversation") => MemoryCategory::Conversation,
Some(other) if other != "core" => MemoryCategory::Custom(other.to_string()),
// "core" or None → default
_ => MemoryCategory::Core,
};
MemoryEntry {
id: item.id,
key: meta.key.unwrap_or_default(),
content: item.memory,
category,
timestamp,
session_id: meta.session_id,
score: item.score,
}
}
/// Store a conversation trace as procedural memory.
///
/// Sends the message history (user input, tool calls, assistant response)
/// to the mem0 procedural endpoint so that "how to" patterns can be
/// extracted and stored for future recall.
pub async fn store_procedural(
&self,
messages: &[ProceduralMessage],
session_id: Option<&str>,
) -> anyhow::Result<()> {
if messages.is_empty() {
return Ok(());
}
let effective_user = self.effective_user_id(session_id);
let body = AddProceduralRequest {
user_id: effective_user,
messages,
metadata: None,
};
let resp = self
.client
.post(self.api_url("/memories/procedural"))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("mem0 store_procedural failed ({status}): {text}");
}
Ok(())
}
}
// ── History API types ─────────────────────────────────────────────
#[derive(Deserialize)]
struct Mem0HistoryResponse {
#[serde(default)]
history: Vec<serde_json::Value>,
#[serde(default)]
error: Option<String>,
}
impl Mem0Memory {
/// Retrieve the edit history (audit trail) for a specific memory by ID.
pub async fn history(&self, memory_id: &str) -> anyhow::Result<String> {
let url = self.api_url(&format!("/memories/{memory_id}/history"));
let resp = self.client.get(&url).send().await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("mem0 history failed ({status}): {text}");
}
let body: Mem0HistoryResponse = resp.json().await?;
if let Some(err) = body.error {
anyhow::bail!("mem0 history error: {err}");
}
if body.history.is_empty() {
return Ok(format!("No history found for memory {memory_id}."));
}
let mut lines = Vec::with_capacity(body.history.len() + 1);
lines.push(format!("History for memory {memory_id}:"));
for (i, entry) in body.history.iter().enumerate() {
let event = entry
.get("event")
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let old_memory = entry
.get("old_memory")
.and_then(|v| v.as_str())
.unwrap_or("-");
let new_memory = entry
.get("new_memory")
.and_then(|v| v.as_str())
.unwrap_or("-");
let timestamp = entry
.get("created_at")
.or_else(|| entry.get("timestamp"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
lines.push(format!(
" {idx}. [{event}] at {timestamp}\n old: {old_memory}\n new: {new_memory}",
idx = i + 1,
));
}
Ok(lines.join("\n"))
}
}
#[async_trait]
impl Memory for Mem0Memory {
fn name(&self) -> &str {
"mem0"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
let cat_str = category.to_string();
let effective_user = self.effective_user_id(session_id);
let body = AddMemoryRequest {
user_id: effective_user,
text: content,
metadata: Some(Mem0Metadata {
key,
category: &cat_str,
session_id,
}),
infer: self.infer,
app: Some(&self.app_name),
custom_instructions: self.extraction_prompt.as_deref(),
};
let resp = self
.client
.post(self.api_url("/memories/"))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("mem0 store failed ({status}): {text}");
}
Ok(())
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
_since: Option<&str>,
_until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
// mem0 handles filtering server-side; since/until are not yet
// supported by the mem0 API, so we pass them through as no-ops.
self.recall_filtered(query, limit, session_id, None, None, None)
.await
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
// mem0 doesn't have a get-by-key API, so we search by key in metadata
let results = self.recall(key, 1, None, None, None).await?;
Ok(results.into_iter().find(|e| e.key == key))
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let effective_user = self.effective_user_id(session_id);
let resp = self
.client
.get(self.api_url("/memories/"))
.query(&[("user_id", effective_user), ("size", "100")])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("mem0 list failed ({status}): {text}");
}
let list: Mem0ListResponse = resp.json().await?;
let entries: Vec<MemoryEntry> = list.items.into_iter().map(|i| self.to_entry(i)).collect();
// Client-side category filter (mem0 API doesn't filter by metadata)
match category {
Some(cat) => Ok(entries.into_iter().filter(|e| &e.category == cat).collect()),
None => Ok(entries),
}
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
// Find the memory ID by key first
let entry = self.get(key).await?;
let entry = match entry {
Some(e) => e,
None => return Ok(false),
};
let body = DeleteMemoriesRequest {
memory_ids: vec![&entry.id],
user_id: &self.user_id,
};
let resp = self
.client
.delete(self.api_url("/memories/"))
.json(&body)
.send()
.await?;
Ok(resp.status().is_success())
}
async fn count(&self) -> anyhow::Result<usize> {
let resp = self
.client
.get(self.api_url("/memories/"))
.query(&[
("user_id", self.user_id.as_str()),
("size", "1"),
("page", "1"),
])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("mem0 count failed ({status}): {text}");
}
let list: Mem0ListResponse = resp.json().await?;
Ok(list.total)
}
async fn health_check(&self) -> bool {
self.client
.get(self.api_url("/memories/"))
.query(&[
("user_id", self.user_id.as_str()),
("size", "1"),
("page", "1"),
])
.send()
.await
.is_ok_and(|r| r.status().is_success())
}
async fn store_procedural(
&self,
messages: &[ProceduralMessage],
session_id: Option<&str>,
) -> anyhow::Result<()> {
Mem0Memory::store_procedural(self, messages, session_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> Mem0Config {
Mem0Config {
url: "http://localhost:8765".into(),
user_id: "test-user".into(),
app_name: "test-app".into(),
infer: true,
extraction_prompt: None,
}
}
#[test]
fn new_rejects_empty_url() {
let config = Mem0Config {
url: String::new(),
..test_config()
};
assert!(Mem0Memory::new(&config).is_err());
}
#[test]
fn new_trims_trailing_slash() {
let config = Mem0Config {
url: "http://localhost:8765/".into(),
..test_config()
};
let mem = Mem0Memory::new(&config).unwrap();
assert_eq!(mem.base_url, "http://localhost:8765");
}
#[test]
fn api_url_builds_correct_path() {
let mem = Mem0Memory::new(&test_config()).unwrap();
assert_eq!(
mem.api_url("/memories/"),
"http://localhost:8765/api/v1/memories/"
);
}
#[test]
fn to_entry_maps_unix_timestamp() {
let mem = Mem0Memory::new(&test_config()).unwrap();
let item = Mem0MemoryItem {
id: "id-1".into(),
memory: "hello".into(),
created_at: Some(serde_json::json!(1_700_000_000)),
metadata: Some(Mem0ResponseMetadata {
key: Some("k1".into()),
category: Some("core".into()),
session_id: None,
}),
score: Some(0.95),
};
let entry = mem.to_entry(item);
assert_eq!(entry.id, "id-1");
assert_eq!(entry.key, "k1");
assert_eq!(entry.category, MemoryCategory::Core);
assert!(!entry.timestamp.is_empty());
assert_eq!(entry.score, Some(0.95));
}
#[test]
fn to_entry_maps_string_timestamp() {
let mem = Mem0Memory::new(&test_config()).unwrap();
let item = Mem0MemoryItem {
id: "id-2".into(),
memory: "world".into(),
created_at: Some(serde_json::json!("2024-01-01T00:00:00Z")),
metadata: None,
score: None,
};
let entry = mem.to_entry(item);
assert_eq!(entry.timestamp, "2024-01-01T00:00:00Z");
assert_eq!(entry.category, MemoryCategory::Core); // default
}
#[test]
fn to_entry_handles_missing_metadata() {
let mem = Mem0Memory::new(&test_config()).unwrap();
let item = Mem0MemoryItem {
id: "id-3".into(),
memory: "bare".into(),
created_at: None,
metadata: None,
score: None,
};
let entry = mem.to_entry(item);
assert_eq!(entry.key, "");
assert_eq!(entry.category, MemoryCategory::Core);
assert!(entry.timestamp.is_empty());
assert_eq!(entry.score, None);
}
#[test]
fn to_entry_custom_category() {
let mem = Mem0Memory::new(&test_config()).unwrap();
let item = Mem0MemoryItem {
id: "id-4".into(),
memory: "custom".into(),
created_at: None,
metadata: Some(Mem0ResponseMetadata {
key: Some("k".into()),
category: Some("project_notes".into()),
session_id: Some("s1".into()),
}),
score: None,
};
let entry = mem.to_entry(item);
assert_eq!(
entry.category,
MemoryCategory::Custom("project_notes".into())
);
assert_eq!(entry.session_id.as_deref(), Some("s1"));
}
#[test]
fn name_returns_mem0() {
let mem = Mem0Memory::new(&test_config()).unwrap();
assert_eq!(mem.name(), "mem0");
}
#[test]
fn procedural_request_serializes_messages() {
let messages = vec![
ProceduralMessage {
role: "user".into(),
content: "How do I deploy?".into(),
name: None,
},
ProceduralMessage {
role: "tool".into(),
content: "deployment started".into(),
name: Some("shell".into()),
},
ProceduralMessage {
role: "assistant".into(),
content: "Deployment complete.".into(),
name: None,
},
];
let req = AddProceduralRequest {
user_id: "test-user",
messages: &messages,
metadata: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["user_id"], "test-user");
let msgs = json["messages"].as_array().unwrap();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0]["role"], "user");
assert_eq!(msgs[1]["name"], "shell");
// metadata should be absent when None
assert!(json.get("metadata").is_none());
}
}
+16 -27
View File
@@ -1,24 +1,33 @@
pub mod audit;
pub mod backend;
pub mod chunker;
pub mod cli;
pub mod conflict;
pub mod consolidation;
pub mod decay;
pub mod embeddings;
pub mod hygiene;
pub mod importance;
pub mod knowledge_graph;
pub mod lucid;
pub mod markdown;
#[cfg(feature = "memory-mem0")]
pub mod mem0;
pub mod none;
pub mod policy;
#[cfg(feature = "memory-postgres")]
pub mod postgres;
pub mod qdrant;
pub mod response_cache;
pub mod retrieval;
pub mod snapshot;
pub mod sqlite;
pub mod traits;
pub mod vector;
#[cfg(test)]
mod battle_tests;
#[allow(unused_imports)]
pub use audit::AuditedMemory;
#[allow(unused_imports)]
pub use backend::{
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
@@ -26,13 +35,15 @@ pub use backend::{
};
pub use lucid::LucidMemory;
pub use markdown::MarkdownMemory;
#[cfg(feature = "memory-mem0")]
pub use mem0::Mem0Memory;
pub use none::NoneMemory;
#[allow(unused_imports)]
pub use policy::PolicyEnforcer;
#[cfg(feature = "memory-postgres")]
pub use postgres::PostgresMemory;
pub use qdrant::QdrantMemory;
pub use response_cache::ResponseCache;
#[allow(unused_imports)]
pub use retrieval::{RetrievalConfig, RetrievalPipeline};
pub use sqlite::SqliteMemory;
pub use traits::Memory;
#[allow(unused_imports)]
@@ -61,7 +72,7 @@ where
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
}
MemoryBackendKind::Postgres => postgres_builder(),
MemoryBackendKind::Qdrant | MemoryBackendKind::Mem0 | MemoryBackendKind::Markdown => {
MemoryBackendKind::Qdrant | MemoryBackendKind::Markdown => {
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
}
MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())),
@@ -340,28 +351,6 @@ pub fn create_memory_with_storage_and_routes(
);
}
#[cfg(feature = "memory-mem0")]
fn build_mem0_memory(config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
let mem = Mem0Memory::new(&config.mem0)?;
tracing::info!(
"📦 Mem0 memory backend configured (url: {}, user: {})",
config.mem0.url,
config.mem0.user_id
);
Ok(Box::new(mem))
}
#[cfg(not(feature = "memory-mem0"))]
fn build_mem0_memory(_config: &crate::config::MemoryConfig) -> anyhow::Result<Box<dyn Memory>> {
anyhow::bail!(
"memory backend 'mem0' requested but this build was compiled without `memory-mem0`; rebuild with `--features memory-mem0`"
);
}
if matches!(backend_kind, MemoryBackendKind::Mem0) {
return build_mem0_memory(config);
}
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
let url = config
.qdrant
+192
View File
@@ -0,0 +1,192 @@
//! Policy engine for memory operations.
//!
//! Validates operations against configurable rules before they reach the
//! backend. Enforces namespace quotas, category limits, read-only namespaces,
//! and per-category retention rules.
use super::traits::MemoryCategory;
use crate::config::MemoryPolicyConfig;
/// Policy enforcer that validates memory operations.
pub struct PolicyEnforcer {
config: MemoryPolicyConfig,
}
impl PolicyEnforcer {
pub fn new(config: &MemoryPolicyConfig) -> Self {
Self {
config: config.clone(),
}
}
/// Check if a namespace is read-only.
pub fn is_read_only(&self, namespace: &str) -> bool {
self.config
.read_only_namespaces
.iter()
.any(|ns| ns == namespace)
}
/// Validate a store operation against policy rules.
pub fn validate_store(
&self,
namespace: &str,
_category: &MemoryCategory,
) -> Result<(), PolicyViolation> {
if self.is_read_only(namespace) {
return Err(PolicyViolation::ReadOnlyNamespace(namespace.to_string()));
}
Ok(())
}
/// Check if adding an entry would exceed namespace limits.
pub fn check_namespace_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
if self.config.max_entries_per_namespace > 0
&& current_count >= self.config.max_entries_per_namespace
{
return Err(PolicyViolation::NamespaceQuotaExceeded {
max: self.config.max_entries_per_namespace,
current: current_count,
});
}
Ok(())
}
/// Check if adding an entry would exceed category limits.
pub fn check_category_limit(&self, current_count: usize) -> Result<(), PolicyViolation> {
if self.config.max_entries_per_category > 0
&& current_count >= self.config.max_entries_per_category
{
return Err(PolicyViolation::CategoryQuotaExceeded {
max: self.config.max_entries_per_category,
current: current_count,
});
}
Ok(())
}
/// Get the retention days for a specific category, falling back to the
/// provided default if no per-category override exists.
pub fn retention_days_for_category(&self, category: &MemoryCategory, default_days: u32) -> u32 {
let key = category.to_string();
self.config
.retention_days_by_category
.get(&key)
.copied()
.unwrap_or(default_days)
}
}
/// Policy violation errors.
#[derive(Debug, Clone)]
pub enum PolicyViolation {
ReadOnlyNamespace(String),
NamespaceQuotaExceeded { max: usize, current: usize },
CategoryQuotaExceeded { max: usize, current: usize },
}
impl std::fmt::Display for PolicyViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ReadOnlyNamespace(ns) => write!(f, "namespace '{ns}' is read-only"),
Self::NamespaceQuotaExceeded { max, current } => {
write!(f, "namespace quota exceeded: {current}/{max} entries")
}
Self::CategoryQuotaExceeded { max, current } => {
write!(f, "category quota exceeded: {current}/{max} entries")
}
}
}
}
impl std::error::Error for PolicyViolation {}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn empty_policy() -> MemoryPolicyConfig {
MemoryPolicyConfig::default()
}
#[test]
fn default_policy_allows_everything() {
let enforcer = PolicyEnforcer::new(&empty_policy());
assert!(!enforcer.is_read_only("default"));
assert!(enforcer
.validate_store("default", &MemoryCategory::Core)
.is_ok());
assert!(enforcer.check_namespace_limit(100).is_ok());
assert!(enforcer.check_category_limit(100).is_ok());
}
#[test]
fn read_only_namespace_blocks_writes() {
let policy = MemoryPolicyConfig {
read_only_namespaces: vec!["archive".into()],
..empty_policy()
};
let enforcer = PolicyEnforcer::new(&policy);
assert!(enforcer.is_read_only("archive"));
assert!(!enforcer.is_read_only("default"));
assert!(enforcer
.validate_store("archive", &MemoryCategory::Core)
.is_err());
assert!(enforcer
.validate_store("default", &MemoryCategory::Core)
.is_ok());
}
#[test]
fn namespace_quota_enforced() {
let policy = MemoryPolicyConfig {
max_entries_per_namespace: 10,
..empty_policy()
};
let enforcer = PolicyEnforcer::new(&policy);
assert!(enforcer.check_namespace_limit(5).is_ok());
assert!(enforcer.check_namespace_limit(10).is_err());
assert!(enforcer.check_namespace_limit(15).is_err());
}
#[test]
fn category_quota_enforced() {
let policy = MemoryPolicyConfig {
max_entries_per_category: 50,
..empty_policy()
};
let enforcer = PolicyEnforcer::new(&policy);
assert!(enforcer.check_category_limit(25).is_ok());
assert!(enforcer.check_category_limit(50).is_err());
}
#[test]
fn per_category_retention_overrides_default() {
let mut retention = HashMap::new();
retention.insert("core".into(), 365);
retention.insert("conversation".into(), 7);
let policy = MemoryPolicyConfig {
retention_days_by_category: retention,
..empty_policy()
};
let enforcer = PolicyEnforcer::new(&policy);
assert_eq!(
enforcer.retention_days_for_category(&MemoryCategory::Core, 30),
365
);
assert_eq!(
enforcer.retention_days_for_category(&MemoryCategory::Conversation, 30),
7
);
assert_eq!(
enforcer.retention_days_for_category(&MemoryCategory::Daily, 30),
30
);
}
}
+12 -3
View File
@@ -100,6 +100,8 @@ impl PostgresMemory {
CREATE INDEX IF NOT EXISTS idx_memories_category ON {qualified_table}(category);
CREATE INDEX IF NOT EXISTS idx_memories_session_id ON {qualified_table}(session_id);
CREATE INDEX IF NOT EXISTS idx_memories_updated_at ON {qualified_table}(updated_at DESC);
CREATE INDEX IF NOT EXISTS idx_memories_content_fts ON {qualified_table} USING gin(to_tsvector('simple', content));
CREATE INDEX IF NOT EXISTS idx_memories_key_fts ON {qualified_table} USING gin(to_tsvector('simple', key));
"
))?;
@@ -135,6 +137,9 @@ impl PostgresMemory {
timestamp: timestamp.to_rfc3339(),
session_id: row.get(5),
score: row.try_get(6).ok(),
namespace: "default".into(),
importance: None,
superseded_by: None,
})
}
}
@@ -267,12 +272,16 @@ impl Memory for PostgresMemory {
"
SELECT id, key, content, category, created_at, session_id,
(
CASE WHEN key ILIKE '%' || $1 || '%' THEN 2.0 ELSE 0.0 END +
CASE WHEN content ILIKE '%' || $1 || '%' THEN 1.0 ELSE 0.0 END
CASE WHEN to_tsvector('simple', key) @@ plainto_tsquery('simple', $1)
THEN ts_rank_cd(to_tsvector('simple', key), plainto_tsquery('simple', $1)) * 2.0
ELSE 0.0 END +
CASE WHEN to_tsvector('simple', content) @@ plainto_tsquery('simple', $1)
THEN ts_rank_cd(to_tsvector('simple', content), plainto_tsquery('simple', $1))
ELSE 0.0 END
) AS score
FROM {qualified_table}
WHERE ($2::TEXT IS NULL OR session_id = $2)
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
AND ($1 = '' OR to_tsvector('simple', key || ' ' || content) @@ plainto_tsquery('simple', $1))
{time_filter}
ORDER BY score DESC, updated_at DESC
LIMIT $3
+9
View File
@@ -373,6 +373,9 @@ impl Memory for QdrantMemory {
timestamp: payload.timestamp,
session_id: payload.session_id,
score: Some(point.score),
namespace: "default".into(),
importance: None,
superseded_by: None,
})
})
.collect();
@@ -437,6 +440,9 @@ impl Memory for QdrantMemory {
timestamp: payload.timestamp,
session_id: payload.session_id,
score: None,
namespace: "default".into(),
importance: None,
superseded_by: None,
})
});
@@ -514,6 +520,9 @@ impl Memory for QdrantMemory {
timestamp: payload.timestamp,
session_id: payload.session_id,
score: None,
namespace: "default".into(),
importance: None,
superseded_by: None,
})
})
.collect();
+267
View File
@@ -0,0 +1,267 @@
//! Multi-stage retrieval pipeline.
//!
//! Wraps a `Memory` trait object with staged retrieval:
//! - **Stage 1 (Hot cache):** In-memory LRU of recent recall results.
//! - **Stage 2 (FTS):** FTS5 keyword search with optional early-return.
//! - **Stage 3 (Vector):** Vector similarity search + hybrid merge.
//!
//! Configurable via `[memory]` settings: `retrieval_stages`, `fts_early_return_score`.
use super::traits::{Memory, MemoryEntry};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
/// A cached recall result.
struct CachedResult {
entries: Vec<MemoryEntry>,
created_at: Instant,
}
/// Multi-stage retrieval pipeline configuration.
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
/// Ordered list of stages: "cache", "fts", "vector".
pub stages: Vec<String>,
/// FTS score above which to early-return without vector stage.
pub fts_early_return_score: f64,
/// Max entries in the hot cache.
pub cache_max_entries: usize,
/// TTL for cached results.
pub cache_ttl: Duration,
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
stages: vec!["cache".into(), "fts".into(), "vector".into()],
fts_early_return_score: 0.85,
cache_max_entries: 256,
cache_ttl: Duration::from_secs(300),
}
}
}
/// Multi-stage retrieval pipeline wrapping a `Memory` backend.
pub struct RetrievalPipeline {
memory: Arc<dyn Memory>,
config: RetrievalConfig,
hot_cache: Mutex<HashMap<String, CachedResult>>,
}
impl RetrievalPipeline {
pub fn new(memory: Arc<dyn Memory>, config: RetrievalConfig) -> Self {
Self {
memory,
config,
hot_cache: Mutex::new(HashMap::new()),
}
}
/// Build a cache key from query parameters.
fn cache_key(
query: &str,
limit: usize,
session_id: Option<&str>,
namespace: Option<&str>,
) -> String {
format!(
"{}:{}:{}:{}",
query,
limit,
session_id.unwrap_or(""),
namespace.unwrap_or("")
)
}
/// Check the hot cache for a previous result.
fn check_cache(&self, key: &str) -> Option<Vec<MemoryEntry>> {
let cache = self.hot_cache.lock();
if let Some(cached) = cache.get(key) {
if cached.created_at.elapsed() < self.config.cache_ttl {
return Some(cached.entries.clone());
}
}
None
}
/// Store a result in the hot cache with LRU eviction.
fn store_in_cache(&self, key: String, entries: Vec<MemoryEntry>) {
let mut cache = self.hot_cache.lock();
// LRU eviction: remove oldest entries if at capacity
if cache.len() >= self.config.cache_max_entries {
let oldest_key = cache
.iter()
.min_by_key(|(_, v)| v.created_at)
.map(|(k, _)| k.clone());
if let Some(k) = oldest_key {
cache.remove(&k);
}
}
cache.insert(
key,
CachedResult {
entries,
created_at: Instant::now(),
},
);
}
/// Execute the multi-stage retrieval pipeline.
pub async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
namespace: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let ck = Self::cache_key(query, limit, session_id, namespace);
for stage in &self.config.stages {
match stage.as_str() {
"cache" => {
if let Some(cached) = self.check_cache(&ck) {
tracing::debug!("retrieval pipeline: cache hit for '{query}'");
return Ok(cached);
}
}
"fts" | "vector" => {
// Both FTS and vector are handled by the backend's recall method
// which already does hybrid merge. We delegate to it.
let results = if let Some(ns) = namespace {
self.memory
.recall_namespaced(ns, query, limit, session_id, since, until)
.await?
} else {
self.memory
.recall(query, limit, session_id, since, until)
.await?
};
if !results.is_empty() {
// Check for FTS early-return: if top score exceeds threshold
// and we're in the FTS stage, we can skip further stages
if stage == "fts" {
if let Some(top_score) = results.first().and_then(|e| e.score) {
if top_score >= self.config.fts_early_return_score {
tracing::debug!(
"retrieval pipeline: FTS early return (score={top_score:.3})"
);
self.store_in_cache(ck, results.clone());
return Ok(results);
}
}
}
self.store_in_cache(ck, results.clone());
return Ok(results);
}
}
other => {
tracing::warn!("retrieval pipeline: unknown stage '{other}', skipping");
}
}
}
// No results from any stage
Ok(Vec::new())
}
/// Invalidate the hot cache (e.g. after a store operation).
pub fn invalidate_cache(&self) {
self.hot_cache.lock().clear();
}
/// Get the number of entries in the hot cache.
pub fn cache_size(&self) -> usize {
self.hot_cache.lock().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::NoneMemory;
#[tokio::test]
async fn pipeline_returns_empty_from_none_backend() {
let memory = Arc::new(NoneMemory::new());
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
let results = pipeline
.recall("test", 10, None, None, None, None)
.await
.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn pipeline_cache_invalidation() {
let memory = Arc::new(NoneMemory::new());
let pipeline = RetrievalPipeline::new(memory, RetrievalConfig::default());
// Force a cache entry
let ck = RetrievalPipeline::cache_key("test", 10, None, None);
pipeline.store_in_cache(ck, vec![]);
assert_eq!(pipeline.cache_size(), 1);
pipeline.invalidate_cache();
assert_eq!(pipeline.cache_size(), 0);
}
#[test]
fn cache_key_includes_all_params() {
let k1 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns1"));
let k2 = RetrievalPipeline::cache_key("hello", 10, Some("sess-b"), Some("ns1"));
let k3 = RetrievalPipeline::cache_key("hello", 10, Some("sess-a"), Some("ns2"));
assert_ne!(k1, k2);
assert_ne!(k1, k3);
}
#[tokio::test]
async fn pipeline_caches_results() {
let memory = Arc::new(NoneMemory::new());
let config = RetrievalConfig {
stages: vec!["cache".into()],
..Default::default()
};
let pipeline = RetrievalPipeline::new(memory, config);
// First call: cache miss, no results
let results = pipeline
.recall("test", 10, None, None, None, None)
.await
.unwrap();
assert!(results.is_empty());
// Manually insert a cache entry
let ck = RetrievalPipeline::cache_key("cached_query", 5, None, None);
let fake_entry = MemoryEntry {
id: "1".into(),
key: "k".into(),
content: "cached content".into(),
category: crate::memory::MemoryCategory::Core,
timestamp: "now".into(),
session_id: None,
score: Some(0.9),
namespace: "default".into(),
importance: None,
superseded_by: None,
};
pipeline.store_in_cache(ck, vec![fake_entry]);
// Cache hit
let results = pipeline
.recall("cached_query", 5, None, None, None, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].content, "cached content");
}
}
+154 -23
View File
@@ -46,6 +46,31 @@ impl SqliteMemory {
)
}
/// Like `new`, but stores data in `{db_name}.db` instead of `brain.db`.
pub fn new_named(workspace_dir: &Path, db_name: &str) -> anyhow::Result<Self> {
let db_path = workspace_dir.join("memory").join(format!("{db_name}.db"));
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Self::open_connection(&db_path, None)?;
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA mmap_size = 8388608;
PRAGMA cache_size = -2000;
PRAGMA temp_store = MEMORY;",
)?;
Self::init_schema(&conn)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
db_path,
embedder: Arc::new(super::embeddings::NoopEmbedding),
vector_weight: 0.7,
keyword_weight: 0.3,
cache_max: 10_000,
})
}
/// Build SQLite memory with optional open timeout.
///
/// If `open_timeout_secs` is `Some(n)`, opening the database is limited to `n` seconds
@@ -172,17 +197,35 @@ impl SqliteMemory {
)?;
// Migration: add session_id column if not present (safe to run repeatedly)
let has_session_id: bool = conn
let schema_sql: String = conn
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
.query_row([], |row| row.get::<_, String>(0))?
.contains("session_id");
if !has_session_id {
.query_row([], |row| row.get::<_, String>(0))?;
if !schema_sql.contains("session_id") {
conn.execute_batch(
"ALTER TABLE memories ADD COLUMN session_id TEXT;
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
)?;
}
// Migration: add namespace column
if !schema_sql.contains("namespace") {
conn.execute_batch(
"ALTER TABLE memories ADD COLUMN namespace TEXT DEFAULT 'default';
CREATE INDEX IF NOT EXISTS idx_memories_namespace ON memories(namespace);",
)?;
}
// Migration: add importance column
if !schema_sql.contains("importance") {
conn.execute_batch("ALTER TABLE memories ADD COLUMN importance REAL DEFAULT 0.5;")?;
}
// Migration: add superseded_by column
if !schema_sql.contains("superseded_by") {
conn.execute_batch("ALTER TABLE memories ADD COLUMN superseded_by TEXT;")?;
}
Ok(())
}
@@ -221,8 +264,13 @@ impl SqliteMemory {
)
}
/// Provide access to the connection for advanced queries (e.g. retrieval pipeline).
pub fn connection(&self) -> &Arc<Mutex<Connection>> {
&self.conn
}
/// Get embedding from cache, or compute + cache it
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
pub async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
if self.embedder.dimensions() == 0 {
return Ok(None); // Noop embedder
}
@@ -285,7 +333,7 @@ impl SqliteMemory {
}
/// FTS5 BM25 keyword search
fn fts5_search(
pub fn fts5_search(
conn: &Connection,
query: &str,
limit: usize,
@@ -331,7 +379,7 @@ impl SqliteMemory {
///
/// Optional `category` and `session_id` filters reduce full-table scans
/// when the caller already knows the scope of relevant memories.
fn vector_search(
pub fn vector_search(
conn: &Connection,
query_embedding: &[f32],
limit: usize,
@@ -448,8 +496,8 @@ impl SqliteMemory {
let until_ref = until_owned.as_deref();
let mut sql =
"SELECT id, key, content, category, created_at, session_id FROM memories \
WHERE 1=1"
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories \
WHERE superseded_by IS NULL AND 1=1"
.to_string();
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut idx = 1;
@@ -485,6 +533,9 @@ impl SqliteMemory {
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: None,
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
importance: row.get(7)?,
superseded_by: row.get(8)?,
})
})?;
@@ -529,8 +580,8 @@ impl Memory for SqliteMemory {
let id = Uuid::new_v4().to_string();
conn.execute(
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 'default', 0.5)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
@@ -616,8 +667,8 @@ impl Memory for SqliteMemory {
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT id, key, content, category, created_at, session_id \
FROM memories WHERE id IN ({placeholders})"
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by \
FROM memories WHERE superseded_by IS NULL AND id IN ({placeholders})"
);
let mut stmt = conn.prepare(&sql)?;
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
@@ -634,17 +685,20 @@ impl Memory for SqliteMemory {
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
row.get::<_, Option<String>>(5)?,
row.get::<_, Option<String>>(6)?,
row.get::<_, Option<f64>>(7)?,
row.get::<_, Option<String>>(8)?,
))
})?;
let mut entry_map = std::collections::HashMap::new();
for row in rows {
let (id, key, content, cat, ts, sid) = row?;
entry_map.insert(id, (key, content, cat, ts, sid));
let (id, key, content, cat, ts, sid, ns, imp, sup) = row?;
entry_map.insert(id, (key, content, cat, ts, sid, ns, imp, sup));
}
for scored in &merged {
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
if let Some((key, content, cat, ts, sid, ns, imp, sup)) = entry_map.remove(&scored.id) {
if let Some(s) = since_ref {
if ts.as_str() < s {
continue;
@@ -663,6 +717,9 @@ impl Memory for SqliteMemory {
timestamp: ts,
session_id: sid,
score: Some(f64::from(scored.final_score)),
namespace: ns.unwrap_or_else(|| "default".into()),
importance: imp,
superseded_by: sup,
};
if let Some(filter_sid) = session_ref {
if entry.session_id.as_deref() != Some(filter_sid) {
@@ -702,8 +759,8 @@ impl Memory for SqliteMemory {
param_idx += 1;
}
let sql = format!(
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE {where_clause}{time_conditions}
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
WHERE superseded_by IS NULL AND ({where_clause}){time_conditions}
ORDER BY updated_at DESC
LIMIT ?{param_idx}"
);
@@ -732,6 +789,9 @@ impl Memory for SqliteMemory {
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: Some(1.0),
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
importance: row.get(7)?,
superseded_by: row.get(8)?,
})
})?;
for row in rows {
@@ -759,7 +819,7 @@ impl Memory for SqliteMemory {
tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
let conn = conn.lock();
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories WHERE key = ?1",
)?;
let mut rows = stmt.query_map(params![key], |row| {
@@ -771,6 +831,9 @@ impl Memory for SqliteMemory {
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: None,
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
importance: row.get(7)?,
superseded_by: row.get(8)?,
})
})?;
@@ -807,14 +870,17 @@ impl Memory for SqliteMemory {
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: None,
namespace: row.get::<_, Option<String>>(6)?.unwrap_or_else(|| "default".into()),
importance: row.get(7)?,
superseded_by: row.get(8)?,
})
};
if let Some(ref cat) = category {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
WHERE superseded_by IS NULL AND category = ?1 ORDER BY updated_at DESC LIMIT ?2",
)?;
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows {
@@ -828,8 +894,8 @@ impl Memory for SqliteMemory {
}
} else {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC LIMIT ?1",
"SELECT id, key, content, category, created_at, session_id, namespace, importance, superseded_by FROM memories
WHERE superseded_by IS NULL ORDER BY updated_at DESC LIMIT ?1",
)?;
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows {
@@ -879,6 +945,71 @@ impl Memory for SqliteMemory {
.await
.unwrap_or(false)
}
async fn recall_namespaced(
&self,
namespace: &str,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let entries = self
.recall(query, limit * 2, session_id, since, until)
.await?;
let filtered: Vec<MemoryEntry> = entries
.into_iter()
.filter(|e| e.namespace == namespace)
.take(limit)
.collect();
Ok(filtered)
}
async fn store_with_metadata(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
namespace: Option<&str>,
importance: Option<f64>,
) -> anyhow::Result<()> {
let embedding_bytes = self
.get_or_compute_embedding(content)
.await?
.map(|emb| vector::vec_to_bytes(&emb));
let conn = self.conn.clone();
let key = key.to_string();
let content = content.to_string();
let sid = session_id.map(String::from);
let ns = namespace.unwrap_or("default").to_string();
let imp = importance.unwrap_or(0.5);
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = conn.lock();
let now = Local::now().to_rfc3339();
let cat = Self::category_to_str(&category);
let id = Uuid::new_v4().to_string();
conn.execute(
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id, namespace, importance)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
embedding = excluded.embedding,
updated_at = excluded.updated_at,
session_id = excluded.session_id,
namespace = excluded.namespace,
importance = excluded.importance",
params![id, key, content, cat, embedding_bytes, now, now, sid, ns, imp],
)?;
Ok(())
})
.await?
}
}
#[cfg(test)]
+63 -2
View File
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
/// A single message in a conversation trace for procedural memory.
///
/// Used to capture "how to" patterns from tool-calling turns so that
/// backends that support procedural storage (e.g. mem0) can learn from them.
/// backends that support procedural storage can learn from them.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ProceduralMessage {
pub role: String,
@@ -23,6 +23,19 @@ pub struct MemoryEntry {
pub timestamp: String,
pub session_id: Option<String>,
pub score: Option<f64>,
/// Namespace for isolation between agents/contexts.
#[serde(default = "default_namespace")]
pub namespace: String,
/// Importance score (0.01.0) for prioritized retrieval.
#[serde(default)]
pub importance: Option<f64>,
/// If this entry was superseded by a newer conflicting entry.
#[serde(default)]
pub superseded_by: Option<String>,
}
fn default_namespace() -> String {
"default".into()
}
impl std::fmt::Debug for MemoryEntry {
@@ -34,6 +47,8 @@ impl std::fmt::Debug for MemoryEntry {
.field("category", &self.category)
.field("timestamp", &self.timestamp)
.field("score", &self.score)
.field("namespace", &self.namespace)
.field("importance", &self.importance)
.finish_non_exhaustive()
}
}
@@ -128,7 +143,7 @@ pub trait Memory: Send + Sync {
/// Store a conversation trace as procedural memory.
///
/// Backends that support procedural storage (e.g. mem0) override this
/// Backends that support procedural storage override this
/// to extract "how to" patterns from tool-calling turns. The default
/// implementation is a no-op.
async fn store_procedural(
@@ -138,6 +153,46 @@ pub trait Memory: Send + Sync {
) -> anyhow::Result<()> {
Ok(())
}
/// Recall memories scoped to a specific namespace.
///
/// Default implementation delegates to `recall()` and filters by namespace.
/// Backends with native namespace support should override for efficiency.
async fn recall_namespaced(
&self,
namespace: &str,
query: &str,
limit: usize,
session_id: Option<&str>,
since: Option<&str>,
until: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let entries = self
.recall(query, limit * 2, session_id, since, until)
.await?;
let filtered: Vec<MemoryEntry> = entries
.into_iter()
.filter(|e| e.namespace == namespace)
.take(limit)
.collect();
Ok(filtered)
}
/// Store a memory entry with namespace and importance.
///
/// Default implementation delegates to `store()`. Backends with native
/// namespace/importance support should override.
async fn store_with_metadata(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
_namespace: Option<&str>,
_importance: Option<f64>,
) -> anyhow::Result<()> {
self.store(key, content, category, session_id).await
}
}
#[cfg(test)]
@@ -185,6 +240,9 @@ mod tests {
timestamp: "2026-02-16T00:00:00Z".into(),
session_id: Some("session-abc".into()),
score: Some(0.98),
namespace: "default".into(),
importance: Some(0.7),
superseded_by: None,
};
let json = serde_json::to_string(&entry).unwrap();
@@ -196,5 +254,8 @@ mod tests {
assert_eq!(parsed.category, MemoryCategory::Core);
assert_eq!(parsed.session_id.as_deref(), Some("session-abc"));
assert_eq!(parsed.score, Some(0.98));
assert_eq!(parsed.namespace, "default");
assert_eq!(parsed.importance, Some(0.7));
assert!(parsed.superseded_by.is_none());
}
}
+1
View File
@@ -126,6 +126,7 @@ pub fn hybrid_merge(
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.id.cmp(&b.id))
});
results.truncate(limit);
results
+2
View File
@@ -507,6 +507,7 @@ mod tests {
max_images: 1,
max_image_size_mb: 5,
allow_remote_fetch: false,
..Default::default()
};
let error = prepare_messages_for_provider(&messages, &config)
@@ -549,6 +550,7 @@ mod tests {
max_images: 4,
max_image_size_mb: 1,
allow_remote_fetch: false,
..Default::default()
};
let error = prepare_messages_for_provider(&messages, &config)
+16 -1
View File
@@ -154,6 +154,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
reliability: crate::config::ReliabilityConfig::default(),
scheduler: crate::config::schema::SchedulerConfig::default(),
agent: crate::config::schema::AgentConfig::default(),
pacing: crate::config::PacingConfig::default(),
skills: crate::config::SkillsConfig::default(),
model_routes: Vec::new(),
embedding_routes: Vec::new(),
@@ -172,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_fetch: crate::config::WebFetchConfig::default(),
link_enricher: crate::config::LinkEnricherConfig::default(),
text_browser: crate::config::TextBrowserConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
@@ -196,6 +198,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
node_transport: crate::config::NodeTransportConfig::default(),
knowledge: crate::config::KnowledgeConfig::default(),
linkedin: crate::config::LinkedInConfig::default(),
image_gen: crate::config::ImageGenConfig::default(),
plugins: crate::config::PluginsConfig::default(),
locale: None,
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
@@ -418,9 +421,17 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
snapshot_enabled: false,
snapshot_on_hygiene: false,
auto_hydrate: true,
retrieval_stages: vec!["cache".into(), "fts".into(), "vector".into()],
rerank_enabled: false,
rerank_threshold: 5,
fts_early_return_score: 0.85,
default_namespace: "default".into(),
conflict_threshold: 0.85,
audit_enabled: false,
audit_retention_days: 30,
policy: crate::config::MemoryPolicyConfig::default(),
sqlite_open_timeout_secs: None,
qdrant: crate::config::QdrantConfig::default(),
mem0: crate::config::schema::Mem0Config::default(),
}
}
@@ -576,6 +587,7 @@ async fn run_quick_setup_with_home(
reliability: crate::config::ReliabilityConfig::default(),
scheduler: crate::config::schema::SchedulerConfig::default(),
agent: crate::config::schema::AgentConfig::default(),
pacing: crate::config::PacingConfig::default(),
skills: crate::config::SkillsConfig::default(),
model_routes: Vec::new(),
embedding_routes: Vec::new(),
@@ -594,6 +606,7 @@ async fn run_quick_setup_with_home(
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_fetch: crate::config::WebFetchConfig::default(),
link_enricher: crate::config::LinkEnricherConfig::default(),
text_browser: crate::config::TextBrowserConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
@@ -618,6 +631,7 @@ async fn run_quick_setup_with_home(
node_transport: crate::config::NodeTransportConfig::default(),
knowledge: crate::config::KnowledgeConfig::default(),
linkedin: crate::config::LinkedInConfig::default(),
image_gen: crate::config::ImageGenConfig::default(),
plugins: crate::config::PluginsConfig::default(),
locale: None,
verifiable_intent: crate::config::VerifiableIntentConfig::default(),
@@ -4181,6 +4195,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
device_id: detected_device_id,
room_id,
allowed_users,
allowed_rooms: vec![],
interrupt_on_new_message: false,
});
}
+1
View File
@@ -412,6 +412,7 @@ mod tests {
));
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "#!/bin/sh\ncat /dev/stdin").unwrap();
f.sync_all().unwrap();
drop(f);
#[cfg(unix)]
{
+10
View File
@@ -767,6 +767,12 @@ impl Provider for OpenAiCodexProvider {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
/// Mutex that serializes all tests which mutate process-global env vars
/// (`std::env::set_var` / `remove_var`). Each such test must hold this
/// lock for its entire duration so that parallel test threads don't race.
static ENV_MUTEX: Mutex<()> = Mutex::new(());
struct EnvGuard {
key: &'static str,
@@ -841,6 +847,7 @@ mod tests {
#[test]
fn resolve_responses_url_prefers_explicit_endpoint_env() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
let _endpoint_guard = EnvGuard::set(
CODEX_RESPONSES_URL_ENV,
Some("https://env.example.com/v1/responses"),
@@ -856,6 +863,7 @@ mod tests {
#[test]
fn resolve_responses_url_uses_provider_api_url_override() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
let _endpoint_guard = EnvGuard::set(CODEX_RESPONSES_URL_ENV, None);
let _base_guard = EnvGuard::set(CODEX_BASE_URL_ENV, None);
@@ -959,6 +967,7 @@ mod tests {
#[test]
fn resolve_reasoning_effort_prefers_configured_override() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("low"));
assert_eq!(
resolve_reasoning_effort("gpt-5-codex", Some("high")),
@@ -968,6 +977,7 @@ mod tests {
#[test]
fn resolve_reasoning_effort_uses_legacy_env_when_unconfigured() {
let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
let _guard = EnvGuard::set("ZEROCLAW_CODEX_REASONING_EFFORT", Some("minimal"));
assert_eq!(
resolve_reasoning_effort("gpt-5-codex", None),
+1
View File
@@ -108,6 +108,7 @@ fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
"token limit exceeded",
"prompt is too long",
"input is too long",
"prompt exceeds max length",
];
hints.iter().any(|hint| lower.contains(hint))
+17 -2
View File
@@ -97,7 +97,8 @@ pub struct SecurityPolicy {
/// Default allowed commands for Unix platforms.
#[cfg(not(target_os = "windows"))]
fn default_allowed_commands() -> Vec<String> {
vec![
#[allow(unused_mut)]
let mut cmds = vec![
"git".into(),
"npm".into(),
"cargo".into(),
@@ -111,7 +112,16 @@ fn default_allowed_commands() -> Vec<String> {
"head".into(),
"tail".into(),
"date".into(),
]
"df".into(),
"du".into(),
"uname".into(),
"uptime".into(),
"hostname".into(),
];
// `free` is Linux-only; it does not exist on macOS or other BSDs.
#[cfg(target_os = "linux")]
cmds.push("free".into());
cmds
}
/// Default allowed commands for Windows platforms.
@@ -142,6 +152,11 @@ fn default_allowed_commands() -> Vec<String> {
"wc".into(),
"head".into(),
"tail".into(),
"df".into(),
"du".into(),
"uname".into(),
"uptime".into(),
"hostname".into(),
]
}
+3 -3
View File
@@ -1236,7 +1236,7 @@ mod tests {
#[cfg(not(target_os = "windows"))]
#[test]
fn run_capture_reads_stdout() {
let out = run_capture(Command::new("sh").args(["-lc", "echo hello"]))
let out = run_capture(Command::new("sh").args(["-c", "echo hello"]))
.expect("stdout capture should succeed");
assert_eq!(out.trim(), "hello");
}
@@ -1244,7 +1244,7 @@ mod tests {
#[cfg(not(target_os = "windows"))]
#[test]
fn run_capture_falls_back_to_stderr() {
let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"]))
let out = run_capture(Command::new("sh").args(["-c", "echo warn 1>&2"]))
.expect("stderr capture should succeed");
assert_eq!(out.trim(), "warn");
}
@@ -1252,7 +1252,7 @@ mod tests {
#[cfg(not(target_os = "windows"))]
#[test]
fn run_checked_errors_on_non_zero_status() {
let err = run_checked(Command::new("sh").args(["-lc", "exit 17"]))
let err = run_checked(Command::new("sh").args(["-c", "exit 17"]))
.expect_err("non-zero exit should error");
assert!(err.to_string().contains("Command failed"));
}
+8 -2
View File
@@ -106,8 +106,14 @@ impl SkillCreator {
// Trim leading/trailing hyphens, then truncate.
let trimmed = collapsed.trim_matches('-');
if trimmed.len() > 64 {
// Truncate at a hyphen boundary if possible.
let truncated = &trimmed[..64];
// Find the nearest valid character boundary at or before 64 bytes.
let safe_index = trimmed
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= 64)
.last()
.unwrap_or(0);
let truncated = &trimmed[..safe_index];
truncated.trim_end_matches('-').to_string()
} else {
trimmed.to_string()
+89 -14
View File
@@ -738,15 +738,47 @@ pub fn skills_to_prompt_with_mode(
}
if !skill.tools.is_empty() {
let _ = writeln!(prompt, " <tools>");
for tool in &skill.tools {
let _ = writeln!(prompt, " <tool>");
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
let _ = writeln!(prompt, " </tool>");
// Tools with known kinds (shell, script, http) are registered as
// callable tool specs and can be invoked directly via function calling.
// We note them here for context but mark them as callable.
let registered: Vec<_> = skill
.tools
.iter()
.filter(|t| matches!(t.kind.as_str(), "shell" | "script" | "http"))
.collect();
let unregistered: Vec<_> = skill
.tools
.iter()
.filter(|t| !matches!(t.kind.as_str(), "shell" | "script" | "http"))
.collect();
if !registered.is_empty() {
let _ = writeln!(prompt, " <callable_tools hint=\"These are registered as callable tool specs. Invoke them directly by name ({{}}.{{}}) instead of using shell.\">");
for tool in &registered {
let _ = writeln!(prompt, " <tool>");
write_xml_text_element(
&mut prompt,
8,
"name",
&format!("{}.{}", skill.name, tool.name),
);
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
let _ = writeln!(prompt, " </tool>");
}
let _ = writeln!(prompt, " </callable_tools>");
}
if !unregistered.is_empty() {
let _ = writeln!(prompt, " <tools>");
for tool in &unregistered {
let _ = writeln!(prompt, " <tool>");
write_xml_text_element(&mut prompt, 8, "name", &tool.name);
write_xml_text_element(&mut prompt, 8, "description", &tool.description);
write_xml_text_element(&mut prompt, 8, "kind", &tool.kind);
let _ = writeln!(prompt, " </tool>");
}
let _ = writeln!(prompt, " </tools>");
}
let _ = writeln!(prompt, " </tools>");
}
let _ = writeln!(prompt, " </skill>");
@@ -756,6 +788,47 @@ pub fn skills_to_prompt_with_mode(
prompt
}
/// Convert skill tools into callable `Tool` trait objects.
///
/// Each skill's `[[tools]]` entries are converted to either `SkillShellTool`
/// (for `shell`/`script` kinds) or `SkillHttpTool` (for `http` kind),
/// enabling them to appear as first-class callable tool specs rather than
/// only as XML in the system prompt.
pub fn skills_to_tools(
skills: &[Skill],
security: std::sync::Arc<crate::security::SecurityPolicy>,
) -> Vec<Box<dyn crate::tools::traits::Tool>> {
let mut tools: Vec<Box<dyn crate::tools::traits::Tool>> = Vec::new();
for skill in skills {
for tool in &skill.tools {
match tool.kind.as_str() {
"shell" | "script" => {
tools.push(Box::new(crate::tools::skill_tool::SkillShellTool::new(
&skill.name,
tool,
security.clone(),
)));
}
"http" => {
tools.push(Box::new(crate::tools::skill_http::SkillHttpTool::new(
&skill.name,
tool,
)));
}
other => {
tracing::warn!(
"Unknown skill tool kind '{}' for {}.{}, skipping",
other,
skill.name,
tool.name
);
}
}
}
}
tools
}
/// Get the skills directory path
pub fn skills_dir(workspace_dir: &Path) -> PathBuf {
workspace_dir.join("skills")
@@ -1517,10 +1590,10 @@ command = "echo hello"
assert!(prompt.contains("read_skill(name)"));
assert!(!prompt.contains("<instructions>"));
assert!(!prompt.contains("<instruction>Do the thing.</instruction>"));
// Compact mode should still include tools so the LLM knows about them
assert!(prompt.contains("<tools>"));
assert!(prompt.contains("<name>run</name>"));
assert!(prompt.contains("<kind>shell</kind>"));
// Compact mode should still include tools so the LLM knows about them.
// Registered tools (shell/script/http) appear under <callable_tools>.
assert!(prompt.contains("<callable_tools"));
assert!(prompt.contains("<name>test.run</name>"));
}
#[test]
@@ -1710,9 +1783,11 @@ description = "Bare minimum"
}];
let prompt = skills_to_prompt(&skills, Path::new("/tmp"));
assert!(prompt.contains("weather"));
assert!(prompt.contains("<name>get_weather</name>"));
// Registered tools (shell kind) now appear under <callable_tools> with
// prefixed names (skill_name.tool_name).
assert!(prompt.contains("<callable_tools"));
assert!(prompt.contains("<name>weather.get_weather</name>"));
assert!(prompt.contains("<description>Fetch forecast</description>"));
assert!(prompt.contains("<kind>shell</kind>"));
}
#[test]
+636
View File
@@ -0,0 +1,636 @@
//! Live Canvas (A2UI) tool — push rendered content to a web canvas in real time.
//!
//! The agent can render HTML/SVG/Markdown to a named canvas, snapshot its
//! current state, clear it, or evaluate a JavaScript expression in the canvas
//! context. Content is stored in a shared [`CanvasStore`] and broadcast to
//! connected WebSocket clients via per-canvas channels.
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::broadcast;
/// Maximum content size per canvas frame (256 KB).
pub const MAX_CONTENT_SIZE: usize = 256 * 1024;
/// Maximum number of history frames kept per canvas.
const MAX_HISTORY_FRAMES: usize = 50;
/// Broadcast channel capacity per canvas.
const BROADCAST_CAPACITY: usize = 64;
/// Maximum number of concurrent canvases to prevent memory exhaustion.
const MAX_CANVAS_COUNT: usize = 100;
/// Allowed content types for canvas frames via the REST API.
pub const ALLOWED_CONTENT_TYPES: &[&str] = &["html", "svg", "markdown", "text"];
/// A single canvas frame (one render).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CanvasFrame {
/// Unique frame identifier.
pub frame_id: String,
/// Content type: `html`, `svg`, `markdown`, or `text`.
pub content_type: String,
/// The rendered content.
pub content: String,
/// ISO-8601 timestamp of when the frame was created.
pub timestamp: String,
}
/// Per-canvas state: current content + history + broadcast sender.
struct CanvasEntry {
current: Option<CanvasFrame>,
history: Vec<CanvasFrame>,
tx: broadcast::Sender<CanvasFrame>,
}
/// Shared canvas store — holds all active canvases.
///
/// Thread-safe and cheaply cloneable (wraps `Arc`).
#[derive(Clone)]
pub struct CanvasStore {
inner: Arc<RwLock<HashMap<String, CanvasEntry>>>,
}
impl Default for CanvasStore {
fn default() -> Self {
Self::new()
}
}
impl CanvasStore {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Push a new frame to a canvas. Creates the canvas if it does not exist.
/// Returns `None` if the maximum canvas count has been reached and this is a new canvas.
pub fn render(
&self,
canvas_id: &str,
content_type: &str,
content: &str,
) -> Option<CanvasFrame> {
let frame = CanvasFrame {
frame_id: uuid::Uuid::new_v4().to_string(),
content_type: content_type.to_string(),
content: content.to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
let mut store = self.inner.write();
// Enforce canvas count limit for new canvases.
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
return None;
}
let entry = store
.entry(canvas_id.to_string())
.or_insert_with(|| CanvasEntry {
current: None,
history: Vec::new(),
tx: broadcast::channel(BROADCAST_CAPACITY).0,
});
entry.current = Some(frame.clone());
entry.history.push(frame.clone());
if entry.history.len() > MAX_HISTORY_FRAMES {
let excess = entry.history.len() - MAX_HISTORY_FRAMES;
entry.history.drain(..excess);
}
// Best-effort broadcast — ignore errors (no receivers is fine).
let _ = entry.tx.send(frame.clone());
Some(frame)
}
/// Get the current (most recent) frame for a canvas.
pub fn snapshot(&self, canvas_id: &str) -> Option<CanvasFrame> {
let store = self.inner.read();
store.get(canvas_id).and_then(|entry| entry.current.clone())
}
/// Get the frame history for a canvas.
pub fn history(&self, canvas_id: &str) -> Vec<CanvasFrame> {
let store = self.inner.read();
store
.get(canvas_id)
.map(|entry| entry.history.clone())
.unwrap_or_default()
}
/// Clear a canvas (removes current content and history).
pub fn clear(&self, canvas_id: &str) -> bool {
let mut store = self.inner.write();
if let Some(entry) = store.get_mut(canvas_id) {
entry.current = None;
entry.history.clear();
// Send an empty frame to signal clear to subscribers.
let clear_frame = CanvasFrame {
frame_id: uuid::Uuid::new_v4().to_string(),
content_type: "clear".to_string(),
content: String::new(),
timestamp: chrono::Utc::now().to_rfc3339(),
};
let _ = entry.tx.send(clear_frame);
true
} else {
false
}
}
/// Subscribe to real-time updates for a canvas.
/// Creates the canvas entry if it does not exist (subject to canvas count limit).
/// Returns `None` if the canvas does not exist and the limit has been reached.
pub fn subscribe(&self, canvas_id: &str) -> Option<broadcast::Receiver<CanvasFrame>> {
let mut store = self.inner.write();
// Enforce canvas count limit for new entries.
if !store.contains_key(canvas_id) && store.len() >= MAX_CANVAS_COUNT {
return None;
}
let entry = store
.entry(canvas_id.to_string())
.or_insert_with(|| CanvasEntry {
current: None,
history: Vec::new(),
tx: broadcast::channel(BROADCAST_CAPACITY).0,
});
Some(entry.tx.subscribe())
}
/// List all canvas IDs that currently have content.
pub fn list(&self) -> Vec<String> {
let store = self.inner.read();
store.keys().cloned().collect()
}
}
/// `CanvasTool` — agent-callable tool for the Live Canvas (A2UI) system.
pub struct CanvasTool {
store: CanvasStore,
}
impl CanvasTool {
pub fn new(store: CanvasStore) -> Self {
Self { store }
}
}
#[async_trait]
impl Tool for CanvasTool {
fn name(&self) -> &str {
"canvas"
}
fn description(&self) -> &str {
"Push rendered content (HTML, SVG, Markdown) to a live web canvas that users can see \
in real-time. Actions: render (push content), snapshot (get current content), \
clear (reset canvas), eval (evaluate JS expression in canvas context). \
Each canvas is identified by a canvas_id string."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"action": {
"type": "string",
"description": "Action to perform on the canvas.",
"enum": ["render", "snapshot", "clear", "eval"]
},
"canvas_id": {
"type": "string",
"description": "Unique identifier for the canvas. Defaults to 'default'."
},
"content_type": {
"type": "string",
"description": "Content type for render action: html, svg, markdown, or text.",
"enum": ["html", "svg", "markdown", "text"]
},
"content": {
"type": "string",
"description": "Content to render (for render action)."
},
"expression": {
"type": "string",
"description": "JavaScript expression to evaluate (for eval action). \
The result is returned as text. Evaluated client-side in the canvas iframe."
}
},
"required": ["action"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let action = match args.get("action").and_then(|v| v.as_str()) {
Some(a) => a,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: action".to_string()),
});
}
};
let canvas_id = args
.get("canvas_id")
.and_then(|v| v.as_str())
.unwrap_or("default");
match action {
"render" => {
let content_type = args
.get("content_type")
.and_then(|v| v.as_str())
.unwrap_or("html");
let content = match args.get("content").and_then(|v| v.as_str()) {
Some(c) => c,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Missing required parameter: content (for render action)"
.to_string(),
),
});
}
};
if content.len() > MAX_CONTENT_SIZE {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Content exceeds maximum size of {} bytes",
MAX_CONTENT_SIZE
)),
});
}
match self.store.render(canvas_id, content_type, content) {
Some(frame) => Ok(ToolResult {
success: true,
output: format!(
"Rendered {} content to canvas '{}' (frame: {})",
content_type, canvas_id, frame.frame_id
),
error: None,
}),
None => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Maximum canvas count ({}) reached. Clear unused canvases first.",
MAX_CANVAS_COUNT
)),
}),
}
}
"snapshot" => match self.store.snapshot(canvas_id) {
Some(frame) => Ok(ToolResult {
success: true,
output: serde_json::to_string_pretty(&frame)
.unwrap_or_else(|_| frame.content.clone()),
error: None,
}),
None => Ok(ToolResult {
success: true,
output: format!("Canvas '{}' is empty", canvas_id),
error: None,
}),
},
"clear" => {
let existed = self.store.clear(canvas_id);
Ok(ToolResult {
success: true,
output: if existed {
format!("Canvas '{}' cleared", canvas_id)
} else {
format!("Canvas '{}' was already empty", canvas_id)
},
error: None,
})
}
"eval" => {
// Eval is handled client-side. We store an eval request as a special frame
// that the web viewer interprets.
let expression = match args.get("expression").and_then(|v| v.as_str()) {
Some(e) => e,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Missing required parameter: expression (for eval action)"
.to_string(),
),
});
}
};
// Push a special eval frame so connected clients know to evaluate it.
match self.store.render(canvas_id, "eval", expression) {
Some(frame) => Ok(ToolResult {
success: true,
output: format!(
"Eval request sent to canvas '{}' (frame: {}). \
Result will be available to connected viewers.",
canvas_id, frame.frame_id
),
error: None,
}),
None => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Maximum canvas count ({}) reached. Clear unused canvases first.",
MAX_CANVAS_COUNT
)),
}),
}
}
other => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Unknown action: '{}'. Valid actions: render, snapshot, clear, eval",
other
)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn canvas_store_render_and_snapshot() {
let store = CanvasStore::new();
let frame = store.render("test", "html", "<h1>Hello</h1>").unwrap();
assert_eq!(frame.content_type, "html");
assert_eq!(frame.content, "<h1>Hello</h1>");
let snapshot = store.snapshot("test").unwrap();
assert_eq!(snapshot.frame_id, frame.frame_id);
assert_eq!(snapshot.content, "<h1>Hello</h1>");
}
#[test]
fn canvas_store_snapshot_empty_returns_none() {
let store = CanvasStore::new();
assert!(store.snapshot("nonexistent").is_none());
}
#[test]
fn canvas_store_clear_removes_content() {
let store = CanvasStore::new();
store.render("test", "html", "<p>content</p>");
assert!(store.snapshot("test").is_some());
let cleared = store.clear("test");
assert!(cleared);
assert!(store.snapshot("test").is_none());
}
#[test]
fn canvas_store_clear_nonexistent_returns_false() {
let store = CanvasStore::new();
assert!(!store.clear("nonexistent"));
}
#[test]
fn canvas_store_history_tracks_frames() {
let store = CanvasStore::new();
store.render("test", "html", "frame1");
store.render("test", "html", "frame2");
store.render("test", "html", "frame3");
let history = store.history("test");
assert_eq!(history.len(), 3);
assert_eq!(history[0].content, "frame1");
assert_eq!(history[2].content, "frame3");
}
#[test]
fn canvas_store_history_limit_enforced() {
let store = CanvasStore::new();
for i in 0..60 {
store.render("test", "html", &format!("frame{i}"));
}
let history = store.history("test");
assert_eq!(history.len(), MAX_HISTORY_FRAMES);
// Oldest frames should have been dropped
assert_eq!(history[0].content, "frame10");
}
#[test]
fn canvas_store_list_returns_canvas_ids() {
let store = CanvasStore::new();
store.render("alpha", "html", "a");
store.render("beta", "svg", "b");
let mut ids = store.list();
ids.sort();
assert_eq!(ids, vec!["alpha", "beta"]);
}
#[test]
fn canvas_store_subscribe_receives_updates() {
let store = CanvasStore::new();
let mut rx = store.subscribe("test").unwrap();
store.render("test", "html", "<p>live</p>");
let frame = rx.try_recv().unwrap();
assert_eq!(frame.content, "<p>live</p>");
}
#[tokio::test]
async fn canvas_tool_render_action() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store.clone());
let result = tool
.execute(json!({
"action": "render",
"canvas_id": "test",
"content_type": "html",
"content": "<h1>Hello World</h1>"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Rendered html content"));
let snapshot = store.snapshot("test").unwrap();
assert_eq!(snapshot.content, "<h1>Hello World</h1>");
}
#[tokio::test]
async fn canvas_tool_snapshot_action() {
let store = CanvasStore::new();
store.render("test", "html", "<p>snap</p>");
let tool = CanvasTool::new(store);
let result = tool
.execute(json!({"action": "snapshot", "canvas_id": "test"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("<p>snap</p>"));
}
#[tokio::test]
async fn canvas_tool_snapshot_empty() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store);
let result = tool
.execute(json!({"action": "snapshot", "canvas_id": "empty"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("empty"));
}
#[tokio::test]
async fn canvas_tool_clear_action() {
let store = CanvasStore::new();
store.render("test", "html", "<p>clear me</p>");
let tool = CanvasTool::new(store.clone());
let result = tool
.execute(json!({"action": "clear", "canvas_id": "test"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("cleared"));
assert!(store.snapshot("test").is_none());
}
#[tokio::test]
async fn canvas_tool_eval_action() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store.clone());
let result = tool
.execute(json!({
"action": "eval",
"canvas_id": "test",
"expression": "document.title"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Eval request sent"));
let snapshot = store.snapshot("test").unwrap();
assert_eq!(snapshot.content_type, "eval");
assert_eq!(snapshot.content, "document.title");
}
#[tokio::test]
async fn canvas_tool_unknown_action() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store);
let result = tool.execute(json!({"action": "invalid"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("Unknown action"));
}
#[tokio::test]
async fn canvas_tool_missing_action() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store);
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("action"));
}
#[tokio::test]
async fn canvas_tool_render_missing_content() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store);
let result = tool
.execute(json!({"action": "render", "canvas_id": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("content"));
}
#[tokio::test]
async fn canvas_tool_render_content_too_large() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store);
let big_content = "x".repeat(MAX_CONTENT_SIZE + 1);
let result = tool
.execute(json!({
"action": "render",
"canvas_id": "test",
"content": big_content
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("maximum size"));
}
#[tokio::test]
async fn canvas_tool_default_canvas_id() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store.clone());
let result = tool
.execute(json!({
"action": "render",
"content_type": "html",
"content": "<p>default</p>"
}))
.await
.unwrap();
assert!(result.success);
assert!(store.snapshot("default").is_some());
}
#[test]
fn canvas_store_enforces_max_canvas_count() {
let store = CanvasStore::new();
// Create MAX_CANVAS_COUNT canvases
for i in 0..MAX_CANVAS_COUNT {
assert!(store
.render(&format!("canvas_{i}"), "html", "content")
.is_some());
}
// The next new canvas should be rejected
assert!(store.render("one_too_many", "html", "content").is_none());
// But rendering to an existing canvas should still work
assert!(store.render("canvas_0", "html", "updated").is_some());
}
#[tokio::test]
async fn canvas_tool_eval_missing_expression() {
let store = CanvasStore::new();
let tool = CanvasTool::new(store);
let result = tool
.execute(json!({"action": "eval", "canvas_id": "test"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("expression"));
}
}
+1
View File
@@ -530,6 +530,7 @@ impl DelegateTool {
&[],
None,
None,
&crate::config::PacingConfig::default(),
),
)
.await;
+204
View File
@@ -0,0 +1,204 @@
use super::traits::{Tool, ToolResult};
use crate::memory::Memory;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::sync::Arc;
/// Search Discord message history stored in discord.db.
pub struct DiscordSearchTool {
discord_memory: Arc<dyn Memory>,
}
impl DiscordSearchTool {
pub fn new(discord_memory: Arc<dyn Memory>) -> Self {
Self { discord_memory }
}
}
#[async_trait]
impl Tool for DiscordSearchTool {
fn name(&self) -> &str {
"discord_search"
}
fn description(&self) -> &str {
"Search Discord message history. Returns messages matching a keyword query, optionally filtered by channel_id, author_id, or time range."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Keywords or phrase to search for in Discord messages (optional if since/until provided)"
},
"limit": {
"type": "integer",
"description": "Max results to return (default: 10)"
},
"channel_id": {
"type": "string",
"description": "Filter results to a specific Discord channel ID"
},
"since": {
"type": "string",
"description": "Filter messages at or after this time (RFC 3339, e.g. 2025-03-01T00:00:00Z)"
},
"until": {
"type": "string",
"description": "Filter messages at or before this time (RFC 3339)"
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
let channel_id = args.get("channel_id").and_then(|v| v.as_str());
let since = args.get("since").and_then(|v| v.as_str());
let until = args.get("until").and_then(|v| v.as_str());
if query.trim().is_empty() && since.is_none() && until.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Provide at least 'query' (keywords) or time range ('since'/'until')".into(),
),
});
}
if let Some(s) = since {
if chrono::DateTime::parse_from_rfc3339(s).is_err() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid 'since' date: {s}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
)),
});
}
}
if let Some(u) = until {
if chrono::DateTime::parse_from_rfc3339(u).is_err() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid 'until' date: {u}. Expected RFC 3339, e.g. 2025-03-01T00:00:00Z"
)),
});
}
}
if let (Some(s), Some(u)) = (since, until) {
if let (Ok(s_dt), Ok(u_dt)) = (
chrono::DateTime::parse_from_rfc3339(s),
chrono::DateTime::parse_from_rfc3339(u),
) {
if s_dt >= u_dt {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'since' must be before 'until'".into()),
});
}
}
}
#[allow(clippy::cast_possible_truncation)]
let limit = args
.get("limit")
.and_then(serde_json::Value::as_u64)
.map_or(10, |v| v as usize);
match self
.discord_memory
.recall(query, limit, channel_id, since, until)
.await
{
Ok(entries) if entries.is_empty() => Ok(ToolResult {
success: true,
output: "No Discord messages found.".into(),
error: None,
}),
Ok(entries) => {
let mut output = format!("Found {} Discord messages:\n", entries.len());
for entry in &entries {
let score = entry
.score
.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
let _ = writeln!(output, "- {}{score}", entry.content);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Discord search failed: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{MemoryCategory, SqliteMemory};
use tempfile::TempDir;
fn seeded_discord_mem() -> (TempDir, Arc<dyn Memory>) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new_named(tmp.path(), "discord").unwrap();
(tmp, Arc::new(mem))
}
#[tokio::test]
async fn search_empty() {
let (_tmp, mem) = seeded_discord_mem();
let tool = DiscordSearchTool::new(mem);
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No Discord messages found"));
}
#[tokio::test]
async fn search_finds_match() {
let (_tmp, mem) = seeded_discord_mem();
mem.store(
"discord_001",
"@user1 in #general at 2025-01-01T00:00:00Z: hello world",
MemoryCategory::Custom("discord".to_string()),
Some("general"),
)
.await
.unwrap();
let tool = DiscordSearchTool::new(mem);
let result = tool.execute(json!({"query": "hello"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("hello"));
}
#[tokio::test]
async fn search_requires_query_or_time() {
let (_tmp, mem) = seeded_discord_mem();
let tool = DiscordSearchTool::new(mem);
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("at least"));
}
#[test]
fn name_and_schema() {
let (_tmp, mem) = seeded_discord_mem();
let tool = DiscordSearchTool::new(mem);
assert_eq!(tool.name(), "discord_search");
assert!(tool.parameters_schema()["properties"]["query"].is_object());
}
}
+494
View File
@@ -0,0 +1,494 @@
use super::traits::{Tool, ToolResult};
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use anyhow::Context;
use async_trait::async_trait;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
///
/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
/// calls the fal.ai synchronous endpoint, downloads the resulting image,
/// and saves it to `{workspace}/images/{filename}.png`.
pub struct ImageGenTool {
security: Arc<SecurityPolicy>,
workspace_dir: PathBuf,
default_model: String,
api_key_env: String,
}
impl ImageGenTool {
pub fn new(
security: Arc<SecurityPolicy>,
workspace_dir: PathBuf,
default_model: String,
api_key_env: String,
) -> Self {
Self {
security,
workspace_dir,
default_model,
api_key_env,
}
}
/// Build a reusable HTTP client with reasonable timeouts.
fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.unwrap_or_default()
}
/// Read an API key from the environment.
fn read_api_key(env_var: &str) -> Result<String, String> {
std::env::var(env_var)
.map(|v| v.trim().to_string())
.ok()
.filter(|v| !v.is_empty())
.ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
}
/// Core generation logic: call fal.ai, download image, save to disk.
async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
// ── Parse parameters ───────────────────────────────────────
let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
Some(p) if !p.trim().is_empty() => p.trim().to_string(),
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Missing required parameter: 'prompt'".into()),
});
}
};
let filename = args
.get("filename")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.unwrap_or("generated_image");
// Sanitize filename — strip path components to prevent traversal.
let safe_name = PathBuf::from(filename).file_name().map_or_else(
|| "generated_image".to_string(),
|n| n.to_string_lossy().to_string(),
);
let size = args
.get("size")
.and_then(|v| v.as_str())
.unwrap_or("square_hd");
// Validate size enum.
const VALID_SIZES: &[&str] = &[
"square_hd",
"landscape_4_3",
"portrait_4_3",
"landscape_16_9",
"portrait_16_9",
];
if !VALID_SIZES.contains(&size) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid size '{size}'. Valid values: {}",
VALID_SIZES.join(", ")
)),
});
}
let model = args
.get("model")
.and_then(|v| v.as_str())
.filter(|s| !s.trim().is_empty())
.unwrap_or(&self.default_model);
// Validate model identifier: must look like a fal.ai model path
// (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
// strings, or fragments that could redirect the HTTP request.
if model.contains("..")
|| model.contains('?')
|| model.contains('#')
|| model.contains('\\')
|| model.starts_with('/')
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid model identifier '{model}'. \
Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
)),
});
}
// ── Read API key ───────────────────────────────────────────
let api_key = match Self::read_api_key(&self.api_key_env) {
Ok(k) => k,
Err(msg) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(msg),
});
}
};
// ── Call fal.ai ────────────────────────────────────────────
let client = Self::http_client();
let url = format!("https://fal.run/{model}");
let body = json!({
"prompt": prompt,
"image_size": size,
"num_images": 1
});
let resp = client
.post(&url)
.header("Authorization", format!("Key {api_key}"))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("fal.ai request failed")?;
let status = resp.status();
if !status.is_success() {
let body_text = resp.text().await.unwrap_or_default();
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("fal.ai API error ({status}): {body_text}")),
});
}
let resp_json: serde_json::Value = resp
.json()
.await
.context("Failed to parse fal.ai response as JSON")?;
let image_url = resp_json
.pointer("/images/0/url")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("No image URL in fal.ai response"))?;
// ── Download image ─────────────────────────────────────────
let img_resp = client
.get(image_url)
.send()
.await
.context("Failed to download generated image")?;
if !img_resp.status().is_success() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Failed to download image from {image_url} ({})",
img_resp.status()
)),
});
}
let bytes = img_resp
.bytes()
.await
.context("Failed to read image bytes")?;
// ── Save to disk ───────────────────────────────────────────
let images_dir = self.workspace_dir.join("images");
tokio::fs::create_dir_all(&images_dir)
.await
.context("Failed to create images directory")?;
let output_path = images_dir.join(format!("{safe_name}.png"));
tokio::fs::write(&output_path, &bytes)
.await
.context("Failed to write image file")?;
let size_kb = bytes.len() / 1024;
Ok(ToolResult {
success: true,
output: format!(
"Image generated successfully.\n\
File: {}\n\
Size: {} KB\n\
Model: {}\n\
Prompt: {}",
output_path.display(),
size_kb,
model,
prompt,
),
error: None,
})
}
}
#[async_trait]
impl Tool for ImageGenTool {
fn name(&self) -> &str {
"image_gen"
}
fn description(&self) -> &str {
"Generate an image from a text prompt using fal.ai (Flux models). \
Saves the result to the workspace images directory and returns the file path."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"required": ["prompt"],
"properties": {
"prompt": {
"type": "string",
"description": "Text prompt describing the image to generate."
},
"filename": {
"type": "string",
"description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
},
"size": {
"type": "string",
"enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
"description": "Image aspect ratio / size preset (default: 'square_hd')."
},
"model": {
"type": "string",
"description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
// Security: image generation is a side-effecting action (HTTP + file write).
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "image_gen")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
self.generate(args).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Full,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
fn test_tool() -> ImageGenTool {
ImageGenTool::new(
test_security(),
std::env::temp_dir(),
"fal-ai/flux/schnell".into(),
"FAL_API_KEY".into(),
)
}
#[test]
fn tool_name() {
let tool = test_tool();
assert_eq!(tool.name(), "image_gen");
}
#[test]
fn tool_description_is_nonempty() {
let tool = test_tool();
assert!(!tool.description().is_empty());
assert!(tool.description().contains("image"));
}
#[test]
fn tool_schema_has_required_prompt() {
let tool = test_tool();
let schema = tool.parameters_schema();
assert_eq!(schema["required"], json!(["prompt"]));
assert!(schema["properties"]["prompt"].is_object());
}
#[test]
fn tool_schema_has_optional_params() {
let tool = test_tool();
let schema = tool.parameters_schema();
assert!(schema["properties"]["filename"].is_object());
assert!(schema["properties"]["size"].is_object());
assert!(schema["properties"]["model"].is_object());
}
#[test]
fn tool_spec_roundtrip() {
let tool = test_tool();
let spec = tool.spec();
assert_eq!(spec.name, "image_gen");
assert!(spec.parameters.is_object());
}
#[tokio::test]
async fn missing_prompt_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("prompt"));
}
#[tokio::test]
async fn empty_prompt_returns_error() {
let tool = test_tool();
let result = tool.execute(json!({"prompt": " "})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("prompt"));
}
#[tokio::test]
async fn missing_api_key_returns_error() {
// Temporarily ensure the env var is unset.
let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN");
let tool = ImageGenTool::new(
test_security(),
std::env::temp_dir(),
"fal-ai/flux/schnell".into(),
"FAL_API_KEY_TEST_IMAGE_GEN".into(),
);
let result = tool
.execute(json!({"prompt": "a sunset over the ocean"}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap()
.contains("FAL_API_KEY_TEST_IMAGE_GEN"));
// Restore if it was set.
if let Some(val) = original {
std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val);
}
}
#[tokio::test]
async fn invalid_size_returns_error() {
// Set a dummy key so we get past the key check.
std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key");
let tool = ImageGenTool::new(
test_security(),
std::env::temp_dir(),
"fal-ai/flux/schnell".into(),
"FAL_API_KEY_TEST_SIZE".into(),
);
let result = tool
.execute(json!({"prompt": "test", "size": "invalid_size"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("Invalid size"));
std::env::remove_var("FAL_API_KEY_TEST_SIZE");
}
#[tokio::test]
async fn read_only_autonomy_blocks_execution() {
let security = Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::ReadOnly,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
});
let tool = ImageGenTool::new(
security,
std::env::temp_dir(),
"fal-ai/flux/schnell".into(),
"FAL_API_KEY".into(),
);
let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
assert!(!result.success);
let err = result.error.as_deref().unwrap();
assert!(
err.contains("read-only") || err.contains("image_gen"),
"expected read-only or image_gen in error, got: {err}"
);
}
#[tokio::test]
async fn invalid_model_with_traversal_returns_error() {
std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key");
let tool = ImageGenTool::new(
test_security(),
std::env::temp_dir(),
"fal-ai/flux/schnell".into(),
"FAL_API_KEY_TEST_MODEL".into(),
);
let result = tool
.execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap()
.contains("Invalid model identifier"));
std::env::remove_var("FAL_API_KEY_TEST_MODEL");
}
#[test]
fn read_api_key_missing() {
let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("DEFINITELY_NOT_SET_ZC_TEST_12345"));
}
#[test]
fn filename_traversal_is_sanitized() {
// Verify that path traversal in filenames is stripped to just the final component.
let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
|| "generated_image".to_string(),
|n| n.to_string_lossy().to_string(),
);
assert_eq!(sanitized, "passwd");
// ".." alone has no file_name, falls back to default.
let sanitized = PathBuf::from("..").file_name().map_or_else(
|| "generated_image".to_string(),
|n| n.to_string_lossy().to_string(),
);
assert_eq!(sanitized, "generated_image");
}
#[test]
fn read_api_key_present() {
std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123");
let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test_value_123");
std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY");
}
}
+119 -10
View File
@@ -20,6 +20,7 @@ pub mod browser;
pub mod browser_delegate;
pub mod browser_open;
pub mod calculator;
pub mod canvas;
pub mod claude_code;
pub mod cli_discovery;
pub mod cloud_ops;
@@ -34,6 +35,7 @@ pub mod cron_runs;
pub mod cron_update;
pub mod data_management;
pub mod delegate;
pub mod discord_search;
pub mod file_edit;
pub mod file_read;
pub mod file_write;
@@ -47,6 +49,7 @@ pub mod hardware_memory_map;
#[cfg(feature = "hardware")]
pub mod hardware_memory_read;
pub mod http_request;
pub mod image_gen;
pub mod image_info;
pub mod jira_tool;
pub mod knowledge_tool;
@@ -69,13 +72,17 @@ pub mod pdf_read;
pub mod project_intel;
pub mod proxy_config;
pub mod pushover;
pub mod reaction;
pub mod read_skill;
pub mod report_templates;
pub mod schedule;
pub mod schema;
pub mod screenshot;
pub mod security_ops;
pub mod sessions;
pub mod shell;
pub mod skill_http;
pub mod skill_tool;
pub mod swarm;
pub mod text_browser;
pub mod tool_search;
@@ -93,6 +100,7 @@ pub use browser::{BrowserTool, ComputerUseConfig};
pub use browser_delegate::{BrowserDelegateConfig, BrowserDelegateTool};
pub use browser_open::BrowserOpenTool;
pub use calculator::CalculatorTool;
pub use canvas::{CanvasStore, CanvasTool};
pub use claude_code::ClaudeCodeTool;
pub use cloud_ops::CloudOpsTool;
pub use cloud_patterns::CloudPatternsTool;
@@ -106,6 +114,7 @@ pub use cron_runs::CronRunsTool;
pub use cron_update::CronUpdateTool;
pub use data_management::DataManagementTool;
pub use delegate::DelegateTool;
pub use discord_search::DiscordSearchTool;
pub use file_edit::FileEditTool;
pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
@@ -119,6 +128,7 @@ pub use hardware_memory_map::HardwareMemoryMapTool;
#[cfg(feature = "hardware")]
pub use hardware_memory_read::HardwareMemoryReadTool;
pub use http_request::HttpRequestTool;
pub use image_gen::ImageGenTool;
pub use image_info::ImageInfoTool;
pub use jira_tool::JiraTool;
pub use knowledge_tool::KnowledgeTool;
@@ -139,13 +149,19 @@ pub use pdf_read::PdfReadTool;
pub use project_intel::ProjectIntelTool;
pub use proxy_config::ProxyConfigTool;
pub use pushover::PushoverTool;
pub use reaction::{ChannelMapHandle, ReactionTool};
pub use read_skill::ReadSkillTool;
pub use schedule::ScheduleTool;
#[allow(unused_imports)]
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use screenshot::ScreenshotTool;
pub use security_ops::SecurityOpsTool;
pub use sessions::{SessionsHistoryTool, SessionsListTool, SessionsSendTool};
pub use shell::ShellTool;
#[allow(unused_imports)]
pub use skill_http::SkillHttpTool;
#[allow(unused_imports)]
pub use skill_tool::SkillShellTool;
pub use swarm::SwarmTool;
pub use text_browser::TextBrowserTool;
pub use tool_search::ToolSearchTool;
@@ -247,6 +263,33 @@ pub fn default_tools_with_runtime(
]
}
/// Register skill-defined tools into an existing tool registry.
///
/// Converts each skill's `[[tools]]` entries into callable `Tool` implementations
/// and appends them to the registry. Skill tools that would shadow a built-in tool
/// name are skipped with a warning.
pub fn register_skill_tools(
tools_registry: &mut Vec<Box<dyn Tool>>,
skills: &[crate::skills::Skill],
security: Arc<SecurityPolicy>,
) {
let skill_tools = crate::skills::skills_to_tools(skills, security);
let existing_names: std::collections::HashSet<String> = tools_registry
.iter()
.map(|t| t.name().to_string())
.collect();
for tool in skill_tools {
if existing_names.contains(tool.name()) {
tracing::warn!(
"Skill tool '{}' shadows built-in tool, skipping",
tool.name()
);
} else {
tools_registry.push(tool);
}
}
}
/// Create full tool registry including memory tools and optional Composio
#[allow(clippy::implicit_hasher, clippy::too_many_arguments)]
pub fn all_tools(
@@ -262,7 +305,12 @@ pub fn all_tools(
agents: &HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<&str>,
root_config: &crate::config::Config,
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
canvas_store: Option<CanvasStore>,
) -> (
Vec<Box<dyn Tool>>,
Option<DelegateParentToolsHandle>,
Option<ChannelMapHandle>,
) {
all_tools_with_runtime(
config,
security,
@@ -277,6 +325,7 @@ pub fn all_tools(
agents,
fallback_api_key,
root_config,
canvas_store,
)
}
@@ -296,7 +345,12 @@ pub fn all_tools_with_runtime(
agents: &HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<&str>,
root_config: &crate::config::Config,
) -> (Vec<Box<dyn Tool>>, Option<DelegateParentToolsHandle>) {
canvas_store: Option<CanvasStore>,
) -> (
Vec<Box<dyn Tool>>,
Option<DelegateParentToolsHandle>,
Option<ChannelMapHandle>,
) {
let has_shell_access = runtime.has_shell_access();
let sandbox = create_sandbox(&root_config.security);
let mut tool_arcs: Vec<Arc<dyn Tool>> = vec![
@@ -336,8 +390,21 @@ pub fn all_tools_with_runtime(
)),
Arc::new(CalculatorTool::new()),
Arc::new(WeatherTool::new()),
Arc::new(CanvasTool::new(canvas_store.unwrap_or_default())),
];
// Register discord_search if discord_history channel is configured
if root_config.channels_config.discord_history.is_some() {
match crate::memory::SqliteMemory::new_named(workspace_dir, "discord") {
Ok(discord_mem) => {
tool_arcs.push(Arc::new(DiscordSearchTool::new(Arc::new(discord_mem))));
}
Err(e) => {
tracing::warn!("discord_search: failed to open discord.db: {e}");
}
}
}
if matches!(
root_config.skills.prompt_injection_mode,
crate::config::SkillsPromptInjectionMode::Compact
@@ -424,6 +491,7 @@ pub fn all_tools_with_runtime(
tool_arcs.push(Arc::new(WebSearchTool::new_with_config(
root_config.web_search.provider.clone(),
root_config.web_search.brave_api_key.clone(),
root_config.web_search.searxng_instance_url.clone(),
root_config.web_search.max_results,
root_config.web_search.timeout_secs,
root_config.config_path.clone(),
@@ -545,6 +613,18 @@ pub fn all_tools_with_runtime(
tool_arcs.push(Arc::new(ScreenshotTool::new(security.clone())));
tool_arcs.push(Arc::new(ImageInfoTool::new(security.clone())));
// Session-to-session messaging tools (always available when sessions dir exists)
if let Ok(session_store) = crate::channels::session_store::SessionStore::new(workspace_dir) {
let backend: Arc<dyn crate::channels::session_backend::SessionBackend> =
Arc::new(session_store);
tool_arcs.push(Arc::new(SessionsListTool::new(backend.clone())));
tool_arcs.push(Arc::new(SessionsHistoryTool::new(
backend.clone(),
security.clone(),
)));
tool_arcs.push(Arc::new(SessionsSendTool::new(backend, security.clone())));
}
// LinkedIn integration (config-gated)
if root_config.linkedin.enabled {
tool_arcs.push(Arc::new(LinkedInTool::new(
@@ -556,6 +636,16 @@ pub fn all_tools_with_runtime(
)));
}
// Standalone image generation tool (config-gated)
if root_config.image_gen.enabled {
tool_arcs.push(Arc::new(ImageGenTool::new(
security.clone(),
workspace_dir.to_path_buf(),
root_config.image_gen.default_model.clone(),
root_config.image_gen.api_key_env.clone(),
)));
}
if let Some(key) = composio_key {
if !key.is_empty() {
tool_arcs.push(Arc::new(ComposioTool::new(
@@ -566,6 +656,11 @@ pub fn all_tools_with_runtime(
}
}
// Emoji reaction tool — always registered; channel map populated later by start_channels.
let reaction_tool = ReactionTool::new(security.clone());
let reaction_handle = reaction_tool.channel_map_handle();
tool_arcs.push(Arc::new(reaction_tool));
// Microsoft 365 Graph API integration
if root_config.microsoft365.enabled {
let ms_cfg = &root_config.microsoft365;
@@ -592,7 +687,11 @@ pub fn all_tools_with_runtime(
tracing::error!(
"microsoft365: client_credentials auth_flow requires a non-empty client_secret"
);
return (boxed_registry_from_arcs(tool_arcs), None);
return (
boxed_registry_from_arcs(tool_arcs),
None,
Some(reaction_handle),
);
}
let resolved = microsoft365::types::Microsoft365ResolvedConfig {
@@ -776,7 +875,11 @@ pub fn all_tools_with_runtime(
}
}
(boxed_registry_from_arcs(tool_arcs), delegate_handle)
(
boxed_registry_from_arcs(tool_arcs),
delegate_handle,
Some(reaction_handle),
)
}
#[cfg(test)]
@@ -820,7 +923,7 @@ mod tests {
let http = crate::config::HttpRequestConfig::default();
let cfg = test_config(&tmp);
let (tools, _) = all_tools(
let (tools, _, _) = all_tools(
Arc::new(Config::default()),
&security,
mem,
@@ -833,6 +936,7 @@ mod tests {
&HashMap::new(),
None,
&cfg,
None,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(!names.contains(&"browser_open"));
@@ -862,7 +966,7 @@ mod tests {
let http = crate::config::HttpRequestConfig::default();
let cfg = test_config(&tmp);
let (tools, _) = all_tools(
let (tools, _, _) = all_tools(
Arc::new(Config::default()),
&security,
mem,
@@ -875,6 +979,7 @@ mod tests {
&HashMap::new(),
None,
&cfg,
None,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"browser_open"));
@@ -1015,7 +1120,7 @@ mod tests {
},
);
let (tools, _) = all_tools(
let (tools, _, _) = all_tools(
Arc::new(Config::default()),
&security,
mem,
@@ -1028,6 +1133,7 @@ mod tests {
&agents,
Some("delegate-test-credential"),
&cfg,
None,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"delegate"));
@@ -1048,7 +1154,7 @@ mod tests {
let http = crate::config::HttpRequestConfig::default();
let cfg = test_config(&tmp);
let (tools, _) = all_tools(
let (tools, _, _) = all_tools(
Arc::new(Config::default()),
&security,
mem,
@@ -1061,6 +1167,7 @@ mod tests {
&HashMap::new(),
None,
&cfg,
None,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(!names.contains(&"delegate"));
@@ -1082,7 +1189,7 @@ mod tests {
let mut cfg = test_config(&tmp);
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Compact;
let (tools, _) = all_tools(
let (tools, _, _) = all_tools(
Arc::new(cfg.clone()),
&security,
mem,
@@ -1095,6 +1202,7 @@ mod tests {
&HashMap::new(),
None,
&cfg,
None,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"read_skill"));
@@ -1116,7 +1224,7 @@ mod tests {
let mut cfg = test_config(&tmp);
cfg.skills.prompt_injection_mode = crate::config::SkillsPromptInjectionMode::Full;
let (tools, _) = all_tools(
let (tools, _, _) = all_tools(
Arc::new(cfg.clone()),
&security,
mem,
@@ -1129,6 +1237,7 @@ mod tests {
&HashMap::new(),
None,
&cfg,
None,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(!names.contains(&"read_skill"));
+546
View File
@@ -0,0 +1,546 @@
//! Emoji reaction tool for cross-channel message reactions.
//!
//! Exposes `add_reaction` and `remove_reaction` from the [`Channel`] trait as an
//! agent-callable tool. The tool holds a late-binding channel map handle that is
//! populated once channels are initialized (after tool construction). This mirrors
//! the pattern used by [`DelegateTool`] for its parent-tools handle.
use super::traits::{Tool, ToolResult};
use crate::channels::traits::Channel;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use parking_lot::RwLock;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
/// Shared handle to the channel map. Starts empty; populated once channels boot.
pub type ChannelMapHandle = Arc<RwLock<HashMap<String, Arc<dyn Channel>>>>;
/// Agent-callable tool for adding or removing emoji reactions on messages.
pub struct ReactionTool {
channels: ChannelMapHandle,
security: Arc<SecurityPolicy>,
}
impl ReactionTool {
/// Create a new reaction tool with an empty channel map.
/// Call [`populate`] or write to the returned [`ChannelMapHandle`] once channels
/// are available.
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
security,
}
}
/// Return the shared handle so callers can populate it after channel init.
pub fn channel_map_handle(&self) -> ChannelMapHandle {
Arc::clone(&self.channels)
}
/// Convenience: populate the channel map from a pre-built map.
pub fn populate(&self, map: HashMap<String, Arc<dyn Channel>>) {
*self.channels.write() = map;
}
}
#[async_trait]
impl Tool for ReactionTool {
fn name(&self) -> &str {
"reaction"
}
fn description(&self) -> &str {
"Add or remove an emoji reaction on a message in any active channel. \
Provide the channel name (e.g. 'discord', 'slack'), the platform channel ID, \
the platform message ID, and the emoji (Unicode character or platform shortcode)."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"channel": {
"type": "string",
"description": "Name of the channel to react in (e.g. 'discord', 'slack', 'telegram')"
},
"channel_id": {
"type": "string",
"description": "Platform-specific channel/conversation identifier (e.g. Discord channel snowflake, Slack channel ID)"
},
"message_id": {
"type": "string",
"description": "Platform-scoped message identifier to react to"
},
"emoji": {
"type": "string",
"description": "Emoji to react with (Unicode character or platform shortcode)"
},
"action": {
"type": "string",
"enum": ["add", "remove"],
"description": "Whether to add or remove the reaction (default: 'add')"
}
},
"required": ["channel", "channel_id", "message_id", "emoji"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
// Security gate
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "reaction")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let channel_name = args
.get("channel")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'channel' parameter"))?;
let channel_id = args
.get("channel_id")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'channel_id' parameter"))?;
let message_id = args
.get("message_id")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'message_id' parameter"))?;
let emoji = args
.get("emoji")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'emoji' parameter"))?;
let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("add");
if action != "add" && action != "remove" {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid action '{action}': must be 'add' or 'remove'"
)),
});
}
// Read-lock the channel map to find the target channel.
let channel = {
let map = self.channels.read();
if map.is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("No channels available yet (channels not initialized)".to_string()),
});
}
match map.get(channel_name) {
Some(ch) => Arc::clone(ch),
None => {
let available: Vec<String> = map.keys().cloned().collect();
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Channel '{channel_name}' not found. Available channels: {}",
available.join(", ")
)),
});
}
}
};
let result = if action == "add" {
channel.add_reaction(channel_id, message_id, emoji).await
} else {
channel.remove_reaction(channel_id, message_id, emoji).await
};
let past_tense = if action == "remove" {
"removed"
} else {
"added"
};
match result {
Ok(()) => Ok(ToolResult {
success: true,
output: format!(
"Reaction {past_tense}: {emoji} on message {message_id} in {channel_name}"
),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to {action} reaction: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channels::traits::{ChannelMessage, SendMessage};
use std::sync::atomic::{AtomicBool, Ordering};
struct MockChannel {
reaction_added: AtomicBool,
reaction_removed: AtomicBool,
last_channel_id: parking_lot::Mutex<Option<String>>,
fail_on_add: bool,
}
impl MockChannel {
fn new() -> Self {
Self {
reaction_added: AtomicBool::new(false),
reaction_removed: AtomicBool::new(false),
last_channel_id: parking_lot::Mutex::new(None),
fail_on_add: false,
}
}
fn failing() -> Self {
Self {
reaction_added: AtomicBool::new(false),
reaction_removed: AtomicBool::new(false),
last_channel_id: parking_lot::Mutex::new(None),
fail_on_add: true,
}
}
}
#[async_trait]
impl Channel for MockChannel {
fn name(&self) -> &str {
"mock"
}
async fn send(&self, _message: &SendMessage) -> anyhow::Result<()> {
Ok(())
}
async fn listen(
&self,
_tx: tokio::sync::mpsc::Sender<ChannelMessage>,
) -> anyhow::Result<()> {
Ok(())
}
async fn add_reaction(
&self,
channel_id: &str,
_message_id: &str,
_emoji: &str,
) -> anyhow::Result<()> {
if self.fail_on_add {
return Err(anyhow::anyhow!("API error: rate limited"));
}
*self.last_channel_id.lock() = Some(channel_id.to_string());
self.reaction_added.store(true, Ordering::SeqCst);
Ok(())
}
async fn remove_reaction(
&self,
channel_id: &str,
_message_id: &str,
_emoji: &str,
) -> anyhow::Result<()> {
*self.last_channel_id.lock() = Some(channel_id.to_string());
self.reaction_removed.store(true, Ordering::SeqCst);
Ok(())
}
}
fn make_tool_with_channels(channels: Vec<(&str, Arc<dyn Channel>)>) -> ReactionTool {
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
let map: HashMap<String, Arc<dyn Channel>> = channels
.into_iter()
.map(|(name, ch)| (name.to_string(), ch))
.collect();
tool.populate(map);
tool
}
#[test]
fn tool_metadata() {
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
assert_eq!(tool.name(), "reaction");
assert!(!tool.description().is_empty());
let schema = tool.parameters_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["channel"].is_object());
assert!(schema["properties"]["channel_id"].is_object());
assert!(schema["properties"]["message_id"].is_object());
assert!(schema["properties"]["emoji"].is_object());
assert!(schema["properties"]["action"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.iter().any(|v| v == "channel"));
assert!(required.iter().any(|v| v == "channel_id"));
assert!(required.iter().any(|v| v == "message_id"));
assert!(required.iter().any(|v| v == "emoji"));
// action is optional (defaults to "add")
assert!(!required.iter().any(|v| v == "action"));
}
#[tokio::test]
async fn add_reaction_success() {
let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
let tool = make_tool_with_channels(vec![("discord", Arc::clone(&mock))]);
let result = tool
.execute(json!({
"channel": "discord",
"channel_id": "ch_001",
"message_id": "msg_123",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("added"));
assert!(result.error.is_none());
}
#[tokio::test]
async fn remove_reaction_success() {
let mock: Arc<dyn Channel> = Arc::new(MockChannel::new());
let tool = make_tool_with_channels(vec![("slack", Arc::clone(&mock))]);
let result = tool
.execute(json!({
"channel": "slack",
"channel_id": "C0123SLACK",
"message_id": "msg_456",
"emoji": "\u{1F440}",
"action": "remove"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("removed"));
}
#[tokio::test]
async fn unknown_channel_returns_error() {
let tool = make_tool_with_channels(vec![(
"discord",
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
)]);
let result = tool
.execute(json!({
"channel": "nonexistent",
"channel_id": "ch_x",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(!result.success);
let err = result.error.as_deref().unwrap();
assert!(err.contains("not found"));
assert!(err.contains("discord"));
}
#[tokio::test]
async fn invalid_action_returns_error() {
let tool = make_tool_with_channels(vec![(
"discord",
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
)]);
let result = tool
.execute(json!({
"channel": "discord",
"channel_id": "ch_001",
"message_id": "msg_1",
"emoji": "\u{2705}",
"action": "toggle"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("toggle"));
}
#[tokio::test]
async fn channel_error_propagated() {
let mock: Arc<dyn Channel> = Arc::new(MockChannel::failing());
let tool = make_tool_with_channels(vec![("discord", mock)]);
let result = tool
.execute(json!({
"channel": "discord",
"channel_id": "ch_001",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("rate limited"));
}
#[tokio::test]
async fn missing_required_params() {
let tool = make_tool_with_channels(vec![(
"test",
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
)]);
// Missing channel
let result = tool
.execute(json!({"channel_id": "c1", "message_id": "1", "emoji": "x"}))
.await;
assert!(result.is_err());
// Missing channel_id
let result = tool
.execute(json!({"channel": "test", "message_id": "1", "emoji": "x"}))
.await;
assert!(result.is_err());
// Missing message_id
let result = tool
.execute(json!({"channel": "a", "channel_id": "c1", "emoji": "x"}))
.await;
assert!(result.is_err());
// Missing emoji
let result = tool
.execute(json!({"channel": "a", "channel_id": "c1", "message_id": "1"}))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn empty_channels_returns_not_initialized() {
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
// No channels populated
let result = tool
.execute(json!({
"channel": "discord",
"channel_id": "ch_001",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_deref().unwrap().contains("not initialized"));
}
#[tokio::test]
async fn default_action_is_add() {
let mock = Arc::new(MockChannel::new());
let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
let tool = make_tool_with_channels(vec![("test", mock_ch)]);
let result = tool
.execute(json!({
"channel": "test",
"channel_id": "ch_test",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(result.success);
assert!(mock.reaction_added.load(Ordering::SeqCst));
assert!(!mock.reaction_removed.load(Ordering::SeqCst));
}
#[tokio::test]
async fn channel_id_passed_to_trait_not_channel_name() {
let mock = Arc::new(MockChannel::new());
let mock_ch: Arc<dyn Channel> = Arc::clone(&mock) as Arc<dyn Channel>;
let tool = make_tool_with_channels(vec![("discord", mock_ch)]);
let result = tool
.execute(json!({
"channel": "discord",
"channel_id": "123456789",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(result.success);
// The trait must receive the platform channel_id, not the channel name
assert_eq!(
mock.last_channel_id.lock().as_deref(),
Some("123456789"),
"add_reaction must receive channel_id, not channel name"
);
}
#[tokio::test]
async fn channel_map_handle_allows_late_binding() {
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
let handle = tool.channel_map_handle();
// Initially empty — tool reports not initialized
let result = tool
.execute(json!({
"channel": "slack",
"channel_id": "C0123",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(!result.success);
// Populate via the handle
{
let mut map = handle.write();
map.insert(
"slack".to_string(),
Arc::new(MockChannel::new()) as Arc<dyn Channel>,
);
}
// Now the tool can route to the channel
let result = tool
.execute(json!({
"channel": "slack",
"channel_id": "C0123",
"message_id": "msg_1",
"emoji": "\u{2705}"
}))
.await
.unwrap();
assert!(result.success);
}
#[test]
fn spec_matches_metadata() {
let tool = ReactionTool::new(Arc::new(SecurityPolicy::default()));
let spec = tool.spec();
assert_eq!(spec.name, "reaction");
assert_eq!(spec.description, tool.description());
assert!(spec.parameters["required"].is_array());
}
}
+573
View File
@@ -0,0 +1,573 @@
//! Session-to-session messaging tools for inter-agent communication.
//!
//! Provides three tools:
//! - `sessions_list` — list active sessions with metadata
//! - `sessions_history` — read message history from a specific session
//! - `sessions_send` — send a message to a specific session
use super::traits::{Tool, ToolResult};
use crate::channels::session_backend::SessionBackend;
use crate::security::policy::ToolOperation;
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::sync::Arc;
/// Validate that a session ID is non-empty and contains at least one
/// alphanumeric character (prevents blank keys after sanitization).
fn validate_session_id(session_id: &str) -> Result<(), ToolResult> {
let trimmed = session_id.trim();
if trimmed.is_empty() || !trimmed.chars().any(|c| c.is_alphanumeric()) {
return Err(ToolResult {
success: false,
output: String::new(),
error: Some(
"Invalid 'session_id': must be non-empty and contain at least one alphanumeric character.".into(),
),
});
}
Ok(())
}
// ── SessionsListTool ────────────────────────────────────────────────
/// Lists active sessions with their channel, last activity time, and message count.
pub struct SessionsListTool {
backend: Arc<dyn SessionBackend>,
}
impl SessionsListTool {
pub fn new(backend: Arc<dyn SessionBackend>) -> Self {
Self { backend }
}
}
#[async_trait]
impl Tool for SessionsListTool {
fn name(&self) -> &str {
"sessions_list"
}
fn description(&self) -> &str {
"List all active conversation sessions with their channel, last activity time, and message count."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"limit": {
"type": "integer",
"description": "Max sessions to return (default: 50)"
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
#[allow(clippy::cast_possible_truncation)]
let limit = args
.get("limit")
.and_then(serde_json::Value::as_u64)
.map_or(50, |v| v as usize);
let metadata = self.backend.list_sessions_with_metadata();
if metadata.is_empty() {
return Ok(ToolResult {
success: true,
output: "No active sessions found.".into(),
error: None,
});
}
let capped: Vec<_> = metadata.into_iter().take(limit).collect();
let mut output = format!("Found {} session(s):\n", capped.len());
for meta in &capped {
// Extract channel from key (convention: channel__identifier)
let channel = meta.key.split("__").next().unwrap_or(&meta.key);
let _ = writeln!(
output,
"- {}: channel={}, messages={}, last_activity={}",
meta.key, channel, meta.message_count, meta.last_activity
);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
// ── SessionsHistoryTool ─────────────────────────────────────────────
/// Reads the message history of a specific session by ID.
pub struct SessionsHistoryTool {
backend: Arc<dyn SessionBackend>,
security: Arc<SecurityPolicy>,
}
impl SessionsHistoryTool {
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
Self { backend, security }
}
}
#[async_trait]
impl Tool for SessionsHistoryTool {
fn name(&self) -> &str {
"sessions_history"
}
fn description(&self) -> &str {
"Read the message history of a specific session by its session ID. Returns the last N messages."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "The session ID to read history from (e.g. telegram__user123)"
},
"limit": {
"type": "integer",
"description": "Max messages to return, from most recent (default: 20)"
}
},
"required": ["session_id"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Read, "sessions_history")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let session_id = args
.get("session_id")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
if let Err(result) = validate_session_id(session_id) {
return Ok(result);
}
#[allow(clippy::cast_possible_truncation)]
let limit = args
.get("limit")
.and_then(serde_json::Value::as_u64)
.map_or(20, |v| v as usize);
let messages = self.backend.load(session_id);
if messages.is_empty() {
return Ok(ToolResult {
success: true,
output: format!("No messages found for session '{session_id}'."),
error: None,
});
}
// Take the last `limit` messages
let start = messages.len().saturating_sub(limit);
let tail = &messages[start..];
let mut output = format!(
"Session '{}': showing {}/{} messages\n",
session_id,
tail.len(),
messages.len()
);
for msg in tail {
let _ = writeln!(output, "[{}] {}", msg.role, msg.content);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
// ── SessionsSendTool ────────────────────────────────────────────────
/// Sends a message to a specific session, enabling inter-agent communication.
pub struct SessionsSendTool {
backend: Arc<dyn SessionBackend>,
security: Arc<SecurityPolicy>,
}
impl SessionsSendTool {
pub fn new(backend: Arc<dyn SessionBackend>, security: Arc<SecurityPolicy>) -> Self {
Self { backend, security }
}
}
#[async_trait]
impl Tool for SessionsSendTool {
fn name(&self) -> &str {
"sessions_send"
}
fn description(&self) -> &str {
"Send a message to a specific session by its session ID. The message is appended to the session's conversation history as a 'user' message, enabling inter-agent communication."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "The target session ID (e.g. telegram__user123)"
},
"message": {
"type": "string",
"description": "The message content to send"
}
},
"required": ["session_id", "message"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if let Err(error) = self
.security
.enforce_tool_operation(ToolOperation::Act, "sessions_send")
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(error),
});
}
let session_id = args
.get("session_id")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'session_id' parameter"))?;
if let Err(result) = validate_session_id(session_id) {
return Ok(result);
}
let message = args
.get("message")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?;
if message.trim().is_empty() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Message content must not be empty.".into()),
});
}
let chat_msg = crate::providers::traits::ChatMessage::user(message);
match self.backend.append(session_id, &chat_msg) {
Ok(()) => Ok(ToolResult {
success: true,
output: format!("Message sent to session '{session_id}'."),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to send message: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channels::session_store::SessionStore;
use crate::providers::traits::ChatMessage;
use tempfile::TempDir;
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy::default())
}
fn test_backend() -> (TempDir, Arc<dyn SessionBackend>) {
let tmp = TempDir::new().unwrap();
let store = SessionStore::new(tmp.path()).unwrap();
(tmp, Arc::new(store))
}
fn seeded_backend() -> (TempDir, Arc<dyn SessionBackend>) {
let tmp = TempDir::new().unwrap();
let store = SessionStore::new(tmp.path()).unwrap();
store
.append("telegram__alice", &ChatMessage::user("Hello from Alice"))
.unwrap();
store
.append(
"telegram__alice",
&ChatMessage::assistant("Hi Alice, how can I help?"),
)
.unwrap();
store
.append("discord__bob", &ChatMessage::user("Hey from Bob"))
.unwrap();
(tmp, Arc::new(store))
}
// ── SessionsListTool tests ──────────────────────────────────────
#[tokio::test]
async fn list_empty_sessions() {
let (_tmp, backend) = test_backend();
let tool = SessionsListTool::new(backend);
let result = tool.execute(json!({})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No active sessions"));
}
#[tokio::test]
async fn list_sessions_shows_all() {
let (_tmp, backend) = seeded_backend();
let tool = SessionsListTool::new(backend);
let result = tool.execute(json!({})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("2 session(s)"));
assert!(result.output.contains("telegram__alice"));
assert!(result.output.contains("discord__bob"));
}
#[tokio::test]
async fn list_sessions_respects_limit() {
let (_tmp, backend) = seeded_backend();
let tool = SessionsListTool::new(backend);
let result = tool.execute(json!({"limit": 1})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("1 session(s)"));
}
#[tokio::test]
async fn list_sessions_extracts_channel() {
let (_tmp, backend) = seeded_backend();
let tool = SessionsListTool::new(backend);
let result = tool.execute(json!({})).await.unwrap();
assert!(result.output.contains("channel=telegram"));
assert!(result.output.contains("channel=discord"));
}
#[test]
fn list_tool_name_and_schema() {
let (_tmp, backend) = test_backend();
let tool = SessionsListTool::new(backend);
assert_eq!(tool.name(), "sessions_list");
assert!(tool.parameters_schema()["properties"]["limit"].is_object());
}
// ── SessionsHistoryTool tests ───────────────────────────────────
#[tokio::test]
async fn history_empty_session() {
let (_tmp, backend) = test_backend();
let tool = SessionsHistoryTool::new(backend, test_security());
let result = tool
.execute(json!({"session_id": "nonexistent"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("No messages found"));
}
#[tokio::test]
async fn history_returns_messages() {
let (_tmp, backend) = seeded_backend();
let tool = SessionsHistoryTool::new(backend, test_security());
let result = tool
.execute(json!({"session_id": "telegram__alice"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("showing 2/2 messages"));
assert!(result.output.contains("[user] Hello from Alice"));
assert!(result.output.contains("[assistant] Hi Alice"));
}
#[tokio::test]
async fn history_respects_limit() {
let (_tmp, backend) = seeded_backend();
let tool = SessionsHistoryTool::new(backend, test_security());
let result = tool
.execute(json!({"session_id": "telegram__alice", "limit": 1}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("showing 1/2 messages"));
// Should show only the last message
assert!(result.output.contains("[assistant]"));
assert!(!result.output.contains("[user] Hello from Alice"));
}
#[tokio::test]
async fn history_missing_session_id() {
let (_tmp, backend) = test_backend();
let tool = SessionsHistoryTool::new(backend, test_security());
let result = tool.execute(json!({})).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("session_id"));
}
#[tokio::test]
async fn history_rejects_empty_session_id() {
let (_tmp, backend) = test_backend();
let tool = SessionsHistoryTool::new(backend, test_security());
let result = tool.execute(json!({"session_id": " "})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Invalid"));
}
#[test]
fn history_tool_name_and_schema() {
let (_tmp, backend) = test_backend();
let tool = SessionsHistoryTool::new(backend, test_security());
assert_eq!(tool.name(), "sessions_history");
let schema = tool.parameters_schema();
assert!(schema["properties"]["session_id"].is_object());
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("session_id")));
}
// ── SessionsSendTool tests ──────────────────────────────────────
#[tokio::test]
async fn send_appends_message() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend.clone(), test_security());
let result = tool
.execute(json!({
"session_id": "telegram__alice",
"message": "Hello from another agent"
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Message sent"));
// Verify message was appended
let messages = backend.load("telegram__alice");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, "user");
assert_eq!(messages[0].content, "Hello from another agent");
}
#[tokio::test]
async fn send_to_existing_session() {
let (_tmp, backend) = seeded_backend();
let tool = SessionsSendTool::new(backend.clone(), test_security());
let result = tool
.execute(json!({
"session_id": "telegram__alice",
"message": "Inter-agent message"
}))
.await
.unwrap();
assert!(result.success);
let messages = backend.load("telegram__alice");
assert_eq!(messages.len(), 3);
assert_eq!(messages[2].content, "Inter-agent message");
}
#[tokio::test]
async fn send_rejects_empty_message() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend, test_security());
let result = tool
.execute(json!({
"session_id": "telegram__alice",
"message": " "
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("empty"));
}
#[tokio::test]
async fn send_rejects_empty_session_id() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend, test_security());
let result = tool
.execute(json!({
"session_id": "",
"message": "hello"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Invalid"));
}
#[tokio::test]
async fn send_rejects_non_alphanumeric_session_id() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend, test_security());
let result = tool
.execute(json!({
"session_id": "///",
"message": "hello"
}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("Invalid"));
}
#[tokio::test]
async fn send_missing_session_id() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend, test_security());
let result = tool.execute(json!({"message": "hi"})).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("session_id"));
}
#[tokio::test]
async fn send_missing_message() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend, test_security());
let result = tool.execute(json!({"session_id": "telegram__alice"})).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("message"));
}
#[test]
fn send_tool_name_and_schema() {
let (_tmp, backend) = test_backend();
let tool = SessionsSendTool::new(backend, test_security());
assert_eq!(tool.name(), "sessions_send");
let schema = tool.parameters_schema();
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("session_id")));
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("message")));
}
}
+224
View File
@@ -0,0 +1,224 @@
//! HTTP-based tool derived from a skill's `[[tools]]` section.
//!
//! Each `SkillTool` with `kind = "http"` is converted into a `SkillHttpTool`
//! that implements the `Tool` trait. The command field is used as the URL
//! template and args are substituted as query parameters or path segments.
use super::traits::{Tool, ToolResult};
use async_trait::async_trait;
use std::collections::HashMap;
use std::time::Duration;
/// Maximum response body size (1 MB).
const MAX_RESPONSE_BYTES: usize = 1_048_576;
/// HTTP request timeout (seconds).
const HTTP_TIMEOUT_SECS: u64 = 30;
/// A tool derived from a skill's `[[tools]]` section that makes HTTP requests.
pub struct SkillHttpTool {
tool_name: String,
tool_description: String,
url_template: String,
args: HashMap<String, String>,
}
impl SkillHttpTool {
/// Create a new skill HTTP tool.
///
/// The tool name is prefixed with the skill name (`skill_name.tool_name`)
/// to prevent collisions with built-in tools.
pub fn new(skill_name: &str, tool: &crate::skills::SkillTool) -> Self {
Self {
tool_name: format!("{}.{}", skill_name, tool.name),
tool_description: tool.description.clone(),
url_template: tool.command.clone(),
args: tool.args.clone(),
}
}
fn build_parameters_schema(&self) -> serde_json::Value {
let mut properties = serde_json::Map::new();
let mut required = Vec::new();
for (name, description) in &self.args {
properties.insert(
name.clone(),
serde_json::json!({
"type": "string",
"description": description
}),
);
required.push(serde_json::Value::String(name.clone()));
}
serde_json::json!({
"type": "object",
"properties": properties,
"required": required
})
}
/// Substitute `{{arg_name}}` placeholders in the URL template with
/// the provided argument values.
fn substitute_args(&self, args: &serde_json::Value) -> String {
let mut url = self.url_template.clone();
if let Some(obj) = args.as_object() {
for (key, value) in obj {
let placeholder = format!("{{{{{}}}}}", key);
let replacement = value.as_str().unwrap_or_default();
url = url.replace(&placeholder, replacement);
}
}
url
}
}
#[async_trait]
impl Tool for SkillHttpTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
&self.tool_description
}
fn parameters_schema(&self) -> serde_json::Value {
self.build_parameters_schema()
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let url = self.substitute_args(&args);
// Validate URL scheme
if !url.starts_with("http://") && !url.starts_with("https://") {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Only http:// and https:// URLs are allowed, got: {url}"
)),
});
}
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(HTTP_TIMEOUT_SECS))
.build()
.map_err(|e| anyhow::anyhow!("Failed to build HTTP client: {e}"))?;
let response = match client.get(&url).send().await {
Ok(resp) => resp,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("HTTP request failed: {e}")),
});
}
};
let status = response.status();
let body = match response.bytes().await {
Ok(bytes) => {
let mut text = String::from_utf8_lossy(&bytes).to_string();
if text.len() > MAX_RESPONSE_BYTES {
let mut b = MAX_RESPONSE_BYTES.min(text.len());
while b > 0 && !text.is_char_boundary(b) {
b -= 1;
}
text.truncate(b);
text.push_str("\n... [response truncated at 1MB]");
}
text
}
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to read response body: {e}")),
});
}
};
Ok(ToolResult {
success: status.is_success(),
output: body,
error: if status.is_success() {
None
} else {
Some(format!("HTTP {}", status))
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::skills::SkillTool;
fn sample_http_tool() -> SkillTool {
let mut args = HashMap::new();
args.insert("city".to_string(), "City name to look up".to_string());
SkillTool {
name: "get_weather".to_string(),
description: "Fetch weather for a city".to_string(),
kind: "http".to_string(),
command: "https://api.example.com/weather?city={{city}}".to_string(),
args,
}
}
#[test]
fn skill_http_tool_name_is_prefixed() {
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
assert_eq!(tool.name(), "weather_skill.get_weather");
}
#[test]
fn skill_http_tool_description() {
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
assert_eq!(tool.description(), "Fetch weather for a city");
}
#[test]
fn skill_http_tool_parameters_schema() {
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
let schema = tool.parameters_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["city"].is_object());
assert_eq!(schema["properties"]["city"]["type"], "string");
}
#[test]
fn skill_http_tool_substitute_args() {
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
let result = tool.substitute_args(&serde_json::json!({"city": "London"}));
assert_eq!(result, "https://api.example.com/weather?city=London");
}
#[test]
fn skill_http_tool_spec_roundtrip() {
let tool = SkillHttpTool::new("weather_skill", &sample_http_tool());
let spec = tool.spec();
assert_eq!(spec.name, "weather_skill.get_weather");
assert_eq!(spec.description, "Fetch weather for a city");
assert_eq!(spec.parameters["type"], "object");
}
#[test]
fn skill_http_tool_empty_args() {
let st = SkillTool {
name: "ping".to_string(),
description: "Ping endpoint".to_string(),
kind: "http".to_string(),
command: "https://api.example.com/ping".to_string(),
args: HashMap::new(),
};
let tool = SkillHttpTool::new("s", &st);
let schema = tool.parameters_schema();
assert!(schema["properties"].as_object().unwrap().is_empty());
}
}

Some files were not shown because too many files have changed in this diff Show More