Skip to content

Commit 5d04a5e

Browse files
committed
full train fixes
- change LR to 2e-5 (learns better) - honor the gradient accumulation given by the multipack sampler Signed-off-by: Charlie Doern <[email protected]>
1 parent 62d35eb commit 5d04a5e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/instructlab/model/full_train.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def train(train_args, device):
6464
)
6565

6666
# based on the length of the dataset, figure out the max batch len
67-
packing_max_batch_len, _ = (
67+
packing_max_batch_len, accum = (
6868
multipack_sampler.find_packing_max_batch_len_and_grad_accum(
6969
num_gpus=1,
7070
avg_sample_len=dataset.get_lengths().mean(),
@@ -99,6 +99,9 @@ def train(train_args, device):
9999
)
100100
dataloader.multiprocessing_context = "spawn"
101101

102+
logger.info(
103+
f"avg_sample_len: {dataset.get_lengths().mean()}\n effective_batch_size: {train_args.effective_batch_size}\n max_batch_len: {train_args.max_batch_len}\n packing_max_batch_len: {packing_max_batch_len} \n grad_accum: {accum}\n num_batches: {len(dataloader)}\n avg_samples_per_batch: {len(dataset) / len(dataloader)}"
104+
)
102105
# set device based off argument given
103106
dev = torch.device(device)
104107
# auto model based on model path
@@ -151,20 +154,18 @@ def train(train_args, device):
151154
total_ram = memory_info.total / (1024**3) # Convert to GB
152155
logger.info(f"Total RAM: {total_ram:.2f} GB")
153156
# if RAM is <= 16, we need to use fp16 not fp32. This will yield a worse model but will allow the full pipeline to run
154-
accum = 1
155157
if total_ram <= 16:
156158
# if <= 16GB ram, use gradinent accum and hald precision
157159
logger.warning(
158160
f"Your system has {total_ram:.2f} GB of RAM. This is below our reccomendation of 32GB for this type of training. Using half precision."
159161
)
160162
model = model.to(dev).half() # Convert model to float16
161-
accum = 4
162163
else:
163164
model = model.to(dev)
164165

165166
# adafactor and gradient checkpointing are memory friendly, we opt to use these in the CPU/MPS loop to fit 7b models.
166167
optimizer = Adafactor(
167-
model.parameters(), lr=1e-5, scale_parameter=True, relative_step=False
168+
model.parameters(), lr=2e-5, scale_parameter=False, relative_step=False
168169
)
169170
model.gradient_checkpointing_enable()
170171

0 commit comments

Comments
 (0)