Skip to content

Commit

Permalink
Adds ingest dbt tests (sodadata#577)
Browse files Browse the repository at this point in the history
A new CLI option `ingest` is added to parse and send the dbt test results to Soda Cloud.
  • Loading branch information
JCZuurmond authored Dec 10, 2021
1 parent 6ba9300 commit f30f3a8
Show file tree
Hide file tree
Showing 11 changed files with 1,263 additions and 375 deletions.
38 changes: 38 additions & 0 deletions core/sodasql/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
import sys
from datetime import datetime, timezone
from math import ceil
from pathlib import Path
from typing import Optional

import click
import yaml

from sodasql.__version__ import SODA_SQL_VERSION
from sodasql.cli.indenting_yaml_dumper import IndentingDumper
from sodasql.cli.ingest import ingest as _ingest
from sodasql.common.logging_helper import LoggingHelper
from sodasql.dataset_analyzer import DatasetAnalyzer
from sodasql.scan.file_system import FileSystemSingleton
Expand Down Expand Up @@ -489,3 +491,39 @@ def scan(scan_yml_file: str, warehouse_yml_file: str, variables: tuple, time: st
"https://github.com/sodadata/soda-sql/issues/new/choose")
logger.info(f'Exiting with code 1')
sys.exit(1)


@main.command(short_help="Ingest test information from different tools")
@click.argument(
"tool",
required=True,
type=click.Choice(["dbt"]),
)
@click.option(
"--warehouse-yml-file",
help="The warehouse yml file.",
required=True,
)
@click.option(
"--dbt-manifest",
help="The path to the dbt manifest file",
default=None,
type=Path,
)
@click.option(
"--dbt-run-results",
help="The path to the dbt run results file",
default=None,
type=Path,
)
def ingest(*args, **kwargs):
"""
Ingest test information from different tools.
For more details see :func:sodasql.cli.ingest.ingest
"""
_ingest(*args, **kwargs)


if __name__ == '__main__':
main()
228 changes: 228 additions & 0 deletions core/sodasql/cli/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""
CLI commands to ingest test results from various sources into the Soda cloud.
"""

from __future__ import annotations

import dataclasses
import datetime as dt
import json
import logging
import os
from pathlib import Path
from typing import Iterator

from sodasql.__version__ import SODA_SQL_VERSION
from sodasql.scan.scan_builder import (
build_warehouse_yml_parser,
create_soda_server_client,
)
from sodasql.scan.test import Test
from sodasql.scan.test_result import TestResult
from sodasql.soda_server_client.soda_server_client import SodaServerClient


@dataclasses.dataclass(frozen=True)
class Table:
"""Represents a table."""
name: str
schema: str
database: str


def map_dbt_run_result_to_test_result(
test_nodes: dict[str, "DbtTestNode"],
run_results: list["RunResultOutput"],
) -> dict[str, set["DbtModelNode"]]:
"""
Map run results to test results.
Parameters
----------
test_nodes : Dict[str: DbtTestNode]
The schema test nodes.
run_results : List[RunResultOutput]
The run results.
Returns
-------
out : dict[str, set[DbtModelNode]]
A mapping from run result to test result.
"""
from dbt.contracts.results import TestStatus

dbt_tests_with_soda_test = {
test_node.unique_id: Test(
id=test_node.unique_id,
title=f"Number of failures for {test_node.unique_id}",
expression=test_node.raw_sql,
metrics=None,
column=test_node.column_name,
)
for test_node in test_nodes.values()
}

tests_with_test_result = {
run_result.unique_id: TestResult(
dbt_tests_with_soda_test[run_result.unique_id],
passed=run_result.status == TestStatus.Pass,
skipped=run_result.status == TestStatus.Skipped,
values={"failures": run_result.failures},
)
for run_result in run_results
if run_result.unique_id in test_nodes.keys()
}
return tests_with_test_result


def map_dbt_test_results_iterator(
manifest_file: Path, run_results_file: Path
) -> Iterator[tuple[Table, list[TestResult]]]:
"""
Create an iterator for the dbt test results.
Parameters
----------
manifest_file : Path
The path to the manifest file.
run_results_file : Path
The path to the run results file.
Returns
-------
out : Iterator[tuple[Table, list[TestResult]]]
The table and its corresponding test results.
"""
try:
from sodasql import dbt as soda_dbt
except ImportError as e:
raise RuntimeError(
"Soda SQL dbt extension is not installed: $ pip install soda-sql-dbt"
) from e

with manifest_file.open("r") as file:
manifest = json.load(file)
with run_results_file.open("r") as file:
run_results = json.load(file)

model_nodes, seed_nodes, test_nodes = soda_dbt.parse_manifest(manifest)
parsed_run_results = soda_dbt.parse_run_results(run_results)
tests_with_test_result = map_dbt_run_result_to_test_result(
test_nodes, parsed_run_results
)
model_and_seed_nodes = {**model_nodes, **seed_nodes}
models_with_tests = soda_dbt.create_nodes_to_tests_mapping(
model_and_seed_nodes, test_nodes, parsed_run_results
)

for unique_id, test_unique_ids in models_with_tests.items():
table = Table(
model_and_seed_nodes[unique_id].alias,
model_and_seed_nodes[unique_id].database,
model_and_seed_nodes[unique_id].schema,
)
test_results = [
tests_with_test_result[test_unique_id] for test_unique_id in test_unique_ids
]

yield table, test_results


def flush_test_results(
test_results_iterator: Iterator[tuple[Table, list[TestResult]]],
soda_server_client: SodaServerClient,
*,
warehouse_name: str,
warehouse_type: str,
) -> None:
"""
Flush the test results.
Parameters
----------
test_results_iterator : Iterator[tuple[Table, list[TestResult]]]
The test results.
soda_server_client : SodaServerClient
The soda server client.
warehouse_name : str
The warehouse name.
warehouse_type : str
The warehouse (and dialect) type.
"""
for table, test_results in test_results_iterator:
test_results_jsons = [
test_result.to_dict()
for test_result in test_results
if not test_result.skipped
]
if len(test_results_jsons) == 0:
continue

start_scan_response = soda_server_client.scan_start(
warehouse_name=warehouse_name,
warehouse_type=warehouse_type,
warehouse_database_name=table.database,
warehouse_database_schema=table.schema,
table_name=table.name,
scan_yml_columns=None,
scan_time=dt.datetime.now().isoformat(),
origin=os.environ.get("SODA_SCAN_ORIGIN", "external"),
)
soda_server_client.scan_test_results(
start_scan_response["scanReference"], test_results_jsons
)
soda_server_client.scan_ended(start_scan_response["scanReference"])


def ingest(
tool: str,
warehouse_yml_file: str,
dbt_manifest: Path | None = None,
dbt_run_results: Path | None = None,
) -> None:
"""
Ingest test information from different tools.
Arguments
---------
tool : str {'dbt'}
The tool name.
warehouse_yml_file : str
The warehouse yml file.
dbt_manifest : Optional[Path]
The path to the dbt manifest.
dbt_run_results : Optional[Path]
The path to the dbt run results.
Raises
------
ValueError :
If the tool is unrecognized.
"""
logger = logging.getLogger(__name__)
logger.info(SODA_SQL_VERSION)

warehouse_yml_parser = build_warehouse_yml_parser(warehouse_yml_file)
warehouse_yml = warehouse_yml_parser.warehouse_yml

soda_server_client = create_soda_server_client(warehouse_yml)
if not soda_server_client.api_key_id or not soda_server_client.api_key_secret:
raise ValueError("Missing Soda cloud api key id and/or secret.")

if tool == "dbt":
if dbt_manifest is None:
raise ValueError(f"Dbt manifest is required: {dbt_manifest}")
if dbt_run_results is None:
raise ValueError(f"Dbt run results is required: {dbt_run_results}")
test_results_iterator = map_dbt_test_results_iterator(
dbt_manifest, dbt_run_results
)
else:
raise ValueError(f"Unknown tool: {tool}")

flush_test_results(
test_results_iterator,
soda_server_client,
warehouse_name=warehouse_yml.name,
warehouse_type=warehouse_yml.dialect.type,
)
9 changes: 7 additions & 2 deletions core/sodasql/scan/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,10 +899,15 @@ def _flush_test_results(self, test_results: List[TestResult]):

def _ensure_scan_reference(self):
if self.soda_server_client and not self.scan_reference:
database_and_schema = self.warehouse.dialect.get_warehouse_name_and_schema()
try:
self.start_scan_response = self.soda_server_client.scan_start(
self.warehouse,
self.scan_yml,
self.warehouse.name,
self.warehouse.dialect.type,
database_and_schema.get("database_name"),
database_and_schema.get("database_schema"),
self.scan_yml.table_name,
self.scan_yml.columns,
self.time,
origin=os.environ.get('SODA_SCAN_ORIGIN', 'external'))
self.scan_reference = self.start_scan_response['scanReference']
Expand Down
Loading

0 comments on commit f30f3a8

Please sign in to comment.