Skip to content

Commit c026b7b

Browse files
Migrate train_on_inputs to sft-specific params (#297)
* migrate to sft_on_inputs, and change defaults to match * add validation to dpo_beta * tests * remove redundant 'automatically' Co-authored-by: Artem Chumachenko <[email protected]> --------- Co-authored-by: Artem Chumachenko <[email protected]>
1 parent 4eef896 commit c026b7b

File tree

5 files changed

+64
-23
lines changed

5 files changed

+64
-23
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.9"
15+
version = "1.5.10"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
import re
45
from datetime import datetime, timezone
56
from textwrap import wrap
67
from typing import Any, Literal
7-
import re
88

99
import click
1010
from click.core import ParameterSource # type: ignore[attr-defined]
@@ -13,17 +13,17 @@
1313

1414
from together import Together
1515
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
16+
from together.types.finetune import (
17+
DownloadCheckpointType,
18+
FinetuneEventType,
19+
FinetuneTrainingLimits,
20+
)
1621
from together.utils import (
1722
finetune_price_to_dollars,
23+
format_timestamp,
1824
log_warn,
1925
log_warn_once,
2026
parse_timestamp,
21-
format_timestamp,
22-
)
23-
from together.types.finetune import (
24-
DownloadCheckpointType,
25-
FinetuneTrainingLimits,
26-
FinetuneEventType,
2727
)
2828

2929

@@ -348,9 +348,9 @@ def list(ctx: click.Context) -> None:
348348
"Model Output Name": "\n".join(wrap(i.output_name or "", width=30)),
349349
"Status": i.status,
350350
"Created At": i.created_at,
351-
"Price": f"""${finetune_price_to_dollars(
352-
float(str(i.total_price))
353-
)}""", # convert to string for mypy typing
351+
"Price": f"""${
352+
finetune_price_to_dollars(float(str(i.total_price)))
353+
}""", # convert to string for mypy typing
354354
}
355355
)
356356
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)

