diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 25e7da738..98965046a 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -32,7 +32,8 @@ use anyhow::{Context, Result}; use axum::{ body::{Body, Bytes}, extract::{ConnectInfo, Query, State}, - http::{header, HeaderMap, StatusCode}, + http::{header, HeaderMap, HeaderValue, StatusCode}, + middleware::{self, Next}, response::{IntoResponse, Json, Response}, routing::{delete, get, post, put}, Router, @@ -58,6 +59,24 @@ pub const RATE_LIMIT_MAX_KEYS_DEFAULT: usize = 10_000; /// Fallback max distinct idempotency keys retained in gateway memory. pub const IDEMPOTENCY_MAX_KEYS_DEFAULT: usize = 10_000; +/// Middleware that injects security headers on every HTTP response. +async fn security_headers_middleware(req: axum::extract::Request, next: Next) -> Response { + let mut response = next.run(req).await; + let headers = response.headers_mut(); + headers.insert( + header::X_CONTENT_TYPE_OPTIONS, + HeaderValue::from_static("nosniff"), + ); + headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY")); + headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store")); + headers.insert(header::X_XSS_PROTECTION, HeaderValue::from_static("0")); + headers.insert( + header::REFERRER_POLICY, + HeaderValue::from_static("strict-origin-when-cross-origin"), + ); + response +} + fn webhook_memory_key() -> String { format!("webhook_msg_{}", Uuid::new_v4()) } @@ -877,6 +896,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .merge(config_put_router) .with_state(state) .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE)) + .layer(middleware::from_fn(security_headers_middleware)) .layer(TimeoutLayer::with_status_code( StatusCode::REQUEST_TIMEOUT, Duration::from_secs(REQUEST_TIMEOUT_SECS), @@ -5710,4 +5730,46 @@ Reminder set successfully."#; // Should be allowed again assert!(limiter.allow("burst-ip")); } + + #[tokio::test] + async fn security_headers_are_set_on_responses() { + use axum::body::Body; + use axum::http::Request; + use tower::ServiceExt; + + let app = + Router::new() + .route("/test", get(|| async { "ok" })) + .layer(axum::middleware::from_fn( + super::security_headers_middleware, + )); + + let req = Request::builder().uri("/test").body(Body::empty()).unwrap(); + + let response = app.oneshot(req).await.unwrap(); + + assert_eq!( + response + .headers() + .get(header::X_CONTENT_TYPE_OPTIONS) + .unwrap(), + "nosniff" + ); + assert_eq!( + response.headers().get(header::X_FRAME_OPTIONS).unwrap(), + "DENY" + ); + assert_eq!( + response.headers().get(header::CACHE_CONTROL).unwrap(), + "no-store" + ); + assert_eq!( + response.headers().get(header::X_XSS_PROTECTION).unwrap(), + "0" + ); + assert_eq!( + response.headers().get(header::REFERRER_POLICY).unwrap(), + "strict-origin-when-cross-origin" + ); + } }