Skip to content

Commit

Permalink
Ingest artifacts from dbt cloud (sodadata#607)
Browse files Browse the repository at this point in the history
This requires dbt_cloud_api_token to be set under the project name in
~/.soda/env_vars

```yaml
dbt_cloud_api_token: xxx
```
  • Loading branch information
JCZuurmond authored Dec 28, 2021
1 parent 0221886 commit 7644459
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 141 deletions.
12 changes: 12 additions & 0 deletions core/sodasql/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,18 @@ def scan(scan_yml_file: str, warehouse_yml_file: str, variables: tuple, time: st
default=None,
type=Path,
)
@click.option(
"--dbt-cloud-account-id",
help="The id of your dbt cloud account",
default=None,
type=Path,
)
@click.option(
"--dbt-cloud-run-id",
help="The id of the dbt job run of which you would like to ingest the test results",
default=None,
type=Path,
)
def ingest(*args, **kwargs):
"""
Ingest test information from different tools.
Expand Down
199 changes: 161 additions & 38 deletions core/sodasql/cli/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import json
import logging
import os
import requests
from pathlib import Path
from typing import Iterator, Optional, Tuple
from typing import Iterator
from requests.structures import CaseInsensitiveDict

from sodasql.__version__ import SODA_SQL_VERSION
from sodasql.scan.scan_builder import (
Expand All @@ -21,6 +23,7 @@
from sodasql.scan.test_result import TestResult
from sodasql.soda_server_client.soda_server_client import SodaServerClient

KEY_DBT_CLOUD_API_TOKEN = "dbt_cloud_api_token"

@dataclasses.dataclass(frozen=True)
class Table:
Expand Down Expand Up @@ -78,17 +81,17 @@ def map_dbt_run_result_to_test_result(


def map_dbt_test_results_iterator(
manifest_file: Path, run_results_file: Path
manifest: dict, run_results: dict
) -> Iterator[tuple[Table, list[TestResult]]]:
"""
Create an iterator for the dbt test results.
Parameters
----------
manifest_file : Path
The path to the manifest file.
run_results_file : Path
The path to the run results file.
manifest : dict
The manifest.
run_results : dict
The run results
Returns
-------
Expand All @@ -103,11 +106,6 @@ def map_dbt_test_results_iterator(
"Soda SQL dbt extension is not installed: $ pip install soda-sql-dbt"
) from e

with manifest_file.open("r") as file:
manifest = json.load(file)
with run_results_file.open("r") as file:
run_results = json.load(file)

model_nodes, seed_nodes, test_nodes, source_nodes = soda_dbt.parse_manifest(manifest)
parsed_run_results = soda_dbt.parse_run_results(run_results)
tests_with_test_result = map_dbt_run_result_to_test_result(test_nodes, parsed_run_results)
Expand Down Expand Up @@ -153,7 +151,9 @@ def flush_test_results(
"""
for table, test_results in test_results_iterator:
test_results_jsons = [
test_result.to_dict() for test_result in test_results if not test_result.skipped
test_result.to_dict()
for test_result in test_results
if not test_result.skipped
]
if len(test_results_jsons) == 0:
continue
Expand All @@ -174,25 +174,104 @@ def flush_test_results(
soda_server_client.scan_ended(start_scan_response["scanReference"])


def resolve_artifacts_paths(
dbt_artifacts: Optional[Path] = None,
dbt_manifest: Optional[Path] = None,
dbt_run_results: Optional[Path] = None
) -> Tuple[Path, Path]:
if dbt_artifacts:
dbt_manifest = Path(dbt_artifacts) / 'manifest.json'
dbt_run_results = Path(dbt_artifacts) / 'run_results.json'
elif dbt_manifest is None:
raise ValueError(
"--dbt-manifest or --dbt-artifacts are required. "
f"Currently, dbt_manifest={dbt_manifest} and dbt_artifacts={dbt_artifacts}"
)
elif dbt_run_results is None:
raise ValueError(
"--dbt-run-results or --dbt-artifacts are required. "
f"Currently, dbt_run_results={dbt_manifest} and dbt_artifacts={dbt_artifacts}"
)
return dbt_manifest, dbt_run_results
def load_dbt_artifacts(
manifest_file: Path,
run_results_file: Path,
) -> tuple[dict, dict]:
"""
Resolve artifacts.
Arguments
---------
manifest_file : Path
The manifest file.
run_results_file : Path
The run results file.
Return
------
out : tuple[dict, dict]
The loaded manifest and run results.
"""
with manifest_file.open("r") as file:
manifest = json.load(file)
with run_results_file.open("r") as file:
run_results = json.load(file)
return manifest, run_results


def download_dbt_artifact_from_cloud(
artifact: str,
api_token: str,
account_id: str,
run_id: str,
) -> dict:
"""
Download an artifact from the dbt cloud.
Parameters
----------
artifact : str
The artifact name.
api_token : str
The dbt cloud API token.
account_id: str :
The account id.
run_id : str
The run id.
Returns
-------
out : dict
The artifact.
Sources
-------
https://docs.getdbt.com/dbt-cloud/api-v2#operation/getArtifactsByRunId
"""
url = f"https://cloud.getdbt.com/api/v2/accounts/{account_id}/runs/{run_id}/artifacts/{artifact}"

headers = CaseInsensitiveDict()
headers["Authorization"] = f"Token {api_token}"
headers["Content-Type"] = "application/json"

response = requests.get(url, headers=headers)

if response.status_code != requests.codes.ok:
response.raise_for_status()

return response.json()


def download_dbt_artifacts_from_cloud(
api_token: str,
account_id: str,
run_id: str,
) -> tuple[dict, dict]:
"""
Download the dbt artifacts from the cloud.
Parameters
----------
api_token : str
The dbt cloud API token.
account_id : str
The account id.
run_id : str
The run id.
Returns
-------
out : tuple[dict, dict]
The loaded artifacts.
"""
manifest = download_dbt_artifact_from_cloud(
"manifest.json", api_token, account_id, run_id
)
run_results = download_dbt_artifact_from_cloud(
"run_results.json", api_token, account_id, run_id
)
return manifest, run_results


def ingest(
Expand All @@ -201,6 +280,8 @@ def ingest(
dbt_artifacts: Path | None = None,
dbt_manifest: Path | None = None,
dbt_run_results: Path | None = None,
dbt_cloud_account_id: str | None = None,
dbt_cloud_run_id: str | None = None,
) -> None:
"""
Ingest test information from different tools.
Expand All @@ -218,6 +299,10 @@ def ingest(
The path to the dbt manifest.
dbt_run_results : Optional[Path]
The path to the dbt run results.
dbt_cloud_account_id: str :
The id of a dbt cloud account.
dbt_cloud_run_id : str
The id of a job run in the dbt cloud.
Raises
------
Expand All @@ -228,19 +313,57 @@ def ingest(
logger.info(SODA_SQL_VERSION)

warehouse_yml_parser = build_warehouse_yml_parser(warehouse_yml_file)
dbt_cloud_api_token = warehouse_yml_parser.get_str_required_env(KEY_DBT_CLOUD_API_TOKEN)
warehouse_yml = warehouse_yml_parser.warehouse_yml

soda_server_client = create_soda_server_client(warehouse_yml)
if not soda_server_client.api_key_id or not soda_server_client.api_key_secret:
raise ValueError("Missing Soda cloud api key id and/or secret.")

if tool == 'dbt':
dbt_manifest, dbt_run_results = resolve_artifacts_paths(
dbt_artifacts=dbt_artifacts,
dbt_manifest=dbt_manifest,
dbt_run_results=dbt_run_results
)
test_results_iterator = map_dbt_test_results_iterator(dbt_manifest, dbt_run_results)
if tool == "dbt":
if (
dbt_artifacts is not None
or dbt_manifest is not None
or dbt_run_results is not None
):
if dbt_artifacts is not None:
dbt_manifest = dbt_artifacts / "manifest.json"
dbt_run_results = dbt_artifacts / "run_results.json"

if dbt_manifest is None or not dbt_manifest.is_file():
raise ValueError(
f"dbt manifest ({dbt_manifest}) or artifacts ({dbt_artifacts}) "
"should point to an existing path."
)
elif dbt_run_results is None or not dbt_run_results.is_file():
raise ValueError(
f"dbt run results ({dbt_run_results}) or artifacts ({dbt_artifacts}) "
"should point to an existing path."
)

manifest, run_results = load_dbt_artifacts(
dbt_manifest,
dbt_run_results,
)
else:
error_values = [dbt_cloud_api_token, dbt_cloud_account_id, dbt_cloud_run_id]
error_messages = [
f"Expecting a dbt cloud api token: {dbt_cloud_api_token}",
f"Expecting a dbt cloud account id: {dbt_cloud_account_id}",
f"Expecting a dbt cloud job run id: {dbt_cloud_run_id}",
]
filtered_messages = [
message
for value, message in zip(error_values, error_messages)
if value is None
]
if len(filtered_messages) > 0:
raise ValueError("\n".join(filtered_messages))
manifest, run_results = download_dbt_artifacts_from_cloud(
dbt_cloud_api_token, dbt_cloud_account_id, dbt_cloud_run_id
)

test_results_iterator = map_dbt_test_results_iterator(manifest, run_results)
else:
raise NotImplementedError(f"Unknown tool: {tool}")

Expand Down
13 changes: 11 additions & 2 deletions core/sodasql/scan/warehouse_yml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
KEY_NAME = 'name'
KEY_CONNECTION = 'connection'
KEY_SODA_ACCOUNT = 'soda_account'
KEY_INGEST = 'ingest'

SODA_KEY_HOST = 'host'
SODA_KEY_PORT = 'port'
SODA_KEY_PROTOCOL = 'protocol'
SODA_KEY_API_KEY_ID = 'api_key_id'
SODA_KEY_API_KEY_SECRET = 'api_key_secret'

VALID_WAREHOUSE_KEYS = [KEY_NAME, KEY_CONNECTION, KEY_SODA_ACCOUNT]
DBT_CLOUD_KEY_API_TOKEN = "dbt_cloud_api_token"

VALID_WAREHOUSE_KEYS = [KEY_NAME, KEY_CONNECTION, KEY_SODA_ACCOUNT, KEY_INGEST]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,7 +81,13 @@ def __init__(self,
self.warehouse_yml.soda_api_key_secret = self.get_str_required_env(SODA_KEY_API_KEY_SECRET)
self._pop_context()

ingest_dict = self.get_dict_optional(KEY_INGEST)
if ingest_dict:
self._push_context(object=ingest_dict, name=KEY_INGEST)
self.warehouse_yml.dbt_cloud_api_token = self.get_str_optional(DBT_CLOUD_KEY_API_TOKEN)
self._pop_context()

self.check_invalid_keys(VALID_WAREHOUSE_KEYS)

else:
self.error('No warehouse configuration provided')
self.error('No warehouse configuration provided')
12 changes: 2 additions & 10 deletions packages/dbt/sodasql/dbt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""DBT integeration"""
"""dbt integeration"""

from __future__ import annotations

Expand All @@ -25,7 +25,7 @@
def parse_manifest(
manifest: dict[str, Any]
) -> tuple[
dict[str, ParsedModelNode | CompileModelNode],
dict[str, ParsedModelNode | CompiledModelNode],
dict[str, ParsedSeedNode | CompiledSeedNode],
dict[str, ParsedGenericTestNode | CompiledGenericTestNode],
dict[str, ParsedSourceDefinition],
Expand Down Expand Up @@ -59,10 +59,6 @@ def parse_manifest(
------
https://docs.getdbt.com/reference/artifacts/manifest-json
"""
dbt_v4_schema = "https://schemas.getdbt.com/dbt/manifest/v4.json"
if manifest["metadata"]["dbt_schema_version"] != dbt_v4_schema:
raise NotImplementedError("Dbt manifest parsing only supported for V4 schema.")

model_nodes = {
node_name: CompiledModelNode(**node)
if "compiled" in node.keys()
Expand Down Expand Up @@ -115,10 +111,6 @@ def parse_run_results(run_results: dict[str, Any]) -> list[RunResultOutput]:
------
https://docs.getdbt.com/reference/artifacts/run-results-json
"""
dbt_v4_schema = "https://schemas.getdbt.com/dbt/run-results/v4.json"
if run_results["metadata"]["dbt_schema_version"] != dbt_v4_schema:
raise NotImplementedError("Dbt run results parsing only supported for v4 schema.")

parsed_run_results = [RunResultOutput(**result) for result in run_results["results"]]
return parsed_run_results

Expand Down
Loading

0 comments on commit 7644459

Please sign in to comment.