Skip to content

Commit 5f8f0d2

Browse files
committed
fix empty state_dict() and bump to 0.2.1
1 parent 293812c commit 5f8f0d2

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

hqq/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.2.1"
22
__author__ = 'Dr. Hicham Badri'
33
__credits__ = 'Mobius Labs GmbH'

hqq/core/quantize.py

+54-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .utils import is_divisible, encode_safetensor_type, decode_safetensor_type
1010
from .optimize import optimize_weights_proximal
1111
from .bitpack import BitPack
12+
from termcolor import colored
1213

1314
_META_TYPE = {
1415
"scale": torch.Tensor,
@@ -386,6 +387,8 @@ def __init__(
386387
self.ready = False
387388
self.in_gpu = False
388389
self.bias = None
390+
self.axis = None
391+
self.channel_wise = None
389392
self.device = device
390393
self.compute_dtype = compute_dtype
391394
self.quant_config = copy.deepcopy(quant_config)
@@ -408,6 +411,9 @@ def __init__(
408411
if initialize:
409412
self.initialize()
410413

414+
def is_initialized(self):
415+
return False if (None in [self.W_q, self.meta]) else True
416+
411417
def initialize(self):
412418
if self.linear_layer is not None:
413419
self.quantize(self.linear_layer.weight.data, **self.quant_config)
@@ -524,9 +530,11 @@ def cuda(self, device):
524530
)
525531

526532
if self.bias is not None:
527-
if(isinstance(self.bias, torch.nn.Parameter)):
528-
self.bias.data = self.bias.data.to(device=device, dtype=self.compute_dtype)
529-
if(isinstance(self.bias, torch.Tensor)):
533+
if isinstance(self.bias, torch.nn.Parameter):
534+
self.bias.data = self.bias.data.to(
535+
device=device, dtype=self.compute_dtype
536+
)
537+
if isinstance(self.bias, torch.Tensor):
530538
self.bias = self.bias.to(device=device, dtype=self.compute_dtype)
531539

532540
self.W_q = nn.Parameter(self.W_q, requires_grad=False)
@@ -569,7 +577,36 @@ def cpu(self):
569577

570578
# state_dict is encoded by default for safetensors support. You can get the raw dict by setting self.encoded_state_dict=False. \
571579
# Note: you can't change the state once it's done
580+
def state_dict_keys(self):
581+
return set(
582+
[
583+
"W_q",
584+
"nbits",
585+
"group_size",
586+
"shape",
587+
"scale",
588+
"zero",
589+
"axis",
590+
"packing",
591+
"unpack_view_dtype",
592+
"view_as_float",
593+
"quant_scale",
594+
"quant_zero",
595+
"compute_dtype",
596+
"bias",
597+
"offload_meta",
598+
"encoded_state_dict",
599+
"stores_quant_config",
600+
"channel_wise",
601+
"optimize",
602+
"round_zero",
603+
]
604+
)
605+
572606
def state_dict(self, *args, **kwargs): # nn.Module override compatible
607+
if not self.is_initialized():
608+
return {k: None for k in self.state_dict_keys()}
609+
573610
if (
574611
self.quant_config["scale_quant_params"]
575612
or self.quant_config["zero_quant_params"]
@@ -1027,11 +1064,21 @@ def hqq_base_quant_config(
10271064
"view_as_float": view_as_float,
10281065
}
10291066

1030-
if(quant_zero or quant_scale):
1031-
print(colored('Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.', 'yellow'))
1067+
if quant_zero or quant_scale:
1068+
print(
1069+
colored(
1070+
"Warning: Quantized meta-data is deprecated and will be removed. It is not supported for quantized model serialization.",
1071+
"yellow",
1072+
)
1073+
)
10321074

1033-
if(offload_meta):
1034-
print(colored('Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.', 'yellow'))
1075+
if offload_meta:
1076+
print(
1077+
colored(
1078+
"Warning: Meta-data offloading is deprecated and will be removed. It is not supported for quantized model serialization.",
1079+
"yellow",
1080+
)
1081+
)
10351082

10361083
if offload_meta:
10371084
if quant_scale != quant_zero:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def run(self):
4343

4444
setup(
4545
name="hqq",
46-
version="0.2.0",
46+
version="0.2.1",
4747
description="Half-Quadratic Quantization (HQQ)",
4848
url="https://github.com/mobiusml/hqq/",
4949
author="Dr. Hicham Badri",

0 commit comments

Comments
 (0)