Skip to content

Commit 5bf0d17

Browse files
artek0chumakmryab
andauthored
Add LoRA support for FT API (#156)
* Add lora support * bump up version * fix docs * fix style and typing * Ignore spurious typing errors --------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent 0409565 commit 5bf0d17

File tree

5 files changed

+105
-7
lines changed

5 files changed

+105
-7
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.2.2"
15+
version = "1.2.3"
1616
authors = [
1717
"Together AI <[email protected]>"
1818
]

src/together/cli/api/finetune.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from textwrap import wrap
33

44
import click
5+
from click.core import ParameterSource # type: ignore[attr-defined]
56
from tabulate import tabulate
67

78
from together import Together
@@ -26,7 +27,22 @@ def fine_tuning(ctx: click.Context) -> None:
2627
"--n-checkpoints", type=int, default=1, help="Number of checkpoints to save"
2728
)
2829
@click.option("--batch-size", type=int, default=32, help="Train batch size")
29-
@click.option("--learning-rate", type=float, default=3e-5, help="Learning rate")
30+
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
31+
@click.option(
32+
"--lora/--no-lora",
33+
type=bool,
34+
default=False,
35+
help="Whether to use LoRA adapters for fine-tuning",
36+
)
37+
@click.option("--lora-r", type=int, default=8, help="LoRA adapters' rank")
38+
@click.option("--lora-dropout", type=float, default=0, help="LoRA adapters' dropout")
39+
@click.option("--lora-alpha", type=float, default=8, help="LoRA adapters' alpha")
40+
@click.option(
41+
"--lora-trainable-modules",
42+
type=str,
43+
default="all-linear",
44+
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
45+
)
3046
@click.option(
3147
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
3248
)
@@ -39,19 +55,44 @@ def create(
3955
n_checkpoints: int,
4056
batch_size: int,
4157
learning_rate: float,
58+
lora: bool,
59+
lora_r: int,
60+
lora_dropout: float,
61+
lora_alpha: float,
62+
lora_trainable_modules: str,
4263
suffix: str,
4364
wandb_api_key: str,
4465
) -> None:
4566
"""Start fine-tuning"""
4667
client: Together = ctx.obj
4768

69+
if lora:
70+
learning_rate_source = click.get_current_context().get_parameter_source( # type: ignore[attr-defined]
71+
"learning_rate"
72+
)
73+
if learning_rate_source == ParameterSource.DEFAULT:
74+
learning_rate = 1e-3
75+
else:
76+
for param in ["lora_r", "lora_dropout", "lora_alpha", "lora_trainable_modules"]:
77+
param_source = click.get_current_context().get_parameter_source(param) # type: ignore[attr-defined]
78+
if param_source != ParameterSource.DEFAULT:
79+
raise click.BadParameter(
80+
f"You set LoRA parameter `{param}` for a full fine-tuning job. "
81+
f"Please change the job type with --lora or remove `{param}` from the arguments"
82+
)
83+
4884
response = client.fine_tuning.create(
4985
training_file=training_file,
5086
model=model,
5187
n_epochs=n_epochs,
5288
n_checkpoints=n_checkpoints,
5389
batch_size=batch_size,
5490
learning_rate=learning_rate,
91+
lora=lora,
92+
lora_r=lora_r,
93+
lora_dropout=lora_dropout,
94+
lora_alpha=lora_alpha,
95+
lora_trainable_modules=lora_trainable_modules,
5596
suffix=suffix,
5697
wandb_api_key=wandb_api_key,
5798
)

src/together/resources/finetune.py

+23
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
FinetuneListEvents,
1212
FinetuneRequest,
1313
FinetuneResponse,
14+
FullTrainingType,
15+
LoRATrainingType,
1416
TogetherClient,
1517
TogetherRequest,
18+
TrainingType,
1619
)
1720
from together.utils import normalize_key
1821

