diff --git a/openchem/models/openchem_model.py b/openchem/models/openchem_model.py index e2fd376..102b55b 100644 --- a/openchem/models/openchem_model.py +++ b/openchem/models/openchem_model.py @@ -83,7 +83,7 @@ def cast_inputs(sample, task, use_cuda): def load_model(self, path): weights = torch.load(path) - self.load_state_dict(weights) + self.load_state_dict(weights, strict=False) def save_model(self, path): torch.save(self.state_dict(), path)