zeroclaw/src/tools/microsoft365/auth.rs
Giulio V 2dbe4e4812 feat(tools): add Microsoft 365 integration via Graph API
Adds a Microsoft365Tool that provides email, calendar, OneDrive, and
Teams operations through the Microsoft Graph API.

- OAuth2 PKCE flow with encrypted token caching
- Config-gated registration (microsoft365.enabled)
- Full parameter validation and security policy enforcement
- Comprehensive unit tests for auth, graph client, and tool execution
2026-03-15 15:31:20 +01:00

357 lines
11 KiB
Rust

use anyhow::Context;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Cached OAuth2 token state persisted to disk between runs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedTokenState {
pub access_token: String,
pub refresh_token: Option<String>,
/// Unix timestamp (seconds) when the access token expires.
pub expires_at: i64,
}
impl CachedTokenState {
/// Returns `true` when the token is expired or will expire within 60 seconds.
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now().timestamp();
self.expires_at <= now + 60
}
}
/// Thread-safe token cache with disk persistence.
pub struct TokenCache {
inner: RwLock<Option<CachedTokenState>>,
config: super::types::Microsoft365ResolvedConfig,
cache_path: PathBuf,
}
impl TokenCache {
pub fn new(
config: super::types::Microsoft365ResolvedConfig,
zeroclaw_dir: &std::path::Path,
) -> Self {
let cache_path = zeroclaw_dir.join("ms365_token_cache.json");
let cached = Self::load_from_disk(&cache_path);
Self {
inner: RwLock::new(cached),
config,
cache_path,
}
}
/// Get a valid access token, refreshing or re-authenticating as needed.
pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
// Fast path: cached and not expired.
{
let guard = self.inner.read();
if let Some(ref state) = *guard {
if !state.is_expired() {
return Ok(state.access_token.clone());
}
}
}
// Slow path: need to acquire or refresh.
let new_state = self.acquire_token(client).await?;
let token = new_state.access_token.clone();
self.persist_to_disk(&new_state);
*self.inner.write() = Some(new_state);
Ok(token)
}
async fn acquire_token(
&self,
client: &reqwest::Client,
) -> anyhow::Result<CachedTokenState> {
// Try refresh first if we have a refresh token.
// Clone the token out so the RwLock guard is dropped before the await.
let refresh_token_copy = {
let guard = self.inner.read();
guard
.as_ref()
.and_then(|state| state.refresh_token.clone())
};
if let Some(refresh_tok) = refresh_token_copy {
match self.refresh_token(client, &refresh_tok).await {
Ok(new_state) => return Ok(new_state),
Err(e) => {
tracing::debug!("ms365: refresh token failed, re-authenticating: {e}");
}
}
}
match self.config.auth_flow.as_str() {
"client_credentials" => self.client_credentials_flow(client).await,
"device_code" => self.device_code_flow(client).await,
other => anyhow::bail!("Unsupported auth flow: {other}"),
}
}
async fn client_credentials_flow(
&self,
client: &reqwest::Client,
) -> anyhow::Result<CachedTokenState> {
let client_secret = self
.config
.client_secret
.as_deref()
.context("client_credentials flow requires client_secret")?;
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let scope = self.config.scopes.join(" ");
let resp = client
.post(&token_url)
.form(&[
("grant_type", "client_credentials"),
("client_id", &self.config.client_id),
("client_secret", client_secret),
("scope", &scope),
])
.send()
.await
.context("ms365: failed to request client_credentials token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("ms365: client_credentials token request failed ({status}): {body}");
}
let token_resp: TokenResponse = resp
.json()
.await
.context("ms365: failed to parse token response")?;
Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
})
}
async fn device_code_flow(
&self,
client: &reqwest::Client,
) -> anyhow::Result<CachedTokenState> {
let device_code_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
self.config.tenant_id
);
let scope = self.config.scopes.join(" ");
let resp = client
.post(&device_code_url)
.form(&[
("client_id", self.config.client_id.as_str()),
("scope", &scope),
])
.send()
.await
.context("ms365: failed to request device code")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("ms365: device code request failed ({status}): {body}");
}
let device_resp: DeviceCodeResponse = resp
.json()
.await
.context("ms365: failed to parse device code response")?;
tracing::info!(
"ms365: device code auth required. {}",
device_resp.message
);
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let interval = device_resp.interval.max(5);
let max_polls = (device_resp.expires_in / interval as i64).max(1) as u32;
for _ in 0..max_polls {
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
let poll_resp = client
.post(&token_url)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("client_id", self.config.client_id.as_str()),
("device_code", &device_resp.device_code),
])
.send()
.await
.context("ms365: failed to poll device code token")?;
if poll_resp.status().is_success() {
let token_resp: TokenResponse = poll_resp
.json()
.await
.context("ms365: failed to parse token response")?;
return Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp.refresh_token,
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
});
}
let body = poll_resp.text().await.unwrap_or_default();
if body.contains("authorization_pending") {
continue;
}
if body.contains("slow_down") {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
anyhow::bail!("ms365: device code polling failed: {body}");
}
anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
}
async fn refresh_token(
&self,
client: &reqwest::Client,
refresh_token: &str,
) -> anyhow::Result<CachedTokenState> {
let token_url = format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.config.tenant_id
);
let mut params = vec![
("grant_type", "refresh_token"),
("client_id", self.config.client_id.as_str()),
("refresh_token", refresh_token),
];
let secret_ref;
if let Some(ref secret) = self.config.client_secret {
secret_ref = secret.as_str();
params.push(("client_secret", secret_ref));
}
let resp = client
.post(&token_url)
.form(&params)
.send()
.await
.context("ms365: failed to refresh token")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("ms365: token refresh failed ({status}): {body}");
}
let token_resp: TokenResponse = resp
.json()
.await
.context("ms365: failed to parse refresh token response")?;
Ok(CachedTokenState {
access_token: token_resp.access_token,
refresh_token: token_resp
.refresh_token
.or_else(|| Some(refresh_token.to_string())),
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
})
}
fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
let data = std::fs::read_to_string(path).ok()?;
serde_json::from_str(&data).ok()
}
fn persist_to_disk(&self, state: &CachedTokenState) {
if let Ok(json) = serde_json::to_string_pretty(state) {
if let Err(e) = std::fs::write(&self.cache_path, json) {
tracing::warn!("ms365: failed to persist token cache: {e}");
}
}
}
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default = "default_expires_in")]
expires_in: i64,
}
fn default_expires_in() -> i64 {
3600
}
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
message: String,
#[serde(default = "default_device_interval")]
interval: u64,
#[serde(default = "default_device_expires_in")]
expires_in: i64,
}
fn default_device_interval() -> u64 {
5
}
fn default_device_expires_in() -> i64 {
900
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_is_expired_when_past_deadline() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() - 10,
};
assert!(state.is_expired());
}
#[test]
fn token_is_expired_within_buffer() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() + 30,
};
assert!(state.is_expired());
}
#[test]
fn token_is_valid_when_far_from_expiry() {
let state = CachedTokenState {
access_token: "test".into(),
refresh_token: None,
expires_at: chrono::Utc::now().timestamp() + 3600,
};
assert!(!state.is_expired());
}
#[test]
fn load_from_disk_returns_none_for_missing_file() {
let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
assert!(TokenCache::load_from_disk(path).is_none());
}
}