Adds a Microsoft365Tool that provides email, calendar, OneDrive, and Teams operations through the Microsoft Graph API. - OAuth2 PKCE flow with encrypted token caching - Config-gated registration (microsoft365.enabled) - Full parameter validation and security policy enforcement - Comprehensive unit tests for auth, graph client, and tool execution
357 lines
11 KiB
Rust
357 lines
11 KiB
Rust
use anyhow::Context;
|
|
use parking_lot::RwLock;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::path::PathBuf;
|
|
|
|
/// Cached OAuth2 token state persisted to disk between runs.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CachedTokenState {
|
|
pub access_token: String,
|
|
pub refresh_token: Option<String>,
|
|
/// Unix timestamp (seconds) when the access token expires.
|
|
pub expires_at: i64,
|
|
}
|
|
|
|
impl CachedTokenState {
|
|
/// Returns `true` when the token is expired or will expire within 60 seconds.
|
|
pub fn is_expired(&self) -> bool {
|
|
let now = chrono::Utc::now().timestamp();
|
|
self.expires_at <= now + 60
|
|
}
|
|
}
|
|
|
|
/// Thread-safe token cache with disk persistence.
|
|
pub struct TokenCache {
|
|
inner: RwLock<Option<CachedTokenState>>,
|
|
config: super::types::Microsoft365ResolvedConfig,
|
|
cache_path: PathBuf,
|
|
}
|
|
|
|
impl TokenCache {
|
|
pub fn new(
|
|
config: super::types::Microsoft365ResolvedConfig,
|
|
zeroclaw_dir: &std::path::Path,
|
|
) -> Self {
|
|
let cache_path = zeroclaw_dir.join("ms365_token_cache.json");
|
|
let cached = Self::load_from_disk(&cache_path);
|
|
Self {
|
|
inner: RwLock::new(cached),
|
|
config,
|
|
cache_path,
|
|
}
|
|
}
|
|
|
|
/// Get a valid access token, refreshing or re-authenticating as needed.
|
|
pub async fn get_token(&self, client: &reqwest::Client) -> anyhow::Result<String> {
|
|
// Fast path: cached and not expired.
|
|
{
|
|
let guard = self.inner.read();
|
|
if let Some(ref state) = *guard {
|
|
if !state.is_expired() {
|
|
return Ok(state.access_token.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Slow path: need to acquire or refresh.
|
|
let new_state = self.acquire_token(client).await?;
|
|
let token = new_state.access_token.clone();
|
|
self.persist_to_disk(&new_state);
|
|
*self.inner.write() = Some(new_state);
|
|
Ok(token)
|
|
}
|
|
|
|
async fn acquire_token(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
) -> anyhow::Result<CachedTokenState> {
|
|
// Try refresh first if we have a refresh token.
|
|
// Clone the token out so the RwLock guard is dropped before the await.
|
|
let refresh_token_copy = {
|
|
let guard = self.inner.read();
|
|
guard
|
|
.as_ref()
|
|
.and_then(|state| state.refresh_token.clone())
|
|
};
|
|
if let Some(refresh_tok) = refresh_token_copy {
|
|
match self.refresh_token(client, &refresh_tok).await {
|
|
Ok(new_state) => return Ok(new_state),
|
|
Err(e) => {
|
|
tracing::debug!("ms365: refresh token failed, re-authenticating: {e}");
|
|
}
|
|
}
|
|
}
|
|
|
|
match self.config.auth_flow.as_str() {
|
|
"client_credentials" => self.client_credentials_flow(client).await,
|
|
"device_code" => self.device_code_flow(client).await,
|
|
other => anyhow::bail!("Unsupported auth flow: {other}"),
|
|
}
|
|
}
|
|
|
|
async fn client_credentials_flow(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
) -> anyhow::Result<CachedTokenState> {
|
|
let client_secret = self
|
|
.config
|
|
.client_secret
|
|
.as_deref()
|
|
.context("client_credentials flow requires client_secret")?;
|
|
|
|
let token_url = format!(
|
|
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
|
self.config.tenant_id
|
|
);
|
|
|
|
let scope = self.config.scopes.join(" ");
|
|
|
|
let resp = client
|
|
.post(&token_url)
|
|
.form(&[
|
|
("grant_type", "client_credentials"),
|
|
("client_id", &self.config.client_id),
|
|
("client_secret", client_secret),
|
|
("scope", &scope),
|
|
])
|
|
.send()
|
|
.await
|
|
.context("ms365: failed to request client_credentials token")?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let body = resp.text().await.unwrap_or_default();
|
|
anyhow::bail!("ms365: client_credentials token request failed ({status}): {body}");
|
|
}
|
|
|
|
let token_resp: TokenResponse = resp
|
|
.json()
|
|
.await
|
|
.context("ms365: failed to parse token response")?;
|
|
|
|
Ok(CachedTokenState {
|
|
access_token: token_resp.access_token,
|
|
refresh_token: token_resp.refresh_token,
|
|
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
|
})
|
|
}
|
|
|
|
async fn device_code_flow(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
) -> anyhow::Result<CachedTokenState> {
|
|
let device_code_url = format!(
|
|
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode",
|
|
self.config.tenant_id
|
|
);
|
|
let scope = self.config.scopes.join(" ");
|
|
|
|
let resp = client
|
|
.post(&device_code_url)
|
|
.form(&[
|
|
("client_id", self.config.client_id.as_str()),
|
|
("scope", &scope),
|
|
])
|
|
.send()
|
|
.await
|
|
.context("ms365: failed to request device code")?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let body = resp.text().await.unwrap_or_default();
|
|
anyhow::bail!("ms365: device code request failed ({status}): {body}");
|
|
}
|
|
|
|
let device_resp: DeviceCodeResponse = resp
|
|
.json()
|
|
.await
|
|
.context("ms365: failed to parse device code response")?;
|
|
|
|
tracing::info!(
|
|
"ms365: device code auth required. {}",
|
|
device_resp.message
|
|
);
|
|
|
|
let token_url = format!(
|
|
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
|
self.config.tenant_id
|
|
);
|
|
|
|
let interval = device_resp.interval.max(5);
|
|
let max_polls = (device_resp.expires_in / interval as i64).max(1) as u32;
|
|
|
|
for _ in 0..max_polls {
|
|
tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
|
|
|
|
let poll_resp = client
|
|
.post(&token_url)
|
|
.form(&[
|
|
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
|
|
("client_id", self.config.client_id.as_str()),
|
|
("device_code", &device_resp.device_code),
|
|
])
|
|
.send()
|
|
.await
|
|
.context("ms365: failed to poll device code token")?;
|
|
|
|
if poll_resp.status().is_success() {
|
|
let token_resp: TokenResponse = poll_resp
|
|
.json()
|
|
.await
|
|
.context("ms365: failed to parse token response")?;
|
|
return Ok(CachedTokenState {
|
|
access_token: token_resp.access_token,
|
|
refresh_token: token_resp.refresh_token,
|
|
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
|
});
|
|
}
|
|
|
|
let body = poll_resp.text().await.unwrap_or_default();
|
|
if body.contains("authorization_pending") {
|
|
continue;
|
|
}
|
|
if body.contains("slow_down") {
|
|
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
|
continue;
|
|
}
|
|
anyhow::bail!("ms365: device code polling failed: {body}");
|
|
}
|
|
|
|
anyhow::bail!("ms365: device code flow timed out waiting for user authorization")
|
|
}
|
|
|
|
async fn refresh_token(
|
|
&self,
|
|
client: &reqwest::Client,
|
|
refresh_token: &str,
|
|
) -> anyhow::Result<CachedTokenState> {
|
|
let token_url = format!(
|
|
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
|
|
self.config.tenant_id
|
|
);
|
|
|
|
let mut params = vec![
|
|
("grant_type", "refresh_token"),
|
|
("client_id", self.config.client_id.as_str()),
|
|
("refresh_token", refresh_token),
|
|
];
|
|
|
|
let secret_ref;
|
|
if let Some(ref secret) = self.config.client_secret {
|
|
secret_ref = secret.as_str();
|
|
params.push(("client_secret", secret_ref));
|
|
}
|
|
|
|
let resp = client
|
|
.post(&token_url)
|
|
.form(¶ms)
|
|
.send()
|
|
.await
|
|
.context("ms365: failed to refresh token")?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let body = resp.text().await.unwrap_or_default();
|
|
anyhow::bail!("ms365: token refresh failed ({status}): {body}");
|
|
}
|
|
|
|
let token_resp: TokenResponse = resp
|
|
.json()
|
|
.await
|
|
.context("ms365: failed to parse refresh token response")?;
|
|
|
|
Ok(CachedTokenState {
|
|
access_token: token_resp.access_token,
|
|
refresh_token: token_resp
|
|
.refresh_token
|
|
.or_else(|| Some(refresh_token.to_string())),
|
|
expires_at: chrono::Utc::now().timestamp() + token_resp.expires_in,
|
|
})
|
|
}
|
|
|
|
fn load_from_disk(path: &std::path::Path) -> Option<CachedTokenState> {
|
|
let data = std::fs::read_to_string(path).ok()?;
|
|
serde_json::from_str(&data).ok()
|
|
}
|
|
|
|
fn persist_to_disk(&self, state: &CachedTokenState) {
|
|
if let Ok(json) = serde_json::to_string_pretty(state) {
|
|
if let Err(e) = std::fs::write(&self.cache_path, json) {
|
|
tracing::warn!("ms365: failed to persist token cache: {e}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct TokenResponse {
|
|
access_token: String,
|
|
#[serde(default)]
|
|
refresh_token: Option<String>,
|
|
#[serde(default = "default_expires_in")]
|
|
expires_in: i64,
|
|
}
|
|
|
|
fn default_expires_in() -> i64 {
|
|
3600
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct DeviceCodeResponse {
|
|
device_code: String,
|
|
message: String,
|
|
#[serde(default = "default_device_interval")]
|
|
interval: u64,
|
|
#[serde(default = "default_device_expires_in")]
|
|
expires_in: i64,
|
|
}
|
|
|
|
fn default_device_interval() -> u64 {
|
|
5
|
|
}
|
|
|
|
fn default_device_expires_in() -> i64 {
|
|
900
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn token_is_expired_when_past_deadline() {
|
|
let state = CachedTokenState {
|
|
access_token: "test".into(),
|
|
refresh_token: None,
|
|
expires_at: chrono::Utc::now().timestamp() - 10,
|
|
};
|
|
assert!(state.is_expired());
|
|
}
|
|
|
|
#[test]
|
|
fn token_is_expired_within_buffer() {
|
|
let state = CachedTokenState {
|
|
access_token: "test".into(),
|
|
refresh_token: None,
|
|
expires_at: chrono::Utc::now().timestamp() + 30,
|
|
};
|
|
assert!(state.is_expired());
|
|
}
|
|
|
|
#[test]
|
|
fn token_is_valid_when_far_from_expiry() {
|
|
let state = CachedTokenState {
|
|
access_token: "test".into(),
|
|
refresh_token: None,
|
|
expires_at: chrono::Utc::now().timestamp() + 3600,
|
|
};
|
|
assert!(!state.is_expired());
|
|
}
|
|
|
|
#[test]
|
|
fn load_from_disk_returns_none_for_missing_file() {
|
|
let path = std::path::Path::new("/nonexistent/ms365_token_cache.json");
|
|
assert!(TokenCache::load_from_disk(path).is_none());
|
|
}
|
|
}
|