src/together/resources/finetune.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_finetune_request(
6969
wandb_base_url: str | None = None,
7070
wandb_project_name: str | None = None,
7171
wandb_name: str | None = None,
72-
train_on_inputs: bool | Literal["auto"] = "auto",
72+
train_on_inputs: bool | Literal["auto"] | None = None,
7373
training_method: str = "sft",
7474
dpo_beta: float | None = None,
7575
from_checkpoint: str | None = None,
@@ -166,6 +166,18 @@ def create_finetune_request(
166166
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
167167
)
168168

169+
if train_on_inputs is not None and training_method != "sft":
170+
raise ValueError("train_on_inputs is only supported for SFT training")
171+
172+
if train_on_inputs is None and training_method == "sft":
173+
log_warn_once(
174+
"train_on_inputs is not set for SFT training, it will be set to 'auto'"
175+
)
176+
train_on_inputs = "auto"
177+
178+
if dpo_beta is not None and training_method != "dpo":
179+
raise ValueError("dpo_beta is only supported for DPO training")
180+
169181
lr_scheduler: FinetuneLRScheduler
170182
if lr_scheduler_type == "cosine":
171183
if scheduler_num_cycles <= 0.0:
@@ -183,8 +195,10 @@ def create_finetune_request(
183195
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
184196
)
185197

186-
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
187-
if training_method == "dpo":
198+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO
199+
if training_method == "sft":
200+
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
201+
elif training_method == "dpo":
188202
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
189203

190204
finetune_request = FinetuneRequest(
@@ -206,7 +220,6 @@ def create_finetune_request(
206220
wandb_base_url=wandb_base_url,
207221
wandb_project_name=wandb_project_name,
208222
wandb_name=wandb_name,
209-
train_on_inputs=train_on_inputs,
210223
training_method=training_method_cls,
211224
from_checkpoint=from_checkpoint,
212225
)
@@ -281,7 +294,7 @@ def create(
281294
wandb_name: str | None = None,
282295
verbose: bool = False,
283296
model_limits: FinetuneTrainingLimits | None = None,
284-
train_on_inputs: bool | Literal["auto"] = "auto",
297+
train_on_inputs: bool | Literal["auto"] | None = None,
285298
training_method: str = "sft",
286299
dpo_beta: float | None = None,
287300
from_checkpoint: str | None = None,
@@ -326,12 +339,12 @@ def create(
326339
Defaults to False.
327340
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
328341
Defaults to None.
329-
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
342+
train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data.
330343
"auto" will automatically determine whether to mask the inputs based on the data format.
331344
For datasets with the "text" field (general format), inputs will not be masked.
332345
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
333346
(Instruction format), inputs will be masked.
334-
Defaults to "auto".
347+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
335348
training_method (str, optional): Training method. Defaults to "sft".
336349
Supported methods: "sft", "dpo".
337350
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
@@ -693,7 +706,7 @@ async def create(
693706
wandb_name: str | None = None,
694707
verbose: bool = False,
695708
model_limits: FinetuneTrainingLimits | None = None,
696-
train_on_inputs: bool | Literal["auto"] = "auto",
709+
train_on_inputs: bool | Literal["auto"] | None = None,
697710
training_method: str = "sft",
698711
dpo_beta: float | None = None,
699712
from_checkpoint: str | None = None,
@@ -743,7 +756,7 @@ async def create(
743756
For datasets with the "text" field (general format), inputs will not be masked.
744757
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
745758
(Instruction format), inputs will be masked.
746-
Defaults to "auto".
759+
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
747760
training_method (str, optional): Training method. Defaults to "sft".
748761
Supported methods: "sft", "dpo".
749762
dpo_beta (float, optional): DPO beta parameter. Defaults to None.

src/together/types/finetune.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from typing import List, Literal, Any
55

6-
from pydantic import StrictBool, Field, field_validator
6+
from pydantic import Field, StrictBool, field_validator
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -149,6 +149,7 @@ class TrainingMethodSFT(TrainingMethod):
149149
"""
150150

151151
method: Literal["sft"] = "sft"
152+
train_on_inputs: StrictBool | Literal["auto"] = "auto"
152153

153154

154155
class TrainingMethodDPO(TrainingMethod):
@@ -201,8 +202,6 @@ class FinetuneRequest(BaseModel):
201202
wandb_name: str | None = None
202203
# training type
203204
training_type: FullTrainingType | LoRATrainingType | None = None
204-
# train on inputs
205-
train_on_inputs: StrictBool | Literal["auto"] = "auto"
206205
# training method
207206
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208207
default_factory=TrainingMethodSFT

tests/unit/test_finetune_resources.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,32 @@ def test_bad_training_method():
281281
training_file=_TRAINING_FILE,
282282
training_method="NON_SFT",
283283
)
284+
285+
286+
@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None])
287+
def test_train_on_inputs_for_sft(train_on_inputs):
288+
request = create_finetune_request(
289+
model_limits=_MODEL_LIMITS,
290+
model=_MODEL_NAME,
291+
training_file=_TRAINING_FILE,
292+
training_method="sft",
293+
train_on_inputs=train_on_inputs,
294+
)
295+
assert request.training_method.method == "sft"
296+
if isinstance(train_on_inputs, bool):
297+
assert request.training_method.train_on_inputs is train_on_inputs
298+
else:
299+
assert request.training_method.train_on_inputs == "auto"
300+
301+
302+
def test_train_on_inputs_not_supported_for_dpo():
303+
with pytest.raises(
304+
ValueError, match="train_on_inputs is only supported for SFT training"
305+
):
306+
_ = create_finetune_request(
307+
model_limits=_MODEL_LIMITS,
308+
model=_MODEL_NAME,
309+
training_file=_TRAINING_FILE,
310+
training_method="dpo",
311+
train_on_inputs=True,
312+
)

0 commit comments

Comments
 (0)