Skip to content

Commit

Permalink
Added Python formatting and pre-commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
elongl committed Jul 9, 2023
1 parent b8ce82c commit a648ae6
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 73 deletions.
7 changes: 7 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[flake8]
ignore =
W503
E203
E501
E402

18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
exclude: '^dbt_packages$'

repos:
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
- id: black

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]

- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8

- repo: local
hooks:
- id: no_commit
Expand Down
80 changes: 52 additions & 28 deletions integration_tests/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EPOCH = datetime.utcfromtimestamp(0)
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"


def generate_fake_data():
generate_string_anomalies_training_and_validation_files()
generate_numeric_anomalies_training_and_validation_files()
Expand Down Expand Up @@ -330,31 +331,57 @@ def get_validation_row(date, row_index, rows_count):


def generate_seasonality_volume_anomalies_files():
# seasonal_data : should trigger volume anomalies and not trigger
columns = ["updated_at",
"user_id"]
dates = generate_rows_timestamps(base_date=EPOCH-timedelta(days=7), days_back=210) # 7 * 30 days backwards
# seasonal_data : should trigger volume anomalies and not trigger
columns = ["updated_at", "user_id"]
dates = generate_rows_timestamps(
base_date=EPOCH - timedelta(days=7), days_back=210
) # 7 * 30 days backwards
training_rows = []
for ix,date in enumerate(dates):
for ix, date in enumerate(dates):
if (ix % 7) == 1:
training_rows.extend([{"updated_at": date.strftime(DATE_FORMAT),
"user_id": random.randint(1000,9999)}
for _ in range(700)])
training_rows.extend(
[
{
"updated_at": date.strftime(DATE_FORMAT),
"user_id": random.randint(1000, 9999),
}
for _ in range(700)
]
)
else:
continue
write_rows_to_csv(csv_path=os.path.join(FILE_DIR, "data", "training", "users_per_day_weekly_seasonal_training.csv"),
rows=training_rows,
header=columns)
write_rows_to_csv(
csv_path=os.path.join(
FILE_DIR, "data", "training", "users_per_day_weekly_seasonal_training.csv"
),
rows=training_rows,
header=columns,
)

validation_dates = generate_rows_timestamps(base_date=EPOCH, days_back=7) # one week
validation_dates = generate_rows_timestamps(
base_date=EPOCH, days_back=7
) # one week
validation_rows = []
for ix, date in enumerate(validation_dates):
validation_rows.extend([{"updated_at": date.strftime(DATE_FORMAT),
"user_id": random.randint(1000,9999)}
for _ in range(100)])
write_rows_to_csv(csv_path=os.path.join(FILE_DIR, "data", "validation", "users_per_day_weekly_seasonal_validation.csv"),
rows=validation_rows,
header=columns)
validation_rows.extend(
[
{
"updated_at": date.strftime(DATE_FORMAT),
"user_id": random.randint(1000, 9999),
}
for _ in range(100)
]
)
write_rows_to_csv(
csv_path=os.path.join(
FILE_DIR,
"data",
"validation",
"users_per_day_weekly_seasonal_validation.csv",
),
rows=validation_rows,
header=columns,
)


def generate_backfill_days_training_and_validation_files(rows_count_per_day=100):
Expand All @@ -364,7 +391,7 @@ def get_training_row(date, row_index, rows_count):
"occurred_at": (date - timedelta(hours=1)).strftime(DATE_FORMAT),
"min_length": "".join(
random.choices(string.ascii_lowercase, k=random.randint(5, 10))
)
),
}

def get_validation_row(date, row_index, rows_count):
Expand All @@ -373,14 +400,10 @@ def get_validation_row(date, row_index, rows_count):
"occurred_at": (date - timedelta(hours=7)).strftime(DATE_FORMAT),
"min_length": "".join(
random.choices(string.ascii_lowercase, k=random.randint(1, 10))
)
),
}

