Skip to content

Commit 1faa861

Browse files
authored
Updated hyperparameters for finetuning
* Updated CLI hyperparameters for fine tuning * minor fix * Comment out parameters that have not yet been implemented * minor fix: default value for n_epochs
1 parent 71277be commit 1faa861

File tree

2 files changed

+76
-70
lines changed

2 files changed

+76
-70
lines changed

src/together/commands/finetune.py

+59-54
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) ->
3333
required=True,
3434
type=str,
3535
)
36-
create_finetune_parser.add_argument(
37-
"--validation-file",
38-
"-v",
39-
default=None,
40-
help="The ID of an uploaded file that contains validation data.",
41-
type=str,
42-
)
36+
# create_finetune_parser.add_argument(
37+
# "--validation-file",
38+
# "-v",
39+
# default=None,
40+
# help="The ID of an uploaded file that contains validation data.",
41+
# type=str,
42+
# )
4343
create_finetune_parser.add_argument(
4444
"--model",
4545
"-m",
@@ -57,53 +57,57 @@ def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) ->
5757
create_finetune_parser.add_argument(
5858
"--batch-size",
5959
"-b",
60-
default=None,
60+
default=32,
6161
help="The batch size to use for training.",
6262
type=int,
6363
)
6464
create_finetune_parser.add_argument(
65-
"--learning-rate-multiplier",
66-
"-lrm",
67-
default=None,
65+
"--learning-rate",
66+
"-lr",
67+
default=0.00001,
6868
help="The learning rate multiplier to use for training.",
6969
type=float,
7070
)
71-
create_finetune_parser.add_argument(
72-
"--prompt-loss-weight",
73-
"-plw",
74-
default=0.01,
75-
help="The weight to use for loss on the prompt tokens.",
76-
type=float,
77-
)
78-
create_finetune_parser.add_argument(
79-
"--compute-classification-metrics",
80-
"-ccm",
81-
default=False,
82-
action="store_true",
83-
help="Calculate classification-specific metrics using the validation set.",
84-
)
85-
create_finetune_parser.add_argument(
86-
"--classification-n-classes",
87-
"-cnc",
88-
default=None,
89-
help="The number of classes in a classification task.",
90-
type=int,
91-
)
92-
create_finetune_parser.add_argument(
93-
"--classification-positive-class",
94-
"-cpc",
95-
default=None,
96-
help="The positive class in binary classification.",
97-
type=str,
98-
)
99-
create_finetune_parser.add_argument(
100-
"--classification-betas",
101-
"-cb",
102-
default=None,
103-
help="Calculate F-beta scores at the specified beta values.",
104-
nargs="+",
105-
type=float,
106-
)
71+
# create_finetune_parser.add_argument(
72+
# "--warmup-steps",
73+
# "-ws",
74+
# default=0,
75+
# help="Warmup steps",
76+
# type=int,
77+
# )
78+
# create_finetune_parser.add_argument(
79+
# "--train-warmup-steps",
80+
# "-tws",
81+
# default=0,
82+
# help="Train warmup steps",
83+
# type=int,
84+
# )
85+
# create_finetune_parser.add_argument(
86+
# "--sequence-length",
87+
# "-sl",
88+
# default=2048,
89+
# help="Max sequence length",
90+
# type=int,
91+
# )
92+
# create_finetune_parser.add_argument(
93+
# "--seed",
94+
# default=42,
95+
# help="Training seed",
96+
# type=int,
97+
# )
98+
# create_finetune_parser.add_argument(
99+
# "--fp32",
100+
# help="Enable FP32 training. Defaults to false (FP16 training).",
101+
# default=False,
102+
# action="store_true",
103+
# )
104+
# create_finetune_parser.add_argument(
105+
# "--checkpoint-steps",
106+
# "-b",
107+
# default=0,
108+
# help="Number of steps between each checkpoint. Defaults to 0 = checkpoints per epoch.",
109+
# type=int,
110+
# )
107111
create_finetune_parser.add_argument(
108112
"--suffix",
109113
"-s",
@@ -244,16 +248,17 @@ def _run_create(args: argparse.Namespace) -> None:
244248

245249
response = finetune.create_finetune(
246250
training_file=args.training_file, # training file_id
247-
validation_file=args.validation_file, # validation file_id
251+
# validation_file=args.validation_file, # validation file_id
248252
model=args.model,
249253
n_epochs=args.n_epochs,
250254
batch_size=args.batch_size,
251-
learning_rate_multiplier=args.learning_rate_multiplier,
252-
prompt_loss_weight=args.prompt_loss_weight,
253-
compute_classification_metrics=args.compute_classification_metrics,
254-
classification_n_classes=args.classification_n_classes,
255-
classification_positive_class=args.classification_positive_class,
256-
classification_betas=args.classification_betas,
255+
learning_rate=args.learning_rate,
256+
# warmup_steps=args.warmup_steps,
257+
# train_warmup_steps=args.train_warmup_steps,
258+
# seq_length=args.sequence_length,
259+
# seed=args.seed,
260+
# fp16=not args.fp32,
261+
# checkpoint_steps=args.checkpoint_steps,
257262
suffix=args.suffix,
258263
)
259264

src/together/finetune.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,31 @@ def __init__(
2929
def create_finetune(
3030
self,
3131
training_file: str, # training file_id
32-
validation_file: Optional[str] = None, # validation file_id
32+
# validation_file: Optional[str] = None, # validation file_id
3333
model: Optional[str] = None,
34-
n_epochs: Optional[int] = 4,
35-
batch_size: Optional[int] = None,
36-
learning_rate_multiplier: Optional[float] = None,
37-
prompt_loss_weight: Optional[float] = 0.01,
38-
compute_classification_metrics: Optional[bool] = False,
39-
classification_n_classes: Optional[int] = None,
40-
classification_positive_class: Optional[str] = None,
41-
classification_betas: Optional[List[Any]] = None,
34+
n_epochs: Optional[int] = 1,
35+
batch_size: Optional[int] = 32,
36+
learning_rate: Optional[float] = 0.00001,
37+
# warmup_steps: Optional[int] = 0,
38+
# train_warmup_steps: Optional[int] = 0,
39+
# seq_length: Optional[int] = 2048,
40+
# seed: Optional[int] = 42,
41+
# fp16: Optional[bool] = True,
42+
# checkpoint_steps: Optional[int] = None,
4243
suffix: Optional[str] = None,
4344
) -> Dict[Any, Any]:
4445
parameter_payload = {
4546
"training_file": training_file,
46-
"validation_file": validation_file,
47+
# "validation_file": validation_file,
4748
"model": model,
4849
"n_epochs": n_epochs,
4950
"batch_size": batch_size,
50-
"learning_rate_multiplier": learning_rate_multiplier,
51-
"prompt_loss_weight": prompt_loss_weight,
52-
"compute_classification_metrics": compute_classification_metrics,
53-
"classification_n_classes": classification_n_classes,
54-
"classification_positive_class": classification_positive_class,
55-
"classification_betas": classification_betas,
51+
"learning_rate": learning_rate,
52+
# "warmup_steps": warmup_steps,
53+
# "train_warmup_steps": train_warmup_steps,
54+
# "seq_length": seq_length,
55+
# "seed": seed,
56+
# "fp16": fp16,
5657
"suffix": suffix,
5758
}
5859

0 commit comments

Comments
 (0)