Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions MPS_SUPPORT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# MPS (Metal Performance Shaders) Support for BoltzGen

This document describes the MPS support that has been added to BoltzGen, enabling the use of PyTorch on Apple Silicon (M1, M2, M3, etc.) GPUs.

## What Changed

The repository has been updated to support PyTorch's MPS (Metal Performance Shaders) backend, which allows BoltzGen to run on Apple Silicon GPUs alongside the existing CUDA and CPU support.

### Key Changes

1. **Device Utility Module** (`src/boltzgen/utils/device.py`)
- Added centralized device detection and management
- Automatically detects the best available device: CUDA > MPS > CPU
- Provides device-agnostic cache clearing and autocast support

2. **CLI Updates** (`src/boltzgen/cli/boltzgen.py`)
- Replaced hardcoded CUDA device detection with device-agnostic functions
- `get_device_capability()` - works for CUDA/MPS/CPU
- `get_device_count()` - returns correct device count for all backends

3. **Model Updates** (`src/boltzgen/model/models/boltz.py`)
- Updated all `torch.autocast("cuda", ...)` calls to use `get_autocast_device_type()`
- Replaced `torch.cuda.empty_cache()` with `empty_cache()`
- Updated device tensor creation to use `get_device_type()`

4. **Validation Updates** (`src/boltzgen/model/validation/refolding.py`)
- Updated cache clearing to support MPS
- Added conditional CUDA-specific cleanup (only runs on CUDA devices)

5. **Module Updates** (`src/boltzgen/model/modules/trunk.py`)
- Updated autocast calls in template and token distance modules

## Usage

### Running on Apple Silicon (MPS)

BoltzGen will automatically detect and use MPS when running on Apple Silicon:

```bash
# No special flags needed - MPS will be auto-detected
boltzgen run design_spec.yaml --output results/
```

### Device Selection Priority

The device selection follows this priority:
1. **CUDA** - If NVIDIA GPU is available
2. **MPS** - If Apple Silicon GPU is available
3. **CPU** - Fallback

### Checking Device

You can verify which device is being used by checking the logs during execution. The CLI will print:
```
Using kernels: True/False [device capability: (X, Y)]
Using N devices
```

### Configuration Files

The YAML configuration files use `accelerator: gpu` which works for both CUDA and MPS:
- PyTorch Lightning automatically detects the appropriate GPU backend
- No changes needed to existing config files

## Limitations and Considerations

### MPS vs CUDA Performance

1. **Single Device**: MPS currently supports only a single device, while CUDA can use multiple GPUs
2. **Kernel Support**: Some CUDA-specific kernels may not be available on MPS
3. **Memory Management**: MPS memory management differs from CUDA; you may need to adjust batch sizes

### Known Issues

1. **Mixed Precision**: MPS autocast support requires PyTorch 2.1+
2. **Some Operations**: A few operations may fall back to CPU on MPS
3. **Memory**: MPS shares memory with the system, unlike dedicated CUDA GPUs

## Requirements

- **PyTorch**: 2.0+ (2.1+ recommended for full MPS autocast support)
- **macOS**: 12.3+ (Monterey or later)
- **Apple Silicon**: M1, M2, M3, or later

## Testing

To verify MPS support is working:

```python
import torch
from boltzgen.utils.device import get_device_type, get_device_count

print(f"Device type: {get_device_type()}") # Should print "mps" on Apple Silicon
print(f"Device count: {get_device_count()}") # Should print 1 on MPS
print(f"MPS available: {torch.backends.mps.is_available()}") # Should be True
```

## Migration Notes

If you have custom code or scripts that reference CUDA explicitly:

### Before
```python
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
```

### After
```python
from boltzgen.utils.device import get_device_type, empty_cache

device = get_device_type()
empty_cache()
```

## Performance Tips

1. **Batch Size**: Start with smaller batch sizes on MPS and adjust based on available memory
2. **Precision**: Use `bf16-mixed` precision (already configured) for best performance
3. **Kernels**: The `--use_kernels` flag works automatically based on device capability

## Troubleshooting

### "MPS backend out of memory"
- Reduce batch size in config files
- Close other applications to free up memory
- MPS shares system memory, unlike dedicated GPUs

### Slower than expected
- Ensure PyTorch 2.1+ is installed for optimal MPS support
- Check that `torch.backends.mps.is_available()` returns `True`
- Some operations may still fall back to CPU

### Import errors
- Verify PyTorch installation: `pip install torch>=2.1.0`
- Check macOS version: `sw_vers` (should be 12.3+)

## Contributing

When adding new PyTorch code:
1. Use `boltzgen.utils.device` functions instead of hardcoded device strings
2. Use `get_autocast_device_type()` for autocast contexts
3. Use `empty_cache()` instead of `torch.cuda.empty_cache()`
4. Test on both CUDA and MPS if possible

## References

- [PyTorch MPS Documentation](https://pytorch.org/docs/stable/notes/mps.html)
- [PyTorch Lightning MPS Support](https://lightning.ai/docs/pytorch/stable/accelerators/mps.html)
5 changes: 3 additions & 2 deletions src/boltzgen/cli/boltzgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from boltzgen.data.parse.schema import YamlDesignParser
from boltzgen.data.write.mmcif import to_mmcif
from boltzgen.task.task import Task
from boltzgen.utils.device import get_device_capability, get_device_count

### Paths and constants ####
# Get the path to the project root (where main.py and configs/ are located)
Expand Down Expand Up @@ -860,7 +861,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
)

# Handle use_kernels argument
device_capability = torch.cuda.get_device_capability()
device_capability = get_device_capability()
use_kernels = None
if args.use_kernels == "auto":
use_kernels = device_capability[0] >= 8
Expand All @@ -881,7 +882,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
)

