Skip to content

Commit

Permalink
Test a wider array of data
Browse files Browse the repository at this point in the history
Add parquet files to the unit/data rather than creating them by using arrow
to create parquet files from their csv counterparts
  • Loading branch information
bsweger committed Aug 20, 2024
1 parent 31fedd9 commit 2be4ffb
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 39 deletions.
Binary file modified test/integration/data/2024-05-04-teamabc-locations_numeric.parquet
Binary file not shown.
44 changes: 36 additions & 8 deletions test/integration/test_model_output_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,9 @@ def test_missing_model_output_id_mixture(tmpdir, test_file_path):
assert len(null_output_type_rows) == 8


@pytest.mark.parametrize(
"mo_path",
[
"2024-05-04-teamabc-locations_numeric.parquet",
"2024-05-04-teamabc-locations_numeric.csv",
],
)
def test_model_output_parquet_schema(tmpdir, test_file_path, mo_path, expected_model_output_schema):
def test_model_output_csv_schema(tmpdir, test_file_path, expected_model_output_schema):
"""Test the parquet schema on files written by ModelOutputHandler."""
mo_path = "2024-05-04-teamabc-locations_numeric.csv"
mo_full_path = test_file_path.joinpath(mo_path)
output_path = pathlib.Path(tmpdir.mkdir("model-output"))
mo = ModelOutputHandler(pathlib.Path(tmpdir), mo_full_path, output_path)
Expand All @@ -91,3 +85,37 @@ def test_model_output_parquet_schema(tmpdir, test_file_path, mo_path, expected_m
# read the output parquet file
transformed_output = parquet.read_table(output_uri)
assert len(transformed_output) == 23

# location is transformed to string
assert set(transformed_output["location"].to_pylist()) == {"02"}

# output_type_id transformed to string
assert transformed_output["output_type_id"].to_pylist()[0] == "0.01"


def test_model_output_parquet_schema(tmpdir, test_file_path, expected_model_output_schema):
"""Test the parquet schema on files written by ModelOutputHandler."""
mo_path = "2024-05-04-teamabc-locations_numeric.parquet"
mo_full_path = test_file_path.joinpath(mo_path)
output_path = pathlib.Path(tmpdir.mkdir("model-output"))
mo = ModelOutputHandler(pathlib.Path(tmpdir), mo_full_path, output_path)
output_uri = mo.transform_model_output()

# check schema of the original model_output parquet files so we can verify that
# model_output_id and location are correctly transformed to string
original_schema = pq.read_metadata(mo_full_path).schema.to_arrow_schema()
pa.types.is_float64(original_schema.field("output_type_id").type)
pa.types.is_int64(original_schema.field("location").type)

actual_schema = pq.read_metadata(output_uri).schema.to_arrow_schema()
assert expected_model_output_schema.equals(actual_schema)

# read the output parquet file
transformed_output = parquet.read_table(output_uri)
assert len(transformed_output) == 23

# location is transformed to string (no leading zeroes because original parquet column is int)
assert set(transformed_output["location"].to_pylist()) == {"2"}

# output_type_id transformed to string (leading zeroes retained because original parquet column is float)
assert transformed_output["output_type_id"].to_pylist()[0] == "0.01"
6 changes: 4 additions & 2 deletions test/unit/data/2024-07-07-teamabc-output_type_ids_numeric.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"origin_date","target","horizon","location","output_type","output_type_id","value"
2022-10-08,"wk inc flu hosp",1,"","quantile",0.99,203
2022-10-08,"wk inc flu hosp",1,"02","mean",,173
2022-10-08,"wk inc flu hosp",1,"02","mean",NA,173
2022-10-08,"wk inc flu hosp",1,02,"mean",,173
2022-10-08,"wk inc flu hosp",1,"02","mean",NA,174
2022-10-08,wk inc flu hosp,1,NaN,mean,0.0,175
2022-10-08,wk inc flu hosp,1,string location,mean,null,176
2023-10-21,wk flu hosp rate change,-1,27,pmf,111,0.0018554857403307722
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
origin_date,target,horizon,output_type,output_type_id,value
2022-10-08,wk inc flu hosp,1,quantile,0.99,203
2022-10-08,wk inc flu hosp,1,mean,,173
2022-10-08,wk inc flu hosp,1,mean,NA,173
2022-10-08,wk inc flu hosp,1,mean,NA,174
2022-10-08,wk inc flu hosp,1,mean,0.0,175
2022-10-08,wk inc flu hosp,1,mean,null,176
2023-10-21,wk flu hosp rate change,-1,pmf,111,0.0018554857403307722
Binary file not shown.
68 changes: 40 additions & 28 deletions test/unit/test_model_output.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from cloudpathlib import AnyPath
from hubverse_transform.model_output import ModelOutputHandler
from pyarrow import csv as pyarrow_csv
from pyarrow import fs

# the mocker fixture used throughout is provided by pytest-mock
# see conftest.py for definition of other fixtures (e.g., s3_bucket_name)
Expand Down Expand Up @@ -280,36 +277,44 @@ def test_file_path() -> Path:
def test_location_or_output_type_id_column_schema_csv(tmpdir, test_file_path):
hub_path = AnyPath(tmpdir)

expected_column_names = [
"origin_date",
"target",
"horizon",
"location",
"output_type",
"output_type_id",
"value",
]
expected_location_values = [None, "02", "02", None, "string location", "27"]
expected_output_type_id_values = ["0.99", None, None, "0.0", None, "111"]

# case 1: location and output_type_id with numeric value types
mo_path = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric.csv")
mo = ModelOutputHandler(hub_path, mo_path, hub_path)
pyarrow_table = mo.read_file()
assert len(pyarrow_table) == 6
assert pyarrow_table.column_names == expected_column_names
assert pa.types.is_string(pyarrow_table.schema.field("location").type)
assert pa.types.is_string(pyarrow_table.schema.field("output_type_id").type)
assert pyarrow_table["location"].to_pylist() == expected_location_values
assert pyarrow_table["output_type_id"].to_pylist() == expected_output_type_id_values

# case 2: no location just output_type_id with numeric value types
mo_path = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric_no_location.csv")
mo = ModelOutputHandler(hub_path, mo_path, hub_path)
pyarrow_table = mo.read_file()
assert len(pyarrow_table) == 6
expected_column_names.remove("location")
assert pyarrow_table.column_names == expected_column_names
assert pa.types.is_string(pyarrow_table.schema.field("output_type_id").type)
assert pyarrow_table["output_type_id"].to_pylist() == expected_output_type_id_values


def test_location_or_output_type_id_column_schema_parquet(tmpdir, test_file_path):
hub_path = AnyPath(tmpdir)

# case 1: location and output_type_id with numeric value types
mo_csv_file = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric.csv")
parquet_file = tmpdir / "2024-07-07-teamabc-output_type_ids_numeric.parquet"
mo_csv_table = pyarrow_csv.read_csv(mo_csv_file)
local = fs.LocalFileSystem()
with local.open_output_stream(str(parquet_file)) as parquet_file_stream:
pq.write_table(mo_csv_table, parquet_file_stream)

mo = ModelOutputHandler(hub_path, parquet_file, hub_path)
pyarrow_table = mo.read_file()
assert pa.types.is_string(pyarrow_table.schema.field("location").type)
assert pa.types.is_string(pyarrow_table.schema.field("output_type_id").type)
assert pyarrow_table.column_names == [
expected_column_names = [
"origin_date",
"target",
"horizon",
Expand All @@ -318,19 +323,26 @@ def test_location_or_output_type_id_column_schema_parquet(tmpdir, test_file_path
"output_type_id",
"value",
]
expected_location_values = [None, "02", "02", None, "string location", "27"]
expected_output_type_id_values = ["0.99", None, None, "0", None, "111"]

# case 1: location and output_type_id with numeric value types
mo_path = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric.parquet")
mo = ModelOutputHandler(hub_path, mo_path, hub_path)
pyarrow_table = mo.read_file()
assert len(pyarrow_table) == 6
assert pyarrow_table.column_names == expected_column_names
assert pa.types.is_string(pyarrow_table.schema.field("location").type)
assert pa.types.is_string(pyarrow_table.schema.field("output_type_id").type)
assert pyarrow_table["location"].to_pylist() == expected_location_values
assert pyarrow_table["output_type_id"].to_pylist() == expected_output_type_id_values

# case 2: no location just output_type_id with numeric value types
mo_csv_file = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric_no_location.csv")
parquet_file = tmpdir / "2024-07-07-teamabc-output_type_ids_numeric.parquet"
mo_csv_table = pyarrow_csv.read_csv(mo_csv_file)
local = fs.LocalFileSystem()
with local.open_output_stream(str(parquet_file)) as parquet_file_stream:
pq.write_table(mo_csv_table, parquet_file_stream)

mo = ModelOutputHandler(hub_path, parquet_file, hub_path)
mo_path = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric_no_location.parquet")
mo = ModelOutputHandler(hub_path, mo_path, hub_path)
pyarrow_table = mo.read_file()
assert len(pyarrow_table) == 6
expected_column_names.remove("location")
assert pyarrow_table.column_names == expected_column_names
assert pa.types.is_string(pyarrow_table.schema.field("output_type_id").type)
assert pyarrow_table.column_names == ["origin_date", "target", "horizon", "output_type", "output_type_id", "value"]
assert len(pyarrow_table.filter(pa.compute.field("output_type_id") == "0.99"))
with pytest.raises(Exception):
pyarrow_table.filter(pa.compute.field("output_type_id") == 0.99)
assert pyarrow_table["output_type_id"].to_pylist() == expected_output_type_id_values

0 comments on commit 2be4ffb

Please sign in to comment.