Skip to content

Commit

Permalink
Merge pull request #5 from Infectious-Disease-Modeling-Hubs/bsweger/c…
Browse files Browse the repository at this point in the history
…leanup

Cleanup styles and dependencies
  • Loading branch information
bsweger authored Apr 30, 2024
2 parents 14233c5 + 948be19 commit d15f761
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 153 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
91 changes: 59 additions & 32 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ classifiers = [

dependencies = [
'boto3',
'pyarrow',
"pyarrow>=16.0.0",
]
authors = [
{name = "Becky Sweger", email = "[email protected]"},
]

[project.optional-dependencies]
dev = [
'coverage',
'mypy',
'pytest',
'pytest-mock',
Expand All @@ -35,10 +34,6 @@ build-backend = "setuptools.build_meta"
line-length = 120
lint.extend-select = ['I']

[tool.ruff.format]
quote-style = 'single'


[tool.pdm]
distribution = true
[tools.setuptools]
Expand Down
56 changes: 28 additions & 28 deletions src/hubverse_transform/model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Log to stdout
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
Expand All @@ -31,24 +31,24 @@ def __init__(self, input_uri: str, output_uri: str):
self.file_type = path.suffix

# TODO: Add other input file types as needed
if self.file_type not in ['.csv', '.parquet']:
raise NotImplementedError(f'Unsupported file type: {path.suffix}')
if self.file_type not in [".csv", ".parquet"]:
raise NotImplementedError(f"Unsupported file type: {path.suffix}")

# Parse model-output file name into individual parts
# (round_id, team, model)
file_parts = self.parse_file(self.file_name)
self.round_id = file_parts['round_id']
self.team = file_parts['team']
self.model = file_parts['model']
self.round_id = file_parts["round_id"]
self.team = file_parts["team"]
self.model = file_parts["model"]

def __repr__(self):
return f"ModelOutputHandler('{self.fs_input.type_name}', '{self.input_file}', '{self.output_path}')"

def __str__(self):
return f'Handle model-output data transforms for {self.input_file}.'
return f"Handle model-output data transforms for {self.input_file}."

@classmethod
def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = 'raw') -> 'ModelOutputHandler':
def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = "raw") -> "ModelOutputHandler":
"""Instantiate ModelOutputHandler for file on AWS S3."""

# ModelOutputHandler is designed to operate on original versions of model-output
Expand All @@ -57,13 +57,13 @@ def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = 'raw') -> '
# model-outputs.
path = pathlib.Path(s3_key)
if path.parts[0] != origin_prefix:
raise ValueError(f'Model output path {s3_key} does not begin with {origin_prefix}.')
raise ValueError(f"Model output path {s3_key} does not begin with {origin_prefix}.")

s3_input_uri = f's3://{bucket_name}/{s3_key}'
s3_input_uri = f"s3://{bucket_name}/{s3_key}"

# Destination path = origin path w/o the origin prefix
destination_path = str(path.relative_to(origin_prefix).parent)
s3_output_uri = f's3://{bucket_name}/{destination_path}'
s3_output_uri = f"s3://{bucket_name}/{destination_path}"

return cls(s3_input_uri, s3_output_uri)

Expand All @@ -74,34 +74,34 @@ def parse_file(cls, file_name: str) -> dict:
# Code below assumes [round_id likely yyyy-mm-dd]-[team]-[model] AND
# that there are no hyphens in team or model names
# https://github.com/orgs/Infectious-Disease-Modeling-Hubs/discussions/10
file_name_split = file_name.rsplit('-', 2)
file_name_split = file_name.rsplit("-", 2)

if (
file_name.count('-') > 4
file_name.count("-") > 4
or len(file_name_split) != 3
or not re.match(r'^\d{4}-\d{2}-\d{2}$', file_name_split[0])
or not re.match(r"^\d{4}-\d{2}-\d{2}$", file_name_split[0])
):
raise ValueError(f'File name {file_name} not in expected model-output format: yyyy-mm-dd-team-model.')
raise ValueError(f"File name {file_name} not in expected model-output format: yyyy-mm-dd-team-model.")

file_parts = {}
file_parts['round_id'] = file_name_split[0]
file_parts['team'] = file_name_split[1]
file_parts['model'] = file_name_split[2]
file_parts["round_id"] = file_name_split[0]
file_parts["team"] = file_name_split[1]
file_parts["model"] = file_name_split[2]

# TODO: why so many logs?
logger.info(f'Parsed model-output filename: {file_parts}')
logger.info(f"Parsed model-output filename: {file_parts}")
return file_parts

def read_file(self) -> pa.table:
"""Read model-output file into PyArrow table."""

logger.info(f'Reading file: {self.input_file}')
logger.info(f"Reading file: {self.input_file}")

if self.file_type == '.csv':
if self.file_type == ".csv":
model_output_file = self.fs_input.open_input_stream(self.input_file)
# normalize incoming missing data values to null, regardless of data type
options = csv.ConvertOptions(
null_values=['na', 'NA', '', ' ', 'null', 'Null', 'NaN', 'nan'], strings_can_be_null=True
null_values=["na", "NA", "", " ", "null", "Null", "NaN", "nan"], strings_can_be_null=True
)
model_output_table = csv.read_csv(model_output_file, convert_options=options)
else:
Expand All @@ -116,16 +116,16 @@ def add_columns(self, model_output_table: pa.table) -> pa.table:
"""Add model-output metadata columns to PyArrow table."""

num_rows = model_output_table.num_rows
logger.info(f'Adding columns to table with {num_rows} rows')
logger.info(f"Adding columns to table with {num_rows} rows")

# Create a dictionary of the existing columns
existing_columns = {name: model_output_table[name] for name in model_output_table.column_names}

# Create arrays that we'll use to append columns to the table
new_columns = {
'round_id': pa.array([self.round_id for i in range(0, num_rows)]),
'team_abbr': pa.array([self.team for i in range(0, num_rows)]),
'model_abbr': pa.array([self.model for i in range(0, num_rows)]),
"round_id": pa.array([self.round_id for i in range(0, num_rows)]),
"team_abbr": pa.array([self.team for i in range(0, num_rows)]),
"model_abbr": pa.array([self.model for i in range(0, num_rows)]),
}

# Merge the new columns with the existing columns
Expand All @@ -137,12 +137,12 @@ def add_columns(self, model_output_table: pa.table) -> pa.table:
def write_parquet(self, updated_model_output_table: pa.table) -> str:
"""Write transformed model-output table to parquet file."""

transformed_file_path = f'{self.output_path}/{self.file_name}.parquet'
transformed_file_path = f"{self.output_path}/{self.file_name}.parquet"

with self.fs_output.open_output_stream(transformed_file_path) as parquet_file:
pq.write_table(updated_model_output_table, parquet_file)

logger.info(f'Finished writing parquet file: {transformed_file_path}')
logger.info(f"Finished writing parquet file: {transformed_file_path}")

return transformed_file_path

Expand Down
Loading

0 comments on commit d15f761

Please sign in to comment.