Skip to content

Commit

Permalink
Clean up ModelOutputHandler attributes and methods
Browse files Browse the repository at this point in the history
As preparation for adding a new method to the ModelOutputHandler
class, incorporate the convention of using the _ prefix to denote
private attributes and methods. Also change some of the class's
public attributes (and display the cleaned version of URIs instead
of the non-clean version)
  • Loading branch information
bsweger committed Aug 30, 2024
1 parent a52e249 commit f8c6b24
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 76 deletions.
48 changes: 42 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ To install this package:
pip install git+https://github.com/hubverse-org/hubverse-transform.git
```

### Using with a model-output file in a hub on the local filesystem

Sample usage:

```python
from pathlib import Path
from hubverse_transform.model_output import ModelOutputHandler

# to use with a local model-output file

mo = ModelOutputHandler(
'~/code/hubverse-cloud/model-output/UMass-flusion/2023-10-14-UMass-flusion.csv',
'/.'

Path('~/code/hubverse-cloud'),
Path('raw/2024-05-06-UMass-flusion.csv'),
Path('~/code/hubverse-cloud/model-output')
)

# read the original model-output file into an Arrow table
original_file = mo.read_file()

Expand All @@ -36,6 +38,28 @@ transformed_data = mo.add_columns(original_file)
# mo.write(transformed_data)
```

### Using with a model-output file in an S3-based hub

Sample usage:

```python
from hubverse_transform.model_output import ModelOutputHandler

mo = ModelOutputHandler.from_s3(
'example-complex-forecast-hub', # S3 bucket name
'raw/model-output/Flusight-baseline/2022-10-22-Flusight-baseline.csv', # S3 key'
)

# read the original model-output file into an Arrow table
original_file = mo.read_file()

# add new columns to the original model_output data
transformed_data = mo.add_columns(original_file)

# write transformed data to parquet (requires S3 write access)
mo.write(transformed_data)
```

Sample output of the original and transformed data:
```
In [31]: original_file.take([0,1])
Expand Down Expand Up @@ -175,4 +199,16 @@ To package the hubverse_transform code for deployment to the `hubverse-transform
3. From the root of this project, run the deploy script:
```bash
source deploy_lambda.sh
```
```

### Re-processing model-output files that have already been transformed

If you need to re-run the hubverse-transform function on model-output files that have already been uploaded to S3,
you can use the `lambda_retrigger_model_output_add.py` script in this repo's `faas/` folder.
This manual action should be done with care but can be handy if data needs to be re-processed (in the event of a
hubverse-transform bug fix, for example). The script works by updating the S3 metadata for every file in the
`raw/model-output` file of the hub's S3 bucket. The metadata update then triggers the lambda function that runs
when new incoming model-output files are detected.

**Note:** You will need write access to the hub's S3 bucket to use this script.
98 changes: 55 additions & 43 deletions src/hubverse_transform/model_output.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# mypy: disable-error-code="operator,attr-defined"

import logging
import os
import re
from urllib.parse import quote
from urllib.parse import quote, unquote

import pyarrow as pa
import pyarrow.parquet as pq
Expand All @@ -14,7 +16,7 @@
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.setLevel(logging.WARNING)


