25
25
from together .utils import log_warn_once , normalize_key
26
26
27
27
28
+ def createFinetuneRequest (
29
+ model_limits : FinetuneTrainingLimits ,
30
+ training_file : str ,
31
+ model : str ,
32
+ n_epochs : int = 1 ,
33
+ validation_file : str | None = "" ,
34
+ n_evals : int | None = 0 ,
35
+ n_checkpoints : int | None = 1 ,
36
+ batch_size : int | Literal ["max" ] = "max" ,
37
+ learning_rate : float | None = 0.00001 ,
38
+ warmup_ratio : float | None = 0.0 ,
39
+ lora : bool = False ,
40
+ lora_r : int | None = None ,
41
+ lora_dropout : float | None = 0 ,
42
+ lora_alpha : float | None = None ,
43
+ lora_trainable_modules : str | None = "all-linear" ,
44
+ suffix : str | None = None ,
45
+ wandb_api_key : str | None = None ,
46
+ ) -> FinetuneRequest :
47
+ if batch_size == "max" :
48
+ log_warn_once (
49
+ "Starting from together>=1.3.0, "
50
+ "the default batch size is set to the maximum allowed value for each model."
51
+ )
52
+ if warmup_ratio is None :
53
+ warmup_ratio = 0.0
54
+
55
+ training_type : TrainingType = FullTrainingType ()
56
+ if lora :
57
+ if model_limits .lora_training is None :
58
+ raise ValueError ("LoRA adapters are not supported for the selected model." )
59
+ lora_r = lora_r if lora_r is not None else model_limits .lora_training .max_rank
60
+ lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
61
+ training_type = LoRATrainingType (
62
+ lora_r = lora_r ,
63
+ lora_alpha = lora_alpha ,
64
+ lora_dropout = lora_dropout ,
65
+ lora_trainable_modules = lora_trainable_modules ,
66
+ )
67
+
68
+ batch_size = (
69
+ batch_size
70
+ if batch_size != "max"
71
+ else model_limits .lora_training .max_batch_size
72
+ )
73
+ else :
74
+ if model_limits .full_training is None :
75
+ raise ValueError ("Full training is not supported for the selected model." )
76
+ batch_size = (
77
+ batch_size
78
+ if batch_size != "max"
79
+ else model_limits .full_training .max_batch_size
80
+ )
81
+
82
+ if warmup_ratio > 1 or warmup_ratio < 0 :
83
+ raise ValueError ("Warmup ratio should be between 0 and 1" )
84
+
85
+ finetune_request = FinetuneRequest (
86
+ model = model ,
87
+ training_file = training_file ,
88
+ validation_file = validation_file ,
89
+ n_epochs = n_epochs ,
90
+ n_evals = n_evals ,
91
+ n_checkpoints = n_checkpoints ,
92
+ batch_size = batch_size ,
93
+ learning_rate = learning_rate ,
94
+ warmup_ratio = warmup_ratio ,
95
+ training_type = training_type ,
96
+ suffix = suffix ,
97
+ wandb_key = wandb_api_key ,
98
+ )
99
+
100
+ return finetune_request
101
+
102
+
28
103
class FineTuning :
29
104
def __init__ (self , client : TogetherClient ) -> None :
30
105
self ._client = client
@@ -40,6 +115,7 @@ def create(
40
115
n_checkpoints : int | None = 1 ,
41
116
batch_size : int | Literal ["max" ] = "max" ,
42
117
learning_rate : float | None = 0.00001 ,
118
+ warmup_ratio : float | None = 0.0 ,
43
119
lora : bool = False ,
44
120
lora_r : int | None = None ,
45
121
lora_dropout : float | None = 0 ,
@@ -64,6 +140,7 @@ def create(
64
140
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
65
141
learning_rate (float, optional): Learning rate multiplier to use for training
66
142
Defaults to 0.00001.
143
+ warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
67
144
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
68
145
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
69
146
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
@@ -82,65 +159,33 @@ def create(
82
159
FinetuneResponse: Object containing information about fine-tuning job.
83
160
"""
84
161
85
- if batch_size == "max" :
86
- log_warn_once (
87
- "Starting from together>=1.3.0, "
88
- "the default batch size is set to the maximum allowed value for each model."
89
- )
90
-
91
162
requestor = api_requestor .APIRequestor (
92
163
client = self ._client ,
93
164
)
94
165
95
166
if model_limits is None :
96
167
model_limits = self .get_model_limits (model = model )
97
168
98
- training_type : TrainingType = FullTrainingType ()
99
- if lora :
100
- if model_limits .lora_training is None :
101
- raise ValueError (
102
- "LoRA adapters are not supported for the selected model."
103
- )
104
- lora_r = (
105
- lora_r if lora_r is not None else model_limits .lora_training .max_rank
106
- )
107
- lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
108
- training_type = LoRATrainingType (
109
- lora_r = lora_r ,
110
- lora_alpha = lora_alpha ,
111
- lora_dropout = lora_dropout ,
112
- lora_trainable_modules = lora_trainable_modules ,
113
- )
114
-
115
- batch_size = (
116
- batch_size
117
- if batch_size != "max"
118
- else model_limits .lora_training .max_batch_size
119
- )
120
- else :
121
- if model_limits .full_training is None :
122
- raise ValueError (
123
- "Full training is not supported for the selected model."
124
- )
125
- batch_size = (
126
- batch_size
127
- if batch_size != "max"
128
- else model_limits .full_training .max_batch_size
129
- )
130
-
131
- finetune_request = FinetuneRequest (
132
- model = model ,
169
+ finetune_request = createFinetuneRequest (
170
+ model_limits = model_limits ,
133
171
training_file = training_file ,
134
- validation_file = validation_file ,
172
+ model = model ,
135
173
n_epochs = n_epochs ,
174
+ validation_file = validation_file ,
136
175
n_evals = n_evals ,
137
176
n_checkpoints = n_checkpoints ,
138
177
batch_size = batch_size ,
139
178
learning_rate = learning_rate ,
140
- training_type = training_type ,
179
+ warmup_ratio = warmup_ratio ,
180
+ lora = lora ,
181
+ lora_r = lora_r ,
182
+ lora_dropout = lora_dropout ,
183
+ lora_alpha = lora_alpha ,
184
+ lora_trainable_modules = lora_trainable_modules ,
141
185
suffix = suffix ,
142
- wandb_key = wandb_api_key ,
186
+ wandb_api_key = wandb_api_key ,
143
187
)
188
+
144
189
if verbose :
145
190
rprint (
146
191
"Submitting a fine-tuning job with the following parameters:" ,
@@ -377,12 +422,20 @@ async def create(
377
422
model : str ,
378
423
n_epochs : int = 1 ,
379
424
validation_file : str | None = "" ,
380
- n_evals : int = 0 ,
425
+ n_evals : int | None = 0 ,
381
426
n_checkpoints : int | None = 1 ,
382
- batch_size : int | None = 32 ,
383
- learning_rate : float = 0.00001 ,
427
+ batch_size : int | Literal ["max" ] = "max" ,
428
+ learning_rate : float | None = 0.00001 ,
429
+ warmup_ratio : float | None = 0.0 ,
430
+ lora : bool = False ,
431
+ lora_r : int | None = None ,
432
+ lora_dropout : float | None = 0 ,
433
+ lora_alpha : float | None = None ,
434
+ lora_trainable_modules : str | None = "all-linear" ,
384
435
suffix : str | None = None ,
385
436
wandb_api_key : str | None = None ,
437
+ verbose : bool = False ,
438
+ model_limits : FinetuneTrainingLimits | None = None ,
386
439
) -> FinetuneResponse :
387
440
"""
388
441
Async method to initiate a fine-tuning job
@@ -395,13 +448,23 @@ async def create(
395
448
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
396
449
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
397
450
Defaults to 1.
398
- batch_size (int, optional): Batch size for fine-tuning. Defaults to 32 .
451
+ batch_size (int, optional): Batch size for fine-tuning. Defaults to max .
399
452
learning_rate (float, optional): Learning rate multiplier to use for training
400
453
Defaults to 0.00001.
454
+ warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
455
+ lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
456
+ lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
457
+ lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
458
+ lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
459
+ lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
401
460
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
402
461
Defaults to None.
403
462
wandb_api_key (str, optional): API key for Weights & Biases integration.
404
463
Defaults to None.
464
+ verbose (bool, optional): whether to print the job parameters before submitting a request.
465
+ Defaults to False.
466
+ model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
467
+ Defaults to None.
405
468
406
469
Returns:
407
470
FinetuneResponse: Object containing information about fine-tuning job.
@@ -411,18 +474,35 @@ async def create(
411
474
client = self ._client ,
412
475
)
413
476
414
- parameter_payload = FinetuneRequest (
415
- model = model ,
477
+ if model_limits is None :
478
+ model_limits = await self .get_model_limits (model = model )
479
+
480
+ finetune_request = createFinetuneRequest (
481
+ model_limits = model_limits ,
416
482
training_file = training_file ,
417
- validation_file = validation_file ,
483
+ model = model ,
418
484
n_epochs = n_epochs ,
485
+ validation_file = validation_file ,
419
486
n_evals = n_evals ,
420
487
n_checkpoints = n_checkpoints ,
421
488
batch_size = batch_size ,
422
489
learning_rate = learning_rate ,
490
+ warmup_ratio = warmup_ratio ,
491
+ lora = lora ,
492
+ lora_r = lora_r ,
493
+ lora_dropout = lora_dropout ,
494
+ lora_alpha = lora_alpha ,
495
+ lora_trainable_modules = lora_trainable_modules ,
423
496
suffix = suffix ,
424
- wandb_key = wandb_api_key ,
425
- ).model_dump (exclude_none = True )
497
+ wandb_api_key = wandb_api_key ,
498
+ )
499
+
500
+ if verbose :
501
+ rprint (
502
+ "Submitting a fine-tuning job with the following parameters:" ,
503
+ finetune_request ,
504
+ )
505
+ parameter_payload = finetune_request .model_dump (exclude_none = True )
426
506
427
507
response , _ , _ = await requestor .arequest (
428
508
options = TogetherRequest (
0 commit comments