Skip to content

Commit bfc43de

Browse files
azahed98Arsh Zahedorangetin
authored
ENG 12851: Add warmup ratio and update AsyncFineTuning (#202)
* Add warmup ratio and update async * Add warmup_ratio to args --------- Co-authored-by: Arsh Zahed <[email protected]> Co-authored-by: orangetin <[email protected]>
1 parent e651fde commit bfc43de

File tree

4 files changed

+148
-53
lines changed

4 files changed

+148
-53
lines changed

src/together/cli/api/finetune.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def fine_tuning(ctx: click.Context) -> None:
6060
)
6161
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
6262
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
63+
@click.option(
64+
"--warmup-ratio",
65+
type=float,
66+
default=0.0,
67+
help="Warmup ratio for learning rate scheduler.",
68+
)
6369
@click.option(
6470
"--lora/--no-lora",
6571
type=bool,
@@ -97,6 +103,7 @@ def create(
97103
n_checkpoints: int,
98104
batch_size: int | Literal["max"],
99105
learning_rate: float,
106+
warmup_ratio: float,
100107
lora: bool,
101108
lora_r: int,
102109
lora_dropout: float,
@@ -118,6 +125,7 @@ def create(
118125
n_checkpoints=n_checkpoints,
119126
batch_size=batch_size,
120127
learning_rate=learning_rate,
128+
warmup_ratio=warmup_ratio,
121129
lora=lora,
122130
lora_r=lora_r,
123131
lora_dropout=lora_dropout,
@@ -186,6 +194,7 @@ def create(
186194
n_checkpoints=n_checkpoints,
187195
batch_size=batch_size,
188196
learning_rate=learning_rate,
197+
warmup_ratio=warmup_ratio,
189198
lora=lora,
190199
lora_r=lora_r,
191200
lora_dropout=lora_dropout,

src/together/cli/api/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import click
24

35
from typing import Literal

src/together/resources/finetune.py

Lines changed: 133 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,81 @@
2525
from together.utils import log_warn_once, normalize_key
2626

2727

28+
def createFinetuneRequest(
29+
model_limits: FinetuneTrainingLimits,
30+
training_file: str,
31+
model: str,
32+
n_epochs: int = 1,
33+
validation_file: str | None = "",
34+
n_evals: int | None = 0,
35+
n_checkpoints: int | None = 1,
36+
batch_size: int | Literal["max"] = "max",
37+
learning_rate: float | None = 0.00001,
38+
warmup_ratio: float | None = 0.0,
39+
lora: bool = False,
40+
lora_r: int | None = None,
41+
lora_dropout: float | None = 0,
42+
lora_alpha: float | None = None,
43+
lora_trainable_modules: str | None = "all-linear",
44+
suffix: str | None = None,
45+
wandb_api_key: str | None = None,
46+
) -> FinetuneRequest:
47+
if batch_size == "max":
48+
log_warn_once(
49+
"Starting from together>=1.3.0, "
50+
"the default batch size is set to the maximum allowed value for each model."
51+
)
52+
if warmup_ratio is None:
53+
warmup_ratio = 0.0
54+
55+
training_type: TrainingType = FullTrainingType()
56+
if lora:
57+
if model_limits.lora_training is None:
58+
raise ValueError("LoRA adapters are not supported for the selected model.")
59+
lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
60+
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
61+
training_type = LoRATrainingType(
62+
lora_r=lora_r,
63+
lora_alpha=lora_alpha,
64+
lora_dropout=lora_dropout,
65+
lora_trainable_modules=lora_trainable_modules,
66+
)
67+
68+
batch_size = (
69+
batch_size
70+
if batch_size != "max"
71+
else model_limits.lora_training.max_batch_size
72+
)
73+
else:
74+
if model_limits.full_training is None:
75+
raise ValueError("Full training is not supported for the selected model.")
76+
batch_size = (
77+
batch_size
78+
if batch_size != "max"
79+
else model_limits.full_training.max_batch_size
80+
)
81+
82+
if warmup_ratio > 1 or warmup_ratio < 0:
83+
raise ValueError("Warmup ratio should be between 0 and 1")
84+
85+
finetune_request = FinetuneRequest(
86+
model=model,
87+
training_file=training_file,
88+
validation_file=validation_file,
89+
n_epochs=n_epochs,
90+
n_evals=n_evals,
91+
n_checkpoints=n_checkpoints,
92+
batch_size=batch_size,
93+
learning_rate=learning_rate,
94+
warmup_ratio=warmup_ratio,
95+
training_type=training_type,
96+
suffix=suffix,
97+
wandb_key=wandb_api_key,
98+
)
99+
100+
return finetune_request
101+
102+
28103
class FineTuning:
29104
def __init__(self, client: TogetherClient) -> None:
30105
self._client = client
@@ -40,6 +115,7 @@ def create(
40115
n_checkpoints: int | None = 1,
41116
batch_size: int | Literal["max"] = "max",
42117
learning_rate: float | None = 0.00001,
118+
warmup_ratio: float | None = 0.0,
43119
lora: bool = False,
44120
lora_r: int | None = None,
45121
lora_dropout: float | None = 0,
@@ -64,6 +140,7 @@ def create(
64140
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
65141
learning_rate (float, optional): Learning rate multiplier to use for training
66142
Defaults to 0.00001.
143+
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
67144
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
68145
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
69146
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
@@ -82,65 +159,33 @@ def create(
82159
FinetuneResponse: Object containing information about fine-tuning job.
83160
"""
84161

85-
if batch_size == "max":
86-
log_warn_once(
87-
"Starting from together>=1.3.0, "
88-
"the default batch size is set to the maximum allowed value for each model."
89-
)
90-
91162
requestor = api_requestor.APIRequestor(
92163
client=self._client,
93164
)
94165

95166
if model_limits is None:
96167
model_limits = self.get_model_limits(model=model)
97168

98-
training_type: TrainingType = FullTrainingType()
99-
if lora:
100-
if model_limits.lora_training is None:
101-
raise ValueError(
102-
"LoRA adapters are not supported for the selected model."
103-
)
104-
lora_r = (
105-
lora_r if lora_r is not None else model_limits.lora_training.max_rank
106-
)
107-
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
108-
training_type = LoRATrainingType(
109-
lora_r=lora_r,
110-
lora_alpha=lora_alpha,
111-
lora_dropout=lora_dropout,
112-
lora_trainable_modules=lora_trainable_modules,
113-
)
114-
115-
batch_size = (
116-
batch_size
117-
if batch_size != "max"
118-
else model_limits.lora_training.max_batch_size
119-
)
120-
else:
121-
if model_limits.full_training is None:
122-
raise ValueError(
123-
"Full training is not supported for the selected model."
124-
)
125-
batch_size = (
126-
batch_size
127-
if batch_size != "max"
128-
else model_limits.full_training.max_batch_size
129-
)
130-
131-
finetune_request = FinetuneRequest(
132-
model=model,
169+
finetune_request = createFinetuneRequest(
170+
model_limits=model_limits,
133171
training_file=training_file,
134-
validation_file=validation_file,
172+
model=model,
135173
n_epochs=n_epochs,
174+
validation_file=validation_file,
136175
n_evals=n_evals,
137176
n_checkpoints=n_checkpoints,
138177
batch_size=batch_size,
139178
learning_rate=learning_rate,
140-
training_type=training_type,
179+
warmup_ratio=warmup_ratio,
180+
lora=lora,
181+
lora_r=lora_r,
182+
lora_dropout=lora_dropout,
183+
lora_alpha=lora_alpha,
184+
lora_trainable_modules=lora_trainable_modules,
141185
suffix=suffix,
142-
wandb_key=wandb_api_key,
186+
wandb_api_key=wandb_api_key,
143187
)
188+
144189
if verbose:
145190
rprint(
146191
"Submitting a fine-tuning job with the following parameters:",
@@ -377,12 +422,20 @@ async def create(
377422
model: str,
378423
n_epochs: int = 1,
379424
validation_file: str | None = "",
380-
n_evals: int = 0,
425+
n_evals: int | None = 0,
381426
n_checkpoints: int | None = 1,
382-
batch_size: int | None = 32,
383-
learning_rate: float = 0.00001,
427+
batch_size: int | Literal["max"] = "max",
428+
learning_rate: float | None = 0.00001,
429+
warmup_ratio: float | None = 0.0,
430+
lora: bool = False,
431+
lora_r: int | None = None,
432+
lora_dropout: float | None = 0,
433+
lora_alpha: float | None = None,
434+
lora_trainable_modules: str | None = "all-linear",
384435
suffix: str | None = None,
385436
wandb_api_key: str | None = None,
437+
verbose: bool = False,
438+
model_limits: FinetuneTrainingLimits | None = None,
386439
) -> FinetuneResponse:
387440
"""
388441
Async method to initiate a fine-tuning job
@@ -395,13 +448,23 @@ async def create(
395448
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
396449
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
397450
Defaults to 1.
398-
batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
451+
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
399452
learning_rate (float, optional): Learning rate multiplier to use for training
400453
Defaults to 0.00001.
454+
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
455+
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
456+
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
457+
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
458+
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
459+
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
401460
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
402461
Defaults to None.
403462
wandb_api_key (str, optional): API key for Weights & Biases integration.
404463
Defaults to None.
464+
verbose (bool, optional): whether to print the job parameters before submitting a request.
465+
Defaults to False.
466+
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
467+
Defaults to None.
405468
406469
Returns:
407470
FinetuneResponse: Object containing information about fine-tuning job.
@@ -411,18 +474,35 @@ async def create(
411474
client=self._client,
412475
)
413476

414-
parameter_payload = FinetuneRequest(
415-
model=model,
477+
if model_limits is None:
478+
model_limits = await self.get_model_limits(model=model)
479+
480+
finetune_request = createFinetuneRequest(
481+
model_limits=model_limits,
416482
training_file=training_file,
417-
validation_file=validation_file,
483+
model=model,
418484
n_epochs=n_epochs,
485+
validation_file=validation_file,
419486
n_evals=n_evals,
420487
n_checkpoints=n_checkpoints,
421488
batch_size=batch_size,
422489
learning_rate=learning_rate,
490+
warmup_ratio=warmup_ratio,
491+
lora=lora,
492+
lora_r=lora_r,
493+
lora_dropout=lora_dropout,
494+
lora_alpha=lora_alpha,
495+
lora_trainable_modules=lora_trainable_modules,
423496
suffix=suffix,
424-
wandb_key=wandb_api_key,
425-
).model_dump(exclude_none=True)
497+
wandb_api_key=wandb_api_key,
498+
)
499+
500+
if verbose:
501+
rprint(
502+
"Submitting a fine-tuning job with the following parameters:",
503+
finetune_request,
504+
)
505+
parameter_payload = finetune_request.model_dump(exclude_none=True)
426506

427507
response, _, _ = await requestor.arequest(
428508
options=TogetherRequest(

src/together/types/finetune.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ class FinetuneRequest(BaseModel):
150150
n_epochs: int
151151
# training learning rate
152152
learning_rate: float
153+
# learning rate warmup ratio
154+
warmup_ratio: float
153155
# number of checkpoints to save
154156
n_checkpoints: int | None = None
155157
# number of evaluation loops to run
@@ -190,6 +192,8 @@ class FinetuneResponse(BaseModel):
190192
batch_size: int | None = None
191193
# training learning rate
192194
learning_rate: float | None = None
195+
# learning rate warmup ratio
196+
warmup_ratio: float | None = None
193197
# number of steps between evals
194198
eval_steps: int | None = None
195199
# training type

0 commit comments

Comments
 (0)