diff --git a/MPS_SUPPORT.md b/MPS_SUPPORT.md new file mode 100644 index 00000000..576c4f02 --- /dev/null +++ b/MPS_SUPPORT.md @@ -0,0 +1,149 @@ +# MPS (Metal Performance Shaders) Support for BoltzGen + +This document describes the MPS support that has been added to BoltzGen, enabling the use of PyTorch on Apple Silicon (M1, M2, M3, etc.) GPUs. + +## What Changed + +The repository has been updated to support PyTorch's MPS (Metal Performance Shaders) backend, which allows BoltzGen to run on Apple Silicon GPUs alongside the existing CUDA and CPU support. + +### Key Changes + +1. **Device Utility Module** (`src/boltzgen/utils/device.py`) + - Added centralized device detection and management + - Automatically detects the best available device: CUDA > MPS > CPU + - Provides device-agnostic cache clearing and autocast support + +2. **CLI Updates** (`src/boltzgen/cli/boltzgen.py`) + - Replaced hardcoded CUDA device detection with device-agnostic functions + - `get_device_capability()` - works for CUDA/MPS/CPU + - `get_device_count()` - returns correct device count for all backends + +3. **Model Updates** (`src/boltzgen/model/models/boltz.py`) + - Updated all `torch.autocast("cuda", ...)` calls to use `get_autocast_device_type()` + - Replaced `torch.cuda.empty_cache()` with `empty_cache()` + - Updated device tensor creation to use `get_device_type()` + +4. **Validation Updates** (`src/boltzgen/model/validation/refolding.py`) + - Updated cache clearing to support MPS + - Added conditional CUDA-specific cleanup (only runs on CUDA devices) + +5. **Module Updates** (`src/boltzgen/model/modules/trunk.py`) + - Updated autocast calls in template and token distance modules + +## Usage + +### Running on Apple Silicon (MPS) + +BoltzGen will automatically detect and use MPS when running on Apple Silicon: + +```bash +# No special flags needed - MPS will be auto-detected +boltzgen run design_spec.yaml --output results/ +``` + +### Device Selection Priority + +The device selection follows this priority: +1. **CUDA** - If NVIDIA GPU is available +2. **MPS** - If Apple Silicon GPU is available +3. **CPU** - Fallback + +### Checking Device + +You can verify which device is being used by checking the logs during execution. The CLI will print: +``` +Using kernels: True/False [device capability: (X, Y)] +Using N devices +``` + +### Configuration Files + +The YAML configuration files use `accelerator: gpu` which works for both CUDA and MPS: +- PyTorch Lightning automatically detects the appropriate GPU backend +- No changes needed to existing config files + +## Limitations and Considerations + +### MPS vs CUDA Performance + +1. **Single Device**: MPS currently supports only a single device, while CUDA can use multiple GPUs +2. **Kernel Support**: Some CUDA-specific kernels may not be available on MPS +3. **Memory Management**: MPS memory management differs from CUDA; you may need to adjust batch sizes + +### Known Issues + +1. **Mixed Precision**: MPS autocast support requires PyTorch 2.1+ +2. **Some Operations**: A few operations may fall back to CPU on MPS +3. **Memory**: MPS shares memory with the system, unlike dedicated CUDA GPUs + +## Requirements + +- **PyTorch**: 2.0+ (2.1+ recommended for full MPS autocast support) +- **macOS**: 12.3+ (Monterey or later) +- **Apple Silicon**: M1, M2, M3, or later + +## Testing + +To verify MPS support is working: + +```python +import torch +from boltzgen.utils.device import get_device_type, get_device_count + +print(f"Device type: {get_device_type()}") # Should print "mps" on Apple Silicon +print(f"Device count: {get_device_count()}") # Should print 1 on MPS +print(f"MPS available: {torch.backends.mps.is_available()}") # Should be True +``` + +## Migration Notes + +If you have custom code or scripts that reference CUDA explicitly: + +### Before +```python +device = "cuda" if torch.cuda.is_available() else "cpu" +torch.cuda.empty_cache() +``` + +### After +```python +from boltzgen.utils.device import get_device_type, empty_cache + +device = get_device_type() +empty_cache() +``` + +## Performance Tips + +1. **Batch Size**: Start with smaller batch sizes on MPS and adjust based on available memory +2. **Precision**: Use `bf16-mixed` precision (already configured) for best performance +3. **Kernels**: The `--use_kernels` flag works automatically based on device capability + +## Troubleshooting + +### "MPS backend out of memory" +- Reduce batch size in config files +- Close other applications to free up memory +- MPS shares system memory, unlike dedicated GPUs + +### Slower than expected +- Ensure PyTorch 2.1+ is installed for optimal MPS support +- Check that `torch.backends.mps.is_available()` returns `True` +- Some operations may still fall back to CPU + +### Import errors +- Verify PyTorch installation: `pip install torch>=2.1.0` +- Check macOS version: `sw_vers` (should be 12.3+) + +## Contributing + +When adding new PyTorch code: +1. Use `boltzgen.utils.device` functions instead of hardcoded device strings +2. Use `get_autocast_device_type()` for autocast contexts +3. Use `empty_cache()` instead of `torch.cuda.empty_cache()` +4. Test on both CUDA and MPS if possible + +## References + +- [PyTorch MPS Documentation](https://pytorch.org/docs/stable/notes/mps.html) +- [PyTorch Lightning MPS Support](https://lightning.ai/docs/pytorch/stable/accelerators/mps.html) diff --git a/src/boltzgen/cli/boltzgen.py b/src/boltzgen/cli/boltzgen.py index b7992705..6eb4b173 100644 --- a/src/boltzgen/cli/boltzgen.py +++ b/src/boltzgen/cli/boltzgen.py @@ -50,6 +50,7 @@ from boltzgen.data.parse.schema import YamlDesignParser from boltzgen.data.write.mmcif import to_mmcif from boltzgen.task.task import Task +from boltzgen.utils.device import get_device_capability, get_device_count ### Paths and constants #### # Get the path to the project root (where main.py and configs/ are located) @@ -860,7 +861,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path): ) # Handle use_kernels argument - device_capability = torch.cuda.get_device_capability() + device_capability = get_device_capability() use_kernels = None if args.use_kernels == "auto": use_kernels = device_capability[0] >= 8 @@ -881,7 +882,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path): ) devices = ( - args.devices if args.devices is not None else torch.cuda.device_count() + args.devices if args.devices is not None else get_device_count() ) print(f"Using {devices} devices") diff --git a/src/boltzgen/model/models/boltz.py b/src/boltzgen/model/models/boltz.py index 8214b9eb..973bcc60 100755 --- a/src/boltzgen/model/models/boltz.py +++ b/src/boltzgen/model/models/boltz.py @@ -49,6 +49,11 @@ InverseFoldingEncoder, InverseFoldingDecoder, ) +from boltzgen.utils.device import ( + get_autocast_device_type, + get_device_type, + empty_cache, +) import torch @@ -655,7 +660,7 @@ def forward( ): if self.inference_logging: print("\nRunning Structure Module.\n") - with torch.autocast("cuda", enabled=False): + with torch.autocast(get_autocast_device_type(), enabled=False): if not self.inverse_fold: struct_out = self.structure_module.sample( s_trunk=s.float(), @@ -711,7 +716,7 @@ def forward( feats["coords"] = atom_coords # (multiplicity, L, 3) assert len(feats["coords"].shape) == 3 - with torch.autocast("cuda", enabled=False): + with torch.autocast(get_autocast_device_type(), enabled=False): if not self.inverse_fold: struct_out = self.structure_module( s_trunk=s.float(), @@ -769,7 +774,7 @@ def forward( ] s_inputs = self.input_embedder(feats, affinity=True) - with torch.autocast("cuda", enabled=False): + with torch.autocast(get_autocast_device_type(), enabled=False): if self.affinity_ensemble: dict_out_affinity1 = self.affinity_module1( s_inputs=s_inputs.detach(), @@ -1102,18 +1107,14 @@ def gradient_norm(self, module): if p.requires_grad and p.grad is not None ] if len(parameters) == 0: - return torch.tensor( - 0.0, device="cuda" if torch.cuda.is_available() else "cpu" - ) + return torch.tensor(0.0, device=get_device_type()) norm = torch.stack(parameters).sum().sqrt() return norm def parameter_norm(self, module): parameters = [p.norm(p=2) ** 2 for p in module.parameters() if p.requires_grad] if len(parameters) == 0: - return torch.tensor( - 0.0, device="cuda" if torch.cuda.is_available() else "cpu" - ) + return torch.tensor(0.0, device=get_device_type()) norm = torch.stack(parameters).sum().sqrt() return norm @@ -1165,7 +1166,7 @@ def validation_step( "res_type =", batch["res_type"].shape, ) - torch.cuda.empty_cache() + empty_cache() return raise e else: @@ -1184,7 +1185,7 @@ def validation_step( if "out of memory" in str(e): msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}" print(msg) - torch.cuda.empty_cache() + empty_cache() return raise e @@ -1368,7 +1369,7 @@ def predict_step( except RuntimeError as e: # catch out of memory exceptions if "out of memory" in str(e): print("| WARNING: ran out of memory, skipping batch") - torch.cuda.empty_cache() + empty_cache() return {"exception": True} else: raise e diff --git a/src/boltzgen/model/modules/trunk.py b/src/boltzgen/model/modules/trunk.py index a168b6be..38645382 100755 --- a/src/boltzgen/model/modules/trunk.py +++ b/src/boltzgen/model/modules/trunk.py @@ -350,12 +350,14 @@ def forward( b_frame_mask = b_frame_mask[..., None] # Compute asym mask, template features only attend within the same chain + from boltzgen.utils.device import get_autocast_device_type + B, T = res_type.shape[:2] # noqa: N806 asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float() asym_mask = asym_mask[:, None].expand(-1, T, -1, -1) # Compute template features - with torch.autocast(device_type="cuda", enabled=False): + with torch.autocast(device_type=get_autocast_device_type(), enabled=False): # Compute distogram cb_dists = torch.cdist(cb_coords, cb_coords) boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1) @@ -499,12 +501,14 @@ def forward( The updated pairwise embeddings. """ + from boltzgen.utils.device import get_autocast_device_type + # Load relevant features token_distance_mask = feats["token_distance_mask"] token_coords = feats["center_coords"] # Compute template features - with torch.autocast(device_type="cuda", enabled=False): + with torch.autocast(device_type=get_autocast_device_type(), enabled=False): # Compute distogram dists = torch.cdist(token_coords, token_coords) boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1) diff --git a/src/boltzgen/model/validation/refolding.py b/src/boltzgen/model/validation/refolding.py index 1e50b8d3..0b1ef164 100755 --- a/src/boltzgen/model/validation/refolding.py +++ b/src/boltzgen/model/validation/refolding.py @@ -294,14 +294,20 @@ def process( def on_epoch_end(self, model): # Cleanup + from boltzgen.utils.device import empty_cache, get_device_type + del self.folding_model self.folding_model = None del self.affinity_model self.affinity_model = None - torch._C._cuda_clearCublasWorkspaces() + + # CUDA-specific cleanup + if get_device_type() == "cuda": + torch._C._cuda_clearCublasWorkspaces() + torch._dynamo.reset() gc.collect() - torch.cuda.empty_cache() + empty_cache() # Compute standard metrics self.common_on_epoch_end(model, logname="val_monomer_ligand") diff --git a/src/boltzgen/utils/device.py b/src/boltzgen/utils/device.py new file mode 100644 index 00000000..83ffbd82 --- /dev/null +++ b/src/boltzgen/utils/device.py @@ -0,0 +1,105 @@ +"""Device utilities for PyTorch device selection (CUDA, MPS, CPU).""" +import torch +from typing import Tuple, Optional + + +def get_device_type() -> str: + """ + Get the best available device type. + + Returns + ------- + str + Device type string: "cuda", "mps", or "cpu" + """ + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + +def get_device() -> torch.device: + """ + Get the best available PyTorch device. + + Returns + ------- + torch.device + PyTorch device object + """ + return torch.device(get_device_type()) + + +def get_device_count() -> int: + """ + Get the number of available devices. + + Returns + ------- + int + Number of devices (1 for MPS/CPU, cuda.device_count() for CUDA) + """ + device_type = get_device_type() + if device_type == "cuda": + return torch.cuda.device_count() + else: + # MPS and CPU only support single device + return 1 + + +def get_device_capability() -> Tuple[int, int]: + """ + Get device capability (compute capability for CUDA, version for MPS/CPU). + + Returns + ------- + Tuple[int, int] + Device capability tuple. For CUDA, returns compute capability. + For MPS/CPU, returns (8, 0) to indicate modern device support. + """ + device_type = get_device_type() + if device_type == "cuda": + return torch.cuda.get_device_capability() + else: + # MPS and modern CPUs support most features, return (8, 0) as default + return (8, 0) + + +def empty_cache(): + """ + Empty device cache if supported. + Works for CUDA and MPS devices. + """ + device_type = get_device_type() + if device_type == "cuda": + torch.cuda.empty_cache() + elif device_type == "mps": + torch.mps.empty_cache() + # CPU doesn't need cache clearing + + +def get_autocast_device_type() -> Optional[str]: + """ + Get the device type for autocast context manager. + + Returns + ------- + Optional[str] + "cuda", "mps", or "cpu" for autocast. Returns None if autocast not supported. + """ + device_type = get_device_type() + if device_type in ["cuda", "cpu"]: + return device_type + elif device_type == "mps": + # MPS support for autocast was added in PyTorch 2.1 + # Check if available + try: + with torch.autocast(device_type="mps"): + pass + return "mps" + except (RuntimeError, TypeError): + # Fallback to CPU autocast if MPS autocast not supported + return "cpu" + return "cpu"