string_columns = [
"updated_at",
"occurred_at",
"min_length"
]
string_columns = ["updated_at", "occurred_at", "min_length"]
dates = generate_rows_timestamps(base_date=EPOCH - timedelta(days=1))
training_rows = generate_rows(rows_count_per_day, dates, get_training_row)
write_rows_to_csv(
Expand All @@ -397,13 +420,14 @@ def get_validation_row(date, row_index, rows_count):
)
write_rows_to_csv(
os.path.join(
FILE_DIR, "data", "validation", "backfill_days_column_anomalies_validation.csv"
FILE_DIR,
"data",
"validation",
"backfill_days_column_anomalies_validation.csv",
),
validation_rows,
string_columns,
)




def main():
Expand Down
6 changes: 4 additions & 2 deletions integration_tests/integration_tests/dbt_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import uuid
from typing import Optional
from packaging import version

from dbt.adapters.base import BaseRelation
from dbt.adapters.factory import get_adapter_class_by_name, register_adapter
Expand All @@ -19,6 +18,7 @@
from dbt.task.sql import SqlCompileRunner
from dbt.tracking import disable_tracking
from dbt.version import __version__
from packaging import version
from pydantic import BaseModel

dbt_version = version.parse(__version__)
Expand All @@ -41,7 +41,9 @@ def default_project_dir() -> Path:
def default_profiles_dir() -> Path:
if "DBT_PROFILES_DIR" in os.environ:
return Path(os.environ["DBT_PROFILES_DIR"]).resolve()
return Path.cwd() if (Path.cwd() / "profiles.yml").exists() else Path.home() / ".dbt"
return (
Path.cwd() if (Path.cwd() / "profiles.yml").exists() else Path.home() / ".dbt"
)


DEFAULT_PROFILES_DIR = str(default_profiles_dir())
Expand Down
42 changes: 15 additions & 27 deletions integration_tests/integration_tests/test_anomaly_scores_query.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from datetime import datetime
import os
from datetime import datetime

import pandas as pd
from parametrization import Parametrization


from .dbt_project import DbtProject
from .utils import (
agate_table_to_pandas_dataframe,
assert_dfs_equal,
create_test_table,
get_package_database_and_schema,
insert_rows,
update_var,
lowercase_column_names,
assert_dfs_equal,
agate_table_to_pandas_dataframe,
get_package_database_and_schema,
update_var,
)

