Skip to content

Commit 71277be

Browse files
committed
Set default output name for model weights and file downloads
1 parent 7fcbbd0 commit 71277be

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

src/together/commands/files.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,13 @@ def _add_retrieve_content(
6868
type=str,
6969
)
7070
retrieve_file_content_parser.add_argument(
71-
"output",
71+
"--output",
72+
"-o",
73+
default=None,
7274
metavar="OUT_FILENAME",
7375
help="Output filename",
7476
type=str,
77+
required=False,
7578
)
7679

7780
retrieve_file_content_parser.set_defaults(func=_run_retrieve_content)

src/together/commands/finetune.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,12 @@ def _add_download(
175175
type=str,
176176
)
177177
download_parser.add_argument(
178-
"output",
179-
metavar="OUT_FILENAME",
178+
"--output",
179+
"-o",
180+
default=None,
180181
help="Output filename",
181182
type=str,
183+
required=False,
182184
)
183185
download_parser.add_argument(
184186
"--checkpoint-num",

src/together/files.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ def retrieve_file(self, file_id: str) -> Dict[str, Union[str, int]]:
129129

130130
return response_json
131131

132-
def retrieve_file_content(self, file_id: str, output: str) -> str:
132+
def retrieve_file_content(
133+
self, file_id: str, output: Union[str, None] = None
134+
) -> str:
135+
if output is None:
136+
output = file_id + ".jsonl"
137+
133138
relative_path = posixpath.join(file_id, "content")
134139
retrieve_url = urllib.parse.urljoin(self.endpoint_url, relative_path)
135140

src/together/finetune.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import posixpath
33
import urllib.parse
4-
from typing import Any, Dict, List, Optional
4+
from typing import Any, Dict, List, Optional, Union
55

66
import requests
77
from tqdm import tqdm
@@ -206,7 +206,19 @@ def is_final_model_available(self, fine_tune_id: str) -> bool:
206206
return True
207207
return False
208208

209-
def download(self, fine_tune_id: str, output: str, checkpoint_num: int = -1) -> str:
209+
def download(
210+
self,
211+
fine_tune_id: str,
212+
output: Union[str, None] = None,
213+
checkpoint_num: int = -1,
214+
) -> str:
215+
# default to model_output_path name
216+
if output is None:
217+
output = (
218+
self.retrieve_finetune(fine_tune_id)["model_output_path"].split("/")[-1]
219+
+ ".tar.gz"
220+
)
221+
210222
model_file_path = urllib.parse.urljoin(
211223
self.endpoint_url,
212224
f"/api/finetune/downloadfinetunefile?ft_id={fine_tune_id}",

0 commit comments

Comments
 (0)