Skip to content

Commit 6d19c27

Browse files
committed
Work around lack of inheritance in TorchScript.
1 parent 1d9f02f commit 6d19c27

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/sire/qm/_emle.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def emle(
156156
)
157157

158158
try:
159+
import torch as _torch
159160
from emle.models import EMLE as _EMLE
160161

161162
has_model = True
@@ -181,7 +182,17 @@ def emle(
181182
raise ValueError("Unable to select 'qm_atoms' from 'mols'")
182183

183184
if has_model:
184-
if not isinstance(calculator, (_EMLECalculator, _EMLE)):
185+
# EMLECalculator.
186+
if isinstance(calculator, _EMLECalculator):
187+
pass
188+
# EMLE model. Note that TorchScript doesn't support inheritance, so
189+
# we need to check whether this is a torch.nn.Module and whether it
190+
# has the "_is_emle" attribute, which is added to all EMLE models.
191+
elif isinstance(calculator, _torch.nn.Module) and hasattr(
192+
calculator, "_is_emle"
193+
):
194+
pass
195+
else:
185196
raise TypeError(
186197
"'calculator' must be a of type 'emle.calculator.EMLECalculator' or 'emle.models.EMLE'"
187198
)

0 commit comments

Comments
 (0)