@@ -30,6 +33,11 @@ def create(
3033
n_checkpoints: int | None = 1,
3134
batch_size: int | None = 32,
3235
learning_rate: float | None = 0.00001,
36+
lora: bool = True,
37+
lora_r: int | None = 8,
38+
lora_dropout: float | None = 0,
39+
lora_alpha: float | None = 8,
40+
lora_trainable_modules: str | None = "all-linear",
3341
suffix: str | None = None,
3442
wandb_api_key: str | None = None,
3543
) -> FinetuneResponse:
@@ -45,6 +53,11 @@ def create(
4553
batch_size (int, optional): Batch size for fine-tuning. Defaults to 32.
4654
learning_rate (float, optional): Learning rate multiplier to use for training
4755
Defaults to 0.00001.
56+
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
57+
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
58+
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
59+
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
60+
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
4861
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
4962
Defaults to None.
5063
wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -58,13 +71,23 @@ def create(
5871
client=self._client,
5972
)
6073

74+
training_type: TrainingType = FullTrainingType()
75+
if lora:
76+
training_type = LoRATrainingType(
77+
lora_r=lora_r,
78+
lora_alpha=lora_alpha,
79+
lora_dropout=lora_dropout,
80+
lora_trainable_modules=lora_trainable_modules,
81+
)
82+
6183
parameter_payload = FinetuneRequest(
6284
model=model,
6385
training_file=training_file,
6486
n_epochs=n_epochs,
6587
n_checkpoints=n_checkpoints,
6688
batch_size=batch_size,
6789
learning_rate=learning_rate,
90+
training_type=training_type,
6891
suffix=suffix,
6992
wandb_key=wandb_api_key,
7093
).model_dump()

src/together/types/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
FinetuneListEvents,
2727
FinetuneRequest,
2828
FinetuneResponse,
29+
FullTrainingType,
30+
LoRATrainingType,
31+
TrainingType,
2932
)
3033
from together.types.images import (
3134
ImageRequest,
@@ -60,4 +63,7 @@
6063
"ImageRequest",
6164
"ImageResponse",
6265
"ModelObject",
66+
"TrainingType",
67+
"FullTrainingType",
68+
"LoRATrainingType",
6369
]

src/together/types/finetune.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,34 @@ class FinetuneEvent(BaseModel):
100100
hash: str | None = None
101101

102102

103+
class TrainingType(BaseModel):
104+
"""
105+
Abstract training type
106+
"""
107+
108+
type: str
109+
110+
111+
class FullTrainingType(TrainingType):
112+
"""
113+
Training type for full fine-tuning
114+
"""
115+
116+
type: str = "Full"
117+
118+
119+
class LoRATrainingType(TrainingType):
120+
"""
121+
Training type for LoRA adapters training
122+
"""
123+
124+
lora_r: int
125+
lora_alpha: int
126+
lora_dropout: float
127+
lora_trainable_modules: str
128+
type: str = "Lora"
129+
130+
103131
class FinetuneRequest(BaseModel):
104132
"""
105133
Fine-tune request type
@@ -121,6 +149,7 @@ class FinetuneRequest(BaseModel):
121149
suffix: str | None = None
122150
# weights & biases api key
123151
wandb_key: str | None = None
152+
training_type: FullTrainingType | LoRATrainingType | None = None
124153

125154

126155
class FinetuneResponse(BaseModel):
@@ -138,6 +167,8 @@ class FinetuneResponse(BaseModel):
138167
model: str | None = None
139168
# output model name
140169
output_name: str | None = Field(None, alias="model_output_name")
170+
# adapter output name
171+
adapter_output_name: str | None = None
141172
# number of epochs
142173
n_epochs: int | None = None
143174
# number of checkpoints to save
@@ -148,11 +179,8 @@ class FinetuneResponse(BaseModel):
148179
learning_rate: float | None = None
149180
# number of steps between evals
150181
eval_steps: int | None = None
151-
# is LoRA finetune boolean
152-
lora: bool | None = None
153-
lora_r: int | None = None
154-
lora_alpha: int | None = None
155-
lora_dropout: int | None = None
182+
# training type
183+
training_type: FullTrainingType | LoRATrainingType | None = None
156184
# created/updated datetime stamps
157185
created_at: str | None = None
158186
updated_at: str | None = None

0 commit comments

Comments
 (0)