Skip to content

Commit 3b930f1

Browse files
committed
Make sure incoming file URIs are encoded
This ensures we can accommodate spaces in file names (and other special characters).
1 parent 70e7257 commit 3b930f1

File tree

3 files changed

+73
-6
lines changed

3 files changed

+73
-6
lines changed

src/hubverse_transform/model_output.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import pathlib
33
import re
4+
from urllib.parse import quote
45

56
import pyarrow as pa
67
import pyarrow.parquet as pq
@@ -17,11 +18,11 @@
1718

1819
class ModelOutputHandler:
1920
def __init__(self, input_uri: str, output_uri: str):
20-
input_filesystem = fs.FileSystem.from_uri(input_uri)
21+
input_filesystem = fs.FileSystem.from_uri(self.sanitize_uri(input_uri))
2122
self.fs_input = input_filesystem[0]
2223
self.input_file = input_filesystem[1]
2324

24-
output_filesystem = fs.FileSystem.from_uri(output_uri)
25+
output_filesystem = fs.FileSystem.from_uri(self.sanitize_uri(output_uri))
2526
self.fs_output = output_filesystem[0]
2627
self.output_path = output_filesystem[1]
2728

@@ -66,6 +67,22 @@ def from_s3(cls, bucket_name: str, s3_key: str, origin_prefix: str = "raw") -> "
6667

6768
return cls(s3_input_uri, s3_output_uri)
6869

70+
def sanitize_uri(self, uri: str, safe=":/") -> str:
71+
"""Sanitize URIs for use with pyarrow's filesystem."""
72+
73+
uri_path = pathlib.Path(uri)
74+
75+
# remove spaces at the end of a filename (e.g., my-model-output .csv) and
76+
# also at the beginning and end of the path string
77+
clean_path = pathlib.Path(str(uri_path).replace(uri_path.stem, uri_path.stem.strip()))
78+
clean_string = str(clean_path).strip()
79+
80+
# encode the cleaned path (for example, any remaining spaces) so we can
81+
# safely use it as a URI
82+
clean_uri = quote(str(clean_string), safe=safe)
83+
84+
return clean_uri
85+
6986
def parse_file(cls, file_name: str) -> dict:
7087
"""Parse model-output file name into individual parts."""
7188

@@ -84,13 +101,12 @@ def parse_file(cls, file_name: str) -> dict:
84101
model_id_split = re.split(rf"{round_id}[-_]*", file_name)
85102
if not model_id_split or len(model_id_split) <= 1 or not model_id_split[-1]:
86103
raise ValueError(f"Unable to get model_id from file name {file_name}.")
87-
model_id = "".join(model_id_split[-1].split())
104+
model_id = model_id_split[-1].strip()
88105

89106
file_parts = {}
90107
file_parts["round_id"] = round_id
91108
file_parts["model_id"] = model_id
92109

93-
# TODO: why so many logs?
94110
logger.info(f"Parsed model-output filename: {file_parts}")
95111
return file_parts
96112

test/integration/test_model_output_integration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_file_path() -> pathlib.Path:
1818
def test_missing_model_output_id_numeric(tmpdir, test_file_path):
1919
"""Test behavior of model_output_id columns when there are a mix of numeric and missing output_type_ids."""
2020
output_dir = str(tmpdir.mkdir("model-output"))
21-
file_path = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric.csv")
21+
file_path = str(test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_numeric.csv"))
2222
mo = ModelOutputHandler(file_path, output_dir)
2323
output_uri = mo.transform_model_output()
2424

@@ -34,7 +34,7 @@ def test_missing_model_output_id_numeric(tmpdir, test_file_path):
3434
def test_missing_model_output_id_mixture(tmpdir, test_file_path):
3535
"""Test behavior of model_output_id columns when there are a mix of numeric, string, and missing output_type_ids."""
3636
output_dir = str(tmpdir.mkdir("model-output"))
37-
file_path = test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_mixed.csv")
37+
file_path = str(test_file_path.joinpath("2024-07-07-teamabc-output_type_ids_mixed.csv"))
3838
mo = ModelOutputHandler(file_path, output_dir)
3939
output_uri = mo.transform_model_output()
4040

test/unit/test_model_output.py

+51
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from pyarrow import csv as pyarrow_csv
1010
from pyarrow import fs
1111

12+
# note: the mocker fixture used throughout is provided by pytest-mock
13+
1214

1315
@pytest.fixture()
1416
def model_output_table() -> pa.Table:
@@ -118,6 +120,55 @@ def test_parse_file(file_uri, expected_round_id, expected_model_id):
118120
assert mo.model_id == expected_model_id
119121

120122

123+
@pytest.mark.parametrize(
124+
"input_uri, output_uri, expected_input_file, expected_output_path, expected_file_name, expected_model_id",
125+
[
126+
(
127+
"mock:bucket123/raw/prefix1/prefix 2/2420-01-01-team-model name with spaces.csv",
128+
"mock:bucket123/prefix1/prefix 2",
129+
"bucket123/raw/prefix1/prefix 2/2420-01-01-team-model name with spaces.csv",
130+
"bucket123/prefix1/prefix 2",
131+
"2420-01-01-team-model name with spaces",
132+
"team-model name with spaces",
133+
),
134+
(
135+
"mock:bucket1.2.3/raw/prefix1/prefix 2/2420-01-01-team-model.name.csv",
136+
"mock:bucket123/prefix1/~prefix 2",
137+
"bucket1.2.3/raw/prefix1/prefix 2/2420-01-01-team-model.name.csv",
138+
"bucket123/prefix1/~prefix 2",
139+
"2420-01-01-team-model.name",
140+
"team-model.name",
141+
),
142+
(
143+
"mock:raw/prefix 1/prefix2/2420-01-01-spáces at end .csv",
144+
"mock:prefix 1/prefix2",
145+
"raw/prefix 1/prefix2/2420-01-01-spáces at end.csv",
146+
"prefix 1/prefix2",
147+
"2420-01-01-spáces at end",
148+
"spáces at end",
149+
),
150+
(
151+
"mock:a space/prefix 1/prefix2/2420-01-01 look ma no hyphens.csv",
152+
"mock:prefix 1/prefix 🐍",
153+
"a space/prefix 1/prefix2/2420-01-01 look ma no hyphens.csv",
154+
"prefix 1/prefix 🐍",
155+
"2420-01-01 look ma no hyphens",
156+
"look ma no hyphens",
157+
),
158+
],
159+
)
160+
def test_new_instance_special_characters(
161+
input_uri, output_uri, expected_input_file, expected_output_path, expected_file_name, expected_model_id
162+
):
163+
# ensure spaces and other characters in directory, filename, s3 key, etc. are handled correctly
164+
165+
mo = ModelOutputHandler(input_uri, output_uri)
166+
assert mo.input_file == expected_input_file
167+
assert mo.output_path == expected_output_path
168+
assert mo.file_name == expected_file_name
169+
assert mo.model_id == expected_model_id
170+
171+
121172
@pytest.mark.parametrize(
122173
"s3_key, expected_input_uri, expected_output_uri",
123174
[

0 commit comments

Comments
 (0)