2
2
from textwrap import wrap
3
3
4
4
import click
5
+ from click .core import ParameterSource # type: ignore[attr-defined]
5
6
from tabulate import tabulate
6
7
7
8
from together import Together
@@ -26,7 +27,22 @@ def fine_tuning(ctx: click.Context) -> None:
26
27
"--n-checkpoints" , type = int , default = 1 , help = "Number of checkpoints to save"
27
28
)
28
29
@click .option ("--batch-size" , type = int , default = 32 , help = "Train batch size" )
29
- @click .option ("--learning-rate" , type = float , default = 3e-5 , help = "Learning rate" )
30
+ @click .option ("--learning-rate" , type = float , default = 1e-5 , help = "Learning rate" )
31
+ @click .option (
32
+ "--lora/--no-lora" ,
33
+ type = bool ,
34
+ default = False ,
35
+ help = "Whether to use LoRA adapters for fine-tuning" ,
36
+ )
37
+ @click .option ("--lora-r" , type = int , default = 8 , help = "LoRA adapters' rank" )
38
+ @click .option ("--lora-dropout" , type = float , default = 0 , help = "LoRA adapters' dropout" )
39
+ @click .option ("--lora-alpha" , type = float , default = 8 , help = "LoRA adapters' alpha" )
40
+ @click .option (
41
+ "--lora-trainable-modules" ,
42
+ type = str ,
43
+ default = "all-linear" ,
44
+ help = "Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'" ,
45
+ )
30
46
@click .option (
31
47
"--suffix" , type = str , default = None , help = "Suffix for the fine-tuned model name"
32
48
)
@@ -39,19 +55,44 @@ def create(
39
55
n_checkpoints : int ,
40
56
batch_size : int ,
41
57
learning_rate : float ,
58
+ lora : bool ,
59
+ lora_r : int ,
60
+ lora_dropout : float ,
61
+ lora_alpha : float ,
62
+ lora_trainable_modules : str ,
42
63
suffix : str ,
43
64
wandb_api_key : str ,
44
65
) -> None :
45
66
"""Start fine-tuning"""
46
67
client : Together = ctx .obj
47
68
69
+ if lora :
70
+ learning_rate_source = click .get_current_context ().get_parameter_source ( # type: ignore[attr-defined]
71
+ "learning_rate"
72
+ )
73
+ if learning_rate_source == ParameterSource .DEFAULT :
74
+ learning_rate = 1e-3
75
+ else :
76
+ for param in ["lora_r" , "lora_dropout" , "lora_alpha" , "lora_trainable_modules" ]:
77
+ param_source = click .get_current_context ().get_parameter_source (param ) # type: ignore[attr-defined]
78
+ if param_source != ParameterSource .DEFAULT :
79
+ raise click .BadParameter (
80
+ f"You set LoRA parameter `{ param } ` for a full fine-tuning job. "
81
+ f"Please change the job type with --lora or remove `{ param } ` from the arguments"
82
+ )
83
+
48
84
response = client .fine_tuning .create (
49
85
training_file = training_file ,
50
86
model = model ,
51
87
n_epochs = n_epochs ,
52
88
n_checkpoints = n_checkpoints ,
53
89
batch_size = batch_size ,
54
90
learning_rate = learning_rate ,
91
+ lora = lora ,
92
+ lora_r = lora_r ,
93
+ lora_dropout = lora_dropout ,
94
+ lora_alpha = lora_alpha ,
95
+ lora_trainable_modules = lora_trainable_modules ,
55
96
suffix = suffix ,
56
97
wandb_api_key = wandb_api_key ,
57
98
)
0 commit comments