devices = (
args.devices if args.devices is not None else torch.cuda.device_count()
args.devices if args.devices is not None else get_device_count()
)
print(f"Using {devices} devices")

Expand Down
25 changes: 13 additions & 12 deletions src/boltzgen/model/models/boltz.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
InverseFoldingEncoder,
InverseFoldingDecoder,
)
from boltzgen.utils.device import (
get_autocast_device_type,
get_device_type,
empty_cache,
)

import torch

Expand Down Expand Up @@ -655,7 +660,7 @@ def forward(
):
if self.inference_logging:
print("\nRunning Structure Module.\n")
with torch.autocast("cuda", enabled=False):
with torch.autocast(get_autocast_device_type(), enabled=False):
if not self.inverse_fold:
struct_out = self.structure_module.sample(
s_trunk=s.float(),
Expand Down Expand Up @@ -711,7 +716,7 @@ def forward(
feats["coords"] = atom_coords # (multiplicity, L, 3)
assert len(feats["coords"].shape) == 3

with torch.autocast("cuda", enabled=False):
with torch.autocast(get_autocast_device_type(), enabled=False):
if not self.inverse_fold:
struct_out = self.structure_module(
s_trunk=s.float(),
Expand Down Expand Up @@ -769,7 +774,7 @@ def forward(
]
s_inputs = self.input_embedder(feats, affinity=True)

with torch.autocast("cuda", enabled=False):
with torch.autocast(get_autocast_device_type(), enabled=False):
if self.affinity_ensemble:
dict_out_affinity1 = self.affinity_module1(
s_inputs=s_inputs.detach(),
Expand Down Expand Up @@ -1102,18 +1107,14 @@ def gradient_norm(self, module):
if p.requires_grad and p.grad is not None
]
if len(parameters) == 0:
return torch.tensor(
0.0, device="cuda" if torch.cuda.is_available() else "cpu"
)
return torch.tensor(0.0, device=get_device_type())
norm = torch.stack(parameters).sum().sqrt()
return norm

def parameter_norm(self, module):
parameters = [p.norm(p=2) ** 2 for p in module.parameters() if p.requires_grad]
if len(parameters) == 0:
return torch.tensor(
0.0, device="cuda" if torch.cuda.is_available() else "cpu"
)
return torch.tensor(0.0, device=get_device_type())
norm = torch.stack(parameters).sum().sqrt()
return norm

Expand Down Expand Up @@ -1165,7 +1166,7 @@ def validation_step(
"res_type =",
batch["res_type"].shape,
)
torch.cuda.empty_cache()
empty_cache()
return
raise e
else:
Expand All @@ -1184,7 +1185,7 @@ def validation_step(
if "out of memory" in str(e):
msg = f"| WARNING: ran out of memory, skipping batch, {idx_dataset}"
print(msg)
torch.cuda.empty_cache()
empty_cache()
return
raise e

Expand Down Expand Up @@ -1368,7 +1369,7 @@ def predict_step(
except RuntimeError as e: # catch out of memory exceptions
if "out of memory" in str(e):
print("| WARNING: ran out of memory, skipping batch")
torch.cuda.empty_cache()
empty_cache()
return {"exception": True}
else:
raise e
Expand Down
8 changes: 6 additions & 2 deletions src/boltzgen/model/modules/trunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,14 @@ def forward(
b_frame_mask = b_frame_mask[..., None]

# Compute asym mask, template features only attend within the same chain
from boltzgen.utils.device import get_autocast_device_type

B, T = res_type.shape[:2] # noqa: N806
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)

# Compute template features
with torch.autocast(device_type="cuda", enabled=False):
with torch.autocast(device_type=get_autocast_device_type(), enabled=False):
# Compute distogram
cb_dists = torch.cdist(cb_coords, cb_coords)
boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
Expand Down Expand Up @@ -499,12 +501,14 @@ def forward(
The updated pairwise embeddings.

"""
from boltzgen.utils.device import get_autocast_device_type

# Load relevant features
token_distance_mask = feats["token_distance_mask"]
token_coords = feats["center_coords"]

# Compute template features
with torch.autocast(device_type="cuda", enabled=False):
with torch.autocast(device_type=get_autocast_device_type(), enabled=False):
# Compute distogram
dists = torch.cdist(token_coords, token_coords)
boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
Expand Down
10 changes: 8 additions & 2 deletions src/boltzgen/model/validation/refolding.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,20 @@ def process(

def on_epoch_end(self, model):
# Cleanup
from boltzgen.utils.device import empty_cache, get_device_type

del self.folding_model
self.folding_model = None
del self.affinity_model
self.affinity_model = None
torch._C._cuda_clearCublasWorkspaces()

# CUDA-specific cleanup
if get_device_type() == "cuda":
torch._C._cuda_clearCublasWorkspaces()

torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
empty_cache()

# Compute standard metrics
self.common_on_epoch_end(model, logname="val_monomer_ligand")
Expand Down
Loading