feat(security): add otp and estop phase-1 foundation

This commit is contained in:
Chummy
2026-02-21 17:19:18 +08:00
parent 9098b379dd
commit a36b1466ff
11 changed files with 1612 additions and 11 deletions
+259
View File
@@ -0,0 +1,259 @@
use anyhow::{bail, Result};
use std::collections::BTreeSet;
const BANKING_DOMAINS: &[&str] = &[
"*.chase.com",
"*.bankofamerica.com",
"*.wellsfargo.com",
"*.fidelity.com",
"*.schwab.com",
"*.venmo.com",
"*.paypal.com",
"*.robinhood.com",
"*.coinbase.com",
];
const MEDICAL_DOMAINS: &[&str] = &[
"*.mychart.com",
"*.epic.com",
"*.patient.portal.*",
"*.healthrecords.*",
];
const GOVERNMENT_DOMAINS: &[&str] = &["*.ssa.gov", "*.irs.gov", "*.login.gov", "*.id.me"];
const IDENTITY_PROVIDER_DOMAINS: &[&str] = &[
"accounts.google.com",
"login.microsoftonline.com",
"appleid.apple.com",
];
const DOMAIN_CATEGORIES: &[(&str, &[&str])] = &[
("banking", BANKING_DOMAINS),
("medical", MEDICAL_DOMAINS),
("government", GOVERNMENT_DOMAINS),
("identity_providers", IDENTITY_PROVIDER_DOMAINS),
];
#[derive(Debug, Clone, Default)]
pub struct DomainMatcher {
patterns: Vec<String>,
}
impl DomainMatcher {
pub fn new(gated_domains: &[String], categories: &[String]) -> Result<Self> {
let mut set = BTreeSet::new();
for domain in gated_domains {
set.insert(normalize_pattern(domain)?);
}
for domain in Self::expand_categories(categories)? {
set.insert(domain);
}
Ok(Self {
patterns: set.into_iter().collect(),
})
}
pub fn patterns(&self) -> &[String] {
&self.patterns
}
pub fn is_gated(&self, domain: &str) -> bool {
let Some(normalized_domain) = normalize_domain(domain) else {
return false;
};
self.patterns
.iter()
.any(|pattern| domain_matches_pattern(pattern, &normalized_domain))
}
pub fn expand_categories(categories: &[String]) -> Result<Vec<String>> {
let mut expanded = Vec::new();
for category in categories {
let normalized = category.trim().to_ascii_lowercase();
let Some((_, domains)) = DOMAIN_CATEGORIES
.iter()
.find(|(name, _)| *name == normalized.as_str())
else {
let known = DOMAIN_CATEGORIES
.iter()
.map(|(name, _)| *name)
.collect::<Vec<_>>()
.join(", ");
bail!("Unknown OTP domain category '{category}'. Known categories: {known}");
};
expanded.extend(domains.iter().map(|domain| (*domain).to_string()));
}
Ok(expanded)
}
pub fn validate_pattern(pattern: &str) -> Result<()> {
let _ = normalize_pattern(pattern)?;
Ok(())
}
}
fn normalize_domain(raw: &str) -> Option<String> {
let mut domain = raw.trim().to_ascii_lowercase();
if domain.is_empty() {
return None;
}
if let Some((_, rest)) = domain.split_once("://") {
domain = rest.to_string();
}
domain = domain
.split(['/', '?', '#'])
.next()
.unwrap_or_default()
.to_string();
if let Some((_, host)) = domain.rsplit_once('@') {
domain = host.to_string();
}
if let Some((host, _port)) = domain.split_once(':') {
domain = host.to_string();
}
domain = domain.trim_end_matches('.').to_string();
if domain.is_empty() {
None
} else {
Some(domain)
}
}
fn normalize_pattern(raw: &str) -> Result<String> {
let pattern = raw.trim().to_ascii_lowercase();
if pattern.is_empty() {
bail!("Domain pattern must not be empty");
}
if pattern == "*" {
return Ok(pattern);
}
if pattern.starts_with('.') || pattern.ends_with('.') {
bail!("Domain pattern '{raw}' must not start or end with '.'");
}
if pattern.contains("..") {
bail!("Domain pattern '{raw}' must not contain consecutive dots");
}
if pattern.contains("**") {
bail!("Domain pattern '{raw}' must not contain consecutive '*'");
}
if !pattern
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '.' || c == '-' || c == '*')
{
bail!(
"Domain pattern '{raw}' contains invalid characters; allowed: a-z, 0-9, '.', '-', '*'"
);
}
if pattern.split('.').any(|label| label.is_empty()) {
bail!("Domain pattern '{raw}' contains an empty label");
}
if pattern.starts_with("*.") && pattern.len() <= 2 {
bail!("Domain pattern '{raw}' is incomplete");
}
Ok(pattern)
}
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
if pattern == "*" {
return true;
}
if !pattern.contains('*') {
return pattern == domain;
}
wildcard_match(pattern.as_bytes(), domain.as_bytes())
}
fn wildcard_match(pattern: &[u8], value: &[u8]) -> bool {
let mut p = 0usize;
let mut v = 0usize;
let mut star_idx: Option<usize> = None;
let mut match_idx = 0usize;
while v < value.len() {
if p < pattern.len() && pattern[p] == value[v] {
p += 1;
v += 1;
continue;
}
if p < pattern.len() && pattern[p] == b'*' {
star_idx = Some(p);
p += 1;
match_idx = v;
continue;
}
if let Some(star) = star_idx {
p = star + 1;
match_idx += 1;
v = match_idx;
continue;
}
return false;
}
while p < pattern.len() && pattern[p] == b'*' {
p += 1;
}
p == pattern.len()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exact_match_works() {
let matcher =
DomainMatcher::new(&["accounts.google.com".to_string()], &[] as &[String]).unwrap();
assert!(matcher.is_gated("accounts.google.com"));
assert!(matcher.is_gated("https://accounts.google.com/login"));
assert!(!matcher.is_gated("mail.google.com"));
}
#[test]
fn wildcard_match_works() {
let matcher = DomainMatcher::new(&["*.chase.com".to_string()], &[] as &[String]).unwrap();
assert!(matcher.is_gated("www.chase.com"));
assert!(matcher.is_gated("secure.chase.com"));
assert!(!matcher.is_gated("chase.com"));
}
#[test]
fn category_preset_expands_and_matches() {
let matcher = DomainMatcher::new(&[] as &[String], &["banking".to_string()]).unwrap();
assert!(matcher.is_gated("login.paypal.com"));
assert!(matcher.is_gated("api.coinbase.com"));
assert!(!matcher.is_gated("developer.mozilla.org"));
}
#[test]
fn non_matching_domain_returns_false() {
let matcher =
DomainMatcher::new(&["accounts.google.com".to_string()], &[] as &[String]).unwrap();
assert!(!matcher.is_gated("example.com"));
}
#[test]
fn malformed_domain_pattern_is_rejected() {
let err = DomainMatcher::new(&["bad domain.com".to_string()], &[] as &[String])
.expect_err("expected invalid pattern");
assert!(err.to_string().contains("invalid characters"));
}
#[test]
fn unknown_category_is_rejected() {
let err = DomainMatcher::new(&[] as &[String], &["unknown".to_string()])
.expect_err("expected unknown category rejection");
assert!(err.to_string().contains("Unknown OTP domain category"));
}
}
+422
View File
@@ -0,0 +1,422 @@
use crate::config::EstopConfig;
use crate::security::domain_matcher::DomainMatcher;
use crate::security::otp::OtpValidator;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EstopLevel {
KillAll,
NetworkKill,
DomainBlock(Vec<String>),
ToolFreeze(Vec<String>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResumeSelector {
KillAll,
Network,
Domains(Vec<String>),
Tools(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct EstopState {
#[serde(default)]
pub kill_all: bool,
#[serde(default)]
pub network_kill: bool,
#[serde(default)]
pub blocked_domains: Vec<String>,
#[serde(default)]
pub frozen_tools: Vec<String>,
#[serde(default)]
pub updated_at: Option<String>,
}
impl EstopState {
pub fn fail_closed() -> Self {
Self {
kill_all: true,
network_kill: false,
blocked_domains: Vec::new(),
frozen_tools: Vec::new(),
updated_at: Some(now_rfc3339()),
}
}
pub fn is_engaged(&self) -> bool {
self.kill_all
|| self.network_kill
|| !self.blocked_domains.is_empty()
|| !self.frozen_tools.is_empty()
}
fn normalize(&mut self) {
self.blocked_domains = dedup_sort(&self.blocked_domains);
self.frozen_tools = dedup_sort(&self.frozen_tools);
}
}
#[derive(Debug, Clone)]
pub struct EstopManager {
config: EstopConfig,
state_path: PathBuf,
state: EstopState,
}
impl EstopManager {
pub fn load(config: &EstopConfig, config_dir: &Path) -> Result<Self> {
let state_path = resolve_state_file_path(config_dir, &config.state_file);
let mut should_fail_closed = false;
let mut state = if state_path.exists() {
match fs::read_to_string(&state_path) {
Ok(raw) => match serde_json::from_str::<EstopState>(&raw) {
Ok(mut parsed) => {
parsed.normalize();
parsed
}
Err(error) => {
tracing::warn!(
path = %state_path.display(),
"Failed to parse estop state file; entering fail-closed mode: {error}"
);
should_fail_closed = true;
EstopState::fail_closed()
}
},
Err(error) => {
tracing::warn!(
path = %state_path.display(),
"Failed to read estop state file; entering fail-closed mode: {error}"
);
should_fail_closed = true;
EstopState::fail_closed()
}
}
} else {
EstopState::default()
};
state.normalize();
let mut manager = Self {
config: config.clone(),
state_path,
state,
};
if should_fail_closed {
let _ = manager.persist_state();
}
Ok(manager)
}
pub fn state_path(&self) -> &Path {
&self.state_path
}
pub fn status(&self) -> EstopState {
self.state.clone()
}
pub fn engage(&mut self, level: EstopLevel) -> Result<()> {
match level {
EstopLevel::KillAll => {
self.state.kill_all = true;
}
EstopLevel::NetworkKill => {
self.state.network_kill = true;
}
EstopLevel::DomainBlock(domains) => {
for domain in domains {
let normalized = domain.trim().to_ascii_lowercase();
DomainMatcher::validate_pattern(&normalized)?;
self.state.blocked_domains.push(normalized);
}
}
EstopLevel::ToolFreeze(tools) => {
for tool in tools {
let normalized = normalize_tool_name(&tool)?;
self.state.frozen_tools.push(normalized);
}
}
}
self.state.updated_at = Some(now_rfc3339());
self.state.normalize();
self.persist_state()
}
pub fn resume(
&mut self,
selector: ResumeSelector,
otp_code: Option<&str>,
otp_validator: Option<&OtpValidator>,
) -> Result<()> {
self.ensure_resume_is_authorized(otp_code, otp_validator)?;
match selector {
ResumeSelector::KillAll => {
self.state.kill_all = false;
}
ResumeSelector::Network => {
self.state.network_kill = false;
}
ResumeSelector::Domains(domains) => {
let normalized = domains
.iter()
.map(|domain| domain.trim().to_ascii_lowercase())
.collect::<Vec<_>>();
self.state
.blocked_domains
.retain(|existing| !normalized.iter().any(|target| target == existing));
}
ResumeSelector::Tools(tools) => {
let normalized = tools
.iter()
.map(|tool| normalize_tool_name(tool))
.collect::<Result<Vec<_>>>()?;
self.state
.frozen_tools
.retain(|existing| !normalized.iter().any(|target| target == existing));
}
}
self.state.updated_at = Some(now_rfc3339());
self.state.normalize();
self.persist_state()
}
fn ensure_resume_is_authorized(
&self,
otp_code: Option<&str>,
otp_validator: Option<&OtpValidator>,
) -> Result<()> {
if !self.config.require_otp_to_resume {
return Ok(());
}
let code = otp_code
.map(str::trim)
.filter(|value| !value.is_empty())
.context("OTP code is required to resume estop state")?;
let validator = otp_validator
.context("OTP validator is required to resume estop state with OTP enabled")?;
let valid = validator.validate(code)?;
if !valid {
anyhow::bail!("Invalid OTP code; estop resume denied");
}
Ok(())
}
fn persist_state(&mut self) -> Result<()> {
if let Some(parent) = self.state_path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!("Failed to create estop state dir {}", parent.display())
})?;
}
let body =
serde_json::to_string_pretty(&self.state).context("Failed to serialize estop state")?;
let temp_path = self
.state_path
.with_extension(format!("tmp-{}", uuid::Uuid::new_v4()));
fs::write(&temp_path, body).with_context(|| {
format!(
"Failed to write temporary estop state file {}",
temp_path.display()
)
})?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(&temp_path, fs::Permissions::from_mode(0o600));
}
fs::rename(&temp_path, &self.state_path).with_context(|| {
format!(
"Failed to atomically replace estop state file {}",
self.state_path.display()
)
})?;
Ok(())
}
}
pub fn resolve_state_file_path(config_dir: &Path, state_file: &str) -> PathBuf {
let expanded = shellexpand::tilde(state_file).into_owned();
let path = PathBuf::from(expanded);
if path.is_absolute() {
path
} else {
config_dir.join(path)
}
}
fn normalize_tool_name(raw: &str) -> Result<String> {
let value = raw.trim().to_ascii_lowercase();
if value.is_empty() {
anyhow::bail!("Tool name must not be empty");
}
if !value
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-')
{
anyhow::bail!("Tool name '{raw}' contains invalid characters");
}
Ok(value)
}
fn dedup_sort(values: &[String]) -> Vec<String> {
let mut deduped = values
.iter()
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.map(ToString::to_string)
.collect::<Vec<_>>();
deduped.sort_unstable();
deduped.dedup();
deduped
}
fn now_rfc3339() -> String {
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs())
.unwrap_or(0);
chrono::DateTime::<chrono::Utc>::from_timestamp(secs as i64, 0)
.unwrap_or(chrono::DateTime::<chrono::Utc>::UNIX_EPOCH)
.to_rfc3339()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::OtpConfig;
use crate::security::otp::OtpValidator;
use crate::security::SecretStore;
use tempfile::tempdir;
fn estop_config(path: &Path) -> EstopConfig {
EstopConfig {
enabled: true,
state_file: path.display().to_string(),
require_otp_to_resume: false,
}
}
#[test]
fn estop_levels_compose_and_resume() {
let dir = tempdir().unwrap();
let state_path = dir.path().join("estop-state.json");
let cfg = estop_config(&state_path);
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
manager
.engage(EstopLevel::DomainBlock(vec!["*.chase.com".into()]))
.unwrap();
manager
.engage(EstopLevel::ToolFreeze(vec!["shell".into()]))
.unwrap();
manager.engage(EstopLevel::NetworkKill).unwrap();
assert!(manager.status().network_kill);
assert_eq!(manager.status().blocked_domains, vec!["*.chase.com"]);
assert_eq!(manager.status().frozen_tools, vec!["shell"]);
manager
.resume(
ResumeSelector::Domains(vec!["*.chase.com".into()]),
None,
None,
)
.unwrap();
assert!(manager.status().blocked_domains.is_empty());
assert!(manager.status().network_kill);
manager
.resume(ResumeSelector::Tools(vec!["shell".into()]), None, None)
.unwrap();
assert!(manager.status().frozen_tools.is_empty());
}
#[test]
fn estop_state_survives_reload() {
let dir = tempdir().unwrap();
let state_path = dir.path().join("estop-state.json");
let cfg = estop_config(&state_path);
{
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
manager.engage(EstopLevel::KillAll).unwrap();
manager
.engage(EstopLevel::DomainBlock(vec!["*.paypal.com".into()]))
.unwrap();
}
let reloaded = EstopManager::load(&cfg, dir.path()).unwrap();
let state = reloaded.status();
assert!(state.kill_all);
assert_eq!(state.blocked_domains, vec!["*.paypal.com"]);
}
#[test]
fn corrupted_state_defaults_to_fail_closed_kill_all() {
let dir = tempdir().unwrap();
let state_path = dir.path().join("estop-state.json");
fs::write(&state_path, "{not-valid-json").unwrap();
let cfg = estop_config(&state_path);
let manager = EstopManager::load(&cfg, dir.path()).unwrap();
assert!(manager.status().kill_all);
}
#[test]
fn resume_requires_valid_otp_when_enabled() {
let dir = tempdir().unwrap();
let state_path = dir.path().join("estop-state.json");
let mut cfg = estop_config(&state_path);
cfg.require_otp_to_resume = true;
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
manager.engage(EstopLevel::KillAll).unwrap();
let err = manager
.resume(ResumeSelector::KillAll, None, None)
.expect_err("resume should require OTP");
assert!(err.to_string().contains("OTP code is required"));
}
#[test]
fn resume_accepts_valid_otp_code() {
let dir = tempdir().unwrap();
let state_path = dir.path().join("estop-state.json");
let mut cfg = estop_config(&state_path);
cfg.require_otp_to_resume = true;
let otp_cfg = OtpConfig {
enabled: true,
..OtpConfig::default()
};
let store = SecretStore::new(dir.path(), true);
let (validator, _) = OtpValidator::from_config(&otp_cfg, dir.path(), &store).unwrap();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs())
.unwrap_or(0);
let code = validator.code_for_timestamp(now);
let mut manager = EstopManager::load(&cfg, dir.path()).unwrap();
manager.engage(EstopLevel::KillAll).unwrap();
manager
.resume(ResumeSelector::KillAll, Some(&code), Some(&validator))
.unwrap();
assert!(!manager.status().kill_all);
}
}
+8
View File
@@ -23,10 +23,13 @@ pub mod audit;
pub mod bubblewrap;
pub mod detect;
pub mod docker;
pub mod domain_matcher;
pub mod estop;
#[cfg(target_os = "linux")]
pub mod firejail;
#[cfg(feature = "sandbox-landlock")]
pub mod landlock;
pub mod otp;
pub mod pairing;
pub mod policy;
pub mod secrets;
@@ -36,6 +39,11 @@ pub mod traits;
pub use audit::{AuditEvent, AuditEventType, AuditLogger};
#[allow(unused_imports)]
pub use detect::create_sandbox;
pub use domain_matcher::DomainMatcher;
#[allow(unused_imports)]
pub use estop::{EstopLevel, EstopManager, EstopState, ResumeSelector};
#[allow(unused_imports)]
pub use otp::OtpValidator;
#[allow(unused_imports)]
pub use pairing::PairingGuard;
pub use policy::{AutonomyLevel, SecurityPolicy};
+318
View File
@@ -0,0 +1,318 @@
use crate::config::OtpConfig;
use crate::security::secrets::SecretStore;
use anyhow::{Context, Result};
use parking_lot::Mutex;
use ring::hmac;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
const OTP_SECRET_FILE: &str = "otp-secret";
const OTP_DIGITS: u32 = 6;
const OTP_ISSUER: &str = "ZeroClaw";
#[derive(Debug)]
pub struct OtpValidator {
config: OtpConfig,
secret: Vec<u8>,
cached_codes: Mutex<HashMap<String, u64>>,
}
impl OtpValidator {
pub fn from_config(
config: &OtpConfig,
zeroclaw_dir: &Path,
store: &SecretStore,
) -> Result<(Self, Option<String>)> {
let secret_path = secret_file_path(zeroclaw_dir);
let (secret, generated) = if secret_path.exists() {
let encoded = fs::read_to_string(&secret_path).with_context(|| {
format!("Failed to read OTP secret file {}", secret_path.display())
})?;
let decrypted = store
.decrypt(encoded.trim())
.context("Failed to decrypt OTP secret file")?;
(decode_base32_secret(&decrypted)?, false)
} else {
let raw: [u8; 20] = rand::random();
let encoded_secret = encode_base32_secret(&raw);
let encrypted = store
.encrypt(&encoded_secret)
.context("Failed to encrypt OTP secret")?;
write_secret_file(&secret_path, &encrypted)?;
(raw.to_vec(), true)
};
let validator = Self {
config: config.clone(),
secret,
cached_codes: Mutex::new(HashMap::new()),
};
let uri = if generated {
Some(validator.otpauth_uri())
} else {
None
};
Ok((validator, uri))
}
pub fn validate(&self, code: &str) -> Result<bool> {
self.validate_at(code, unix_timestamp_now())
}
fn validate_at(&self, code: &str, now_secs: u64) -> Result<bool> {
let normalized = code.trim();
if normalized.len() != OTP_DIGITS as usize
|| !normalized.chars().all(|ch| ch.is_ascii_digit())
{
return Ok(false);
}
{
let mut cache = self.cached_codes.lock();
cache.retain(|_, expiry| *expiry >= now_secs);
if cache
.get(normalized)
.is_some_and(|expiry| *expiry >= now_secs)
{
return Ok(true);
}
}
let step = self.config.token_ttl_secs.max(1);
let counter = now_secs / step;
let counters = [
counter.saturating_sub(1),
counter,
counter.saturating_add(1),
];
let is_valid = counters
.iter()
.map(|c| compute_totp_code(&self.secret, *c))
.any(|candidate| candidate == normalized);
if is_valid {
let mut cache = self.cached_codes.lock();
cache.insert(
normalized.to_string(),
now_secs.saturating_add(self.config.cache_valid_secs),
);
}
Ok(is_valid)
}
pub fn otpauth_uri(&self) -> String {
let secret = encode_base32_secret(&self.secret);
let account = "zeroclaw";
format!(
"otpauth://totp/{issuer}:{account}?secret={secret}&issuer={issuer}&period={period}",
issuer = OTP_ISSUER,
period = self.config.token_ttl_secs.max(1)
)
}
#[cfg(test)]
pub(crate) fn code_for_timestamp(&self, timestamp: u64) -> String {
let counter = timestamp / self.config.token_ttl_secs.max(1);
compute_totp_code(&self.secret, counter)
}
}
pub fn secret_file_path(zeroclaw_dir: &Path) -> PathBuf {
zeroclaw_dir.join(OTP_SECRET_FILE)
}
fn write_secret_file(path: &Path, value: &str) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.with_context(|| format!("Failed to create directory {}", parent.display()))?;
}
let temp_path = path.with_extension(format!("tmp-{}", uuid::Uuid::new_v4()));
fs::write(&temp_path, value).with_context(|| {
format!(
"Failed to write temporary OTP secret {}",
temp_path.display()
)
})?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(&temp_path, fs::Permissions::from_mode(0o600));
}
fs::rename(&temp_path, path).with_context(|| {
format!(
"Failed to atomically replace OTP secret file {}",
path.display()
)
})?;
Ok(())
}
fn unix_timestamp_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs())
.unwrap_or(0)
}
fn compute_totp_code(secret: &[u8], counter: u64) -> String {
let key = hmac::Key::new(hmac::HMAC_SHA1_FOR_LEGACY_USE_ONLY, secret);
let counter_bytes = counter.to_be_bytes();
let digest = hmac::sign(&key, &counter_bytes);
let hash = digest.as_ref();
let offset = (hash[19] & 0x0f) as usize;
let binary = ((u32::from(hash[offset]) & 0x7f) << 24)
| (u32::from(hash[offset + 1]) << 16)
| (u32::from(hash[offset + 2]) << 8)
| u32::from(hash[offset + 3]);
let code = binary % 10_u32.pow(OTP_DIGITS);
format!("{code:0>6}")
}
fn encode_base32_secret(input: &[u8]) -> String {
const ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
if input.is_empty() {
return String::new();
}
let mut result = String::new();
let mut buffer = 0u16;
let mut bits_left = 0u8;
for byte in input {
buffer = (buffer << 8) | u16::from(*byte);
bits_left += 8;
while bits_left >= 5 {
let index = ((buffer >> (bits_left - 5)) & 0x1f) as usize;
result.push(ALPHABET[index] as char);
bits_left -= 5;
}
}
if bits_left > 0 {
let index = ((buffer << (5 - bits_left)) & 0x1f) as usize;
result.push(ALPHABET[index] as char);
}
result
}
fn decode_base32_secret(raw: &str) -> Result<Vec<u8>> {
fn decode_char(ch: char) -> Option<u8> {
match ch {
'A'..='Z' => Some((ch as u8) - b'A'),
'2'..='7' => Some((ch as u8) - b'2' + 26),
_ => None,
}
}
let mut cleaned = raw
.chars()
.filter(|ch| !matches!(ch, ' ' | '\t' | '\n' | '\r' | '-'))
.collect::<String>()
.to_ascii_uppercase();
while cleaned.ends_with('=') {
cleaned.pop();
}
if cleaned.is_empty() {
anyhow::bail!("OTP secret is empty");
}
let mut output = Vec::new();
let mut buffer = 0u32;
let mut bits_left = 0u8;
for ch in cleaned.chars() {
let value = decode_char(ch)
.with_context(|| format!("OTP secret contains invalid base32 character '{ch}'"))?;
buffer = (buffer << 5) | u32::from(value);
bits_left += 5;
if bits_left >= 8 {
let byte = ((buffer >> (bits_left - 8)) & 0xff) as u8;
output.push(byte);
bits_left -= 8;
}
}
if output.is_empty() {
anyhow::bail!("OTP secret did not decode to any bytes");
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn test_config() -> OtpConfig {
OtpConfig {
enabled: true,
token_ttl_secs: 30,
cache_valid_secs: 120,
..OtpConfig::default()
}
}
#[test]
fn valid_totp_code_is_accepted() {
let dir = tempdir().unwrap();
let store = SecretStore::new(dir.path(), true);
let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
let now = 1_700_000_000u64;
let code = validator.code_for_timestamp(now);
assert!(validator.validate_at(&code, now).unwrap());
}
#[test]
fn expired_totp_code_is_rejected() {
let dir = tempdir().unwrap();
let store = SecretStore::new(dir.path(), true);
let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
let stale = 1_700_000_000u64;
let now = stale + 300;
let code = validator.code_for_timestamp(stale);
assert!(!validator.validate_at(&code, now).unwrap());
}
#[test]
fn wrong_totp_code_is_rejected() {
let dir = tempdir().unwrap();
let store = SecretStore::new(dir.path(), true);
let (validator, _) = OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
assert!(!validator.validate_at("123456", 1_700_000_000).unwrap());
}
#[test]
fn secret_is_generated_and_reused() {
let dir = tempdir().unwrap();
let store = SecretStore::new(dir.path(), true);
let (first, first_uri) =
OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
assert!(first_uri.is_some());
let secret_path = secret_file_path(dir.path());
let stored = fs::read_to_string(&secret_path).unwrap();
assert!(SecretStore::is_encrypted(stored.trim()));
let (second, second_uri) =
OtpValidator::from_config(&test_config(), dir.path(), &store).unwrap();
assert!(second_uri.is_none());
let ts = 1_700_000_000u64;
assert_eq!(first.code_for_timestamp(ts), second.code_for_timestamp(ts));
}
}