diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 98965046a..d76a9e9c0 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -68,7 +68,10 @@ async fn security_headers_middleware(req: axum::extract::Request, next: Next) -> HeaderValue::from_static("nosniff"), ); headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY")); - headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-store")); + // Only set Cache-Control if not already set by handler (e.g., SSE uses no-cache) + headers + .entry(header::CACHE_CONTROL) + .or_insert(HeaderValue::from_static("no-store")); headers.insert(header::X_XSS_PROTECTION, HeaderValue::from_static("0")); headers.insert( header::REFERRER_POLICY, @@ -5772,4 +5775,39 @@ Reminder set successfully."#; "strict-origin-when-cross-origin" ); } + + #[tokio::test] + async fn security_headers_are_set_on_error_responses() { + use axum::body::Body; + use axum::http::{Request, StatusCode}; + use tower::ServiceExt; + + let app = Router::new() + .route( + "/error", + get(|| async { StatusCode::INTERNAL_SERVER_ERROR }), + ) + .layer(axum::middleware::from_fn( + super::security_headers_middleware, + )); + + let req = Request::builder() + .uri("/error") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(req).await.unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + response + .headers() + .get(header::X_CONTENT_TYPE_OPTIONS) + .unwrap(), + "nosniff" + ); + assert_eq!( + response.headers().get(header::X_FRAME_OPTIONS).unwrap(), + "DENY" + ); + } }