Merge remote-tracking branch 'origin/main'
This commit is contained in:
+350
-118
@@ -60,6 +60,20 @@ const DEFAULT_MAX_TOOL_ITERATIONS: usize = 20;
|
||||
/// Matches the channel-side constant in `channels/mod.rs`.
|
||||
const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20;
|
||||
|
||||
fn should_treat_provider_as_vision_capable(provider_name: &str, provider: &dyn Provider) -> bool {
|
||||
if provider.supports_vision() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Guardrail for issue #2107: some anthropic setups have reported false
|
||||
// negatives from provider capability probing even though Claude models
|
||||
// accept image inputs. Keep the preflight permissive for anthropic routes
|
||||
// and rely on upstream API validation if a specific model cannot handle
|
||||
// vision.
|
||||
let normalized = provider_name.trim().to_ascii_lowercase();
|
||||
normalized == "anthropic" || normalized.starts_with("anthropic-custom:")
|
||||
}
|
||||
|
||||
/// Slash-command definitions for interactive-mode completion.
|
||||
/// Each entry: (trigger aliases, display label, description).
|
||||
const SLASH_COMMANDS: &[(&[&str], &str, &str)] = &[
|
||||
@@ -173,6 +187,43 @@ static SENSITIVE_KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
|
||||
});
|
||||
|
||||
/// Detect "I'll do X" style deferred-action replies that often indicate a missing
|
||||
/// follow-up tool call in agentic flows.
|
||||
static DEFERRED_ACTION_WITHOUT_TOOL_CALL_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(
|
||||
r"(?ix)
|
||||
\b(
|
||||
i(?:'ll|\s+will)|
|
||||
i\s+am\s+going\s+to|
|
||||
let\s+me|
|
||||
let(?:'s|\s+us)|
|
||||
we(?:'ll|\s+will)
|
||||
)\b
|
||||
[^.!?\n]{0,160}
|
||||
\b(
|
||||
check|look|search|browse|open|read|write|run|execute|call|
|
||||
inspect|analy(?:s|z)e|verify|list|fetch|try|see|continue
|
||||
)\b",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
/// Detect common CJK deferred-action phrases (e.g., Chinese "让我…查看")
|
||||
/// that imply a follow-up tool call should occur.
|
||||
static CJK_DEFERRED_ACTION_CUE_REGEX: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"(让我|我来|我会|我们来|我们会|我先|先让我|马上)").unwrap());
|
||||
|
||||
/// Action verbs commonly used when promising to perform tool-backed work in CJK text.
|
||||
static CJK_DEFERRED_ACTION_VERB_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"(查看|检查|搜索|查找|浏览|打开|读取|写入|运行|执行|调用|分析|验证|列出|获取|尝试|试试|继续|处理|修复|看看|看一看|看一下)").unwrap()
|
||||
});
|
||||
|
||||
/// Fast check for CJK scripts (Han/Hiragana/Katakana/Hangul) so we only run
|
||||
/// additional regexes when non-Latin text is present.
|
||||
static CJK_SCRIPT_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"[\p{Script=Han}\p{Script=Hiragana}\p{Script=Katakana}\p{Script=Hangul}]").unwrap()
|
||||
});
|
||||
|
||||
/// Scrub credentials from tool output to prevent accidental exfiltration.
|
||||
/// Replaces known credential patterns with a redacted placeholder while preserving
|
||||
/// a small prefix for context.
|
||||
@@ -241,6 +292,7 @@ const AUTO_CRON_DELIVERY_CHANNELS: &[&str] = &[
|
||||
|
||||
const NON_CLI_APPROVAL_WAIT_TIMEOUT_SECS: u64 = 300;
|
||||
const NON_CLI_APPROVAL_POLL_INTERVAL_MS: u64 = 250;
|
||||
const MISSING_TOOL_CALL_RETRY_PROMPT: &str = "Internal correction: your last reply indicated you were about to take an action, but no valid tool call was emitted. If a tool is needed, emit it now using the required <tool_call>...</tool_call> format. If no tool is needed, provide the complete final answer now and do not defer action.";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct NonCliApprovalPrompt {
|
||||
@@ -276,6 +328,21 @@ fn truncate_tool_args_for_progress(name: &str, args: &serde_json::Value, max_len
|
||||
}
|
||||
}
|
||||
|
||||
fn looks_like_deferred_action_without_tool_call(text: &str) -> bool {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if DEFERRED_ACTION_WITHOUT_TOOL_CALL_REGEX.is_match(trimmed) {
|
||||
return true;
|
||||
}
|
||||
|
||||
CJK_SCRIPT_REGEX.is_match(trimmed)
|
||||
&& CJK_DEFERRED_ACTION_CUE_REGEX.is_match(trimmed)
|
||||
&& CJK_DEFERRED_ACTION_VERB_REGEX.is_match(trimmed)
|
||||
}
|
||||
|
||||
fn maybe_inject_cron_add_delivery(
|
||||
tool_name: &str,
|
||||
tool_args: &mut serde_json::Value,
|
||||
@@ -726,6 +793,8 @@ pub(crate) async fn run_tool_call_loop(
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
let turn_id = Uuid::new_v4().to_string();
|
||||
let mut seen_tool_signatures: HashSet<(String, String)> = HashSet::new();
|
||||
let mut missing_tool_call_retry_used = false;
|
||||
let mut missing_tool_call_retry_prompt: Option<String> = None;
|
||||
let bypass_non_cli_approval_for_turn =
|
||||
approval.is_some_and(|mgr| channel_name != "cli" && mgr.consume_non_cli_allow_all_once());
|
||||
if bypass_non_cli_approval_for_turn {
|
||||
@@ -750,7 +819,9 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
let provider_supports_vision =
|
||||
should_treat_provider_as_vision_capable(provider_name, provider);
|
||||
if image_marker_count > 0 && !provider_supports_vision {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
@@ -763,6 +834,10 @@ pub(crate) async fn run_tool_call_loop(
|
||||
|
||||
let prepared_messages =
|
||||
multimodal::prepare_messages_for_provider(history, multimodal_config).await?;
|
||||
let mut request_messages = prepared_messages.messages.clone();
|
||||
if let Some(prompt) = missing_tool_call_retry_prompt.take() {
|
||||
request_messages.push(ChatMessage::user(prompt));
|
||||
}
|
||||
|
||||
// ── Progress: LLM thinking ────────────────────────────
|
||||
if let Some(ref tx) = on_delta {
|
||||
@@ -810,7 +885,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||
|
||||
let chat_future = provider.chat(
|
||||
ChatRequest {
|
||||
messages: &prepared_messages.messages,
|
||||
messages: &request_messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
@@ -826,138 +901,145 @@ pub(crate) async fn run_tool_call_loop(
|
||||
chat_future.await
|
||||
};
|
||||
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
||||
match chat_result {
|
||||
Ok(resp) => {
|
||||
let (resp_input_tokens, resp_output_tokens) = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| (u.input_tokens, u.output_tokens))
|
||||
.unwrap_or((None, None));
|
||||
let (
|
||||
response_text,
|
||||
parsed_text,
|
||||
tool_calls,
|
||||
assistant_history_content,
|
||||
native_tool_calls,
|
||||
parse_issue_detected,
|
||||
) = match chat_result {
|
||||
Ok(resp) => {
|
||||
let (resp_input_tokens, resp_output_tokens) = resp
|
||||
.usage
|
||||
.as_ref()
|
||||
.map(|u| (u.input_tokens, u.output_tokens))
|
||||
.unwrap_or((None, None));
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: resp_input_tokens,
|
||||
output_tokens: resp_output_tokens,
|
||||
});
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
input_tokens: resp_input_tokens,
|
||||
output_tokens: resp_output_tokens,
|
||||
});
|
||||
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
// Fall back to text-based parsing (XML tags, markdown blocks,
|
||||
// GLM format) only if the provider returned no native calls —
|
||||
// this ensures we support both native and prompt-guided models.
|
||||
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
|
||||
let mut parsed_text = String::new();
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
// First try native structured tool calls (OpenAI-format).
|
||||
// Fall back to text-based parsing (XML tags, markdown blocks,
|
||||
// GLM format) only if the provider returned no native calls —
|
||||
// this ensures we support both native and prompt-guided models.
|
||||
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
|
||||
let mut parsed_text = String::new();
|
||||
|
||||
if calls.is_empty() {
|
||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
||||
if !fallback_text.is_empty() {
|
||||
parsed_text = fallback_text;
|
||||
}
|
||||
calls = fallback_calls;
|
||||
if calls.is_empty() {
|
||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
||||
if !fallback_text.is_empty() {
|
||||
parsed_text = fallback_text;
|
||||
}
|
||||
|
||||
if let Some(parse_issue) = detect_tool_call_parse_issue(&response_text, &calls)
|
||||
{
|
||||
runtime_trace::record_event(
|
||||
"tool_call_parse_issue",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&parse_issue),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"response_excerpt": truncate_with_ellipsis(
|
||||
&scrub_credentials(&response_text),
|
||||
600
|
||||
),
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"duration_ms": llm_started_at.elapsed().as_millis(),
|
||||
"input_tokens": resp_input_tokens,
|
||||
"output_tokens": resp_output_tokens,
|
||||
"raw_response": scrub_credentials(&response_text),
|
||||
"native_tool_calls": resp.tool_calls.len(),
|
||||
"parsed_tool_calls": calls.len(),
|
||||
}),
|
||||
);
|
||||
|
||||
// Preserve native tool call IDs in assistant history so role=tool
|
||||
// follow-up messages can reference the exact call id.
|
||||
let reasoning_content = resp.reasoning_content.clone();
|
||||
let assistant_history_content = if resp.tool_calls.is_empty() {
|
||||
if use_native_tools {
|
||||
build_native_assistant_history_from_parsed_calls(
|
||||
&response_text,
|
||||
&calls,
|
||||
reasoning_content.as_deref(),
|
||||
)
|
||||
.unwrap_or_else(|| response_text.clone())
|
||||
} else {
|
||||
response_text.clone()
|
||||
}
|
||||
} else {
|
||||
build_native_assistant_history(
|
||||
&response_text,
|
||||
&resp.tool_calls,
|
||||
reasoning_content.as_deref(),
|
||||
)
|
||||
};
|
||||
|
||||
let native_calls = resp.tool_calls;
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
native_calls,
|
||||
)
|
||||
calls = fallback_calls;
|
||||
}
|
||||
Err(e) => {
|
||||
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(safe_error.clone()),
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
});
|
||||
|
||||
let parse_issue = detect_tool_call_parse_issue(&response_text, &calls);
|
||||
if let Some(parse_issue) = parse_issue.as_deref() {
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
"tool_call_parse_issue",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&safe_error),
|
||||
Some(&parse_issue),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"duration_ms": llm_started_at.elapsed().as_millis(),
|
||||
"response_excerpt": truncate_with_ellipsis(
|
||||
&scrub_credentials(&response_text),
|
||||
600
|
||||
),
|
||||
}),
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
None,
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"duration_ms": llm_started_at.elapsed().as_millis(),
|
||||
"input_tokens": resp_input_tokens,
|
||||
"output_tokens": resp_output_tokens,
|
||||
"raw_response": scrub_credentials(&response_text),
|
||||
"native_tool_calls": resp.tool_calls.len(),
|
||||
"parsed_tool_calls": calls.len(),
|
||||
}),
|
||||
);
|
||||
|
||||
// Preserve native tool call IDs in assistant history so role=tool
|
||||
// follow-up messages can reference the exact call id.
|
||||
let reasoning_content = resp.reasoning_content.clone();
|
||||
let assistant_history_content = if resp.tool_calls.is_empty() {
|
||||
if use_native_tools {
|
||||
build_native_assistant_history_from_parsed_calls(
|
||||
&response_text,
|
||||
&calls,
|
||||
reasoning_content.as_deref(),
|
||||
)
|
||||
.unwrap_or_else(|| response_text.clone())
|
||||
} else {
|
||||
response_text.clone()
|
||||
}
|
||||
} else {
|
||||
build_native_assistant_history(
|
||||
&response_text,
|
||||
&resp.tool_calls,
|
||||
reasoning_content.as_deref(),
|
||||
)
|
||||
};
|
||||
|
||||
let native_calls = resp.tool_calls;
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
native_calls,
|
||||
parse_issue.is_some(),
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
let safe_error = crate::providers::sanitize_api_error(&e.to_string());
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(safe_error.clone()),
|
||||
input_tokens: None,
|
||||
output_tokens: None,
|
||||
});
|
||||
runtime_trace::record_event(
|
||||
"llm_response",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(false),
|
||||
Some(&safe_error),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"duration_ms": llm_started_at.elapsed().as_millis(),
|
||||
}),
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let display_text = if parsed_text.is_empty() {
|
||||
response_text.clone()
|
||||
@@ -979,6 +1061,46 @@ pub(crate) async fn run_tool_call_loop(
|
||||
}
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
let missing_tool_call_followthrough = !missing_tool_call_retry_used
|
||||
&& iteration + 1 < max_iterations
|
||||
&& !tool_specs.is_empty()
|
||||
&& (parse_issue_detected
|
||||
|| looks_like_deferred_action_without_tool_call(&display_text));
|
||||
if missing_tool_call_followthrough {
|
||||
missing_tool_call_retry_used = true;
|
||||
missing_tool_call_retry_prompt = Some(MISSING_TOOL_CALL_RETRY_PROMPT.to_string());
|
||||
let retry_reason = if parse_issue_detected {
|
||||
"parse_issue_detected"
|
||||
} else {
|
||||
"deferred_action_text_detected"
|
||||
};
|
||||
|
||||
runtime_trace::record_event(
|
||||
"tool_call_followthrough_retry",
|
||||
Some(channel_name),
|
||||
Some(provider_name),
|
||||
Some(model),
|
||||
Some(&turn_id),
|
||||
Some(true),
|
||||
Some("llm response implied follow-up action but emitted no tool call"),
|
||||
serde_json::json!({
|
||||
"iteration": iteration + 1,
|
||||
"reason": retry_reason,
|
||||
"response_excerpt": truncate_with_ellipsis(&scrub_credentials(&display_text), 600),
|
||||
}),
|
||||
);
|
||||
|
||||
if let Some(ref tx) = on_delta {
|
||||
let _ = tx
|
||||
.send(format!(
|
||||
"{DRAFT_PROGRESS_SENTINEL}\u{21bb} Retrying: response deferred action without a tool call\n"
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
runtime_trace::record_event(
|
||||
"turn_final_response",
|
||||
Some(channel_name),
|
||||
@@ -2641,6 +2763,39 @@ mod tests {
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_allows_anthropic_route_on_vision_probe_false_negative() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec!["vision-ok"]);
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"please inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"anthropic",
|
||||
"opus-4-6",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("anthropic route should not fail on a false-negative vision capability probe");
|
||||
|
||||
assert_eq!(result, "vision-ok");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_rejects_oversized_image_payload() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
@@ -3249,6 +3404,57 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_retries_once_when_response_defers_action_without_tool_call() {
|
||||
let provider = ScriptedProvider::from_text_responses(vec![
|
||||
"I'll check that right away.",
|
||||
r#"<tool_call>
|
||||
{"name":"count_tool","arguments":{"value":"retry"}}
|
||||
</tool_call>"#,
|
||||
"done after tool",
|
||||
]);
|
||||
|
||||
let invocations = Arc::new(AtomicUsize::new(0));
|
||||
let tools_registry: Vec<Box<dyn Tool>> = vec![Box::new(CountingTool::new(
|
||||
"count_tool",
|
||||
Arc::clone(&invocations),
|
||||
))];
|
||||
|
||||
let mut history = vec![
|
||||
ChatMessage::system("test-system"),
|
||||
ChatMessage::user("please check the workspace"),
|
||||
];
|
||||
let observer = NoopObserver;
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
5,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.expect("loop should recover after one deferred-action reply");
|
||||
|
||||
assert_eq!(result, "done after tool");
|
||||
assert_eq!(
|
||||
invocations.load(Ordering::SeqCst),
|
||||
1,
|
||||
"the fallback retry should lead to an actual tool execution"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_extracts_single_call() {
|
||||
let response = r#"Let me check that.
|
||||
@@ -4148,6 +4354,32 @@ Done."#;
|
||||
assert!(issue.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn looks_like_deferred_action_without_tool_call_detects_action_promises() {
|
||||
assert!(looks_like_deferred_action_without_tool_call(
|
||||
"Webpage opened, let's see what's new here."
|
||||
));
|
||||
assert!(looks_like_deferred_action_without_tool_call(
|
||||
"It seems absolute paths are blocked. Let me try using a relative path."
|
||||
));
|
||||
assert!(looks_like_deferred_action_without_tool_call(
|
||||
"看起来绝对路径不可用,让我尝试使用当前目录的相对路径。"
|
||||
));
|
||||
assert!(looks_like_deferred_action_without_tool_call(
|
||||
"页面已打开,让我获取快照查看详细信息。"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn looks_like_deferred_action_without_tool_call_ignores_final_answers() {
|
||||
assert!(!looks_like_deferred_action_without_tool_call(
|
||||
"The latest update is already shown above."
|
||||
));
|
||||
assert!(!looks_like_deferred_action_without_tool_call(
|
||||
"最新结果已经在上面整理完成。"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_handles_whitespace_only_name() {
|
||||
// Recovery: Whitespace-only tool name should return None
|
||||
|
||||
@@ -115,6 +115,13 @@ impl PromptSection for IdentitySection {
|
||||
inject_workspace_file(&mut prompt, ctx.workspace_dir, "MEMORY.md");
|
||||
}
|
||||
|
||||
let extra_files = ctx.identity_config.map_or(&[][..], |cfg| cfg.extra_files.as_slice());
|
||||
for file in extra_files {
|
||||
if let Some(safe_relative) = normalize_openclaw_identity_extra_file(file) {
|
||||
inject_workspace_file(&mut prompt, ctx.workspace_dir, safe_relative);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
}
|
||||
@@ -260,6 +267,29 @@ fn inject_workspace_file(prompt: &mut String, workspace_dir: &Path, filename: &s
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_openclaw_identity_extra_file(raw: &str) -> Option<&str> {
|
||||
use std::path::{Component, Path};
|
||||
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let path = Path::new(trimmed);
|
||||
if path.is_absolute() {
|
||||
return None;
|
||||
}
|
||||
|
||||
for component in path.components() {
|
||||
match component {
|
||||
Component::Normal(_) | Component::CurDir => {}
|
||||
Component::ParentDir | Component::RootDir | Component::Prefix(_) => return None,
|
||||
}
|
||||
}
|
||||
|
||||
Some(trimmed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -307,6 +337,7 @@ mod tests {
|
||||
|
||||
let identity_config = crate::config::IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: Some(r#"{"identity":{"names":{"first":"Nova"}}}"#.into()),
|
||||
};
|
||||
@@ -337,6 +368,96 @@ mod tests {
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_section_openclaw_injects_extra_files() {
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_prompt_extra_files_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
std::fs::create_dir_all(workspace.join("memory")).unwrap();
|
||||
std::fs::write(workspace.join("AGENTS.md"), "agent baseline").unwrap();
|
||||
std::fs::write(workspace.join("SOUL.md"), "soul baseline").unwrap();
|
||||
std::fs::write(workspace.join("TOOLS.md"), "tools baseline").unwrap();
|
||||
std::fs::write(workspace.join("IDENTITY.md"), "identity baseline").unwrap();
|
||||
std::fs::write(workspace.join("USER.md"), "user baseline").unwrap();
|
||||
std::fs::write(workspace.join("FRAMEWORK.md"), "framework context").unwrap();
|
||||
std::fs::write(workspace.join("memory").join("notes.md"), "memory notes").unwrap();
|
||||
|
||||
let identity_config = crate::config::IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: vec!["FRAMEWORK.md".into(), "memory/notes.md".into()],
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![];
|
||||
let ctx = PromptContext {
|
||||
workspace_dir: &workspace,
|
||||
model_name: "test-model",
|
||||
tools: &tools,
|
||||
skills: &[],
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode::Full,
|
||||
identity_config: Some(&identity_config),
|
||||
dispatcher_instructions: "",
|
||||
};
|
||||
|
||||
let section = IdentitySection;
|
||||
let output = section.build(&ctx).unwrap();
|
||||
|
||||
assert!(output.contains("### FRAMEWORK.md"));
|
||||
assert!(output.contains("framework context"));
|
||||
assert!(output.contains("### memory/notes.md"));
|
||||
assert!(output.contains("memory notes"));
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn identity_section_openclaw_rejects_unsafe_extra_files() {
|
||||
let workspace = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_prompt_extra_files_unsafe_test_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
std::fs::create_dir_all(&workspace).unwrap();
|
||||
std::fs::write(workspace.join("AGENTS.md"), "agent baseline").unwrap();
|
||||
std::fs::write(workspace.join("SOUL.md"), "soul baseline").unwrap();
|
||||
std::fs::write(workspace.join("TOOLS.md"), "tools baseline").unwrap();
|
||||
std::fs::write(workspace.join("IDENTITY.md"), "identity baseline").unwrap();
|
||||
std::fs::write(workspace.join("USER.md"), "user baseline").unwrap();
|
||||
std::fs::write(workspace.join("SAFE.md"), "safe context").unwrap();
|
||||
|
||||
let identity_config = crate::config::IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: vec![
|
||||
"SAFE.md".into(),
|
||||
"../outside.md".into(),
|
||||
"/tmp/absolute.md".into(),
|
||||
],
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
};
|
||||
|
||||
let tools: Vec<Box<dyn Tool>> = vec![];
|
||||
let ctx = PromptContext {
|
||||
workspace_dir: &workspace,
|
||||
model_name: "test-model",
|
||||
tools: &tools,
|
||||
skills: &[],
|
||||
skills_prompt_mode: crate::config::SkillsPromptInjectionMode::Full,
|
||||
identity_config: Some(&identity_config),
|
||||
dispatcher_instructions: "",
|
||||
};
|
||||
|
||||
let section = IdentitySection;
|
||||
let output = section.build(&ctx).unwrap();
|
||||
|
||||
assert!(output.contains("### SAFE.md"));
|
||||
assert!(!output.contains("outside.md"));
|
||||
assert!(!output.contains("absolute.md"));
|
||||
|
||||
let _ = std::fs::remove_dir_all(workspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_builder_assembles_sections() {
|
||||
let tools: Vec<Box<dyn Tool>> = vec![Box::new(TestTool)];
|
||||
|
||||
+156
-4
@@ -130,6 +130,36 @@ const CHANNEL_HOOK_MAX_OUTBOUND_CHARS: usize = 20_000;
|
||||
type ProviderCacheMap = Arc<Mutex<HashMap<String, Arc<dyn Provider>>>>;
|
||||
type RouteSelectionMap = Arc<Mutex<HashMap<String, ChannelRouteSelection>>>;
|
||||
|
||||
fn live_channels_registry() -> &'static Mutex<HashMap<String, Arc<dyn Channel>>> {
|
||||
static REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn Channel>>>> = OnceLock::new();
|
||||
REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
fn register_live_channels(channels_by_name: &HashMap<String, Arc<dyn Channel>>) {
|
||||
let mut guard = live_channels_registry()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
guard.clear();
|
||||
for (name, channel) in channels_by_name {
|
||||
guard.insert(name.to_ascii_lowercase(), Arc::clone(channel));
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_live_channels() {
|
||||
live_channels_registry()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn get_live_channel(name: &str) -> Option<Arc<dyn Channel>> {
|
||||
live_channels_registry()
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.get(&name.to_ascii_lowercase())
|
||||
.cloned()
|
||||
}
|
||||
|
||||
fn effective_channel_message_timeout_secs(configured: u64) -> u64 {
|
||||
configured.max(MIN_CHANNEL_MESSAGE_TIMEOUT_SECS)
|
||||
}
|
||||
@@ -1079,6 +1109,11 @@ async fn load_runtime_defaults_from_config_file(
|
||||
if let Some(zeroclaw_dir) = path.parent() {
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, parsed.secrets.encrypt);
|
||||
decrypt_optional_secret_for_runtime_reload(&store, &mut parsed.api_key, "config.api_key")?;
|
||||
decrypt_optional_secret_for_runtime_reload(
|
||||
&store,
|
||||
&mut parsed.transcription.api_key,
|
||||
"config.transcription.api_key",
|
||||
)?;
|
||||
}
|
||||
|
||||
parsed.apply_env_overrides();
|
||||
@@ -3815,6 +3850,7 @@ fn load_openclaw_bootstrap_files(
|
||||
prompt: &mut String,
|
||||
workspace_dir: &std::path::Path,
|
||||
max_chars_per_file: usize,
|
||||
identity_config: Option<&crate::config::IdentityConfig>,
|
||||
) {
|
||||
prompt.push_str(
|
||||
"The following workspace files define your identity, behavior, and context. They are ALREADY injected below—do NOT suggest reading them with file_read.\n\n",
|
||||
@@ -3837,6 +3873,44 @@ fn load_openclaw_bootstrap_files(
|
||||
if memory_path.exists() {
|
||||
inject_workspace_file(prompt, workspace_dir, "MEMORY.md", max_chars_per_file);
|
||||
}
|
||||
|
||||
let extra_files = identity_config.map_or(&[][..], |cfg| cfg.extra_files.as_slice());
|
||||
for file in extra_files {
|
||||
match normalize_openclaw_identity_extra_file(file) {
|
||||
Some(safe_relative) => {
|
||||
inject_workspace_file(prompt, workspace_dir, safe_relative, max_chars_per_file);
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(
|
||||
file = file.as_str(),
|
||||
"Ignoring unsafe identity.extra_files entry; expected workspace-relative path without traversal"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_openclaw_identity_extra_file(raw: &str) -> Option<&str> {
|
||||
use std::path::{Component, Path};
|
||||
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let path = Path::new(trimmed);
|
||||
if path.is_absolute() {
|
||||
return None;
|
||||
}
|
||||
|
||||
for component in path.components() {
|
||||
match component {
|
||||
Component::Normal(_) | Component::CurDir => {}
|
||||
Component::ParentDir | Component::RootDir | Component::Prefix(_) => return None,
|
||||
}
|
||||
}
|
||||
|
||||
Some(trimmed)
|
||||
}
|
||||
|
||||
/// Load workspace identity files and build a system prompt.
|
||||
@@ -3982,7 +4056,12 @@ pub fn build_system_prompt_with_mode(
|
||||
// No AIEOS identity loaded (shouldn't happen if is_aieos_configured returned true)
|
||||
// Fall back to OpenClaw bootstrap files
|
||||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||||
load_openclaw_bootstrap_files(
|
||||
&mut prompt,
|
||||
workspace_dir,
|
||||
max_chars,
|
||||
identity_config,
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Log error but don't fail - fall back to OpenClaw
|
||||
@@ -3990,18 +4069,23 @@ pub fn build_system_prompt_with_mode(
|
||||
"Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format."
|
||||
);
|
||||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||||
load_openclaw_bootstrap_files(
|
||||
&mut prompt,
|
||||
workspace_dir,
|
||||
max_chars,
|
||||
identity_config,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// OpenClaw format
|
||||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars, identity_config);
|
||||
}
|
||||
} else {
|
||||
// No identity config - use OpenClaw format
|
||||
let max_chars = bootstrap_max_chars.unwrap_or(BOOTSTRAP_MAX_CHARS);
|
||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars);
|
||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir, max_chars, identity_config);
|
||||
}
|
||||
|
||||
// ── 6. Date & Time ──────────────────────────────────────────
|
||||
@@ -4714,6 +4798,9 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
||||
/// Start all configured channels and route messages to the agent
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn start_channels(config: Config) -> Result<()> {
|
||||
// Ensure stale channel handles are never reused across restarts.
|
||||
clear_live_channels();
|
||||
|
||||
let provider_name = resolved_default_provider(&config);
|
||||
let provider_runtime_options = providers::ProviderRuntimeOptions {
|
||||
auth_profile_override: None,
|
||||
@@ -5039,6 +5126,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
.map(|ch| (ch.name().to_string(), Arc::clone(ch)))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
);
|
||||
register_live_channels(channels_by_name.as_ref());
|
||||
let max_in_flight_messages = compute_max_in_flight_messages(channels.len());
|
||||
|
||||
println!(" 🚦 In-flight message limit: {max_in_flight_messages}");
|
||||
@@ -5113,6 +5201,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||
let _ = h.await;
|
||||
}
|
||||
|
||||
clear_live_channels();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9965,6 +10055,7 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
// Create identity config pointing to the file
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("aieos_identity.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -9999,6 +10090,7 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: Some(r#"{"identity":{"names":{"first":"Claw"}}}"#.into()),
|
||||
};
|
||||
@@ -10022,6 +10114,7 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("nonexistent.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -10041,6 +10134,7 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
// Format is "aieos" but neither path nor inline is set
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -10059,6 +10153,7 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
|
||||
let config = IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("identity.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -10072,6 +10167,63 @@ BTC is currently around $65,000 based on latest tool output."#;
|
||||
assert!(!prompt.contains("## Identity"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openclaw_extra_files_are_injected() {
|
||||
use crate::config::IdentityConfig;
|
||||
|
||||
let ws = make_workspace();
|
||||
std::fs::write(
|
||||
ws.path().join("FRAMEWORK.md"),
|
||||
"# Framework\nSession-level context.",
|
||||
)
|
||||
.unwrap();
|
||||
std::fs::create_dir_all(ws.path().join("memory")).unwrap();
|
||||
std::fs::write(
|
||||
ws.path().join("memory").join("notes.md"),
|
||||
"# Notes\nSupplemental context.",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: vec!["FRAMEWORK.md".into(), "memory/notes.md".into()],
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
};
|
||||
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||||
|
||||
assert!(prompt.contains("### FRAMEWORK.md"));
|
||||
assert!(prompt.contains("Session-level context."));
|
||||
assert!(prompt.contains("### memory/notes.md"));
|
||||
assert!(prompt.contains("Supplemental context."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn openclaw_extra_files_reject_unsafe_paths() {
|
||||
use crate::config::IdentityConfig;
|
||||
|
||||
let ws = make_workspace();
|
||||
std::fs::write(ws.path().join("SAFE.md"), "safe").unwrap();
|
||||
|
||||
let config = IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: vec![
|
||||
"SAFE.md".into(),
|
||||
"../outside.md".into(),
|
||||
"/tmp/absolute.md".into(),
|
||||
],
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
};
|
||||
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], Some(&config), None);
|
||||
|
||||
assert!(prompt.contains("### SAFE.md"));
|
||||
assert!(!prompt.contains("outside.md"));
|
||||
assert!(!prompt.contains("absolute.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn none_identity_config_uses_openclaw() {
|
||||
let ws = make_workspace();
|
||||
|
||||
@@ -33,9 +33,14 @@ fn normalize_audio_filename(file_name: &str) -> String {
|
||||
|
||||
/// Transcribe audio bytes via a Whisper-compatible transcription API.
|
||||
///
|
||||
/// Returns the transcribed text on success. Requires `GROQ_API_KEY` in the
|
||||
/// environment. The caller is responsible for enforcing duration limits
|
||||
/// *before* downloading the file; this function enforces the byte-size cap.
|
||||
/// Returns the transcribed text on success.
|
||||
///
|
||||
/// Credential resolution order:
|
||||
/// 1. `config.transcription.api_key`
|
||||
/// 2. `GROQ_API_KEY` environment variable (backward compatibility)
|
||||
///
|
||||
/// The caller is responsible for enforcing duration limits *before* downloading
|
||||
/// the file; this function enforces the byte-size cap.
|
||||
pub async fn transcribe_audio(
|
||||
audio_data: Vec<u8>,
|
||||
file_name: &str,
|
||||
@@ -59,9 +64,21 @@ pub async fn transcribe_audio(
|
||||
)
|
||||
})?;
|
||||
|
||||
let api_key = std::env::var("GROQ_API_KEY").context(
|
||||
"GROQ_API_KEY environment variable is not set — required for voice transcription",
|
||||
)?;
|
||||
let api_key = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
.or_else(|| {
|
||||
std::env::var("GROQ_API_KEY")
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
})
|
||||
.context(
|
||||
"Missing transcription API key: set [transcription].api_key or GROQ_API_KEY environment variable",
|
||||
)?;
|
||||
|
||||
let client = crate::config::build_runtime_proxy_client("transcription.groq");
|
||||
|
||||
@@ -125,7 +142,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_missing_api_key() {
|
||||
// Ensure the key is absent for this test
|
||||
// Ensure fallback env key is absent for this test.
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let data = vec![0u8; 100];
|
||||
@@ -135,11 +152,29 @@ mod tests {
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("GROQ_API_KEY"),
|
||||
err.to_string().contains("transcription API key"),
|
||||
"expected missing-key error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn uses_config_api_key_without_groq_env() {
|
||||
std::env::remove_var("GROQ_API_KEY");
|
||||
|
||||
let data = vec![0u8; 100];
|
||||
let mut config = TranscriptionConfig::default();
|
||||
config.api_key = Some("transcription-key".to_string());
|
||||
|
||||
// Keep invalid extension so we fail before network, but after key resolution.
|
||||
let err = transcribe_audio(data, "recording.aac", &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("Unsupported audio format"),
|
||||
"expected unsupported-format error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mime_for_audio_maps_accepted_formats() {
|
||||
let cases = [
|
||||
|
||||
+9
-9
@@ -12,15 +12,15 @@ pub use schema::{
|
||||
DockerRuntimeConfig, EmbeddingRouteConfig, EstopConfig, FeishuConfig, GatewayConfig,
|
||||
GroupReplyConfig, GroupReplyMode, HardwareConfig, HardwareTransport, HeartbeatConfig,
|
||||
HooksConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig,
|
||||
MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig,
|
||||
NonCliNaturalLanguageApprovalMode, ObservabilityConfig, OtpChallengeDelivery, OtpConfig,
|
||||
OtpMethod, PeripheralBoardConfig, PeripheralsConfig, PerplexityFilterConfig, PluginEntryConfig,
|
||||
PluginsConfig, ProviderConfig, ProxyConfig, ProxyScope, QdrantConfig,
|
||||
QueryClassificationConfig, ReliabilityConfig, ResearchPhaseConfig, ResearchTrigger,
|
||||
ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig,
|
||||
SecretsConfig, SecurityConfig, SecurityRoleConfig, SkillsConfig, SkillsPromptInjectionMode,
|
||||
SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode,
|
||||
SyscallAnomalyConfig, TelegramConfig, TranscriptionConfig, TunnelConfig, UrlAccessConfig,
|
||||
MemoryConfig, ModelRouteConfig, MultimodalConfig, NextcloudTalkConfig, ObservabilityConfig,
|
||||
OtpChallengeDelivery, OtpConfig, OtpMethod, PeripheralBoardConfig, PeripheralsConfig,
|
||||
NonCliNaturalLanguageApprovalMode, PerplexityFilterConfig, PluginEntryConfig, PluginsConfig,
|
||||
ProviderConfig, ProxyConfig, ProxyScope, QdrantConfig, QueryClassificationConfig,
|
||||
ReliabilityConfig, ResearchPhaseConfig, ResearchTrigger, ResourceLimitsConfig, RuntimeConfig,
|
||||
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig,
|
||||
SecurityRoleConfig, SkillsConfig, SkillsPromptInjectionMode, SlackConfig, StorageConfig,
|
||||
StorageProviderConfig, StorageProviderSection, StreamMode, SyscallAnomalyConfig,
|
||||
TelegramConfig, TranscriptionConfig, TunnelConfig, UrlAccessConfig,
|
||||
WasmCapabilityEscalationMode, WasmConfig, WasmModuleHashPolicy, WasmRuntimeConfig,
|
||||
WasmSecurityConfig, WebFetchConfig, WebSearchConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
+42
-4
@@ -508,6 +508,11 @@ pub struct TranscriptionConfig {
|
||||
/// Enable voice transcription for channels that support it.
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
/// API key used for transcription requests.
|
||||
///
|
||||
/// If unset, runtime falls back to `GROQ_API_KEY` for backward compatibility.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Whisper API endpoint URL.
|
||||
#[serde(default = "default_transcription_api_url")]
|
||||
pub api_url: String,
|
||||
@@ -526,6 +531,7 @@ impl Default for TranscriptionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
api_key: None,
|
||||
api_url: default_transcription_api_url(),
|
||||
model: default_transcription_model(),
|
||||
language: None,
|
||||
@@ -948,6 +954,11 @@ pub struct IdentityConfig {
|
||||
/// Identity format: "openclaw" (default) or "aieos"
|
||||
#[serde(default = "default_identity_format")]
|
||||
pub format: String,
|
||||
/// Additional workspace files injected for the OpenClaw identity format.
|
||||
///
|
||||
/// Paths are resolved relative to the workspace root.
|
||||
#[serde(default)]
|
||||
pub extra_files: Vec<String>,
|
||||
/// Path to AIEOS JSON file (relative to workspace)
|
||||
#[serde(default)]
|
||||
pub aieos_path: Option<String>,
|
||||
@@ -964,6 +975,7 @@ impl Default for IdentityConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
format: default_identity_format(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
}
|
||||
@@ -2187,7 +2199,7 @@ impl Default for StorageProviderConfig {
|
||||
/// Controls conversation memory storage, embeddings, hybrid search, response caching,
|
||||
/// and memory snapshot/hydration.
|
||||
/// Configuration for Qdrant vector database backend (`[memory.qdrant]`).
|
||||
/// Used when `[memory].backend = "qdrant"`.
|
||||
/// Used when `[memory].backend = "qdrant"` or `"sqlite_qdrant_hybrid"`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct QdrantConfig {
|
||||
/// Qdrant server URL (e.g. "http://localhost:6333").
|
||||
@@ -2221,10 +2233,10 @@ impl Default for QdrantConfig {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
pub struct MemoryConfig {
|
||||
/// "sqlite" | "lucid" | "postgres" | "qdrant" | "markdown" | "none" (`none` = explicit no-op memory)
|
||||
/// "sqlite" | "sqlite_qdrant_hybrid" | "lucid" | "postgres" | "qdrant" | "markdown" | "none" (`none` = explicit no-op memory)
|
||||
///
|
||||
/// `postgres` requires `[storage.provider.config]` with `db_url` (`dbURL` alias supported).
|
||||
/// `qdrant` uses `[memory.qdrant]` config or `QDRANT_URL` env var.
|
||||
/// `qdrant` and `sqlite_qdrant_hybrid` use `[memory.qdrant]` config or `QDRANT_URL` env var.
|
||||
pub backend: String,
|
||||
/// Auto-save user-stated conversation input to memory (assistant output is excluded)
|
||||
pub auto_save: bool,
|
||||
@@ -2297,7 +2309,7 @@ pub struct MemoryConfig {
|
||||
|
||||
// ── Qdrant backend options ─────────────────────────────────
|
||||
/// Configuration for Qdrant vector database backend.
|
||||
/// Only used when `backend = "qdrant"`.
|
||||
/// Used when `backend = "qdrant"` or `backend = "sqlite_qdrant_hybrid"`.
|
||||
#[serde(default)]
|
||||
pub qdrant: QdrantConfig,
|
||||
}
|
||||
@@ -5985,6 +5997,11 @@ impl Config {
|
||||
config.workspace_dir = workspace_dir;
|
||||
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
|
||||
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.transcription.api_key,
|
||||
"config.transcription.api_key",
|
||||
)?;
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.composio.api_key,
|
||||
@@ -6997,6 +7014,11 @@ impl Config {
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
||||
|
||||
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.transcription.api_key,
|
||||
"config.transcription.api_key",
|
||||
)?;
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.composio.api_key,
|
||||
@@ -8087,6 +8109,7 @@ tool_dispatcher = "xml"
|
||||
config.workspace_dir = dir.join("workspace");
|
||||
config.config_path = dir.join("config.toml");
|
||||
config.api_key = Some("root-credential".into());
|
||||
config.transcription.api_key = Some("transcription-credential".into());
|
||||
config.composio.api_key = Some("composio-credential".into());
|
||||
config.proxy.http_proxy = Some("http://user:pass@proxy.internal:8080".into());
|
||||
config.proxy.https_proxy = Some("https://user:pass@proxy.internal:8443".into());
|
||||
@@ -8134,6 +8157,15 @@ tool_dispatcher = "xml"
|
||||
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
|
||||
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
|
||||
|
||||
let transcription_encrypted = stored.transcription.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(
|
||||
transcription_encrypted
|
||||
));
|
||||
assert_eq!(
|
||||
store.decrypt(transcription_encrypted).unwrap(),
|
||||
"transcription-credential"
|
||||
);
|
||||
|
||||
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(
|
||||
composio_encrypted
|
||||
@@ -10665,6 +10697,7 @@ default_model = "legacy-model"
|
||||
async fn transcription_config_defaults() {
|
||||
let tc = TranscriptionConfig::default();
|
||||
assert!(!tc.enabled);
|
||||
assert!(tc.api_key.is_none());
|
||||
assert!(tc.api_url.contains("groq.com"));
|
||||
assert_eq!(tc.model, "whisper-large-v3-turbo");
|
||||
assert!(tc.language.is_none());
|
||||
@@ -10675,12 +10708,17 @@ default_model = "legacy-model"
|
||||
async fn config_roundtrip_with_transcription() {
|
||||
let mut config = Config::default();
|
||||
config.transcription.enabled = true;
|
||||
config.transcription.api_key = Some("transcription-key".into());
|
||||
config.transcription.language = Some("en".into());
|
||||
|
||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||
let parsed: Config = toml::from_str(&toml_str).unwrap();
|
||||
|
||||
assert!(parsed.transcription.enabled);
|
||||
assert_eq!(
|
||||
parsed.transcription.api_key.as_deref(),
|
||||
Some("transcription-key")
|
||||
);
|
||||
assert_eq!(parsed.transcription.language.as_deref(), Some("en"));
|
||||
assert_eq!(parsed.transcription.model, "whisper-large-v3-turbo");
|
||||
}
|
||||
|
||||
+57
-2
@@ -2,7 +2,7 @@
|
||||
use crate::channels::LarkChannel;
|
||||
use crate::channels::{
|
||||
Channel, DiscordChannel, EmailChannel, MattermostChannel, QQChannel, SendMessage, SlackChannel,
|
||||
TelegramChannel,
|
||||
TelegramChannel, WhatsAppChannel,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::cron::{
|
||||
@@ -308,7 +308,8 @@ pub(crate) async fn deliver_announcement(
|
||||
target: &str,
|
||||
output: &str,
|
||||
) -> Result<()> {
|
||||
match channel.to_ascii_lowercase().as_str() {
|
||||
let normalized = channel.to_ascii_lowercase();
|
||||
match normalized.as_str() {
|
||||
"telegram" => {
|
||||
let tg = config
|
||||
.channels_config
|
||||
@@ -383,6 +384,31 @@ pub(crate) async fn deliver_announcement(
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
}
|
||||
"whatsapp_web" | "whatsapp" => {
|
||||
let wa = config
|
||||
.channels_config
|
||||
.whatsapp
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("whatsapp channel not configured"))?;
|
||||
|
||||
// WhatsApp Web requires the connected channel instance from the
|
||||
// channel runtime. Fall back to cloud mode if configured.
|
||||
if let Some(live_channel) = crate::channels::get_live_channel("whatsapp") {
|
||||
live_channel.send(&SendMessage::new(output, target)).await?;
|
||||
} else if wa.is_cloud_config() {
|
||||
let channel = WhatsAppChannel::new(
|
||||
wa.access_token.clone().unwrap_or_default(),
|
||||
wa.phone_number_id.clone().unwrap_or_default(),
|
||||
wa.verify_token.clone().unwrap_or_default(),
|
||||
wa.allowed_numbers.clone(),
|
||||
);
|
||||
channel.send(&SendMessage::new(output, target)).await?;
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"whatsapp_web delivery requires an active channels runtime session; start daemon/channels with whatsapp web enabled"
|
||||
);
|
||||
}
|
||||
}
|
||||
"lark" => {
|
||||
#[cfg(feature = "channel-lark")]
|
||||
{
|
||||
@@ -1106,4 +1132,33 @@ mod tests {
|
||||
let err = deliver_if_configured(&config, &job, "x").await.unwrap_err();
|
||||
assert!(err.to_string().contains("unsupported delivery channel"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deliver_if_configured_whatsapp_web_requires_live_session_in_web_mode() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mut config = test_config(&tmp).await;
|
||||
config.channels_config.whatsapp = Some(crate::config::schema::WhatsAppConfig {
|
||||
access_token: None,
|
||||
phone_number_id: None,
|
||||
verify_token: None,
|
||||
app_secret: None,
|
||||
session_path: Some("~/.zeroclaw/state/whatsapp-web/session.db".into()),
|
||||
pair_phone: None,
|
||||
pair_code: None,
|
||||
allowed_numbers: vec!["*".into()],
|
||||
});
|
||||
|
||||
let mut job = test_job("echo ok");
|
||||
job.delivery = DeliveryConfig {
|
||||
mode: "announce".into(),
|
||||
channel: Some("whatsapp_web".into()),
|
||||
to: Some("+15551234567".into()),
|
||||
best_effort: true,
|
||||
};
|
||||
|
||||
let err = deliver_if_configured(&config, &job, "x").await.unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("requires an active channels runtime session"));
|
||||
}
|
||||
}
|
||||
|
||||
+58
-1
@@ -299,7 +299,8 @@ fn heartbeat_delivery_target(config: &Config) -> Result<Option<(String, String)>
|
||||
}
|
||||
|
||||
fn validate_heartbeat_channel_config(config: &Config, channel: &str) -> Result<()> {
|
||||
match channel.to_ascii_lowercase().as_str() {
|
||||
let normalized = channel.to_ascii_lowercase();
|
||||
match normalized.as_str() {
|
||||
"telegram" => {
|
||||
if config.channels_config.telegram.is_none() {
|
||||
anyhow::bail!(
|
||||
@@ -328,6 +329,19 @@ fn validate_heartbeat_channel_config(config: &Config, channel: &str) -> Result<(
|
||||
);
|
||||
}
|
||||
}
|
||||
"whatsapp" | "whatsapp_web" => {
|
||||
let wa = config.channels_config.whatsapp.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"heartbeat.target is set to {channel} but channels_config.whatsapp is not configured"
|
||||
)
|
||||
})?;
|
||||
|
||||
if normalized == "whatsapp_web" && wa.is_cloud_config() && !wa.is_web_config() {
|
||||
anyhow::bail!(
|
||||
"heartbeat.target is set to whatsapp_web but channels_config.whatsapp is configured for cloud mode (set session_path for web mode)"
|
||||
);
|
||||
}
|
||||
}
|
||||
other => anyhow::bail!("unsupported heartbeat.target channel: {other}"),
|
||||
}
|
||||
|
||||
@@ -607,4 +621,47 @@ mod tests {
|
||||
let target = heartbeat_delivery_target(&config).unwrap();
|
||||
assert_eq!(target, Some(("telegram".to_string(), "123456".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_accepts_whatsapp_web_target_in_web_mode() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("whatsapp_web".into());
|
||||
config.heartbeat.to = Some("+15551234567".into());
|
||||
config.channels_config.whatsapp = Some(crate::config::schema::WhatsAppConfig {
|
||||
access_token: None,
|
||||
phone_number_id: None,
|
||||
verify_token: None,
|
||||
app_secret: None,
|
||||
session_path: Some("~/.zeroclaw/state/whatsapp-web/session.db".into()),
|
||||
pair_phone: None,
|
||||
pair_code: None,
|
||||
allowed_numbers: vec!["*".into()],
|
||||
});
|
||||
|
||||
let target = heartbeat_delivery_target(&config).unwrap();
|
||||
assert_eq!(
|
||||
target,
|
||||
Some(("whatsapp_web".to_string(), "+15551234567".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn heartbeat_delivery_target_rejects_whatsapp_web_target_in_cloud_mode() {
|
||||
let mut config = Config::default();
|
||||
config.heartbeat.target = Some("whatsapp_web".into());
|
||||
config.heartbeat.to = Some("+15551234567".into());
|
||||
config.channels_config.whatsapp = Some(crate::config::schema::WhatsAppConfig {
|
||||
access_token: Some("token".into()),
|
||||
phone_number_id: Some("123456".into()),
|
||||
verify_token: Some("verify".into()),
|
||||
app_secret: None,
|
||||
session_path: None,
|
||||
pair_phone: None,
|
||||
pair_code: None,
|
||||
allowed_numbers: vec!["*".into()],
|
||||
});
|
||||
|
||||
let err = heartbeat_delivery_target(&config).unwrap_err();
|
||||
assert!(err.to_string().contains("configured for cloud mode"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -607,6 +607,7 @@ fn mask_sensitive_fields(config: &crate::config::Config) -> crate::config::Confi
|
||||
mask_optional_secret(&mut masked.proxy.http_proxy);
|
||||
mask_optional_secret(&mut masked.proxy.https_proxy);
|
||||
mask_optional_secret(&mut masked.proxy.all_proxy);
|
||||
mask_optional_secret(&mut masked.transcription.api_key);
|
||||
mask_optional_secret(&mut masked.browser.computer_use.api_key);
|
||||
mask_optional_secret(&mut masked.web_fetch.api_key);
|
||||
mask_optional_secret(&mut masked.web_search.api_key);
|
||||
@@ -705,6 +706,7 @@ fn restore_masked_sensitive_fields(
|
||||
restore_optional_secret(&mut incoming.proxy.http_proxy, ¤t.proxy.http_proxy);
|
||||
restore_optional_secret(&mut incoming.proxy.https_proxy, ¤t.proxy.https_proxy);
|
||||
restore_optional_secret(&mut incoming.proxy.all_proxy, ¤t.proxy.all_proxy);
|
||||
restore_optional_secret(&mut incoming.transcription.api_key, ¤t.transcription.api_key);
|
||||
restore_optional_secret(
|
||||
&mut incoming.browser.computer_use.api_key,
|
||||
¤t.browser.computer_use.api_key,
|
||||
@@ -917,6 +919,7 @@ mod tests {
|
||||
current.config_path = std::path::PathBuf::from("/tmp/current/config.toml");
|
||||
current.workspace_dir = std::path::PathBuf::from("/tmp/current/workspace");
|
||||
current.api_key = Some("real-key".to_string());
|
||||
current.transcription.api_key = Some("transcription-real-key".to_string());
|
||||
current.reliability.api_keys = vec!["r1".to_string(), "r2".to_string()];
|
||||
|
||||
let mut incoming = mask_sensitive_fields(¤t);
|
||||
@@ -929,6 +932,7 @@ mod tests {
|
||||
assert_eq!(hydrated.config_path, current.config_path);
|
||||
assert_eq!(hydrated.workspace_dir, current.workspace_dir);
|
||||
assert_eq!(hydrated.api_key, current.api_key);
|
||||
assert_eq!(hydrated.transcription.api_key, current.transcription.api_key);
|
||||
assert_eq!(hydrated.default_model.as_deref(), Some("gpt-4.1-mini"));
|
||||
assert_eq!(
|
||||
hydrated.reliability.api_keys,
|
||||
@@ -964,6 +968,7 @@ mod tests {
|
||||
cfg.proxy.http_proxy = Some("http://user:pass@proxy.internal:8080".to_string());
|
||||
cfg.proxy.https_proxy = Some("https://user:pass@proxy.internal:8443".to_string());
|
||||
cfg.proxy.all_proxy = Some("socks5://user:pass@proxy.internal:1080".to_string());
|
||||
cfg.transcription.api_key = Some("transcription-real-key".to_string());
|
||||
cfg.tunnel.cloudflare = Some(CloudflareTunnelConfig {
|
||||
token: "cloudflare-real-token".to_string(),
|
||||
});
|
||||
@@ -998,6 +1003,7 @@ mod tests {
|
||||
assert_eq!(masked.proxy.http_proxy.as_deref(), Some(MASKED_SECRET));
|
||||
assert_eq!(masked.proxy.https_proxy.as_deref(), Some(MASKED_SECRET));
|
||||
assert_eq!(masked.proxy.all_proxy.as_deref(), Some(MASKED_SECRET));
|
||||
assert_eq!(masked.transcription.api_key.as_deref(), Some(MASKED_SECRET));
|
||||
assert_eq!(
|
||||
masked
|
||||
.tunnel
|
||||
|
||||
@@ -1316,6 +1316,7 @@ mod tests {
|
||||
fn is_aieos_configured_true_with_path() {
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("identity.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -1326,6 +1327,7 @@ mod tests {
|
||||
fn is_aieos_configured_true_with_inline() {
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: Some("{\"identity\":{}}".into()),
|
||||
};
|
||||
@@ -1336,6 +1338,7 @@ mod tests {
|
||||
fn is_aieos_configured_false_openclaw_format() {
|
||||
let config = IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("identity.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -1346,6 +1349,7 @@ mod tests {
|
||||
fn is_aieos_configured_false_no_config() {
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -1520,6 +1524,7 @@ mod tests {
|
||||
|
||||
let config = IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("identity.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub enum MemoryBackendKind {
|
||||
Sqlite,
|
||||
SqliteQdrantHybrid,
|
||||
Lucid,
|
||||
Postgres,
|
||||
Qdrant,
|
||||
@@ -65,6 +66,15 @@ const QDRANT_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const SQLITE_QDRANT_HYBRID_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "sqlite_qdrant_hybrid",
|
||||
label: "SQLite + Qdrant hybrid — SQLite metadata/FTS with Qdrant semantic ranking",
|
||||
auto_save_default: true,
|
||||
uses_sqlite_hygiene: true,
|
||||
sqlite_based: true,
|
||||
optional_dependency: false,
|
||||
};
|
||||
|
||||
const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
|
||||
key: "none",
|
||||
label: "None — disable persistent memory",
|
||||
@@ -101,6 +111,7 @@ pub fn default_memory_backend_key() -> &'static str {
|
||||
pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
|
||||
match backend {
|
||||
"sqlite" => MemoryBackendKind::Sqlite,
|
||||
"sqlite_qdrant_hybrid" | "hybrid" => MemoryBackendKind::SqliteQdrantHybrid,
|
||||
"lucid" => MemoryBackendKind::Lucid,
|
||||
"postgres" => MemoryBackendKind::Postgres,
|
||||
"qdrant" => MemoryBackendKind::Qdrant,
|
||||
@@ -113,6 +124,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
|
||||
pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
|
||||
match classify_memory_backend(backend) {
|
||||
MemoryBackendKind::Sqlite => SQLITE_PROFILE,
|
||||
MemoryBackendKind::SqliteQdrantHybrid => SQLITE_QDRANT_HYBRID_PROFILE,
|
||||
MemoryBackendKind::Lucid => LUCID_PROFILE,
|
||||
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
|
||||
MemoryBackendKind::Qdrant => QDRANT_PROFILE,
|
||||
@@ -129,6 +141,10 @@ mod tests {
|
||||
#[test]
|
||||
fn classify_known_backends() {
|
||||
assert_eq!(classify_memory_backend("sqlite"), MemoryBackendKind::Sqlite);
|
||||
assert_eq!(
|
||||
classify_memory_backend("sqlite_qdrant_hybrid"),
|
||||
MemoryBackendKind::SqliteQdrantHybrid
|
||||
);
|
||||
assert_eq!(classify_memory_backend("lucid"), MemoryBackendKind::Lucid);
|
||||
assert_eq!(
|
||||
classify_memory_backend("postgres"),
|
||||
@@ -141,6 +157,14 @@ mod tests {
|
||||
assert_eq!(classify_memory_backend("none"), MemoryBackendKind::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hybrid_profile_is_sqlite_based() {
|
||||
let profile = memory_backend_profile("sqlite_qdrant_hybrid");
|
||||
assert_eq!(profile.key, "sqlite_qdrant_hybrid");
|
||||
assert!(profile.sqlite_based);
|
||||
assert!(profile.uses_sqlite_hygiene);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_unknown_backend() {
|
||||
assert_eq!(classify_memory_backend("redis"), MemoryBackendKind::Unknown);
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Composite memory backend:
|
||||
/// - SQLite remains authoritative for metadata/content/filtering.
|
||||
/// - Qdrant provides semantic ranking candidates.
|
||||
pub struct SqliteQdrantHybridMemory {
|
||||
sqlite: Arc<dyn Memory>,
|
||||
qdrant: Arc<dyn Memory>,
|
||||
}
|
||||
|
||||
impl SqliteQdrantHybridMemory {
|
||||
pub fn new(sqlite: Arc<dyn Memory>, qdrant: Arc<dyn Memory>) -> Self {
|
||||
Self { sqlite, qdrant }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for SqliteQdrantHybridMemory {
|
||||
fn name(&self) -> &str {
|
||||
"sqlite_qdrant_hybrid"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<()> {
|
||||
// SQLite is authoritative. Fail only if local persistence fails.
|
||||
self.sqlite
|
||||
.store(key, content, category.clone(), session_id)
|
||||
.await?;
|
||||
|
||||
// Best-effort vector sync to Qdrant.
|
||||
if let Err(err) = self.qdrant.store(key, content, category, session_id).await {
|
||||
tracing::warn!(
|
||||
key,
|
||||
error = %err,
|
||||
"Hybrid memory vector sync failed; SQLite entry was stored"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
let trimmed_query = query.trim();
|
||||
if trimmed_query.is_empty() {
|
||||
return self.sqlite.recall(query, limit, session_id).await;
|
||||
}
|
||||
|
||||
let qdrant_candidates = match self
|
||||
.qdrant
|
||||
.recall(trimmed_query, limit.max(1).saturating_mul(3), session_id)
|
||||
.await
|
||||
{
|
||||
Ok(candidates) => candidates,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
query = trimmed_query,
|
||||
error = %err,
|
||||
"Hybrid memory semantic recall failed; falling back to SQLite recall"
|
||||
);
|
||||
return self.sqlite.recall(trimmed_query, limit, session_id).await;
|
||||
}
|
||||
};
|
||||
|
||||
if qdrant_candidates.is_empty() {
|
||||
return self.sqlite.recall(trimmed_query, limit, session_id).await;
|
||||
}
|
||||
|
||||
let mut seen_keys = HashSet::new();
|
||||
let mut merged = Vec::with_capacity(limit);
|
||||
|
||||
for candidate in qdrant_candidates {
|
||||
if !seen_keys.insert(candidate.key.clone()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
match self.sqlite.get(&candidate.key).await {
|
||||
Ok(Some(mut entry)) => {
|
||||
if let Some(filter_sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(filter_sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
entry.score = candidate.score;
|
||||
merged.push(entry);
|
||||
if merged.len() >= limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
// Ignore Qdrant candidates that no longer exist in SQLite.
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
key = candidate.key,
|
||||
error = %err,
|
||||
"Hybrid memory failed to load SQLite row for Qdrant candidate"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if merged.is_empty() {
|
||||
return self.sqlite.recall(trimmed_query, limit, session_id).await;
|
||||
}
|
||||
|
||||
Ok(merged)
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> Result<Option<MemoryEntry>> {
|
||||
self.sqlite.get(key).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
self.sqlite.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> Result<bool> {
|
||||
let removed = self.sqlite.forget(key).await?;
|
||||
if let Err(err) = self.qdrant.forget(key).await {
|
||||
tracing::warn!(
|
||||
key,
|
||||
error = %err,
|
||||
"Hybrid memory vector delete failed; SQLite delete result preserved"
|
||||
);
|
||||
}
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
async fn count(&self) -> Result<usize> {
|
||||
self.sqlite.count().await
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
let sqlite_ok = self.sqlite.health_check().await;
|
||||
if !sqlite_ok {
|
||||
return false;
|
||||
}
|
||||
|
||||
if !self.qdrant.health_check().await {
|
||||
tracing::warn!("Hybrid memory Qdrant health check failed; SQLite remains available");
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{Memory, MemoryCategory, MemoryEntry, SqliteMemory};
|
||||
use std::sync::Mutex;
|
||||
use tempfile::TempDir;
|
||||
|
||||
struct StubQdrantMemory {
|
||||
recall_results: Vec<MemoryEntry>,
|
||||
fail_store: bool,
|
||||
fail_recall: bool,
|
||||
forget_calls: Mutex<Vec<String>>,
|
||||
}
|
||||
|
||||
impl StubQdrantMemory {
|
||||
fn new(recall_results: Vec<MemoryEntry>, fail_store: bool, fail_recall: bool) -> Self {
|
||||
Self {
|
||||
recall_results,
|
||||
fail_store,
|
||||
fail_recall,
|
||||
forget_calls: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for StubQdrantMemory {
|
||||
fn name(&self) -> &str {
|
||||
"qdrant_stub"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> Result<()> {
|
||||
if self.fail_store {
|
||||
anyhow::bail!("simulated qdrant store failure");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
if self.fail_recall {
|
||||
anyhow::bail!("simulated qdrant recall failure");
|
||||
}
|
||||
Ok(self.recall_results.clone())
|
||||
}
|
||||
|
||||
async fn get(&self, _key: &str) -> Result<Option<MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> Result<bool> {
|
||||
self.forget_calls
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.push(key.to_string());
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn count(&self) -> Result<usize> {
|
||||
Ok(self.recall_results.len())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn temp_sqlite() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let sqlite = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, Arc::new(sqlite))
|
||||
}
|
||||
|
||||
fn make_qdrant_entry(key: &str, score: f64) -> MemoryEntry {
|
||||
MemoryEntry {
|
||||
id: format!("vec-{key}"),
|
||||
key: key.to_string(),
|
||||
content: "vector payload".to_string(),
|
||||
category: MemoryCategory::Core,
|
||||
timestamp: "2026-02-27T00:00:00Z".to_string(),
|
||||
session_id: None,
|
||||
score: Some(score),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_keeps_sqlite_when_qdrant_sync_fails() {
|
||||
let (_tmp, sqlite) = temp_sqlite();
|
||||
let qdrant: Arc<dyn Memory> = Arc::new(StubQdrantMemory::new(Vec::new(), true, false));
|
||||
let hybrid = SqliteQdrantHybridMemory::new(Arc::clone(&sqlite), qdrant);
|
||||
|
||||
hybrid
|
||||
.store("fav_lang", "Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stored = sqlite.get("fav_lang").await.unwrap();
|
||||
assert!(stored.is_some(), "SQLite should remain authoritative");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_joins_qdrant_ranking_with_sqlite_rows() {
|
||||
let (_tmp, sqlite) = temp_sqlite();
|
||||
sqlite
|
||||
.store("a", "alpha from sqlite", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
sqlite
|
||||
.store("b", "beta from sqlite", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let qdrant: Arc<dyn Memory> = Arc::new(StubQdrantMemory::new(
|
||||
vec![make_qdrant_entry("b", 0.91), make_qdrant_entry("a", 0.72)],
|
||||
false,
|
||||
false,
|
||||
));
|
||||
let hybrid = SqliteQdrantHybridMemory::new(Arc::clone(&sqlite), qdrant);
|
||||
|
||||
let recalled = hybrid.recall("rank semantically", 2, None).await.unwrap();
|
||||
assert_eq!(recalled.len(), 2);
|
||||
assert_eq!(recalled[0].key, "b");
|
||||
assert_eq!(recalled[0].content, "beta from sqlite");
|
||||
assert_eq!(recalled[0].score, Some(0.91));
|
||||
assert_eq!(recalled[1].key, "a");
|
||||
assert_eq!(recalled[1].score, Some(0.72));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_falls_back_to_sqlite_when_qdrant_fails() {
|
||||
let (_tmp, sqlite) = temp_sqlite();
|
||||
sqlite
|
||||
.store(
|
||||
"topic",
|
||||
"hybrid fallback should still find this",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let qdrant: Arc<dyn Memory> = Arc::new(StubQdrantMemory::new(Vec::new(), false, true));
|
||||
let hybrid = SqliteQdrantHybridMemory::new(Arc::clone(&sqlite), qdrant);
|
||||
|
||||
let recalled = hybrid.recall("fallback", 5, None).await.unwrap();
|
||||
assert!(
|
||||
recalled.iter().any(|entry| entry.key == "topic"),
|
||||
"SQLite fallback should provide recall results when Qdrant is unavailable"
|
||||
);
|
||||
}
|
||||
}
|
||||
+66
-6
@@ -2,6 +2,7 @@ pub mod backend;
|
||||
pub mod chunker;
|
||||
pub mod cli;
|
||||
pub mod embeddings;
|
||||
pub mod hybrid;
|
||||
pub mod hygiene;
|
||||
pub mod lucid;
|
||||
pub mod markdown;
|
||||
@@ -20,6 +21,7 @@ pub use backend::{
|
||||
classify_memory_backend, default_memory_backend_key, memory_backend_profile,
|
||||
selectable_memory_backends, MemoryBackendKind, MemoryBackendProfile,
|
||||
};
|
||||
pub use hybrid::SqliteQdrantHybridMemory;
|
||||
pub use lucid::LucidMemory;
|
||||
pub use markdown::MarkdownMemory;
|
||||
pub use none::NoneMemory;
|
||||
@@ -49,7 +51,9 @@ where
|
||||
G: FnMut() -> anyhow::Result<Box<dyn Memory>>,
|
||||
{
|
||||
match classify_memory_backend(backend_name) {
|
||||
MemoryBackendKind::Sqlite => Ok(Box::new(sqlite_builder()?)),
|
||||
MemoryBackendKind::Sqlite | MemoryBackendKind::SqliteQdrantHybrid => {
|
||||
Ok(Box::new(sqlite_builder()?))
|
||||
}
|
||||
MemoryBackendKind::Lucid => {
|
||||
let local = sqlite_builder()?;
|
||||
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
|
||||
@@ -210,7 +214,9 @@ pub fn create_memory_with_storage_and_routes(
|
||||
&& config.snapshot_on_hygiene
|
||||
&& matches!(
|
||||
backend_kind,
|
||||
MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid
|
||||
MemoryBackendKind::Sqlite
|
||||
| MemoryBackendKind::SqliteQdrantHybrid
|
||||
| MemoryBackendKind::Lucid
|
||||
)
|
||||
{
|
||||
if let Err(e) = snapshot::export_snapshot(workspace_dir) {
|
||||
@@ -223,7 +229,9 @@ pub fn create_memory_with_storage_and_routes(
|
||||
if config.auto_hydrate
|
||||
&& matches!(
|
||||
backend_kind,
|
||||
MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid
|
||||
MemoryBackendKind::Sqlite
|
||||
| MemoryBackendKind::SqliteQdrantHybrid
|
||||
| MemoryBackendKind::Lucid
|
||||
)
|
||||
&& snapshot::should_hydrate(workspace_dir)
|
||||
{
|
||||
@@ -299,7 +307,10 @@ pub fn create_memory_with_storage_and_routes(
|
||||
);
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
|
||||
fn build_qdrant_memory(
|
||||
config: &MemoryConfig,
|
||||
resolved_embedding: &ResolvedEmbeddingConfig,
|
||||
) -> anyhow::Result<QdrantMemory> {
|
||||
let url = config
|
||||
.qdrant
|
||||
.url
|
||||
@@ -332,12 +343,26 @@ pub fn create_memory_with_storage_and_routes(
|
||||
url,
|
||||
collection
|
||||
);
|
||||
return Ok(Box::new(QdrantMemory::new_lazy(
|
||||
Ok(QdrantMemory::new_lazy(
|
||||
&url,
|
||||
&collection,
|
||||
qdrant_api_key,
|
||||
embedder,
|
||||
)));
|
||||
))
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::Qdrant) {
|
||||
return Ok(Box::new(build_qdrant_memory(config, &resolved_embedding)?));
|
||||
}
|
||||
|
||||
if matches!(backend_kind, MemoryBackendKind::SqliteQdrantHybrid) {
|
||||
let sqlite: Arc<dyn Memory> = Arc::new(build_sqlite_memory(
|
||||
config,
|
||||
workspace_dir,
|
||||
&resolved_embedding,
|
||||
)?);
|
||||
let qdrant: Arc<dyn Memory> = Arc::new(build_qdrant_memory(config, &resolved_embedding)?);
|
||||
return Ok(Box::new(SqliteQdrantHybridMemory::new(sqlite, qdrant)));
|
||||
}
|
||||
|
||||
create_memory_with_builders(
|
||||
@@ -451,6 +476,21 @@ mod tests {
|
||||
assert_eq!(mem.name(), "lucid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_sqlite_qdrant_hybrid() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "sqlite_qdrant_hybrid".into(),
|
||||
qdrant: crate::config::QdrantConfig {
|
||||
url: Some("http://localhost:6333".into()),
|
||||
..crate::config::QdrantConfig::default()
|
||||
},
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
|
||||
assert_eq!(mem.name(), "sqlite_qdrant_hybrid");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_none_uses_noop_memory() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
@@ -526,6 +566,26 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_hybrid_requires_qdrant_url() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "sqlite_qdrant_hybrid".into(),
|
||||
qdrant: crate::config::QdrantConfig {
|
||||
url: None,
|
||||
..crate::config::QdrantConfig::default()
|
||||
},
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
|
||||
let error = create_memory(&cfg, tmp.path(), None)
|
||||
.err()
|
||||
.expect("hybrid backend should require qdrant url");
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("Qdrant memory backend requires url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_embedding_config_uses_base_config_when_model_is_not_hint() {
|
||||
let cfg = MemoryConfig {
|
||||
|
||||
+58
-3
@@ -779,6 +779,7 @@ fn default_model_for_provider(provider: &str) -> String {
|
||||
"ollama" => "llama3.2".into(),
|
||||
"llamacpp" => "ggml-org/gpt-oss-20b-GGUF".into(),
|
||||
"sglang" | "vllm" | "osaurus" => "default".into(),
|
||||
"copilot" => "default".into(),
|
||||
"gemini" => "gemini-2.5-pro".into(),
|
||||
"kimi-code" => "kimi-for-coding".into(),
|
||||
"bedrock" => "anthropic.claude-sonnet-4-5-20250929-v1:0".into(),
|
||||
@@ -1225,6 +1226,10 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
||||
"Gemini 2.5 Flash-Lite (lowest cost)".to_string(),
|
||||
),
|
||||
],
|
||||
"copilot" => vec![(
|
||||
"default".to_string(),
|
||||
"Copilot default model (recommended)".to_string(),
|
||||
)],
|
||||
_ => vec![("default".to_string(), "Default model".to_string())],
|
||||
}
|
||||
}
|
||||
@@ -2213,7 +2218,7 @@ async fn setup_workspace() -> Result<(PathBuf, PathBuf)> {
|
||||
async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Option<String>)> {
|
||||
// ── Tier selection ──
|
||||
let tiers = vec![
|
||||
"⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI, Gemini)",
|
||||
"⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI, Gemini, GitHub Copilot)",
|
||||
"⚡ Fast inference (Groq, Fireworks, Together AI, NVIDIA NIM)",
|
||||
"🌐 Gateway / proxy (Vercel AI, Cloudflare AI, Amazon Bedrock)",
|
||||
"🔬 Specialized (Moonshot/Kimi, GLM/Zhipu, MiniMax, Qwen/DashScope, Qianfan, Z.AI, Synthetic, OpenCode Zen, Cohere)",
|
||||
@@ -2240,6 +2245,10 @@ async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String,
|
||||
"openai-codex",
|
||||
"OpenAI Codex (ChatGPT subscription OAuth, no API key)",
|
||||
),
|
||||
(
|
||||
"copilot",
|
||||
"GitHub Copilot — OAuth device flow (Copilot subscription)",
|
||||
),
|
||||
("deepseek", "DeepSeek — V3 & R1 (affordable)"),
|
||||
("mistral", "Mistral — Large & Codestral"),
|
||||
("xai", "xAI — Grok 3 & 4"),
|
||||
@@ -2536,6 +2545,24 @@ async fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String,
|
||||
));
|
||||
}
|
||||
|
||||
key
|
||||
} else if canonical_provider_name(provider_name) == "copilot" {
|
||||
print_bullet("GitHub Copilot uses GitHub OAuth device flow.");
|
||||
print_bullet("Press Enter to keep setup keyless and authenticate on first run.");
|
||||
print_bullet("Optional: paste a GitHub token now to skip the first-run device prompt.");
|
||||
println!();
|
||||
|
||||
let key: String = Input::new()
|
||||
.with_prompt(" Paste your GitHub token (optional; Enter = device flow)")
|
||||
.allow_empty(true)
|
||||
.interact_text()?;
|
||||
|
||||
if key.trim().is_empty() {
|
||||
print_bullet(
|
||||
"No token provided. ZeroClaw will open the GitHub device login flow on first use.",
|
||||
);
|
||||
}
|
||||
|
||||
key
|
||||
} else if canonical_provider_name(provider_name) == "gemini" {
|
||||
// Special handling for Gemini: check for CLI auth first
|
||||
@@ -3649,6 +3676,7 @@ fn setup_identity_backend() -> Result<IdentityConfig> {
|
||||
);
|
||||
IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some(default_path),
|
||||
aieos_inline: None,
|
||||
}
|
||||
@@ -3660,6 +3688,7 @@ fn setup_identity_backend() -> Result<IdentityConfig> {
|
||||
);
|
||||
IdentityConfig {
|
||||
format: "openclaw".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: None,
|
||||
aieos_inline: None,
|
||||
}
|
||||
@@ -6094,8 +6123,19 @@ fn print_summary(config: &Config) {
|
||||
let mut step = 1u8;
|
||||
|
||||
let provider = config.default_provider.as_deref().unwrap_or("openrouter");
|
||||
let canonical_provider = canonical_provider_name(provider);
|
||||
if config.api_key.is_none() && !provider_supports_keyless_local_usage(provider) {
|
||||
if provider == "openai-codex" {
|
||||
if canonical_provider == "copilot" {
|
||||
println!(
|
||||
" {} Authenticate GitHub Copilot:",
|
||||
style(format!("{step}.")).cyan().bold()
|
||||
);
|
||||
println!(" {}", style("zeroclaw agent -m \"Hello!\"").yellow());
|
||||
println!(
|
||||
" {}",
|
||||
style("(device/OAuth prompt appears automatically on first run)").dim()
|
||||
);
|
||||
} else if canonical_provider == "openai-codex" {
|
||||
println!(
|
||||
" {} Authenticate OpenAI Codex:",
|
||||
style(format!("{step}.")).cyan().bold()
|
||||
@@ -6104,7 +6144,7 @@ fn print_summary(config: &Config) {
|
||||
" {}",
|
||||
style("zeroclaw auth login --provider openai-codex --device-code").yellow()
|
||||
);
|
||||
} else if provider == "anthropic" {
|
||||
} else if canonical_provider == "anthropic" {
|
||||
println!(
|
||||
" {} Configure Anthropic auth:",
|
||||
style(format!("{step}.")).cyan().bold()
|
||||
@@ -6576,6 +6616,7 @@ mod tests {
|
||||
};
|
||||
let identity_config = crate::config::IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("identity.aieos.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -6605,6 +6646,7 @@ mod tests {
|
||||
let ctx = ProjectContext::default();
|
||||
let identity_config = crate::config::IdentityConfig {
|
||||
format: "aieos".into(),
|
||||
extra_files: Vec::new(),
|
||||
aieos_path: Some("identity.aieos.json".into()),
|
||||
aieos_inline: None,
|
||||
};
|
||||
@@ -7250,6 +7292,7 @@ mod tests {
|
||||
assert_eq!(default_model_for_provider("zai-cn"), "glm-5");
|
||||
assert_eq!(default_model_for_provider("gemini"), "gemini-2.5-pro");
|
||||
assert_eq!(default_model_for_provider("google"), "gemini-2.5-pro");
|
||||
assert_eq!(default_model_for_provider("copilot"), "default");
|
||||
assert_eq!(default_model_for_provider("kimi-code"), "kimi-for-coding");
|
||||
assert_eq!(
|
||||
default_model_for_provider("bedrock"),
|
||||
@@ -7343,6 +7386,18 @@ mod tests {
|
||||
assert!(ids.contains(&"gpt-5.2-codex".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_copilot_have_default_entry() {
|
||||
let models = curated_models_for_provider("copilot");
|
||||
assert_eq!(
|
||||
models,
|
||||
vec![(
|
||||
"default".to_string(),
|
||||
"Copilot default model (recommended)".to_string(),
|
||||
)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_openrouter_use_valid_anthropic_id() {
|
||||
let ids: Vec<String> = curated_models_for_provider("openrouter")
|
||||
|
||||
Reference in New Issue
Block a user