Skip to content

Commit 2479615

Browse files
authored
Fix dpo defaults for cli + turn on the checks removed previously (#330)
* Fix dpo defaults for cli + turn on the checks removed previously
1 parent 4655ead commit 2479615

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.16"
15+
version = "1.5.17"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def fine_tuning(ctx: click.Context) -> None:
139139
@click.option(
140140
"--dpo-beta",
141141
type=float,
142-
default=0.1,
142+
default=None,
143143
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
144144
)
145145
@click.option(
@@ -154,7 +154,7 @@ def fine_tuning(ctx: click.Context) -> None:
154154
@click.option(
155155
"--rpo-alpha",
156156
type=float,
157-
default=0.0,
157+
default=None,
158158
help=(
159159
"RPO alpha parameter of DPO training to include NLL in the loss "
160160
"(only used when '--training-method' is 'dpo')"
@@ -163,7 +163,7 @@ def fine_tuning(ctx: click.Context) -> None:
163163
@click.option(
164164
"--simpo-gamma",
165165
type=float,
166-
default=0.0,
166+
default=None,
167167
help="SimPO gamma parameter (only used when '--training-method' is 'dpo')",
168168
)
169169
@click.option(
@@ -188,7 +188,7 @@ def fine_tuning(ctx: click.Context) -> None:
188188
@click.option(
189189
"--train-on-inputs",
190190
type=BOOL_WITH_AUTO,
191-
default="auto",
191+
default=None,
192192
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
193193
"`auto` will automatically determine whether to mask the inputs based on the data format.",
194194
)
@@ -229,10 +229,10 @@ def create(
229229
confirm: bool,
230230
train_on_inputs: bool | Literal["auto"],
231231
training_method: str,
232-
dpo_beta: float,
232+
dpo_beta: float | None,
233233
dpo_normalize_logratios_by_length: bool,
234-
rpo_alpha: float,
235-
simpo_gamma: float,
234+
rpo_alpha: float | None,
235+
simpo_gamma: float | None,
236236
from_checkpoint: str,
237237
) -> None:
238238
"""Start fine-tuning"""

src/together/resources/finetune.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,24 @@ def create_finetune_request(
183183
)
184184
train_on_inputs = "auto"
185185

186+
if dpo_beta is not None and training_method != "dpo":
187+
raise ValueError("dpo_beta is only supported for DPO training")
188+
if dpo_normalize_logratios_by_length and training_method != "dpo":
189+
raise ValueError(
190+
"dpo_normalize_logratios_by_length=True is only supported for DPO training"
191+
)
192+
if rpo_alpha is not None:
193+
if training_method != "dpo":
194+
raise ValueError("rpo_alpha is only supported for DPO training")
195+
if not rpo_alpha >= 0.0:
196+
raise ValueError(f"rpo_alpha should be non-negative (got {rpo_alpha})")
197+
198+
if simpo_gamma is not None:
199+
if training_method != "dpo":
200+
raise ValueError("simpo_gamma is only supported for DPO training")
201+
if not simpo_gamma >= 0.0:
202+
raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})")
203+
186204
lr_scheduler: FinetuneLRScheduler
187205
if lr_scheduler_type == "cosine":
188206
if scheduler_num_cycles <= 0.0:

0 commit comments

Comments
 (0)