Skip to content

Commit

Permalink
fix(ingest): okta - better use of asyncio and additional debug logging (
Browse files Browse the repository at this point in the history
  • Loading branch information
Aditya Radhakrishnan authored Feb 11, 2022
1 parent 9bdc9af commit b331106
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 27 deletions.
1 change: 1 addition & 0 deletions metadata-ingestion/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ disallow_untyped_defs = yes
disallow_untyped_defs = yes

[tool:pytest]
asyncio_mode = auto
addopts = --cov=src --cov-report term-missing --cov-config setup.cfg --strict-markers
markers =
integration: marks tests to only run in integration (deselect with '-m "not integration"')
Expand Down
1 change: 1 addition & 0 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def get_long_description():
# Waiting for https://github.com/samuelcolvin/pydantic/pull/3175 before allowing mypy 0.920.
"mypy>=0.901,<0.920",
"pytest>=6.2.2",
"pytest-asyncio>=0.16.0",
"pytest-cov>=2.8.1",
"pytest-docker>=0.10.3",
"tox",
Expand Down
105 changes: 78 additions & 27 deletions metadata-ingestion/src/datahub/ingestion/source/identity/okta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Iterable, List, Union

from okta.client import Client as OktaClient
from okta.exceptions import OktaAPIException
from okta.models import Group, GroupProfile, User, UserProfile, UserStatus

from datahub.configuration import ConfigModel
Expand Down Expand Up @@ -96,9 +97,12 @@ def __init__(self, config: OktaConfig, ctx: PipelineContext):

def get_workunits(self) -> Iterable[MetadataWorkUnit]:

# Step 0: create the event loop
event_loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()

# Step 1: Produce MetadataWorkUnits for CorpGroups.
if self.config.ingest_groups:
okta_groups = list(self._get_okta_groups())
okta_groups = list(self._get_okta_groups(event_loop))
datahub_corp_group_snapshots = self._map_okta_groups(okta_groups)
for datahub_corp_group_snapshot in datahub_corp_group_snapshots:
mce = MetadataChangeEvent(proposedSnapshot=datahub_corp_group_snapshot)
Expand All @@ -122,7 +126,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
continue

# Extract and map users for each group.
okta_group_users = self._get_okta_group_users(okta_group)
okta_group_users = self._get_okta_group_users(okta_group, event_loop)
for okta_user in okta_group_users:
datahub_corp_user_urn = self._map_okta_user_profile_to_urn(
okta_user.profile
Expand Down Expand Up @@ -150,7 +154,7 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:

# Step 3: Produce MetadataWorkUnits for CorpUsers.
if self.config.ingest_users:
okta_users = self._get_okta_users()
okta_users = self._get_okta_users(event_loop)
filtered_okta_users = filter(self._filter_okta_user, okta_users)
datahub_corp_user_snapshots = self._map_okta_users(filtered_okta_users)
for datahub_corp_user_snapshot in datahub_corp_user_snapshots:
Expand All @@ -172,6 +176,9 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
self.report.report_workunit(wu)
yield wu

# Step 4: Close the event loop
event_loop.close()

def get_report(self):
return self.report

Expand All @@ -183,69 +190,113 @@ def _create_okta_client(self):
config = {
"orgUrl": f"https://{self.config.okta_domain}",
"token": f"{self.config.okta_api_token}",
"raiseException": True,
}
return OktaClient(config)

# Retrieves all Okta Group Objects in batches.
def _get_okta_groups(self) -> Iterable[Group]:
def _get_okta_groups(
self, event_loop: asyncio.AbstractEventLoop
) -> Iterable[Group]:
logger.debug("Extracting all Okta groups")

# Note that this is not taking full advantage of Python AsyncIO, as we are blocking on calls.
query_parameters = {"limit": self.config.page_size}
groups, resp, err = asyncio.get_event_loop().run_until_complete(
self.okta_client.list_groups(query_parameters)
)
try:
groups, resp, err = event_loop.run_until_complete(
self.okta_client.list_groups(query_parameters)
)
except OktaAPIException as api_err:
self.report.report_failure(
"okta_groups", f"Failed to fetch Groups from Okta API: {api_err}"
)
while True:
if err is not None:
if err:
self.report.report_failure(
"okta_groups", f"Failed to fetch Groups from Okta API: {err}"
)
if groups is not None:
if groups:
for group in groups:
yield group
if resp is not None and resp.has_next():
if resp and resp.has_next():
sleep(self.config.delay_seconds)
groups, err = asyncio.get_event_loop().run_until_complete(resp.next())
try:
groups, err = event_loop.run_until_complete(resp.next())
except OktaAPIException as api_err:
self.report.report_failure(
"okta_groups",
f"Failed to fetch Groups from Okta API: {api_err}",
)
else:
break

# Retrieves Okta User Objects in a particular Okta Group in batches.
def _get_okta_group_users(self, group: Group) -> Iterable[User]:
def _get_okta_group_users(
self, group: Group, event_loop: asyncio.AbstractEventLoop
) -> Iterable[User]:
logger.debug(f"Extracting users from Okta group named {group.profile.name}")

# Note that this is not taking full advantage of Python AsyncIO; we are blocking on calls.
query_parameters = {"limit": self.config.page_size}
users, resp, err = asyncio.get_event_loop().run_until_complete(
self.okta_client.list_group_users(group.id, query_parameters)
)
try:
users, resp, err = event_loop.run_until_complete(
self.okta_client.list_group_users(group.id, query_parameters)
)
except OktaAPIException as api_err:
self.report.report_failure(
"okta_group_users",
f"Failed to fetch Users of Group {group.profile.name} from Okta API: {api_err}",
)
while True:
if err is not None:
if err:
self.report.report_failure(
"okta_group_users",
f"Failed to fetch Users of Group {group.profile.name} from Okta API: {err}",
)
if users is not None:
if users:
for user in users:
yield user
if resp is not None and resp.has_next():
if resp and resp.has_next():
sleep(self.config.delay_seconds)
users, err = asyncio.get_event_loop().run_until_complete(resp.next())
try:
users, err = event_loop.run_until_complete(resp.next())
except OktaAPIException as api_err:
self.report.report_failure(
"okta_group_users",
f"Failed to fetch Users of Group {group.profile.name} from Okta API: {api_err}",
)
else:
break

# Retrieves all Okta User Objects in batches.
def _get_okta_users(self) -> Iterable[User]:
def _get_okta_users(self, event_loop: asyncio.AbstractEventLoop) -> Iterable[User]:
logger.debug("Extracting all Okta users")

query_parameters = {"limit": self.config.page_size}
users, resp, err = asyncio.get_event_loop().run_until_complete(
self.okta_client.list_users(query_parameters)
)
try:
users, resp, err = event_loop.run_until_complete(
self.okta_client.list_users(query_parameters)
)
except OktaAPIException as api_err:
self.report.report_failure(
"okta_users", f"Failed to fetch Users from Okta API: {api_err}"
)
while True:
if err is not None:
if err:
self.report.report_failure(
"okta_users", f"Failed to fetch Users from Okta API: {err}"
)
if users is not None:
if users:
for user in users:
yield user
if resp is not None and resp.has_next():
if resp and resp.has_next():
sleep(self.config.delay_seconds)
users, err = asyncio.get_event_loop().run_until_complete(resp.next())
try:
users, err = event_loop.run_until_complete(resp.next())
except OktaAPIException as api_err:
self.report.report_failure(
"okta_users", f"Failed to fetch Users from Okta API: {api_err}"
)
else:
break

Expand Down
3 changes: 3 additions & 0 deletions metadata-ingestion/tests/integration/okta/test_okta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest.mock import Mock, patch

import jsonpickle
import pytest
from freezegun import freeze_time
from okta.models import Group, User

Expand Down Expand Up @@ -135,6 +136,7 @@ def test_okta_source_ingestion_disabled(pytestconfig, tmp_path):


@freeze_time(FROZEN_TIME)
@pytest.mark.asyncio
def test_okta_source_include_deprovisioned_suspended_users(pytestconfig, tmp_path):

test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta"
Expand Down Expand Up @@ -185,6 +187,7 @@ def test_okta_source_include_deprovisioned_suspended_users(pytestconfig, tmp_pat


@freeze_time(FROZEN_TIME)
@pytest.mark.asyncio
def test_okta_source_custom_user_name_regex(pytestconfig, tmp_path):

test_resources_dir: pathlib.Path = pytestconfig.rootpath / "tests/integration/okta"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pyparsing==2.4.7
pyrsistent==0.16.1
pyspark==3.0.3
pytest==6.2.5
pytest-asyncio==0.16.0
pytest-cov==3.0.0
pytest-docker==0.10.3
python-daemon==2.3.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ pyparsing==2.4.7
pyrsistent==0.18.1
pyspark==3.0.3
pytest==6.2.5
pytest-asyncio==0.16.0
pytest-cov==3.0.0
pytest-docker==0.10.3
python-daemon==2.3.0
Expand Down

0 comments on commit b331106

Please sign in to comment.