diff --git a/src/providers/bedrock.rs b/src/providers/bedrock.rs index 7f8e9fcfe..e1576db55 100644 --- a/src/providers/bedrock.rs +++ b/src/providers/bedrock.rs @@ -33,10 +33,7 @@ struct AwsCredentials { } impl AwsCredentials { - /// Resolve credentials from environment variables. - /// - /// Required: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`. - /// Optional: `AWS_SESSION_TOKEN`, `AWS_REGION` / `AWS_DEFAULT_REGION`. + /// Resolve credentials: first try environment variables, then EC2 IMDSv2. fn from_env() -> anyhow::Result { let access_key_id = env_required("AWS_ACCESS_KEY_ID")?; let secret_access_key = env_required("AWS_SECRET_ACCESS_KEY")?; @@ -55,6 +52,91 @@ impl AwsCredentials { }) } + /// Fetch credentials from EC2 IMDSv2 instance metadata service. + async fn from_imds() -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(3)) + .build()?; + + // Step 1: get IMDSv2 token + let token = client + .put("http://169.254.169.254/latest/api/token") + .header("X-aws-ec2-metadata-token-ttl-seconds", "21600") + .send() + .await? + .text() + .await?; + + // Step 2: get IAM role name + let role = client + .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/") + .header("X-aws-ec2-metadata-token", &token) + .send() + .await? + .text() + .await?; + let role = role.trim().to_string(); + anyhow::ensure!(!role.is_empty(), "No IAM role attached to this instance"); + + // Step 3: get credentials for that role + let creds_url = format!( + "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}", + role + ); + let creds_json: serde_json::Value = client + .get(&creds_url) + .header("X-aws-ec2-metadata-token", &token) + .send() + .await? + .json() + .await?; + + let access_key_id = creds_json["AccessKeyId"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing AccessKeyId in IMDS response"))? + .to_string(); + let secret_access_key = creds_json["SecretAccessKey"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("Missing SecretAccessKey in IMDS response"))? + .to_string(); + let session_token = creds_json["Token"].as_str().map(|s| s.to_string()); + + // Step 4: get region from instance identity document + let region = match client + .get("http://169.254.169.254/latest/meta-data/placement/region") + .header("X-aws-ec2-metadata-token", &token) + .send() + .await + { + Ok(resp) => resp.text().await.unwrap_or_default(), + Err(_) => String::new(), + }; + let region = if region.trim().is_empty() { + env_optional("AWS_REGION") + .or_else(|| env_optional("AWS_DEFAULT_REGION")) + .unwrap_or_else(|| DEFAULT_REGION.to_string()) + } else { + region.trim().to_string() + }; + + tracing::info!("Loaded AWS credentials from EC2 instance metadata (role: {})", role); + + Ok(Self { + access_key_id, + secret_access_key, + session_token, + region, + }) + } + + /// Resolve credentials: env vars first, then EC2 IMDS. + async fn resolve() -> anyhow::Result { + if let Ok(creds) = Self::from_env() { + return Ok(creds); + } + Self::from_imds().await + } + fn host(&self) -> String { format!("{ENDPOINT_PREFIX}.{}.amazonaws.com", self.region) } @@ -366,6 +448,11 @@ impl BedrockProvider { } } + pub async fn new_async() -> Self { + let credentials = AwsCredentials::resolve().await.ok(); + Self { credentials } + } + fn http_client(&self) -> Client { crate::config::build_runtime_proxy_client_with_timeouts("provider.bedrock", 120, 10) } @@ -394,11 +481,20 @@ impl BedrockProvider { self.credentials.as_ref().ok_or_else(|| { anyhow::anyhow!( "AWS Bedrock credentials not set. Set AWS_ACCESS_KEY_ID and \ - AWS_SECRET_ACCESS_KEY environment variables." + AWS_SECRET_ACCESS_KEY environment variables, or run on an EC2 \ + instance with an IAM role attached." ) }) } + /// Resolve credentials: use cached if available, otherwise fetch from IMDS. + async fn resolve_credentials(&self) -> anyhow::Result { + if let Ok(creds) = AwsCredentials::from_env() { + return Ok(creds); + } + AwsCredentials::from_imds().await + } + // ── Cache heuristics (same thresholds as AnthropicProvider) ── /// Cache system prompts larger than ~1024 tokens (3KB of text). @@ -775,7 +871,7 @@ impl Provider for BedrockProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let credentials = self.require_credentials()?; + let credentials = self.resolve_credentials().await?; let system = system_prompt.map(|text| { let mut blocks = vec![SystemBlock::Text(TextBlock { @@ -803,7 +899,7 @@ impl Provider for BedrockProvider { }; let response = self - .send_converse_request(credentials, model, &request) + .send_converse_request(&credentials, model, &request) .await?; Self::parse_converse_response(response) @@ -817,7 +913,7 @@ impl Provider for BedrockProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let credentials = self.require_credentials()?; + let credentials = self.resolve_credentials().await?; let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages); @@ -858,7 +954,7 @@ impl Provider for BedrockProvider { }; let response = self - .send_converse_request(credentials, model, &converse_request) + .send_converse_request(&credentials, model, &converse_request) .await?; Ok(Self::parse_converse_response(response))