feat(security): add otp and estop phase-1 foundation
This commit is contained in:
@@ -0,0 +1,259 @@
|
||||
use anyhow::{bail, Result};
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
const BANKING_DOMAINS: &[&str] = &[
|
||||
"*.chase.com",
|
||||
"*.bankofamerica.com",
|
||||
"*.wellsfargo.com",
|
||||
"*.fidelity.com",
|
||||
"*.schwab.com",
|
||||
"*.venmo.com",
|
||||
"*.paypal.com",
|
||||
"*.robinhood.com",
|
||||
"*.coinbase.com",
|
||||
];
|
||||
|
||||
const MEDICAL_DOMAINS: &[&str] = &[
|
||||
"*.mychart.com",
|
||||
"*.epic.com",
|
||||
"*.patient.portal.*",
|
||||
"*.healthrecords.*",
|
||||
];
|
||||
|
||||
const GOVERNMENT_DOMAINS: &[&str] = &["*.ssa.gov", "*.irs.gov", "*.login.gov", "*.id.me"];
|
||||
|
||||
const IDENTITY_PROVIDER_DOMAINS: &[&str] = &[
|
||||
"accounts.google.com",
|
||||
"login.microsoftonline.com",
|
||||
"appleid.apple.com",
|
||||
];
|
||||
|
||||
const DOMAIN_CATEGORIES: &[(&str, &[&str])] = &[
|
||||
("banking", BANKING_DOMAINS),
|
||||
("medical", MEDICAL_DOMAINS),
|
||||
("government", GOVERNMENT_DOMAINS),
|
||||
("identity_providers", IDENTITY_PROVIDER_DOMAINS),
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct DomainMatcher {
|
||||
patterns: Vec<String>,
|
||||
}
|
||||
|
||||
impl DomainMatcher {
|
||||
pub fn new(gated_domains: &[String], categories: &[String]) -> Result<Self> {
|
||||
let mut set = BTreeSet::new();
|
||||
|
||||
for domain in gated_domains {
|
||||
set.insert(normalize_pattern(domain)?);
|
||||
}
|
||||
|
||||
for domain in Self::expand_categories(categories)? {
|
||||
set.insert(domain);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
patterns: set.into_iter().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn patterns(&self) -> &[String] {
|
||||
&self.patterns
|
||||
}
|
||||
|
||||
pub fn is_gated(&self, domain: &str) -> bool {
|
||||
let Some(normalized_domain) = normalize_domain(domain) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
self.patterns
|
||||
.iter()
|
||||
.any(|pattern| domain_matches_pattern(pattern, &normalized_domain))
|
||||
}
|
||||
|
||||
pub fn expand_categories(categories: &[String]) -> Result<Vec<String>> {
|
||||
let mut expanded = Vec::new();
|
||||
for category in categories {
|
||||
let normalized = category.trim().to_ascii_lowercase();
|
||||
let Some((_, domains)) = DOMAIN_CATEGORIES
|
||||
.iter()
|
||||
.find(|(name, _)| *name == normalized.as_str())
|
||||
else {
|
||||
let known = DOMAIN_CATEGORIES
|
||||
.iter()
|
||||
.map(|(name, _)| *name)
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
bail!("Unknown OTP domain category '{category}'. Known categories: {known}");
|
||||
};
|
||||
expanded.extend(domains.iter().map(|domain| (*domain).to_string()));
|
||||
}
|
||||
Ok(expanded)
|
||||
}
|
||||
|
||||
pub fn validate_pattern(pattern: &str) -> Result<()> {
|
||||
let _ = normalize_pattern(pattern)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_domain(raw: &str) -> Option<String> {
|
||||
let mut domain = raw.trim().to_ascii_lowercase();
|
||||
if domain.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some((_, rest)) = domain.split_once("://") {
|
||||
domain = rest.to_string();
|
||||
}
|
||||
|
||||
domain = domain
|
||||
.split(['/', '?', '#'])
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
if let Some((_, host)) = domain.rsplit_once('@') {
|
||||
domain = host.to_string();
|
||||
}
|
||||
if let Some((host, _port)) = domain.split_once(':') {
|
||||
domain = host.to_string();
|
||||
}
|
||||
domain = domain.trim_end_matches('.').to_string();
|
||||
|
||||
if domain.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(domain)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_pattern(raw: &str) -> Result<String> {
|
||||
let pattern = raw.trim().to_ascii_lowercase();
|
||||
if pattern.is_empty() {
|
||||
bail!("Domain pattern must not be empty");
|
||||
}
|
||||
if pattern == "*" {
|
||||
return Ok(pattern);
|
||||
}
|
||||
if pattern.starts_with('.') || pattern.ends_with('.') {
|
||||
bail!("Domain pattern '{raw}' must not start or end with '.'");
|
||||
}
|
||||
if pattern.contains("..") {
|
||||
bail!("Domain pattern '{raw}' must not contain consecutive dots");
|
||||
}
|
||||
if pattern.contains("**") {
|
||||
bail!("Domain pattern '{raw}' must not contain consecutive '*'");
|
||||
}
|
||||
if !pattern
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '.' || c == '-' || c == '*')
|
||||
{
|
||||
bail!(
|
||||
"Domain pattern '{raw}' contains invalid characters; allowed: a-z, 0-9, '.', '-', '*'"
|
||||
);
|
||||
}
|
||||
if pattern.split('.').any(|label| label.is_empty()) {
|
||||
bail!("Domain pattern '{raw}' contains an empty label");
|
||||
}
|
||||
if pattern.starts_with("*.") && pattern.len() <= 2 {
|
||||
bail!("Domain pattern '{raw}' is incomplete");
|
||||
}
|
||||
Ok(pattern)
|
||||
}
|
||||
|
||||
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
if !pattern.contains('*') {
|
||||
return pattern == domain;
|
||||
}
|
||||
wildcard_match(pattern.as_bytes(), domain.as_bytes())
|
||||
}
|
||||
|
||||
fn wildcard_match(pattern: &[u8], value: &[u8]) -> bool {
|
||||
let mut p = 0usize;
|
||||
let mut v = 0usize;
|
||||
let mut star_idx: Option<usize> = None;
|
||||
let mut match_idx = 0usize;
|
||||
|
||||
while v < value.len() {
|
||||
if p < pattern.len() && pattern[p] == value[v] {
|
||||
p += 1;
|
||||
v += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if p < pattern.len() && pattern[p] == b'*' {
|
||||
star_idx = Some(p);
|
||||
p += 1;
|
||||
match_idx = v;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(star) = star_idx {
|
||||
p = star + 1;
|
||||
match_idx += 1;
|
||||
v = match_idx;
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
while p < pattern.len() && pattern[p] == b'*' {
|
||||
p += 1;
|
||||
}
|
||||
p == pattern.len()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn exact_match_works() {
|
||||
let matcher =
|
||||
DomainMatcher::new(&["accounts.google.com".to_string()], &[] as &[String]).unwrap();
|
||||
assert!(matcher.is_gated("accounts.google.com"));
|
||||
assert!(matcher.is_gated("https://accounts.google.com/login"));
|
||||
assert!(!matcher.is_gated("mail.google.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_match_works() {
|
||||
let matcher = DomainMatcher::new(&["*.chase.com".to_string()], &[] as &[String]).unwrap();
|
||||
assert!(matcher.is_gated("www.chase.com"));
|
||||
assert!(matcher.is_gated("secure.chase.com"));
|
||||
assert!(!matcher.is_gated("chase.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn category_preset_expands_and_matches() {
|
||||
let matcher = DomainMatcher::new(&[] as &[String], &["banking".to_string()]).unwrap();
|
||||
assert!(matcher.is_gated("login.paypal.com"));
|
||||
assert!(matcher.is_gated("api.coinbase.com"));
|
||||
assert!(!matcher.is_gated("developer.mozilla.org"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_matching_domain_returns_false() {
|
||||
let matcher =
|
||||
DomainMatcher::new(&["accounts.google.com".to_string()], &[] as &[String]).unwrap();
|
||||
assert!(!matcher.is_gated("example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn malformed_domain_pattern_is_rejected() {
|
||||
let err = DomainMatcher::new(&["bad domain.com".to_string()], &[] as &[String])
|
||||
.expect_err("expected invalid pattern");
|
||||
assert!(err.to_string().contains("invalid characters"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_category_is_rejected() {
|
||||
let err = DomainMatcher::new(&[] as &[String], &["unknown".to_string()])
|
||||
.expect_err("expected unknown category rejection");
|
||||
assert!(err.to_string().contains("Unknown OTP domain category"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
use crate::config::EstopConfig;
|
||||
use crate::security::domain_matcher::DomainMatcher;
|
||||
use crate::security::otp::OtpValidator;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum EstopLevel {
|
||||
KillAll,
|
||||
NetworkKill,
|
||||
DomainBlock(Vec<String>),
|
||||
ToolFreeze(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ResumeSelector {
|
||||
KillAll,
|
||||
Network,
|
||||
Domains(Vec<String>),
|
||||
Tools(Vec<String>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
|
||||
pub struct EstopState {
|
||||
#[serde(default)]
|
||||
pub kill_all: bool,
|
||||
#[serde(default)]
|
||||
pub network_kill: bool,
|
||||
#[serde(default)]
|
||||
pub blocked_domains: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub frozen_tools: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub updated_at: Option<String>,
|
||||
}
|
||||
|
||||
impl EstopState {
|
||||
pub fn fail_closed() -> Self {
|
||||
Self {
|
||||
kill_all: true,
|
||||
network_kill: false,
|
||||
blocked_domains: Vec::new(),
|
||||
frozen_tools: Vec::new(),
|
||||
updated_at: Some(now_rfc3339()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_engaged(&self) -> bool {
|
||||
self.kill_all
|
||||
|| self.network_kill
|
||||
|| !self.blocked_domains.is_empty()
|
||||
|| !self.frozen_tools.is_empty()
|
||||
}
|
||||
|
||||
fn normalize(&mut self) {
|
||||
self.blocked_domains = dedup_sort(&self.blocked_domains);
|
||||
self.frozen_tools = dedup_sort(&self.frozen_tools);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EstopManager {
|
||||
config: EstopConfig,
|
||||
state_path: PathBuf,
|
||||
state: EstopState,
|
||||
}
|
||||
|
||||
impl EstopManager {
|
||||
pub fn load(config: &EstopConfig, config_dir: &Path) -> Result<Self> {
|
||||
let state_path = resolve_state_file_path(config_dir, &config.state_file);
|
||||
let mut should_fail_closed = false;
|
||||
let mut state = if state_path.exists() {
|
||||
match fs::read_to_string(&state_path) {
|
||||
Ok(raw) => match serde_json::from_str::<EstopState>(&raw) {
|
||||
Ok(mut parsed) => {
|
||||
parsed.normalize();
|
||||
parsed
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
path = %state_path.display(),
|
||||
"Failed to parse estop state file; entering fail-closed mode: {error}"
|
||||
);
|
||||
should_fail_closed = true;
|
||||
EstopState::fail_closed()
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!(
|
||||
path = %state_path.display(),
|
||||
"Failed to read estop state file; entering fail-closed mode: {error}"
|
||||
);
|
||||
should_fail_closed = true;
|
||||
EstopState::fail_closed()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
EstopState::default()
|
||||
};
|
||||
|
||||
state.normalize();
|
||||
|
||||
let mut manager = Self {
|
||||
config: config.clone(),
|
||||
state_path,
|
||||
state,
|
||||
};
|
||||
|
||||
if should_fail_closed {
|
||||
let _ = manager.persist_state();
|
||||
}
|
||||
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
pub fn state_path(&self) -> &Path {
|
||||
&self.state_path
|
||||
}
|
||||
|
||||
pub fn status(&self) -> EstopState {
|
||||
self.state.clone()
|
||||
}
|
||||
|
||||
pub fn engage(&mut self, level: EstopLevel) -> Result<()> {
|
||||
match level {
|
||||
EstopLevel::KillAll => {
|
||||
self.state.kill_all = true;
|
||||
}
|
||||
EstopLevel::NetworkKill => {
|
||||
self.state.network_kill = true;
|
||||
}
|
||||
EstopLevel::DomainBlock(domains) => {
|
||||
for domain in domains {
|
||||
let normalized = domain.trim().to_ascii_lowercase();
|
||||
DomainMatcher::validate_pattern(&normalized)?;
|
||||
self.state.blocked_domains.push(normalized);
|
||||
}
|
||||
}
|
||||
EstopLevel::ToolFreeze(tools) => {
|
||||
for tool in tools {
|
||||
let normalized = normalize_tool_name(&tool)?;
|
||||
self.state.frozen_tools.push(normalized);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.state.updated_at = Some(now_rfc3339());
|
||||
self.state.normalize();
|
||||
self.persist_state()
|
||||
}
|
||||
|
||||
pub fn resume(
|
||||
&mut self,
|
||||
selector: ResumeSelector,
|
||||
otp_code: Option<&str>,
|
||||
otp_validator: Option<&OtpValidator>,
|
||||
) -> Result<()> {
|
||||
self.ensure_resume_is_authorized(otp_code, otp_validator)?;
|
||||
|
||||
match selector {
|
||||
ResumeSelector::KillAll => {
|
||||
self.state.kill_all = false;
|
||||
}
|
||||
ResumeSelector::Network => {
|
||||
self.state.network_kill = false;
|
||||
}
|
||||
ResumeSelector::Domains(domains) => {
|
||||
let normalized = domains
|
||||
.iter()
|
||||
.map(|domain| domain.trim().to_ascii_lowercase())
|
||||
.collect::<Vec<_>>();
|
||||
self.state
|
||||
.blocked_domains
|
||||
.retain(|existing| !normalized.iter().any(|target| target == existing));
|
||||
}
|
||||
ResumeSelector::Tools(tools) => {
|
||||
let normalized = tools
|
||||
.iter()
|
||||
.map(|tool| normalize_tool_name(tool))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
self.state
|
||||
.frozen_tools
|
||||
.retain(|existing| !normalized.iter().any(|target| target == existing));
|
||||
}
|
||||
}
|
||||
|
||||
self.state.updated_at = Some(now_rfc3339());
|
||||
self.state.normalize();
|
||||
self.persist_state()
|
||||
}
|
||||
|
||||
fn ensure_resume_is_authorized(
|
||||
&self,
|
||||
otp_code: Option<&str>,
|
||||
otp_validator: Option<&OtpValidator>,
|
||||
) -> Result<()> {
|
||||
if !self.config.require_otp_to_resume {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let code = otp_code
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.context("OTP code is required to resume estop state")?;
|
||||
let validator = otp_validator
|
||||
.context("OTP validator is required to resume estop state with OTP enabled")?;
|
||||
let valid = validator.validate(code)?;
|
||||
if !valid {
|
||||
anyhow::bail!("Invalid OTP code; estop resume denied");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn persist_state(&mut self) -> Result<()> {
|
||||
if let Some(parent) = self.state_path.parent() {
|
||||
fs::create_dir_all(parent).with_context(|| {
|
||||
format!("Failed to create estop state dir {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
|
||||
let body =
|
||||
serde_json::to_string_pretty(&self.state).context("Failed to serialize estop state")?;
|
||||
|
||||
let temp_path = self
|
||||
.state_path
|
||||
.with_extension(format!("tmp-{}", uuid::Uuid::new_v4()));
|
||||
fs::write(&temp_path, body).with_context(|| {
|
||||
format!(
|
||||
"Failed to write temporary estop state file {}",
|
||||
temp_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = fs::set_permissions(&temp_path, fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
fs::rename(&temp_path, &self.state_path).with_context(|| {
|
||||
format!(
|
||||
"Failed to atomically replace estop state file {}",
|
||||
self.state_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_state_file_path(config_dir: &Path, state_file: &str) -> PathBuf {
|
||||
let expanded = shellexpand::tilde(state_file).into_owned();
|
||||
let path = PathBuf::from(expanded);
|
||||
if path.is_absolute() {
|
||||
path
|
||||
} else {
|
||||
config_dir.join(path)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_tool_name(raw: &str) -> Result<String> {
|
||||
let value = raw.trim().to_ascii_lowercase();
|
||||
if value.is_empty() {
|
||||
anyhow::bail!("Tool name must not be empty");
|
||||
}
|
||||
if !value
|
||||
.chars()
|
||||
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-')
|
||||
{
|
||||
anyhow::bail!("Tool name '{raw}' contains invalid characters");
|
||||
}
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn dedup_sort(values: &[String]) -> Vec<String> {
|
||||
let mut deduped = values
|
||||
.iter()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>();
|
||||
deduped.sort_unstable();
|
||||
deduped.dedup();
|
||||
deduped
|
||||
}
|
||||
|
||||
fn now_rfc3339() -> String {
|
||||
let secs = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|duration| duration.as_secs())
|
||||
.unwrap_or(0);
|
||||
chrono::DateTime::<chrono::Utc>::from_timestamp(secs as i64, 0)
|
||||
.unwrap_or(chrono::DateTime::<chrono::Utc>::UNIX_EPOCH)
|
||||
.to_rfc3339()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::OtpConfig;
|
||||
use crate::security::otp::OtpValidator;
|
||||
use crate::security::SecretStore;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn estop_config(path: &Path) -> EstopConfig {
|
||||
EstopConfig {
|
||||
enabled: true,
|
||||
state_file: path.display().to_string(),
|
||||
require_otp_to_resume: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estop_levels_compose_and_resume() {
|
||||
let dir = tempdir().unwrap();
|
||||
let state_path = dir.path().join("estop-state.json");
|
||||
let cfg = estop_config(&state_path);
|
||||
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
|
||||
|
||||
manager
|
||||
.engage(EstopLevel::DomainBlock(vec!["*.chase.com".into()]))
|
||||
.unwrap();
|
||||
manager
|
||||
.engage(EstopLevel::ToolFreeze(vec!["shell".into()]))
|
||||
.unwrap();
|
||||
manager.engage(EstopLevel::NetworkKill).unwrap();
|
||||
assert!(manager.status().network_kill);
|
||||
assert_eq!(manager.status().blocked_domains, vec!["*.chase.com"]);
|
||||
assert_eq!(manager.status().frozen_tools, vec!["shell"]);
|
||||
|
||||
manager
|
||||
.resume(
|
||||
ResumeSelector::Domains(vec!["*.chase.com".into()]),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(manager.status().blocked_domains.is_empty());
|
||||
assert!(manager.status().network_kill);
|
||||
|
||||
manager
|
||||
.resume(ResumeSelector::Tools(vec!["shell".into()]), None, None)
|
||||
.unwrap();
|
||||
assert!(manager.status().frozen_tools.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn estop_state_survives_reload() {
|
||||
let dir = tempdir().unwrap();
|
||||
let state_path = dir.path().join("estop-state.json");
|
||||
let cfg = estop_config(&state_path);
|
||||
|
||||
{
|
||||
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
|
||||
manager.engage(EstopLevel::KillAll).unwrap();
|
||||
manager
|
||||
.engage(EstopLevel::DomainBlock(vec!["*.paypal.com".into()]))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let reloaded = EstopManager::load(&cfg, dir.path()).unwrap();
|
||||
let state = reloaded.status();
|
||||
assert!(state.kill_all);
|
||||
assert_eq!(state.blocked_domains, vec!["*.paypal.com"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn corrupted_state_defaults_to_fail_closed_kill_all() {
|
||||
let dir = tempdir().unwrap();
|
||||
let state_path = dir.path().join("estop-state.json");
|
||||
fs::write(&state_path, "{not-valid-json").unwrap();
|
||||
let cfg = estop_config(&state_path);
|
||||
let manager = EstopManager::load(&cfg, dir.path()).unwrap();
|
||||
assert!(manager.status().kill_all);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_requires_valid_otp_when_enabled() {
|
||||
let dir = tempdir().unwrap();
|
||||
let state_path = dir.path().join("estop-state.json");
|
||||
let mut cfg = estop_config(&state_path);
|
||||
cfg.require_otp_to_resume = true;
|
||||
|
||||
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
|
||||
manager.engage(EstopLevel::KillAll).unwrap();
|
||||
|
||||
let err = manager
|
||||
.resume(ResumeSelector::KillAll, None, None)
|
||||
.expect_err("resume should require OTP");
|
||||
assert!(err.to_string().contains("OTP code is required"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_accepts_valid_otp_code() {
|
||||
let dir = tempdir().unwrap();
|
||||
let state_path = dir.path().join("estop-state.json");
|
||||
let mut cfg = estop_config(&state_path);
|
||||
cfg.require_otp_to_resume = true;
|
||||
|
||||
let otp_cfg = OtpConfig {
|
||||
enabled: true,
|
||||
..OtpConfig::default()
|
||||
};
|
||||
let store = SecretStore::new(dir.path(), true);
|
||||
let (validator, _) = OtpValidator::from_config(&otp_cfg, dir.path(), &store).unwrap();
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|duration| duration.as_secs())
|
||||
.unwrap_or(0);
|
||||
let code = validator.code_for_timestamp(now);
|
||||
|
||||
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
|
||||
manager.engage(EstopLevel::KillAll).unwrap();
|
||||
manager
|
||||
.resume(ResumeSelector::KillAll, Some(&code), Some(&validator))
|
||||
.unwrap();
|
||||
assert!(!manager.status().kill_all);
|
||||
}
|
||||
}
|
||||
@@ -23,10 +23,13 @@ pub mod audit;
|
||||
pub mod bubblewrap;
|
||||
pub mod detect;
|
||||
pub mod docker;
|
||||
pub mod domain_matcher;
|
||||
pub mod estop;
|
||||
#[cfg(target_os = "linux")]
|
||||
pub mod firejail;
|
||||
#[cfg(feature = "sandbox-landlock")]
|
||||
pub mod landlock;
|
||||
pub mod otp;
|
||||
pub mod pairing;
|
||||
pub mod policy;
|
||||
pub mod secrets;
|
||||
@@ -36,6 +39,11 @@ pub mod traits;
|
||||
pub use audit::{AuditEvent, AuditEventType, AuditLogger};
|
||||
#[allow(unused_imports)]
|
||||
pub use detect::create_sandbox;
|
||||
pub use domain_matcher::DomainMatcher;
|
||||
#[allow(unused_imports)]
|
||||
pub use estop::{EstopLevel, EstopManager, EstopState, ResumeSelector};
|
||||
#[allow(unused_imports)]
|
||||
pub use otp::OtpValidator;
|
||||
#[allow(unused_imports)]
|
||||
pub use pairing::PairingGuard;
|
||||
pub use policy::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
@@ -0,0 +1,318 @@
|
||||
use crate::config::OtpConfig;
|
||||
use crate::security::secrets::SecretStore;
|
||||
use anyhow::{Context, Result};
|
||||
use parking_lot::Mutex;
|
||||
use ring::hmac;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
const OTP_SECRET_FILE: &str = "otp-secret";
|
||||
const OTP_DIGITS: u32 = 6;
|
||||
const OTP_ISSUER: &str = "ZeroClaw";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OtpValidator {
|
||||
config: OtpConfig,
|
||||
secret: Vec<u8>,
|
||||
cached_codes: Mutex<HashMap<String, u64>>,
|
||||
}
|
||||
|
||||
impl OtpValidator {
|
||||
pub fn from_config(
|
||||
config: &OtpConfig,
|
||||
zeroclaw_dir: &Path,
|
||||
store: &SecretStore,
|
||||
) -> Result<(Self, Option<String>)> {
|
||||
let secret_path = secret_file_path(zeroclaw_dir);
|
||||
let (secret, generated) = if secret_path.exists() {
|
||||
let encoded = fs::read_to_string(&secret_path).with_context(|| {
|
||||
format!("Failed to read OTP secret file {}", secret_path.display())
|
||||
})?;
|
||||
let decrypted = store
|
||||
.decrypt(encoded.trim())
|
||||
.context("Failed to decrypt OTP secret file")?;
|
||||
(decode_base32_secret(&decrypted)?, false)
|
||||
} else {
|
||||
let raw: [u8; 20] = rand::random();
|
||||
let encoded_secret = encode_base32_secret(&raw);
|
||||
let encrypted = store
|
||||
.encrypt(&encoded_secret)
|
||||
.context("Failed to encrypt OTP secret")?;
|
||||
write_secret_file(&secret_path, &encrypted)?;
|
||||
(raw.to_vec(), true)
|
||||
};
|
||||
|
||||
let validator = Self {
|
||||
config: config.clone(),
|
||||
secret,
|
||||
cached_codes: Mutex::new(HashMap::new()),
|
||||
};
|
||||
let uri = if generated {
|
||||
Some(validator.otpauth_uri())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok((validator, uri))
|
||||
}
|
||||
|
||||
pub fn validate(&self, code: &str) -> Result<bool> {
|
||||
self.validate_at(code, unix_timestamp_now())
|
||||
}
|
||||
|
||||
fn validate_at(&self, code: &str, now_secs: u64) -> Result<bool> {
|
||||
let normalized = code.trim();
|
||||
if normalized.len() != OTP_DIGITS as usize
|
||||
|| !normalized.chars().all(|ch| ch.is_ascii_digit())
|
||||
{
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
{
|
||||
let mut cache = self.cached_codes.lock();
|
||||
cache.retain(|_, expiry| *expiry >= now_secs);
|
||||
if cache
|
||||
.get(normalized)
|
||||
.is_some_and(|expiry| *expiry >= now_secs)
|
||||
{
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
let step = self.config.token_ttl_secs.max(1);
|
||||
let counter = now_secs / step;
|
||||
let counters = [
|
||||
counter.saturating_sub(1),
|
||||
counter,
|
||||
counter.saturating_add(1),
|
||||
];
|
||||
|
||||
let is_valid = counters
|
||||
.iter()
|
||||
.map(|c| compute_totp_code(&self.secret, *c))
|
||||
.any(|candidate| candidate == normalized);
|
||||
|
||||
if is_valid {
|
||||
let mut cache = self.cached_codes.lock();
|
||||
cache.insert(
|
||||
normalized.to_string(),
|
||||
now_secs.saturating_add(self.config.cache_valid_secs),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(is_valid)
|
||||
}
|
||||
|
||||
pub fn otpauth_uri(&self) -> String {
|
||||
let secret = encode_base32_secret(&self.secret);
|
||||
let account = "zeroclaw";
|
||||
format!(
|
||||
"otpauth://totp/{issuer}:{account}?secret={secret}&issuer={issuer}&period={period}",
|
||||
issuer = OTP_ISSUER,
|
||||
period = self.config.token_ttl_secs.max(1)
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn code_for_timestamp(&self, timestamp: u64) -> String {
|
||||
let counter = timestamp / self.config.token_ttl_secs.max(1);
|
||||
compute_totp_code(&self.secret, counter)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn secret_file_path(zeroclaw_dir: &Path) -> PathBuf {
|
||||
zeroclaw_dir.join(OTP_SECRET_FILE)
|
||||
}
|
||||
|
||||
fn write_secret_file(path: &Path, value: &str) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create directory {}", parent.display()))?;
|
||||
}
|
||||
|
||||
let temp_path = path.with_extension(format!("tmp-{}", uuid::Uuid::new_v4()));
|
||||
fs::write(&temp_path, value).with_context(|| {
|
||||
format!(
|
||||
"Failed to write temporary OTP secret {}",
|
||||
temp_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = fs::set_permissions(&temp_path, fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
fs::rename(&temp_path, path).with_context(|| {
|
||||
format!(
|
||||
"Failed to atomically replace OTP secret file {}",
|
||||
path.display()
|
||||
)
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unix_timestamp_now() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|duration| duration.as_secs())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn compute_totp_code(secret: &[u8], counter: u64) -> String {
|
||||
let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, secret);
|
||||
let counter_bytes = counter.to_be_bytes();
|
||||
let digest = hmac::sign(&key, &counter_bytes);
|
||||
let hash = digest.as_ref();
|
||||
|
||||
let offset = (hash[19] & 0x0f) as usize;
|
||||
let binary = ((u32::from(hash[offset]) & 0x7f) << 24)
|
||||
| (u32::from(hash[offset + 1]) << 16)
|
||||
| (u32::from(hash[offset + 2]) << 8)
|
||||
| u32::from(hash[offset + 3]);
|
||||
|
||||
let code = binary % 10_u32.pow(OTP_DIGITS);
|
||||
format!("{code:0>6}")
|
||||
}
|
||||
|
||||
fn encode_base32_secret(input: &[u8]) -> String {
|
||||
const ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
|
||||
if input.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut result = String::new();
|
||||
let mut buffer = 0u16;
|
||||
let mut bits_left = 0u8;
|
||||
|
||||
for byte in input {
|
||||
buffer = (buffer << 8) | u16::from(*byte);
|
||||
bits_left += 8;
|
||||
|
||||
while bits_left >= 5 {
|
||||
let index = ((buffer >> (bits_left - 5)) & 0x1f) as usize;
|
||||
result.push(ALPHABET[index] as char);
|
||||
bits_left -= 5;
|
||||
}
|
||||
}
|
||||
|
||||
if bits_left > 0 {
|
||||
let index = ((buffer << (5 - bits_left)) & 0x1f) as usize;
|
||||
result.push(ALPHABET[index] as char);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn decode_base32_secret(raw: &str) -> Result<Vec<u8>> {
|
||||
fn decode_char(ch: char) -> Option<u8> {
|
||||
match ch {
|
||||
'A'..='Z' => Some((ch as u8) - b'A'),
|
||||
'2'..='7' => Some((ch as u8) - b'2' + 26),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
let mut cleaned = raw
|
||||
.chars()
|
||||
.filter(|ch| !matches!(ch, ' ' | '\t' | '\n' | '\r' | '-'))
|
||||
.collect::<String>()
|
||||
.to_ascii_uppercase();
|
||||
while cleaned.ends_with('=') {
|
||||
cleaned.pop();
|
||||
}
|
||||
if cleaned.is_empty() {
|
||||
anyhow::bail!("OTP secret is empty");
|
||||
}
|
||||
|
||||
let mut output = Vec::new();
|
||||
let mut buffer = 0u32;
|
||||
let mut bits_left = 0u8;
|
||||
|
||||
for ch in cleaned.chars() {
|
||||
let value = decode_char(ch)
|
||||
.with_context(|| format!("OTP secret contains invalid base32 character '{ch}'"))?;
|
||||
buffer = (buffer << 5) | u32::from(value);
|
||||
bits_left += 5;
|
||||
|
||||
if bits_left >= 8 {
|
||||
let byte = ((buffer >> (bits_left - 8)) & 0xff) as u8;
|
||||
output.push(byte);
|
||||
bits_left -= 8;
|
||||
}
|
||||
}
|
||||
|
||||
if output.is_empty() {
|
||||
anyhow::bail!("OTP secret did not decode to any bytes");
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn test_config() -> OtpConfig {
|
||||
OtpConfig {
|
||||
enabled: true,
|
||||
token_ttl_secs: 30,
|
||||
cache_valid_secs: 120,
|
||||
..OtpConfig::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_totp_code_is_accepted() {
|
||||
let dir = tempdir().unwrap();
|
||||
let store = SecretStore::new(dir.path(), true);
|
||||
let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
|
||||
|
||||
let now = 1_700_000_000u64;
|
||||
let code = validator.code_for_timestamp(now);
|
||||
assert!(validator.validate_at(&code, now).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expired_totp_code_is_rejected() {
|
||||
let dir = tempdir().unwrap();
|
||||
let store = SecretStore::new(dir.path(), true);
|
||||
let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
|
||||
|
||||
let stale = 1_700_000_000u64;
|
||||
let now = stale + 300;
|
||||
let code = validator.code_for_timestamp(stale);
|
||||
assert!(!validator.validate_at(&code, now).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_totp_code_is_rejected() {
|
||||
let dir = tempdir().unwrap();
|
||||
let store = SecretStore::new(dir.path(), true);
|
||||
let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
|
||||
assert!(!validator.validate_at("123456", 1_700_000_000).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secret_is_generated_and_reused() {
|
||||
let dir = tempdir().unwrap();
|
||||
let store = SecretStore::new(dir.path(), true);
|
||||
|
||||
let (first, first_uri) =
|
||||
OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
|
||||
assert!(first_uri.is_some());
|
||||
|
||||
let secret_path = secret_file_path(dir.path());
|
||||
let stored = fs::read_to_string(&secret_path).unwrap();
|
||||
assert!(SecretStore::is_encrypted(stored.trim()));
|
||||
|
||||
let (second, second_uri) =
|
||||
OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
|
||||
assert!(second_uri.is_none());
|
||||
|
||||
let ts = 1_700_000_000u64;
|
||||
assert_eq!(first.code_for_timestamp(ts), second.code_for_timestamp(ts));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user