@@ -64,7 +64,7 @@ def train(train_args, device):
64
64
)
65
65
66
66
# based on the length of the dataset, figure out the max batch len
67
- packing_max_batch_len , _ = (
67
+ packing_max_batch_len , accum = (
68
68
multipack_sampler .find_packing_max_batch_len_and_grad_accum (
69
69
num_gpus = 1 ,
70
70
avg_sample_len = dataset .get_lengths ().mean (),
@@ -99,6 +99,9 @@ def train(train_args, device):
99
99
)
100
100
dataloader .multiprocessing_context = "spawn"
101
101
102
+ logger .info (
103
+ f"avg_sample_len: { dataset .get_lengths ().mean ()} \n effective_batch_size: { train_args .effective_batch_size } \n max_batch_len: { train_args .max_batch_len } \n packing_max_batch_len: { packing_max_batch_len } \n grad_accum: { accum } \n num_batches: { len (dataloader )} \n avg_samples_per_batch: { len (dataset ) / len (dataloader )} "
104
+ )
102
105
# set device based off argument given
103
106
dev = torch .device (device )
104
107
# auto model based on model path
@@ -151,20 +154,18 @@ def train(train_args, device):
151
154
total_ram = memory_info .total / (1024 ** 3 ) # Convert to GB
152
155
logger .info (f"Total RAM: { total_ram :.2f} GB" )
153
156
# if RAM is <= 16, we need to use fp16 not fp32. This will yield a worse model but will allow the full pipeline to run
154
- accum = 1
155
157
if total_ram <= 16 :
156
158
# if <= 16GB ram, use gradinent accum and hald precision
157
159
logger .warning (
158
160
f"Your system has { total_ram :.2f} GB of RAM. This is below our reccomendation of 32GB for this type of training. Using half precision."
159
161
)
160
162
model = model .to (dev ).half () # Convert model to float16
161
- accum = 4
162
163
else :
163
164
model = model .to (dev )
164
165
165
166
# adafactor and gradient checkpointing are memory friendly, we opt to use these in the CPU/MPS loop to fit 7b models.
166
167
optimizer = Adafactor (
167
- model .parameters (), lr = 1e -5 , scale_parameter = True , relative_step = False
168
+ model .parameters (), lr = 2e -5 , scale_parameter = False , relative_step = False
168
169
)
169
170
model .gradient_checkpointing_enable ()
170
171
0 commit comments