MIN_BUCKET_START = datetime(2022, 1, 1, 0, 0, 0)
Expand Down Expand Up @@ -47,36 +47,24 @@
name="freshness_1_day_mostly_non_anomalous",
metric_name="freshness",
input_rows=os.path.join(BASE_DIR, "freshness_1_day_test.csv"),
input_rows_history=os.path.join(
BASE_DIR, "freshness_1_day_history.csv"
),
expected_output=os.path.join(
BASE_DIR, "freshness_1_day_after.csv"
),
input_rows_history=os.path.join(BASE_DIR, "freshness_1_day_history.csv"),
expected_output=os.path.join(BASE_DIR, "freshness_1_day_after.csv"),
alias_table_name="numeric_column_anomalies",
)
@Parametrization.case(
name="row_count_1_day_1_row_no_anomalies",
metric_name="row_count",
input_rows=os.path.join(BASE_DIR, "row_count_1_day_1_row_test.csv"),
input_rows_history=os.path.join(
BASE_DIR, "row_count_1_day_1_row_history.csv"
),
expected_output=os.path.join(
BASE_DIR, "row_count_1_day_1_row_after.csv"
),
input_rows_history=os.path.join(BASE_DIR, "row_count_1_day_1_row_history.csv"),
expected_output=os.path.join(BASE_DIR, "row_count_1_day_1_row_after.csv"),
alias_table_name="any_type_column_anomalies_training",
)
@Parametrization.case(
name="row_count_1_day_mostly_non_anomalous",
metric_name="row_count",
input_rows=os.path.join(BASE_DIR, "row_count_1_day_test.csv"),
input_rows_history=os.path.join(
BASE_DIR, "row_count_1_day_history.csv"
),
expected_output=os.path.join(
BASE_DIR, "row_count_1_day_after.csv"
),
input_rows_history=os.path.join(BASE_DIR, "row_count_1_day_history.csv"),
expected_output=os.path.join(BASE_DIR, "row_count_1_day_after.csv"),
alias_table_name="string_column_anomalies_training",
)
@Parametrization.case(
Expand Down Expand Up @@ -157,7 +145,7 @@ def test_anomaly_scores_query(
"timestamp_column": timestamp_column,
"where_expression": where_expression,
"freshness_column": None,
"event_timestamp_column": None
"event_timestamp_column": None,
}
if dbt_project.adapter_name not in ["postgres", "redshift"]:
database, schema = get_package_database_and_schema(dbt_project)
Expand All @@ -168,7 +156,7 @@ def test_anomaly_scores_query(
"resource_type": "dummy_type_not_incremental",
"alias": alias_table_name,
"database": database,
"schema": schema
"schema": schema,
}

test_configuration = {
Expand All @@ -178,7 +166,7 @@ def test_anomaly_scores_query(
"dimensions": dimensions, # should be the same as in the input
"time_bucket": time_bucket,
"timestamp_column": timestamp_column,
"where_expression": where_expression
"where_expression": where_expression,
}

query = dbt_project.execute_macro(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from parametrization import Parametrization


from .dbt_project import DbtProject
from .utils import create_test_table, insert_rows, update_var, lowercase_column_names, get_package_database_and_schema
from .utils import create_test_table, insert_rows, lowercase_column_names, update_var

MIN_BUCKET_START = datetime(2022, 1, 1, 0, 0, 0)
MAX_BUCKET_END = datetime(2022, 1, 4, 0, 0, 0)
Expand Down Expand Up @@ -169,7 +168,7 @@ def test_table_monitoring_query(
min_bucket_start=MIN_BUCKET_START.strftime("'%Y-%m-%d %H:%M:%S'"),
max_bucket_end=MAX_BUCKET_END.strftime("'%Y-%m-%d %H:%M:%S'"),
table_monitors=[metric],
days_back=(MAX_BUCKET_END-MIN_BUCKET_START).days,
days_back=(MAX_BUCKET_END - MIN_BUCKET_START).days,
metric_properties=metric_properties,
)
res_table = dbt_project.execute_sql(query)
Expand Down
8 changes: 4 additions & 4 deletions integration_tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import csv
import uuid
from functools import lru_cache
from typing import Dict, List, Any, Union
import csv
import pandas as pd
from typing import Any, Dict, List, Union

import agate
import pandas as pd
from dbt.adapters.base import BaseRelation

from .dbt_project import DbtProject, dbt_version
from .dbt_project import DbtProject


def create_test_table(
Expand Down
28 changes: 19 additions & 9 deletions integration_tests/run_e2e_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import sys
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Optional
from typing import List

import click
from dbt.version import __version__
from packaging import version

from elementary.clients.dbt.dbt_runner import DbtRunner
from generate_data import generate_fake_data
from packaging import version

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
DBT_VERSION = version.parse(version.parse(__version__).base_version)
Expand Down Expand Up @@ -365,17 +364,28 @@ def e2e_tests(
model = "non_dbt_model"
try:
row = get_row(model, dbt_runner)
if row['depends_on_nodes'] != '["model.elementary_integration_tests.one"]' \
or row['materialization'] != "non_dbt":
result = TestResult(type="non_dbt_models", message="FAILED: non_dbt model not materialized as expected")
if (
row["depends_on_nodes"] != '["model.elementary_integration_tests.one"]'
or row["materialization"] != "non_dbt"
):
result = TestResult(
type="non_dbt_models",
message="FAILED: non_dbt model not materialized as expected",
)
else:
result = TestResult(
type="non_dbt_models",
message=dbt_runner.run_operation('assert_table_doesnt_exist',
macro_args={"model_name": "non_dbt_model"}, should_log=False)[0]
message=dbt_runner.run_operation(
"assert_table_doesnt_exist",
macro_args={"model_name": "non_dbt_model"},
should_log=False,
)[0],
)
except ValueError:
result = TestResult(type="non_dbt_models", message="FAILED: we need to see the non_dbt model in the run")
result = TestResult(
type="non_dbt_models",
message="FAILED: we need to see the non_dbt model in the run",
)
test_results.append(result)

return test_results
Expand Down

0 comments on commit a648ae6

Please sign in to comment.