Skip to content

Commit df03f70

Browse files
committed
Add integration tests
These small integration tests in their current form are a check to make sure we're getting the expected behavior for model-output files that contain a variety of model_output_id_types
1 parent ac542ca commit df03f70

4 files changed

+72
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"origin_date","target","horizon","location","output_type","output_type_id","value"
2+
2022-10-08,"wk inc flu hosp",1,"02","quantile",0.99,203
3+
2022-10-08,"wk inc flu hosp",1,"02","mean",,173
4+
2023-10-21,wk flu hosp rate change,-1,US,pmf,large,0.0018554857403307722
5+
2023-10-21,wk flu hosp rate change,-1,US,pmf,"large",0.0018554857403307722
6+
2023-10-21,wk flu hosp rate change,-1,US,pmf,"large",what if this is a big string with no quotes
7+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"origin_date","target","horizon","location","output_type","output_type_id","value"
2+
2022-10-08,"wk inc flu hosp",1,"02","quantile",0.99,203
3+
2022-10-08,"wk inc flu hosp",1,"02","mean",,173
4+
2023-10-21,wk flu hosp rate change,-1,US,pmf,111,0.0018554857403307722
5+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pathlib
2+
3+
import pyarrow.compute as pc
4+
import pytest
5+
from hubverse_transform.model_output import ModelOutputHandler
6+
from pyarrow import parquet
7+
8+
9+
@pytest.fixture()
10+
def test_file_path() -> pathlib.Path:
11+
"""
12+
Return path to the integration test files.
13+
"""
14+
test_file_path = pathlib.Path(__file__).parent.joinpath('data')
15+
return test_file_path
16+
17+
18+
def test_missing_model_output_id_numeric(tmpdir, test_file_path):
19+
"""Test behavior of model_output_id columns when there are a mix of numeric and missing output_type_ids."""
20+
output_dir = str(tmpdir.mkdir('model-output'))
21+
file_path = test_file_path.joinpath('2024-07-07-teamabc-output_type_ids_numeric.csv')
22+
mo = ModelOutputHandler(file_path, output_dir)
23+
output_uri = mo.transform_model_output()
24+
25+
# read the output parquet file
26+
transformed_output = parquet.read_table(output_uri)
27+
28+
# when the rest of the model_output_types are numeric, the empty one should be null
29+
expr = pc.field('output_type_id').is_null()
30+
null_output_type_rows = transformed_output.filter(expr)
31+
assert len(null_output_type_rows) == 1
32+
33+
34+
def test_missing_model_output_id_mixture(tmpdir, test_file_path):
35+
"""Test behavior of model_output_id columns when there are a mix of numeric, string, and missing output_type_ids."""
36+
output_dir = str(tmpdir.mkdir('model-output'))
37+
file_path = test_file_path.joinpath('2024-07-07-teamabc-output_type_ids_mixed.csv')
38+
mo = ModelOutputHandler(file_path, output_dir)
39+
output_uri = mo.transform_model_output()
40+
41+
# read the output parquet file
42+
transformed_output = parquet.read_table(output_uri)
43+
44+
# where there are a mix of string and numeric output_type_ids, the column is cast to string
45+
# and, therefore, missing values should be empty strings
46+
expr = pc.field('output_type_id').is_null()
47+
null_output_type_rows = transformed_output.filter(expr)
48+
assert len(null_output_type_rows) == 0
49+
expr = pc.field('output_type_id') == ''
50+
empty_output_type_rows = transformed_output.filter(expr)
51+
assert len(empty_output_type_rows) == 1

test/test_model_output.py test/unit/test_model_output.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def model_output_data() -> list[dict[str, Any]]:
2929
Fixture that returns a list of model-output data representing multiple output types.
3030
This fixture is used as input for other fixtures that generate temporary .csv and .parquest files for testing.
3131
"""
32+
3233
model_output_fieldnames = [
3334
'reference_date',
3435
'location',
@@ -41,8 +42,6 @@ def model_output_data() -> list[dict[str, Any]]:
4142
model_output_list = [
4243
['2420-01-01', 'US', '1 light year', 'hospitalizations', 'quantile', 0.5, 62],
4344
['2420-01-01', 'US', '1 light year', 'hospitalizations', 'quantile', 0.75, 50.1],
44-
['2420-01-01', '02', 3, 'hospitalizations', 'mean', 'NA', 11],
45-
['2420-01-01', '03', 3, 'hospitalizations', 'mean', 'NA', 'a string value for some reason'],
4645
['2420-01-01', '03', 3, 'hospitalizations', 'mean', None, 33],
4746
['1999-12-31', 'US', 'last month', 'hospitalizations', 'pmf', 'large_increase', 2.597827508665773e-9],
4847
]
@@ -203,25 +202,26 @@ def test_added_column_values(model_output_table):
203202
def test_read_file_csv(test_csv_file, model_output_table):
204203
mo = ModelOutputHandler(test_csv_file, 'mock:fake-output-uri')
205204
pyarrow_table = mo.read_file()
206-
assert len(pyarrow_table) == 6
205+
assert len(pyarrow_table) == 4
207206

208207
# output_type_id should retain the value from the .csv file, even when the value is empty or "NA"
208+
# NA values generate
209209
output_type_id_col = pyarrow_table.column('output_type_id')
210210
assert str(output_type_id_col[0]) == '0.5'
211-
assert str(output_type_id_col[2]) == 'NA'
212-
assert str(output_type_id_col[4]) == ''
211+
assert str(output_type_id_col[2]) == ''
212+
assert str(output_type_id_col[3]) == 'large_increase'
213213

214214

215215
def test_read_file_parquet(test_parquet_file, model_output_table):
216216
mo = ModelOutputHandler(test_parquet_file, 'mock:fake-output-uri')
217217
pyarrow_table = mo.read_file()
218-
assert len(pyarrow_table) == 6
218+
assert len(pyarrow_table) == 4
219219

220220
# output_type_id should retain the value from the .csv file, even when the value is empty or "NA"
221221
output_type_id_col = pyarrow_table.column('output_type_id')
222222
assert str(output_type_id_col[0]) == '0.5'
223-
assert str(output_type_id_col[2]) == 'NA'
224-
assert str(output_type_id_col[4]) == ''
223+
assert str(output_type_id_col[2]) == ''
224+
assert str(output_type_id_col[3]) == 'large_increase'
225225

226226

227227
def test_write_parquet(tmpdir, model_output_table):
@@ -235,7 +235,7 @@ def test_write_parquet(tmpdir, model_output_table):
235235
assert actual_output_file_path == expected_output_file_path
236236

237237

238-
def test_transform_model_output(test_csv_file, tmpdir):
238+
def test_transform_model_output_path(test_csv_file, tmpdir):
239239
output_dir = str(tmpdir.mkdir('model-output'))
240240
mo = ModelOutputHandler(test_csv_file, output_dir)
241241
output_uri = mo.transform_model_output()

0 commit comments

Comments
 (0)