@@ -269,6 +269,50 @@ def noop(*args, **kwargs):
269269
270270 return model
271271
272+ def simple_quantize_weights (weights , bits = 4 ):
273+ """Simple quantization function that rounds weights to specified bits"""
274+ w_min = weights .min ()
275+ w_max = weights .max ()
276+
277+ # Calculate scale and zero point for specified bits
278+ max_val = (2 ** bits ) - 1
279+ scale = (w_max - w_min ) / max_val
280+ zero_point = w_min
281+
282+ # Quantize
283+ quantized = torch .round ((weights - zero_point ) / scale )
284+ quantized = torch .clamp (quantized , 0 , max_val )
285+
286+ return quantized .int (), scale , zero_point
287+
288+ def simple_quantize_model (model , bits = 4 ):
289+ """Apply simple quantization to all linear layers in the model"""
290+ print (f"Applying simple { bits } -bit quantization..." )
291+
292+ for name , module in model .named_modules ():
293+ if isinstance (module , nn .Linear ):
294+ print (f"Quantizing { name } ..." )
295+
296+ # Quantize weights
297+ quantized_weights , scale , zero_point = simple_quantize_weights (module .weight .data , bits )
298+
299+ # Store quantization parameters
300+ module .register_buffer ('weight_scale' , torch .tensor (scale ))
301+ module .register_buffer ('weight_zero_point' , torch .tensor (zero_point ))
302+ module .register_buffer ('weight_quantized' , quantized_weights )
303+
304+ # Override forward method
305+ def make_forward (module , scale , zero_point , quantized_weights ):
306+ def forward (x ):
307+ # Dequantize weights on-the-fly
308+ dequantized_weights = quantized_weights .float () * scale + zero_point
309+ return nn .functional .linear (x , dequantized_weights , module .bias )
310+ return forward
311+
312+ module .forward = make_forward (module , scale , zero_point , quantized_weights )
313+
314+ print ("Simple quantization completed!" )
315+
272316def opt_multigpu (model , gpus ):
273317 model .model .decoder .embed_tokens = model .model .decoder .embed_tokens .to (gpus [0 ])
274318 model .model .decoder .embed_positions = model .model .decoder .embed_positions .to (gpus [0 ])
@@ -307,7 +351,8 @@ def forward(self, *inp, **kwargs):
307351
308352def benchmark (model , input_ids , check = False ):
309353 input_ids = input_ids .to (model .gpus [0 ] if hasattr (model , 'gpus' ) else DEV )
310- torch .cuda .synchronize ()
354+ if torch .cuda .is_available ():
355+ torch .cuda .synchronize ()
311356
312357 cache = {'past' : None }
313358 def clear_past (i ):
@@ -327,9 +372,11 @@ def tmp(layer, inp, out):
327372 def sync ():
328373 if hasattr (model , 'gpus' ):
329374 for gpu in model .gpus :
330- torch .cuda .synchronize (gpu )
375+ if torch .cuda .is_available ():
376+ torch .cuda .synchronize ()
331377 else :
332- torch .cuda .synchronize ()
378+ if torch .cuda .is_available ():
379+ torch .cuda .synchronize ()
333380 with torch .no_grad ():
334381 attention_mask = torch .ones ((1 , input_ids .numel ()), device = DEV )
335382 times = []
@@ -432,6 +479,10 @@ def sync():
432479 '--static-groups' , action = 'store_true' ,
433480 help = 'Whether to use static groups; recommended when using `--actorder` for more efficient inference.'
434481 )
482+ parser .add_argument (
483+ '--quantization-type' , choices = ['gptq' , 'simple' ], default = 'gptq' ,
484+ help = 'Type of quantization to use: gptq (sophisticated) or simple (basic rounding)'
485+ )
435486
436487 args = parser .parse_args ()
437488
@@ -446,16 +497,28 @@ def sync():
446497 )
447498
448499 if args .wbits < 16 and not args .nearest :
449- tick = time .time ()
450- quantizers = opt_sequential (model , dataloader , DEV )
451- print (time .time () - tick )
500+ if args .quantization_type == 'gptq' :
501+ print ("Using GPTQ quantization..." )
502+ tick = time .time ()
503+ quantizers = opt_sequential (model , dataloader , DEV )
504+ print (time .time () - tick )
505+ elif args .quantization_type == 'simple' :
506+ print ("Using simple quantization..." )
507+ simple_quantize_model (model , args .wbits )
508+ quantizers = {} # Empty dict for simple quantization
452509
453510 if args .benchmark :
454- gpus = [torch .device ('cuda:%d' % i ) for i in range (torch .cuda .device_count ())]
455- if len (gpus ) > 1 :
456- opt_multigpu (model , gpus )
457- else :
458- model = model .to (DEV )
511+ try :
512+ gpus = [torch .device ('cuda:%d' % i ) for i in range (torch .cuda .device_count ())]
513+ if len (gpus ) > 1 :
514+ opt_multigpu (model , gpus )
515+ else :
516+ model = model .to (DEV )
517+ except (AssertionError , RuntimeError ):
518+ print ("CUDA not available, using CPU for benchmarking..." )
519+ model = model .to ('cpu' )
520+ DEV = torch .device ('cpu' )
521+
459522 if args .benchmark :
460523 input_ids = next (iter (dataloader ))[0 ][:, :args .benchmark ]
461524 benchmark (model , input_ids , check = args .check )
0 commit comments