diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index c57fa7980de45..c66de89820676 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -269,6 +269,7 @@ class AwsConnectionConfig(ConfigModel): """ _credentials_expiration: Optional[datetime] = None + _cached_credentials: Optional[dict] = None aws_access_key_id: Optional[str] = Field( default=None, @@ -353,10 +354,25 @@ def get_session(self) -> Session: ) else: # Use boto3's credential autodetection - session = Session(region_name=self.aws_region) - target_roles = self._normalized_aws_roles() if target_roles: + # If we have cached credentials that are still valid, use them + if ( + self._cached_credentials is not None + and not self._should_refresh_credentials() + ): + logger.debug("Using cached assumed role credentials") + return Session( + aws_access_key_id=self._cached_credentials["AccessKeyId"], + aws_secret_access_key=self._cached_credentials[ + "SecretAccessKey" + ], + aws_session_token=self._cached_credentials["SessionToken"], + region_name=self.aws_region, + ) + + # Need to assume role (either first time or credentials expired) + session = Session(region_name=self.aws_region) current_role_arn, credential_source = get_current_identity() # Only assume role if: @@ -368,7 +384,10 @@ def get_session(self) -> Session: if should_assume_role: env = detect_aws_environment() - logger.debug(f"Assuming role(s) from {env.value} environment") + role_arns = [role.RoleArn for role in target_roles] + logger.debug( + f"Assuming {role_arns} role(s) from {env.value} environment" + ) current_credentials = session.get_credentials() if current_credentials is None: @@ -381,14 +400,15 @@ def get_session(self) -> Session: } for role in target_roles: - if self._should_refresh_credentials(): - credentials = assume_role( - role=role, - aws_region=self.aws_region, - credentials=credentials, - ) - if isinstance(credentials["Expiration"], datetime): - self._credentials_expiration = credentials["Expiration"] + credentials = assume_role( + role=role, + aws_region=self.aws_region, + credentials=credentials, + ) + + if isinstance(credentials["Expiration"], datetime): + self._credentials_expiration = credentials["Expiration"] + self._cached_credentials = credentials session = Session( aws_access_key_id=credentials["AccessKeyId"], @@ -398,12 +418,17 @@ def get_session(self) -> Session: ) else: logger.debug(f"Using existing role from {credential_source}") + else: + session = Session(region_name=self.aws_region) return session def _should_refresh_credentials(self) -> bool: if self._credentials_expiration is None: return True + # Refresh credentials when less than 5 minutes remain before expiration. + # This buffer helps avoid race conditions where credentials expire between + # the cache check and actual AWS API usage. remaining_time = self._credentials_expiration - datetime.now(timezone.utc) return remaining_time < timedelta(minutes=5) diff --git a/metadata-ingestion/tests/unit/test_aws_common.py b/metadata-ingestion/tests/unit/test_aws_common.py index ec7538f4433a4..c85fa24d403c7 100644 --- a/metadata-ingestion/tests/unit/test_aws_common.py +++ b/metadata-ingestion/tests/unit/test_aws_common.py @@ -340,3 +340,318 @@ def test_environment_detection_parametrized( """Parametrized test for environment detection with different configurations""" with patch.dict(os.environ, env_vars, clear=True): assert detect_aws_environment() == expected_environment + + @mock_sts + def test_role_assumption_credentials_cached_across_sessions(self): + """ + Test that assumed role credentials are cached and reused across multiple + get_session() calls within the same AwsConnectionConfig instance. + + This is the main test for the bug fix where the second/third get_session() + calls were using base credentials instead of assumed role credentials. + """ + config = AwsConnectionConfig( + aws_region="us-east-1", + aws_role="arn:aws:iam::339713033063:role/dh-executor-s3-access-role", + ) + + with ( + patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity, + patch( + "datahub.ingestion.source.aws.aws_common.assume_role" + ) as mock_assume_role, + patch( + "datahub.ingestion.source.aws.aws_common.detect_aws_environment" + ) as mock_detect_env, + ): + # Setup mocks + mock_identity.return_value = ( + "arn:aws:sts::024848452848:assumed-role/dh-remote-executor/session", + "ecs.amazonaws.com", + ) + mock_detect_env.return_value = AwsEnvironment.ECS + + # Mock assume_role to return credentials with expiration + from datetime import datetime, timedelta, timezone + + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + mock_assume_role.return_value = { + "AccessKeyId": "ASSUMED_KEY_ID", + "SecretAccessKey": "ASSUMED_SECRET_KEY", + "SessionToken": "ASSUMED_SESSION_TOKEN", + "Expiration": expiration, + } + + # First call - should assume role + session1 = config.get_session() + creds1 = session1.get_credentials() + assert creds1 is not None + assert creds1.access_key == "ASSUMED_KEY_ID" + assert creds1.secret_key == "ASSUMED_SECRET_KEY" + assert creds1.token == "ASSUMED_SESSION_TOKEN" + assert mock_assume_role.call_count == 1 + + # Second call - should use cached credentials, NOT assume role again + session2 = config.get_session() + creds2 = session2.get_credentials() + assert creds2 is not None + assert creds2.access_key == "ASSUMED_KEY_ID" + assert creds2.secret_key == "ASSUMED_SECRET_KEY" + assert creds2.token == "ASSUMED_SESSION_TOKEN" + # assume_role should still only be called once + assert mock_assume_role.call_count == 1 + + # Third call - should still use cached credentials + session3 = config.get_session() + creds3 = session3.get_credentials() + assert creds3 is not None + assert creds3.access_key == "ASSUMED_KEY_ID" + assert creds3.secret_key == "ASSUMED_SECRET_KEY" + assert creds3.token == "ASSUMED_SESSION_TOKEN" + # assume_role should still only be called once + assert mock_assume_role.call_count == 1 + + @mock_sts + def test_role_assumption_refreshes_expired_credentials(self): + """ + Test that expired credentials trigger a new role assumption. + """ + config = AwsConnectionConfig( + aws_region="us-east-1", + aws_role="arn:aws:iam::123456789012:role/test-role", + ) + + with ( + patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity, + patch( + "datahub.ingestion.source.aws.aws_common.assume_role" + ) as mock_assume_role, + patch( + "datahub.ingestion.source.aws.aws_common.detect_aws_environment" + ) as mock_detect_env, + ): + mock_identity.return_value = (None, None) + mock_detect_env.return_value = AwsEnvironment.ECS + + from datetime import datetime, timedelta, timezone + + # First call - credentials that will expire soon (4 minutes remaining) + expiration1 = datetime.now(timezone.utc) + timedelta(minutes=4) + mock_assume_role.return_value = { + "AccessKeyId": "KEY_1", + "SecretAccessKey": "SECRET_1", + "SessionToken": "TOKEN_1", + "Expiration": expiration1, + } + + session1 = config.get_session() + creds1 = session1.get_credentials() + assert creds1 is not None + assert creds1.access_key == "KEY_1" + assert mock_assume_role.call_count == 1 + + # Second call - credentials are about to expire (< 5 min threshold) + # Should trigger refresh + expiration2 = datetime.now(timezone.utc) + timedelta(hours=1) + mock_assume_role.return_value = { + "AccessKeyId": "KEY_2", + "SecretAccessKey": "SECRET_2", + "SessionToken": "TOKEN_2", + "Expiration": expiration2, + } + + session2 = config.get_session() + creds2 = session2.get_credentials() + assert creds2 is not None + assert creds2.access_key == "KEY_2" + # assume_role should be called again due to expiration + assert mock_assume_role.call_count == 2 + + @mock_sts + def test_multiple_clients_use_same_cached_credentials(self): + """ + Test that multiple AWS clients (glue, s3, lakeformation) created from + the same config instance all use the same assumed role credentials. + + This simulates the real-world scenario where GlueSource.__init__ + creates multiple clients in quick succession. + """ + config = AwsConnectionConfig( + aws_region="us-east-1", + aws_role="arn:aws:iam::339713033063:role/cross-account-role", + ) + + with ( + patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity, + patch( + "datahub.ingestion.source.aws.aws_common.assume_role" + ) as mock_assume_role, + patch( + "datahub.ingestion.source.aws.aws_common.detect_aws_environment" + ) as mock_detect_env, + ): + mock_identity.return_value = ( + "arn:aws:sts::024848452848:assumed-role/base-role/session", + "ecs.amazonaws.com", + ) + mock_detect_env.return_value = AwsEnvironment.ECS + + from datetime import datetime, timedelta, timezone + + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + mock_assume_role.return_value = { + "AccessKeyId": "CROSS_ACCOUNT_KEY", + "SecretAccessKey": "CROSS_ACCOUNT_SECRET", + "SessionToken": "CROSS_ACCOUNT_TOKEN", + "Expiration": expiration, + } + + # Create multiple clients simulating GlueSource.__init__ + config.get_glue_client() + config.get_s3_client() + config.get_lakeformation_client() + + # Verify all clients use the same assumed credentials + # assume_role should only be called once + assert mock_assume_role.call_count == 1 + + @mock_sts + def test_role_assumption_without_caching_before_fix(self): + """ + This test demonstrates the bug that existed before the fix. + Without credential caching, the second get_session() call would + skip role assumption and return base credentials. + + This is a regression test to ensure the bug doesn't come back. + """ + config = AwsConnectionConfig( + aws_region="us-east-1", + aws_role="arn:aws:iam::123456789012:role/test-role", + ) + + # Simulate the old buggy behavior by clearing cached credentials + # after first call + with ( + patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity, + patch( + "datahub.ingestion.source.aws.aws_common.assume_role" + ) as mock_assume_role, + patch( + "datahub.ingestion.source.aws.aws_common.detect_aws_environment" + ) as mock_detect_env, + ): + mock_identity.return_value = (None, None) + mock_detect_env.return_value = AwsEnvironment.ECS + + from datetime import datetime, timedelta, timezone + + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + mock_assume_role.return_value = { + "AccessKeyId": "ASSUMED_KEY", + "SecretAccessKey": "ASSUMED_SECRET", + "SessionToken": "ASSUMED_TOKEN", + "Expiration": expiration, + } + + # First call + session1 = config.get_session() + creds1 = session1.get_credentials() + assert creds1 is not None + assert creds1.access_key == "ASSUMED_KEY" + + # Verify that second call also gets assumed credentials + # (not base credentials as in the bug) + session2 = config.get_session() + creds2 = session2.get_credentials() + assert creds2 is not None + assert creds2.access_key == "ASSUMED_KEY" + assert creds2.secret_key == "ASSUMED_SECRET" + assert creds2.token == "ASSUMED_TOKEN" + + @mock_sts + def test_role_assumption_with_explicit_credentials(self): + """ + Test that explicit credentials (aws_access_key_id/aws_secret_access_key) + take precedence and role assumption works with them. + """ + config = AwsConnectionConfig( + aws_access_key_id="EXPLICIT_KEY", + aws_secret_access_key="EXPLICIT_SECRET", + aws_region="us-east-1", + aws_role="arn:aws:iam::123456789012:role/test-role", + ) + + # When explicit credentials are provided, get_session() doesn't + # go through the role assumption logic, so this is just a sanity check + session = config.get_session() + creds = session.get_credentials() + assert creds is not None + assert creds.access_key == "EXPLICIT_KEY" + assert creds.secret_key == "EXPLICIT_SECRET" + + @mock_sts + def test_role_assumption_chain(self): + """ + Test assuming multiple roles in a chain. + """ + config = AwsConnectionConfig( + aws_region="us-east-1", + aws_role=[ + "arn:aws:iam::111111111111:role/role1", + "arn:aws:iam::222222222222:role/role2", + ], + ) + + with ( + patch( + "datahub.ingestion.source.aws.aws_common.get_current_identity" + ) as mock_identity, + patch( + "datahub.ingestion.source.aws.aws_common.assume_role" + ) as mock_assume_role, + patch( + "datahub.ingestion.source.aws.aws_common.detect_aws_environment" + ) as mock_detect_env, + ): + mock_identity.return_value = (None, None) + mock_detect_env.return_value = AwsEnvironment.ECS + + from datetime import datetime, timedelta, timezone + + expiration = datetime.now(timezone.utc) + timedelta(hours=1) + + # First role assumption + mock_assume_role.side_effect = [ + { + "AccessKeyId": "KEY_1", + "SecretAccessKey": "SECRET_1", + "SessionToken": "TOKEN_1", + "Expiration": expiration, + }, + { + "AccessKeyId": "KEY_2", + "SecretAccessKey": "SECRET_2", + "SessionToken": "TOKEN_2", + "Expiration": expiration, + }, + ] + + session = config.get_session() + creds = session.get_credentials() + + # Should use credentials from second role in chain + assert creds is not None + assert creds.access_key == "KEY_2" + assert creds.secret_key == "SECRET_2" + assert creds.token == "TOKEN_2" + + # Both roles should have been assumed + assert mock_assume_role.call_count == 2