We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 0bd8d2f commit c60218eCopy full SHA for c60218e
hqq/core/quantize.py
@@ -680,11 +680,13 @@ def _load_from_state_dict(
680
layer_state_dict[key] = state_dict.pop(prefix + key)
681
else:
682
if(key not in ['bias']):
683
- missing_keys.append(prefix + key)
684
-
685
- layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False)
+ missing_keys.append(prefix + key)
686
687
- self.load_state_dict(layer_state_dict, strict=strict)
+ if 'W_q' in layer_state_dict:
+ layer_state_dict['W_q'] = nn.Parameter(layer_state_dict['W_q'], requires_grad=False)
+ self.load_state_dict(layer_state_dict, strict=strict)
688
+ else:
689
+ missing_keys.append(prefix + "W_q")
690
691
def load_state_dict(self, state_dict, strict=True, assign=False):
692
if "encoded_state_dict" in state_dict:
0 commit comments