Skip to content

Commit d2b5822

Browse files
Added simple Quant and GPTQ seprately via cmdline
1 parent 8980e47 commit d2b5822

File tree

1 file changed

+74
-11
lines changed

1 file changed

+74
-11
lines changed

opt.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,50 @@ def noop(*args, **kwargs):
269269

270270
return model
271271

272+
def simple_quantize_weights(weights, bits=4):
273+
"""Simple quantization function that rounds weights to specified bits"""
274+
w_min = weights.min()
275+
w_max = weights.max()
276+
277+
# Calculate scale and zero point for specified bits
278+
max_val = (2 ** bits) - 1
279+
scale = (w_max - w_min) / max_val
280+
zero_point = w_min
281+
282+
# Quantize
283+
quantized = torch.round((weights - zero_point) / scale)
284+
quantized = torch.clamp(quantized, 0, max_val)
285+
286+
return quantized.int(), scale, zero_point
287+
288+
def simple_quantize_model(model, bits=4):
289+
"""Apply simple quantization to all linear layers in the model"""
290+
print(f"Applying simple {bits}-bit quantization...")
291+
292+
for name, module in model.named_modules():
293+
if isinstance(module, nn.Linear):
294+
print(f"Quantizing {name}...")
295+
296+
# Quantize weights
297+
quantized_weights, scale, zero_point = simple_quantize_weights(module.weight.data, bits)
298+
299+
# Store quantization parameters
300+
module.register_buffer('weight_scale', torch.tensor(scale))
301+
module.register_buffer('weight_zero_point', torch.tensor(zero_point))
302+
module.register_buffer('weight_quantized', quantized_weights)
303+
304+
# Override forward method
305+
def make_forward(module, scale, zero_point, quantized_weights):
306+
def forward(x):
307+
# Dequantize weights on-the-fly
308+
dequantized_weights = quantized_weights.float() * scale + zero_point
309+
return nn.functional.linear(x, dequantized_weights, module.bias)
310+
return forward
311+
312+
module.forward = make_forward(module, scale, zero_point, quantized_weights)
313+
314+
print("Simple quantization completed!")
315+
272316
def opt_multigpu(model, gpus):
273317
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(gpus[0])
274318
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(gpus[0])
@@ -307,7 +351,8 @@ def forward(self, *inp, **kwargs):
307351

308352
def benchmark(model, input_ids, check=False):
309353
input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
310-
torch.cuda.synchronize()
354+
if torch.cuda.is_available():
355+
torch.cuda.synchronize()
311356

312357
cache = {'past': None}
313358
def clear_past(i):
@@ -327,9 +372,11 @@ def tmp(layer, inp, out):
327372
def sync():
328373
if hasattr(model, 'gpus'):
329374
for gpu in model.gpus:
330-
torch.cuda.synchronize(gpu)
375+
if torch.cuda.is_available():
376+
torch.cuda.synchronize()
331377
else:
332-
torch.cuda.synchronize()
378+
if torch.cuda.is_available():
379+
torch.cuda.synchronize()
333380
with torch.no_grad():
334381
attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
335382
times = []
@@ -432,6 +479,10 @@ def sync():
432479
'--static-groups', action='store_true',
433480
help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.'
434481
)
482+
parser.add_argument(
483+
'--quantization-type', choices=['gptq', 'simple'], default='gptq',
484+
help='Type of quantization to use: gptq (sophisticated) or simple (basic rounding)'
485+
)
435486

436487
args = parser.parse_args()
437488

@@ -446,16 +497,28 @@ def sync():
446497
)
447498

448499
if args.wbits < 16 and not args.nearest:
449-
tick = time.time()
450-
quantizers = opt_sequential(model, dataloader, DEV)
451-
print(time.time() - tick)
500+
if args.quantization_type == 'gptq':
501+
print("Using GPTQ quantization...")
502+
tick = time.time()
503+
quantizers = opt_sequential(model, dataloader, DEV)
504+
print(time.time() - tick)
505+
elif args.quantization_type == 'simple':
506+
print("Using simple quantization...")
507+
simple_quantize_model(model, args.wbits)
508+
quantizers = {} # Empty dict for simple quantization
452509

453510
if args.benchmark:
454-
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
455-
if len(gpus) > 1:
456-
opt_multigpu(model, gpus)
457-
else:
458-
model = model.to(DEV)
511+
try:
512+
gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
513+
if len(gpus) > 1:
514+
opt_multigpu(model, gpus)
515+
else:
516+
model = model.to(DEV)
517+
except (AssertionError, RuntimeError):
518+
print("CUDA not available, using CPU for benchmarking...")
519+
model = model.to('cpu')
520+
DEV = torch.device('cpu')
521+
459522
if args.benchmark:
460523
input_ids = next(iter(dataloader))[0][:, :args.benchmark]
461524
benchmark(model, input_ids, check=args.check)

0 commit comments

Comments
 (0)