Compare commits

...

12 Commits

Author SHA1 Message Date
Giulio V 675a5c9af0 feat(tools): add Google Workspace CLI (gws) integration (#3616)
* feat(tools): add Google Workspace CLI (gws) integration

Adds GoogleWorkspaceTool for interacting with Google Drive, Sheets,
Gmail, Calendar, Docs, and other Workspace services via CLI.

- Config-gated (google_workspace.enabled)
- Service allowlist for restricted access
- Requires shell access for CLI delegation
- Input validation against shell injection
- Wrong-type rejection for all optional parameters
- Config validation for allowed_services (empty, duplicate, malformed)
- Registered in integrations registry and CLI discovery

Closes #2986

* style: fix cargo fmt + clippy violations

* feat(google-workspace): expand config with auth, rate limits, and audit settings

* fix(tools): define missing GWS_TIMEOUT_SECS constant

* fix: Box::pin large futures and resolve duplicate Default impl

---------

Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
2026-03-17 00:52:59 -04:00
Giulio V b099728c27 feat(stt): multi-provider STT with TranscriptionProvider trait (#3614)
* feat(stt): add multi-provider STT with TranscriptionProvider trait

Refactors single-endpoint transcription to support multiple providers:
Groq (existing), OpenAI Whisper, Deepgram, AssemblyAI, and Google Cloud
Speech-to-Text. Adds TranscriptionManager for provider routing with
backward-compatible config fields.

* style: fix cargo fmt + clippy violations

* fix: Box::pin large futures and resolve merge conflicts with master

---------

Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
2026-03-17 00:33:41 -04:00
Argenis 1ca2092ca0 test(channel): add QQ markdown msg_type regression test (#3752)
Verify that QQ send body uses msg_type 2 with nested markdown object
instead of msg_type 0 with top-level content. Adapted from #3668.
2026-03-16 22:03:43 -04:00
Giulio V 5e3308eaaa feat(providers): add Claude Code, Gemini CLI, and KiloCLI subprocess providers (#3615)
* feat(providers): add Claude Code, Gemini CLI, and KiloCLI subprocess providers

Adds three new local subprocess-based providers for AI CLI tools.
Each provider spawns the CLI as a child process, communicates via
stdin/stdout pipes, and parses responses into ChatResponse format.

* fix: resolve clippy unnecessary_debug_formatting and rustfmt violations

* fix: resolve remaining clippy unnecessary_debug_formatting in CLI providers

* fix(providers): add AiAgent CLI category for subprocess providers
2026-03-16 21:51:05 -04:00
Chris Hengge ec255ad788 fix(tool): expand cron_add and cron_update parameter schemas (#3671)
The schedule field in cron_add used a bare {"type":"object"} with a
description string encoding a tagged union in pseudo-notation. The patch
field in cron_update was an opaque {"type":"object"} despite CronJobPatch
having nine fully-typed fields. Both gaps cause weaker instruction-following
models to produce malformed or missing nested JSON when invoking these tools.

Changes:
- cron_add: expand schedule into a oneOf discriminated union with explicit
  properties and required fields for each variant (cron/at/every), matching
  the Schedule enum in src/cron/types.rs exactly
- cron_add: add descriptions to all previously undocumented top-level fields
- cron_add: expand delivery from a bare inline comment to fully-specified
  properties with per-field descriptions
- cron_update: expand patch from opaque object to full properties matching
  CronJobPatch (name, enabled, command, prompt, model, session_target,
  delete_after_run, schedule, delivery)
- cron_update: schedule inside patch mirrors the same oneOf expansion
- Both: add inline NOTE comments flagging that oneOf is correct for
  OpenAI-compatible APIs but SchemaCleanr::clean_for_gemini must be
  applied if Gemini native tool calling is ever wired up
- Both: add schema-shape tests using the existing test_config/test_security
  helper pattern, covering oneOf variant structure, required fields, and
  delivery channel enum completeness

No behavior changes. No new dependencies. Backward compatible: the runtime
deserialization path (serde on Schedule/CronJobPatch) is unchanged.

Co-authored-by: Argenis <theonlyhennygod@gmail.com>
2026-03-16 21:45:49 -04:00
Sid Jain 7182f659ce fix(slack): honor mention_only in runtime channel wiring (#3715)
* feat(slack): wire mention_only group reply policy

* feat(slack): expose mention_only in config and wizard defaults
2026-03-16 21:40:47 -04:00
Ericsunsk ae7681209d fix(openai-codex): decode utf-8 safely across stream chunks (#3723) 2026-03-16 21:40:45 -04:00
Markus Bergholz ee3469e912 Fix: Support Nextcloud Talk Activity Streams 2.0 webhook format (#3737)
* fix

* fix

* format
2026-03-16 21:40:42 -04:00
Argenis fec81d8e75 ci: auto-sync Scoop and AUR on stable release (#3743)
Add workflow_call triggers to pub-scoop.yml and pub-aur.yml so the
stable release workflow can invoke them automatically after publish.

Wire scoop and aur jobs into release-stable-manual.yml as post-publish
steps (parallel with tweet), gated on publish success.

Update ci-map.md trigger docs to reflect auto-called behavior.
2026-03-16 21:34:29 -04:00
Ricardo Madriz 9a073fae1a fix(tools) Wire activated toolset into dispatch (#3747)
* fix(tools): wire ActivatedToolSet into tool dispatch and spec advertisement

When deferred MCP tools are activated via tool_search, they are stored
in ActivatedToolSet but never consulted by the tool call loop.
tool_specs is built once before the iteration loop and never refreshed,
so the provider API tools[] parameter never includes activated tools.
find_tool only searches the static registry, so execution dispatch also
fails silently.

Thread Arc<Mutex<ActivatedToolSet>> from creation sites through to
run_tool_call_loop. Rebuild tool_specs each iteration to merge base
registry specs with activated specs. Add fallback in execute_one_tool
to check the activated set when the static registry lookup misses.

Change ActivatedToolSet internal storage from Box<dyn Tool> to
Arc<dyn Tool> so we can clone the Arc out of the mutex guard before
awaiting tool.execute() (std::sync::MutexGuard is not Send).

* fix(tools): add activated_tools field to new ChannelRuntimeContext test site
2026-03-16 21:34:08 -04:00
Chris Hengge f0db63e53c fix(integrations): wire Cron and Browser status to config fields (#3750)
Both entries had hardcoded |_| IntegrationStatus::Available, ignoring
the live config entirely. Users with cron.enabled = true or
browser.enabled = true saw 'Available' on the /integrations dashboard
card instead of 'Active'.

Root cause: status_fn closures did not capture the Config argument.

Fix: replace the |_| stubs with |c| closures that check c.cron.enabled
and c.browser.enabled respectively, matching the pattern used by every
other wired entry in the registry (Telegram, Discord, Shell, etc.).

What did NOT change: ComingSoon entries, always-Active entries (Shell,
File System), platform entries, or any other registry logic.
2026-03-16 21:34:06 -04:00
Argenis df4dfeaf66 chore: bump version to 0.4.3 (#3749)
Update version across Cargo.toml, Cargo.lock, Scoop manifest,
and AUR PKGBUILD/.SRCINFO for the v0.4.3 stable release.
2026-03-16 21:23:04 -04:00
34 changed files with 3709 additions and 151 deletions
+14
View File
@@ -1,6 +1,20 @@
name: Pub AUR Package
on:
workflow_call:
inputs:
release_tag:
description: "Existing release tag (vX.Y.Z)"
required: true
type: string
dry_run:
description: "Generate PKGBUILD only (no push)"
required: false
default: false
type: boolean
secrets:
AUR_SSH_KEY:
required: false
workflow_dispatch:
inputs:
release_tag:
+14
View File
@@ -1,6 +1,20 @@
name: Pub Scoop Manifest
on:
workflow_call:
inputs:
release_tag:
description: "Existing release tag (vX.Y.Z)"
required: true
type: string
dry_run:
description: "Generate manifest only (no push)"
required: false
default: false
type: boolean
secrets:
SCOOP_BUCKET_TOKEN:
required: false
workflow_dispatch:
inputs:
release_tag:
@@ -361,6 +361,27 @@ jobs:
cache-from: type=gha
cache-to: type=gha,mode=max
# ── Post-publish: package manager auto-sync ─────────────────────────
scoop:
name: Update Scoop Manifest
needs: [validate, publish]
if: ${{ !cancelled() && needs.publish.result == 'success' }}
uses: ./.github/workflows/pub-scoop.yml
with:
release_tag: ${{ needs.validate.outputs.tag }}
dry_run: false
secrets: inherit
aur:
name: Update AUR Package
needs: [validate, publish]
if: ${{ !cancelled() && needs.publish.result == 'success' }}
uses: ./.github/workflows/pub-aur.yml
with:
release_tag: ${{ needs.validate.outputs.tag }}
dry_run: false
secrets: inherit
# ── Post-publish: tweet after release + website are live ──────────────
# Docker push can be slow; don't let it block the tweet.
tweet:
Generated
+1 -1
View File
@@ -7945,7 +7945,7 @@ dependencies = [
[[package]]
name = "zeroclawlabs"
version = "0.4.2"
version = "0.4.3"
dependencies = [
"anyhow",
"async-imap",
+1 -1
View File
@@ -4,7 +4,7 @@ resolver = "2"
[package]
name = "zeroclawlabs"
version = "0.4.2"
version = "0.4.3"
edition = "2021"
authors = ["theonlyhennygod"]
license = "MIT OR Apache-2.0"
+2 -2
View File
@@ -1,6 +1,6 @@
pkgbase = zeroclaw
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
pkgver = 0.4.2
pkgver = 0.4.3
pkgrel = 1
url = https://github.com/zeroclaw-labs/zeroclaw
arch = x86_64
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
makedepends = git
depends = gcc-libs
depends = openssl
source = zeroclaw-0.4.2.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.4.2.tar.gz
source = zeroclaw-0.4.3.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.4.3.tar.gz
sha256sums = SKIP
pkgname = zeroclaw
+1 -1
View File
@@ -1,6 +1,6 @@
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
pkgname=zeroclaw
pkgver=0.4.2
pkgver=0.4.3
pkgrel=1
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
arch=('x86_64')
+2 -2
View File
@@ -1,11 +1,11 @@
{
"version": "0.4.2",
"version": "0.4.3",
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
"license": "MIT|Apache-2.0",
"architecture": {
"64bit": {
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.4.2/zeroclaw-x86_64-pc-windows-msvc.zip",
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.4.3/zeroclaw-x86_64-pc-windows-msvc.zip",
"hash": "",
"bin": "zeroclaw.exe"
}
+4 -4
View File
@@ -38,10 +38,10 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
- Purpose: manual, bot-owned Homebrew core formula bump PR flow for tagged releases
- Guardrail: release tag must match `Cargo.toml` version
- `.github/workflows/pub-scoop.yml` (`Pub Scoop Manifest`)
- Purpose: manual Scoop bucket manifest update for Windows distribution
- Purpose: Scoop bucket manifest update for Windows; auto-called by stable release, also manual dispatch
- Guardrail: release tag must be `vX.Y.Z` format; Windows binary hash extracted from `SHA256SUMS`
- `.github/workflows/pub-aur.yml` (`Pub AUR Package`)
- Purpose: manual AUR PKGBUILD push for Arch Linux distribution
- Purpose: AUR PKGBUILD push for Arch Linux; auto-called by stable release, also manual dispatch
- Guardrail: release tag must be `vX.Y.Z` format; source tarball SHA256 computed at publish time
- `.github/workflows/pr-label-policy-check.yml` (`Label Policy Sanity`)
- Purpose: validate shared contributor-tier policy in `.github/label-policy.json` and ensure label workflows consume that policy
@@ -81,8 +81,8 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
- `Docker`: tag push (`v*`) for publish, matching PRs to `master` for smoke build, manual dispatch for smoke only
- `Release`: tag push (`v*`), weekly schedule (verification-only), manual dispatch (verification or publish)
- `Pub Homebrew Core`: manual dispatch only
- `Pub Scoop Manifest`: manual dispatch only
- `Pub AUR Package`: manual dispatch only
- `Pub Scoop Manifest`: auto-called by stable release, also manual dispatch
- `Pub AUR Package`: auto-called by stable release, also manual dispatch
- `Security Audit`: push to `master`, PRs to `master`, weekly schedule
- `Sec Vorpal Reviewdog`: manual dispatch only
- `Workflow Sanity`: PR/push when `.github/workflows/**`, `.github/*.yml`, or `.github/*.yaml` change
+48 -8
View File
@@ -2152,6 +2152,7 @@ pub(crate) async fn agent_turn(
None,
&[],
&[],
None,
)
.await
}
@@ -2160,6 +2161,7 @@ async fn execute_one_tool(
call_name: &str,
call_arguments: serde_json::Value,
tools_registry: &[Box<dyn Tool>],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
observer: &dyn Observer,
cancellation_token: Option<&CancellationToken>,
) -> Result<ToolExecutionOutcome> {
@@ -2170,7 +2172,13 @@ async fn execute_one_tool(
});
let start = Instant::now();
let Some(tool) = find_tool(tools_registry, call_name) else {
let static_tool = find_tool(tools_registry, call_name);
let activated_arc = if static_tool.is_none() {
activated_tools.and_then(|at| at.lock().unwrap().get(call_name))
} else {
None
};
let Some(tool) = static_tool.or(activated_arc.as_deref()) else {
let reason = format!("Unknown tool: {call_name}");
let duration = start.elapsed();
observer.record_event(&ObserverEvent::ToolCall {
@@ -2268,6 +2276,7 @@ fn should_execute_tools_in_parallel(
async fn execute_tools_parallel(
tool_calls: &[ParsedToolCall],
tools_registry: &[Box<dyn Tool>],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
observer: &dyn Observer,
cancellation_token: Option<&CancellationToken>,
) -> Result<Vec<ToolExecutionOutcome>> {
@@ -2278,6 +2287,7 @@ async fn execute_tools_parallel(
&call.name,
call.arguments.clone(),
tools_registry,
activated_tools,
observer,
cancellation_token,
)
@@ -2291,6 +2301,7 @@ async fn execute_tools_parallel(
async fn execute_tools_sequential(
tool_calls: &[ParsedToolCall],
tools_registry: &[Box<dyn Tool>],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
observer: &dyn Observer,
cancellation_token: Option<&CancellationToken>,
) -> Result<Vec<ToolExecutionOutcome>> {
@@ -2302,6 +2313,7 @@ async fn execute_tools_sequential(
&call.name,
call.arguments.clone(),
tools_registry,
activated_tools,
observer,
cancellation_token,
)
@@ -2345,6 +2357,7 @@ pub(crate) async fn run_tool_call_loop(
hooks: Option<&crate::hooks::HookRunner>,
excluded_tools: &[String],
dedup_exempt_tools: &[String],
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
) -> Result<String> {
let max_iterations = if max_tool_iterations == 0 {
DEFAULT_MAX_TOOL_ITERATIONS
@@ -2352,12 +2365,6 @@ pub(crate) async fn run_tool_call_loop(
max_tool_iterations
};
let tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
.iter()
.filter(|tool| !excluded_tools.iter().any(|ex| ex == tool.name()))
.map(|tool| tool.spec())
.collect();
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
let turn_id = Uuid::new_v4().to_string();
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
@@ -2369,6 +2376,21 @@ pub(crate) async fn run_tool_call_loop(
return Err(ToolLoopCancelled.into());
}
// Rebuild tool_specs each iteration so newly activated deferred tools appear.
let mut tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
.iter()
.filter(|tool| !excluded_tools.iter().any(|ex| ex == tool.name()))
.map(|tool| tool.spec())
.collect();
if let Some(at) = activated_tools {
for spec in at.lock().unwrap().tool_specs() {
if !excluded_tools.iter().any(|ex| ex == &spec.name) {
tool_specs.push(spec);
}
}
}
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
let image_marker_count = multimodal::count_image_markers(history);
if image_marker_count > 0 && !provider.supports_vision() {
return Err(ProviderCapabilityError {
@@ -2847,6 +2869,7 @@ pub(crate) async fn run_tool_call_loop(
execute_tools_parallel(
&executable_calls,
tools_registry,
activated_tools,
observer,
cancellation_token.as_ref(),
)
@@ -2855,6 +2878,7 @@ pub(crate) async fn run_tool_call_loop(
execute_tools_sequential(
&executable_calls,
tools_registry,
activated_tools,
observer,
cancellation_token.as_ref(),
)
@@ -3106,6 +3130,9 @@ pub async fn run(
// eagerly. Instead, a `tool_search` built-in is registered so the LLM can
// fetch schemas on demand. This reduces context window waste.
let mut deferred_section = String::new();
let mut activated_handle: Option<
std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>,
> = None;
if config.mcp.enabled && !config.mcp.servers.is_empty() {
tracing::info!(
"Initializing MCP client — {} server(s) configured",
@@ -3130,6 +3157,7 @@ pub async fn run(
let activated = std::sync::Arc::new(std::sync::Mutex::new(
crate::tools::ActivatedToolSet::new(),
));
activated_handle = Some(std::sync::Arc::clone(&activated));
tools_registry.push(Box::new(crate::tools::ToolSearchTool::new(
deferred_set,
activated,
@@ -3442,6 +3470,7 @@ pub async fn run(
None,
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
)
.await?;
final_output = response.clone();
@@ -3603,6 +3632,7 @@ pub async fn run(
None,
&excluded_tools,
&config.agent.tool_call_dedup_exempt,
activated_handle.as_ref(),
)
.await
{
@@ -3982,7 +4012,8 @@ mod tests {
.expect("should produce a sample whose byte index 300 is not a char boundary");
let observer = NoopObserver;
let result = execute_one_tool("unknown_tool", call_arguments, &[], &observer, None).await;
let result =
execute_one_tool("unknown_tool", call_arguments, &[], None, &observer, None).await;
assert!(result.is_ok(), "execute_one_tool should not panic or error");
let outcome = result.unwrap();
@@ -4319,6 +4350,7 @@ mod tests {
None,
&[],
&[],
None,
)
.await
.expect_err("provider without vision support should fail");
@@ -4366,6 +4398,7 @@ mod tests {
None,
&[],
&[],
None,
)
.await
.expect_err("oversized payload must fail");
@@ -4407,6 +4440,7 @@ mod tests {
None,
&[],
&[],
None,
)
.await
.expect("valid multimodal payload should pass");
@@ -4534,6 +4568,7 @@ mod tests {
None,
&[],
&[],
None,
)
.await
.expect("parallel execution should complete");
@@ -4604,6 +4639,7 @@ mod tests {
None,
&[],
&[],
None,
)
.await
.expect("loop should finish after deduplicating repeated calls");
@@ -4666,6 +4702,7 @@ mod tests {
None,
&[],
&exempt,
None,
)
.await
.expect("loop should finish with exempt tool executing twice");
@@ -4743,6 +4780,7 @@ mod tests {
None,
&[],
&exempt,
None,
)
.await
.expect("loop should complete");
@@ -4797,6 +4835,7 @@ mod tests {
None,
&[],
&[],
None,
)
.await
.expect("native fallback id flow should complete");
@@ -6698,6 +6737,7 @@ Let me check the result."#;
None,
&[],
&[],
None,
)
.await
.expect("tool loop should complete");
+37
View File
@@ -329,6 +329,7 @@ struct ChannelRuntimeContext {
/// `[autonomy]` config; auto-denies tools that would need interactive
/// approval since no operator is present on channel runs.
approval_manager: Arc<ApprovalManager>,
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
}
#[derive(Clone)]
@@ -2140,6 +2141,7 @@ async fn process_channel_message(
ctx.non_cli_excluded_tools.as_ref()
},
ctx.tool_call_dedup_exempt.as_ref(),
ctx.activated_tools.as_ref(),
),
) => LlmExecutionResult::Completed(result),
};
@@ -3236,6 +3238,7 @@ fn collect_configured_channels(
Vec::new(),
sl.allowed_users.clone(),
)
.with_group_reply_policy(sl.mention_only, Vec::new())
.with_workspace_dir(config.workspace_dir.clone()),
),
});
@@ -3699,6 +3702,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
// Instead, a `tool_search` built-in is registered for on-demand loading.
let mut deferred_section = String::new();
let mut ch_activated_handle: Option<
std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>,
> = None;
if config.mcp.enabled && !config.mcp.servers.is_empty() {
tracing::info!(
"Initializing MCP client — {} server(s) configured",
@@ -3722,6 +3728,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
let activated = std::sync::Arc::new(std::sync::Mutex::new(
crate::tools::ActivatedToolSet::new(),
));
ch_activated_handle = Some(std::sync::Arc::clone(&activated));
built_tools.push(Box::new(crate::tools::ToolSearchTool::new(
deferred_set,
activated,
@@ -4016,6 +4023,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
None
},
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
activated_tools: ch_activated_handle,
});
// Hydrate in-memory conversation histories from persisted JSONL session files.
@@ -4308,6 +4316,7 @@ mod tests {
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
};
assert!(compact_sender_history(&ctx, &sender));
@@ -4416,6 +4425,7 @@ mod tests {
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
};
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
@@ -4480,6 +4490,7 @@ mod tests {
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
};
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
@@ -4563,6 +4574,7 @@ mod tests {
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
};
assert!(rollback_orphan_user_turn(
@@ -5096,6 +5108,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5168,6 +5181,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5254,6 +5268,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5325,6 +5340,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5406,6 +5422,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5507,6 +5524,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5590,6 +5608,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5688,6 +5707,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5771,6 +5791,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -5844,6 +5865,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -6028,6 +6050,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
@@ -6120,6 +6143,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -6226,6 +6250,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
query_classification: crate::config::QueryClassificationConfig::default(),
});
@@ -6331,6 +6356,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
@@ -6417,6 +6443,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -6488,6 +6515,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -7117,6 +7145,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -7214,6 +7243,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -7311,6 +7341,7 @@ BTC is currently around $65,000 based on latest tool output."#
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -7872,6 +7903,7 @@ This is an example JSON object for profile settings."#;
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
// Simulate a photo attachment message with [IMAGE:] marker.
@@ -7950,6 +7982,7 @@ This is an example JSON object for profile settings."#;
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -8102,6 +8135,7 @@ This is an example JSON object for profile settings."#;
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -8204,6 +8238,7 @@ This is an example JSON object for profile settings."#;
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -8298,6 +8333,7 @@ This is an example JSON object for profile settings."#;
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
@@ -8412,6 +8448,7 @@ This is an example JSON object for profile settings."#;
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
&crate::config::AutonomyConfig::default(),
)),
activated_tools: None,
});
process_channel_message(
+211 -32
View File
@@ -62,24 +62,146 @@ impl NextcloudTalkChannel {
/// Parse a Nextcloud Talk webhook payload into channel messages.
///
/// Relevant payload fields:
/// - `type` (accepts `message` or `Create`)
/// - `object.token` (room token for reply routing)
/// - `message.actorType`, `message.actorId`, `message.message`, `message.timestamp`
/// Two payload formats are supported:
///
/// **Format A — legacy/custom** (`type: "message"`):
/// ```json
/// {
/// "type": "message",
/// "object": { "token": "<room>" },
/// "message": { "actorId": "...", "message": "...", ... }
/// }
/// ```
///
/// **Format B — Activity Streams 2.0** (`type: "Create"`):
/// This is the format actually sent by Nextcloud Talk bot webhooks.
/// ```json
/// {
/// "type": "Create",
/// "actor": { "type": "Person", "id": "users/alice", "name": "Alice" },
/// "object": { "type": "Note", "id": "177", "content": "{\"message\":\"hi\",\"parameters\":[]}", "mediaType": "text/markdown" },
/// "target": { "type": "Collection", "id": "<room_token>", "name": "Room Name" }
/// }
/// ```
pub fn parse_webhook_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
let messages = Vec::new();
let event_type = match payload.get("type").and_then(|v| v.as_str()) {
Some(t) => t,
None => return messages,
};
// Activity Streams 2.0 format sent by Nextcloud Talk bot webhooks.
if event_type.eq_ignore_ascii_case("create") {
return self.parse_as2_payload(payload);
}
// Legacy/custom format.
if !event_type.eq_ignore_ascii_case("message") {
tracing::debug!("Nextcloud Talk: skipping non-message event: {event_type}");
return messages;
}
self.parse_message_payload(payload)
}
/// Parse Activity Streams 2.0 `Create` payload (real Nextcloud Talk bot webhook format).
fn parse_as2_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
let mut messages = Vec::new();
if let Some(event_type) = payload.get("type").and_then(|v| v.as_str()) {
// Nextcloud Talk bot webhooks send "Create" for new chat messages,
// but some setups may use "message". Accept both.
let is_message_event = event_type.eq_ignore_ascii_case("message")
|| event_type.eq_ignore_ascii_case("create");
if !is_message_event {
tracing::debug!("Nextcloud Talk: skipping non-message event: {event_type}");
return messages;
}
let obj = match payload.get("object") {
Some(o) => o,
None => return messages,
};
// Only handle Note objects (= chat messages). Ignore reactions, etc.
let object_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or("");
if !object_type.eq_ignore_ascii_case("note") {
tracing::debug!("Nextcloud Talk: skipping AS2 Create with object.type={object_type}");
return messages;
}
// Room token is in target.id.
let room_token = payload
.get("target")
.and_then(|t| t.get("id"))
.and_then(|v| v.as_str())
.map(str::trim)
.filter(|t| !t.is_empty());
let Some(room_token) = room_token else {
tracing::warn!("Nextcloud Talk: missing target.id (room token) in AS2 payload");
return messages;
};
// Actor — skip bot-originated messages to prevent feedback loops.
let actor = payload.get("actor").cloned().unwrap_or_default();
let actor_type = actor.get("type").and_then(|v| v.as_str()).unwrap_or("");
if actor_type.eq_ignore_ascii_case("application") {
tracing::debug!("Nextcloud Talk: skipping bot-originated AS2 message");
return messages;
}
// actor.id is "users/<id>" — strip the prefix.
let actor_id = actor
.get("id")
.and_then(|v| v.as_str())
.map(|id| id.trim_start_matches("users/").trim())
.filter(|id| !id.is_empty());
let Some(actor_id) = actor_id else {
tracing::warn!("Nextcloud Talk: missing actor.id in AS2 payload");
return messages;
};
if !self.is_user_allowed(actor_id) {
tracing::warn!(
"Nextcloud Talk: ignoring message from unauthorized actor: {actor_id}. \
Add to channels.nextcloud_talk.allowed_users in config.toml, \
or run `zeroclaw onboard --channels-only` to configure interactively."
);
return messages;
}
// Message text is JSON-encoded inside object.content.
// e.g. content = "{\"message\":\"hello\",\"parameters\":[]}"
let content = obj
.get("content")
.and_then(|v| v.as_str())
.and_then(|s| serde_json::from_str::<serde_json::Value>(s).ok())
.and_then(|v| {
v.get("message")
.and_then(|m| m.as_str())
.map(str::trim)
.map(str::to_string)
})
.filter(|s| !s.is_empty());
let Some(content) = content else {
tracing::debug!("Nextcloud Talk: empty or unparseable AS2 message content");
return messages;
};
let message_id =
Self::value_to_string(obj.get("id")).unwrap_or_else(|| Uuid::new_v4().to_string());
messages.push(ChannelMessage {
id: message_id,
reply_target: room_token.to_string(),
sender: actor_id.to_string(),
content,
channel: "nextcloud_talk".to_string(),
timestamp: Self::now_unix_secs(),
thread_ts: None,
});
messages
}
/// Parse legacy `type: "message"` payload format.
fn parse_message_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
let mut messages = Vec::new();
let Some(message_obj) = payload.get("message") else {
return messages;
};
@@ -343,33 +465,90 @@ mod tests {
}
#[test]
fn nextcloud_talk_parse_create_event_type() {
let channel = make_channel();
fn nextcloud_talk_parse_as2_create_payload() {
let channel = NextcloudTalkChannel::new(
"https://cloud.example.com".into(),
"app-token".into(),
vec!["*".into()],
);
// Real payload format sent by Nextcloud Talk bot webhooks.
let payload = serde_json::json!({
"type": "Create",
"object": {
"id": "42",
"token": "room-token-123",
"name": "Team Room",
"type": "room"
"actor": {
"type": "Person",
"id": "users/user_a",
"name": "User A",
"talkParticipantType": "1"
},
"message": {
"id": 88,
"token": "room-token-123",
"actorType": "users",
"actorId": "user_a",
"actorDisplayName": "User A",
"timestamp": 1_735_701_300,
"messageType": "comment",
"systemMessage": "",
"message": "Hello via Create event"
"object": {
"type": "Note",
"id": "177",
"name": "message",
"content": "{\"message\":\"hallo, bist du da?\",\"parameters\":[]}",
"mediaType": "text/markdown"
},
"target": {
"type": "Collection",
"id": "room-token-123",
"name": "HOME"
}
});
let messages = channel.parse_webhook_payload(&payload);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].id, "88");
assert_eq!(messages[0].content, "Hello via Create event");
assert_eq!(messages[0].reply_target, "room-token-123");
assert_eq!(messages[0].sender, "user_a");
assert_eq!(messages[0].content, "hallo, bist du da?");
assert_eq!(messages[0].channel, "nextcloud_talk");
}
#[test]
fn nextcloud_talk_parse_as2_skips_bot_originated() {
let channel = NextcloudTalkChannel::new(
"https://cloud.example.com".into(),
"app-token".into(),
vec!["*".into()],
);
let payload = serde_json::json!({
"type": "Create",
"actor": {
"type": "Application",
"id": "bots/jarvis",
"name": "jarvis"
},
"object": {
"type": "Note",
"id": "178",
"content": "{\"message\":\"I am the bot\",\"parameters\":[]}",
"mediaType": "text/markdown"
},
"target": {
"type": "Collection",
"id": "room-token-123",
"name": "HOME"
}
});
let messages = channel.parse_webhook_payload(&payload);
assert!(messages.is_empty());
}
#[test]
fn nextcloud_talk_parse_as2_skips_non_note_objects() {
let channel = NextcloudTalkChannel::new(
"https://cloud.example.com".into(),
"app-token".into(),
vec!["*".into()],
);
let payload = serde_json::json!({
"type": "Create",
"actor": { "type": "Person", "id": "users/user_a" },
"object": { "type": "Reaction", "id": "5" },
"target": { "type": "Collection", "id": "room-token-123" }
});
let messages = channel.parse_webhook_payload(&payload);
assert!(messages.is_empty());
}
#[test]
+31
View File
@@ -671,4 +671,35 @@ allowed_users = ["user1"]
assert_eq!(compose_message_content(&payload), None);
}
#[test]
fn test_send_body_uses_markdown_msg_type() {
// Verify the expected JSON shape for both group and user send paths.
// msg_type 2 with a nested markdown object is required by the QQ API
// for markdown rendering; msg_type 0 (plain text) causes markdown
// syntax to appear literally in the client.
let content = "**bold** and `code`";
let group_body = json!({
"markdown": { "content": content },
"msg_type": 2,
});
assert_eq!(group_body["msg_type"], 2);
assert_eq!(group_body["markdown"]["content"], content);
assert!(
group_body.get("content").is_none(),
"top-level 'content' must not be present"
);
let user_body = json!({
"markdown": { "content": content },
"msg_type": 2,
});
assert_eq!(user_body["msg_type"], 2);
assert_eq!(user_body["markdown"]["content"], content);
assert!(
user_body.get("content").is_none(),
"top-level 'content' must not be present"
);
}
}
+1 -1
View File
@@ -2685,7 +2685,7 @@ Ensure only one `zeroclaw` process is using this bot token."
} else if let Some(m) = self.try_parse_attachment_message(update).await {
m
} else {
self.handle_unauthorized_message(update).await;
Box::pin(self.handle_unauthorized_message(update)).await;
continue;
};
+757 -34
View File
@@ -1,11 +1,19 @@
use std::collections::HashMap;
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use reqwest::multipart::{Form, Part};
use crate::config::TranscriptionConfig;
/// Maximum upload size accepted by the Groq Whisper API (25 MB).
/// Maximum upload size accepted by most Whisper-compatible APIs (25 MB).
const MAX_AUDIO_BYTES: usize = 25 * 1024 * 1024;
/// Request timeout for transcription API calls (seconds).
const TRANSCRIPTION_TIMEOUT_SECS: u64 = 120;
// ── Audio utilities ─────────────────────────────────────────────
/// Map file extension to MIME type for Whisper-compatible transcription APIs.
fn mime_for_audio(extension: &str) -> Option<&'static str> {
match extension.to_ascii_lowercase().as_str() {
@@ -31,16 +39,10 @@ fn normalize_audio_filename(file_name: &str) -> String {
}
}
/// Transcribe audio bytes via a Whisper-compatible transcription API.
/// Validate audio data and resolve MIME type from file name.
///
/// Returns the transcribed text on success. Requires `GROQ_API_KEY` in the
/// environment. The caller is responsible for enforcing duration limits
/// *before* downloading the file; this function enforces the byte-size cap.
pub async fn transcribe_audio(
audio_data: Vec<u8>,
file_name: &str,
config: &TranscriptionConfig,
) -> Result<String> {
/// Returns `(normalized_filename, mime_type)` on success.
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
if audio_data.len() > MAX_AUDIO_BYTES {
bail!(
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
@@ -59,37 +61,494 @@ pub async fn transcribe_audio(
)
})?;
let api_key = std::env::var("GROQ_API_KEY").context(
"GROQ_API_KEY environment variable is not set — required for voice transcription",
)?;
Ok((normalized_name, mime))
}
let client = crate::config::build_runtime_proxy_client("transcription.groq");
// ── TranscriptionProvider trait ─────────────────────────────────
let file_part = Part::bytes(audio_data)
.file_name(normalized_name)
.mime_str(mime)?;
/// Trait for speech-to-text provider implementations.
#[async_trait]
pub trait TranscriptionProvider: Send + Sync {
/// Human-readable provider name (e.g. "groq", "openai").
fn name(&self) -> &str;
let mut form = Form::new()
.part("file", file_part)
.text("model", config.model.clone())
.text("response_format", "json");
/// Transcribe raw audio bytes. `file_name` includes the extension for
/// format detection (e.g. "voice.ogg").
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String>;
if let Some(ref lang) = config.language {
form = form.text("language", lang.clone());
/// List of supported audio file extensions.
fn supported_formats(&self) -> Vec<String> {
vec![
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
]
.into_iter()
.map(String::from)
.collect()
}
}
// ── GroqProvider ────────────────────────────────────────────────
/// Groq Whisper API provider (default, backward-compatible with existing config).
pub struct GroqProvider {
api_url: String,
model: String,
api_key: String,
language: Option<String>,
}
impl GroqProvider {
/// Build from the existing `TranscriptionConfig` fields.
///
/// Credential resolution order:
/// 1. `config.api_key`
/// 2. `GROQ_API_KEY` environment variable (backward compatibility)
pub fn from_config(config: &TranscriptionConfig) -> Result<Self> {
let api_key = config
.api_key
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.or_else(|| {
std::env::var("GROQ_API_KEY")
.ok()
.map(|v| v.trim().to_string())
.filter(|v| !v.is_empty())
})
.context(
"Missing transcription API key: set [transcription].api_key or GROQ_API_KEY environment variable",
)?;
Ok(Self {
api_url: config.api_url.clone(),
model: config.model.clone(),
api_key,
language: config.language.clone(),
})
}
}
#[async_trait]
impl TranscriptionProvider for GroqProvider {
fn name(&self) -> &str {
"groq"
}
if let Some(ref prompt) = config.initial_prompt {
form = form.text("prompt", prompt.clone());
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
let client = crate::config::build_runtime_proxy_client("transcription.groq");
let file_part = Part::bytes(audio_data.to_vec())
.file_name(normalized_name)
.mime_str(mime)?;
let mut form = Form::new()
.part("file", file_part)
.text("model", self.model.clone())
.text("response_format", "json");
if let Some(ref lang) = self.language {
form = form.text("language", lang.clone());
}
let resp = client
.post(&self.api_url)
.bearer_auth(&self.api_key)
.multipart(form)
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
.send()
.await
.context("Failed to send transcription request to Groq")?;
parse_whisper_response(resp).await
}
}
// ── OpenAiWhisperProvider ───────────────────────────────────────
/// OpenAI Whisper API provider.
pub struct OpenAiWhisperProvider {
api_key: String,
model: String,
}
impl OpenAiWhisperProvider {
pub fn from_config(config: &crate::config::OpenAiSttConfig) -> Result<Self> {
let api_key = config
.api_key
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.context("Missing OpenAI STT API key: set [transcription.openai].api_key")?;
Ok(Self {
api_key,
model: config.model.clone(),
})
}
}
#[async_trait]
impl TranscriptionProvider for OpenAiWhisperProvider {
fn name(&self) -> &str {
"openai"
}
let resp = client
.post(&config.api_url)
.bearer_auth(&api_key)
.multipart(form)
.send()
.await
.context("Failed to send transcription request")?;
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
let client = crate::config::build_runtime_proxy_client("transcription.openai");
let file_part = Part::bytes(audio_data.to_vec())
.file_name(normalized_name)
.mime_str(mime)?;
let form = Form::new()
.part("file", file_part)
.text("model", self.model.clone())
.text("response_format", "json");
let resp = client
.post("https://api.openai.com/v1/audio/transcriptions")
.bearer_auth(&self.api_key)
.multipart(form)
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
.send()
.await
.context("Failed to send transcription request to OpenAI")?;
parse_whisper_response(resp).await
}
}
// ── DeepgramProvider ────────────────────────────────────────────
/// Deepgram STT API provider.
pub struct DeepgramProvider {
api_key: String,
model: String,
}
impl DeepgramProvider {
pub fn from_config(config: &crate::config::DeepgramSttConfig) -> Result<Self> {
let api_key = config
.api_key
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.context("Missing Deepgram API key: set [transcription.deepgram].api_key")?;
Ok(Self {
api_key,
model: config.model.clone(),
})
}
}
#[async_trait]
impl TranscriptionProvider for DeepgramProvider {
fn name(&self) -> &str {
"deepgram"
}
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
let (_, mime) = validate_audio(audio_data, file_name)?;
let client = crate::config::build_runtime_proxy_client("transcription.deepgram");
let url = format!(
"https://api.deepgram.com/v1/listen?model={}&punctuate=true",
self.model
);
let resp = client
.post(&url)
.header("Authorization", format!("Token {}", self.api_key))
.header("Content-Type", mime)
.body(audio_data.to_vec())
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
.send()
.await
.context("Failed to send transcription request to Deepgram")?;
let status = resp.status();
let body: serde_json::Value = resp
.json()
.await
.context("Failed to parse Deepgram response")?;
if !status.is_success() {
let error_msg = body["err_msg"]
.as_str()
.or_else(|| body["error"].as_str())
.unwrap_or("unknown error");
bail!("Deepgram API error ({}): {}", status, error_msg);
}
let text = body["results"]["channels"][0]["alternatives"][0]["transcript"]
.as_str()
.context("Deepgram response missing transcript field")?
.to_string();
Ok(text)
}
}
// ── AssemblyAiProvider ──────────────────────────────────────────
/// AssemblyAI STT API provider.
pub struct AssemblyAiProvider {
api_key: String,
}
impl AssemblyAiProvider {
pub fn from_config(config: &crate::config::AssemblyAiSttConfig) -> Result<Self> {
let api_key = config
.api_key
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.context("Missing AssemblyAI API key: set [transcription.assemblyai].api_key")?;
Ok(Self { api_key })
}
}
#[async_trait]
impl TranscriptionProvider for AssemblyAiProvider {
fn name(&self) -> &str {
"assemblyai"
}
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
let (_, _) = validate_audio(audio_data, file_name)?;
let client = crate::config::build_runtime_proxy_client("transcription.assemblyai");
// Step 1: Upload the audio file.
let upload_resp = client
.post("https://api.assemblyai.com/v2/upload")
.header("Authorization", &self.api_key)
.header("Content-Type", "application/octet-stream")
.body(audio_data.to_vec())
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
.send()
.await
.context("Failed to upload audio to AssemblyAI")?;
let upload_status = upload_resp.status();
let upload_body: serde_json::Value = upload_resp
.json()
.await
.context("Failed to parse AssemblyAI upload response")?;
if !upload_status.is_success() {
let error_msg = upload_body["error"].as_str().unwrap_or("unknown error");
bail!("AssemblyAI upload error ({}): {}", upload_status, error_msg);
}
let upload_url = upload_body["upload_url"]
.as_str()
.context("AssemblyAI upload response missing 'upload_url'")?;
// Step 2: Create transcription job.
let transcript_req = serde_json::json!({
"audio_url": upload_url,
});
let create_resp = client
.post("https://api.assemblyai.com/v2/transcript")
.header("Authorization", &self.api_key)
.json(&transcript_req)
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
.send()
.await
.context("Failed to create AssemblyAI transcription")?;
let create_status = create_resp.status();
let create_body: serde_json::Value = create_resp
.json()
.await
.context("Failed to parse AssemblyAI create response")?;
if !create_status.is_success() {
let error_msg = create_body["error"].as_str().unwrap_or("unknown error");
bail!(
"AssemblyAI transcription error ({}): {}",
create_status,
error_msg
);
}
let transcript_id = create_body["id"]
.as_str()
.context("AssemblyAI response missing 'id'")?;
// Step 3: Poll for completion.
let poll_url = format!("https://api.assemblyai.com/v2/transcript/{transcript_id}");
let poll_interval = std::time::Duration::from_secs(3);
let poll_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(180);
while tokio::time::Instant::now() < poll_deadline {
tokio::time::sleep(poll_interval).await;
let poll_resp = client
.get(&poll_url)
.header("Authorization", &self.api_key)
.timeout(std::time::Duration::from_secs(30))
.send()
.await
.context("Failed to poll AssemblyAI transcription")?;
let poll_status = poll_resp.status();
let poll_body: serde_json::Value = poll_resp
.json()
.await
.context("Failed to parse AssemblyAI poll response")?;
if !poll_status.is_success() {
let error_msg = poll_body["error"].as_str().unwrap_or("unknown poll error");
bail!("AssemblyAI poll error ({}): {}", poll_status, error_msg);
}
let status_str = poll_body["status"].as_str().unwrap_or("unknown");
match status_str {
"completed" => {
let text = poll_body["text"]
.as_str()
.context("AssemblyAI response missing 'text'")?
.to_string();
return Ok(text);
}
"error" => {
let error_msg = poll_body["error"]
.as_str()
.unwrap_or("unknown transcription error");
bail!("AssemblyAI transcription failed: {}", error_msg);
}
_ => {}
}
}
bail!("AssemblyAI transcription timed out after 180s")
}
}
// ── GoogleSttProvider ───────────────────────────────────────────
/// Google Cloud Speech-to-Text API provider.
pub struct GoogleSttProvider {
api_key: String,
language_code: String,
}
impl GoogleSttProvider {
pub fn from_config(config: &crate::config::GoogleSttConfig) -> Result<Self> {
let api_key = config
.api_key
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
.map(ToOwned::to_owned)
.context("Missing Google STT API key: set [transcription.google].api_key")?;
Ok(Self {
api_key,
language_code: config.language_code.clone(),
})
}
}
#[async_trait]
impl TranscriptionProvider for GoogleSttProvider {
fn name(&self) -> &str {
"google"
}
fn supported_formats(&self) -> Vec<String> {
// Google Cloud STT supports a subset of formats.
vec!["flac", "wav", "ogg", "opus", "mp3", "webm"]
.into_iter()
.map(String::from)
.collect()
}
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
let (normalized_name, _) = validate_audio(audio_data, file_name)?;
let client = crate::config::build_runtime_proxy_client("transcription.google");
let encoding = match normalized_name
.rsplit_once('.')
.map(|(_, e)| e.to_ascii_lowercase())
.as_deref()
{
Some("flac") => "FLAC",
Some("wav") => "LINEAR16",
Some("ogg" | "opus") => "OGG_OPUS",
Some("mp3") => "MP3",
Some("webm") => "WEBM_OPUS",
Some(ext) => bail!("Google STT does not support '.{ext}' input"),
None => bail!("Google STT requires a file extension"),
};
let audio_content =
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, audio_data);
let request_body = serde_json::json!({
"config": {
"encoding": encoding,
"languageCode": &self.language_code,
"enableAutomaticPunctuation": true,
},
"audio": {
"content": audio_content,
}
});
let url = format!(
"https://speech.googleapis.com/v1/speech:recognize?key={}",
self.api_key
);
let resp = client
.post(&url)
.json(&request_body)
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
.send()
.await
.context("Failed to send transcription request to Google STT")?;
let status = resp.status();
let body: serde_json::Value = resp
.json()
.await
.context("Failed to parse Google STT response")?;
if !status.is_success() {
let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
bail!("Google STT API error ({}): {}", status, error_msg);
}
let text = body["results"][0]["alternatives"][0]["transcript"]
.as_str()
.unwrap_or("")
.to_string();
Ok(text)
}
}
// ── Shared response parsing ─────────────────────────────────────
/// Parse a standard Whisper-compatible JSON response (`{ "text": "..." }`).
async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
let status = resp.status();
let body: serde_json::Value = resp
.json()
@@ -109,6 +568,128 @@ pub async fn transcribe_audio(
Ok(text)
}
// ── TranscriptionManager ────────────────────────────────────────
/// Manages multiple STT providers and routes transcription requests.
pub struct TranscriptionManager {
providers: HashMap<String, Box<dyn TranscriptionProvider>>,
default_provider: String,
}
impl TranscriptionManager {
/// Build a `TranscriptionManager` from config.
///
/// Always attempts to register the Groq provider from existing config fields.
/// Additional providers are registered when their config sections are present.
///
/// Provider keys with missing API keys are silently skipped — the error
/// surfaces at transcribe-time so callers that target a different default
/// provider are not blocked.
pub fn new(config: &TranscriptionConfig) -> Result<Self> {
let mut providers: HashMap<String, Box<dyn TranscriptionProvider>> = HashMap::new();
if let Ok(groq) = GroqProvider::from_config(config) {
providers.insert("groq".to_string(), Box::new(groq));
}
if let Some(ref openai_cfg) = config.openai {
if let Ok(p) = OpenAiWhisperProvider::from_config(openai_cfg) {
providers.insert("openai".to_string(), Box::new(p));
}
}
if let Some(ref deepgram_cfg) = config.deepgram {
if let Ok(p) = DeepgramProvider::from_config(deepgram_cfg) {
providers.insert("deepgram".to_string(), Box::new(p));
}
}
if let Some(ref assemblyai_cfg) = config.assemblyai {
if let Ok(p) = AssemblyAiProvider::from_config(assemblyai_cfg) {
providers.insert("assemblyai".to_string(), Box::new(p));
}
}
if let Some(ref google_cfg) = config.google {
if let Ok(p) = GoogleSttProvider::from_config(google_cfg) {
providers.insert("google".to_string(), Box::new(p));
}
}
let default_provider = config.default_provider.clone();
if config.enabled && !providers.contains_key(&default_provider) {
let available: Vec<&str> = providers.keys().map(|k| k.as_str()).collect();
bail!(
"Default transcription provider '{}' is not configured. Available: {available:?}",
default_provider
);
}
Ok(Self {
providers,
default_provider,
})
}
/// Transcribe audio using the default provider.
pub async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
self.transcribe_with_provider(audio_data, file_name, &self.default_provider)
.await
}
/// Transcribe audio using a specific named provider.
pub async fn transcribe_with_provider(
&self,
audio_data: &[u8],
file_name: &str,
provider: &str,
) -> Result<String> {
let p = self.providers.get(provider).ok_or_else(|| {
let available: Vec<&str> = self.providers.keys().map(|k| k.as_str()).collect();
anyhow::anyhow!(
"Transcription provider '{provider}' not configured. Available: {available:?}"
)
})?;
p.transcribe(audio_data, file_name).await
}
/// List registered provider names.
pub fn available_providers(&self) -> Vec<&str> {
self.providers.keys().map(|k| k.as_str()).collect()
}
}
// ── Backward-compatible convenience function ────────────────────
/// Transcribe audio bytes via a Whisper-compatible transcription API.
///
/// Returns the transcribed text on success.
///
/// This is the backward-compatible entry point that preserves the original
/// function signature. It uses the Groq provider directly, matching the
/// original single-provider behavior.
///
/// Credential resolution order:
/// 1. `config.transcription.api_key`
/// 2. `GROQ_API_KEY` environment variable (backward compatibility)
///
/// The caller is responsible for enforcing duration limits *before* downloading
/// the file; this function enforces the byte-size cap.
pub async fn transcribe_audio(
audio_data: Vec<u8>,
file_name: &str,
config: &TranscriptionConfig,
) -> Result<String> {
// Validate audio before resolving credentials so that size/format errors
// are reported before missing-key errors (preserves original behavior).
validate_audio(&audio_data, file_name)?;
let groq = GroqProvider::from_config(config)?;
groq.transcribe(&audio_data, file_name).await
}
#[cfg(test)]
mod tests {
use super::*;
@@ -129,7 +710,7 @@ mod tests {
#[tokio::test]
async fn rejects_missing_api_key() {
// Ensure the key is absent for this test
// Ensure fallback env key is absent for this test.
std::env::remove_var("GROQ_API_KEY");
let data = vec![0u8; 100];
@@ -139,11 +720,29 @@ mod tests {
.await
.unwrap_err();
assert!(
err.to_string().contains("GROQ_API_KEY"),
err.to_string().contains("transcription API key"),
"expected missing-key error, got: {err}"
);
}
#[tokio::test]
async fn uses_config_api_key_without_groq_env() {
std::env::remove_var("GROQ_API_KEY");
let data = vec![0u8; 100];
let mut config = TranscriptionConfig::default();
config.api_key = Some("transcription-key".to_string());
// Keep invalid extension so we fail before network, but after key resolution.
let err = transcribe_audio(data, "recording.aac", &config)
.await
.unwrap_err();
assert!(
err.to_string().contains("Unsupported audio format"),
"expected unsupported-format error, got: {err}"
);
}
#[test]
fn mime_for_audio_maps_accepted_formats() {
let cases = [
@@ -219,4 +818,128 @@ mod tests {
"error should mention the rejected extension, got: {msg}"
);
}
// ── TranscriptionManager tests ──────────────────────────────
#[test]
fn manager_creation_with_default_config() {
std::env::remove_var("GROQ_API_KEY");
let config = TranscriptionConfig::default();
let manager = TranscriptionManager::new(&config).unwrap();
assert_eq!(manager.default_provider, "groq");
// Groq won't be registered without a key.
assert!(manager.providers.is_empty());
}
#[test]
fn manager_registers_groq_with_key() {
std::env::remove_var("GROQ_API_KEY");
let mut config = TranscriptionConfig::default();
config.api_key = Some("test-groq-key".to_string());
let manager = TranscriptionManager::new(&config).unwrap();
assert!(manager.providers.contains_key("groq"));
assert_eq!(manager.providers["groq"].name(), "groq");
}
#[test]
fn manager_registers_multiple_providers() {
std::env::remove_var("GROQ_API_KEY");
let mut config = TranscriptionConfig::default();
config.api_key = Some("test-groq-key".to_string());
config.openai = Some(crate::config::OpenAiSttConfig {
api_key: Some("test-openai-key".to_string()),
model: "whisper-1".to_string(),
});
config.deepgram = Some(crate::config::DeepgramSttConfig {
api_key: Some("test-deepgram-key".to_string()),
model: "nova-2".to_string(),
});
let manager = TranscriptionManager::new(&config).unwrap();
assert!(manager.providers.contains_key("groq"));
assert!(manager.providers.contains_key("openai"));
assert!(manager.providers.contains_key("deepgram"));
assert_eq!(manager.available_providers().len(), 3);
}
#[tokio::test]
async fn manager_rejects_unconfigured_provider() {
std::env::remove_var("GROQ_API_KEY");
let mut config = TranscriptionConfig::default();
config.api_key = Some("test-groq-key".to_string());
let manager = TranscriptionManager::new(&config).unwrap();
let err = manager
.transcribe_with_provider(&[0u8; 100], "test.ogg", "nonexistent")
.await
.unwrap_err();
assert!(
err.to_string().contains("not configured"),
"expected not-configured error, got: {err}"
);
}
#[test]
fn manager_default_provider_from_config() {
std::env::remove_var("GROQ_API_KEY");
let mut config = TranscriptionConfig::default();
config.default_provider = "openai".to_string();
config.openai = Some(crate::config::OpenAiSttConfig {
api_key: Some("test-openai-key".to_string()),
model: "whisper-1".to_string(),
});
let manager = TranscriptionManager::new(&config).unwrap();
assert_eq!(manager.default_provider, "openai");
}
#[test]
fn validate_audio_rejects_oversized() {
let big = vec![0u8; MAX_AUDIO_BYTES + 1];
let err = validate_audio(&big, "test.ogg").unwrap_err();
assert!(err.to_string().contains("too large"));
}
#[test]
fn validate_audio_rejects_unsupported_format() {
let data = vec![0u8; 100];
let err = validate_audio(&data, "test.aac").unwrap_err();
assert!(err.to_string().contains("Unsupported audio format"));
}
#[test]
fn validate_audio_accepts_supported_format() {
let data = vec![0u8; 100];
let (name, mime) = validate_audio(&data, "test.ogg").unwrap();
assert_eq!(name, "test.ogg");
assert_eq!(mime, "audio/ogg");
}
#[test]
fn validate_audio_normalizes_oga() {
let data = vec![0u8; 100];
let (name, mime) = validate_audio(&data, "voice.oga").unwrap();
assert_eq!(name, "voice.ogg");
assert_eq!(mime, "audio/ogg");
}
#[test]
fn backward_compat_config_defaults_unchanged() {
let config = TranscriptionConfig::default();
assert!(!config.enabled);
assert!(config.api_key.is_none());
assert!(config.api_url.contains("groq.com"));
assert_eq!(config.model, "whisper-large-v3-turbo");
assert_eq!(config.default_provider, "groq");
assert!(config.openai.is_none());
assert!(config.deepgram.is_none());
assert!(config.assemblyai.is_none());
assert!(config.google.is_none());
}
}
+10 -9
View File
@@ -6,19 +6,20 @@ pub mod workspace;
pub use schema::{
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
AgentConfig, AuditConfig, AutonomyConfig, BackupConfig, BrowserComputerUseConfig,
BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, CloudOpsConfig,
ComposioConfig, Config, ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig,
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig,
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig,
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
ClassificationRule, CloudOpsConfig, ComposioConfig, Config, ConversationalAiConfig, CostConfig,
CronConfig, DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DiscordConfig,
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceConfig,
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig,
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig,
PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope,
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
+306 -4
View File
@@ -264,6 +264,10 @@ pub struct Config {
#[serde(default)]
pub project_intel: ProjectIntelConfig,
/// Google Workspace CLI (`gws`) tool configuration (`[google_workspace]`).
#[serde(default)]
pub google_workspace: GoogleWorkspaceConfig,
/// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]`).
#[serde(default)]
pub proxy: ProxyConfig,
@@ -604,19 +608,46 @@ fn default_transcription_max_duration_secs() -> u64 {
120
}
/// Voice transcription configuration (Whisper API via Groq).
fn default_transcription_provider() -> String {
"groq".into()
}
fn default_openai_stt_model() -> String {
"whisper-1".into()
}
fn default_deepgram_stt_model() -> String {
"nova-2".into()
}
fn default_google_stt_language_code() -> String {
"en-US".into()
}
/// Voice transcription configuration with multi-provider support.
///
/// The top-level `api_url`, `model`, and `api_key` fields remain for backward
/// compatibility with existing Groq-based configurations.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct TranscriptionConfig {
/// Enable voice transcription for channels that support it.
#[serde(default)]
pub enabled: bool,
/// Whisper API endpoint URL.
/// Default STT provider: "groq", "openai", "deepgram", "assemblyai", "google".
#[serde(default = "default_transcription_provider")]
pub default_provider: String,
/// API key used for transcription requests (Groq provider).
///
/// If unset, runtime falls back to `GROQ_API_KEY` for backward compatibility.
#[serde(default)]
pub api_key: Option<String>,
/// Whisper API endpoint URL (Groq provider).
#[serde(default = "default_transcription_api_url")]
pub api_url: String,
/// Whisper model name.
/// Whisper model name (Groq provider).
#[serde(default = "default_transcription_model")]
pub model: String,
/// Optional language hint (ISO-639-1, e.g. "en", "ru").
/// Optional language hint (ISO-639-1, e.g. "en", "ru") for Groq provider.
#[serde(default)]
pub language: Option<String>,
/// Optional initial prompt to bias transcription toward expected vocabulary
@@ -627,17 +658,35 @@ pub struct TranscriptionConfig {
/// Maximum voice duration in seconds (messages longer than this are skipped).
#[serde(default = "default_transcription_max_duration_secs")]
pub max_duration_secs: u64,
/// OpenAI Whisper STT provider configuration.
#[serde(default)]
pub openai: Option<OpenAiSttConfig>,
/// Deepgram STT provider configuration.
#[serde(default)]
pub deepgram: Option<DeepgramSttConfig>,
/// AssemblyAI STT provider configuration.
#[serde(default)]
pub assemblyai: Option<AssemblyAiSttConfig>,
/// Google Cloud Speech-to-Text provider configuration.
#[serde(default)]
pub google: Option<GoogleSttConfig>,
}
impl Default for TranscriptionConfig {
fn default() -> Self {
Self {
enabled: false,
default_provider: default_transcription_provider(),
api_key: None,
api_url: default_transcription_api_url(),
model: default_transcription_model(),
language: None,
initial_prompt: None,
max_duration_secs: default_transcription_max_duration_secs(),
openai: None,
deepgram: None,
assemblyai: None,
google: None,
}
}
}
@@ -938,6 +987,47 @@ pub struct ToolFilterGroup {
pub keywords: Vec<String>,
}
/// OpenAI Whisper STT provider configuration (`[transcription.openai]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OpenAiSttConfig {
/// OpenAI API key for Whisper transcription.
#[serde(default)]
pub api_key: Option<String>,
/// Whisper model name (default: "whisper-1").
#[serde(default = "default_openai_stt_model")]
pub model: String,
}
/// Deepgram STT provider configuration (`[transcription.deepgram]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct DeepgramSttConfig {
/// Deepgram API key.
#[serde(default)]
pub api_key: Option<String>,
/// Deepgram model name (default: "nova-2").
#[serde(default = "default_deepgram_stt_model")]
pub model: String,
}
/// AssemblyAI STT provider configuration (`[transcription.assemblyai]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AssemblyAiSttConfig {
/// AssemblyAI API key.
#[serde(default)]
pub api_key: Option<String>,
}
/// Google Cloud Speech-to-Text provider configuration (`[transcription.google]`).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GoogleSttConfig {
/// Google Cloud API key.
#[serde(default)]
pub api_key: Option<String>,
/// BCP-47 language code (default: "en-US").
#[serde(default = "default_google_stt_language_code")]
pub language_code: String,
}
/// Agent orchestration configuration (`[agent]` section).
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct AgentConfig {
@@ -1994,6 +2084,94 @@ impl Default for DataRetentionConfig {
}
}
// ── Google Workspace ─────────────────────────────────────────────
/// Google Workspace CLI (`gws`) tool configuration (`[google_workspace]` section).
///
/// ## Defaults
/// - `enabled`: `false` (tool is not registered unless explicitly opted-in).
/// - `allowed_services`: empty vector, which grants access to the full default
/// service set: `drive`, `sheets`, `gmail`, `calendar`, `docs`, `slides`,
/// `tasks`, `people`, `chat`, `classroom`, `forms`, `keep`, `meet`, `events`.
/// - `credentials_path`: `None` (uses default `gws` credential discovery).
/// - `default_account`: `None` (uses the `gws` active account).
/// - `rate_limit_per_minute`: `60`.
/// - `timeout_secs`: `30`.
/// - `audit_log`: `false`.
/// - `credentials_path`: `None` (uses default `gws` credential discovery).
/// - `default_account`: `None` (uses the `gws` active account).
/// - `rate_limit_per_minute`: `60`.
/// - `timeout_secs`: `30`.
/// - `audit_log`: `false`.
///
/// ## Compatibility
/// Configs that omit the `[google_workspace]` section entirely are treated as
/// `GoogleWorkspaceConfig::default()` (disabled, all defaults allowed). Adding
/// the section is purely opt-in and does not affect other config sections.
///
/// ## Rollback / Migration
/// To revert, remove the `[google_workspace]` section from the config file (or
/// set `enabled = false`). No data migration is required; the tool simply stops
/// being registered.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct GoogleWorkspaceConfig {
/// Enable the `google_workspace` tool. Default: `false`.
#[serde(default)]
pub enabled: bool,
/// Restrict which Google Workspace services the agent can access.
///
/// When empty (the default), the full default service set is allowed (see
/// struct-level docs). When non-empty, only the listed service IDs are
/// permitted. Each entry must be non-empty, lowercase alphanumeric with
/// optional underscores/hyphens, and unique.
#[serde(default)]
pub allowed_services: Vec<String>,
/// Path to service account JSON or OAuth client credentials file.
///
/// When `None`, the tool relies on the default `gws` credential discovery
/// (`gws auth login`). Set this to point at a service-account key or an
/// OAuth client-secrets JSON for headless / CI environments.
#[serde(default)]
pub credentials_path: Option<String>,
/// Default Google account email to pass to `gws --account`.
///
/// When `None`, the currently active `gws` account is used.
#[serde(default)]
pub default_account: Option<String>,
/// Maximum number of `gws` API calls allowed per minute. Default: `60`.
#[serde(default = "default_gws_rate_limit")]
pub rate_limit_per_minute: u32,
/// Command execution timeout in seconds. Default: `30`.
#[serde(default = "default_gws_timeout_secs")]
pub timeout_secs: u64,
/// Enable audit logging of every `gws` invocation (service, resource,
/// method, timestamp). Default: `false`.
#[serde(default)]
pub audit_log: bool,
}
fn default_gws_rate_limit() -> u32 {
60
}
fn default_gws_timeout_secs() -> u64 {
30
}
impl Default for GoogleWorkspaceConfig {
fn default() -> Self {
Self {
enabled: false,
allowed_services: Vec::new(),
credentials_path: None,
default_account: None,
rate_limit_per_minute: default_gws_rate_limit(),
timeout_secs: default_gws_timeout_secs(),
audit_log: false,
}
}
}
// ── Proxy ───────────────────────────────────────────────────────
/// Proxy application scope — determines which outbound traffic uses the proxy.
@@ -3920,6 +4098,10 @@ pub struct SlackConfig {
/// cancels the in-flight request and starts a fresh response with preserved history.
#[serde(default)]
pub interrupt_on_new_message: bool,
/// When true, only respond to messages that @-mention the bot in groups.
/// Direct messages remain allowed.
#[serde(default)]
pub mention_only: bool,
}
impl ChannelConfig for SlackConfig {
@@ -5260,6 +5442,7 @@ impl Default for Config {
web_fetch: WebFetchConfig::default(),
web_search: WebSearchConfig::default(),
project_intel: ProjectIntelConfig::default(),
google_workspace: GoogleWorkspaceConfig::default(),
proxy: ProxyConfig::default(),
identity: IdentityConfig::default(),
cost: CostConfig::default(),
@@ -5794,6 +5977,41 @@ impl Config {
decrypt_optional_secret(&store, &mut google.api_key, "config.tts.google.api_key")?;
}
// Decrypt nested STT provider API keys
decrypt_optional_secret(
&store,
&mut config.transcription.api_key,
"config.transcription.api_key",
)?;
if let Some(ref mut openai) = config.transcription.openai {
decrypt_optional_secret(
&store,
&mut openai.api_key,
"config.transcription.openai.api_key",
)?;
}
if let Some(ref mut deepgram) = config.transcription.deepgram {
decrypt_optional_secret(
&store,
&mut deepgram.api_key,
"config.transcription.deepgram.api_key",
)?;
}
if let Some(ref mut assemblyai) = config.transcription.assemblyai {
decrypt_optional_secret(
&store,
&mut assemblyai.api_key,
"config.transcription.assemblyai.api_key",
)?;
}
if let Some(ref mut google) = config.transcription.google {
decrypt_optional_secret(
&store,
&mut google.api_key,
"config.transcription.google.api_key",
)?;
}
#[cfg(feature = "channel-nostr")]
if let Some(ref mut ns) = config.channels_config.nostr {
decrypt_secret(
@@ -6417,6 +6635,28 @@ impl Config {
validate_mcp_config(&self.mcp)?;
}
// Google Workspace allowed_services validation
let mut seen_gws_services = std::collections::HashSet::new();
for (i, service) in self.google_workspace.allowed_services.iter().enumerate() {
let normalized = service.trim();
if normalized.is_empty() {
anyhow::bail!("google_workspace.allowed_services[{i}] must not be empty");
}
if !normalized
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-')
{
anyhow::bail!(
"google_workspace.allowed_services[{i}] contains invalid characters: {normalized}"
);
}
if !seen_gws_services.insert(normalized.to_string()) {
anyhow::bail!(
"google_workspace.allowed_services contains duplicate entry: {normalized}"
);
}
}
// Project intelligence
if self.project_intel.enabled {
let lang = &self.project_intel.default_language;
@@ -6470,6 +6710,19 @@ impl Config {
anyhow::bail!("security.nevis: {msg}");
}
// Transcription
{
let dp = self.transcription.default_provider.trim();
match dp {
"groq" | "openai" | "deepgram" | "assemblyai" | "google" => {}
other => {
anyhow::bail!(
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google (got '{other}')"
);
}
}
}
Ok(())
}
@@ -6879,6 +7132,41 @@ impl Config {
encrypt_optional_secret(&store, &mut google.api_key, "config.tts.google.api_key")?;
}
// Encrypt nested STT provider API keys
encrypt_optional_secret(
&store,
&mut config_to_save.transcription.api_key,
"config.transcription.api_key",
)?;
if let Some(ref mut openai) = config_to_save.transcription.openai {
encrypt_optional_secret(
&store,
&mut openai.api_key,
"config.transcription.openai.api_key",
)?;
}
if let Some(ref mut deepgram) = config_to_save.transcription.deepgram {
encrypt_optional_secret(
&store,
&mut deepgram.api_key,
"config.transcription.deepgram.api_key",
)?;
}
if let Some(ref mut assemblyai) = config_to_save.transcription.assemblyai {
encrypt_optional_secret(
&store,
&mut assemblyai.api_key,
"config.transcription.assemblyai.api_key",
)?;
}
if let Some(ref mut google) = config_to_save.transcription.google {
encrypt_optional_secret(
&store,
&mut google.api_key,
"config.transcription.google.api_key",
)?;
}
#[cfg(feature = "channel-nostr")]
if let Some(ref mut ns) = config_to_save.channels_config.nostr {
encrypt_secret(
@@ -7564,6 +7852,7 @@ default_temperature = 0.7
web_fetch: WebFetchConfig::default(),
web_search: WebSearchConfig::default(),
project_intel: ProjectIntelConfig::default(),
google_workspace: GoogleWorkspaceConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
identity: IdentityConfig::default(),
@@ -7867,6 +8156,7 @@ tool_dispatcher = "xml"
web_fetch: WebFetchConfig::default(),
web_search: WebSearchConfig::default(),
project_intel: ProjectIntelConfig::default(),
google_workspace: GoogleWorkspaceConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
identity: IdentityConfig::default(),
@@ -8326,6 +8616,7 @@ allowed_users = ["@ops:matrix.org"]
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
assert!(parsed.allowed_users.is_empty());
assert!(!parsed.interrupt_on_new_message);
assert!(!parsed.mention_only);
}
#[test]
@@ -8334,6 +8625,15 @@ allowed_users = ["@ops:matrix.org"]
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
assert_eq!(parsed.allowed_users, vec!["U111"]);
assert!(!parsed.interrupt_on_new_message);
assert!(!parsed.mention_only);
}
#[test]
async fn slack_config_deserializes_with_mention_only() {
let json = r#"{"bot_token":"xoxb-tok","mention_only":true}"#;
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
assert!(parsed.mention_only);
assert!(!parsed.interrupt_on_new_message);
}
#[test]
@@ -8341,6 +8641,7 @@ allowed_users = ["@ops:matrix.org"]
let json = r#"{"bot_token":"xoxb-tok","interrupt_on_new_message":true}"#;
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
assert!(parsed.interrupt_on_new_message);
assert!(!parsed.mention_only);
}
#[test]
@@ -8363,6 +8664,7 @@ channel_id = "C123"
let parsed: SlackConfig = toml::from_str(toml_str).unwrap();
assert!(parsed.allowed_users.is_empty());
assert!(!parsed.interrupt_on_new_message);
assert!(!parsed.mention_only);
assert_eq!(parsed.channel_id.as_deref(), Some("C123"));
}
+74 -2
View File
@@ -509,6 +509,18 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
},
},
// ── Productivity ────────────────────────────────────────
IntegrationEntry {
name: "Google Workspace",
description: "Drive, Gmail, Calendar, Sheets, Docs via gws CLI",
category: IntegrationCategory::Productivity,
status_fn: |c| {
if c.google_workspace.enabled {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "GitHub",
description: "Code, issues, PRs",
@@ -606,7 +618,13 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
name: "Browser",
description: "Chrome/Chromium control",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::Available,
status_fn: |c| {
if c.browser.enabled {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Shell",
@@ -624,7 +642,13 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
name: "Cron",
description: "Scheduled tasks",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::Available,
status_fn: |c| {
if c.cron.enabled {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Voice",
@@ -917,6 +941,54 @@ mod tests {
));
}
#[test]
fn cron_active_when_enabled() {
let mut config = Config::default();
config.cron.enabled = true;
let entries = all_integrations();
let cron = entries.iter().find(|e| e.name == "Cron").unwrap();
assert!(matches!(
(cron.status_fn)(&config),
IntegrationStatus::Active
));
}
#[test]
fn cron_available_when_disabled() {
let mut config = Config::default();
config.cron.enabled = false;
let entries = all_integrations();
let cron = entries.iter().find(|e| e.name == "Cron").unwrap();
assert!(matches!(
(cron.status_fn)(&config),
IntegrationStatus::Available
));
}
#[test]
fn browser_active_when_enabled() {
let mut config = Config::default();
config.browser.enabled = true;
let entries = all_integrations();
let browser = entries.iter().find(|e| e.name == "Browser").unwrap();
assert!(matches!(
(browser.status_fn)(&config),
IntegrationStatus::Active
));
}
#[test]
fn browser_available_when_disabled() {
let mut config = Config::default();
config.browser.enabled = false;
let entries = all_integrations();
let browser = entries.iter().find(|e| e.name == "Browser").unwrap();
assert!(matches!(
(browser.status_fn)(&config),
IntegrationStatus::Available
));
}
#[test]
fn shell_and_filesystem_always_active() {
let config = Config::default();
+8 -4
View File
@@ -806,13 +806,13 @@ async fn main() -> Result<()> {
} else if is_tty && !has_provider_flags {
Box::pin(onboard::run_wizard(force)).await
} else {
onboard::run_quick_setup(
Box::pin(onboard::run_quick_setup(
api_key.as_deref(),
provider.as_deref(),
model.as_deref(),
memory.as_deref(),
force,
)
))
.await
}?;
@@ -1191,7 +1191,7 @@ async fn main() -> Result<()> {
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
other => channels::handle_command(other, &config).await,
other => Box::pin(channels::handle_command(other, &config)).await,
},
Commands::Integrations {
@@ -1215,7 +1215,11 @@ async fn main() -> Result<()> {
}
Commands::Peripheral { peripheral_command } => {
peripherals::handle_command(peripheral_command.clone(), &config).await
Box::pin(peripherals::handle_command(
peripheral_command.clone(),
&config,
))
.await
}
Commands::Config { config_command } => match config_command {
+15 -12
View File
@@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
web_fetch: crate::config::WebFetchConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
google_workspace: crate::config::GoogleWorkspaceConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
@@ -424,14 +425,14 @@ pub async fn run_quick_setup(
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
run_quick_setup_with_home(
Box::pin(run_quick_setup_with_home(
credential_override,
provider,
model_override,
memory_backend,
force,
&home,
)
))
.await
}
@@ -543,6 +544,7 @@ async fn run_quick_setup_with_home(
web_fetch: crate::config::WebFetchConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
project_intel: crate::config::ProjectIntelConfig::default(),
google_workspace: crate::config::GoogleWorkspaceConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
cost: crate::config::CostConfig::default(),
@@ -3902,6 +3904,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
},
allowed_users,
interrupt_on_new_message: false,
mention_only: false,
});
}
ChannelMenuChoice::IMessage => {
@@ -5911,14 +5914,14 @@ mod tests {
let _config_env = EnvVarGuard::unset("ZEROCLAW_CONFIG_DIR");
let tmp = TempDir::new().unwrap();
let config = run_quick_setup_with_home(
let config = Box::pin(run_quick_setup_with_home(
Some("sk-issue946"),
Some("openrouter"),
Some("custom-model-946"),
Some("sqlite"),
false,
tmp.path(),
)
))
.await
.unwrap();
@@ -5938,14 +5941,14 @@ mod tests {
let _config_env = EnvVarGuard::unset("ZEROCLAW_CONFIG_DIR");
let tmp = TempDir::new().unwrap();
let config = run_quick_setup_with_home(
let config = Box::pin(run_quick_setup_with_home(
Some("sk-issue946"),
Some("anthropic"),
None,
Some("sqlite"),
false,
tmp.path(),
)
))
.await
.unwrap();
@@ -5968,14 +5971,14 @@ mod tests {
.await
.unwrap();
let err = run_quick_setup_with_home(
let err = Box::pin(run_quick_setup_with_home(
Some("sk-existing"),
Some("openrouter"),
Some("custom-model"),
Some("sqlite"),
false,
tmp.path(),
)
))
.await
.expect_err("quick setup should refuse overwrite without --force");
@@ -6001,14 +6004,14 @@ mod tests {
.await
.unwrap();
let config = run_quick_setup_with_home(
let config = Box::pin(run_quick_setup_with_home(
Some("sk-force"),
Some("openrouter"),
Some("custom-model-fresh"),
Some("sqlite"),
true,
tmp.path(),
)
))
.await
.expect("quick setup should overwrite existing config with --force");
@@ -6035,14 +6038,14 @@ mod tests {
);
let _config_env = EnvVarGuard::unset("ZEROCLAW_CONFIG_DIR");
let config = run_quick_setup_with_home(
let config = Box::pin(run_quick_setup_with_home(
Some("sk-env"),
Some("openrouter"),
Some("model-env"),
Some("sqlite"),
false,
tmp.path(),
)
))
.await
.expect("quick setup should honor ZEROCLAW_WORKSPACE");
+330
View File
@@ -0,0 +1,330 @@
//! Claude Code headless CLI provider.
//!
//! Integrates with the Claude Code CLI, spawning the `claude` binary
//! as a subprocess for each inference request. This allows using Claude's AI
//! models without an interactive UI session.
//!
//! # Usage
//!
//! The `claude` binary must be available in `PATH`, or its location must be
//! set via the `CLAUDE_CODE_PATH` environment variable.
//!
//! Claude Code is invoked as:
//! ```text
//! claude --print -
//! ```
//! with prompt content written to stdin.
//!
//! # Limitations
//!
//! - **Conversation history**: Only the system prompt (if present) and the last
//! user message are forwarded. Full multi-turn history is not preserved because
//! the CLI accepts a single prompt per invocation.
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
//! Only default values are accepted; custom values return an explicit error.
//!
//! # Authentication
//!
//! Authentication is handled by Claude Code itself (its own credential store).
//! No explicit API key is required by this provider.
//!
//! # Environment variables
//!
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{timeout, Duration};
/// Environment variable for overriding the path to the `claude` binary.
pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
/// Default `claude` binary name (resolved via `PATH`).
const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
/// Model name used to signal "use the provider's own default model".
const DEFAULT_MODEL_MARKER: &str = "default";
/// Claude Code requests are bounded to avoid hung subprocesses.
const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
/// Avoid leaking oversized stderr payloads.
const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
/// The CLI does not support sampling controls; allow only baseline defaults.
const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
const TEMP_EPSILON: f64 = 1e-9;
/// Provider that invokes the Claude Code CLI as a subprocess.
///
/// Each inference request spawns a fresh `claude` process. This is the
/// non-interactive approach: the process handles the prompt and exits.
pub struct ClaudeCodeProvider {
/// Path to the `claude` binary.
binary_path: PathBuf,
}
impl ClaudeCodeProvider {
/// Create a new `ClaudeCodeProvider`.
///
/// The binary path is resolved from `CLAUDE_CODE_PATH` env var if set,
/// otherwise defaults to `"claude"` (found via `PATH`).
pub fn new() -> Self {
let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
.ok()
.filter(|path| !path.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
Self { binary_path }
}
/// Returns true if the model argument should be forwarded to the CLI.
fn should_forward_model(model: &str) -> bool {
let trimmed = model.trim();
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
}
fn supports_temperature(temperature: f64) -> bool {
CLAUDE_CODE_SUPPORTED_TEMPERATURES
.iter()
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
}
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
if !temperature.is_finite() {
anyhow::bail!("Claude Code provider received non-finite temperature value");
}
if !Self::supports_temperature(temperature) {
anyhow::bail!(
"temperature unsupported by Claude Code CLI: {temperature}. \
Supported values: 0.7 or 1.0"
);
}
Ok(())
}
fn redact_stderr(stderr: &[u8]) -> String {
let text = String::from_utf8_lossy(stderr);
let trimmed = text.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
return trimmed.to_string();
}
let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
format!("{clipped}...")
}
/// Invoke the claude binary with the given prompt and optional model.
/// Returns the trimmed stdout output as the assistant response.
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
let mut cmd = Command::new(&self.binary_path);
cmd.arg("--print");
if Self::should_forward_model(model) {
cmd.arg("--model").arg(model);
}
// Read prompt from stdin to avoid exposing sensitive content in process args.
cmd.arg("-");
cmd.kill_on_drop(true);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|err| {
anyhow::anyhow!(
"Failed to spawn Claude Code binary at {}: {err}. \
Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
self.binary_path.display()
)
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(message.as_bytes()).await.map_err(|err| {
anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
})?;
stdin.shutdown().await.map_err(|err| {
anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
})?;
}
let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
.await
.map_err(|_| {
anyhow::anyhow!(
"Claude Code request timed out after {:?} (binary: {})",
CLAUDE_CODE_REQUEST_TIMEOUT,
self.binary_path.display()
)
})?
.map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
if !output.status.success() {
let code = output.status.code().unwrap_or(-1);
let stderr_excerpt = Self::redact_stderr(&output.stderr);
let stderr_note = if stderr_excerpt.is_empty() {
String::new()
} else {
format!(" Stderr: {stderr_excerpt}")
};
anyhow::bail!(
"Claude Code exited with non-zero status {code}. \
Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
);
}
let text = String::from_utf8(output.stdout)
.map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
Ok(text.trim().to_string())
}
}
impl Default for ClaudeCodeProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for ClaudeCodeProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
let full_message = match system_prompt {
Some(system) if !system.is_empty() => {
format!("{system}\n\n{message}")
}
_ => message.to_string(),
};
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: Some(TokenUsage::default()),
reasoning_content: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
}
#[test]
fn new_uses_env_override() {
let _guard = env_lock();
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude");
let provider = ClaudeCodeProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
match orig {
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
}
}
#[test]
fn new_defaults_to_claude() {
let _guard = env_lock();
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
std::env::remove_var(CLAUDE_CODE_PATH_ENV);
let provider = ClaudeCodeProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("claude"));
if let Some(v) = orig {
std::env::set_var(CLAUDE_CODE_PATH_ENV, v);
}
}
#[test]
fn new_ignores_blank_env_override() {
let _guard = env_lock();
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
std::env::set_var(CLAUDE_CODE_PATH_ENV, " ");
let provider = ClaudeCodeProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("claude"));
match orig {
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
}
}
#[test]
fn should_forward_model_standard() {
assert!(ClaudeCodeProvider::should_forward_model(
"claude-sonnet-4-20250514"
));
assert!(ClaudeCodeProvider::should_forward_model(
"claude-3.5-sonnet"
));
}
#[test]
fn should_not_forward_default_model() {
assert!(!ClaudeCodeProvider::should_forward_model(
DEFAULT_MODEL_MARKER
));
assert!(!ClaudeCodeProvider::should_forward_model(""));
assert!(!ClaudeCodeProvider::should_forward_model(" "));
}
#[test]
fn validate_temperature_allows_defaults() {
assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
}
#[test]
fn validate_temperature_rejects_custom_value() {
let err = ClaudeCodeProvider::validate_temperature(0.2).unwrap_err();
assert!(err
.to_string()
.contains("temperature unsupported by Claude Code CLI"));
}
#[tokio::test]
async fn invoke_missing_binary_returns_error() {
let provider = ClaudeCodeProvider {
binary_path: PathBuf::from("/nonexistent/path/to/claude"),
};
let result = provider.invoke_cli("hello", "default").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Failed to spawn Claude Code binary"),
"unexpected error message: {msg}"
);
}
}
+326
View File
@@ -0,0 +1,326 @@
//! Gemini CLI subprocess provider.
//!
//! Integrates with the Gemini CLI, spawning the `gemini` binary
//! as a subprocess for each inference request. This allows using Google's
//! Gemini models via the CLI without an interactive UI session.
//!
//! # Usage
//!
//! The `gemini` binary must be available in `PATH`, or its location must be
//! set via the `GEMINI_CLI_PATH` environment variable.
//!
//! Gemini CLI is invoked as:
//! ```text
//! gemini --print -
//! ```
//! with prompt content written to stdin.
//!
//! # Limitations
//!
//! - **Conversation history**: Only the system prompt (if present) and the last
//! user message are forwarded. Full multi-turn history is not preserved because
//! the CLI accepts a single prompt per invocation.
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
//! Only default values are accepted; custom values return an explicit error.
//!
//! # Authentication
//!
//! Authentication is handled by the Gemini CLI itself (its own credential store).
//! No explicit API key is required by this provider.
//!
//! # Environment variables
//!
//! - `GEMINI_CLI_PATH` — override the path to the `gemini` binary (default: `"gemini"`)
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{timeout, Duration};
/// Environment variable for overriding the path to the `gemini` binary.
pub const GEMINI_CLI_PATH_ENV: &str = "GEMINI_CLI_PATH";
/// Default `gemini` binary name (resolved via `PATH`).
const DEFAULT_GEMINI_CLI_BINARY: &str = "gemini";
/// Model name used to signal "use the provider's own default model".
const DEFAULT_MODEL_MARKER: &str = "default";
/// Gemini CLI requests are bounded to avoid hung subprocesses.
const GEMINI_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
/// Avoid leaking oversized stderr payloads.
const MAX_GEMINI_CLI_STDERR_CHARS: usize = 512;
/// The CLI does not support sampling controls; allow only baseline defaults.
const GEMINI_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
const TEMP_EPSILON: f64 = 1e-9;
/// Provider that invokes the Gemini CLI as a subprocess.
///
/// Each inference request spawns a fresh `gemini` process. This is the
/// non-interactive approach: the process handles the prompt and exits.
pub struct GeminiCliProvider {
/// Path to the `gemini` binary.
binary_path: PathBuf,
}
impl GeminiCliProvider {
/// Create a new `GeminiCliProvider`.
///
/// The binary path is resolved from `GEMINI_CLI_PATH` env var if set,
/// otherwise defaults to `"gemini"` (found via `PATH`).
pub fn new() -> Self {
let binary_path = std::env::var(GEMINI_CLI_PATH_ENV)
.ok()
.filter(|path| !path.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_GEMINI_CLI_BINARY));
Self { binary_path }
}
/// Returns true if the model argument should be forwarded to the CLI.
fn should_forward_model(model: &str) -> bool {
let trimmed = model.trim();
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
}
fn supports_temperature(temperature: f64) -> bool {
GEMINI_CLI_SUPPORTED_TEMPERATURES
.iter()
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
}
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
if !temperature.is_finite() {
anyhow::bail!("Gemini CLI provider received non-finite temperature value");
}
if !Self::supports_temperature(temperature) {
anyhow::bail!(
"temperature unsupported by Gemini CLI: {temperature}. \
Supported values: 0.7 or 1.0"
);
}
Ok(())
}
fn redact_stderr(stderr: &[u8]) -> String {
let text = String::from_utf8_lossy(stderr);
let trimmed = text.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.chars().count() <= MAX_GEMINI_CLI_STDERR_CHARS {
return trimmed.to_string();
}
let clipped: String = trimmed.chars().take(MAX_GEMINI_CLI_STDERR_CHARS).collect();
format!("{clipped}...")
}
/// Invoke the gemini binary with the given prompt and optional model.
/// Returns the trimmed stdout output as the assistant response.
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
let mut cmd = Command::new(&self.binary_path);
cmd.arg("--print");
if Self::should_forward_model(model) {
cmd.arg("--model").arg(model);
}
// Read prompt from stdin to avoid exposing sensitive content in process args.
cmd.arg("-");
cmd.kill_on_drop(true);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|err| {
anyhow::anyhow!(
"Failed to spawn Gemini CLI binary at {}: {err}. \
Ensure `gemini` is installed and in PATH, or set GEMINI_CLI_PATH.",
self.binary_path.display()
)
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(message.as_bytes()).await.map_err(|err| {
anyhow::anyhow!("Failed to write prompt to Gemini CLI stdin: {err}")
})?;
stdin.shutdown().await.map_err(|err| {
anyhow::anyhow!("Failed to finalize Gemini CLI stdin stream: {err}")
})?;
}
let output = timeout(GEMINI_CLI_REQUEST_TIMEOUT, child.wait_with_output())
.await
.map_err(|_| {
anyhow::anyhow!(
"Gemini CLI request timed out after {:?} (binary: {})",
GEMINI_CLI_REQUEST_TIMEOUT,
self.binary_path.display()
)
})?
.map_err(|err| anyhow::anyhow!("Gemini CLI process failed: {err}"))?;
if !output.status.success() {
let code = output.status.code().unwrap_or(-1);
let stderr_excerpt = Self::redact_stderr(&output.stderr);
let stderr_note = if stderr_excerpt.is_empty() {
String::new()
} else {
format!(" Stderr: {stderr_excerpt}")
};
anyhow::bail!(
"Gemini CLI exited with non-zero status {code}. \
Check that Gemini CLI is authenticated and the CLI is supported.{stderr_note}"
);
}
let text = String::from_utf8(output.stdout)
.map_err(|err| anyhow::anyhow!("Gemini CLI produced non-UTF-8 output: {err}"))?;
Ok(text.trim().to_string())
}
}
impl Default for GeminiCliProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for GeminiCliProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
let full_message = match system_prompt {
Some(system) if !system.is_empty() => {
format!("{system}\n\n{message}")
}
_ => message.to_string(),
};
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: Some(TokenUsage::default()),
reasoning_content: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
}
#[test]
fn new_uses_env_override() {
let _guard = env_lock();
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
std::env::set_var(GEMINI_CLI_PATH_ENV, "/usr/local/bin/gemini");
let provider = GeminiCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/gemini"));
match orig {
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
}
}
#[test]
fn new_defaults_to_gemini() {
let _guard = env_lock();
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
std::env::remove_var(GEMINI_CLI_PATH_ENV);
let provider = GeminiCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
if let Some(v) = orig {
std::env::set_var(GEMINI_CLI_PATH_ENV, v);
}
}
#[test]
fn new_ignores_blank_env_override() {
let _guard = env_lock();
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
std::env::set_var(GEMINI_CLI_PATH_ENV, " ");
let provider = GeminiCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
match orig {
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
}
}
#[test]
fn should_forward_model_standard() {
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-pro"));
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-flash"));
}
#[test]
fn should_not_forward_default_model() {
assert!(!GeminiCliProvider::should_forward_model(
DEFAULT_MODEL_MARKER
));
assert!(!GeminiCliProvider::should_forward_model(""));
assert!(!GeminiCliProvider::should_forward_model(" "));
}
#[test]
fn validate_temperature_allows_defaults() {
assert!(GeminiCliProvider::validate_temperature(0.7).is_ok());
assert!(GeminiCliProvider::validate_temperature(1.0).is_ok());
}
#[test]
fn validate_temperature_rejects_custom_value() {
let err = GeminiCliProvider::validate_temperature(0.2).unwrap_err();
assert!(err
.to_string()
.contains("temperature unsupported by Gemini CLI"));
}
#[tokio::test]
async fn invoke_missing_binary_returns_error() {
let provider = GeminiCliProvider {
binary_path: PathBuf::from("/nonexistent/path/to/gemini"),
};
let result = provider.invoke_cli("hello", "default").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Failed to spawn Gemini CLI binary"),
"unexpected error message: {msg}"
);
}
}
+326
View File
@@ -0,0 +1,326 @@
//! KiloCLI subprocess provider.
//!
//! Integrates with the KiloCLI tool, spawning the `kilo` binary
//! as a subprocess for each inference request. This allows using KiloCLI's AI
//! models without an interactive UI session.
//!
//! # Usage
//!
//! The `kilo` binary must be available in `PATH`, or its location must be
//! set via the `KILO_CLI_PATH` environment variable.
//!
//! KiloCLI is invoked as:
//! ```text
//! kilo --print -
//! ```
//! with prompt content written to stdin.
//!
//! # Limitations
//!
//! - **Conversation history**: Only the system prompt (if present) and the last
//! user message are forwarded. Full multi-turn history is not preserved because
//! the CLI accepts a single prompt per invocation.
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
//! Only default values are accepted; custom values return an explicit error.
//!
//! # Authentication
//!
//! Authentication is handled by KiloCLI itself (its own credential store).
//! No explicit API key is required by this provider.
//!
//! # Environment variables
//!
//! - `KILO_CLI_PATH` — override the path to the `kilo` binary (default: `"kilo"`)
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{timeout, Duration};
/// Environment variable for overriding the path to the `kilo` binary.
pub const KILO_CLI_PATH_ENV: &str = "KILO_CLI_PATH";
/// Default `kilo` binary name (resolved via `PATH`).
const DEFAULT_KILO_CLI_BINARY: &str = "kilo";
/// Model name used to signal "use the provider's own default model".
const DEFAULT_MODEL_MARKER: &str = "default";
/// KiloCLI requests are bounded to avoid hung subprocesses.
const KILO_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
/// Avoid leaking oversized stderr payloads.
const MAX_KILO_CLI_STDERR_CHARS: usize = 512;
/// The CLI does not support sampling controls; allow only baseline defaults.
const KILO_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
const TEMP_EPSILON: f64 = 1e-9;
/// Provider that invokes the KiloCLI as a subprocess.
///
/// Each inference request spawns a fresh `kilo` process. This is the
/// non-interactive approach: the process handles the prompt and exits.
pub struct KiloCliProvider {
/// Path to the `kilo` binary.
binary_path: PathBuf,
}
impl KiloCliProvider {
/// Create a new `KiloCliProvider`.
///
/// The binary path is resolved from `KILO_CLI_PATH` env var if set,
/// otherwise defaults to `"kilo"` (found via `PATH`).
pub fn new() -> Self {
let binary_path = std::env::var(KILO_CLI_PATH_ENV)
.ok()
.filter(|path| !path.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_KILO_CLI_BINARY));
Self { binary_path }
}
/// Returns true if the model argument should be forwarded to the CLI.
fn should_forward_model(model: &str) -> bool {
let trimmed = model.trim();
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
}
fn supports_temperature(temperature: f64) -> bool {
KILO_CLI_SUPPORTED_TEMPERATURES
.iter()
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
}
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
if !temperature.is_finite() {
anyhow::bail!("KiloCLI provider received non-finite temperature value");
}
if !Self::supports_temperature(temperature) {
anyhow::bail!(
"temperature unsupported by KiloCLI: {temperature}. \
Supported values: 0.7 or 1.0"
);
}
Ok(())
}
fn redact_stderr(stderr: &[u8]) -> String {
let text = String::from_utf8_lossy(stderr);
let trimmed = text.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.chars().count() <= MAX_KILO_CLI_STDERR_CHARS {
return trimmed.to_string();
}
let clipped: String = trimmed.chars().take(MAX_KILO_CLI_STDERR_CHARS).collect();
format!("{clipped}...")
}
/// Invoke the kilo binary with the given prompt and optional model.
/// Returns the trimmed stdout output as the assistant response.
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
let mut cmd = Command::new(&self.binary_path);
cmd.arg("--print");
if Self::should_forward_model(model) {
cmd.arg("--model").arg(model);
}
// Read prompt from stdin to avoid exposing sensitive content in process args.
cmd.arg("-");
cmd.kill_on_drop(true);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|err| {
anyhow::anyhow!(
"Failed to spawn KiloCLI binary at {}: {err}. \
Ensure `kilo` is installed and in PATH, or set KILO_CLI_PATH.",
self.binary_path.display()
)
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(message.as_bytes())
.await
.map_err(|err| anyhow::anyhow!("Failed to write prompt to KiloCLI stdin: {err}"))?;
stdin
.shutdown()
.await
.map_err(|err| anyhow::anyhow!("Failed to finalize KiloCLI stdin stream: {err}"))?;
}
let output = timeout(KILO_CLI_REQUEST_TIMEOUT, child.wait_with_output())
.await
.map_err(|_| {
anyhow::anyhow!(
"KiloCLI request timed out after {:?} (binary: {})",
KILO_CLI_REQUEST_TIMEOUT,
self.binary_path.display()
)
})?
.map_err(|err| anyhow::anyhow!("KiloCLI process failed: {err}"))?;
if !output.status.success() {
let code = output.status.code().unwrap_or(-1);
let stderr_excerpt = Self::redact_stderr(&output.stderr);
let stderr_note = if stderr_excerpt.is_empty() {
String::new()
} else {
format!(" Stderr: {stderr_excerpt}")
};
anyhow::bail!(
"KiloCLI exited with non-zero status {code}. \
Check that KiloCLI is authenticated and the CLI is supported.{stderr_note}"
);
}
let text = String::from_utf8(output.stdout)
.map_err(|err| anyhow::anyhow!("KiloCLI produced non-UTF-8 output: {err}"))?;
Ok(text.trim().to_string())
}
}
impl Default for KiloCliProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for KiloCliProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
let full_message = match system_prompt {
Some(system) if !system.is_empty() => {
format!("{system}\n\n{message}")
}
_ => message.to_string(),
};
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: Some(TokenUsage::default()),
reasoning_content: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
}
#[test]
fn new_uses_env_override() {
let _guard = env_lock();
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
std::env::set_var(KILO_CLI_PATH_ENV, "/usr/local/bin/kilo");
let provider = KiloCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/kilo"));
match orig {
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
None => std::env::remove_var(KILO_CLI_PATH_ENV),
}
}
#[test]
fn new_defaults_to_kilo() {
let _guard = env_lock();
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
std::env::remove_var(KILO_CLI_PATH_ENV);
let provider = KiloCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
if let Some(v) = orig {
std::env::set_var(KILO_CLI_PATH_ENV, v);
}
}
#[test]
fn new_ignores_blank_env_override() {
let _guard = env_lock();
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
std::env::set_var(KILO_CLI_PATH_ENV, " ");
let provider = KiloCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
match orig {
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
None => std::env::remove_var(KILO_CLI_PATH_ENV),
}
}
#[test]
fn should_forward_model_standard() {
assert!(KiloCliProvider::should_forward_model("some-model"));
assert!(KiloCliProvider::should_forward_model("gpt-4o"));
}
#[test]
fn should_not_forward_default_model() {
assert!(!KiloCliProvider::should_forward_model(DEFAULT_MODEL_MARKER));
assert!(!KiloCliProvider::should_forward_model(""));
assert!(!KiloCliProvider::should_forward_model(" "));
}
#[test]
fn validate_temperature_allows_defaults() {
assert!(KiloCliProvider::validate_temperature(0.7).is_ok());
assert!(KiloCliProvider::validate_temperature(1.0).is_ok());
}
#[test]
fn validate_temperature_rejects_custom_value() {
let err = KiloCliProvider::validate_temperature(0.2).unwrap_err();
assert!(err
.to_string()
.contains("temperature unsupported by KiloCLI"));
}
#[tokio::test]
async fn invoke_missing_binary_returns_error() {
let provider = KiloCliProvider {
binary_path: PathBuf::from("/nonexistent/path/to/kilo"),
};
let result = provider.invoke_cli("hello", "default").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Failed to spawn KiloCLI binary"),
"unexpected error message: {msg}"
);
}
}
+43
View File
@@ -19,9 +19,12 @@
pub mod anthropic;
pub mod azure_openai;
pub mod bedrock;
pub mod claude_code;
pub mod compatible;
pub mod copilot;
pub mod gemini;
pub mod gemini_cli;
pub mod kilocli;
pub mod ollama;
pub mod openai;
pub mod openai_codex;
@@ -1251,6 +1254,9 @@ fn create_provider_with_url_and_options(
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
))),
"copilot" | "github-copilot" => Ok(Box::new(copilot::CopilotProvider::new(key))),
"claude-code" => Ok(Box::new(claude_code::ClaudeCodeProvider::new())),
"gemini-cli" => Ok(Box::new(gemini_cli::GeminiCliProvider::new())),
"kilocli" | "kilo" => Ok(Box::new(kilocli::KiloCliProvider::new())),
"lmstudio" | "lm-studio" => {
let lm_studio_key = key
.map(str::trim)
@@ -1902,6 +1908,24 @@ pub fn list_providers() -> Vec<ProviderInfo> {
aliases: &["github-copilot"],
local: false,
},
ProviderInfo {
name: "claude-code",
display_name: "Claude Code (CLI)",
aliases: &[],
local: true,
},
ProviderInfo {
name: "gemini-cli",
display_name: "Gemini CLI",
aliases: &[],
local: true,
},
ProviderInfo {
name: "kilocli",
display_name: "KiloCLI",
aliases: &["kilo"],
local: true,
},
ProviderInfo {
name: "lmstudio",
display_name: "LM Studio",
@@ -2720,6 +2744,22 @@ mod tests {
assert!(create_provider("github-copilot", Some("key")).is_ok());
}
#[test]
fn factory_claude_code() {
assert!(create_provider("claude-code", None).is_ok());
}
#[test]
fn factory_gemini_cli() {
assert!(create_provider("gemini-cli", None).is_ok());
}
#[test]
fn factory_kilocli() {
assert!(create_provider("kilocli", None).is_ok());
assert!(create_provider("kilo", None).is_ok());
}
#[test]
fn factory_nvidia() {
assert!(create_provider("nvidia", Some("nvapi-test")).is_ok());
@@ -3053,6 +3093,9 @@ mod tests {
"perplexity",
"cohere",
"copilot",
"claude-code",
"gemini-cli",
"kilocli",
"nvidia",
"astrai",
"ovhcloud",
+26
View File
@@ -12,6 +12,8 @@ pub enum CliCategory {
Container,
Build,
Cloud,
AiAgent,
Productivity,
}
impl std::fmt::Display for CliCategory {
@@ -23,6 +25,8 @@ impl std::fmt::Display for CliCategory {
Self::Container => write!(f, "Container"),
Self::Build => write!(f, "Build"),
Self::Cloud => write!(f, "Cloud"),
Self::AiAgent => write!(f, "AI Agent"),
Self::Productivity => write!(f, "Productivity"),
}
}
}
@@ -104,6 +108,26 @@ const KNOWN_CLIS: &[KnownCli] = &[
version_args: &["--version"],
category: CliCategory::Language,
},
KnownCli {
name: "claude",
version_args: &["--version"],
category: CliCategory::AiAgent,
},
KnownCli {
name: "gemini",
version_args: &["--version"],
category: CliCategory::AiAgent,
},
KnownCli {
name: "kilo",
version_args: &["--version"],
category: CliCategory::AiAgent,
},
KnownCli {
name: "gws",
version_args: &["--version"],
category: CliCategory::Productivity,
},
];
/// Discover available CLI tools on the system.
@@ -235,5 +259,7 @@ mod tests {
assert_eq!(CliCategory::Container.to_string(), "Container");
assert_eq!(CliCategory::Build.to_string(), "Build");
assert_eq!(CliCategory::Cloud.to_string(), "Cloud");
assert_eq!(CliCategory::AiAgent.to_string(), "AI Agent");
assert_eq!(CliCategory::Productivity.to_string(), "Productivity");
}
}
+152 -15
View File
@@ -65,27 +65,97 @@ impl Tool for CronAddTool {
json!({
"type": "object",
"properties": {
"name": { "type": "string" },
"schedule": {
"type": "object",
"description": "Schedule object: {kind:'cron',expr,tz?} | {kind:'at',at} | {kind:'every',every_ms}"
"name": {
"type": "string",
"description": "Optional human-readable name for the job"
},
// NOTE: oneOf is correct for OpenAI-compatible APIs (including OpenRouter).
// Gemini does not support oneOf in tool schemas; if Gemini native tool calling
// is ever wired up, SchemaCleanr::clean_for_gemini must be applied before
// tool specs are sent. See src/tools/schema.rs.
"schedule": {
"description": "When to run the job. Exactly one of three forms must be used.",
"oneOf": [
{
"type": "object",
"description": "Cron expression schedule (repeating). Example: {\"kind\":\"cron\",\"expr\":\"0 9 * * 1-5\",\"tz\":\"America/New_York\"}",
"properties": {
"kind": { "type": "string", "enum": ["cron"] },
"expr": { "type": "string", "description": "Standard 5-field cron expression, e.g. '*/5 * * * *'" },
"tz": { "type": "string", "description": "Optional IANA timezone name, e.g. 'America/New_York'. Defaults to UTC." }
},
"required": ["kind", "expr"]
},
{
"type": "object",
"description": "One-shot schedule at a specific UTC datetime. Example: {\"kind\":\"at\",\"at\":\"2025-12-31T23:59:00Z\"}",
"properties": {
"kind": { "type": "string", "enum": ["at"] },
"at": { "type": "string", "description": "ISO 8601 UTC datetime string, e.g. '2025-12-31T23:59:00Z'" }
},
"required": ["kind", "at"]
},
{
"type": "object",
"description": "Repeating interval schedule in milliseconds. Example: {\"kind\":\"every\",\"every_ms\":3600000} runs every hour.",
"properties": {
"kind": { "type": "string", "enum": ["every"] },
"every_ms": { "type": "integer", "description": "Interval in milliseconds, e.g. 3600000 for every hour" }
},
"required": ["kind", "every_ms"]
}
]
},
"job_type": {
"type": "string",
"enum": ["shell", "agent"],
"description": "Type of job: 'shell' runs a command, 'agent' runs the AI agent with a prompt"
},
"command": {
"type": "string",
"description": "Shell command to run (required when job_type is 'shell')"
},
"prompt": {
"type": "string",
"description": "Agent prompt to run on schedule (required when job_type is 'agent')"
},
"session_target": {
"type": "string",
"enum": ["isolated", "main"],
"description": "Agent session context: 'isolated' starts a fresh session each run, 'main' reuses the primary session"
},
"model": {
"type": "string",
"description": "Optional model override for agent jobs, e.g. 'x-ai/grok-4-1-fast'"
},
"job_type": { "type": "string", "enum": ["shell", "agent"] },
"command": { "type": "string" },
"prompt": { "type": "string" },
"session_target": { "type": "string", "enum": ["isolated", "main"] },
"model": { "type": "string" },
"delivery": {
"type": "object",
"description": "Delivery config to send job output to a channel. Example: {\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id>\"}",
"description": "Optional delivery config to send job output to a channel after each run. When provided, all three of mode, channel, and to are expected.",
"properties": {
"mode": { "type": "string", "enum": ["none", "announce"], "description": "Set to 'announce' to deliver output to a channel" },
"channel": { "type": "string", "enum": ["telegram", "discord", "slack", "mattermost", "matrix"], "description": "Channel type to deliver to" },
"to": { "type": "string", "description": "Target: Discord channel ID, Telegram chat ID, Slack channel, etc." },
"best_effort": { "type": "boolean", "description": "If true, delivery failure does not fail the job" }
"mode": {
"type": "string",
"enum": ["none", "announce"],
"description": "'announce' sends output to the specified channel; 'none' disables delivery"
},
"channel": {
"type": "string",
"enum": ["telegram", "discord", "slack", "mattermost", "matrix"],
"description": "Channel type to deliver output to"
},
"to": {
"type": "string",
"description": "Destination ID: Discord channel ID, Telegram chat ID, Slack channel name, etc."
},
"best_effort": {
"type": "boolean",
"description": "If true, a delivery failure does not fail the job itself. Defaults to true."
}
}
},
"delete_after_run": { "type": "boolean" },
"delete_after_run": {
"type": "boolean",
"description": "If true, the job is automatically deleted after its first successful run. Defaults to true for 'at' schedules."
},
"approved": {
"type": "boolean",
"description": "Set true to explicitly approve medium/high-risk shell commands in supervised mode",
@@ -497,4 +567,71 @@ mod tests {
assert!(values.iter().any(|value| value == "matrix"));
}
#[test]
fn schedule_schema_is_oneof_with_cron_at_every_variants() {
let tmp = tempfile::TempDir::new().unwrap();
let cfg = Arc::new(Config {
workspace_dir: tmp.path().join("workspace"),
config_path: tmp.path().join("config.toml"),
..Config::default()
});
let security = Arc::new(SecurityPolicy::from_config(
&cfg.autonomy,
&cfg.workspace_dir,
));
let tool = CronAddTool::new(cfg, security);
let schema = tool.parameters_schema();
// Top-level: schedule is required
let top_required = schema["required"].as_array().expect("top-level required");
assert!(top_required.iter().any(|v| v == "schedule"));
// schedule is a oneOf with exactly 3 variants: cron, at, every
let one_of = schema["properties"]["schedule"]["oneOf"]
.as_array()
.expect("schedule.oneOf must be an array");
assert_eq!(one_of.len(), 3, "expected cron, at, and every variants");
let kinds: Vec<&str> = one_of
.iter()
.filter_map(|v| v["properties"]["kind"]["enum"][0].as_str())
.collect();
assert!(kinds.contains(&"cron"), "missing cron variant");
assert!(kinds.contains(&"at"), "missing at variant");
assert!(kinds.contains(&"every"), "missing every variant");
// Each variant declares its required fields and every_ms is typed integer
for variant in one_of {
let kind = variant["properties"]["kind"]["enum"][0]
.as_str()
.expect("variant kind");
let req: Vec<&str> = variant["required"]
.as_array()
.unwrap_or_else(|| panic!("{kind} variant must have required"))
.iter()
.filter_map(|v| v.as_str())
.collect();
assert!(
req.contains(&"kind"),
"{kind} variant missing 'kind' in required"
);
match kind {
"cron" => assert!(req.contains(&"expr"), "cron variant missing 'expr'"),
"at" => assert!(req.contains(&"at"), "at variant missing 'at'"),
"every" => {
assert!(
req.contains(&"every_ms"),
"every variant missing 'every_ms'"
);
assert_eq!(
variant["properties"]["every_ms"]["type"].as_str(),
Some("integer"),
"every_ms must be typed as integer"
);
}
_ => panic!("unexpected kind: {kind}"),
}
}
}
}
+200 -2
View File
@@ -61,8 +61,106 @@ impl Tool for CronUpdateTool {
json!({
"type": "object",
"properties": {
"job_id": { "type": "string" },
"patch": { "type": "object" },
"job_id": {
"type": "string",
"description": "ID of the cron job to update, as returned by cron_add or cron_list"
},
"patch": {
"type": "object",
"description": "Fields to update. Only include fields you want to change; omitted fields are left as-is.",
"properties": {
"name": {
"type": "string",
"description": "New human-readable name for the job"
},
"enabled": {
"type": "boolean",
"description": "Enable or disable the job without deleting it"
},
"command": {
"type": "string",
"description": "New shell command (for shell jobs)"
},
"prompt": {
"type": "string",
"description": "New agent prompt (for agent jobs)"
},
"model": {
"type": "string",
"description": "Model override for agent jobs, e.g. 'x-ai/grok-4-1-fast'"
},
"session_target": {
"type": "string",
"enum": ["isolated", "main"],
"description": "Agent session context: 'isolated' starts fresh each run, 'main' reuses the primary session"
},
"delete_after_run": {
"type": "boolean",
"description": "If true, delete the job automatically after its first successful run"
},
// NOTE: oneOf is correct for OpenAI-compatible APIs (including OpenRouter).
// Gemini does not support oneOf in tool schemas; if Gemini native tool calling
// is ever wired up, SchemaCleanr::clean_for_gemini must be applied before
// tool specs are sent. See src/tools/schema.rs.
"schedule": {
"description": "New schedule for the job. Exactly one of three forms must be used.",
"oneOf": [
{
"type": "object",
"description": "Cron expression schedule (repeating). Example: {\"kind\":\"cron\",\"expr\":\"0 9 * * 1-5\",\"tz\":\"America/New_York\"}",
"properties": {
"kind": { "type": "string", "enum": ["cron"] },
"expr": { "type": "string", "description": "Standard 5-field cron expression, e.g. '*/5 * * * *'" },
"tz": { "type": "string", "description": "Optional IANA timezone name, e.g. 'America/New_York'. Defaults to UTC." }
},
"required": ["kind", "expr"]
},
{
"type": "object",
"description": "One-shot schedule at a specific UTC datetime. Example: {\"kind\":\"at\",\"at\":\"2025-12-31T23:59:00Z\"}",
"properties": {
"kind": { "type": "string", "enum": ["at"] },
"at": { "type": "string", "description": "ISO 8601 UTC datetime string, e.g. '2025-12-31T23:59:00Z'" }
},
"required": ["kind", "at"]
},
{
"type": "object",
"description": "Repeating interval schedule in milliseconds. Example: {\"kind\":\"every\",\"every_ms\":3600000} runs every hour.",
"properties": {
"kind": { "type": "string", "enum": ["every"] },
"every_ms": { "type": "integer", "description": "Interval in milliseconds, e.g. 3600000 for every hour" }
},
"required": ["kind", "every_ms"]
}
]
},
"delivery": {
"type": "object",
"description": "Delivery config to send job output to a channel after each run. When provided, mode, channel, and to are all expected.",
"properties": {
"mode": {
"type": "string",
"enum": ["none", "announce"],
"description": "'announce' sends output to the specified channel; 'none' disables delivery"
},
"channel": {
"type": "string",
"enum": ["telegram", "discord", "slack", "mattermost", "matrix"],
"description": "Channel type to deliver output to"
},
"to": {
"type": "string",
"description": "Destination ID: Discord channel ID, Telegram chat ID, Slack channel name, etc."
},
"best_effort": {
"type": "boolean",
"description": "If true, a delivery failure does not fail the job itself. Defaults to true."
}
}
}
}
},
"approved": {
"type": "boolean",
"description": "Set true to explicitly approve medium/high-risk shell commands in supervised mode",
@@ -274,6 +372,106 @@ mod tests {
assert!(approved.success, "{:?}", approved.error);
}
#[test]
fn patch_schema_covers_all_cronjobpatch_fields_and_schedule_is_oneof() {
let tmp = TempDir::new().unwrap();
let cfg = Arc::new(Config {
workspace_dir: tmp.path().join("workspace"),
config_path: tmp.path().join("config.toml"),
..Config::default()
});
let security = Arc::new(SecurityPolicy::from_config(
&cfg.autonomy,
&cfg.workspace_dir,
));
let tool = CronUpdateTool::new(cfg, security);
let schema = tool.parameters_schema();
// Top-level: job_id and patch are required
let top_required = schema["required"].as_array().expect("top-level required");
let top_req_strs: Vec<&str> = top_required.iter().filter_map(|v| v.as_str()).collect();
assert!(top_req_strs.contains(&"job_id"));
assert!(top_req_strs.contains(&"patch"));
// patch exposes all CronJobPatch fields
let patch_props = schema["properties"]["patch"]["properties"]
.as_object()
.expect("patch must have a properties object");
for field in &[
"name",
"enabled",
"command",
"prompt",
"model",
"session_target",
"delete_after_run",
"schedule",
"delivery",
] {
assert!(
patch_props.contains_key(*field),
"patch schema missing field: {field}"
);
}
// patch.schedule is a oneOf with exactly 3 variants: cron, at, every
let one_of = schema["properties"]["patch"]["properties"]["schedule"]["oneOf"]
.as_array()
.expect("patch.schedule.oneOf must be an array");
assert_eq!(one_of.len(), 3, "expected cron, at, and every variants");
let kinds: Vec<&str> = one_of
.iter()
.filter_map(|v| v["properties"]["kind"]["enum"][0].as_str())
.collect();
assert!(kinds.contains(&"cron"), "missing cron variant");
assert!(kinds.contains(&"at"), "missing at variant");
assert!(kinds.contains(&"every"), "missing every variant");
// Each variant declares its required fields and every_ms is typed integer
for variant in one_of {
let kind = variant["properties"]["kind"]["enum"][0]
.as_str()
.expect("variant kind");
let req: Vec<&str> = variant["required"]
.as_array()
.unwrap_or_else(|| panic!("{kind} variant must have required"))
.iter()
.filter_map(|v| v.as_str())
.collect();
assert!(
req.contains(&"kind"),
"{kind} variant missing 'kind' in required"
);
match kind {
"cron" => assert!(req.contains(&"expr"), "cron variant missing 'expr'"),
"at" => assert!(req.contains(&"at"), "at variant missing 'at'"),
"every" => {
assert!(
req.contains(&"every_ms"),
"every variant missing 'every_ms'"
);
assert_eq!(
variant["properties"]["every_ms"]["type"].as_str(),
Some("integer"),
"every_ms must be typed as integer"
);
}
_ => panic!("unexpected schedule kind: {kind}"),
}
}
// patch.delivery.channel enum covers all supported channels
let channel_enum = schema["properties"]["patch"]["properties"]["delivery"]["properties"]
["channel"]["enum"]
.as_array()
.expect("patch.delivery.channel must have an enum");
let channel_strs: Vec<&str> = channel_enum.iter().filter_map(|v| v.as_str()).collect();
for ch in &["telegram", "discord", "slack", "mattermost", "matrix"] {
assert!(channel_strs.contains(ch), "delivery.channel missing: {ch}");
}
}
#[tokio::test]
async fn blocks_update_when_rate_limited() {
let tmp = TempDir::new().unwrap();
+1
View File
@@ -421,6 +421,7 @@ impl DelegateTool {
None,
&[],
&[],
None,
),
)
.await;
+716
View File
@@ -0,0 +1,716 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use std::time::Duration;
/// Default `gws` command execution time before kill (overridden by config).
const DEFAULT_GWS_TIMEOUT_SECS: u64 = 30;
/// Maximum output size in bytes (1MB).
const MAX_OUTPUT_BYTES: usize = 1_048_576;
/// Allowed Google Workspace services that gws can target.
const DEFAULT_ALLOWED_SERVICES: &[&str] = &[
"drive",
"sheets",
"gmail",
"calendar",
"docs",
"slides",
"tasks",
"people",
"chat",
"classroom",
"forms",
"keep",
"meet",
"events",
];
/// Google Workspace CLI (`gws`) integration tool.
///
/// Wraps the `gws` CLI binary to give the agent structured access to
/// Google Workspace services (Drive, Gmail, Calendar, Sheets, etc.).
/// Requires `gws` to be installed and authenticated (`gws auth login`).
pub struct GoogleWorkspaceTool {
security: Arc<SecurityPolicy>,
allowed_services: Vec<String>,
credentials_path: Option<String>,
default_account: Option<String>,
rate_limit_per_minute: u32,
timeout_secs: u64,
audit_log: bool,
}
impl GoogleWorkspaceTool {
/// Create a new `GoogleWorkspaceTool`.
///
/// If `allowed_services` is empty, the default service set is used.
pub fn new(
security: Arc<SecurityPolicy>,
allowed_services: Vec<String>,
credentials_path: Option<String>,
default_account: Option<String>,
rate_limit_per_minute: u32,
timeout_secs: u64,
audit_log: bool,
) -> Self {
let services = if allowed_services.is_empty() {
DEFAULT_ALLOWED_SERVICES
.iter()
.map(|s| (*s).to_string())
.collect()
} else {
allowed_services
};
Self {
security,
allowed_services: services,
credentials_path,
default_account,
rate_limit_per_minute,
timeout_secs,
audit_log,
}
}
}
#[async_trait]
impl Tool for GoogleWorkspaceTool {
fn name(&self) -> &str {
"google_workspace"
}
fn description(&self) -> &str {
"Interact with Google Workspace services (Drive, Gmail, Calendar, Sheets, Docs, etc.) \
via the gws CLI. Requires gws to be installed and authenticated."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"service": {
"type": "string",
"description": "Google Workspace service (e.g. drive, gmail, calendar, sheets, docs, slides, tasks, people, chat, forms, keep, meet)"
},
"resource": {
"type": "string",
"description": "Service resource (e.g. files, messages, events, spreadsheets)"
},
"method": {
"type": "string",
"description": "Method to call on the resource (e.g. list, get, create, update, delete)"
},
"sub_resource": {
"type": "string",
"description": "Optional sub-resource for nested operations"
},
"params": {
"type": "object",
"description": "URL/query parameters as key-value pairs (passed as --params JSON)"
},
"body": {
"type": "object",
"description": "Request body for POST/PATCH/PUT operations (passed as --json JSON)"
},
"format": {
"type": "string",
"enum": ["json", "table", "yaml", "csv"],
"description": "Output format (default: json)"
},
"page_all": {
"type": "boolean",
"description": "Auto-paginate through all results"
},
"page_limit": {
"type": "integer",
"description": "Max pages to fetch when using page_all (default: 10)"
}
},
"required": ["service", "resource", "method"]
})
}
/// Execute a Google Workspace CLI command with input validation and security enforcement.
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let service = args
.get("service")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'service' parameter"))?;
let resource = args
.get("resource")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'resource' parameter"))?;
let method = args
.get("method")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'method' parameter"))?;
// Security checks
if self.security.is_rate_limited() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
});
}
// Validate service is in the allowlist
if !self.allowed_services.iter().any(|s| s == service) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Service '{service}' is not in the allowed services list. \
Allowed: {}",
self.allowed_services.join(", ")
)),
});
}
// Validate inputs contain no shell metacharacters
for (label, value) in [
("service", service),
("resource", resource),
("method", method),
] {
if !value
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid characters in '{label}': only alphanumeric, underscore, and hyphen are allowed"
)),
});
}
}
// Build the gws command — validate all optional fields before consuming budget
let mut cmd_args: Vec<String> = vec![service.to_string(), resource.to_string()];
if let Some(sub_resource_value) = args.get("sub_resource") {
let sub_resource = match sub_resource_value.as_str() {
Some(s) => s,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'sub_resource' must be a string".into()),
})
}
};
if !sub_resource
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"Invalid characters in 'sub_resource': only alphanumeric, underscore, and hyphen are allowed"
.into(),
),
});
}
cmd_args.push(sub_resource.to_string());
}
cmd_args.push(method.to_string());
if let Some(params) = args.get("params") {
if !params.is_object() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'params' must be an object".into()),
});
}
cmd_args.push("--params".into());
cmd_args.push(params.to_string());
}
if let Some(body) = args.get("body") {
if !body.is_object() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'body' must be an object".into()),
});
}
cmd_args.push("--json".into());
cmd_args.push(body.to_string());
}
if let Some(format_value) = args.get("format") {
let format = match format_value.as_str() {
Some(s) => s,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'format' must be a string".into()),
})
}
};
match format {
"json" | "table" | "yaml" | "csv" => {
cmd_args.push("--format".into());
cmd_args.push(format.to_string());
}
_ => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid format '{format}': must be json, table, yaml, or csv"
)),
});
}
}
}
let page_all = match args.get("page_all") {
Some(v) => match v.as_bool() {
Some(b) => b,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'page_all' must be a boolean".into()),
})
}
},
None => false,
};
if page_all {
cmd_args.push("--page-all".into());
}
let page_limit = match args.get("page_limit") {
Some(v) => match v.as_u64() {
Some(n) => Some(n),
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("'page_limit' must be a non-negative integer".into()),
})
}
},
None => None,
};
if page_all || page_limit.is_some() {
cmd_args.push("--page-limit".into());
cmd_args.push(page_limit.unwrap_or(10).to_string());
}
// Charge action budget only after all validation passes
if !self.security.record_action() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Rate limit exceeded: action budget exhausted".into()),
});
}
let mut cmd = tokio::process::Command::new("gws");
cmd.args(&cmd_args);
cmd.env_clear();
// gws needs PATH to find itself and HOME/APPDATA for credential storage
for key in &["PATH", "HOME", "APPDATA", "USERPROFILE", "LANG", "TERM"] {
if let Ok(val) = std::env::var(key) {
cmd.env(key, val);
}
}
// Apply credential path if configured
if let Some(ref creds) = self.credentials_path {
cmd.env("GOOGLE_APPLICATION_CREDENTIALS", creds);
}
// Apply default account if configured
if let Some(ref account) = self.default_account {
cmd.args(["--account", account]);
}
if self.audit_log {
tracing::info!(
tool = "google_workspace",
service = service,
resource = resource,
method = method,
"gws audit: executing API call"
);
}
// Apply credential path if configured
if let Some(ref creds) = self.credentials_path {
cmd.env("GOOGLE_APPLICATION_CREDENTIALS", creds);
}
// Apply default account if configured
if let Some(ref account) = self.default_account {
cmd.args(["--account", account]);
}
if self.audit_log {
tracing::info!(
tool = "google_workspace",
service = service,
resource = resource,
method = method,
"gws audit: executing API call"
);
}
let result =
tokio::time::timeout(Duration::from_secs(self.timeout_secs), cmd.output()).await;
match result {
Ok(Ok(output)) => {
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
let mut stderr = String::from_utf8_lossy(&output.stderr).to_string();
if stdout.len() > MAX_OUTPUT_BYTES {
// Find a valid char boundary at or before MAX_OUTPUT_BYTES
let mut boundary = MAX_OUTPUT_BYTES;
while boundary > 0 && !stdout.is_char_boundary(boundary) {
boundary -= 1;
}
stdout.truncate(boundary);
stdout.push_str("\n... [output truncated at 1MB]");
}
if stderr.len() > MAX_OUTPUT_BYTES {
let mut boundary = MAX_OUTPUT_BYTES;
while boundary > 0 && !stderr.is_char_boundary(boundary) {
boundary -= 1;
}
stderr.truncate(boundary);
stderr.push_str("\n... [stderr truncated at 1MB]");
}
Ok(ToolResult {
success: output.status.success(),
output: stdout,
error: if stderr.is_empty() {
None
} else {
Some(stderr)
},
})
}
Ok(Err(e)) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Failed to execute gws: {e}. Is gws installed? Run: npm install -g @googleworkspace/cli"
)),
}),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"gws command timed out after {}s and was killed", self.timeout_secs
)),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Full,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
#[test]
fn tool_name() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
assert_eq!(tool.name(), "google_workspace");
}
#[test]
fn tool_description_non_empty() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
assert!(!tool.description().is_empty());
}
#[test]
fn tool_schema_has_required_fields() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let schema = tool.parameters_schema();
assert!(schema["properties"]["service"].is_object());
assert!(schema["properties"]["resource"].is_object());
assert!(schema["properties"]["method"].is_object());
let required = schema["required"]
.as_array()
.expect("required should be an array");
assert!(required.contains(&json!("service")));
assert!(required.contains(&json!("resource")));
assert!(required.contains(&json!("method")));
}
#[test]
fn default_allowed_services_populated() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
assert!(!tool.allowed_services.is_empty());
assert!(tool.allowed_services.contains(&"drive".to_string()));
assert!(tool.allowed_services.contains(&"gmail".to_string()));
assert!(tool.allowed_services.contains(&"calendar".to_string()));
}
#[test]
fn custom_allowed_services_override_defaults() {
let tool = GoogleWorkspaceTool::new(
test_security(),
vec!["drive".into(), "sheets".into()],
None,
None,
60,
30,
false,
);
assert_eq!(tool.allowed_services.len(), 2);
assert!(tool.allowed_services.contains(&"drive".to_string()));
assert!(tool.allowed_services.contains(&"sheets".to_string()));
assert!(!tool.allowed_services.contains(&"gmail".to_string()));
}
#[tokio::test]
async fn rejects_disallowed_service() {
let tool = GoogleWorkspaceTool::new(
test_security(),
vec!["drive".into()],
None,
None,
60,
30,
false,
);
let result = tool
.execute(json!({
"service": "gmail",
"resource": "users",
"method": "list"
}))
.await
.expect("disallowed service should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("not in the allowed"));
}
#[tokio::test]
async fn rejects_shell_injection_in_service() {
let tool = GoogleWorkspaceTool::new(
test_security(),
vec!["drive; rm -rf /".into()],
None,
None,
60,
30,
false,
);
let result = tool
.execute(json!({
"service": "drive; rm -rf /",
"resource": "files",
"method": "list"
}))
.await
.expect("shell injection should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("Invalid characters"));
}
#[tokio::test]
async fn rejects_shell_injection_in_resource() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files$(whoami)",
"method": "list"
}))
.await
.expect("shell injection should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("Invalid characters"));
}
#[tokio::test]
async fn rejects_invalid_format() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "list",
"format": "xml"
}))
.await
.expect("invalid format should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("Invalid format"));
}
#[tokio::test]
async fn rejects_wrong_type_params() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "list",
"params": "not_an_object"
}))
.await
.expect("wrong type params should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("'params' must be an object"));
}
#[tokio::test]
async fn rejects_wrong_type_body() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "create",
"body": "not_an_object"
}))
.await
.expect("wrong type body should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("'body' must be an object"));
}
#[tokio::test]
async fn rejects_wrong_type_page_all() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "list",
"page_all": "yes"
}))
.await
.expect("wrong type page_all should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("'page_all' must be a boolean"));
}
#[tokio::test]
async fn rejects_wrong_type_page_limit() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "list",
"page_limit": "ten"
}))
.await
.expect("wrong type page_limit should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("'page_limit' must be a non-negative integer"));
}
#[tokio::test]
async fn rejects_wrong_type_sub_resource() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "list",
"sub_resource": 123
}))
.await
.expect("wrong type sub_resource should return a result");
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("'sub_resource' must be a string"));
}
#[tokio::test]
async fn missing_required_param_returns_error() {
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
let result = tool.execute(json!({"service": "drive"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn rate_limited_returns_error() {
let security = Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Full,
max_actions_per_hour: 0,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
});
let tool = GoogleWorkspaceTool::new(security, vec![], None, None, 60, 30, false);
let result = tool
.execute(json!({
"service": "drive",
"resource": "files",
"method": "list"
}))
.await
.expect("rate-limited should return a result");
assert!(!result.success);
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
}
#[test]
fn gws_timeout_is_reasonable() {
assert_eq!(DEFAULT_GWS_TIMEOUT_SECS, 30);
}
}
+6 -11
View File
@@ -161,8 +161,7 @@ impl DeferredMcpToolSet {
/// The agent loop consults this each iteration to decide which tool_specs
/// to include in the LLM request.
pub struct ActivatedToolSet {
/// name -> activated Tool
tools: HashMap<String, Box<dyn Tool>>,
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ActivatedToolSet {
@@ -172,27 +171,23 @@ impl ActivatedToolSet {
}
}
/// Mark a tool as activated, storing its live wrapper.
pub fn activate(&mut self, name: String, tool: Box<dyn Tool>) {
pub fn activate(&mut self, name: String, tool: Arc<dyn Tool>) {
self.tools.insert(name, tool);
}
/// Whether a tool has been activated.
pub fn is_activated(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
/// Get an activated tool for execution.
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|t| t.as_ref())
/// Clone the Arc so the caller can drop the mutex guard before awaiting.
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
/// All currently activated tool specs (to include in LLM requests).
pub fn tool_specs(&self) -> Vec<ToolSpec> {
self.tools.values().map(|t| t.spec()).collect()
}
/// All activated tools for execution dispatch.
pub fn tool_names(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
@@ -280,7 +275,7 @@ mod tests {
let mut set = ActivatedToolSet::new();
assert!(!set.is_activated("fake"));
set.activate("fake".into(), Box::new(FakeTool));
set.activate("fake".into(), Arc::new(FakeTool));
assert!(set.is_activated("fake"));
assert!(set.get("fake").is_some());
assert_eq!(set.tool_specs().len(), 1);
+19
View File
@@ -37,6 +37,7 @@ pub mod file_read;
pub mod file_write;
pub mod git_operations;
pub mod glob_search;
pub mod google_workspace;
#[cfg(feature = "hardware")]
pub mod hardware_board_info;
#[cfg(feature = "hardware")]
@@ -96,6 +97,7 @@ pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
pub use git_operations::GitOperationsTool;
pub use glob_search::GlobSearchTool;
pub use google_workspace::GoogleWorkspaceTool;
#[cfg(feature = "hardware")]
pub use hardware_board_info::HardwareBoardInfoTool;
#[cfg(feature = "hardware")]
@@ -433,6 +435,23 @@ pub fn all_tools_with_runtime(
tool_arcs.push(Arc::new(CloudPatternsTool::new()));
}
// Google Workspace CLI (gws) integration — requires shell access
if root_config.google_workspace.enabled && has_shell_access {
tool_arcs.push(Arc::new(GoogleWorkspaceTool::new(
security.clone(),
root_config.google_workspace.allowed_services.clone(),
root_config.google_workspace.credentials_path.clone(),
root_config.google_workspace.default_account.clone(),
root_config.google_workspace.rate_limit_per_minute,
root_config.google_workspace.timeout_secs,
root_config.google_workspace.audit_log,
)));
} else if root_config.google_workspace.enabled {
tracing::warn!(
"google_workspace: skipped registration because shell access is unavailable"
);
}
// PDF extraction (feature-gated at compile time via rag-pdf)
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
+3 -3
View File
@@ -940,10 +940,10 @@ impl Tool for ModelRoutingConfigTool {
}
match action.as_str() {
"set_default" => self.handle_set_default(&args).await,
"upsert_scenario" => self.handle_upsert_scenario(&args).await,
"set_default" => Box::pin(self.handle_set_default(&args)).await,
"upsert_scenario" => Box::pin(self.handle_upsert_scenario(&args)).await,
"remove_scenario" => self.handle_remove_scenario(&args).await,
"upsert_agent" => self.handle_upsert_agent(&args).await,
"upsert_agent" => Box::pin(self.handle_upsert_agent(&args)).await,
"remove_agent" => self.handle_remove_agent(&args).await,
_ => unreachable!("validated above"),
}
+1 -1
View File
@@ -412,7 +412,7 @@ impl Tool for ProxyConfigTool {
}
match action.as_str() {
"set" => self.handle_set(&args).await,
"set" => Box::pin(self.handle_set(&args)).await,
"disable" => self.handle_disable(&args).await,
"apply_env" => self.handle_apply_env(),
"clear_env" => self.handle_clear_env(),
+2 -2
View File
@@ -107,7 +107,7 @@ impl Tool for ToolSearchTool {
if let Some(spec) = self.deferred.tool_spec(&stub.prefixed_name) {
if !guard.is_activated(&stub.prefixed_name) {
if let Some(tool) = self.deferred.activate(&stub.prefixed_name) {
guard.activate(stub.prefixed_name.clone(), tool);
guard.activate(stub.prefixed_name.clone(), Arc::from(tool));
activated_count += 1;
}
}
@@ -152,7 +152,7 @@ impl ToolSearchTool {
Some(spec) => {
if !guard.is_activated(name) {
if let Some(tool) = self.deferred.activate(name) {
guard.activate(name.to_string(), tool);
guard.activate(name.to_string(), Arc::from(tool));
activated_count += 1;
}
}