Skip to content

Commit c60218e

Browse files
committed
check W_q in state_dict
1 parent 0bd8d2f commit c60218e

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

hqq/core/quantize.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -680,11 +680,13 @@ def _load_from_state_dict(
680680
layer_state_dict[key] = state_dict.pop(prefix + key)
681681
else:
682682
if(key not in ['bias']):
683-
missing_keys.append(prefix + key)
684-
685-
layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False)
683+
missing_keys.append(prefix + key)
686684

687-
self.load_state_dict(layer_state_dict, strict=strict)
685+
if 'W_q' in layer_state_dict:
686+
layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False)
687+
self.load_state_dict(layer_state_dict, strict=strict)
688+
else:
689+
missing_keys.append(prefix + "W_q")
688690

689691
def load_state_dict(self, state_dict, strict=True, assign=False):
690692
if "encoded_state_dict" in state_dict:

0 commit comments

Comments
 (0)