Skip to content

Commit e157fcd

Browse files
artek0chumakmryab
andauthored
Add instruction and conversational data support (#211)
* Add format checks * add tests * add train on inputs flag * style * PR feedback * style * more tests * enhance logic * enhance logic * pr feedback part 1 * style and fixed * pr feedback * style * style * fix typing * change to strict boolean * error out on train_on_inputs * use "auto" directly * add system message * version bump * Update src/together/cli/api/finetune.py Co-authored-by: Max Ryabinin <[email protected]> --------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent 296f2a5 commit e157fcd

File tree

8 files changed

+509
-47
lines changed

8 files changed

+509
-47
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.3.3"
15+
version = "1.3.4"
1616
authors = [
1717
"Together AI <[email protected]>"
1818
]

src/together/cli/api/finetune.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
from tabulate import tabulate
1212

1313
from together import Together
14-
from together.cli.api.utils import INT_WITH_MAX
15-
from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
14+
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
15+
from together.utils import (
16+
finetune_price_to_dollars,
17+
log_warn,
18+
log_warn_once,
19+
parse_timestamp,
20+
)
1621
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
1722

1823

@@ -93,6 +98,13 @@ def fine_tuning(ctx: click.Context) -> None:
9398
default=False,
9499
help="Whether to skip the launch confirmation message",
95100
)
101+
@click.option(
102+
"--train-on-inputs",
103+
type=BOOL_WITH_AUTO,
104+
default="auto",
105+
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
106+
"`auto` will automatically determine whether to mask the inputs based on the data format.",
107+
)
96108
def create(
97109
ctx: click.Context,
98110
training_file: str,
@@ -112,6 +124,7 @@ def create(
112124
suffix: str,
113125
wandb_api_key: str,
114126
confirm: bool,
127+
train_on_inputs: bool | Literal["auto"],
115128
) -> None:
116129
"""Start fine-tuning"""
117130
client: Together = ctx.obj
@@ -133,6 +146,7 @@ def create(
133146
lora_trainable_modules=lora_trainable_modules,
134147
suffix=suffix,
135148
wandb_api_key=wandb_api_key,
149+
train_on_inputs=train_on_inputs,
136150
)
137151

138152
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
@@ -150,6 +164,10 @@ def create(
150164
"batch_size": model_limits.lora_training.max_batch_size,
151165
"learning_rate": 1e-3,
152166
}
167+
log_warn_once(
168+
f"The default LoRA rank for {model} has been changed to {default_values['lora_r']} as the max available.\n"
169+
f"Also, the default learning rate for LoRA fine-tuning has been changed to {default_values['learning_rate']}."
170+
)
153171
for arg in default_values:
154172
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
155173
if arg_source == ParameterSource.DEFAULT:
@@ -186,22 +204,7 @@ def create(
186204

187205
if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
188206
response = client.fine_tuning.create(
189-
training_file=training_file,
190-
model=model,
191-
n_epochs=n_epochs,
192-
validation_file=validation_file,
193-
n_evals=n_evals,
194-
n_checkpoints=n_checkpoints,
195-
batch_size=batch_size,
196-
learning_rate=learning_rate,
197-
warmup_ratio=warmup_ratio,
198-
lora=lora,
199-
lora_r=lora_r,
200-
lora_dropout=lora_dropout,
201-
lora_alpha=lora_alpha,
202-
lora_trainable_modules=lora_trainable_modules,
203-
suffix=suffix,
204-
wandb_api_key=wandb_api_key,
207+
**training_args,
205208
verbose=True,
206209
)
207210

src/together/cli/api/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,25 @@ def convert(
2727
)
2828

2929

30+
class BooleanWithAutoParamType(click.ParamType):
31+
name = "boolean_or_auto"
32+
33+
def convert(
34+
self, value: str, param: click.Parameter | None, ctx: click.Context | None
35+
) -> bool | Literal["auto"] | None:
36+
if value == "auto":
37+
return "auto"
38+
try:
39+
return bool(value)
40+
except ValueError:
41+
self.fail(
42+
_("{value!r} is not a valid {type}.").format(
43+
value=value, type=self.name
44+
),
45+
param,
46+
ctx,
47+
)
48+
49+
3050
INT_WITH_MAX = AutoIntParamType()
51+
BOOL_WITH_AUTO = BooleanWithAutoParamType()

src/together/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import enum
2+
13
# Session constants
24
TIMEOUT_SECS = 600
35
MAX_SESSION_LIFETIME_SECS = 180
@@ -29,3 +31,20 @@
2931

3032
# expected columns for Parquet files
3133
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
34+
35+
36+
class DatasetFormat(enum.Enum):
37+
"""Dataset format enum."""
38+
39+
GENERAL = "general"
40+
CONVERSATION = "conversation"
41+
INSTRUCTION = "instruction"
42+
43+
44+
JSONL_REQUIRED_COLUMNS_MAP = {
45+
DatasetFormat.GENERAL: ["text"],
46+
DatasetFormat.CONVERSATION: ["messages"],
47+
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
48+
}
49+
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
50+
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]

src/together/resources/finetune.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def createFinetuneRequest(
4343
lora_trainable_modules: str | None = "all-linear",
4444
suffix: str | None = None,
4545
wandb_api_key: str | None = None,
46+
train_on_inputs: bool | Literal["auto"] = "auto",
4647
) -> FinetuneRequest:
4748
if batch_size == "max":
4849
log_warn_once(
@@ -95,6 +96,7 @@ def createFinetuneRequest(
9596
training_type=training_type,
9697
suffix=suffix,
9798
wandb_key=wandb_api_key,
99+
train_on_inputs=train_on_inputs,
98100
)
99101

100102
return finetune_request
@@ -125,6 +127,7 @@ def create(
125127
wandb_api_key: str | None = None,
126128
verbose: bool = False,
127129
model_limits: FinetuneTrainingLimits | None = None,
130+
train_on_inputs: bool | Literal["auto"] = "auto",
128131
) -> FinetuneResponse:
129132
"""
130133
Method to initiate a fine-tuning job
@@ -137,7 +140,7 @@ def create(
137140
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
138141
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
139142
Defaults to 1.
140-
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
143+
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
141144
learning_rate (float, optional): Learning rate multiplier to use for training
142145
Defaults to 0.00001.
143146
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
@@ -154,6 +157,12 @@ def create(
154157
Defaults to False.
155158
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
156159
Defaults to None.
160+
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
161+
"auto" will automatically determine whether to mask the inputs based on the data format.
162+
For datasets with the "text" field (general format), inputs will not be masked.
163+
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
164+
(Instruction format), inputs will be masked.
165+
Defaults to "auto".
157166
158167
Returns:
159168
FinetuneResponse: Object containing information about fine-tuning job.
@@ -184,6 +193,7 @@ def create(
184193
lora_trainable_modules=lora_trainable_modules,
185194
suffix=suffix,
186195
wandb_api_key=wandb_api_key,
196+
train_on_inputs=train_on_inputs,
187197
)
188198

189199
if verbose:
@@ -436,6 +446,7 @@ async def create(
436446
wandb_api_key: str | None = None,
437447
verbose: bool = False,
438448
model_limits: FinetuneTrainingLimits | None = None,
449+
train_on_inputs: bool | Literal["auto"] = "auto",
439450
) -> FinetuneResponse:
440451
"""
441452
Async method to initiate a fine-tuning job
@@ -465,6 +476,12 @@ async def create(
465476
Defaults to False.
466477
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
467478
Defaults to None.
479+
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
480+
"auto" will automatically determine whether to mask the inputs based on the data format.
481+
For datasets with the "text" field (general format), inputs will not be masked.
482+
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
483+
(Instruction format), inputs will be masked.
484+
Defaults to "auto".
468485
469486
Returns:
470487
FinetuneResponse: Object containing information about fine-tuning job.
@@ -495,6 +512,7 @@ async def create(
495512
lora_trainable_modules=lora_trainable_modules,
496513
suffix=suffix,
497514
wandb_api_key=wandb_api_key,
515+
train_on_inputs=train_on_inputs,
498516
)
499517

500518
if verbose:

src/together/types/finetune.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from typing import List, Literal
55

6-
from pydantic import Field, validator, field_validator
6+
from pydantic import StrictBool, Field, validator, field_validator
77

88
from together.types.abstract import BaseModel
99
from together.types.common import (
@@ -163,6 +163,7 @@ class FinetuneRequest(BaseModel):
163163
# weights & biases api key
164164
wandb_key: str | None = None
165165
training_type: FullTrainingType | LoRATrainingType | None = None
166+
train_on_inputs: StrictBool | Literal["auto"] = "auto"
166167

167168

168169
class FinetuneResponse(BaseModel):
@@ -230,6 +231,7 @@ class FinetuneResponse(BaseModel):
230231
# training file metadata
231232
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
232233
training_file_size: int | None = Field(None, alias="TrainingFileSize")
234+
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
233235

234236
@field_validator("training_type")
235237
@classmethod

0 commit comments

Comments
 (0)