Skip to content

Commit 3caf608

Browse files
authored
Check Parquet files in together-cli, supply filetype in header (#73)
* Check Parquet files in together-cli, supply filetype in header * Fix typing annotations * Post-rebase fixes * Bump minor version due to the change of behavior * Remove files.py * Fix typing errors * Fix typing errors * Fix typing errors * Fix typing errors * Reduce the diff * Address review feedback
1 parent 2e8fea4 commit 3caf608

File tree

8 files changed

+281
-29
lines changed

8 files changed

+281
-29
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ repos:
2525
hooks:
2626
- id: mypy
2727
args: [--strict]
28-
additional_dependencies: [types-requests, types-tqdm, types-tabulate, types-click, types-filelock, types-Pillow, pydantic, aiohttp]
28+
additional_dependencies: [types-requests, types-tqdm, types-tabulate, types-click, types-filelock, types-Pillow, pyarrow-stubs, pydantic, aiohttp]
2929
exclude: ^tests/

poetry.lock

+150-10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
[build-system]
2-
requires = ["poetry"]
2+
requires = [
3+
"poetry",
4+
# Starting with NumPy 1.25, NumPy is (by default) as far back compatible
5+
# as oldest-support-numpy was (customizable with a NPY_TARGET_VERSION
6+
# define). For older Python versions (where NumPy 1.25 is not yet avaiable)
7+
# continue using oldest-support-numpy.
8+
"oldest-supported-numpy>=0.14; python_version<'3.9'",
9+
"numpy>=1.25; python_version>='3.9'",
10+
]
311
build-backend = "poetry.masonry.api"
412

513
[tool.poetry]
614
name = "together"
7-
version = "1.0.1"
15+
version = "1.1.0"
816
authors = [
9-
"Together AI <[email protected]>"
17+
"Together AI <[email protected]>"
1018
]
1119
description = "Python client for Together's Cloud Platform!"
1220
readme = "README.md"
@@ -31,6 +39,11 @@ filelock = "^3.13.1"
3139
eval-type-backport = "^0.1.3"
3240
click = "^8.1.7"
3341
pillow = "^10.3.0"
42+
pyarrow = ">=10.0.1"
43+
numpy = [
44+
{ version = ">=1.23.5", python = "<3.12" },
45+
{ version = ">=1.26.0", python = ">=3.12" },
46+
]
3447

3548
[tool.poetry.group.quality]
3649
optional = true
@@ -42,6 +55,7 @@ types-tqdm = "^4.65.0.0"
4255
types-tabulate = "^0.9.0.3"
4356
pre-commit = "3.5.0"
4457
types-requests = "^2.31.0.20240218"
58+
pyarrow-stubs = "^10.0.1.7"
4559
mypy = "^1.9.0"
4660

4761
[tool.poetry.group.tests]

src/together/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@
2626

2727
# maximum number of GB sized files we support finetuning for
2828
MAX_FILE_SIZE_GB = 4.9
29+
30+
# expected columns for Parquet files
31+
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]

src/together/filemanager.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
FileTypeError,
2626
)
2727
from together.together_response import TogetherResponse
28-
from together.types import FilePurpose, FileResponse, TogetherClient, TogetherRequest
28+
from together.types import (
29+
FilePurpose,
30+
FileResponse,
31+
FileType,
32+
TogetherClient,
33+
TogetherRequest,
34+
)
2935

3036

3137
def chmod_and_replace(src: Path, dst: Path) -> None:
@@ -260,12 +266,17 @@ def _redirect_error_handler(
260266
http_status=response.status_code,
261267
)
262268

