diff --git a/src/boltz/model/models/boltz2.py b/src/boltz/model/models/boltz2.py index d42f3400c..34728cccd 100644 --- a/src/boltz/model/models/boltz2.py +++ b/src/boltz/model/models/boltz2.py @@ -35,6 +35,7 @@ ) from boltz.model.optim.ema import EMA from boltz.model.optim.scheduler import AlphaFoldLRScheduler +from boltz.model.modules.utils import autocast_device_type class Boltz2(LightningModule): @@ -529,7 +530,7 @@ def forward( "token_trans_bias": token_trans_bias, } - with torch.autocast("cuda", enabled=False): + with torch.autocast(autocast_device_type(s.device.type), enabled=False): struct_out = self.structure_module.sample( s_trunk=s.float(), s_inputs=s_inputs.float(), @@ -568,7 +569,7 @@ def forward( feats["coords"] = atom_coords # (multiplicity, L, 3) assert len(feats["coords"].shape) == 3 - with torch.autocast("cuda", enabled=False): + with torch.autocast(autocast_device_type(s.device.type), enabled=False): struct_out = self.structure_module( s_trunk=s.float(), s_inputs=s_inputs.float(), @@ -625,7 +626,7 @@ def forward( ] s_inputs = self.input_embedder(feats, affinity=True) - with torch.autocast("cuda", enabled=False): + with torch.autocast(autocast_device_type(s.device.type), enabled=False): if self.affinity_ensemble: dict_out_affinity1 = self.affinity_module1( s_inputs=s_inputs.detach(), diff --git a/src/boltz/model/modules/utils.py b/src/boltz/model/modules/utils.py index a5a1f2e25..98df46395 100644 --- a/src/boltz/model/modules/utils.py +++ b/src/boltz/model/modules/utils.py @@ -13,6 +13,18 @@ LinearNoBias = partial(Linear, bias=False) +def autocast_device_type(device_type: str) -> str: + """Return a device_type string accepted by ``torch.autocast``. + + When autocast is used with ``enabled=False`` (to disable autocasting), + PyTorch still validates the device_type. MPS was not a valid autocast + device type until PyTorch 2.4. Since ``enabled=False`` is a no-op, we + fall back to ``"cpu"``, which is always accepted. + """ + from torch.amp.autocast_mode import is_autocast_available + + return device_type if is_autocast_available(device_type) else "cpu" + def exists(v): return v is not None