@@ -269,6 +269,7 @@ class AwsConnectionConfig(ConfigModel):
269269 """
270270
271271 _credentials_expiration : Optional [datetime ] = None
272+ _cached_credentials : Optional [dict ] = None
272273
273274 aws_access_key_id : Optional [str ] = Field (
274275 default = None ,
@@ -353,10 +354,25 @@ def get_session(self) -> Session:
353354 )
354355 else :
355356 # Use boto3's credential autodetection
356- session = Session (region_name = self .aws_region )
357-
358357 target_roles = self ._normalized_aws_roles ()
359358 if target_roles :
359+ # If we have cached credentials that are still valid, use them
360+ if (
361+ self ._cached_credentials is not None
362+ and not self ._should_refresh_credentials ()
363+ ):
364+ logger .debug ("Using cached assumed role credentials" )
365+ return Session (
366+ aws_access_key_id = self ._cached_credentials ["AccessKeyId" ],
367+ aws_secret_access_key = self ._cached_credentials [
368+ "SecretAccessKey"
369+ ],
370+ aws_session_token = self ._cached_credentials ["SessionToken" ],
371+ region_name = self .aws_region ,
372+ )
373+
374+ # Need to assume role (either first time or credentials expired)
375+ session = Session (region_name = self .aws_region )
360376 current_role_arn , credential_source = get_current_identity ()
361377
362378 # Only assume role if:
@@ -368,7 +384,10 @@ def get_session(self) -> Session:
368384
369385 if should_assume_role :
370386 env = detect_aws_environment ()
371- logger .debug (f"Assuming role(s) from { env .value } environment" )
387+ role_arns = [role .RoleArn for role in target_roles ]
388+ logger .debug (
389+ f"Assuming { role_arns } role(s) from { env .value } environment"
390+ )
372391
373392 current_credentials = session .get_credentials ()
374393 if current_credentials is None :
@@ -381,14 +400,15 @@ def get_session(self) -> Session:
381400 }
382401
383402 for role in target_roles :
384- if self ._should_refresh_credentials ():
385- credentials = assume_role (
386- role = role ,
387- aws_region = self .aws_region ,
388- credentials = credentials ,
389- )
390- if isinstance (credentials ["Expiration" ], datetime ):
391- self ._credentials_expiration = credentials ["Expiration" ]
403+ credentials = assume_role (
404+ role = role ,
405+ aws_region = self .aws_region ,
406+ credentials = credentials ,
407+ )
408+
409+ if isinstance (credentials ["Expiration" ], datetime ):
410+ self ._credentials_expiration = credentials ["Expiration" ]
411+ self ._cached_credentials = credentials
392412
393413 session = Session (
394414 aws_access_key_id = credentials ["AccessKeyId" ],
@@ -398,12 +418,17 @@ def get_session(self) -> Session:
398418 )
399419 else :
400420 logger .debug (f"Using existing role from { credential_source } " )
421+ else :
422+ session = Session (region_name = self .aws_region )
401423
402424 return session
403425
404426 def _should_refresh_credentials (self ) -> bool :
405427 if self ._credentials_expiration is None :
406428 return True
429+ # Refresh credentials when less than 5 minutes remain before expiration.
430+ # This buffer helps avoid race conditions where credentials expire between
431+ # the cache check and actual AWS API usage.
407432 remaining_time = self ._credentials_expiration - datetime .now (timezone .utc )
408433 return remaining_time < timedelta (minutes = 5 )
409434
0 commit comments