class ModelOutputHandler:
Expand All @@ -23,11 +25,17 @@ class ModelOutputHandler:
Attributes
----------
input_uri: str
URI of the incoming model-output file to be transformed.
file_name : str
Name of the incoming model-output file to be transformed.
file_type : str
Type of file to be transformed
(.parquet, .pqt, .csv currently supported).
fs: str
The filesystem type of the incoming model-output file.
output_uri: str
URI of the transformed model-output file.
round_id : str
Name of the round_id associated with the model-output file.
model_id : int
Expand All @@ -48,43 +56,49 @@ def __init__(self, hub_path: os.PathLike, mo_path: os.PathLike, output_path: os.
Where the transformed model-output file will be saved.
"""

input_path = hub_path / mo_path # type: ignore
sanitized_input_uri = self.sanitize_uri(input_path)
input_filesystem = fs.FileSystem.from_uri(sanitized_input_uri)
self.fs_input = input_filesystem[0]
self.input_file = input_filesystem[1]

output_filesystem = fs.FileSystem.from_uri(self.sanitize_uri(output_path))
self.fs_output = output_filesystem[0]
self.output_path = output_filesystem[1]

# get file name and type from input file (use the sanitized version)
file_path = AnyPath(self.input_file)
self.file_name = file_path.stem # file name without extension
self.file_type = file_path.suffix
# create sanitized input/output URIs
sanitized_input_path = AnyPath(self._sanitize_uri(hub_path / mo_path))
self.input_uri = str(sanitized_input_path)
self.file_name = sanitized_input_path.stem
self.file_type = sanitized_input_path.suffix
sanitized_output_path = AnyPath(self._sanitize_uri(output_path / f"{self.file_name}.parquet"))
self.output_uri = str(sanitized_output_path)

# create pyarrow filesystem objects to read and write files
input_filesystem = fs.FileSystem.from_uri(self.input_uri)
self._fs_input = input_filesystem[0]
self._input_file = input_filesystem[1]
output_filesystem = fs.FileSystem.from_uri(self.output_uri)
self._fs_output = output_filesystem[0]
self._output_file = output_filesystem[1]

# parse model-output file name into individual parts
file_parts = self._parse_file(self.file_name)
self.round_id = file_parts["round_id"]
self.model_id = unquote(file_parts["model_id"])

# handle case when the function is triggered without a file
# (e.g., if someone manually creates a folder in an S3 bucket)
if not input_path.suffix:
msg = "Input file has no extension"
self.raise_invalid_file_warning(str(input_path), msg)
# filesystem must be supoprted
if (file_system := self._fs_input.type_name) not in ["local", "s3"]:
raise ValueError(f"Unsupported filesystem: {file_system}")
else:
self.fs = file_system

# TODO: Add other input file types as needed
# file must be a supported type
if self.file_type not in [".csv", ".parquet", ".pqt"]:
msg = f"Input file type {self.file_type} is not supported"
self.raise_invalid_file_warning(str(input_path), msg)
self._raise_invalid_file_warning(self.input_uri, msg)

# Parse model-output file name into individual parts
# (round_id, model_id)
file_parts = self.parse_file(self.file_name)
self.round_id = file_parts["round_id"]
self.model_id = file_parts["model_id"]
# handle case when object creation is triggered without a file
# (e.g., if someone manually creates a folder in an S3 bucket)
if not mo_path.suffix:
msg = "Input file has no extension"
self._raise_invalid_file_warning(self.input_uri, msg)

def __repr__(self):
return f"ModelOutputHandler('{self.fs_input.type_name}', '{self.input_file}', '{self.output_path}')"
return f"ModelOutputHandler('{self.input_uri}')"

def __str__(self):
return f"Handle model-output data transforms for {self.input_file}."
return f"Handle model-output data transforms for {self._input_file}."

@classmethod
def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = "raw") -> "ModelOutputHandler":
Expand All @@ -93,7 +107,7 @@ def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = "raw") -> "
Use this method to instantiate a ModelOutputHandler object for
model-output files store in an S3 bucket (for example, when
transformations are invoked via an AWS lambda function).
transitions are invoked via an AWS lambda function).
Parameters
----------
Expand Down Expand Up @@ -140,7 +154,7 @@ def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = "raw") -> "

return cls(s3_bucket_path, s3_mo_path, s3_output_path) # type: ignore

def raise_invalid_file_warning(self, path: str, msg: str) -> None:
def _raise_invalid_file_warning(self, path: str, msg: str) -> None:
"""Raise a warning if the class was instantiated with an invalid file."""

logger.warning(
Expand All @@ -151,7 +165,7 @@ def raise_invalid_file_warning(self, path: str, msg: str) -> None:
)
raise UserWarning(msg)

def sanitize_uri(self, path: os.PathLike, safe=":/") -> str:
def _sanitize_uri(self, path: os.PathLike, safe=":/") -> str:
"""Sanitize URIs for use with pyarrow's filesystem."""

# remove spaces at the end of a filename (e.g., my-model-output .csv) and
Expand All @@ -164,7 +178,7 @@ def sanitize_uri(self, path: os.PathLike, safe=":/") -> str:

return clean_uri

def parse_file(cls, file_name: str) -> dict:
def _parse_file(cls, file_name: str) -> dict:
"""Parse model-output file name into individual parts."""

# In practice, Hubverse hubs are formatting round_id as dates in YYYY-MM-DD format.
Expand Down Expand Up @@ -194,10 +208,10 @@ def parse_file(cls, file_name: str) -> dict:
def read_file(self) -> pa.table:
"""Read model-output file into PyArrow table."""

logger.info(f"Reading file: {self.input_file}")
logger.info(f"Reading file: {self.input_uri}")

if self.file_type == ".csv":
model_output_file = self.fs_input.open_input_stream(self.input_file)
model_output_file = self._fs_input.open_input_stream(self._input_file)
# normalize incoming missing data values to null, regardless of data type
options = csv.ConvertOptions(
null_values=["na", "NA", "", " ", "null", "Null", "NaN", "nan"],
Expand All @@ -208,12 +222,12 @@ def read_file(self) -> pa.table:
model_output_table = csv.read_csv(model_output_file, convert_options=options)
else:
# temp fix: force location and output_type_id columns to string
schema_new = pq.read_schema(self.input_file)
schema_new = pq.read_schema(self._input_file)
for field_name in ["location", "output_type_id"]:
field_idx = schema_new.get_field_index(field_name)
if field_idx >= 0:
schema_new = schema_new.set(field_idx, pa.field(field_name, pa.string()))
model_output_file = self.fs_input.open_input_file(self.input_file)
model_output_file = self._fs_input.open_input_file(self._input_file)
model_output_table = pq.read_table(model_output_file, schema=schema_new)

return model_output_table
Expand Down Expand Up @@ -242,14 +256,12 @@ def add_columns(self, model_output_table: pa.table) -> pa.table:
def write_parquet(self, updated_model_output_table: pa.table) -> str:
"""Write transformed model-output table to parquet file."""

transformed_file_path = f"{self.output_path}/{self.file_name}.parquet"

with self.fs_output.open_output_stream(transformed_file_path) as parquet_file:
with self._fs_output.open_output_stream(self._output_file) as parquet_file:
pq.write_table(updated_model_output_table, parquet_file)

logger.info(f"Finished writing parquet file: {transformed_file_path}")
logger.info(f"Finished writing parquet file: {self.output_uri}")

return transformed_file_path
return self.output_uri

def transform_model_output(self) -> str:
"""Transform model-output data and write to parquet file."""
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def s3_bucket_name() -> str: # type: ignore
Ultimately, this difficulty probably warrants a closer look at the code structure, because it shouldn't be this hard to test!
"""

bucket_list = ["hubverse-assets", "test-bucket", "bucket123"]
bucket_list = ["hubverse", "hubverse-assets", "test-bucket", "bucket123"]
for bucket in bucket_list:
try:
fs.FileSystem.from_uri(f"s3://{bucket}")
Expand Down
12 changes: 12 additions & 0 deletions test/integration/test_model_output_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_missing_model_output_id_mixture(tmpdir, test_file_path):
assert len(null_output_type_rows) == 8


def test_s3_transform():
"""Test the entire transform operation on a model output file stored in S3."""
bucket = "hubutils"
key = "testhubs/simple/model-output/baseline/2022-10-08-simple_hub-baseline.csv"
mo = ModelOutputHandler.from_s3(bucket, key, "testhubs")

# Until we incorporate S3 mocks, we'll test everything up until the write operation,
# will while fail on a permissions error
with pytest.raises(OSError):
mo.transform_model_output()


def test_model_output_csv_schema(tmpdir, test_file_path, expected_model_output_schema):
"""Test the parquet schema on files written by ModelOutputHandler."""
mo_path = "2024-05-04-teamabc-locations_numeric.csv"
Expand Down
Loading

0 comments on commit f8c6b24

Please sign in to comment.