263-
def redirect_policy(
264-
self, url: str, file: Path, purpose: FilePurpose
269+
def get_upload_url(
270+
self,
271+
url: str,
272+
file: Path,
273+
purpose: FilePurpose,
274+
filetype: FileType,
265275
) -> Tuple[str, str]:
266276
data = {
267277
"purpose": purpose.value,
268278
"file_name": file.name,
279+
"file_type": filetype.value,
269280
}
270281

271282
requestor = api_requestor.APIRequestor(
@@ -324,7 +335,16 @@ def upload(
324335

325336
redirect_url = None
326337
if redirect:
327-
redirect_url, file_id = self.redirect_policy(url, file, purpose)
338+
if file.suffix == ".jsonl":
339+
filetype = FileType.jsonl
340+
elif file.suffix == ".parquet":
341+
filetype = FileType.parquet
342+
else:
343+
raise FileTypeError(
344+
f"Unknown extension of file {file}. "
345+
"Only files with extensions .jsonl and .parquet are supported."
346+
)
347+
redirect_url, file_id = self.get_upload_url(url, file, purpose, filetype)
328348

329349
file_size = os.stat(file.as_posix()).st_size
330350

src/together/types/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
FilePurpose,
1919
FileRequest,
2020
FileResponse,
21+
FileType,
2122
)
2223
from together.types.finetune import (
2324
FinetuneDownloadResult,
@@ -55,6 +56,7 @@
5556
"FileDeleteResponse",
5657
"FileObject",
5758
"FilePurpose",
59+
"FileType",
5860
"ImageRequest",
5961
"ImageResponse",
6062
"ModelObject",

src/together/types/files.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ class FilePurpose(str, Enum):
1515
FineTune = "fine-tune"
1616

1717

18+
class FileType(str, Enum):
19+
jsonl = "jsonl"
20+
parquet = "parquet"
21+
22+
1823
class FileRequest(BaseModel):
1924
"""
2025
Files request type
@@ -43,21 +48,17 @@ class FileResponse(BaseModel):
4348
Files API response type
4449
"""
4550

46-
# file id
4751
id: str
48-
# object type
4952
object: Literal[ObjectType.File]
5053
# created timestamp
5154
created_at: int | None = None
52-
# file purpose
55+
type: FileType | None = None
5356
purpose: FilePurpose | None = None
54-
# file-name
5557
filename: str | None = None
5658
# file byte size
5759
bytes: int | None = None
5860
# JSONL line count
5961
line_count: int | None = Field(None, alias="LineCount")
60-
# is processed
6162
processed: bool | None = Field(None, alias="Processed")
6263

6364

src/together/utils/files.py

+78-6
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,17 @@
33
import json
44
import os
55
from pathlib import Path
6+
from traceback import format_exc
67
from typing import Any, Dict
78

8-
from together.constants import MAX_FILE_SIZE_GB, MIN_SAMPLES, NUM_BYTES_IN_GB
9+
from pyarrow import ArrowInvalid, parquet
10+
11+
from together.constants import (
12+
MAX_FILE_SIZE_GB,
13+
MIN_SAMPLES,
14+
NUM_BYTES_IN_GB,
15+
PARQUET_EXPECTED_COLUMNS,
16+
)
917

1018

1119
def check_file(
@@ -50,6 +58,25 @@ def check_file(
5058
else:
5159
report_dict["file_size"] = file_size
5260

61+
if file.suffix == ".jsonl":
62+
report_dict["filetype"] = "jsonl"
63+
data_report_dict = _check_jsonl(file)
64+
elif file.suffix == ".parquet":
65+
report_dict["filetype"] = "parquet"
66+
data_report_dict = _check_parquet(file)
67+
else:
68+
report_dict["filetype"] = (
69+
f"Unknown extension of file {file}. "
70+
"Only files with extensions .jsonl and .parquet are supported."
71+
)
72+
report_dict["is_check_passed"] = False
73+
74+
report_dict.update(data_report_dict)
75+
return report_dict
76+
77+
78+
def _check_jsonl(file: Path) -> Dict[str, Any]:
79+
report_dict: Dict[str, Any] = {}
5380
# Check that the file is UTF-8 encoded. If not report where the error occurs.
5481
try:
5582
with file.open(encoding="utf-8") as f:
@@ -71,7 +98,7 @@ def check_file(
7198
if not isinstance(json_line, dict):
7299
report_dict["line_type"] = False
73100
report_dict["message"] = (
74-
f"Error parsing file. Invalid format on line {idx+1} of the input file. "
101+
f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
75102
'Example of valid json: {"text": "my sample string"}. '
76103
)
77104

@@ -80,7 +107,7 @@ def check_file(
80107
if "text" not in json_line.keys():
81108
report_dict["text_field"] = False
82109
report_dict["message"] = (
83-
f"Missing 'text' field was found on line {idx+1} of the the input file. "
110+
f"Missing 'text' field was found on line {idx + 1} of the the input file. "
84111
"Expected format: {'text': 'my sample string'}. "
85112
)
86113
report_dict["is_check_passed"] = False
@@ -89,7 +116,7 @@ def check_file(
89116
if not isinstance(json_line["text"], str):
90117
report_dict["key_value"] = False
91118
report_dict["message"] = (
92-
f'Invalid value type for "text" key on line {idx+1}. '
119+
f'Invalid value type for "text" key on line {idx + 1}. '
93120
f'Expected string. Found {type(json_line["text"])}.'
94121
)
95122

@@ -99,7 +126,7 @@ def check_file(
99126
if idx + 1 < MIN_SAMPLES:
100127
report_dict["min_samples"] = False
101128
report_dict["message"] = (
102-
f"Processing {file} resulted in only {idx+1} samples. "
129+
f"Processing {file} resulted in only {idx + 1} samples. "
103130
f"Our minimum is {MIN_SAMPLES} samples. "
104131
)
105132
report_dict["is_check_passed"] = False
@@ -118,7 +145,7 @@ def check_file(
118145
)
119146
else:
120147
report_dict["message"] = (
121-
f"Error parsing json payload. Unexpected format on line {idx+1}."
148+
f"Error parsing json payload. Unexpected format on line {idx + 1}."
122149
)
123150
report_dict["is_check_passed"] = False
124151

@@ -128,5 +155,50 @@ def check_file(
128155
report_dict["line_type"] = True
129156
if report_dict["key_value"] is not False:
130157
report_dict["key_value"] = True
158+
return report_dict
159+
160+
161+
def _check_parquet(file: Path) -> Dict[str, Any]:
162+
report_dict: Dict[str, Any] = {}
163+
164+
try:
165+
table = parquet.read_table(str(file), memory_map=True)
166+
except ArrowInvalid:
167+
report_dict["load_parquet"] = (
168+
f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
169+
f"Exception trace:\n{format_exc()}"
170+
)
171+
report_dict["is_check_passed"] = False
172+
return report_dict
173+
174+
column_names = table.schema.names
175+
if "input_ids" not in column_names:
176+
report_dict["load_parquet"] = (
177+
f"Parquet file {file} does not contain the `input_ids` column."
178+
)
179+
report_dict["is_check_passed"] = False
180+
return report_dict
181+
182+
for column_name in column_names:
183+
if column_name not in PARQUET_EXPECTED_COLUMNS:
184+
report_dict["load_parquet"] = (
185+
f"Parquet file {file} contains an unexpected column {column_name}. "
186+
f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
187+
)
188+
report_dict["is_check_passed"] = False
189+
return report_dict
190+
191+
num_samples = len(table)
192+
if num_samples < MIN_SAMPLES:
193+
report_dict["min_samples"] = (
194+
f"Processing {file} resulted in only {num_samples} samples. "
195+
f"Our minimum is {MIN_SAMPLES} samples. "
196+
)
197+
report_dict["is_check_passed"] = False
198+
return report_dict
199+
else:
200+
report_dict["num_samples"] = num_samples
201+
202+
report_dict["is_check_passed"] = True
131203

132204
return report_dict

0 commit comments

Comments
 (0)