diff --git a/metadata-ingestion/setup.cfg b/metadata-ingestion/setup.cfg index c241fd5de046a..6966fd4df6fd0 100644 --- a/metadata-ingestion/setup.cfg +++ b/metadata-ingestion/setup.cfg @@ -51,6 +51,7 @@ disallow_untyped_defs = yes disallow_untyped_defs = yes [tool:pytest] +asyncio_mode = auto addopts = --cov=src --cov-report term-missing --cov-config setup.cfg --strict-markers markers = integration: marks tests to only run in integration (deselect with '-m "not integration"') diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 5b8958b34fb2b..a6e6f0369e639 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -189,6 +189,7 @@ def get_long_description(): # Waiting for https://github.com/samuelcolvin/pydantic/pull/3175 before allowing mypy 0.920. "mypy>=0.901,<0.920", "pytest>=6.2.2", + "pytest-asyncio>=0.16.0", "pytest-cov>=2.8.1", "pytest-docker>=0.10.3", "tox", diff --git a/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py b/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py index d4b6eef89d7f1..ced3c307e6102 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py +++ b/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py @@ -7,6 +7,7 @@ from typing import Dict, Iterable, List, Union from okta.client import Client as OktaClient +from okta.exceptions import OktaAPIException from okta.models import Group, GroupProfile, User, UserProfile, UserStatus from datahub.configuration import ConfigModel @@ -96,9 +97,12 @@ def __init__(self, config: OktaConfig, ctx: PipelineContext): def get_workunits(self) -> Iterable[MetadataWorkUnit]: + # Step 0: create the event loop + event_loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() + # Step 1: Produce MetadataWorkUnits for CorpGroups. if self.config.ingest_groups: - okta_groups = list(self._get_okta_groups()) + okta_groups = list(self._get_okta_groups(event_loop)) datahub_corp_group_snapshots = self._map_okta_groups(okta_groups) for datahub_corp_group_snapshot in datahub_corp_group_snapshots: mce = MetadataChangeEvent(proposedSnapshot=datahub_corp_group_snapshot) @@ -122,7 +126,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: continue # Extract and map users for each group. - okta_group_users = self._get_okta_group_users(okta_group) + okta_group_users = self._get_okta_group_users(okta_group, event_loop) for okta_user in okta_group_users: datahub_corp_user_urn = self._map_okta_user_profile_to_urn( okta_user.profile @@ -150,7 +154,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: # Step 3: Produce MetadataWorkUnits for CorpUsers. if self.config.ingest_users: - okta_users = self._get_okta_users() + okta_users = self._get_okta_users(event_loop) filtered_okta_users = filter(self._filter_okta_user, okta_users) datahub_corp_user_snapshots = self._map_okta_users(filtered_okta_users) for datahub_corp_user_snapshot in datahub_corp_user_snapshots: @@ -172,6 +176,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]: self.report.report_workunit(wu) yield wu + # Step 4: Close the event loop + event_loop.close() + def get_report(self): return self.report @@ -183,69 +190,113 @@ def _create_okta_client(self): config = { "orgUrl": f"https://{self.config.okta_domain}", "token": f"{self.config.okta_api_token}", + "raiseException": True, } return OktaClient(config) # Retrieves all Okta Group Objects in batches. - def _get_okta_groups(self) -> Iterable[Group]: + def _get_okta_groups( + self, event_loop: asyncio.AbstractEventLoop + ) -> Iterable[Group]: + logger.debug("Extracting all Okta groups") + # Note that this is not taking full advantage of Python AsyncIO, as we are blocking on calls. query_parameters = {"limit": self.config.page_size} - groups, resp, err = asyncio.get_event_loop().run_until_complete( - self.okta_client.list_groups(query_parameters) - ) + try: + groups, resp, err = event_loop.run_until_complete( + self.okta_client.list_groups(query_parameters) + ) + except OktaAPIException as api_err: + self.report.report_failure( + "okta_groups", f"Failed to fetch Groups from Okta API: {api_err}" + ) while True: - if err is not None: + if err: self.report.report_failure( "okta_groups", f"Failed to fetch Groups from Okta API: {err}" ) - if groups is not None: + if groups: for group in groups: yield group - if resp is not None and resp.has_next(): + if resp and resp.has_next(): sleep(self.config.delay_seconds) - groups, err = asyncio.get_event_loop().run_until_complete(resp.next()) + try: + groups, err = event_loop.run_until_complete(resp.next()) + except OktaAPIException as api_err: + self.report.report_failure( + "okta_groups", + f"Failed to fetch Groups from Okta API: {api_err}", + ) else: break # Retrieves Okta User Objects in a particular Okta Group in batches. - def _get_okta_group_users(self, group: Group) -> Iterable[User]: + def _get_okta_group_users( + self, group: Group, event_loop: asyncio.AbstractEventLoop + ) -> Iterable[User]: + logger.debug(f"Extracting users from Okta group named {group.profile.name}") + # Note that this is not taking full advantage of Python AsyncIO; we are blocking on calls. query_parameters = {"limit": self.config.page_size} - users, resp, err = asyncio.get_event_loop().run_until_complete( - self.okta_client.list_group_users(group.id, query_parameters) - ) + try: + users, resp, err = event_loop.run_until_complete( + self.okta_client.list_group_users(group.id, query_parameters) + ) + except OktaAPIException as api_err: + self.report.report_failure( + "okta_group_users", + f"Failed to fetch Users of Group {group.profile.name} from Okta API: {api_err}", + ) while True: - if err is not None: + if err: self.report.report_failure( "okta_group_users", f"Failed to fetch Users of Group {group.profile.name} from Okta API: {err}", ) - if users is not None: + if users: for user in users: yield user - if resp is not None and resp.has_next(): + if resp and resp.has_next(): sleep(self.config.delay_seconds) - users, err = asyncio.get_event_loop().run_until_complete(resp.next()) + try: + users, err = event_loop.run_until_complete(resp.next()) + except OktaAPIException as api_err: + self.report.report_failure( + "okta_group_users", + f"Failed to fetch Users of Group {group.profile.name} from Okta API: {api_err}", + ) else: break # Retrieves all Okta User Objects in batches. - def _get_okta_users(self) -> Iterable[User]: + def _get_okta_users(self, event_loop: asyncio.AbstractEventLoop) -> Iterable[User]: + logger.debug("Extracting all Okta users") + query_parameters = {"limit": self.config.page_size} - users, resp, err = asyncio.get_event_loop().run_until_complete( - self.okta_client.list_users(query_parameters) - ) + try: + users, resp, err = event_loop.run_until_complete( + self.okta_client.list_users(query_parameters) + ) + except OktaAPIException as api_err: + self.report.report_failure( + "okta_users", f"Failed to fetch Users from Okta API: {api_err}" + ) while True: - if err is not None: + if err: self.report.report_failure( "okta_users", f"Failed to fetch Users from Okta API: {err}" ) - if users is not None: + if users: for user in users: yield user - if resp is not None and resp.has_next(): + if resp and resp.has_next(): sleep(self.config.delay_seconds) - users, err = asyncio.get_event_loop().run_until_complete(resp.next()) + try: + users, err = event_loop.run_until_complete(resp.next()) + except OktaAPIException as api_err: + self.report.report_failure( + "okta_users", f"Failed to fetch Users from Okta API: {api_err}" + ) else: break diff --git a/metadata-ingestion/tests/integration/okta/test_okta.py b/metadata-ingestion/tests/integration/okta/test_okta.py index e4fa6f75b59f5..189af3eea7428 100644 --- a/metadata-ingestion/tests/integration/okta/test_okta.py +++ b/metadata-ingestion/tests/integration/okta/test_okta.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonpickle +import pytest from freezegun import freeze_time from okta.models import Group, User @@ -135,6 +136,7 @@ def test_okta_source_ingestion_disabled(pytestconfig, tmp_path): @freeze_time(FROZEN_TIME) +@pytest.mark.asyncio def test_okta_source_include_deprovisioned_suspended_users(pytestconfig, tmp_path): test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta" @@ -185,6 +187,7 @@ def test_okta_source_include_deprovisioned_suspended_users(pytestconfig, tmp_pat @freeze_time(FROZEN_TIME) +@pytest.mark.asyncio def test_okta_source_custom_user_name_regex(pytestconfig, tmp_path): test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta" diff --git a/metadata-ingestion/tox_requirements/py36-full_requirements.txt b/metadata-ingestion/tox_requirements/py36-full_requirements.txt index 169165d2947fe..eb6109b34fba4 100644 --- a/metadata-ingestion/tox_requirements/py36-full_requirements.txt +++ b/metadata-ingestion/tox_requirements/py36-full_requirements.txt @@ -219,6 +219,7 @@ pyparsing==2.4.7 pyrsistent==0.16.1 pyspark==3.0.3 pytest==6.2.5 +pytest-asyncio==0.16.0 pytest-cov==3.0.0 pytest-docker==0.10.3 python-daemon==2.3.0 diff --git a/metadata-ingestion/tox_requirements/py39-full_requirements.txt b/metadata-ingestion/tox_requirements/py39-full_requirements.txt index a584e52511723..325db150b788f 100644 --- a/metadata-ingestion/tox_requirements/py39-full_requirements.txt +++ b/metadata-ingestion/tox_requirements/py39-full_requirements.txt @@ -215,6 +215,7 @@ pyparsing==2.4.7 pyrsistent==0.18.1 pyspark==3.0.3 pytest==6.2.5 +pytest-asyncio==0.16.0 pytest-cov==3.0.0 pytest-docker==0.10.3 python-daemon==2.3.0