forked from tdrussell/qlora-pipe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline_model.py
428 lines (373 loc) · 19.4 KB
/
pipeline_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
import os
import math
from inspect import signature
import torch
from torch import nn
import torch.nn.functional as F
import transformers
from deepspeed.accelerator import get_accelerator
from transformers.integrations import get_keys_to_not_convert
from deepspeed.runtime.pipe import module as ds_pipe_module
import bitsandbytes as bnb
import accelerate
from hqq.core import quantize as hqq_quantize
from utils import is_main_process
from kernels.cross_entropy_loss import Fast_CrossEntropyLoss
import hqq_utils
def move_data_to_device(module, device):
# handle lora
if hasattr(module, 'base_layer'):
module = module.base_layer
# handle HQQ
if hasattr(module, 'W_q'):
orig_data = module.W_q.data
module.W_q.data = orig_data.to(device, non_blocking=True)
else:
orig_data = module.weight.data
module.weight.data = orig_data.to(device, non_blocking=True)
return orig_data
def set_data(module, data):
# handle lora
if hasattr(module, 'base_layer'):
module = module.base_layer
# handle HQQ
if hasattr(module, 'W_q'):
module.W_q.data = data
else:
module.weight.data = data
def entropy_fn(logits):
result = []
# There is a very wide range of chuck sizes that cause no increase in memory reported by
# nvidia-smi (Torch re-using blocks of memory?). If you try to compute it as one tensor,
# memory usage is huge. Chuck size of 128 seems good enough for now.
for logits_chuck in torch.split(logits, 128):
result.append(torch.distributions.Categorical(logits=logits_chuck).entropy())
return torch.cat(result).float()
def top_k_accuracy(logits, labels, k_list, ignore_index=-100):
keep = (labels != ignore_index)
labels = labels[keep].view(-1, 1)
max_k = max(k_list)
_, top_k_predictions = torch.topk(logits, max_k, dim=-1, sorted=True)
top_k_predictions = top_k_predictions[keep]
accuracies = []
for k in k_list:
accuracies.append(torch.any(top_k_predictions[:, :k] == labels, dim=-1).to(torch.float32).mean())
return accuracies
class LayerSpec(ds_pipe_module.LayerSpec):
def __init__(self, typename, *module_args, **module_kwargs):
super().__init__(typename, *module_args, **module_kwargs)
def build(self):
self.module_kwargs.pop('_estimated_size', None)
return self.typename(*self.module_args, **self.module_kwargs)
@property
def estimated_size(self):
return self.module_kwargs.get('_estimated_size', 1)
# TODO: consider using Liger-Kernel fused loss implementations. The inputs are already set up to support this.
# Would save VRAM, but some metrics could no longer be computed (e.g. entropy, accuracies).
class OutputLayer(nn.Module):
def __init__(
self,
pipeline_model,
loader_util,
lm_head,
logit_scale=1.0,
loss_type='cross_entropy_loss',
focal_loss_gamma=0,
tie_weights=None,
logit_softcapping=None,
):
super().__init__()
# Assign list to prevent registering the nn.Module
self.pipeline_model = [pipeline_model]
# Unlike the other wrapper classes, this is called lm_head and not orig. Because this is directly a
# nn.Linear layer, it needs to keep the same attribute name so quantization knows not to quantize it.
self.lm_head = lm_head
self.logit_scale = logit_scale
self.loss_type = loss_type.lower()
self.focal_loss_gamma = focal_loss_gamma
if tie_weights:
self.lm_head.weight.original_name = tie_weights
self.logit_softcapping = logit_softcapping
loader_util.load_state_dict_into_module(self)
if self.loss_type == 'cross_entropy_loss' and self.focal_loss_gamma != 0:
raise ValueError("focal_loss_gamma can't be used with 'cross_entropy_loss' function")
def forward(self, inputs):
hidden_states, labels = inputs
labels = labels.to(hidden_states.device)
if self.logit_scale != 1.0:
hidden_states = hidden_states * self.logit_scale
logits = self.lm_head(hidden_states)
if self.logit_softcapping is not None and self.logit_softcapping > 0:
logits = logits / self.logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.logit_softcapping
extra_ignored_labels = torch.full((labels.shape[0], 1), -100, device=logits.device)
labels = torch.hstack((labels[..., 1:], extra_ignored_labels))
# Flatten the tokens
vocab_size = logits.size(-1)
flat_logits = logits.view(-1, vocab_size)
flat_labels = labels.view(-1)
flat_loss_mask = (flat_labels >= 0)
cross_entropy_loss = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels)
loss = None
if self.loss_type == 'cross_entropy_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
loss_unreduced = cross_entropy_loss
elif self.loss_type == 'focal_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See https://arxiv.org/abs/1708.02002 (Section 3)
p = torch.exp(-cross_entropy_loss)
loss_unreduced = (1-p)**self.focal_loss_gamma * cross_entropy_loss
elif self.loss_type == 'focal_loss_star':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See https://arxiv.org/abs/1708.02002 (Appendix A/B)
# NOTE: The use of Beta makes no sense for the multinomial case as it's invariant to translation
loss_unreduced = Fast_CrossEntropyLoss.apply(flat_logits, flat_labels, self.focal_loss_gamma)
loss_unreduced = loss_unreduced[flat_loss_mask]
loss_unreduced = loss_unreduced / self.focal_loss_gamma
elif self.loss_type == 'inverse_focal_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See "Rethinking Calibration of Deep Neural Networks: Do Not Be Afraid of Overconfidence" (Section 5.2)
# NOTE: The alternative of p^gamma (instead of (1+p)^gamma) might be useful for gradient ascent...
p = torch.exp(-cross_entropy_loss)
loss_unreduced = (1+p)**self.focal_loss_gamma * cross_entropy_loss
elif self.loss_type == 'exponentiated_cross_entropy_loss':
cross_entropy_loss = cross_entropy_loss[flat_loss_mask]
# See "Gradient as a Foundation for Building a Loss Function" (Section III.B)
# NOTE: This is a generalisation of their "Quadratic Cross-Entropy" loss (QCE: gamma=2, CE: gamma=1, etc).
loss_unreduced = cross_entropy_loss**self.focal_loss_gamma / self.focal_loss_gamma
elif self.loss_type == 'dpo':
rl_config = self.pipeline_model[0].train_config['rl']
cross_entropy_loss = cross_entropy_loss.view_as(labels) # unflatten
loss_mask = (labels >= 0)
logps = -(cross_entropy_loss * loss_mask).sum(-1)
half = cross_entropy_loss.size(0) // 2
chosen_logps = logps[:half]
rejected_logps = logps[half:]
if self.pipeline_model[0].dpo_reference_mode:
self.reference_chosen_logps = chosen_logps.detach()
self.reference_rejected_logps = rejected_logps.detach()
return torch.tensor(0., device=logits.device)
# log the language modeling loss metrics on the chosen completion
cross_entropy_loss = cross_entropy_loss[:half].flatten()[loss_mask[:half].flatten()]
loss_unreduced = cross_entropy_loss
flat_logits = logits[:half].view(-1, vocab_size)
flat_labels = labels[:half].view(-1)
flat_loss_mask = (flat_labels >= 0)
policy_chosen_logps = chosen_logps
policy_rejected_logps = rejected_logps
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = self.reference_chosen_logps - self.reference_rejected_logps
del self.reference_chosen_logps
del self.reference_rejected_logps
dpo_logits = pi_logratios - ref_logratios
loss = -F.logsigmoid(rl_config['dpo_beta'] * dpo_logits).mean()
else:
raise NotImplementedError(self.loss_type)
with torch.no_grad():
log_vocab_size = math.log(logits.size(-1))
entropy = entropy_fn(flat_logits)[flat_loss_mask]
# Compute normalised entropy so we can compare between models with different vocab sizes
normalised_entropy = entropy / log_vocab_size
# Compute the (negative) log-likelihood using the original *UNADJUSTED* Cross-Entropy loss.
log_likelihood = cross_entropy_loss.mean()
# Compute McFadden's Pseudo-R² metric using log(vocab_size) as the null log-likelihood.
mcfaddens_pseudo_r2 = 1 - (log_likelihood / log_vocab_size)
accuracies = top_k_accuracy(flat_logits, flat_labels, k_list=[1, 5, 20])
if loss is None:
# Normal language modeling loss types (e.g. not DPO)
loss = loss_unreduced.mean()
loss_unreduced = loss_unreduced.detach()
return loss, loss_unreduced, entropy, normalised_entropy, log_likelihood, mcfaddens_pseudo_r2, *accuracies
class PipelineModel(nn.Module):
def __init__(self, config, quantization_config, model_config):
if config['full_fine_tune'] and model_config.tie_word_embeddings:
raise NotImplementedError('FFT is not supported for models with tied embeddings')
self.train_config = config
self.modules_to_not_quantize = get_keys_to_not_convert(self)
self.loader_util = LoaderUtil(config['model'], quantization_config, self.modules_to_not_quantize)
self.loss_type = config.get('loss_type', 'cross_entropy_loss').lower()
if rl_config := config.get('rl', None):
self.loss_type = rl_config['method']
self.focal_loss_gamma = config.get('focal_loss_gamma', 0)
if self.focal_loss_gamma > 0 and is_main_process():
print(f'Optimizing using \'{self.loss_type}\' with gamma={self.focal_loss_gamma}')
self.dpo_reference_mode = False
for name, p in self.named_parameters():
p.original_name = name
# need to override this method
def to_layer_specs(self):
raise NotImplementedError()
def set_dpo_reference_mode(self, dpo_reference_mode):
self.dpo_reference_mode = dpo_reference_mode
def _partial_module_name_match(full_name, list_to_match):
return any(key in full_name for key in list_to_match)
def _replace_with_quantized_linear(parent_modules_map, name, full_name, quantization_config):
if isinstance(quantization_config, transformers.BitsAndBytesConfig):
_replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config)
elif isinstance(quantization_config, hqq_utils.CustomHQQConfig):
_replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config)
else:
raise NotImplementedError(f'Quantization config not implemented: {quantization_config}')
def _replace_with_bnb_linear(parent_modules_map, name, full_name, quantization_config):
'''Replace a Linear layer with a BNB quantized version.'''
if quantization_config.llm_int8_skip_modules is not None and _partial_module_name_match(full_name, quantization_config.llm_int8_skip_modules):
return
module = parent_modules_map[name]
with accelerate.init_empty_weights():
if isinstance(module, nn.Conv1d):
in_features, out_features = module.weight.shape
else:
in_features = module.in_features
out_features = module.out_features
if quantization_config.quantization_method() == "llm_int8":
parent_modules_map[name] = bnb.nn.Linear8bitLt(
in_features,
out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
)
else:
extra_kwargs = (
{"quant_storage": quantization_config.bnb_4bit_quant_storage}
if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
else {}
)
parent_modules_map[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
**extra_kwargs,
)
# Store the module class in case we need to transpose the weight later
parent_modules_map[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
parent_modules_map[name].requires_grad_(False)
def _replace_with_hqq_linear(parent_modules_map, name, full_name, quantization_config):
'''Replace a Linear layer with a HQQ quantized version.'''
if _partial_module_name_match(full_name, quantization_config.skip_modules):
return
module = parent_modules_map[name]
quant_config_dict = quantization_config.get_dict(full_name)
hqq_linear = hqq_quantize.HQQLinear(
module,
quant_config=quant_config_dict,
compute_dtype=quantization_config.compute_dtype,
device=module.weight.device,
initialize=True,
del_orig=True
)
# Quantization itself uses a decent amount of VRAM. Temporarily move each quantized parameter to the CPU as we
# finish, so the quant process doesn't OOM. Deepspeed will move everything to the correct device later.
hqq_linear.W_q.data = hqq_linear.W_q.data.to('cpu')
# Store the module class in case we need to transpose the weight later
hqq_linear.source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
hqq_linear.requires_grad_(False)
parent_modules_map[name] = hqq_linear
# modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
def _recursively_replace_with_quantized_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
):
"""
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if (isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
_replace_with_quantized_linear(model._modules, name, current_key_name_str, quantization_config)
# copy over the original_name attribute we added earlier (needed for loading weights)
for orig_name, orig_p in module.named_parameters():
if hasattr(orig_p, 'original_name'):
for new_name, new_p in model._modules[name].named_parameters():
if new_name == orig_name:
new_p.original_name = orig_p.original_name
if len(list(module.children())) > 0:
_recursively_replace_with_quantized_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
)
# Remove the last key for recursion
current_key_name.pop(-1)
class LoaderUtil:
def __init__(self, model_path, quantization_config, modules_to_not_quantize):
self.model_path = model_path
self.quantization_config = quantization_config
self.modules_to_not_quantize = modules_to_not_quantize
self.local_rank = int(os.environ.get("LOCAL_RANK", None))
assert self.local_rank is not None
self.device = get_accelerator().device_name(self.local_rank)
index_file = os.path.join(model_path, transformers.utils.SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(index_file):
checkpoint_files, checkpoint_metadata = transformers.utils.hub.get_checkpoint_shard_files(
model_path,
index_file,
local_files_only=True
)
self.checkpoint_metadata = checkpoint_metadata
else:
self.checkpoint_metadata = None
self.loaded_state_dict = None
def get_partial_state_dict(self, leaf_file):
if self.loaded_state_dict is None or leaf_file != self.loaded_state_dict[0]:
print(f'loading checkpoint file {leaf_file}')
state_dict = transformers.modeling_utils.load_state_dict(os.path.join(self.model_path, leaf_file))
self.loaded_state_dict = (leaf_file, state_dict)
return self.loaded_state_dict[1]
def maybe_quantize(self, module):
if self.quantization_config is None:
return
modules_to_not_convert = self.modules_to_not_quantize
if not isinstance(modules_to_not_convert, list):
modules_to_not_convert = [modules_to_not_convert]
_recursively_replace_with_quantized_linear(
module, modules_to_not_convert=modules_to_not_convert, quantization_config=self.quantization_config
)
# Make sure to set this or PEFT (and probably other things) will break in strange ways.
# We only need this because we do the loading and quanting ourselves.
self.is_loaded_in_4bit = True
def load_state_dict_into_module(self, module):
print(f'load params into module {type(module)}')
if isinstance(self.quantization_config, transformers.BitsAndBytesConfig):
# bnb needs to replace with quantized linear before weights are loaded
self.maybe_quantize(module)
param_renaming_map = {p.original_name: new_name for new_name, p in module.named_parameters()}
expected_keys = [p.original_name for p in module.parameters()]
# If we have any extra attributes on the parameter, loading with BNB 4bit params breaks, so delete them.
for p in module.parameters():
del p.original_name
if self.checkpoint_metadata is not None:
weight_map = self.checkpoint_metadata['weight_map']
needed_checkpoint_files = set(weight_map[key.replace('orig.', '')] for key in expected_keys)
else:
needed_checkpoint_files = ['model.safetensors']
for checkpoint_file in needed_checkpoint_files:
state_dict = self.get_partial_state_dict(checkpoint_file)
renamed_state_dict = {param_renaming_map[k]: v for k, v in state_dict.items() if k in param_renaming_map}
# Use some transformers internals to avoid writing a bunch of code ourselves.
# Might be a bit brittle...
transformers.modeling_utils._load_state_dict_into_meta_model(
module,
renamed_state_dict,
'',
list(renamed_state_dict.keys()),
)
module.to(self.device)
if not isinstance(self.quantization_config, transformers.BitsAndBytesConfig):
self.maybe_quantize(module)