Skip to content

Commit 3583b0e

Browse files
authored
fix(ingest/aws): Fix for the aws credential caching (#15278)
1 parent ad46463 commit 3583b0e

File tree

2 files changed

+351
-11
lines changed

2 files changed

+351
-11
lines changed

metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ class AwsConnectionConfig(ConfigModel):
269269
"""
270270

271271
_credentials_expiration: Optional[datetime] = None
272+
_cached_credentials: Optional[dict] = None
272273

273274
aws_access_key_id: Optional[str] = Field(
274275
default=None,
@@ -353,10 +354,25 @@ def get_session(self) -> Session:
353354
)
354355
else:
355356
# Use boto3's credential autodetection
356-
session = Session(region_name=self.aws_region)
357-
358357
target_roles = self._normalized_aws_roles()
359358
if target_roles:
359+
# If we have cached credentials that are still valid, use them
360+
if (
361+
self._cached_credentials is not None
362+
and not self._should_refresh_credentials()
363+
):
364+
logger.debug("Using cached assumed role credentials")
365+
return Session(
366+
aws_access_key_id=self._cached_credentials["AccessKeyId"],
367+
aws_secret_access_key=self._cached_credentials[
368+
"SecretAccessKey"
369+
],
370+
aws_session_token=self._cached_credentials["SessionToken"],
371+
region_name=self.aws_region,
372+
)
373+
374+
# Need to assume role (either first time or credentials expired)
375+
session = Session(region_name=self.aws_region)
360376
current_role_arn, credential_source = get_current_identity()
361377

362378
# Only assume role if:
@@ -368,7 +384,10 @@ def get_session(self) -> Session:
368384

369385
if should_assume_role:
370386
env = detect_aws_environment()
371-
logger.debug(f"Assuming role(s) from {env.value} environment")
387+
role_arns = [role.RoleArn for role in target_roles]
388+
logger.debug(
389+
f"Assuming {role_arns} role(s) from {env.value} environment"
390+
)
372391

373392
current_credentials = session.get_credentials()
374393
if current_credentials is None:
@@ -381,14 +400,15 @@ def get_session(self) -> Session:
381400
}
382401

383402
for role in target_roles:
384-
if self._should_refresh_credentials():
385-
credentials = assume_role(
386-
role=role,
387-
aws_region=self.aws_region,
388-
credentials=credentials,
389-
)
390-
if isinstance(credentials["Expiration"], datetime):
391-
self._credentials_expiration = credentials["Expiration"]
403+
credentials = assume_role(
404+
role=role,
405+
aws_region=self.aws_region,
406+
credentials=credentials,
407+
)
408+
409+
if isinstance(credentials["Expiration"], datetime):
410+
self._credentials_expiration = credentials["Expiration"]
411+
self._cached_credentials = credentials
392412

393413
session = Session(
394414
aws_access_key_id=credentials["AccessKeyId"],
@@ -398,12 +418,17 @@ def get_session(self) -> Session:
398418
)
399419
else:
400420
logger.debug(f"Using existing role from {credential_source}")
421+
else:
422+
session = Session(region_name=self.aws_region)
401423

402424
return session
403425

404426
def _should_refresh_credentials(self) -> bool:
405427
if self._credentials_expiration is None:
406428
return True
429+
# Refresh credentials when less than 5 minutes remain before expiration.
430+
# This buffer helps avoid race conditions where credentials expire between
431+
# the cache check and actual AWS API usage.
407432
remaining_time = self._credentials_expiration - datetime.now(timezone.utc)
408433
return remaining_time < timedelta(minutes=5)
409434

0 commit comments

Comments
 (0)