use anyhow::Context; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use std::path::PathBuf; /// Cached OAuth2 token state persisted to disk between runs. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CachedTokenState { pub access_token: String, pub refresh_token: Option, /// Unix timestamp (seconds) when the access token expires. pub expires_at: i64, } impl CachedTokenState { /// Returns `true` when the token is expired or will expire within 60 seconds. pub fn is_expired(&self) -> bool { let now = chrono::Utc::now().timestamp(); self.expires_at <= now + 60 } } /// Thread-safe token cache with disk persistence. pub struct TokenCache { inner: RwLock>, config: super::types::Microsoft365ResolvedConfig, cache_path: PathBuf, } impl TokenCache { pub fn new( config: super::types::Microsoft365ResolvedConfig, zeroclaw_dir: &std::path::Path, ) -> Self { let cache_path = zeroclaw_dir.join("ms365_token_cache.json"); let cached = Self::load_from_disk(&cache_path); Self { inner: RwLock::new(cached), config, cache_path, } } /// Get a valid access token, refreshing or re-authenticating as needed. pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result { // Fast path: cached and not expired. { let guard = self.inner.read(); if let Some(ref state) = *guard { if !state.is_expired() { return Ok(state.access_token.clone()); } } } // Slow path: need to acquire or refresh. let new_state = self.acquire_token(client).await?; let token = new_state.access_token.clone(); self.persist_to_disk(&new_state); *self.inner.write() = Some(new_state); Ok(token) } async fn acquire_token( &self, client: &reqwest::Client, ) -> anyhow::Result { // Try refresh first if we have a refresh token. // Clone the token out so the RwLock guard is dropped before the await. let refresh_token_copy = { let guard = self.inner.read(); guard .as_ref() .and_then(|state| state.refresh_token.clone()) }; if let Some(refresh_tok) = refresh_token_copy { match self.refresh_token(client, &refresh_tok).await { Ok(new_state) => return Ok(new_state), Err(e) => { tracing::debug!("ms365: refresh token failed, re-authenticating: {e}"); } } } match self.config.auth_flow.as_str() { "client_credentials" => self.client_credentials_flow(client).await, "device_code" => self.device_code_flow(client).await, other => anyhow::bail!("Unsupported auth flow: {other}"), } } async fn client_credentials_flow( &self, client: &reqwest::Client, ) -> anyhow::Result { let client_secret = self .config .client_secret .as_deref() .context("client_credentials flow requires client_secret")?; let token_url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.config.tenant_id ); let scope = self.config.scopes.join(" "); let resp = client .post(&token_url) .form(&[ ("grant_type", "client_credentials"), ("client_id", &self.config.client_id), ("client_secret", client_secret), ("scope", &scope), ]) .send() .await .context("ms365: failed to request client_credentials token")?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); anyhow::bail!("ms365: client_credentials token request failed ({status}): {body}"); } let token_resp: TokenResponse = resp .json() .await .context("ms365: failed to parse token response")?; Ok(CachedTokenState { access_token: token_resp.access_token, refresh_token: token_resp.refresh_token, expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, }) } async fn device_code_flow( &self, client: &reqwest::Client, ) -> anyhow::Result { let device_code_url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode", self.config.tenant_id ); let scope = self.config.scopes.join(" "); let resp = client .post(&device_code_url) .form(&[ ("client_id", self.config.client_id.as_str()), ("scope", &scope), ]) .send() .await .context("ms365: failed to request device code")?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); anyhow::bail!("ms365: device code request failed ({status}): {body}"); } let device_resp: DeviceCodeResponse = resp .json() .await .context("ms365: failed to parse device code response")?; tracing::info!( "ms365: device code auth required. {}", device_resp.message ); let token_url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.config.tenant_id ); let interval = device_resp.interval.max(5); let max_polls = (device_resp.expires_in / interval as i64).max(1) as u32; for _ in 0..max_polls { tokio::time::sleep(std::time::Duration::from_secs(interval)).await; let poll_resp = client .post(&token_url) .form(&[ ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), ("client_id", self.config.client_id.as_str()), ("device_code", &device_resp.device_code), ]) .send() .await .context("ms365: failed to poll device code token")?; if poll_resp.status().is_success() { let token_resp: TokenResponse = poll_resp .json() .await .context("ms365: failed to parse token response")?; return Ok(CachedTokenState { access_token: token_resp.access_token, refresh_token: token_resp.refresh_token, expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, }); } let body = poll_resp.text().await.unwrap_or_default(); if body.contains("authorization_pending") { continue; } if body.contains("slow_down") { tokio::time::sleep(std::time::Duration::from_secs(5)).await; continue; } anyhow::bail!("ms365: device code polling failed: {body}"); } anyhow::bail!("ms365: device code flow timed out waiting for user authorization") } async fn refresh_token( &self, client: &reqwest::Client, refresh_token: &str, ) -> anyhow::Result { let token_url = format!( "https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.config.tenant_id ); let mut params = vec![ ("grant_type", "refresh_token"), ("client_id", self.config.client_id.as_str()), ("refresh_token", refresh_token), ]; let secret_ref; if let Some(ref secret) = self.config.client_secret { secret_ref = secret.as_str(); params.push(("client_secret", secret_ref)); } let resp = client .post(&token_url) .form(¶ms) .send() .await .context("ms365: failed to refresh token")?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); anyhow::bail!("ms365: token refresh failed ({status}): {body}"); } let token_resp: TokenResponse = resp .json() .await .context("ms365: failed to parse refresh token response")?; Ok(CachedTokenState { access_token: token_resp.access_token, refresh_token: token_resp .refresh_token .or_else(|| Some(refresh_token.to_string())), expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in, }) } fn load_from_disk(path: &std::path::Path) -> Option { let data = std::fs::read_to_string(path).ok()?; serde_json::from_str(&data).ok() } fn persist_to_disk(&self, state: &CachedTokenState) { if let Ok(json) = serde_json::to_string_pretty(state) { if let Err(e) = std::fs::write(&self.cache_path, json) { tracing::warn!("ms365: failed to persist token cache: {e}"); } } } } #[derive(Deserialize)] struct TokenResponse { access_token: String, #[serde(default)] refresh_token: Option, #[serde(default = "default_expires_in")] expires_in: i64, } fn default_expires_in() -> i64 { 3600 } #[derive(Deserialize)] struct DeviceCodeResponse { device_code: String, message: String, #[serde(default = "default_device_interval")] interval: u64, #[serde(default = "default_device_expires_in")] expires_in: i64, } fn default_device_interval() -> u64 { 5 } fn default_device_expires_in() -> i64 { 900 } #[cfg(test)] mod tests { use super::*; #[test] fn token_is_expired_when_past_deadline() { let state = CachedTokenState { access_token: "test".into(), refresh_token: None, expires_at: chrono::Utc::now().timestamp() - 10, }; assert!(state.is_expired()); } #[test] fn token_is_expired_within_buffer() { let state = CachedTokenState { access_token: "test".into(), refresh_token: None, expires_at: chrono::Utc::now().timestamp() + 30, }; assert!(state.is_expired()); } #[test] fn token_is_valid_when_far_from_expiry() { let state = CachedTokenState { access_token: "test".into(), refresh_token: None, expires_at: chrono::Utc::now().timestamp() + 3600, }; assert!(!state.is_expired()); } #[test] fn load_from_disk_returns_none_for_missing_file() { let path = std::path::Path::new("/nonexistent/ms365_token_cache.json"); assert!(TokenCache::load_from_disk(path).is_none()); } }