Skip to content

Commit d4b66d2

Browse files
authored
full train fixes (instructlab#2382)
- change LR to 2e-5 (learns better) - honor the gradient accumulation given by the multipack sampler the LR is the big change here. The previous version of full train has little to no improvement after 8 epochs. This version knows about the new information after as little as 1 epoch. This is a major improvement. Also add logging to print information about the data **Checklist:** - [ ] **Commit Message Formatting**: Commit titles and messages follow guidelines in the [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/#summary). - [ ] [Changelog](https://github.com/instructlab/instructlab/blob/main/CHANGELOG.md) updated with breaking and/or notable changes for the next minor release. - [ ] Documentation has been updated, if necessary. - [ ] Unit tests have been added, if necessary. - [ ] Functional tests have been added, if necessary. - [ ] E2E Workflow tests have been added, if necessary. Approved-by: jaideepr97 Approved-by: RobotSail
2 parents 59052f1 + 5d04a5e commit d4b66d2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/instructlab/model/full_train.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def train(train_args, device):
6464
)
6565

6666
# based on the length of the dataset, figure out the max batch len
67-
packing_max_batch_len, _ = (
67+
packing_max_batch_len, accum = (
6868
multipack_sampler.find_packing_max_batch_len_and_grad_accum(
6969
num_gpus=1,
7070
avg_sample_len=dataset.get_lengths().mean(),
@@ -99,6 +99,9 @@ def train(train_args, device):
9999
)
100100
dataloader.multiprocessing_context = "spawn"
101101

102+
logger.info(
103+
f"avg_sample_len: {dataset.get_lengths().mean()}\n effective_batch_size: {train_args.effective_batch_size}\n max_batch_len: {train_args.max_batch_len}\n packing_max_batch_len: {packing_max_batch_len} \n grad_accum: {accum}\n num_batches: {len(dataloader)}\n avg_samples_per_batch: {len(dataset) / len(dataloader)}"
104+
)
102105
# set device based off argument given
103106
dev = torch.device(device)
104107
# auto model based on model path
@@ -151,20 +154,18 @@ def train(train_args, device):
151154
total_ram = memory_info.total / (1024**3) # Convert to GB
152155
logger.info(f"Total RAM: {total_ram:.2f} GB")
153156
# if RAM is <= 16, we need to use fp16 not fp32. This will yield a worse model but will allow the full pipeline to run
154-
accum = 1
155157
if total_ram <= 16:
156158
# if <= 16GB ram, use gradinent accum and hald precision
157159
logger.warning(
158160
f"Your system has {total_ram:.2f} GB of RAM. This is below our reccomendation of 32GB for this type of training. Using half precision."
159161
)
160162
model = model.to(dev).half() # Convert model to float16
161-
accum = 4
162163
else:
163164
model = model.to(dev)
164165

165166
# adafactor and gradient checkpointing are memory friendly, we opt to use these in the CPU/MPS loop to fit 7b models.
166167
optimizer = Adafactor(
167-
model.parameters(), lr=1e-5, scale_parameter=True, relative_step=False
168+
model.parameters(), lr=2e-5, scale_parameter=False, relative_step=False
168169
)
169170
model.gradient_checkpointing_enable()
170171

0 commit comments

Comments
 (0)