@@ -69,7 +69,7 @@ def create_finetune_request(
69
69
wandb_base_url : str | None = None ,
70
70
wandb_project_name : str | None = None ,
71
71
wandb_name : str | None = None ,
72
- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
72
+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
73
73
training_method : str = "sft" ,
74
74
dpo_beta : float | None = None ,
75
75
from_checkpoint : str | None = None ,
@@ -166,6 +166,18 @@ def create_finetune_request(
166
166
f"training_method must be one of { ', ' .join (AVAILABLE_TRAINING_METHODS )} "
167
167
)
168
168
169
+ if train_on_inputs is not None and training_method != "sft" :
170
+ raise ValueError ("train_on_inputs is only supported for SFT training" )
171
+
172
+ if train_on_inputs is None and training_method == "sft" :
173
+ log_warn_once (
174
+ "train_on_inputs is not set for SFT training, it will be set to 'auto'"
175
+ )
176
+ train_on_inputs = "auto"
177
+
178
+ if dpo_beta is not None and training_method != "dpo" :
179
+ raise ValueError ("dpo_beta is only supported for DPO training" )
180
+
169
181
lr_scheduler : FinetuneLRScheduler
170
182
if lr_scheduler_type == "cosine" :
171
183
if scheduler_num_cycles <= 0.0 :
@@ -183,8 +195,10 @@ def create_finetune_request(
183
195
lr_scheduler_args = LinearLRSchedulerArgs (min_lr_ratio = min_lr_ratio ),
184
196
)
185
197
186
- training_method_cls : TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT ()
187
- if training_method == "dpo" :
198
+ training_method_cls : TrainingMethodSFT | TrainingMethodDPO
199
+ if training_method == "sft" :
200
+ training_method_cls = TrainingMethodSFT (train_on_inputs = train_on_inputs )
201
+ elif training_method == "dpo" :
188
202
training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
189
203
190
204
finetune_request = FinetuneRequest (
@@ -206,7 +220,6 @@ def create_finetune_request(
206
220
wandb_base_url = wandb_base_url ,
207
221
wandb_project_name = wandb_project_name ,
208
222
wandb_name = wandb_name ,
209
- train_on_inputs = train_on_inputs ,
210
223
training_method = training_method_cls ,
211
224
from_checkpoint = from_checkpoint ,
212
225
)
@@ -281,7 +294,7 @@ def create(
281
294
wandb_name : str | None = None ,
282
295
verbose : bool = False ,
283
296
model_limits : FinetuneTrainingLimits | None = None ,
284
- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
297
+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
285
298
training_method : str = "sft" ,
286
299
dpo_beta : float | None = None ,
287
300
from_checkpoint : str | None = None ,
@@ -326,12 +339,12 @@ def create(
326
339
Defaults to False.
327
340
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
328
341
Defaults to None.
329
- train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
342
+ train_on_inputs (bool or "auto", optional ): Whether to mask the user messages in conversational data or prompts in instruction data.
330
343
"auto" will automatically determine whether to mask the inputs based on the data format.
331
344
For datasets with the "text" field (general format), inputs will not be masked.
332
345
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
333
346
(Instruction format), inputs will be masked.
334
- Defaults to "auto".
347
+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
335
348
training_method (str, optional): Training method. Defaults to "sft".
336
349
Supported methods: "sft", "dpo".
337
350
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -693,7 +706,7 @@ async def create(
693
706
wandb_name : str | None = None ,
694
707
verbose : bool = False ,
695
708
model_limits : FinetuneTrainingLimits | None = None ,
696
- train_on_inputs : bool | Literal ["auto" ] = "auto" ,
709
+ train_on_inputs : bool | Literal ["auto" ] | None = None ,
697
710
training_method : str = "sft" ,
698
711
dpo_beta : float | None = None ,
699
712
from_checkpoint : str | None = None ,
@@ -743,7 +756,7 @@ async def create(
743
756
For datasets with the "text" field (general format), inputs will not be masked.
744
757
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
745
758
(Instruction format), inputs will be masked.
746
- Defaults to "auto".
759
+ Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request) .
747
760
training_method (str, optional): Training method. Defaults to "sft".
748
761
Supported methods: "sft", "dpo".
749
762
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
0 commit comments