Skip to content

Commit c3d2a35

Browse files
authored
Make fine-tuning job validation messages more informative (#292)
* more verbose messages * update test
1 parent cd154f8 commit c3d2a35

File tree

2 files changed

+64
-66
lines changed

2 files changed

+64
-66
lines changed

src/together/resources/finetune.py

+33-21
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,46 @@
22

33
import re
44
from pathlib import Path
5-
from typing import Literal, List
5+
from typing import List, Literal
66

77
from rich import print as rprint
88

99
from together.abstract import api_requestor
1010
from together.filemanager import DownloadManager
1111
from together.together_response import TogetherResponse
1212
from together.types import (
13+
CosineLRScheduler,
14+
CosineLRSchedulerArgs,
15+
FinetuneCheckpoint,
1316
FinetuneDownloadResult,
1417
FinetuneList,
1518
FinetuneListEvents,
19+
FinetuneLRScheduler,
1620
FinetuneRequest,
1721
FinetuneResponse,
1822
FinetuneTrainingLimits,
1923
FullTrainingType,
24+
LinearLRScheduler,
25+
LinearLRSchedulerArgs,
2026
LoRATrainingType,
2127
TogetherClient,
2228
TogetherRequest,
23-
TrainingType,
24-
FinetuneLRScheduler,
25-
LinearLRScheduler,
26-
CosineLRScheduler,
27-
LinearLRSchedulerArgs,
28-
CosineLRSchedulerArgs,
2929
TrainingMethodDPO,
3030
TrainingMethodSFT,
31-
FinetuneCheckpoint,
31+
TrainingType,
3232
)
3333
from together.types.finetune import (
3434
DownloadCheckpointType,
35-
FinetuneEventType,
3635
FinetuneEvent,
36+
FinetuneEventType,
3737
)
3838
from together.utils import (
39+
get_event_step,
3940
log_warn_once,
4041
normalize_key,
41-
get_event_step,
4242
)
4343

44+
4445
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
4546

4647

@@ -63,7 +64,7 @@ def create_finetune_request(
6364
lr_scheduler_type: Literal["linear", "cosine"] = "linear",
6465
min_lr_ratio: float = 0.0,
6566
scheduler_num_cycles: float = 0.5,
66-
warmup_ratio: float = 0.0,
67+
warmup_ratio: float | None = None,
6768
max_grad_norm: float = 1.0,
6869
weight_decay: float = 0.0,
6970
lora: bool = False,
@@ -81,7 +82,6 @@ def create_finetune_request(
8182
dpo_beta: float | None = None,
8283
from_checkpoint: str | None = None,
8384
) -> FinetuneRequest:
84-
8585
if model is not None and from_checkpoint is not None:
8686
raise ValueError(
8787
"You must specify either a model or a checkpoint to start a job from, not both"
@@ -90,6 +90,8 @@ def create_finetune_request(
9090
if model is None and from_checkpoint is None:
9191
raise ValueError("You must specify either a model or a checkpoint")
9292

93+
model_or_checkpoint = model or from_checkpoint
94+
9395
if batch_size == "max":
9496
log_warn_once(
9597
"Starting from together>=1.3.0, "
@@ -103,7 +105,9 @@ def create_finetune_request(
103105
min_batch_size: int = 0
104106
if lora:
105107
if model_limits.lora_training is None:
106-
raise ValueError("LoRA adapters are not supported for the selected model.")
108+
raise ValueError(
109+
f"LoRA adapters are not supported for the selected model ({model_or_checkpoint})."
110+
)
107111
lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
108112
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
109113
training_type = LoRATrainingType(
@@ -118,7 +122,9 @@ def create_finetune_request(
118122

119123
else:
120124
if model_limits.full_training is None:
121-
raise ValueError("Full training is not supported for the selected model.")
125+
raise ValueError(
126+
f"Full training is not supported for the selected model ({model_or_checkpoint})."
127+
)
122128

123129
max_batch_size = model_limits.full_training.max_batch_size
124130
min_batch_size = model_limits.full_training.min_batch_size
@@ -127,25 +133,29 @@ def create_finetune_request(
127133

128134
if batch_size > max_batch_size:
129135
raise ValueError(
130-
"Requested batch size is higher that the maximum allowed value."
136+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
131137
)
132138

133139
if batch_size < min_batch_size:
134140
raise ValueError(
135-
"Requested batch size is lower that the minimum allowed value."
141+
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
136142
)
137143

138144
if warmup_ratio > 1 or warmup_ratio < 0:
139-
raise ValueError("Warmup ratio should be between 0 and 1")
145+
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")
140146

141147
if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
142-
raise ValueError("Min learning rate ratio should be between 0 and 1")
148+
raise ValueError(
149+
f"Min learning rate ratio should be between 0 and 1 (got {min_lr_ratio})"
150+
)
143151

144152
if max_grad_norm < 0:
145-
raise ValueError("Max gradient norm should be non-negative")
153+
raise ValueError(
154+
f"Max gradient norm should be non-negative (got {max_grad_norm})"
155+
)
146156

147157
if weight_decay is not None and (weight_decay < 0):
148-
raise ValueError("Weight decay should be non-negative")
158+
raise ValueError(f"Weight decay should be non-negative (got {weight_decay})")
149159

150160
if training_method not in AVAILABLE_TRAINING_METHODS:
151161
raise ValueError(
@@ -155,7 +165,9 @@ def create_finetune_request(
155165
lr_scheduler: FinetuneLRScheduler
156166
if lr_scheduler_type == "cosine":
157167
if scheduler_num_cycles <= 0.0:
158-
raise ValueError("Number of cycles should be greater than 0")
168+
raise ValueError(
169+
f"Number of cycles should be greater than 0 (got {scheduler_num_cycles})"
170+
)
159171

160172
lr_scheduler = CosineLRScheduler(
161173
lr_scheduler_args=CosineLRSchedulerArgs(

tests/unit/test_finetune_resources.py

+31-45
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from together.resources.finetune import create_finetune_request
44
from together.types.finetune import (
5-
FinetuneTrainingLimits,
65
FinetuneFullTrainingLimits,
76
FinetuneLoraTrainingLimits,
7+
FinetuneTrainingLimits,
88
)
99

1010

@@ -117,50 +117,36 @@ def test_no_from_checkpoint_no_model_name():
117117
)
118118

119119

120-
def test_batch_size_limit():
121-
with pytest.raises(
122-
ValueError,
123-
match="Requested batch size is higher that the maximum allowed value",
124-
):
125-
_ = create_finetune_request(
126-
model_limits=_MODEL_LIMITS,
127-
model=_MODEL_NAME,
128-
training_file=_TRAINING_FILE,
129-
batch_size=128,
130-
)
131-
132-
with pytest.raises(
133-
ValueError, match="Requested batch size is lower that the minimum allowed value"
134-
):
135-
_ = create_finetune_request(
136-
model_limits=_MODEL_LIMITS,
137-
model=_MODEL_NAME,
138-
training_file=_TRAINING_FILE,
139-
batch_size=1,
140-
)
141-
142-
with pytest.raises(
143-
ValueError,
144-
match="Requested batch size is higher that the maximum allowed value",
145-
):
146-
_ = create_finetune_request(
147-
model_limits=_MODEL_LIMITS,
148-
model=_MODEL_NAME,
149-
training_file=_TRAINING_FILE,
150-
batch_size=256,
151-
lora=True,
152-
)
153-
154-
with pytest.raises(
155-
ValueError, match="Requested batch size is lower that the minimum allowed value"
156-
):
157-
_ = create_finetune_request(
158-
model_limits=_MODEL_LIMITS,
159-
model=_MODEL_NAME,
160-
training_file=_TRAINING_FILE,
161-
batch_size=1,
162-
lora=True,
163-
)
120+
@pytest.mark.parametrize("batch_size", [256, 1])
121+
@pytest.mark.parametrize("use_lora", [False, True])
122+
def test_batch_size_limit(batch_size, use_lora):
123+
model_limits = (
124+
_MODEL_LIMITS.full_training if not use_lora else _MODEL_LIMITS.lora_training
125+
)
126+
max_batch_size = model_limits.max_batch_size
127+
min_batch_size = model_limits.min_batch_size
128+
129+
if batch_size > max_batch_size:
130+
error_message = f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}"
131+
with pytest.raises(ValueError, match=error_message):
132+
_ = create_finetune_request(
133+
model_limits=_MODEL_LIMITS,
134+
model=_MODEL_NAME,
135+
training_file=_TRAINING_FILE,
136+
batch_size=batch_size,
137+
lora=use_lora,
138+
)
139+
140+
if batch_size < min_batch_size:
141+
error_message = f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}"
142+
with pytest.raises(ValueError, match=error_message):
143+
_ = create_finetune_request(
144+
model_limits=_MODEL_LIMITS,
145+
model=_MODEL_NAME,
146+
training_file=_TRAINING_FILE,
147+
batch_size=batch_size,
148+
lora=use_lora,
149+
)
164150

165151

166152
def test_non_lora_model():

0 commit comments

Comments
 (0)