Skip to content

Commit cd154f8

Browse files
authored
Fix empty LR scheduler for old jobs, clean up argument validation (#288)
1 parent e0de91e commit cd154f8

File tree

5 files changed

+57
-53
lines changed

5 files changed

+57
-53
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.4"
15+
version = "1.5.5"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/resources/finetune.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
TogetherRequest,
2323
TrainingType,
2424
FinetuneLRScheduler,
25-
FinetuneLinearLRScheduler,
26-
FinetuneCosineLRScheduler,
27-
FinetuneLinearLRSchedulerArgs,
28-
FinetuneCosineLRSchedulerArgs,
25+
LinearLRScheduler,
26+
CosineLRScheduler,
27+
LinearLRSchedulerArgs,
28+
CosineLRSchedulerArgs,
2929
TrainingMethodDPO,
3030
TrainingMethodSFT,
3131
FinetuneCheckpoint,
@@ -50,7 +50,7 @@
5050
}
5151

5252

53-
def createFinetuneRequest(
53+
def create_finetune_request(
5454
model_limits: FinetuneTrainingLimits,
5555
training_file: str,
5656
model: str | None = None,
@@ -152,21 +152,19 @@ def createFinetuneRequest(
152152
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
153153
)
154154

155-
# Default to generic lr scheduler
156-
lrScheduler: FinetuneLRScheduler = FinetuneLRScheduler(lr_scheduler_type="linear")
157-
155+
lr_scheduler: FinetuneLRScheduler
158156
if lr_scheduler_type == "cosine":
159157
if scheduler_num_cycles <= 0.0:
160158
raise ValueError("Number of cycles should be greater than 0")
161159

162-
lrScheduler = FinetuneCosineLRScheduler(
163-
lr_scheduler_args=FinetuneCosineLRSchedulerArgs(
160+
lr_scheduler = CosineLRScheduler(
161+
lr_scheduler_args=CosineLRSchedulerArgs(
164162
min_lr_ratio=min_lr_ratio, num_cycles=scheduler_num_cycles
165163
),
166164
)
167165
else:
168-
lrScheduler = FinetuneLinearLRScheduler(
169-
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
166+
lr_scheduler = LinearLRScheduler(
167+
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
170168
)
171169

172170
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
@@ -182,7 +180,7 @@ def createFinetuneRequest(
182180
n_checkpoints=n_checkpoints,
183181
batch_size=batch_size,
184182
learning_rate=learning_rate,
185-
lr_scheduler=lrScheduler,
183+
lr_scheduler=lr_scheduler,
186184
warmup_ratio=warmup_ratio,
187185
max_grad_norm=max_grad_norm,
188186
weight_decay=weight_decay,
@@ -374,7 +372,7 @@ def create(
374372
pass
375373
model_limits = self.get_model_limits(model=model_name)
376374

377-
finetune_request = createFinetuneRequest(
375+
finetune_request = create_finetune_request(
378376
model_limits=model_limits,
379377
training_file=training_file,
380378
model=model,
@@ -762,7 +760,7 @@ async def create(
762760
pass
763761
model_limits = await self.get_model_limits(model=model_name)
764762

765-
finetune_request = createFinetuneRequest(
763+
finetune_request = create_finetune_request(
766764
model_limits=model_limits,
767765
training_file=training_file,
768766
model=model,

src/together/types/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
TrainingMethodDPO,
3535
TrainingMethodSFT,
3636
FinetuneCheckpoint,
37-
FinetuneCosineLRScheduler,
38-
FinetuneCosineLRSchedulerArgs,
37+
CosineLRScheduler,
38+
CosineLRSchedulerArgs,
3939
FinetuneDownloadResult,
40-
FinetuneLinearLRScheduler,
41-
FinetuneLinearLRSchedulerArgs,
40+
LinearLRScheduler,
41+
LinearLRSchedulerArgs,
4242
FinetuneLRScheduler,
4343
FinetuneList,
4444
FinetuneListEvents,
@@ -72,10 +72,10 @@
7272
"FinetuneListEvents",
7373
"FinetuneDownloadResult",
7474
"FinetuneLRScheduler",
75-
"FinetuneLinearLRScheduler",
76-
"FinetuneLinearLRSchedulerArgs",
77-
"FinetuneCosineLRScheduler",
78-
"FinetuneCosineLRSchedulerArgs",
75+
"LinearLRScheduler",
76+
"LinearLRSchedulerArgs",
77+
"CosineLRScheduler",
78+
"CosineLRSchedulerArgs",
7979
"FileRequest",
8080
"FileResponse",
8181
"FileList",

src/together/types/finetune.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal, Union
4+
from typing import List, Literal
55

6-
from pydantic import StrictBool, Field, validator, field_validator, ValidationInfo
6+
from pydantic import StrictBool, Field, field_validator
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -176,7 +176,7 @@ class FinetuneRequest(BaseModel):
176176
# training learning rate
177177
learning_rate: float
178178
# learning rate scheduler type and args
179-
lr_scheduler: FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
179+
lr_scheduler: LinearLRScheduler | CosineLRScheduler | None = None
180180
# learning rate warmup ratio
181181
warmup_ratio: float
182182
# max gradient norm
@@ -239,7 +239,7 @@ class FinetuneResponse(BaseModel):
239239
# training learning rate
240240
learning_rate: float | None = None
241241
# learning rate scheduler type and args
242-
lr_scheduler: FinetuneLinearLRScheduler | FinetuneCosineLRScheduler | None = None
242+
lr_scheduler: LinearLRScheduler | CosineLRScheduler | EmptyLRScheduler | None = None
243243
# learning rate warmup ratio
244244
warmup_ratio: float | None = None
245245
# max gradient norm
@@ -345,11 +345,11 @@ class FinetuneTrainingLimits(BaseModel):
345345
lora_training: FinetuneLoraTrainingLimits | None = None
346346

347347

348-
class FinetuneLinearLRSchedulerArgs(BaseModel):
348+
class LinearLRSchedulerArgs(BaseModel):
349349
min_lr_ratio: float | None = 0.0
350350

351351

352-
class FinetuneCosineLRSchedulerArgs(BaseModel):
352+
class CosineLRSchedulerArgs(BaseModel):
353353
min_lr_ratio: float | None = 0.0
354354
num_cycles: float | None = 0.5
355355

@@ -358,14 +358,20 @@ class FinetuneLRScheduler(BaseModel):
358358
lr_scheduler_type: str
359359

360360

361-
class FinetuneLinearLRScheduler(FinetuneLRScheduler):
361+
class LinearLRScheduler(FinetuneLRScheduler):
362362
lr_scheduler_type: Literal["linear"] = "linear"
363-
lr_scheduler: FinetuneLinearLRSchedulerArgs | None = None
363+
lr_scheduler_args: LinearLRSchedulerArgs | None = None
364364

365365

366-
class FinetuneCosineLRScheduler(FinetuneLRScheduler):
366+
class CosineLRScheduler(FinetuneLRScheduler):
367367
lr_scheduler_type: Literal["cosine"] = "cosine"
368-
lr_scheduler: FinetuneCosineLRSchedulerArgs | None = None
368+
lr_scheduler_args: CosineLRSchedulerArgs | None = None
369+
370+
371+
# placeholder for old fine-tuning jobs with no lr_scheduler_type specified
372+
class EmptyLRScheduler(FinetuneLRScheduler):
373+
lr_scheduler_type: Literal[""]
374+
lr_scheduler_args: None = None
369375

370376

371377
class FinetuneCheckpoint(BaseModel):

tests/unit/test_finetune_resources.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from together.resources.finetune import createFinetuneRequest
3+
from together.resources.finetune import create_finetune_request
44
from together.types.finetune import (
55
FinetuneTrainingLimits,
66
FinetuneFullTrainingLimits,
@@ -30,7 +30,7 @@
3030

3131

3232
def test_simple_request():
33-
request = createFinetuneRequest(
33+
request = create_finetune_request(
3434
model_limits=_MODEL_LIMITS,
3535
model=_MODEL_NAME,
3636
training_file=_TRAINING_FILE,
@@ -46,7 +46,7 @@ def test_simple_request():
4646

4747

4848
def test_validation_file():
49-
request = createFinetuneRequest(
49+
request = create_finetune_request(
5050
model_limits=_MODEL_LIMITS,
5151
model=_MODEL_NAME,
5252
training_file=_TRAINING_FILE,
@@ -61,14 +61,14 @@ def test_no_training_file():
6161
with pytest.raises(
6262
TypeError, match="missing 1 required positional argument: 'training_file'"
6363
):
64-
_ = createFinetuneRequest(
64+
_ = create_finetune_request(
6565
model_limits=_MODEL_LIMITS,
6666
model=_MODEL_NAME,
6767
)
6868

6969

7070
def test_lora_request():
71-
request = createFinetuneRequest(
71+
request = create_finetune_request(
7272
model_limits=_MODEL_LIMITS,
7373
model=_MODEL_NAME,
7474
training_file=_TRAINING_FILE,
@@ -84,7 +84,7 @@ def test_lora_request():
8484

8585

8686
def test_from_checkpoint_request():
87-
request = createFinetuneRequest(
87+
request = create_finetune_request(
8888
model_limits=_MODEL_LIMITS,
8989
training_file=_TRAINING_FILE,
9090
from_checkpoint=_FROM_CHECKPOINT,
@@ -99,7 +99,7 @@ def test_both_from_checkpoint_model_name():
9999
ValueError,
100100
match="You must specify either a model or a checkpoint to start a job from, not both",
101101
):
102-
_ = createFinetuneRequest(
102+
_ = create_finetune_request(
103103
model_limits=_MODEL_LIMITS,
104104
model=_MODEL_NAME,
105105
training_file=_TRAINING_FILE,
@@ -111,7 +111,7 @@ def test_no_from_checkpoint_no_model_name():
111111
with pytest.raises(
112112
ValueError, match="You must specify either a model or a checkpoint"
113113
):
114-
_ = createFinetuneRequest(
114+
_ = create_finetune_request(
115115
model_limits=_MODEL_LIMITS,
116116
training_file=_TRAINING_FILE,
117117
)
@@ -122,7 +122,7 @@ def test_batch_size_limit():
122122
ValueError,
123123
match="Requested batch size is higher that the maximum allowed value",
124124
):
125-
_ = createFinetuneRequest(
125+
_ = create_finetune_request(
126126
model_limits=_MODEL_LIMITS,
127127
model=_MODEL_NAME,
128128
training_file=_TRAINING_FILE,
@@ -132,7 +132,7 @@ def test_batch_size_limit():
132132
with pytest.raises(
133133
ValueError, match="Requested batch size is lower that the minimum allowed value"
134134
):
135-
_ = createFinetuneRequest(
135+
_ = create_finetune_request(
136136
model_limits=_MODEL_LIMITS,
137137
model=_MODEL_NAME,
138138
training_file=_TRAINING_FILE,
@@ -143,7 +143,7 @@ def test_batch_size_limit():
143143
ValueError,
144144
match="Requested batch size is higher that the maximum allowed value",
145145
):
146-
_ = createFinetuneRequest(
146+
_ = create_finetune_request(
147147
model_limits=_MODEL_LIMITS,
148148
model=_MODEL_NAME,
149149
training_file=_TRAINING_FILE,
@@ -154,7 +154,7 @@ def test_batch_size_limit():
154154
with pytest.raises(
155155
ValueError, match="Requested batch size is lower that the minimum allowed value"
156156
):
157-
_ = createFinetuneRequest(
157+
_ = create_finetune_request(
158158
model_limits=_MODEL_LIMITS,
159159
model=_MODEL_NAME,
160160
training_file=_TRAINING_FILE,
@@ -167,7 +167,7 @@ def test_non_lora_model():
167167
with pytest.raises(
168168
ValueError, match="LoRA adapters are not supported for the selected model."
169169
):
170-
_ = createFinetuneRequest(
170+
_ = create_finetune_request(
171171
model_limits=FinetuneTrainingLimits(
172172
max_num_epochs=20,
173173
max_learning_rate=1.0,
@@ -188,7 +188,7 @@ def test_non_full_model():
188188
with pytest.raises(
189189
ValueError, match="Full training is not supported for the selected model."
190190
):
191-
_ = createFinetuneRequest(
191+
_ = create_finetune_request(
192192
model_limits=FinetuneTrainingLimits(
193193
max_num_epochs=20,
194194
max_learning_rate=1.0,
@@ -210,7 +210,7 @@ def test_non_full_model():
210210
@pytest.mark.parametrize("warmup_ratio", [-1.0, 2.0])
211211
def test_bad_warmup(warmup_ratio):
212212
with pytest.raises(ValueError, match="Warmup ratio should be between 0 and 1"):
213-
_ = createFinetuneRequest(
213+
_ = create_finetune_request(
214214
model_limits=_MODEL_LIMITS,
215215
model=_MODEL_NAME,
216216
training_file=_TRAINING_FILE,
@@ -223,7 +223,7 @@ def test_bad_min_lr_ratio(min_lr_ratio):
223223
with pytest.raises(
224224
ValueError, match="Min learning rate ratio should be between 0 and 1"
225225
):
226-
_ = createFinetuneRequest(
226+
_ = create_finetune_request(
227227
model_limits=_MODEL_LIMITS,
228228
model=_MODEL_NAME,
229229
training_file=_TRAINING_FILE,
@@ -234,7 +234,7 @@ def test_bad_min_lr_ratio(min_lr_ratio):
234234
@pytest.mark.parametrize("max_grad_norm", [-1.0, -0.01])
235235
def test_bad_max_grad_norm(max_grad_norm):
236236
with pytest.raises(ValueError, match="Max gradient norm should be non-negative"):
237-
_ = createFinetuneRequest(
237+
_ = create_finetune_request(
238238
model_limits=_MODEL_LIMITS,
239239
model=_MODEL_NAME,
240240
training_file=_TRAINING_FILE,
@@ -245,7 +245,7 @@ def test_bad_max_grad_norm(max_grad_norm):
245245
@pytest.mark.parametrize("weight_decay", [-1.0, -0.01])
246246
def test_bad_weight_decay(weight_decay):
247247
with pytest.raises(ValueError, match="Weight decay should be non-negative"):
248-
_ = createFinetuneRequest(
248+
_ = create_finetune_request(
249249
model_limits=_MODEL_LIMITS,
250250
model=_MODEL_NAME,
251251
training_file=_TRAINING_FILE,
@@ -255,7 +255,7 @@ def test_bad_weight_decay(weight_decay):
255255

256256
def test_bad_training_method():
257257
with pytest.raises(ValueError, match="training_method must be one of .*"):
258-
_ = createFinetuneRequest(
258+
_ = create_finetune_request(
259259
model_limits=_MODEL_LIMITS,
260260
model=_MODEL_NAME,
261261
training_file=_TRAINING_FILE,

0 commit comments

Comments
 (0)