File tree 1 file changed +12
-1
lines changed
1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -156,6 +156,7 @@ def emle(
156
156
)
157
157
158
158
try :
159
+ import torch as _torch
159
160
from emle .models import EMLE as _EMLE
160
161
161
162
has_model = True
@@ -181,7 +182,17 @@ def emle(
181
182
raise ValueError ("Unable to select 'qm_atoms' from 'mols'" )
182
183
183
184
if has_model :
184
- if not isinstance (calculator , (_EMLECalculator , _EMLE )):
185
+ # EMLECalculator.
186
+ if isinstance (calculator , _EMLECalculator ):
187
+ pass
188
+ # EMLE model. Note that TorchScript doesn't support inheritance, so
189
+ # we need to check whether this is a torch.nn.Module and whether it
190
+ # has the "_is_emle" attribute, which is added to all EMLE models.
191
+ elif isinstance (calculator , _torch .nn .Module ) and hasattr (
192
+ calculator , "_is_emle"
193
+ ):
194
+ pass
195
+ else :
185
196
raise TypeError (
186
197
"'calculator' must be a of type 'emle.calculator.EMLECalculator' or 'emle.models.EMLE'"
187
198
)
You can’t perform that action at this time.
0 commit comments