9
9
from .utils import is_divisible , encode_safetensor_type , decode_safetensor_type
10
10
from .optimize import optimize_weights_proximal
11
11
from .bitpack import BitPack
12
+ from termcolor import colored
12
13
13
14
_META_TYPE = {
14
15
"scale" : torch .Tensor ,
@@ -386,6 +387,8 @@ def __init__(
386
387
self .ready = False
387
388
self .in_gpu = False
388
389
self .bias = None
390
+ self .axis = None
391
+ self .channel_wise = None
389
392
self .device = device
390
393
self .compute_dtype = compute_dtype
391
394
self .quant_config = copy .deepcopy (quant_config )
@@ -408,6 +411,9 @@ def __init__(
408
411
if initialize :
409
412
self .initialize ()
410
413
414
+ def is_initialized (self ):
415
+ return False if (None in [self .W_q , self .meta ]) else True
416
+
411
417
def initialize (self ):
412
418
if self .linear_layer is not None :
413
419
self .quantize (self .linear_layer .weight .data , ** self .quant_config )
@@ -524,9 +530,11 @@ def cuda(self, device):
524
530
)
525
531
526
532
if self .bias is not None :
527
- if (isinstance (self .bias , torch .nn .Parameter )):
528
- self .bias .data = self .bias .data .to (device = device , dtype = self .compute_dtype )
529
- if (isinstance (self .bias , torch .Tensor )):
533
+ if isinstance (self .bias , torch .nn .Parameter ):
534
+ self .bias .data = self .bias .data .to (
535
+ device = device , dtype = self .compute_dtype
536
+ )
537
+ if isinstance (self .bias , torch .Tensor ):
530
538
self .bias = self .bias .to (device = device , dtype = self .compute_dtype )
531
539
532
540
self .W_q = nn .Parameter (self .W_q , requires_grad = False )
@@ -569,7 +577,36 @@ def cpu(self):
569
577
570
578
# state_dict is encoded by default for safetensors support. You can get the raw dict by setting self.encoded_state_dict=False. \
571
579
# Note: you can't change the state once it's done
580
+ def state_dict_keys (self ):
581
+ return set (
582
+ [
583
+ "W_q" ,
584
+ "nbits" ,
585
+ "group_size" ,
586
+ "shape" ,
587
+ "scale" ,
588
+ "zero" ,
589
+ "axis" ,
590
+ "packing" ,
591
+ "unpack_view_dtype" ,
592
+ "view_as_float" ,
593
+ "quant_scale" ,
594
+ "quant_zero" ,
595
+ "compute_dtype" ,
596
+ "bias" ,
597
+ "offload_meta" ,
598
+ "encoded_state_dict" ,
599
+ "stores_quant_config" ,
600
+ "channel_wise" ,
601
+ "optimize" ,
602
+ "round_zero" ,
603
+ ]
604
+ )
605
+
572
606
def state_dict (self , * args , ** kwargs ): # nn.Module override compatible
607
+ if not self .is_initialized ():
608
+ return {k : None for k in self .state_dict_keys ()}
609
+
573
610
if (
574
611
self .quant_config ["scale_quant_params" ]
575
612
or self .quant_config ["zero_quant_params" ]
@@ -1027,11 +1064,21 @@ def hqq_base_quant_config(
1027
1064
"view_as_float" : view_as_float ,
1028
1065
}
1029
1066
1030
- if (quant_zero or quant_scale ):
1031
- print (colored ('Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.' , 'yellow' ))
1067
+ if quant_zero or quant_scale :
1068
+ print (
1069
+ colored (
1070
+ "Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization." ,
1071
+ "yellow" ,
1072
+ )
1073
+ )
1032
1074
1033
- if (offload_meta ):
1034
- print (colored ('Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.' , 'yellow' ))
1075
+ if offload_meta :
1076
+ print (
1077
+ colored (
1078
+ "Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization." ,
1079
+ "yellow" ,
1080
+ )
1081
+ )
1035
1082
1036
1083
if offload_meta :
1037
1084
if quant_scale != quant_zero :
0 commit comments