Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 675a5c9af0 | |||
| b099728c27 | |||
| 1ca2092ca0 | |||
| 5e3308eaaa | |||
| ec255ad788 | |||
| 7182f659ce | |||
| ae7681209d | |||
| ee3469e912 | |||
| fec81d8e75 | |||
| 9a073fae1a | |||
| f0db63e53c | |||
| df4dfeaf66 |
@@ -1,6 +1,20 @@
|
||||
name: Pub AUR Package
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Generate PKGBUILD only (no push)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
AUR_SSH_KEY:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
|
||||
@@ -1,6 +1,20 @@
|
||||
name: Pub Scoop Manifest
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: "Existing release tag (vX.Y.Z)"
|
||||
required: true
|
||||
type: string
|
||||
dry_run:
|
||||
description: "Generate manifest only (no push)"
|
||||
required: false
|
||||
default: false
|
||||
type: boolean
|
||||
secrets:
|
||||
SCOOP_BUCKET_TOKEN:
|
||||
required: false
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release_tag:
|
||||
|
||||
@@ -361,6 +361,27 @@ jobs:
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
# ── Post-publish: package manager auto-sync ─────────────────────────
|
||||
scoop:
|
||||
name: Update Scoop Manifest
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-scoop.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
aur:
|
||||
name: Update AUR Package
|
||||
needs: [validate, publish]
|
||||
if: ${{ !cancelled() && needs.publish.result == 'success' }}
|
||||
uses: ./.github/workflows/pub-aur.yml
|
||||
with:
|
||||
release_tag: ${{ needs.validate.outputs.tag }}
|
||||
dry_run: false
|
||||
secrets: inherit
|
||||
|
||||
# ── Post-publish: tweet after release + website are live ──────────────
|
||||
# Docker push can be slow; don't let it block the tweet.
|
||||
tweet:
|
||||
|
||||
Generated
+1
-1
@@ -7945,7 +7945,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.4.2"
|
||||
version = "0.4.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-imap",
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclawlabs"
|
||||
version = "0.4.2"
|
||||
version = "0.4.3"
|
||||
edition = "2021"
|
||||
authors = ["theonlyhennygod"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
Vendored
+2
-2
@@ -1,6 +1,6 @@
|
||||
pkgbase = zeroclaw
|
||||
pkgdesc = Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.
|
||||
pkgver = 0.4.2
|
||||
pkgver = 0.4.3
|
||||
pkgrel = 1
|
||||
url = https://github.com/zeroclaw-labs/zeroclaw
|
||||
arch = x86_64
|
||||
@@ -10,7 +10,7 @@ pkgbase = zeroclaw
|
||||
makedepends = git
|
||||
depends = gcc-libs
|
||||
depends = openssl
|
||||
source = zeroclaw-0.4.2.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.4.2.tar.gz
|
||||
source = zeroclaw-0.4.3.tar.gz::https://github.com/zeroclaw-labs/zeroclaw/archive/refs/tags/v0.4.3.tar.gz
|
||||
sha256sums = SKIP
|
||||
|
||||
pkgname = zeroclaw
|
||||
|
||||
Vendored
+1
-1
@@ -1,6 +1,6 @@
|
||||
# Maintainer: zeroclaw-labs <bot@zeroclaw.dev>
|
||||
pkgname=zeroclaw
|
||||
pkgver=0.4.2
|
||||
pkgver=0.4.3
|
||||
pkgrel=1
|
||||
pkgdesc="Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||
arch=('x86_64')
|
||||
|
||||
Vendored
+2
-2
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"version": "0.4.2",
|
||||
"version": "0.4.3",
|
||||
"description": "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant.",
|
||||
"homepage": "https://github.com/zeroclaw-labs/zeroclaw",
|
||||
"license": "MIT|Apache-2.0",
|
||||
"architecture": {
|
||||
"64bit": {
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.4.2/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"url": "https://github.com/zeroclaw-labs/zeroclaw/releases/download/v0.4.3/zeroclaw-x86_64-pc-windows-msvc.zip",
|
||||
"hash": "",
|
||||
"bin": "zeroclaw.exe"
|
||||
}
|
||||
|
||||
@@ -38,10 +38,10 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
||||
- Purpose: manual, bot-owned Homebrew core formula bump PR flow for tagged releases
|
||||
- Guardrail: release tag must match `Cargo.toml` version
|
||||
- `.github/workflows/pub-scoop.yml` (`Pub Scoop Manifest`)
|
||||
- Purpose: manual Scoop bucket manifest update for Windows distribution
|
||||
- Purpose: Scoop bucket manifest update for Windows; auto-called by stable release, also manual dispatch
|
||||
- Guardrail: release tag must be `vX.Y.Z` format; Windows binary hash extracted from `SHA256SUMS`
|
||||
- `.github/workflows/pub-aur.yml` (`Pub AUR Package`)
|
||||
- Purpose: manual AUR PKGBUILD push for Arch Linux distribution
|
||||
- Purpose: AUR PKGBUILD push for Arch Linux; auto-called by stable release, also manual dispatch
|
||||
- Guardrail: release tag must be `vX.Y.Z` format; source tarball SHA256 computed at publish time
|
||||
- `.github/workflows/pr-label-policy-check.yml` (`Label Policy Sanity`)
|
||||
- Purpose: validate shared contributor-tier policy in `.github/label-policy.json` and ensure label workflows consume that policy
|
||||
@@ -81,8 +81,8 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
||||
- `Docker`: tag push (`v*`) for publish, matching PRs to `master` for smoke build, manual dispatch for smoke only
|
||||
- `Release`: tag push (`v*`), weekly schedule (verification-only), manual dispatch (verification or publish)
|
||||
- `Pub Homebrew Core`: manual dispatch only
|
||||
- `Pub Scoop Manifest`: manual dispatch only
|
||||
- `Pub AUR Package`: manual dispatch only
|
||||
- `Pub Scoop Manifest`: auto-called by stable release, also manual dispatch
|
||||
- `Pub AUR Package`: auto-called by stable release, also manual dispatch
|
||||
- `Security Audit`: push to `master`, PRs to `master`, weekly schedule
|
||||
- `Sec Vorpal Reviewdog`: manual dispatch only
|
||||
- `Workflow Sanity`: PR/push when `.github/workflows/**`, `.github/*.yml`, or `.github/*.yaml` change
|
||||
|
||||
+48
-8
@@ -2152,6 +2152,7 @@ pub(crate) async fn agent_turn(
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -2160,6 +2161,7 @@ async fn execute_one_tool(
|
||||
call_name: &str,
|
||||
call_arguments: serde_json::Value,
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
observer: &dyn Observer,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<ToolExecutionOutcome> {
|
||||
@@ -2170,7 +2172,13 @@ async fn execute_one_tool(
|
||||
});
|
||||
let start = Instant::now();
|
||||
|
||||
let Some(tool) = find_tool(tools_registry, call_name) else {
|
||||
let static_tool = find_tool(tools_registry, call_name);
|
||||
let activated_arc = if static_tool.is_none() {
|
||||
activated_tools.and_then(|at| at.lock().unwrap().get(call_name))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let Some(tool) = static_tool.or(activated_arc.as_deref()) else {
|
||||
let reason = format!("Unknown tool: {call_name}");
|
||||
let duration = start.elapsed();
|
||||
observer.record_event(&ObserverEvent::ToolCall {
|
||||
@@ -2268,6 +2276,7 @@ fn should_execute_tools_in_parallel(
|
||||
async fn execute_tools_parallel(
|
||||
tool_calls: &[ParsedToolCall],
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
observer: &dyn Observer,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<Vec<ToolExecutionOutcome>> {
|
||||
@@ -2278,6 +2287,7 @@ async fn execute_tools_parallel(
|
||||
&call.name,
|
||||
call.arguments.clone(),
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token,
|
||||
)
|
||||
@@ -2291,6 +2301,7 @@ async fn execute_tools_parallel(
|
||||
async fn execute_tools_sequential(
|
||||
tool_calls: &[ParsedToolCall],
|
||||
tools_registry: &[Box<dyn Tool>],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
observer: &dyn Observer,
|
||||
cancellation_token: Option<&CancellationToken>,
|
||||
) -> Result<Vec<ToolExecutionOutcome>> {
|
||||
@@ -2302,6 +2313,7 @@ async fn execute_tools_sequential(
|
||||
&call.name,
|
||||
call.arguments.clone(),
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token,
|
||||
)
|
||||
@@ -2345,6 +2357,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
hooks: Option<&crate::hooks::HookRunner>,
|
||||
excluded_tools: &[String],
|
||||
dedup_exempt_tools: &[String],
|
||||
activated_tools: Option<&std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
) -> Result<String> {
|
||||
let max_iterations = if max_tool_iterations == 0 {
|
||||
DEFAULT_MAX_TOOL_ITERATIONS
|
||||
@@ -2352,12 +2365,6 @@ pub(crate) async fn run_tool_call_loop(
|
||||
max_tool_iterations
|
||||
};
|
||||
|
||||
let tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
|
||||
.iter()
|
||||
.filter(|tool| !excluded_tools.iter().any(|ex| ex == tool.name()))
|
||||
.map(|tool| tool.spec())
|
||||
.collect();
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
let turn_id = Uuid::new_v4().to_string();
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
|
||||
@@ -2369,6 +2376,21 @@ pub(crate) async fn run_tool_call_loop(
|
||||
return Err(ToolLoopCancelled.into());
|
||||
}
|
||||
|
||||
// Rebuild tool_specs each iteration so newly activated deferred tools appear.
|
||||
let mut tool_specs: Vec<crate::tools::ToolSpec> = tools_registry
|
||||
.iter()
|
||||
.filter(|tool| !excluded_tools.iter().any(|ex| ex == tool.name()))
|
||||
.map(|tool| tool.spec())
|
||||
.collect();
|
||||
if let Some(at) = activated_tools {
|
||||
for spec in at.lock().unwrap().tool_specs() {
|
||||
if !excluded_tools.iter().any(|ex| ex == &spec.name) {
|
||||
tool_specs.push(spec);
|
||||
}
|
||||
}
|
||||
}
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
@@ -2847,6 +2869,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
execute_tools_parallel(
|
||||
&executable_calls,
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token.as_ref(),
|
||||
)
|
||||
@@ -2855,6 +2878,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
execute_tools_sequential(
|
||||
&executable_calls,
|
||||
tools_registry,
|
||||
activated_tools,
|
||||
observer,
|
||||
cancellation_token.as_ref(),
|
||||
)
|
||||
@@ -3106,6 +3130,9 @@ pub async fn run(
|
||||
// eagerly. Instead, a `tool_search` built-in is registered so the LLM can
|
||||
// fetch schemas on demand. This reduces context window waste.
|
||||
let mut deferred_section = String::new();
|
||||
let mut activated_handle: Option<
|
||||
std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>,
|
||||
> = None;
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
@@ -3130,6 +3157,7 @@ pub async fn run(
|
||||
let activated = std::sync::Arc::new(std::sync::Mutex::new(
|
||||
crate::tools::ActivatedToolSet::new(),
|
||||
));
|
||||
activated_handle = Some(std::sync::Arc::clone(&activated));
|
||||
tools_registry.push(Box::new(crate::tools::ToolSearchTool::new(
|
||||
deferred_set,
|
||||
activated,
|
||||
@@ -3442,6 +3470,7 @@ pub async fn run(
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
final_output = response.clone();
|
||||
@@ -3603,6 +3632,7 @@ pub async fn run(
|
||||
None,
|
||||
&excluded_tools,
|
||||
&config.agent.tool_call_dedup_exempt,
|
||||
activated_handle.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -3982,7 +4012,8 @@ mod tests {
|
||||
.expect("should produce a sample whose byte index 300 is not a char boundary");
|
||||
|
||||
let observer = NoopObserver;
|
||||
let result = execute_one_tool("unknown_tool", call_arguments, &[], &observer, None).await;
|
||||
let result =
|
||||
execute_one_tool("unknown_tool", call_arguments, &[], None, &observer, None).await;
|
||||
assert!(result.is_ok(), "execute_one_tool should not panic or error");
|
||||
|
||||
let outcome = result.unwrap();
|
||||
@@ -4319,6 +4350,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
@@ -4366,6 +4398,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
@@ -4407,6 +4440,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
@@ -4534,6 +4568,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("parallel execution should complete");
|
||||
@@ -4604,6 +4639,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish after deduplicating repeated calls");
|
||||
@@ -4666,6 +4702,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&exempt,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should finish with exempt tool executing twice");
|
||||
@@ -4743,6 +4780,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&exempt,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("loop should complete");
|
||||
@@ -4797,6 +4835,7 @@ mod tests {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("native fallback id flow should complete");
|
||||
@@ -6698,6 +6737,7 @@ Let me check the result."#;
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("tool loop should complete");
|
||||
|
||||
@@ -329,6 +329,7 @@ struct ChannelRuntimeContext {
|
||||
/// `[autonomy]` config; auto-denies tools that would need interactive
|
||||
/// approval since no operator is present on channel runs.
|
||||
approval_manager: Arc<ApprovalManager>,
|
||||
activated_tools: Option<std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -2140,6 +2141,7 @@ async fn process_channel_message(
|
||||
ctx.non_cli_excluded_tools.as_ref()
|
||||
},
|
||||
ctx.tool_call_dedup_exempt.as_ref(),
|
||||
ctx.activated_tools.as_ref(),
|
||||
),
|
||||
) => LlmExecutionResult::Completed(result),
|
||||
};
|
||||
@@ -3236,6 +3238,7 @@ fn collect_configured_channels(
|
||||
Vec::new(),
|
||||
sl.allowed_users.clone(),
|
||||
)
|
||||
.with_group_reply_policy(sl.mention_only, Vec::new())
|
||||
.with_workspace_dir(config.workspace_dir.clone()),
|
||||
),
|
||||
});
|
||||
@@ -3699,6 +3702,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
// When `deferred_loading` is enabled, MCP tools are NOT added eagerly.
|
||||
// Instead, a `tool_search` built-in is registered for on-demand loading.
|
||||
let mut deferred_section = String::new();
|
||||
let mut ch_activated_handle: Option<
|
||||
std::sync::Arc<std::sync::Mutex<crate::tools::ActivatedToolSet>>,
|
||||
> = None;
|
||||
if config.mcp.enabled && !config.mcp.servers.is_empty() {
|
||||
tracing::info!(
|
||||
"Initializing MCP client — {} server(s) configured",
|
||||
@@ -3722,6 +3728,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
let activated = std::sync::Arc::new(std::sync::Mutex::new(
|
||||
crate::tools::ActivatedToolSet::new(),
|
||||
));
|
||||
ch_activated_handle = Some(std::sync::Arc::clone(&activated));
|
||||
built_tools.push(Box::new(crate::tools::ToolSearchTool::new(
|
||||
deferred_set,
|
||||
activated,
|
||||
@@ -4016,6 +4023,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
None
|
||||
},
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(&config.autonomy)),
|
||||
activated_tools: ch_activated_handle,
|
||||
});
|
||||
|
||||
// Hydrate in-memory conversation histories from persisted JSONL session files.
|
||||
@@ -4308,6 +4316,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
assert!(compact_sender_history(&ctx, &sender));
|
||||
@@ -4416,6 +4425,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
append_sender_turn(&ctx, &sender, ChatMessage::user("hello"));
|
||||
@@ -4480,6 +4490,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(&ctx, &sender, "pending"));
|
||||
@@ -4563,6 +4574,7 @@ mod tests {
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
};
|
||||
|
||||
assert!(rollback_orphan_user_turn(
|
||||
@@ -5096,6 +5108,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5168,6 +5181,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5254,6 +5268,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5325,6 +5340,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5406,6 +5422,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5507,6 +5524,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5590,6 +5608,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5688,6 +5707,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5771,6 +5791,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -5844,6 +5865,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6028,6 +6050,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||
@@ -6120,6 +6143,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6226,6 +6250,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
});
|
||||
|
||||
@@ -6331,6 +6356,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(8);
|
||||
@@ -6417,6 +6443,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -6488,6 +6515,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7117,6 +7145,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7214,6 +7243,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7311,6 +7341,7 @@ BTC is currently around $65,000 based on latest tool output."#
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -7872,6 +7903,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
// Simulate a photo attachment message with [IMAGE:] marker.
|
||||
@@ -7950,6 +7982,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8102,6 +8135,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8204,6 +8238,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8298,6 +8333,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
@@ -8412,6 +8448,7 @@ This is an example JSON object for profile settings."#;
|
||||
approval_manager: Arc::new(ApprovalManager::for_non_interactive(
|
||||
&crate::config::AutonomyConfig::default(),
|
||||
)),
|
||||
activated_tools: None,
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
|
||||
+211
-32
@@ -62,24 +62,146 @@ impl NextcloudTalkChannel {
|
||||
|
||||
/// Parse a Nextcloud Talk webhook payload into channel messages.
|
||||
///
|
||||
/// Relevant payload fields:
|
||||
/// - `type` (accepts `message` or `Create`)
|
||||
/// - `object.token` (room token for reply routing)
|
||||
/// - `message.actorType`, `message.actorId`, `message.message`, `message.timestamp`
|
||||
/// Two payload formats are supported:
|
||||
///
|
||||
/// **Format A — legacy/custom** (`type: "message"`):
|
||||
/// ```json
|
||||
/// {
|
||||
/// "type": "message",
|
||||
/// "object": { "token": "<room>" },
|
||||
/// "message": { "actorId": "...", "message": "...", ... }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// **Format B — Activity Streams 2.0** (`type: "Create"`):
|
||||
/// This is the format actually sent by Nextcloud Talk bot webhooks.
|
||||
/// ```json
|
||||
/// {
|
||||
/// "type": "Create",
|
||||
/// "actor": { "type": "Person", "id": "users/alice", "name": "Alice" },
|
||||
/// "object": { "type": "Note", "id": "177", "content": "{\"message\":\"hi\",\"parameters\":[]}", "mediaType": "text/markdown" },
|
||||
/// "target": { "type": "Collection", "id": "<room_token>", "name": "Room Name" }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn parse_webhook_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let messages = Vec::new();
|
||||
|
||||
let event_type = match payload.get("type").and_then(|v| v.as_str()) {
|
||||
Some(t) => t,
|
||||
None => return messages,
|
||||
};
|
||||
|
||||
// Activity Streams 2.0 format sent by Nextcloud Talk bot webhooks.
|
||||
if event_type.eq_ignore_ascii_case("create") {
|
||||
return self.parse_as2_payload(payload);
|
||||
}
|
||||
|
||||
// Legacy/custom format.
|
||||
if !event_type.eq_ignore_ascii_case("message") {
|
||||
tracing::debug!("Nextcloud Talk: skipping non-message event: {event_type}");
|
||||
return messages;
|
||||
}
|
||||
|
||||
self.parse_message_payload(payload)
|
||||
}
|
||||
|
||||
/// Parse Activity Streams 2.0 `Create` payload (real Nextcloud Talk bot webhook format).
|
||||
fn parse_as2_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(event_type) = payload.get("type").and_then(|v| v.as_str()) {
|
||||
// Nextcloud Talk bot webhooks send "Create" for new chat messages,
|
||||
// but some setups may use "message". Accept both.
|
||||
let is_message_event = event_type.eq_ignore_ascii_case("message")
|
||||
|| event_type.eq_ignore_ascii_case("create");
|
||||
if !is_message_event {
|
||||
tracing::debug!("Nextcloud Talk: skipping non-message event: {event_type}");
|
||||
return messages;
|
||||
}
|
||||
let obj = match payload.get("object") {
|
||||
Some(o) => o,
|
||||
None => return messages,
|
||||
};
|
||||
|
||||
// Only handle Note objects (= chat messages). Ignore reactions, etc.
|
||||
let object_type = obj.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if !object_type.eq_ignore_ascii_case("note") {
|
||||
tracing::debug!("Nextcloud Talk: skipping AS2 Create with object.type={object_type}");
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Room token is in target.id.
|
||||
let room_token = payload
|
||||
.get("target")
|
||||
.and_then(|t| t.get("id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|t| !t.is_empty());
|
||||
|
||||
let Some(room_token) = room_token else {
|
||||
tracing::warn!("Nextcloud Talk: missing target.id (room token) in AS2 payload");
|
||||
return messages;
|
||||
};
|
||||
|
||||
// Actor — skip bot-originated messages to prevent feedback loops.
|
||||
let actor = payload.get("actor").cloned().unwrap_or_default();
|
||||
let actor_type = actor.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if actor_type.eq_ignore_ascii_case("application") {
|
||||
tracing::debug!("Nextcloud Talk: skipping bot-originated AS2 message");
|
||||
return messages;
|
||||
}
|
||||
|
||||
// actor.id is "users/<id>" — strip the prefix.
|
||||
let actor_id = actor
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|id| id.trim_start_matches("users/").trim())
|
||||
.filter(|id| !id.is_empty());
|
||||
|
||||
let Some(actor_id) = actor_id else {
|
||||
tracing::warn!("Nextcloud Talk: missing actor.id in AS2 payload");
|
||||
return messages;
|
||||
};
|
||||
|
||||
if !self.is_user_allowed(actor_id) {
|
||||
tracing::warn!(
|
||||
"Nextcloud Talk: ignoring message from unauthorized actor: {actor_id}. \
|
||||
Add to channels.nextcloud_talk.allowed_users in config.toml, \
|
||||
or run `zeroclaw onboard --channels-only` to configure interactively."
|
||||
);
|
||||
return messages;
|
||||
}
|
||||
|
||||
// Message text is JSON-encoded inside object.content.
|
||||
// e.g. content = "{\"message\":\"hello\",\"parameters\":[]}"
|
||||
let content = obj
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.and_then(|s| serde_json::from_str::<serde_json::Value>(s).ok())
|
||||
.and_then(|v| {
|
||||
v.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(str::trim)
|
||||
.map(str::to_string)
|
||||
})
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
let Some(content) = content else {
|
||||
tracing::debug!("Nextcloud Talk: empty or unparseable AS2 message content");
|
||||
return messages;
|
||||
};
|
||||
|
||||
let message_id =
|
||||
Self::value_to_string(obj.get("id")).unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
|
||||
messages.push(ChannelMessage {
|
||||
id: message_id,
|
||||
reply_target: room_token.to_string(),
|
||||
sender: actor_id.to_string(),
|
||||
content,
|
||||
channel: "nextcloud_talk".to_string(),
|
||||
timestamp: Self::now_unix_secs(),
|
||||
thread_ts: None,
|
||||
});
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Parse legacy `type: "message"` payload format.
|
||||
fn parse_message_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
let Some(message_obj) = payload.get("message") else {
|
||||
return messages;
|
||||
};
|
||||
@@ -343,33 +465,90 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nextcloud_talk_parse_create_event_type() {
|
||||
let channel = make_channel();
|
||||
fn nextcloud_talk_parse_as2_create_payload() {
|
||||
let channel = NextcloudTalkChannel::new(
|
||||
"https://cloud.example.com".into(),
|
||||
"app-token".into(),
|
||||
vec!["*".into()],
|
||||
);
|
||||
// Real payload format sent by Nextcloud Talk bot webhooks.
|
||||
let payload = serde_json::json!({
|
||||
"type": "Create",
|
||||
"object": {
|
||||
"id": "42",
|
||||
"token": "room-token-123",
|
||||
"name": "Team Room",
|
||||
"type": "room"
|
||||
"actor": {
|
||||
"type": "Person",
|
||||
"id": "users/user_a",
|
||||
"name": "User A",
|
||||
"talkParticipantType": "1"
|
||||
},
|
||||
"message": {
|
||||
"id": 88,
|
||||
"token": "room-token-123",
|
||||
"actorType": "users",
|
||||
"actorId": "user_a",
|
||||
"actorDisplayName": "User A",
|
||||
"timestamp": 1_735_701_300,
|
||||
"messageType": "comment",
|
||||
"systemMessage": "",
|
||||
"message": "Hello via Create event"
|
||||
"object": {
|
||||
"type": "Note",
|
||||
"id": "177",
|
||||
"name": "message",
|
||||
"content": "{\"message\":\"hallo, bist du da?\",\"parameters\":[]}",
|
||||
"mediaType": "text/markdown"
|
||||
},
|
||||
"target": {
|
||||
"type": "Collection",
|
||||
"id": "room-token-123",
|
||||
"name": "HOME"
|
||||
}
|
||||
});
|
||||
|
||||
let messages = channel.parse_webhook_payload(&payload);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].id, "88");
|
||||
assert_eq!(messages[0].content, "Hello via Create event");
|
||||
assert_eq!(messages[0].reply_target, "room-token-123");
|
||||
assert_eq!(messages[0].sender, "user_a");
|
||||
assert_eq!(messages[0].content, "hallo, bist du da?");
|
||||
assert_eq!(messages[0].channel, "nextcloud_talk");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nextcloud_talk_parse_as2_skips_bot_originated() {
|
||||
let channel = NextcloudTalkChannel::new(
|
||||
"https://cloud.example.com".into(),
|
||||
"app-token".into(),
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"type": "Create",
|
||||
"actor": {
|
||||
"type": "Application",
|
||||
"id": "bots/jarvis",
|
||||
"name": "jarvis"
|
||||
},
|
||||
"object": {
|
||||
"type": "Note",
|
||||
"id": "178",
|
||||
"content": "{\"message\":\"I am the bot\",\"parameters\":[]}",
|
||||
"mediaType": "text/markdown"
|
||||
},
|
||||
"target": {
|
||||
"type": "Collection",
|
||||
"id": "room-token-123",
|
||||
"name": "HOME"
|
||||
}
|
||||
});
|
||||
|
||||
let messages = channel.parse_webhook_payload(&payload);
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nextcloud_talk_parse_as2_skips_non_note_objects() {
|
||||
let channel = NextcloudTalkChannel::new(
|
||||
"https://cloud.example.com".into(),
|
||||
"app-token".into(),
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
"type": "Create",
|
||||
"actor": { "type": "Person", "id": "users/user_a" },
|
||||
"object": { "type": "Reaction", "id": "5" },
|
||||
"target": { "type": "Collection", "id": "room-token-123" }
|
||||
});
|
||||
|
||||
let messages = channel.parse_webhook_payload(&payload);
|
||||
assert!(messages.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -671,4 +671,35 @@ allowed_users = ["user1"]
|
||||
|
||||
assert_eq!(compose_message_content(&payload), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_send_body_uses_markdown_msg_type() {
|
||||
// Verify the expected JSON shape for both group and user send paths.
|
||||
// msg_type 2 with a nested markdown object is required by the QQ API
|
||||
// for markdown rendering; msg_type 0 (plain text) causes markdown
|
||||
// syntax to appear literally in the client.
|
||||
let content = "**bold** and `code`";
|
||||
|
||||
let group_body = json!({
|
||||
"markdown": { "content": content },
|
||||
"msg_type": 2,
|
||||
});
|
||||
assert_eq!(group_body["msg_type"], 2);
|
||||
assert_eq!(group_body["markdown"]["content"], content);
|
||||
assert!(
|
||||
group_body.get("content").is_none(),
|
||||
"top-level 'content' must not be present"
|
||||
);
|
||||
|
||||
let user_body = json!({
|
||||
"markdown": { "content": content },
|
||||
"msg_type": 2,
|
||||
});
|
||||
assert_eq!(user_body["msg_type"], 2);
|
||||
assert_eq!(user_body["markdown"]["content"], content);
|
||||
assert!(
|
||||
user_body.get("content").is_none(),
|
||||
"top-level 'content' must not be present"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2685,7 +2685,7 @@ Ensure only one `zeroclaw` process is using this bot token."
|
||||
} else if let Some(m) = self.try_parse_attachment_message(update).await {
|
||||
m
|
||||
} else {
|
||||
self.handle_unauthorized_message(update).await;
|
||||
Box::pin(self.handle_unauthorized_message(update)).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
|
||||
+757
-34
@@ -1,11 +1,19 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::multipart::{Form, Part};
|
||||
|
||||
use crate::config::TranscriptionConfig;
|
||||
|
||||
/// Maximum upload size accepted by the Groq Whisper API (25 MB).
|
||||
/// Maximum upload size accepted by most Whisper-compatible APIs (25 MB).
|
||||
const MAX_AUDIO_BYTES: usize = 25 * 1024 * 1024;
|
||||
|
||||
/// Request timeout for transcription API calls (seconds).
|
||||
const TRANSCRIPTION_TIMEOUT_SECS: u64 = 120;
|
||||
|
||||
// ── Audio utilities ─────────────────────────────────────────────
|
||||
|
||||
/// Map file extension to MIME type for Whisper-compatible transcription APIs.
|
||||
fn mime_for_audio(extension: &str) -> Option<&'static str> {
|
||||
match extension.to_ascii_lowercase().as_str() {
|
||||
@@ -31,16 +39,10 @@ fn normalize_audio_filename(file_name: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// Transcribe audio bytes via a Whisper-compatible transcription API.
|
||||
/// Validate audio data and resolve MIME type from file name.
|
||||
///
|
||||
/// Returns the transcribed text on success. Requires `GROQ_API_KEY` in the
|
||||
/// environment. The caller is responsible for enforcing duration limits
|
||||
/// *before* downloading the file; this function enforces the byte-size cap.
|
||||
pub async fn transcribe_audio(
|
||||
audio_data: Vec<u8>,
|
||||
file_name: &str,
|
||||
config: &TranscriptionConfig,
|
||||
) -> Result<String> {
|
||||
/// Returns `(normalized_filename, mime_type)` on success.
|
||||
fn validate_audio(audio_data: &[u8], file_name: &str) -> Result<(String, &'static str)> {
|
||||
if audio_data.len() > MAX_AUDIO_BYTES {
|
||||
bail!(
|
||||
"Audio file too large ({} bytes, max {MAX_AUDIO_BYTES})",
|
||||
@@ -59,37 +61,494 @@ pub async fn transcribe_audio(
|
||||
)
|
||||
})?;
|
||||
|
||||
let api_key = std::env::var("GROQ_API_KEY").context(
|
||||
"GROQ_API_KEY environment variable is not set — required for voice transcription",
|
||||
)?;
|
||||
Ok((normalized_name, mime))
|
||||
}
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.groq");
|
||||
// ── TranscriptionProvider trait ─────────────────────────────────
|
||||
|
||||
let file_part = Part::bytes(audio_data)
|
||||
.file_name(normalized_name)
|
||||
.mime_str(mime)?;
|
||||
/// Trait for speech-to-text provider implementations.
|
||||
#[async_trait]
|
||||
pub trait TranscriptionProvider: Send + Sync {
|
||||
/// Human-readable provider name (e.g. "groq", "openai").
|
||||
fn name(&self) -> &str;
|
||||
|
||||
let mut form = Form::new()
|
||||
.part("file", file_part)
|
||||
.text("model", config.model.clone())
|
||||
.text("response_format", "json");
|
||||
/// Transcribe raw audio bytes. `file_name` includes the extension for
|
||||
/// format detection (e.g. "voice.ogg").
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String>;
|
||||
|
||||
if let Some(ref lang) = config.language {
|
||||
form = form.text("language", lang.clone());
|
||||
/// List of supported audio file extensions.
|
||||
fn supported_formats(&self) -> Vec<String> {
|
||||
vec![
|
||||
"flac", "mp3", "mpeg", "mpga", "mp4", "m4a", "ogg", "oga", "opus", "wav", "webm",
|
||||
]
|
||||
.into_iter()
|
||||
.map(String::from)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── GroqProvider ────────────────────────────────────────────────
|
||||
|
||||
/// Groq Whisper API provider (default, backward-compatible with existing config).
|
||||
pub struct GroqProvider {
|
||||
api_url: String,
|
||||
model: String,
|
||||
api_key: String,
|
||||
language: Option<String>,
|
||||
}
|
||||
|
||||
impl GroqProvider {
|
||||
/// Build from the existing `TranscriptionConfig` fields.
|
||||
///
|
||||
/// Credential resolution order:
|
||||
/// 1. `config.api_key`
|
||||
/// 2. `GROQ_API_KEY` environment variable (backward compatibility)
|
||||
pub fn from_config(config: &TranscriptionConfig) -> Result<Self> {
|
||||
let api_key = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.or_else(|| {
|
||||
std::env::var("GROQ_API_KEY")
|
||||
.ok()
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
})
|
||||
.context(
|
||||
"Missing transcription API key: set [transcription].api_key or GROQ_API_KEY environment variable",
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
api_url: config.api_url.clone(),
|
||||
model: config.model.clone(),
|
||||
api_key,
|
||||
language: config.language.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for GroqProvider {
|
||||
fn name(&self) -> &str {
|
||||
"groq"
|
||||
}
|
||||
|
||||
if let Some(ref prompt) = config.initial_prompt {
|
||||
form = form.text("prompt", prompt.clone());
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.groq");
|
||||
|
||||
let file_part = Part::bytes(audio_data.to_vec())
|
||||
.file_name(normalized_name)
|
||||
.mime_str(mime)?;
|
||||
|
||||
let mut form = Form::new()
|
||||
.part("file", file_part)
|
||||
.text("model", self.model.clone())
|
||||
.text("response_format", "json");
|
||||
|
||||
if let Some(ref lang) = self.language {
|
||||
form = form.text("language", lang.clone());
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&self.api_url)
|
||||
.bearer_auth(&self.api_key)
|
||||
.multipart(form)
|
||||
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send transcription request to Groq")?;
|
||||
|
||||
parse_whisper_response(resp).await
|
||||
}
|
||||
}
|
||||
|
||||
// ── OpenAiWhisperProvider ───────────────────────────────────────
|
||||
|
||||
/// OpenAI Whisper API provider.
|
||||
pub struct OpenAiWhisperProvider {
|
||||
api_key: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl OpenAiWhisperProvider {
|
||||
pub fn from_config(config: &crate::config::OpenAiSttConfig) -> Result<Self> {
|
||||
let api_key = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.context("Missing OpenAI STT API key: set [transcription.openai].api_key")?;
|
||||
|
||||
Ok(Self {
|
||||
api_key,
|
||||
model: config.model.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for OpenAiWhisperProvider {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
let resp = client
|
||||
.post(&config.api_url)
|
||||
.bearer_auth(&api_key)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send transcription request")?;
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
let (normalized_name, mime) = validate_audio(audio_data, file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.openai");
|
||||
|
||||
let file_part = Part::bytes(audio_data.to_vec())
|
||||
.file_name(normalized_name)
|
||||
.mime_str(mime)?;
|
||||
|
||||
let form = Form::new()
|
||||
.part("file", file_part)
|
||||
.text("model", self.model.clone())
|
||||
.text("response_format", "json");
|
||||
|
||||
let resp = client
|
||||
.post("https://api.openai.com/v1/audio/transcriptions")
|
||||
.bearer_auth(&self.api_key)
|
||||
.multipart(form)
|
||||
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send transcription request to OpenAI")?;
|
||||
|
||||
parse_whisper_response(resp).await
|
||||
}
|
||||
}
|
||||
|
||||
// ── DeepgramProvider ────────────────────────────────────────────
|
||||
|
||||
/// Deepgram STT API provider.
|
||||
pub struct DeepgramProvider {
|
||||
api_key: String,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl DeepgramProvider {
|
||||
pub fn from_config(config: &crate::config::DeepgramSttConfig) -> Result<Self> {
|
||||
let api_key = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.context("Missing Deepgram API key: set [transcription.deepgram].api_key")?;
|
||||
|
||||
Ok(Self {
|
||||
api_key,
|
||||
model: config.model.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for DeepgramProvider {
|
||||
fn name(&self) -> &str {
|
||||
"deepgram"
|
||||
}
|
||||
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
let (_, mime) = validate_audio(audio_data, file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.deepgram");
|
||||
|
||||
let url = format!(
|
||||
"https://api.deepgram.com/v1/listen?model={}&punctuate=true",
|
||||
self.model
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Token {}", self.api_key))
|
||||
.header("Content-Type", mime)
|
||||
.body(audio_data.to_vec())
|
||||
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send transcription request to Deepgram")?;
|
||||
|
||||
let status = resp.status();
|
||||
let body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Deepgram response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["err_msg"]
|
||||
.as_str()
|
||||
.or_else(|| body["error"].as_str())
|
||||
.unwrap_or("unknown error");
|
||||
bail!("Deepgram API error ({}): {}", status, error_msg);
|
||||
}
|
||||
|
||||
let text = body["results"]["channels"][0]["alternatives"][0]["transcript"]
|
||||
.as_str()
|
||||
.context("Deepgram response missing transcript field")?
|
||||
.to_string();
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
|
||||
// ── AssemblyAiProvider ──────────────────────────────────────────
|
||||
|
||||
/// AssemblyAI STT API provider.
|
||||
pub struct AssemblyAiProvider {
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl AssemblyAiProvider {
|
||||
pub fn from_config(config: &crate::config::AssemblyAiSttConfig) -> Result<Self> {
|
||||
let api_key = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.context("Missing AssemblyAI API key: set [transcription.assemblyai].api_key")?;
|
||||
|
||||
Ok(Self { api_key })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for AssemblyAiProvider {
|
||||
fn name(&self) -> &str {
|
||||
"assemblyai"
|
||||
}
|
||||
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
let (_, _) = validate_audio(audio_data, file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.assemblyai");
|
||||
|
||||
// Step 1: Upload the audio file.
|
||||
let upload_resp = client
|
||||
.post("https://api.assemblyai.com/v2/upload")
|
||||
.header("Authorization", &self.api_key)
|
||||
.header("Content-Type", "application/octet-stream")
|
||||
.body(audio_data.to_vec())
|
||||
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to upload audio to AssemblyAI")?;
|
||||
|
||||
let upload_status = upload_resp.status();
|
||||
let upload_body: serde_json::Value = upload_resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse AssemblyAI upload response")?;
|
||||
|
||||
if !upload_status.is_success() {
|
||||
let error_msg = upload_body["error"].as_str().unwrap_or("unknown error");
|
||||
bail!("AssemblyAI upload error ({}): {}", upload_status, error_msg);
|
||||
}
|
||||
|
||||
let upload_url = upload_body["upload_url"]
|
||||
.as_str()
|
||||
.context("AssemblyAI upload response missing 'upload_url'")?;
|
||||
|
||||
// Step 2: Create transcription job.
|
||||
let transcript_req = serde_json::json!({
|
||||
"audio_url": upload_url,
|
||||
});
|
||||
|
||||
let create_resp = client
|
||||
.post("https://api.assemblyai.com/v2/transcript")
|
||||
.header("Authorization", &self.api_key)
|
||||
.json(&transcript_req)
|
||||
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to create AssemblyAI transcription")?;
|
||||
|
||||
let create_status = create_resp.status();
|
||||
let create_body: serde_json::Value = create_resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse AssemblyAI create response")?;
|
||||
|
||||
if !create_status.is_success() {
|
||||
let error_msg = create_body["error"].as_str().unwrap_or("unknown error");
|
||||
bail!(
|
||||
"AssemblyAI transcription error ({}): {}",
|
||||
create_status,
|
||||
error_msg
|
||||
);
|
||||
}
|
||||
|
||||
let transcript_id = create_body["id"]
|
||||
.as_str()
|
||||
.context("AssemblyAI response missing 'id'")?;
|
||||
|
||||
// Step 3: Poll for completion.
|
||||
let poll_url = format!("https://api.assemblyai.com/v2/transcript/{transcript_id}");
|
||||
let poll_interval = std::time::Duration::from_secs(3);
|
||||
let poll_deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(180);
|
||||
|
||||
while tokio::time::Instant::now() < poll_deadline {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let poll_resp = client
|
||||
.get(&poll_url)
|
||||
.header("Authorization", &self.api_key)
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to poll AssemblyAI transcription")?;
|
||||
|
||||
let poll_status = poll_resp.status();
|
||||
let poll_body: serde_json::Value = poll_resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse AssemblyAI poll response")?;
|
||||
|
||||
if !poll_status.is_success() {
|
||||
let error_msg = poll_body["error"].as_str().unwrap_or("unknown poll error");
|
||||
bail!("AssemblyAI poll error ({}): {}", poll_status, error_msg);
|
||||
}
|
||||
|
||||
let status_str = poll_body["status"].as_str().unwrap_or("unknown");
|
||||
|
||||
match status_str {
|
||||
"completed" => {
|
||||
let text = poll_body["text"]
|
||||
.as_str()
|
||||
.context("AssemblyAI response missing 'text'")?
|
||||
.to_string();
|
||||
return Ok(text);
|
||||
}
|
||||
"error" => {
|
||||
let error_msg = poll_body["error"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown transcription error");
|
||||
bail!("AssemblyAI transcription failed: {}", error_msg);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
bail!("AssemblyAI transcription timed out after 180s")
|
||||
}
|
||||
}
|
||||
|
||||
// ── GoogleSttProvider ───────────────────────────────────────────
|
||||
|
||||
/// Google Cloud Speech-to-Text API provider.
|
||||
pub struct GoogleSttProvider {
|
||||
api_key: String,
|
||||
language_code: String,
|
||||
}
|
||||
|
||||
impl GoogleSttProvider {
|
||||
pub fn from_config(config: &crate::config::GoogleSttConfig) -> Result<Self> {
|
||||
let api_key = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.context("Missing Google STT API key: set [transcription.google].api_key")?;
|
||||
|
||||
Ok(Self {
|
||||
api_key,
|
||||
language_code: config.language_code.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TranscriptionProvider for GoogleSttProvider {
|
||||
fn name(&self) -> &str {
|
||||
"google"
|
||||
}
|
||||
|
||||
fn supported_formats(&self) -> Vec<String> {
|
||||
// Google Cloud STT supports a subset of formats.
|
||||
vec!["flac", "wav", "ogg", "opus", "mp3", "webm"]
|
||||
.into_iter()
|
||||
.map(String::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
let (normalized_name, _) = validate_audio(audio_data, file_name)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.google");
|
||||
|
||||
let encoding = match normalized_name
|
||||
.rsplit_once('.')
|
||||
.map(|(_, e)| e.to_ascii_lowercase())
|
||||
.as_deref()
|
||||
{
|
||||
Some("flac") => "FLAC",
|
||||
Some("wav") => "LINEAR16",
|
||||
Some("ogg" | "opus") => "OGG_OPUS",
|
||||
Some("mp3") => "MP3",
|
||||
Some("webm") => "WEBM_OPUS",
|
||||
Some(ext) => bail!("Google STT does not support '.{ext}' input"),
|
||||
None => bail!("Google STT requires a file extension"),
|
||||
};
|
||||
|
||||
let audio_content =
|
||||
base64::Engine::encode(&base64::engine::general_purpose::STANDARD, audio_data);
|
||||
|
||||
let request_body = serde_json::json!({
|
||||
"config": {
|
||||
"encoding": encoding,
|
||||
"languageCode": &self.language_code,
|
||||
"enableAutomaticPunctuation": true,
|
||||
},
|
||||
"audio": {
|
||||
"content": audio_content,
|
||||
}
|
||||
});
|
||||
|
||||
let url = format!(
|
||||
"https://speech.googleapis.com/v1/speech:recognize?key={}",
|
||||
self.api_key
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.json(&request_body)
|
||||
.timeout(std::time::Duration::from_secs(TRANSCRIPTION_TIMEOUT_SECS))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send transcription request to Google STT")?;
|
||||
|
||||
let status = resp.status();
|
||||
let body: serde_json::Value = resp
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Google STT response")?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = body["error"]["message"].as_str().unwrap_or("unknown error");
|
||||
bail!("Google STT API error ({}): {}", status, error_msg);
|
||||
}
|
||||
|
||||
let text = body["results"][0]["alternatives"][0]["transcript"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared response parsing ─────────────────────────────────────
|
||||
|
||||
/// Parse a standard Whisper-compatible JSON response (`{ "text": "..." }`).
|
||||
async fn parse_whisper_response(resp: reqwest::Response) -> Result<String> {
|
||||
let status = resp.status();
|
||||
let body: serde_json::Value = resp
|
||||
.json()
|
||||
@@ -109,6 +568,128 @@ pub async fn transcribe_audio(
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
// ── TranscriptionManager ────────────────────────────────────────
|
||||
|
||||
/// Manages multiple STT providers and routes transcription requests.
|
||||
pub struct TranscriptionManager {
|
||||
providers: HashMap<String, Box<dyn TranscriptionProvider>>,
|
||||
default_provider: String,
|
||||
}
|
||||
|
||||
impl TranscriptionManager {
|
||||
/// Build a `TranscriptionManager` from config.
|
||||
///
|
||||
/// Always attempts to register the Groq provider from existing config fields.
|
||||
/// Additional providers are registered when their config sections are present.
|
||||
///
|
||||
/// Provider keys with missing API keys are silently skipped — the error
|
||||
/// surfaces at transcribe-time so callers that target a different default
|
||||
/// provider are not blocked.
|
||||
pub fn new(config: &TranscriptionConfig) -> Result<Self> {
|
||||
let mut providers: HashMap<String, Box<dyn TranscriptionProvider>> = HashMap::new();
|
||||
|
||||
if let Ok(groq) = GroqProvider::from_config(config) {
|
||||
providers.insert("groq".to_string(), Box::new(groq));
|
||||
}
|
||||
|
||||
if let Some(ref openai_cfg) = config.openai {
|
||||
if let Ok(p) = OpenAiWhisperProvider::from_config(openai_cfg) {
|
||||
providers.insert("openai".to_string(), Box::new(p));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref deepgram_cfg) = config.deepgram {
|
||||
if let Ok(p) = DeepgramProvider::from_config(deepgram_cfg) {
|
||||
providers.insert("deepgram".to_string(), Box::new(p));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref assemblyai_cfg) = config.assemblyai {
|
||||
if let Ok(p) = AssemblyAiProvider::from_config(assemblyai_cfg) {
|
||||
providers.insert("assemblyai".to_string(), Box::new(p));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref google_cfg) = config.google {
|
||||
if let Ok(p) = GoogleSttProvider::from_config(google_cfg) {
|
||||
providers.insert("google".to_string(), Box::new(p));
|
||||
}
|
||||
}
|
||||
|
||||
let default_provider = config.default_provider.clone();
|
||||
|
||||
if config.enabled && !providers.contains_key(&default_provider) {
|
||||
let available: Vec<&str> = providers.keys().map(|k| k.as_str()).collect();
|
||||
bail!(
|
||||
"Default transcription provider '{}' is not configured. Available: {available:?}",
|
||||
default_provider
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
providers,
|
||||
default_provider,
|
||||
})
|
||||
}
|
||||
|
||||
/// Transcribe audio using the default provider.
|
||||
pub async fn transcribe(&self, audio_data: &[u8], file_name: &str) -> Result<String> {
|
||||
self.transcribe_with_provider(audio_data, file_name, &self.default_provider)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Transcribe audio using a specific named provider.
|
||||
pub async fn transcribe_with_provider(
|
||||
&self,
|
||||
audio_data: &[u8],
|
||||
file_name: &str,
|
||||
provider: &str,
|
||||
) -> Result<String> {
|
||||
let p = self.providers.get(provider).ok_or_else(|| {
|
||||
let available: Vec<&str> = self.providers.keys().map(|k| k.as_str()).collect();
|
||||
anyhow::anyhow!(
|
||||
"Transcription provider '{provider}' not configured. Available: {available:?}"
|
||||
)
|
||||
})?;
|
||||
|
||||
p.transcribe(audio_data, file_name).await
|
||||
}
|
||||
|
||||
/// List registered provider names.
|
||||
pub fn available_providers(&self) -> Vec<&str> {
|
||||
self.providers.keys().map(|k| k.as_str()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Backward-compatible convenience function ────────────────────
|
||||
|
||||
/// Transcribe audio bytes via a Whisper-compatible transcription API.
|
||||
///
|
||||
/// Returns the transcribed text on success.
|
||||
///
|
||||
/// This is the backward-compatible entry point that preserves the original
|
||||
/// function signature. It uses the Groq provider directly, matching the
|
||||
/// original single-provider behavior.
|
||||
///
|
||||
/// Credential resolution order:
|
||||
/// 1. `config.transcription.api_key`
|
||||
/// 2. `GROQ_API_KEY` environment variable (backward compatibility)
|
||||
///
|
||||
/// The caller is responsible for enforcing duration limits *before* downloading
|
||||
/// the file; this function enforces the byte-size cap.
|
||||
pub async fn transcribe_audio(
|
||||
audio_data: Vec<u8>,
|
||||
file_name: &str,
|
||||
config: &TranscriptionConfig,
|
||||
) -> Result<String> {
|
||||
// Validate audio before resolving credentials so that size/format errors
|
||||
// are reported before missing-key errors (preserves original behavior).
|
||||
validate_audio(&audio_data, file_name)?;
|
||||
|
||||
let groq = GroqProvider::from_config(config)?;
|
||||
groq.transcribe(&audio_data, file_name).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -129,7 +710,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_missing_api_key() {
|
||||
// Ensure the key is absent for this test
|
||||
// Ensure fallback env key is absent for this test.
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let data = vec![0u8; 100];
|
||||
@@ -139,11 +720,29 @@ mod tests {
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("GROQ_API_KEY"),
|
||||
err.to_string().contains("transcription API key"),
|
||||
"expected missing-key error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn uses_config_api_key_without_groq_env() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let data = vec![0u8; 100];
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.api_key = Some("transcription-key".to_string());
|
||||
|
||||
// Keep invalid extension so we fail before network, but after key resolution.
|
||||
let err = transcribe_audio(data, "recording.aac", &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("Unsupported audio format"),
|
||||
"expected unsupported-format error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mime_for_audio_maps_accepted_formats() {
|
||||
let cases = [
|
||||
@@ -219,4 +818,128 @@ mod tests {
|
||||
"error should mention the rejected extension, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
// ── TranscriptionManager tests ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn manager_creation_with_default_config() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let config = TranscriptionConfig::default();
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
assert_eq!(manager.default_provider, "groq");
|
||||
// Groq won't be registered without a key.
|
||||
assert!(manager.providers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_registers_groq_with_key() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.api_key = Some("test-groq-key".to_string());
|
||||
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
assert!(manager.providers.contains_key("groq"));
|
||||
assert_eq!(manager.providers["groq"].name(), "groq");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_registers_multiple_providers() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.api_key = Some("test-groq-key".to_string());
|
||||
config.openai = Some(crate::config::OpenAiSttConfig {
|
||||
api_key: Some("test-openai-key".to_string()),
|
||||
model: "whisper-1".to_string(),
|
||||
});
|
||||
config.deepgram = Some(crate::config::DeepgramSttConfig {
|
||||
api_key: Some("test-deepgram-key".to_string()),
|
||||
model: "nova-2".to_string(),
|
||||
});
|
||||
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
assert!(manager.providers.contains_key("groq"));
|
||||
assert!(manager.providers.contains_key("openai"));
|
||||
assert!(manager.providers.contains_key("deepgram"));
|
||||
assert_eq!(manager.available_providers().len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn manager_rejects_unconfigured_provider() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.api_key = Some("test-groq-key".to_string());
|
||||
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
let err = manager
|
||||
.transcribe_with_provider(&[0u8; 100], "test.ogg", "nonexistent")
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("not configured"),
|
||||
"expected not-configured error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_default_provider_from_config() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.default_provider = "openai".to_string();
|
||||
config.openai = Some(crate::config::OpenAiSttConfig {
|
||||
api_key: Some("test-openai-key".to_string()),
|
||||
model: "whisper-1".to_string(),
|
||||
});
|
||||
|
||||
let manager = TranscriptionManager::new(&config).unwrap();
|
||||
assert_eq!(manager.default_provider, "openai");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_audio_rejects_oversized() {
|
||||
let big = vec![0u8; MAX_AUDIO_BYTES + 1];
|
||||
let err = validate_audio(&big, "test.ogg").unwrap_err();
|
||||
assert!(err.to_string().contains("too large"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_audio_rejects_unsupported_format() {
|
||||
let data = vec![0u8; 100];
|
||||
let err = validate_audio(&data, "test.aac").unwrap_err();
|
||||
assert!(err.to_string().contains("Unsupported audio format"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_audio_accepts_supported_format() {
|
||||
let data = vec![0u8; 100];
|
||||
let (name, mime) = validate_audio(&data, "test.ogg").unwrap();
|
||||
assert_eq!(name, "test.ogg");
|
||||
assert_eq!(mime, "audio/ogg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_audio_normalizes_oga() {
|
||||
let data = vec![0u8; 100];
|
||||
let (name, mime) = validate_audio(&data, "voice.oga").unwrap();
|
||||
assert_eq!(name, "voice.ogg");
|
||||
assert_eq!(mime, "audio/ogg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn backward_compat_config_defaults_unchanged() {
|
||||
let config = TranscriptionConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert!(config.api_key.is_none());
|
||||
assert!(config.api_url.contains("groq.com"));
|
||||
assert_eq!(config.model, "whisper-large-v3-turbo");
|
||||
assert_eq!(config.default_provider, "groq");
|
||||
assert!(config.openai.is_none());
|
||||
assert!(config.deepgram.is_none());
|
||||
assert!(config.assemblyai.is_none());
|
||||
assert!(config.google.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
+10
-9
@@ -6,19 +6,20 @@ pub mod workspace;
|
||||
pub use schema::{
|
||||
apply_runtime_proxy_to_builder, build_runtime_proxy_client,
|
||||
build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config,
|
||||
AgentConfig, AuditConfig, AutonomyConfig, BackupConfig, BrowserComputerUseConfig,
|
||||
BrowserConfig, BuiltinHooksConfig, ChannelsConfig, ClassificationRule, CloudOpsConfig,
|
||||
ComposioConfig, Config, ConversationalAiConfig, CostConfig, CronConfig, DataRetentionConfig,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig,
|
||||
EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig, GoogleTtsConfig,
|
||||
AgentConfig, AssemblyAiSttConfig, AuditConfig, AutonomyConfig, BackupConfig,
|
||||
BrowserComputerUseConfig, BrowserConfig, BuiltinHooksConfig, ChannelsConfig,
|
||||
ClassificationRule, CloudOpsConfig, ComposioConfig, Config, ConversationalAiConfig, CostConfig,
|
||||
CronConfig, DataRetentionConfig, DeepgramSttConfig, DelegateAgentConfig, DiscordConfig,
|
||||
DockerRuntimeConfig, EdgeTtsConfig, ElevenLabsTtsConfig, EmbeddingRouteConfig, EstopConfig,
|
||||
FeishuConfig, GatewayConfig, GoogleSttConfig, GoogleTtsConfig, GoogleWorkspaceConfig,
|
||||
HardwareConfig, HardwareTransport, HeartbeatConfig, HooksConfig, HttpRequestConfig,
|
||||
IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, McpConfig, McpServerConfig,
|
||||
McpTransport, MemoryConfig, Microsoft365Config, ModelRouteConfig, MultimodalConfig,
|
||||
NextcloudTalkConfig, NodeTransportConfig, NodesConfig, NotionConfig, ObservabilityConfig,
|
||||
OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod, PeripheralBoardConfig,
|
||||
PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
OpenAiSttConfig, OpenAiTtsConfig, OpenVpnTunnelConfig, OtpConfig, OtpMethod,
|
||||
PeripheralBoardConfig, PeripheralsConfig, ProjectIntelConfig, ProxyConfig, ProxyScope,
|
||||
QdrantConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig,
|
||||
RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityOpsConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
|
||||
StorageProviderConfig, StorageProviderSection, StreamMode, SwarmConfig, SwarmStrategy,
|
||||
TelegramConfig, ToolFilterGroup, ToolFilterGroupMode, TranscriptionConfig, TtsConfig,
|
||||
|
||||
+306
-4
@@ -264,6 +264,10 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub project_intel: ProjectIntelConfig,
|
||||
|
||||
/// Google Workspace CLI (`gws`) tool configuration (`[google_workspace]`).
|
||||
#[serde(default)]
|
||||
pub google_workspace: GoogleWorkspaceConfig,
|
||||
|
||||
/// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]`).
|
||||
#[serde(default)]
|
||||
pub proxy: ProxyConfig,
|
||||
@@ -604,19 +608,46 @@ fn default_transcription_max_duration_secs() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
/// Voice transcription configuration (Whisper API via Groq).
|
||||
fn default_transcription_provider() -> String {
|
||||
"groq".into()
|
||||
}
|
||||
|
||||
fn default_openai_stt_model() -> String {
|
||||
"whisper-1".into()
|
||||
}
|
||||
|
||||
fn default_deepgram_stt_model() -> String {
|
||||
"nova-2".into()
|
||||
}
|
||||
|
||||
fn default_google_stt_language_code() -> String {
|
||||
"en-US".into()
|
||||
}
|
||||
|
||||
/// Voice transcription configuration with multi-provider support.
|
||||
///
|
||||
/// The top-level `api_url`, `model`, and `api_key` fields remain for backward
|
||||
/// compatibility with existing Groq-based configurations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct TranscriptionConfig {
|
||||
/// Enable voice transcription for channels that support it.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Whisper API endpoint URL.
|
||||
/// Default STT provider: "groq", "openai", "deepgram", "assemblyai", "google".
|
||||
#[serde(default = "default_transcription_provider")]
|
||||
pub default_provider: String,
|
||||
/// API key used for transcription requests (Groq provider).
|
||||
///
|
||||
/// If unset, runtime falls back to `GROQ_API_KEY` for backward compatibility.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Whisper API endpoint URL (Groq provider).
|
||||
#[serde(default = "default_transcription_api_url")]
|
||||
pub api_url: String,
|
||||
/// Whisper model name.
|
||||
/// Whisper model name (Groq provider).
|
||||
#[serde(default = "default_transcription_model")]
|
||||
pub model: String,
|
||||
/// Optional language hint (ISO-639-1, e.g. "en", "ru").
|
||||
/// Optional language hint (ISO-639-1, e.g. "en", "ru") for Groq provider.
|
||||
#[serde(default)]
|
||||
pub language: Option<String>,
|
||||
/// Optional initial prompt to bias transcription toward expected vocabulary
|
||||
@@ -627,17 +658,35 @@ pub struct TranscriptionConfig {
|
||||
/// Maximum voice duration in seconds (messages longer than this are skipped).
|
||||
#[serde(default = "default_transcription_max_duration_secs")]
|
||||
pub max_duration_secs: u64,
|
||||
/// OpenAI Whisper STT provider configuration.
|
||||
#[serde(default)]
|
||||
pub openai: Option<OpenAiSttConfig>,
|
||||
/// Deepgram STT provider configuration.
|
||||
#[serde(default)]
|
||||
pub deepgram: Option<DeepgramSttConfig>,
|
||||
/// AssemblyAI STT provider configuration.
|
||||
#[serde(default)]
|
||||
pub assemblyai: Option<AssemblyAiSttConfig>,
|
||||
/// Google Cloud Speech-to-Text provider configuration.
|
||||
#[serde(default)]
|
||||
pub google: Option<GoogleSttConfig>,
|
||||
}
|
||||
|
||||
impl Default for TranscriptionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
default_provider: default_transcription_provider(),
|
||||
api_key: None,
|
||||
api_url: default_transcription_api_url(),
|
||||
model: default_transcription_model(),
|
||||
language: None,
|
||||
initial_prompt: None,
|
||||
max_duration_secs: default_transcription_max_duration_secs(),
|
||||
openai: None,
|
||||
deepgram: None,
|
||||
assemblyai: None,
|
||||
google: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -938,6 +987,47 @@ pub struct ToolFilterGroup {
|
||||
pub keywords: Vec<String>,
|
||||
}
|
||||
|
||||
/// OpenAI Whisper STT provider configuration (`[transcription.openai]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OpenAiSttConfig {
|
||||
/// OpenAI API key for Whisper transcription.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Whisper model name (default: "whisper-1").
|
||||
#[serde(default = "default_openai_stt_model")]
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// Deepgram STT provider configuration (`[transcription.deepgram]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DeepgramSttConfig {
|
||||
/// Deepgram API key.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Deepgram model name (default: "nova-2").
|
||||
#[serde(default = "default_deepgram_stt_model")]
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// AssemblyAI STT provider configuration (`[transcription.assemblyai]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AssemblyAiSttConfig {
|
||||
/// AssemblyAI API key.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
/// Google Cloud Speech-to-Text provider configuration (`[transcription.google]`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GoogleSttConfig {
|
||||
/// Google Cloud API key.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// BCP-47 language code (default: "en-US").
|
||||
#[serde(default = "default_google_stt_language_code")]
|
||||
pub language_code: String,
|
||||
}
|
||||
|
||||
/// Agent orchestration configuration (`[agent]` section).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AgentConfig {
|
||||
@@ -1994,6 +2084,94 @@ impl Default for DataRetentionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Google Workspace ─────────────────────────────────────────────
|
||||
|
||||
/// Google Workspace CLI (`gws`) tool configuration (`[google_workspace]` section).
|
||||
///
|
||||
/// ## Defaults
|
||||
/// - `enabled`: `false` (tool is not registered unless explicitly opted-in).
|
||||
/// - `allowed_services`: empty vector, which grants access to the full default
|
||||
/// service set: `drive`, `sheets`, `gmail`, `calendar`, `docs`, `slides`,
|
||||
/// `tasks`, `people`, `chat`, `classroom`, `forms`, `keep`, `meet`, `events`.
|
||||
/// - `credentials_path`: `None` (uses default `gws` credential discovery).
|
||||
/// - `default_account`: `None` (uses the `gws` active account).
|
||||
/// - `rate_limit_per_minute`: `60`.
|
||||
/// - `timeout_secs`: `30`.
|
||||
/// - `audit_log`: `false`.
|
||||
/// - `credentials_path`: `None` (uses default `gws` credential discovery).
|
||||
/// - `default_account`: `None` (uses the `gws` active account).
|
||||
/// - `rate_limit_per_minute`: `60`.
|
||||
/// - `timeout_secs`: `30`.
|
||||
/// - `audit_log`: `false`.
|
||||
///
|
||||
/// ## Compatibility
|
||||
/// Configs that omit the `[google_workspace]` section entirely are treated as
|
||||
/// `GoogleWorkspaceConfig::default()` (disabled, all defaults allowed). Adding
|
||||
/// the section is purely opt-in and does not affect other config sections.
|
||||
///
|
||||
/// ## Rollback / Migration
|
||||
/// To revert, remove the `[google_workspace]` section from the config file (or
|
||||
/// set `enabled = false`). No data migration is required; the tool simply stops
|
||||
/// being registered.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct GoogleWorkspaceConfig {
|
||||
/// Enable the `google_workspace` tool. Default: `false`.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// Restrict which Google Workspace services the agent can access.
|
||||
///
|
||||
/// When empty (the default), the full default service set is allowed (see
|
||||
/// struct-level docs). When non-empty, only the listed service IDs are
|
||||
/// permitted. Each entry must be non-empty, lowercase alphanumeric with
|
||||
/// optional underscores/hyphens, and unique.
|
||||
#[serde(default)]
|
||||
pub allowed_services: Vec<String>,
|
||||
/// Path to service account JSON or OAuth client credentials file.
|
||||
///
|
||||
/// When `None`, the tool relies on the default `gws` credential discovery
|
||||
/// (`gws auth login`). Set this to point at a service-account key or an
|
||||
/// OAuth client-secrets JSON for headless / CI environments.
|
||||
#[serde(default)]
|
||||
pub credentials_path: Option<String>,
|
||||
/// Default Google account email to pass to `gws --account`.
|
||||
///
|
||||
/// When `None`, the currently active `gws` account is used.
|
||||
#[serde(default)]
|
||||
pub default_account: Option<String>,
|
||||
/// Maximum number of `gws` API calls allowed per minute. Default: `60`.
|
||||
#[serde(default = "default_gws_rate_limit")]
|
||||
pub rate_limit_per_minute: u32,
|
||||
/// Command execution timeout in seconds. Default: `30`.
|
||||
#[serde(default = "default_gws_timeout_secs")]
|
||||
pub timeout_secs: u64,
|
||||
/// Enable audit logging of every `gws` invocation (service, resource,
|
||||
/// method, timestamp). Default: `false`.
|
||||
#[serde(default)]
|
||||
pub audit_log: bool,
|
||||
}
|
||||
|
||||
fn default_gws_rate_limit() -> u32 {
|
||||
60
|
||||
}
|
||||
|
||||
fn default_gws_timeout_secs() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
impl Default for GoogleWorkspaceConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_services: Vec::new(),
|
||||
credentials_path: None,
|
||||
default_account: None,
|
||||
rate_limit_per_minute: default_gws_rate_limit(),
|
||||
timeout_secs: default_gws_timeout_secs(),
|
||||
audit_log: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Proxy ───────────────────────────────────────────────────────
|
||||
|
||||
/// Proxy application scope — determines which outbound traffic uses the proxy.
|
||||
@@ -3920,6 +4098,10 @@ pub struct SlackConfig {
|
||||
/// cancels the in-flight request and starts a fresh response with preserved history.
|
||||
#[serde(default)]
|
||||
pub interrupt_on_new_message: bool,
|
||||
/// When true, only respond to messages that @-mention the bot in groups.
|
||||
/// Direct messages remain allowed.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
}
|
||||
|
||||
impl ChannelConfig for SlackConfig {
|
||||
@@ -5260,6 +5442,7 @@ impl Default for Config {
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
cost: CostConfig::default(),
|
||||
@@ -5794,6 +5977,41 @@ impl Config {
|
||||
decrypt_optional_secret(&store, &mut google.api_key, "config.tts.google.api_key")?;
|
||||
}
|
||||
|
||||
// Decrypt nested STT provider API keys
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.transcription.api_key,
|
||||
"config.transcription.api_key",
|
||||
)?;
|
||||
if let Some(ref mut openai) = config.transcription.openai {
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut openai.api_key,
|
||||
"config.transcription.openai.api_key",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut deepgram) = config.transcription.deepgram {
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut deepgram.api_key,
|
||||
"config.transcription.deepgram.api_key",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut assemblyai) = config.transcription.assemblyai {
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut assemblyai.api_key,
|
||||
"config.transcription.assemblyai.api_key",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut google) = config.transcription.google {
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut google.api_key,
|
||||
"config.transcription.google.api_key",
|
||||
)?;
|
||||
}
|
||||
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
if let Some(ref mut ns) = config.channels_config.nostr {
|
||||
decrypt_secret(
|
||||
@@ -6417,6 +6635,28 @@ impl Config {
|
||||
validate_mcp_config(&self.mcp)?;
|
||||
}
|
||||
|
||||
// Google Workspace allowed_services validation
|
||||
let mut seen_gws_services = std::collections::HashSet::new();
|
||||
for (i, service) in self.google_workspace.allowed_services.iter().enumerate() {
|
||||
let normalized = service.trim();
|
||||
if normalized.is_empty() {
|
||||
anyhow::bail!("google_workspace.allowed_services[{i}] must not be empty");
|
||||
}
|
||||
if !normalized
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-')
|
||||
{
|
||||
anyhow::bail!(
|
||||
"google_workspace.allowed_services[{i}] contains invalid characters: {normalized}"
|
||||
);
|
||||
}
|
||||
if !seen_gws_services.insert(normalized.to_string()) {
|
||||
anyhow::bail!(
|
||||
"google_workspace.allowed_services contains duplicate entry: {normalized}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Project intelligence
|
||||
if self.project_intel.enabled {
|
||||
let lang = &self.project_intel.default_language;
|
||||
@@ -6470,6 +6710,19 @@ impl Config {
|
||||
anyhow::bail!("security.nevis: {msg}");
|
||||
}
|
||||
|
||||
// Transcription
|
||||
{
|
||||
let dp = self.transcription.default_provider.trim();
|
||||
match dp {
|
||||
"groq" | "openai" | "deepgram" | "assemblyai" | "google" => {}
|
||||
other => {
|
||||
anyhow::bail!(
|
||||
"transcription.default_provider must be one of: groq, openai, deepgram, assemblyai, google (got '{other}')"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -6879,6 +7132,41 @@ impl Config {
|
||||
encrypt_optional_secret(&store, &mut google.api_key, "config.tts.google.api_key")?;
|
||||
}
|
||||
|
||||
// Encrypt nested STT provider API keys
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.transcription.api_key,
|
||||
"config.transcription.api_key",
|
||||
)?;
|
||||
if let Some(ref mut openai) = config_to_save.transcription.openai {
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut openai.api_key,
|
||||
"config.transcription.openai.api_key",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut deepgram) = config_to_save.transcription.deepgram {
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut deepgram.api_key,
|
||||
"config.transcription.deepgram.api_key",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut assemblyai) = config_to_save.transcription.assemblyai {
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut assemblyai.api_key,
|
||||
"config.transcription.assemblyai.api_key",
|
||||
)?;
|
||||
}
|
||||
if let Some(ref mut google) = config_to_save.transcription.google {
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut google.api_key,
|
||||
"config.transcription.google.api_key",
|
||||
)?;
|
||||
}
|
||||
|
||||
#[cfg(feature = "channel-nostr")]
|
||||
if let Some(ref mut ns) = config_to_save.channels_config.nostr {
|
||||
encrypt_secret(
|
||||
@@ -7564,6 +7852,7 @@ default_temperature = 0.7
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
@@ -7867,6 +8156,7 @@ tool_dispatcher = "xml"
|
||||
web_fetch: WebFetchConfig::default(),
|
||||
web_search: WebSearchConfig::default(),
|
||||
project_intel: ProjectIntelConfig::default(),
|
||||
google_workspace: GoogleWorkspaceConfig::default(),
|
||||
proxy: ProxyConfig::default(),
|
||||
agent: AgentConfig::default(),
|
||||
identity: IdentityConfig::default(),
|
||||
@@ -8326,6 +8616,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8334,6 +8625,15 @@ allowed_users = ["@ops:matrix.org"]
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(parsed.allowed_users, vec!["U111"]);
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
async fn slack_config_deserializes_with_mention_only() {
|
||||
let json = r#"{"bot_token":"xoxb-tok","mention_only":true}"#;
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.mention_only);
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8341,6 +8641,7 @@ allowed_users = ["@ops:matrix.org"]
|
||||
let json = r#"{"bot_token":"xoxb-tok","interrupt_on_new_message":true}"#;
|
||||
let parsed: SlackConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -8363,6 +8664,7 @@ channel_id = "C123"
|
||||
let parsed: SlackConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert!(!parsed.interrupt_on_new_message);
|
||||
assert!(!parsed.mention_only);
|
||||
assert_eq!(parsed.channel_id.as_deref(), Some("C123"));
|
||||
}
|
||||
|
||||
|
||||
@@ -509,6 +509,18 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
},
|
||||
},
|
||||
// ── Productivity ────────────────────────────────────────
|
||||
IntegrationEntry {
|
||||
name: "Google Workspace",
|
||||
description: "Drive, Gmail, Calendar, Sheets, Docs via gws CLI",
|
||||
category: IntegrationCategory::Productivity,
|
||||
status_fn: |c| {
|
||||
if c.google_workspace.enabled {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "GitHub",
|
||||
description: "Code, issues, PRs",
|
||||
@@ -606,7 +618,13 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
name: "Browser",
|
||||
description: "Chrome/Chromium control",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
status_fn: |c| {
|
||||
if c.browser.enabled {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Shell",
|
||||
@@ -624,7 +642,13 @@ pub fn all_integrations() -> Vec<IntegrationEntry> {
|
||||
name: "Cron",
|
||||
description: "Scheduled tasks",
|
||||
category: IntegrationCategory::ToolsAutomation,
|
||||
status_fn: |_| IntegrationStatus::Available,
|
||||
status_fn: |c| {
|
||||
if c.cron.enabled {
|
||||
IntegrationStatus::Active
|
||||
} else {
|
||||
IntegrationStatus::Available
|
||||
}
|
||||
},
|
||||
},
|
||||
IntegrationEntry {
|
||||
name: "Voice",
|
||||
@@ -917,6 +941,54 @@ mod tests {
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_active_when_enabled() {
|
||||
let mut config = Config::default();
|
||||
config.cron.enabled = true;
|
||||
let entries = all_integrations();
|
||||
let cron = entries.iter().find(|e| e.name == "Cron").unwrap();
|
||||
assert!(matches!(
|
||||
(cron.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_available_when_disabled() {
|
||||
let mut config = Config::default();
|
||||
config.cron.enabled = false;
|
||||
let entries = all_integrations();
|
||||
let cron = entries.iter().find(|e| e.name == "Cron").unwrap();
|
||||
assert!(matches!(
|
||||
(cron.status_fn)(&config),
|
||||
IntegrationStatus::Available
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn browser_active_when_enabled() {
|
||||
let mut config = Config::default();
|
||||
config.browser.enabled = true;
|
||||
let entries = all_integrations();
|
||||
let browser = entries.iter().find(|e| e.name == "Browser").unwrap();
|
||||
assert!(matches!(
|
||||
(browser.status_fn)(&config),
|
||||
IntegrationStatus::Active
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn browser_available_when_disabled() {
|
||||
let mut config = Config::default();
|
||||
config.browser.enabled = false;
|
||||
let entries = all_integrations();
|
||||
let browser = entries.iter().find(|e| e.name == "Browser").unwrap();
|
||||
assert!(matches!(
|
||||
(browser.status_fn)(&config),
|
||||
IntegrationStatus::Available
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shell_and_filesystem_always_active() {
|
||||
let config = Config::default();
|
||||
|
||||
+8
-4
@@ -806,13 +806,13 @@ async fn main() -> Result<()> {
|
||||
} else if is_tty && !has_provider_flags {
|
||||
Box::pin(onboard::run_wizard(force)).await
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
Box::pin(onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
provider.as_deref(),
|
||||
model.as_deref(),
|
||||
memory.as_deref(),
|
||||
force,
|
||||
)
|
||||
))
|
||||
.await
|
||||
}?;
|
||||
|
||||
@@ -1191,7 +1191,7 @@ async fn main() -> Result<()> {
|
||||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => Box::pin(channels::start_channels(config)).await,
|
||||
ChannelCommands::Doctor => Box::pin(channels::doctor_channels(config)).await,
|
||||
other => channels::handle_command(other, &config).await,
|
||||
other => Box::pin(channels::handle_command(other, &config)).await,
|
||||
},
|
||||
|
||||
Commands::Integrations {
|
||||
@@ -1215,7 +1215,11 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
Commands::Peripheral { peripheral_command } => {
|
||||
peripherals::handle_command(peripheral_command.clone(), &config).await
|
||||
Box::pin(peripherals::handle_command(
|
||||
peripheral_command.clone(),
|
||||
&config,
|
||||
))
|
||||
.await
|
||||
}
|
||||
|
||||
Commands::Config { config_command } => match config_command {
|
||||
|
||||
+15
-12
@@ -173,6 +173,7 @@ pub async fn run_wizard(force: bool) -> Result<Config> {
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
google_workspace: crate::config::GoogleWorkspaceConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
@@ -424,14 +425,14 @@ pub async fn run_quick_setup(
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
|
||||
run_quick_setup_with_home(
|
||||
Box::pin(run_quick_setup_with_home(
|
||||
credential_override,
|
||||
provider,
|
||||
model_override,
|
||||
memory_backend,
|
||||
force,
|
||||
&home,
|
||||
)
|
||||
))
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -543,6 +544,7 @@ async fn run_quick_setup_with_home(
|
||||
web_fetch: crate::config::WebFetchConfig::default(),
|
||||
web_search: crate::config::WebSearchConfig::default(),
|
||||
project_intel: crate::config::ProjectIntelConfig::default(),
|
||||
google_workspace: crate::config::GoogleWorkspaceConfig::default(),
|
||||
proxy: crate::config::ProxyConfig::default(),
|
||||
identity: crate::config::IdentityConfig::default(),
|
||||
cost: crate::config::CostConfig::default(),
|
||||
@@ -3902,6 +3904,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||
},
|
||||
allowed_users,
|
||||
interrupt_on_new_message: false,
|
||||
mention_only: false,
|
||||
});
|
||||
}
|
||||
ChannelMenuChoice::IMessage => {
|
||||
@@ -5911,14 +5914,14 @@ mod tests {
|
||||
let _config_env = EnvVarGuard::unset("ZEROCLAW_CONFIG_DIR");
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
let config = run_quick_setup_with_home(
|
||||
let config = Box::pin(run_quick_setup_with_home(
|
||||
Some("sk-issue946"),
|
||||
Some("openrouter"),
|
||||
Some("custom-model-946"),
|
||||
Some("sqlite"),
|
||||
false,
|
||||
tmp.path(),
|
||||
)
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -5938,14 +5941,14 @@ mod tests {
|
||||
let _config_env = EnvVarGuard::unset("ZEROCLAW_CONFIG_DIR");
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
let config = run_quick_setup_with_home(
|
||||
let config = Box::pin(run_quick_setup_with_home(
|
||||
Some("sk-issue946"),
|
||||
Some("anthropic"),
|
||||
None,
|
||||
Some("sqlite"),
|
||||
false,
|
||||
tmp.path(),
|
||||
)
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -5968,14 +5971,14 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let err = run_quick_setup_with_home(
|
||||
let err = Box::pin(run_quick_setup_with_home(
|
||||
Some("sk-existing"),
|
||||
Some("openrouter"),
|
||||
Some("custom-model"),
|
||||
Some("sqlite"),
|
||||
false,
|
||||
tmp.path(),
|
||||
)
|
||||
))
|
||||
.await
|
||||
.expect_err("quick setup should refuse overwrite without --force");
|
||||
|
||||
@@ -6001,14 +6004,14 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let config = run_quick_setup_with_home(
|
||||
let config = Box::pin(run_quick_setup_with_home(
|
||||
Some("sk-force"),
|
||||
Some("openrouter"),
|
||||
Some("custom-model-fresh"),
|
||||
Some("sqlite"),
|
||||
true,
|
||||
tmp.path(),
|
||||
)
|
||||
))
|
||||
.await
|
||||
.expect("quick setup should overwrite existing config with --force");
|
||||
|
||||
@@ -6035,14 +6038,14 @@ mod tests {
|
||||
);
|
||||
let _config_env = EnvVarGuard::unset("ZEROCLAW_CONFIG_DIR");
|
||||
|
||||
let config = run_quick_setup_with_home(
|
||||
let config = Box::pin(run_quick_setup_with_home(
|
||||
Some("sk-env"),
|
||||
Some("openrouter"),
|
||||
Some("model-env"),
|
||||
Some("sqlite"),
|
||||
false,
|
||||
tmp.path(),
|
||||
)
|
||||
))
|
||||
.await
|
||||
.expect("quick setup should honor ZEROCLAW_WORKSPACE");
|
||||
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
//! Claude Code headless CLI provider.
|
||||
//!
|
||||
//! Integrates with the Claude Code CLI, spawning the `claude` binary
|
||||
//! as a subprocess for each inference request. This allows using Claude's AI
|
||||
//! models without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `claude` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `CLAUDE_CODE_PATH` environment variable.
|
||||
//!
|
||||
//! Claude Code is invoked as:
|
||||
//! ```text
|
||||
//! claude --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by Claude Code itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `claude` binary.
|
||||
pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
|
||||
|
||||
/// Default `claude` binary name (resolved via `PATH`).
|
||||
const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// Claude Code requests are bounded to avoid hung subprocesses.
|
||||
const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the Claude Code CLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `claude` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct ClaudeCodeProvider {
|
||||
/// Path to the `claude` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ClaudeCodeProvider {
|
||||
/// Create a new `ClaudeCodeProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `CLAUDE_CODE_PATH` env var if set,
|
||||
/// otherwise defaults to `"claude"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
CLAUDE_CODE_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("Claude Code provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by Claude Code CLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the claude binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn Claude Code binary at {}: {err}. \
|
||||
Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(message.as_bytes()).await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
|
||||
})?;
|
||||
stdin.shutdown().await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
|
||||
})?;
|
||||
}
|
||||
|
||||
let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Claude Code request timed out after {:?} (binary: {})",
|
||||
CLAUDE_CODE_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"Claude Code exited with non-zero status {code}. \
|
||||
Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeCodeProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ClaudeCodeProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude");
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
|
||||
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_claude() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::remove_var(CLAUDE_CODE_PATH_ENV);
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("claude"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, " ");
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("claude"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
|
||||
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(ClaudeCodeProvider::should_forward_model(
|
||||
"claude-sonnet-4-20250514"
|
||||
));
|
||||
assert!(ClaudeCodeProvider::should_forward_model(
|
||||
"claude-3.5-sonnet"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(
|
||||
DEFAULT_MODEL_MARKER
|
||||
));
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(""));
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = ClaudeCodeProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by Claude Code CLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = ClaudeCodeProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/claude"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn Claude Code binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
//! Gemini CLI subprocess provider.
|
||||
//!
|
||||
//! Integrates with the Gemini CLI, spawning the `gemini` binary
|
||||
//! as a subprocess for each inference request. This allows using Google's
|
||||
//! Gemini models via the CLI without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `gemini` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `GEMINI_CLI_PATH` environment variable.
|
||||
//!
|
||||
//! Gemini CLI is invoked as:
|
||||
//! ```text
|
||||
//! gemini --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by the Gemini CLI itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `GEMINI_CLI_PATH` — override the path to the `gemini` binary (default: `"gemini"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `gemini` binary.
|
||||
pub const GEMINI_CLI_PATH_ENV: &str = "GEMINI_CLI_PATH";
|
||||
|
||||
/// Default `gemini` binary name (resolved via `PATH`).
|
||||
const DEFAULT_GEMINI_CLI_BINARY: &str = "gemini";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// Gemini CLI requests are bounded to avoid hung subprocesses.
|
||||
const GEMINI_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_GEMINI_CLI_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const GEMINI_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the Gemini CLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `gemini` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct GeminiCliProvider {
|
||||
/// Path to the `gemini` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl GeminiCliProvider {
|
||||
/// Create a new `GeminiCliProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `GEMINI_CLI_PATH` env var if set,
|
||||
/// otherwise defaults to `"gemini"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(GEMINI_CLI_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_GEMINI_CLI_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
GEMINI_CLI_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("Gemini CLI provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by Gemini CLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_GEMINI_CLI_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_GEMINI_CLI_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the gemini binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn Gemini CLI binary at {}: {err}. \
|
||||
Ensure `gemini` is installed and in PATH, or set GEMINI_CLI_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(message.as_bytes()).await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to write prompt to Gemini CLI stdin: {err}")
|
||||
})?;
|
||||
stdin.shutdown().await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to finalize Gemini CLI stdin stream: {err}")
|
||||
})?;
|
||||
}
|
||||
|
||||
let output = timeout(GEMINI_CLI_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Gemini CLI request timed out after {:?} (binary: {})",
|
||||
GEMINI_CLI_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("Gemini CLI process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"Gemini CLI exited with non-zero status {code}. \
|
||||
Check that Gemini CLI is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("Gemini CLI produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GeminiCliProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for GeminiCliProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, "/usr/local/bin/gemini");
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/gemini"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_gemini() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::remove_var(GEMINI_CLI_PATH_ENV);
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, " ");
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-pro"));
|
||||
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-flash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!GeminiCliProvider::should_forward_model(
|
||||
DEFAULT_MODEL_MARKER
|
||||
));
|
||||
assert!(!GeminiCliProvider::should_forward_model(""));
|
||||
assert!(!GeminiCliProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(GeminiCliProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(GeminiCliProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = GeminiCliProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by Gemini CLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = GeminiCliProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/gemini"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn Gemini CLI binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,326 @@
|
||||
//! KiloCLI subprocess provider.
|
||||
//!
|
||||
//! Integrates with the KiloCLI tool, spawning the `kilo` binary
|
||||
//! as a subprocess for each inference request. This allows using KiloCLI's AI
|
||||
//! models without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `kilo` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `KILO_CLI_PATH` environment variable.
|
||||
//!
|
||||
//! KiloCLI is invoked as:
|
||||
//! ```text
|
||||
//! kilo --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by KiloCLI itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `KILO_CLI_PATH` — override the path to the `kilo` binary (default: `"kilo"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `kilo` binary.
|
||||
pub const KILO_CLI_PATH_ENV: &str = "KILO_CLI_PATH";
|
||||
|
||||
/// Default `kilo` binary name (resolved via `PATH`).
|
||||
const DEFAULT_KILO_CLI_BINARY: &str = "kilo";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// KiloCLI requests are bounded to avoid hung subprocesses.
|
||||
const KILO_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_KILO_CLI_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const KILO_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the KiloCLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `kilo` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct KiloCliProvider {
|
||||
/// Path to the `kilo` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl KiloCliProvider {
|
||||
/// Create a new `KiloCliProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `KILO_CLI_PATH` env var if set,
|
||||
/// otherwise defaults to `"kilo"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(KILO_CLI_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_KILO_CLI_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
KILO_CLI_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("KiloCLI provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by KiloCLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_KILO_CLI_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_KILO_CLI_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the kilo binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn KiloCLI binary at {}: {err}. \
|
||||
Ensure `kilo` is installed and in PATH, or set KILO_CLI_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(message.as_bytes())
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("Failed to write prompt to KiloCLI stdin: {err}"))?;
|
||||
stdin
|
||||
.shutdown()
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("Failed to finalize KiloCLI stdin stream: {err}"))?;
|
||||
}
|
||||
|
||||
let output = timeout(KILO_CLI_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"KiloCLI request timed out after {:?} (binary: {})",
|
||||
KILO_CLI_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("KiloCLI process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"KiloCLI exited with non-zero status {code}. \
|
||||
Check that KiloCLI is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("KiloCLI produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KiloCliProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for KiloCliProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, "/usr/local/bin/kilo");
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/kilo"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(KILO_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_kilo() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::remove_var(KILO_CLI_PATH_ENV);
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, " ");
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(KILO_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(KiloCliProvider::should_forward_model("some-model"));
|
||||
assert!(KiloCliProvider::should_forward_model("gpt-4o"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!KiloCliProvider::should_forward_model(DEFAULT_MODEL_MARKER));
|
||||
assert!(!KiloCliProvider::should_forward_model(""));
|
||||
assert!(!KiloCliProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(KiloCliProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(KiloCliProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = KiloCliProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by KiloCLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = KiloCliProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/kilo"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn KiloCLI binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -19,9 +19,12 @@
|
||||
pub mod anthropic;
|
||||
pub mod azure_openai;
|
||||
pub mod bedrock;
|
||||
pub mod claude_code;
|
||||
pub mod compatible;
|
||||
pub mod copilot;
|
||||
pub mod gemini;
|
||||
pub mod gemini_cli;
|
||||
pub mod kilocli;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
pub mod openai_codex;
|
||||
@@ -1251,6 +1254,9 @@ fn create_provider_with_url_and_options(
|
||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"copilot" | "github-copilot" => Ok(Box::new(copilot::CopilotProvider::new(key))),
|
||||
"claude-code" => Ok(Box::new(claude_code::ClaudeCodeProvider::new())),
|
||||
"gemini-cli" => Ok(Box::new(gemini_cli::GeminiCliProvider::new())),
|
||||
"kilocli" | "kilo" => Ok(Box::new(kilocli::KiloCliProvider::new())),
|
||||
"lmstudio" | "lm-studio" => {
|
||||
let lm_studio_key = key
|
||||
.map(str::trim)
|
||||
@@ -1902,6 +1908,24 @@ pub fn list_providers() -> Vec<ProviderInfo> {
|
||||
aliases: &["github-copilot"],
|
||||
local: false,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "claude-code",
|
||||
display_name: "Claude Code (CLI)",
|
||||
aliases: &[],
|
||||
local: true,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "gemini-cli",
|
||||
display_name: "Gemini CLI",
|
||||
aliases: &[],
|
||||
local: true,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "kilocli",
|
||||
display_name: "KiloCLI",
|
||||
aliases: &["kilo"],
|
||||
local: true,
|
||||
},
|
||||
ProviderInfo {
|
||||
name: "lmstudio",
|
||||
display_name: "LM Studio",
|
||||
@@ -2720,6 +2744,22 @@ mod tests {
|
||||
assert!(create_provider("github-copilot", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_claude_code() {
|
||||
assert!(create_provider("claude-code", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_gemini_cli() {
|
||||
assert!(create_provider("gemini-cli", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_kilocli() {
|
||||
assert!(create_provider("kilocli", None).is_ok());
|
||||
assert!(create_provider("kilo", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_nvidia() {
|
||||
assert!(create_provider("nvidia", Some("nvapi-test")).is_ok());
|
||||
@@ -3053,6 +3093,9 @@ mod tests {
|
||||
"perplexity",
|
||||
"cohere",
|
||||
"copilot",
|
||||
"claude-code",
|
||||
"gemini-cli",
|
||||
"kilocli",
|
||||
"nvidia",
|
||||
"astrai",
|
||||
"ovhcloud",
|
||||
|
||||
@@ -12,6 +12,8 @@ pub enum CliCategory {
|
||||
Container,
|
||||
Build,
|
||||
Cloud,
|
||||
AiAgent,
|
||||
Productivity,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CliCategory {
|
||||
@@ -23,6 +25,8 @@ impl std::fmt::Display for CliCategory {
|
||||
Self::Container => write!(f, "Container"),
|
||||
Self::Build => write!(f, "Build"),
|
||||
Self::Cloud => write!(f, "Cloud"),
|
||||
Self::AiAgent => write!(f, "AI Agent"),
|
||||
Self::Productivity => write!(f, "Productivity"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,6 +108,26 @@ const KNOWN_CLIS: &[KnownCli] = &[
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::Language,
|
||||
},
|
||||
KnownCli {
|
||||
name: "claude",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::AiAgent,
|
||||
},
|
||||
KnownCli {
|
||||
name: "gemini",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::AiAgent,
|
||||
},
|
||||
KnownCli {
|
||||
name: "kilo",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::AiAgent,
|
||||
},
|
||||
KnownCli {
|
||||
name: "gws",
|
||||
version_args: &["--version"],
|
||||
category: CliCategory::Productivity,
|
||||
},
|
||||
];
|
||||
|
||||
/// Discover available CLI tools on the system.
|
||||
@@ -235,5 +259,7 @@ mod tests {
|
||||
assert_eq!(CliCategory::Container.to_string(), "Container");
|
||||
assert_eq!(CliCategory::Build.to_string(), "Build");
|
||||
assert_eq!(CliCategory::Cloud.to_string(), "Cloud");
|
||||
assert_eq!(CliCategory::AiAgent.to_string(), "AI Agent");
|
||||
assert_eq!(CliCategory::Productivity.to_string(), "Productivity");
|
||||
}
|
||||
}
|
||||
|
||||
+152
-15
@@ -65,27 +65,97 @@ impl Tool for CronAddTool {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"schedule": {
|
||||
"type": "object",
|
||||
"description": "Schedule object: {kind:'cron',expr,tz?} | {kind:'at',at} | {kind:'every',every_ms}"
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Optional human-readable name for the job"
|
||||
},
|
||||
// NOTE: oneOf is correct for OpenAI-compatible APIs (including OpenRouter).
|
||||
// Gemini does not support oneOf in tool schemas; if Gemini native tool calling
|
||||
// is ever wired up, SchemaCleanr::clean_for_gemini must be applied before
|
||||
// tool specs are sent. See src/tools/schema.rs.
|
||||
"schedule": {
|
||||
"description": "When to run the job. Exactly one of three forms must be used.",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Cron expression schedule (repeating). Example: {\"kind\":\"cron\",\"expr\":\"0 9 * * 1-5\",\"tz\":\"America/New_York\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["cron"] },
|
||||
"expr": { "type": "string", "description": "Standard 5-field cron expression, e.g. '*/5 * * * *'" },
|
||||
"tz": { "type": "string", "description": "Optional IANA timezone name, e.g. 'America/New_York'. Defaults to UTC." }
|
||||
},
|
||||
"required": ["kind", "expr"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "One-shot schedule at a specific UTC datetime. Example: {\"kind\":\"at\",\"at\":\"2025-12-31T23:59:00Z\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["at"] },
|
||||
"at": { "type": "string", "description": "ISO 8601 UTC datetime string, e.g. '2025-12-31T23:59:00Z'" }
|
||||
},
|
||||
"required": ["kind", "at"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Repeating interval schedule in milliseconds. Example: {\"kind\":\"every\",\"every_ms\":3600000} runs every hour.",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["every"] },
|
||||
"every_ms": { "type": "integer", "description": "Interval in milliseconds, e.g. 3600000 for every hour" }
|
||||
},
|
||||
"required": ["kind", "every_ms"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"job_type": {
|
||||
"type": "string",
|
||||
"enum": ["shell", "agent"],
|
||||
"description": "Type of job: 'shell' runs a command, 'agent' runs the AI agent with a prompt"
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Shell command to run (required when job_type is 'shell')"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Agent prompt to run on schedule (required when job_type is 'agent')"
|
||||
},
|
||||
"session_target": {
|
||||
"type": "string",
|
||||
"enum": ["isolated", "main"],
|
||||
"description": "Agent session context: 'isolated' starts a fresh session each run, 'main' reuses the primary session"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Optional model override for agent jobs, e.g. 'x-ai/grok-4-1-fast'"
|
||||
},
|
||||
"job_type": { "type": "string", "enum": ["shell", "agent"] },
|
||||
"command": { "type": "string" },
|
||||
"prompt": { "type": "string" },
|
||||
"session_target": { "type": "string", "enum": ["isolated", "main"] },
|
||||
"model": { "type": "string" },
|
||||
"delivery": {
|
||||
"type": "object",
|
||||
"description": "Delivery config to send job output to a channel. Example: {\"mode\":\"announce\",\"channel\":\"discord\",\"to\":\"<channel_id>\"}",
|
||||
"description": "Optional delivery config to send job output to a channel after each run. When provided, all three of mode, channel, and to are expected.",
|
||||
"properties": {
|
||||
"mode": { "type": "string", "enum": ["none", "announce"], "description": "Set to 'announce' to deliver output to a channel" },
|
||||
"channel": { "type": "string", "enum": ["telegram", "discord", "slack", "mattermost", "matrix"], "description": "Channel type to deliver to" },
|
||||
"to": { "type": "string", "description": "Target: Discord channel ID, Telegram chat ID, Slack channel, etc." },
|
||||
"best_effort": { "type": "boolean", "description": "If true, delivery failure does not fail the job" }
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["none", "announce"],
|
||||
"description": "'announce' sends output to the specified channel; 'none' disables delivery"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"enum": ["telegram", "discord", "slack", "mattermost", "matrix"],
|
||||
"description": "Channel type to deliver output to"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "Destination ID: Discord channel ID, Telegram chat ID, Slack channel name, etc."
|
||||
},
|
||||
"best_effort": {
|
||||
"type": "boolean",
|
||||
"description": "If true, a delivery failure does not fail the job itself. Defaults to true."
|
||||
}
|
||||
}
|
||||
},
|
||||
"delete_after_run": { "type": "boolean" },
|
||||
"delete_after_run": {
|
||||
"type": "boolean",
|
||||
"description": "If true, the job is automatically deleted after its first successful run. Defaults to true for 'at' schedules."
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean",
|
||||
"description": "Set true to explicitly approve medium/high-risk shell commands in supervised mode",
|
||||
@@ -497,4 +567,71 @@ mod tests {
|
||||
|
||||
assert!(values.iter().any(|value| value == "matrix"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schedule_schema_is_oneof_with_cron_at_every_variants() {
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let cfg = Arc::new(Config {
|
||||
workspace_dir: tmp.path().join("workspace"),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
});
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&cfg.autonomy,
|
||||
&cfg.workspace_dir,
|
||||
));
|
||||
let tool = CronAddTool::new(cfg, security);
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
// Top-level: schedule is required
|
||||
let top_required = schema["required"].as_array().expect("top-level required");
|
||||
assert!(top_required.iter().any(|v| v == "schedule"));
|
||||
|
||||
// schedule is a oneOf with exactly 3 variants: cron, at, every
|
||||
let one_of = schema["properties"]["schedule"]["oneOf"]
|
||||
.as_array()
|
||||
.expect("schedule.oneOf must be an array");
|
||||
assert_eq!(one_of.len(), 3, "expected cron, at, and every variants");
|
||||
|
||||
let kinds: Vec<&str> = one_of
|
||||
.iter()
|
||||
.filter_map(|v| v["properties"]["kind"]["enum"][0].as_str())
|
||||
.collect();
|
||||
assert!(kinds.contains(&"cron"), "missing cron variant");
|
||||
assert!(kinds.contains(&"at"), "missing at variant");
|
||||
assert!(kinds.contains(&"every"), "missing every variant");
|
||||
|
||||
// Each variant declares its required fields and every_ms is typed integer
|
||||
for variant in one_of {
|
||||
let kind = variant["properties"]["kind"]["enum"][0]
|
||||
.as_str()
|
||||
.expect("variant kind");
|
||||
let req: Vec<&str> = variant["required"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("{kind} variant must have required"))
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect();
|
||||
assert!(
|
||||
req.contains(&"kind"),
|
||||
"{kind} variant missing 'kind' in required"
|
||||
);
|
||||
match kind {
|
||||
"cron" => assert!(req.contains(&"expr"), "cron variant missing 'expr'"),
|
||||
"at" => assert!(req.contains(&"at"), "at variant missing 'at'"),
|
||||
"every" => {
|
||||
assert!(
|
||||
req.contains(&"every_ms"),
|
||||
"every variant missing 'every_ms'"
|
||||
);
|
||||
assert_eq!(
|
||||
variant["properties"]["every_ms"]["type"].as_str(),
|
||||
Some("integer"),
|
||||
"every_ms must be typed as integer"
|
||||
);
|
||||
}
|
||||
_ => panic!("unexpected kind: {kind}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+200
-2
@@ -61,8 +61,106 @@ impl Tool for CronUpdateTool {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"job_id": { "type": "string" },
|
||||
"patch": { "type": "object" },
|
||||
"job_id": {
|
||||
"type": "string",
|
||||
"description": "ID of the cron job to update, as returned by cron_add or cron_list"
|
||||
},
|
||||
"patch": {
|
||||
"type": "object",
|
||||
"description": "Fields to update. Only include fields you want to change; omitted fields are left as-is.",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "New human-readable name for the job"
|
||||
},
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"description": "Enable or disable the job without deleting it"
|
||||
},
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "New shell command (for shell jobs)"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "New agent prompt (for agent jobs)"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": "Model override for agent jobs, e.g. 'x-ai/grok-4-1-fast'"
|
||||
},
|
||||
"session_target": {
|
||||
"type": "string",
|
||||
"enum": ["isolated", "main"],
|
||||
"description": "Agent session context: 'isolated' starts fresh each run, 'main' reuses the primary session"
|
||||
},
|
||||
"delete_after_run": {
|
||||
"type": "boolean",
|
||||
"description": "If true, delete the job automatically after its first successful run"
|
||||
},
|
||||
// NOTE: oneOf is correct for OpenAI-compatible APIs (including OpenRouter).
|
||||
// Gemini does not support oneOf in tool schemas; if Gemini native tool calling
|
||||
// is ever wired up, SchemaCleanr::clean_for_gemini must be applied before
|
||||
// tool specs are sent. See src/tools/schema.rs.
|
||||
"schedule": {
|
||||
"description": "New schedule for the job. Exactly one of three forms must be used.",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Cron expression schedule (repeating). Example: {\"kind\":\"cron\",\"expr\":\"0 9 * * 1-5\",\"tz\":\"America/New_York\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["cron"] },
|
||||
"expr": { "type": "string", "description": "Standard 5-field cron expression, e.g. '*/5 * * * *'" },
|
||||
"tz": { "type": "string", "description": "Optional IANA timezone name, e.g. 'America/New_York'. Defaults to UTC." }
|
||||
},
|
||||
"required": ["kind", "expr"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "One-shot schedule at a specific UTC datetime. Example: {\"kind\":\"at\",\"at\":\"2025-12-31T23:59:00Z\"}",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["at"] },
|
||||
"at": { "type": "string", "description": "ISO 8601 UTC datetime string, e.g. '2025-12-31T23:59:00Z'" }
|
||||
},
|
||||
"required": ["kind", "at"]
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"description": "Repeating interval schedule in milliseconds. Example: {\"kind\":\"every\",\"every_ms\":3600000} runs every hour.",
|
||||
"properties": {
|
||||
"kind": { "type": "string", "enum": ["every"] },
|
||||
"every_ms": { "type": "integer", "description": "Interval in milliseconds, e.g. 3600000 for every hour" }
|
||||
},
|
||||
"required": ["kind", "every_ms"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"delivery": {
|
||||
"type": "object",
|
||||
"description": "Delivery config to send job output to a channel after each run. When provided, mode, channel, and to are all expected.",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["none", "announce"],
|
||||
"description": "'announce' sends output to the specified channel; 'none' disables delivery"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"enum": ["telegram", "discord", "slack", "mattermost", "matrix"],
|
||||
"description": "Channel type to deliver output to"
|
||||
},
|
||||
"to": {
|
||||
"type": "string",
|
||||
"description": "Destination ID: Discord channel ID, Telegram chat ID, Slack channel name, etc."
|
||||
},
|
||||
"best_effort": {
|
||||
"type": "boolean",
|
||||
"description": "If true, a delivery failure does not fail the job itself. Defaults to true."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"approved": {
|
||||
"type": "boolean",
|
||||
"description": "Set true to explicitly approve medium/high-risk shell commands in supervised mode",
|
||||
@@ -274,6 +372,106 @@ mod tests {
|
||||
assert!(approved.success, "{:?}", approved.error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_schema_covers_all_cronjobpatch_fields_and_schedule_is_oneof() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = Arc::new(Config {
|
||||
workspace_dir: tmp.path().join("workspace"),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
});
|
||||
let security = Arc::new(SecurityPolicy::from_config(
|
||||
&cfg.autonomy,
|
||||
&cfg.workspace_dir,
|
||||
));
|
||||
let tool = CronUpdateTool::new(cfg, security);
|
||||
let schema = tool.parameters_schema();
|
||||
|
||||
// Top-level: job_id and patch are required
|
||||
let top_required = schema["required"].as_array().expect("top-level required");
|
||||
let top_req_strs: Vec<&str> = top_required.iter().filter_map(|v| v.as_str()).collect();
|
||||
assert!(top_req_strs.contains(&"job_id"));
|
||||
assert!(top_req_strs.contains(&"patch"));
|
||||
|
||||
// patch exposes all CronJobPatch fields
|
||||
let patch_props = schema["properties"]["patch"]["properties"]
|
||||
.as_object()
|
||||
.expect("patch must have a properties object");
|
||||
for field in &[
|
||||
"name",
|
||||
"enabled",
|
||||
"command",
|
||||
"prompt",
|
||||
"model",
|
||||
"session_target",
|
||||
"delete_after_run",
|
||||
"schedule",
|
||||
"delivery",
|
||||
] {
|
||||
assert!(
|
||||
patch_props.contains_key(*field),
|
||||
"patch schema missing field: {field}"
|
||||
);
|
||||
}
|
||||
|
||||
// patch.schedule is a oneOf with exactly 3 variants: cron, at, every
|
||||
let one_of = schema["properties"]["patch"]["properties"]["schedule"]["oneOf"]
|
||||
.as_array()
|
||||
.expect("patch.schedule.oneOf must be an array");
|
||||
assert_eq!(one_of.len(), 3, "expected cron, at, and every variants");
|
||||
|
||||
let kinds: Vec<&str> = one_of
|
||||
.iter()
|
||||
.filter_map(|v| v["properties"]["kind"]["enum"][0].as_str())
|
||||
.collect();
|
||||
assert!(kinds.contains(&"cron"), "missing cron variant");
|
||||
assert!(kinds.contains(&"at"), "missing at variant");
|
||||
assert!(kinds.contains(&"every"), "missing every variant");
|
||||
|
||||
// Each variant declares its required fields and every_ms is typed integer
|
||||
for variant in one_of {
|
||||
let kind = variant["properties"]["kind"]["enum"][0]
|
||||
.as_str()
|
||||
.expect("variant kind");
|
||||
let req: Vec<&str> = variant["required"]
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("{kind} variant must have required"))
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect();
|
||||
assert!(
|
||||
req.contains(&"kind"),
|
||||
"{kind} variant missing 'kind' in required"
|
||||
);
|
||||
match kind {
|
||||
"cron" => assert!(req.contains(&"expr"), "cron variant missing 'expr'"),
|
||||
"at" => assert!(req.contains(&"at"), "at variant missing 'at'"),
|
||||
"every" => {
|
||||
assert!(
|
||||
req.contains(&"every_ms"),
|
||||
"every variant missing 'every_ms'"
|
||||
);
|
||||
assert_eq!(
|
||||
variant["properties"]["every_ms"]["type"].as_str(),
|
||||
Some("integer"),
|
||||
"every_ms must be typed as integer"
|
||||
);
|
||||
}
|
||||
_ => panic!("unexpected schedule kind: {kind}"),
|
||||
}
|
||||
}
|
||||
|
||||
// patch.delivery.channel enum covers all supported channels
|
||||
let channel_enum = schema["properties"]["patch"]["properties"]["delivery"]["properties"]
|
||||
["channel"]["enum"]
|
||||
.as_array()
|
||||
.expect("patch.delivery.channel must have an enum");
|
||||
let channel_strs: Vec<&str> = channel_enum.iter().filter_map(|v| v.as_str()).collect();
|
||||
for ch in &["telegram", "discord", "slack", "mattermost", "matrix"] {
|
||||
assert!(channel_strs.contains(ch), "delivery.channel missing: {ch}");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn blocks_update_when_rate_limited() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
@@ -421,6 +421,7 @@ impl DelegateTool {
|
||||
None,
|
||||
&[],
|
||||
&[],
|
||||
None,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -0,0 +1,716 @@
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default `gws` command execution time before kill (overridden by config).
|
||||
const DEFAULT_GWS_TIMEOUT_SECS: u64 = 30;
|
||||
/// Maximum output size in bytes (1MB).
|
||||
const MAX_OUTPUT_BYTES: usize = 1_048_576;
|
||||
|
||||
/// Allowed Google Workspace services that gws can target.
|
||||
const DEFAULT_ALLOWED_SERVICES: &[&str] = &[
|
||||
"drive",
|
||||
"sheets",
|
||||
"gmail",
|
||||
"calendar",
|
||||
"docs",
|
||||
"slides",
|
||||
"tasks",
|
||||
"people",
|
||||
"chat",
|
||||
"classroom",
|
||||
"forms",
|
||||
"keep",
|
||||
"meet",
|
||||
"events",
|
||||
];
|
||||
|
||||
/// Google Workspace CLI (`gws`) integration tool.
|
||||
///
|
||||
/// Wraps the `gws` CLI binary to give the agent structured access to
|
||||
/// Google Workspace services (Drive, Gmail, Calendar, Sheets, etc.).
|
||||
/// Requires `gws` to be installed and authenticated (`gws auth login`).
|
||||
pub struct GoogleWorkspaceTool {
|
||||
security: Arc<SecurityPolicy>,
|
||||
allowed_services: Vec<String>,
|
||||
credentials_path: Option<String>,
|
||||
default_account: Option<String>,
|
||||
rate_limit_per_minute: u32,
|
||||
timeout_secs: u64,
|
||||
audit_log: bool,
|
||||
}
|
||||
|
||||
impl GoogleWorkspaceTool {
|
||||
/// Create a new `GoogleWorkspaceTool`.
|
||||
///
|
||||
/// If `allowed_services` is empty, the default service set is used.
|
||||
pub fn new(
|
||||
security: Arc<SecurityPolicy>,
|
||||
allowed_services: Vec<String>,
|
||||
credentials_path: Option<String>,
|
||||
default_account: Option<String>,
|
||||
rate_limit_per_minute: u32,
|
||||
timeout_secs: u64,
|
||||
audit_log: bool,
|
||||
) -> Self {
|
||||
let services = if allowed_services.is_empty() {
|
||||
DEFAULT_ALLOWED_SERVICES
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect()
|
||||
} else {
|
||||
allowed_services
|
||||
};
|
||||
Self {
|
||||
security,
|
||||
allowed_services: services,
|
||||
credentials_path,
|
||||
default_account,
|
||||
rate_limit_per_minute,
|
||||
timeout_secs,
|
||||
audit_log,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for GoogleWorkspaceTool {
|
||||
fn name(&self) -> &str {
|
||||
"google_workspace"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Interact with Google Workspace services (Drive, Gmail, Calendar, Sheets, Docs, etc.) \
|
||||
via the gws CLI. Requires gws to be installed and authenticated."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"service": {
|
||||
"type": "string",
|
||||
"description": "Google Workspace service (e.g. drive, gmail, calendar, sheets, docs, slides, tasks, people, chat, forms, keep, meet)"
|
||||
},
|
||||
"resource": {
|
||||
"type": "string",
|
||||
"description": "Service resource (e.g. files, messages, events, spreadsheets)"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "Method to call on the resource (e.g. list, get, create, update, delete)"
|
||||
},
|
||||
"sub_resource": {
|
||||
"type": "string",
|
||||
"description": "Optional sub-resource for nested operations"
|
||||
},
|
||||
"params": {
|
||||
"type": "object",
|
||||
"description": "URL/query parameters as key-value pairs (passed as --params JSON)"
|
||||
},
|
||||
"body": {
|
||||
"type": "object",
|
||||
"description": "Request body for POST/PATCH/PUT operations (passed as --json JSON)"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["json", "table", "yaml", "csv"],
|
||||
"description": "Output format (default: json)"
|
||||
},
|
||||
"page_all": {
|
||||
"type": "boolean",
|
||||
"description": "Auto-paginate through all results"
|
||||
},
|
||||
"page_limit": {
|
||||
"type": "integer",
|
||||
"description": "Max pages to fetch when using page_all (default: 10)"
|
||||
}
|
||||
},
|
||||
"required": ["service", "resource", "method"]
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a Google Workspace CLI command with input validation and security enforcement.
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let service = args
|
||||
.get("service")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'service' parameter"))?;
|
||||
let resource = args
|
||||
.get("resource")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'resource' parameter"))?;
|
||||
let method = args
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'method' parameter"))?;
|
||||
|
||||
// Security checks
|
||||
if self.security.is_rate_limited() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate service is in the allowlist
|
||||
if !self.allowed_services.iter().any(|s| s == service) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Service '{service}' is not in the allowed services list. \
|
||||
Allowed: {}",
|
||||
self.allowed_services.join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
// Validate inputs contain no shell metacharacters
|
||||
for (label, value) in [
|
||||
("service", service),
|
||||
("resource", resource),
|
||||
("method", method),
|
||||
] {
|
||||
if !value
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid characters in '{label}': only alphanumeric, underscore, and hyphen are allowed"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Build the gws command — validate all optional fields before consuming budget
|
||||
let mut cmd_args: Vec<String> = vec![service.to_string(), resource.to_string()];
|
||||
|
||||
if let Some(sub_resource_value) = args.get("sub_resource") {
|
||||
let sub_resource = match sub_resource_value.as_str() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'sub_resource' must be a string".into()),
|
||||
})
|
||||
}
|
||||
};
|
||||
if !sub_resource
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"Invalid characters in 'sub_resource': only alphanumeric, underscore, and hyphen are allowed"
|
||||
.into(),
|
||||
),
|
||||
});
|
||||
}
|
||||
cmd_args.push(sub_resource.to_string());
|
||||
}
|
||||
|
||||
cmd_args.push(method.to_string());
|
||||
|
||||
if let Some(params) = args.get("params") {
|
||||
if !params.is_object() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'params' must be an object".into()),
|
||||
});
|
||||
}
|
||||
cmd_args.push("--params".into());
|
||||
cmd_args.push(params.to_string());
|
||||
}
|
||||
|
||||
if let Some(body) = args.get("body") {
|
||||
if !body.is_object() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'body' must be an object".into()),
|
||||
});
|
||||
}
|
||||
cmd_args.push("--json".into());
|
||||
cmd_args.push(body.to_string());
|
||||
}
|
||||
|
||||
if let Some(format_value) = args.get("format") {
|
||||
let format = match format_value.as_str() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'format' must be a string".into()),
|
||||
})
|
||||
}
|
||||
};
|
||||
match format {
|
||||
"json" | "table" | "yaml" | "csv" => {
|
||||
cmd_args.push("--format".into());
|
||||
cmd_args.push(format.to_string());
|
||||
}
|
||||
_ => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid format '{format}': must be json, table, yaml, or csv"
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let page_all = match args.get("page_all") {
|
||||
Some(v) => match v.as_bool() {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'page_all' must be a boolean".into()),
|
||||
})
|
||||
}
|
||||
},
|
||||
None => false,
|
||||
};
|
||||
if page_all {
|
||||
cmd_args.push("--page-all".into());
|
||||
}
|
||||
|
||||
let page_limit = match args.get("page_limit") {
|
||||
Some(v) => match v.as_u64() {
|
||||
Some(n) => Some(n),
|
||||
None => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("'page_limit' must be a non-negative integer".into()),
|
||||
})
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
if page_all || page_limit.is_some() {
|
||||
cmd_args.push("--page-limit".into());
|
||||
cmd_args.push(page_limit.unwrap_or(10).to_string());
|
||||
}
|
||||
|
||||
// Charge action budget only after all validation passes
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let mut cmd = tokio::process::Command::new("gws");
|
||||
cmd.args(&cmd_args);
|
||||
cmd.env_clear();
|
||||
// gws needs PATH to find itself and HOME/APPDATA for credential storage
|
||||
for key in &["PATH", "HOME", "APPDATA", "USERPROFILE", "LANG", "TERM"] {
|
||||
if let Ok(val) = std::env::var(key) {
|
||||
cmd.env(key, val);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply credential path if configured
|
||||
if let Some(ref creds) = self.credentials_path {
|
||||
cmd.env("GOOGLE_APPLICATION_CREDENTIALS", creds);
|
||||
}
|
||||
|
||||
// Apply default account if configured
|
||||
if let Some(ref account) = self.default_account {
|
||||
cmd.args(["--account", account]);
|
||||
}
|
||||
|
||||
if self.audit_log {
|
||||
tracing::info!(
|
||||
tool = "google_workspace",
|
||||
service = service,
|
||||
resource = resource,
|
||||
method = method,
|
||||
"gws audit: executing API call"
|
||||
);
|
||||
}
|
||||
|
||||
// Apply credential path if configured
|
||||
if let Some(ref creds) = self.credentials_path {
|
||||
cmd.env("GOOGLE_APPLICATION_CREDENTIALS", creds);
|
||||
}
|
||||
|
||||
// Apply default account if configured
|
||||
if let Some(ref account) = self.default_account {
|
||||
cmd.args(["--account", account]);
|
||||
}
|
||||
|
||||
if self.audit_log {
|
||||
tracing::info!(
|
||||
tool = "google_workspace",
|
||||
service = service,
|
||||
resource = resource,
|
||||
method = method,
|
||||
"gws audit: executing API call"
|
||||
);
|
||||
}
|
||||
|
||||
let result =
|
||||
tokio::time::timeout(Duration::from_secs(self.timeout_secs), cmd.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => {
|
||||
let mut stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let mut stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
|
||||
if stdout.len() > MAX_OUTPUT_BYTES {
|
||||
// Find a valid char boundary at or before MAX_OUTPUT_BYTES
|
||||
let mut boundary = MAX_OUTPUT_BYTES;
|
||||
while boundary > 0 && !stdout.is_char_boundary(boundary) {
|
||||
boundary -= 1;
|
||||
}
|
||||
stdout.truncate(boundary);
|
||||
stdout.push_str("\n... [output truncated at 1MB]");
|
||||
}
|
||||
if stderr.len() > MAX_OUTPUT_BYTES {
|
||||
let mut boundary = MAX_OUTPUT_BYTES;
|
||||
while boundary > 0 && !stderr.is_char_boundary(boundary) {
|
||||
boundary -= 1;
|
||||
}
|
||||
stderr.truncate(boundary);
|
||||
stderr.push_str("\n... [stderr truncated at 1MB]");
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
success: output.status.success(),
|
||||
output: stdout,
|
||||
error: if stderr.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(stderr)
|
||||
},
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Failed to execute gws: {e}. Is gws installed? Run: npm install -g @googleworkspace/cli"
|
||||
)),
|
||||
}),
|
||||
Err(_) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"gws command timed out after {}s and was killed", self.timeout_secs
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_name() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
assert_eq!(tool.name(), "google_workspace");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_description_non_empty() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_schema_has_required_fields() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["service"].is_object());
|
||||
assert!(schema["properties"]["resource"].is_object());
|
||||
assert!(schema["properties"]["method"].is_object());
|
||||
let required = schema["required"]
|
||||
.as_array()
|
||||
.expect("required should be an array");
|
||||
assert!(required.contains(&json!("service")));
|
||||
assert!(required.contains(&json!("resource")));
|
||||
assert!(required.contains(&json!("method")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_allowed_services_populated() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
assert!(!tool.allowed_services.is_empty());
|
||||
assert!(tool.allowed_services.contains(&"drive".to_string()));
|
||||
assert!(tool.allowed_services.contains(&"gmail".to_string()));
|
||||
assert!(tool.allowed_services.contains(&"calendar".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_allowed_services_override_defaults() {
|
||||
let tool = GoogleWorkspaceTool::new(
|
||||
test_security(),
|
||||
vec!["drive".into(), "sheets".into()],
|
||||
None,
|
||||
None,
|
||||
60,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
assert_eq!(tool.allowed_services.len(), 2);
|
||||
assert!(tool.allowed_services.contains(&"drive".to_string()));
|
||||
assert!(tool.allowed_services.contains(&"sheets".to_string()));
|
||||
assert!(!tool.allowed_services.contains(&"gmail".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_disallowed_service() {
|
||||
let tool = GoogleWorkspaceTool::new(
|
||||
test_security(),
|
||||
vec!["drive".into()],
|
||||
None,
|
||||
None,
|
||||
60,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "gmail",
|
||||
"resource": "users",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("disallowed service should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("not in the allowed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_shell_injection_in_service() {
|
||||
let tool = GoogleWorkspaceTool::new(
|
||||
test_security(),
|
||||
vec!["drive; rm -rf /".into()],
|
||||
None,
|
||||
None,
|
||||
60,
|
||||
30,
|
||||
false,
|
||||
);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive; rm -rf /",
|
||||
"resource": "files",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("shell injection should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Invalid characters"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_shell_injection_in_resource() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files$(whoami)",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("shell injection should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Invalid characters"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_invalid_format() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"format": "xml"
|
||||
}))
|
||||
.await
|
||||
.expect("invalid format should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Invalid format"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_params() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"params": "not_an_object"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type params should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'params' must be an object"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_body() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "create",
|
||||
"body": "not_an_object"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type body should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'body' must be an object"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_page_all() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"page_all": "yes"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type page_all should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'page_all' must be a boolean"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_page_limit() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"page_limit": "ten"
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type page_limit should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'page_limit' must be a non-negative integer"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_wrong_type_sub_resource() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list",
|
||||
"sub_resource": 123
|
||||
}))
|
||||
.await
|
||||
.expect("wrong type sub_resource should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("'sub_resource' must be a string"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_required_param_returns_error() {
|
||||
let tool = GoogleWorkspaceTool::new(test_security(), vec![], None, None, 60, 30, false);
|
||||
let result = tool.execute(json!({"service": "drive"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rate_limited_returns_error() {
|
||||
let security = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::Full,
|
||||
max_actions_per_hour: 0,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = GoogleWorkspaceTool::new(security, vec![], None, None, 60, 30, false);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"service": "drive",
|
||||
"resource": "files",
|
||||
"method": "list"
|
||||
}))
|
||||
.await
|
||||
.expect("rate-limited should return a result");
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("Rate limit"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gws_timeout_is_reasonable() {
|
||||
assert_eq!(DEFAULT_GWS_TIMEOUT_SECS, 30);
|
||||
}
|
||||
}
|
||||
@@ -161,8 +161,7 @@ impl DeferredMcpToolSet {
|
||||
/// The agent loop consults this each iteration to decide which tool_specs
|
||||
/// to include in the LLM request.
|
||||
pub struct ActivatedToolSet {
|
||||
/// name -> activated Tool
|
||||
tools: HashMap<String, Box<dyn Tool>>,
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
}
|
||||
|
||||
impl ActivatedToolSet {
|
||||
@@ -172,27 +171,23 @@ impl ActivatedToolSet {
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a tool as activated, storing its live wrapper.
|
||||
pub fn activate(&mut self, name: String, tool: Box<dyn Tool>) {
|
||||
pub fn activate(&mut self, name: String, tool: Arc<dyn Tool>) {
|
||||
self.tools.insert(name, tool);
|
||||
}
|
||||
|
||||
/// Whether a tool has been activated.
|
||||
pub fn is_activated(&self, name: &str) -> bool {
|
||||
self.tools.contains_key(name)
|
||||
}
|
||||
|
||||
/// Get an activated tool for execution.
|
||||
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
|
||||
self.tools.get(name).map(|t| t.as_ref())
|
||||
/// Clone the Arc so the caller can drop the mutex guard before awaiting.
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||
self.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
/// All currently activated tool specs (to include in LLM requests).
|
||||
pub fn tool_specs(&self) -> Vec<ToolSpec> {
|
||||
self.tools.values().map(|t| t.spec()).collect()
|
||||
}
|
||||
|
||||
/// All activated tools for execution dispatch.
|
||||
pub fn tool_names(&self) -> Vec<&str> {
|
||||
self.tools.keys().map(|s| s.as_str()).collect()
|
||||
}
|
||||
@@ -280,7 +275,7 @@ mod tests {
|
||||
|
||||
let mut set = ActivatedToolSet::new();
|
||||
assert!(!set.is_activated("fake"));
|
||||
set.activate("fake".into(), Box::new(FakeTool));
|
||||
set.activate("fake".into(), Arc::new(FakeTool));
|
||||
assert!(set.is_activated("fake"));
|
||||
assert!(set.get("fake").is_some());
|
||||
assert_eq!(set.tool_specs().len(), 1);
|
||||
|
||||
@@ -37,6 +37,7 @@ pub mod file_read;
|
||||
pub mod file_write;
|
||||
pub mod git_operations;
|
||||
pub mod glob_search;
|
||||
pub mod google_workspace;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub mod hardware_board_info;
|
||||
#[cfg(feature = "hardware")]
|
||||
@@ -96,6 +97,7 @@ pub use file_read::FileReadTool;
|
||||
pub use file_write::FileWriteTool;
|
||||
pub use git_operations::GitOperationsTool;
|
||||
pub use glob_search::GlobSearchTool;
|
||||
pub use google_workspace::GoogleWorkspaceTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
pub use hardware_board_info::HardwareBoardInfoTool;
|
||||
#[cfg(feature = "hardware")]
|
||||
@@ -433,6 +435,23 @@ pub fn all_tools_with_runtime(
|
||||
tool_arcs.push(Arc::new(CloudPatternsTool::new()));
|
||||
}
|
||||
|
||||
// Google Workspace CLI (gws) integration — requires shell access
|
||||
if root_config.google_workspace.enabled && has_shell_access {
|
||||
tool_arcs.push(Arc::new(GoogleWorkspaceTool::new(
|
||||
security.clone(),
|
||||
root_config.google_workspace.allowed_services.clone(),
|
||||
root_config.google_workspace.credentials_path.clone(),
|
||||
root_config.google_workspace.default_account.clone(),
|
||||
root_config.google_workspace.rate_limit_per_minute,
|
||||
root_config.google_workspace.timeout_secs,
|
||||
root_config.google_workspace.audit_log,
|
||||
)));
|
||||
} else if root_config.google_workspace.enabled {
|
||||
tracing::warn!(
|
||||
"google_workspace: skipped registration because shell access is unavailable"
|
||||
);
|
||||
}
|
||||
|
||||
// PDF extraction (feature-gated at compile time via rag-pdf)
|
||||
tool_arcs.push(Arc::new(PdfReadTool::new(security.clone())));
|
||||
|
||||
|
||||
@@ -940,10 +940,10 @@ impl Tool for ModelRoutingConfigTool {
|
||||
}
|
||||
|
||||
match action.as_str() {
|
||||
"set_default" => self.handle_set_default(&args).await,
|
||||
"upsert_scenario" => self.handle_upsert_scenario(&args).await,
|
||||
"set_default" => Box::pin(self.handle_set_default(&args)).await,
|
||||
"upsert_scenario" => Box::pin(self.handle_upsert_scenario(&args)).await,
|
||||
"remove_scenario" => self.handle_remove_scenario(&args).await,
|
||||
"upsert_agent" => self.handle_upsert_agent(&args).await,
|
||||
"upsert_agent" => Box::pin(self.handle_upsert_agent(&args)).await,
|
||||
"remove_agent" => self.handle_remove_agent(&args).await,
|
||||
_ => unreachable!("validated above"),
|
||||
}
|
||||
|
||||
@@ -412,7 +412,7 @@ impl Tool for ProxyConfigTool {
|
||||
}
|
||||
|
||||
match action.as_str() {
|
||||
"set" => self.handle_set(&args).await,
|
||||
"set" => Box::pin(self.handle_set(&args)).await,
|
||||
"disable" => self.handle_disable(&args).await,
|
||||
"apply_env" => self.handle_apply_env(),
|
||||
"clear_env" => self.handle_clear_env(),
|
||||
|
||||
@@ -107,7 +107,7 @@ impl Tool for ToolSearchTool {
|
||||
if let Some(spec) = self.deferred.tool_spec(&stub.prefixed_name) {
|
||||
if !guard.is_activated(&stub.prefixed_name) {
|
||||
if let Some(tool) = self.deferred.activate(&stub.prefixed_name) {
|
||||
guard.activate(stub.prefixed_name.clone(), tool);
|
||||
guard.activate(stub.prefixed_name.clone(), Arc::from(tool));
|
||||
activated_count += 1;
|
||||
}
|
||||
}
|
||||
@@ -152,7 +152,7 @@ impl ToolSearchTool {
|
||||
Some(spec) => {
|
||||
if !guard.is_activated(name) {
|
||||
if let Some(tool) = self.deferred.activate(name) {
|
||||
guard.activate(name.to_string(), tool);
|
||||
guard.activate(name.to_string(), Arc::from(tool));
|
||||
activated_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user