-
Notifications
You must be signed in to change notification settings - Fork 722
Open
Labels
Contributions WelcomeWe welcome contributions to fix this issue!We welcome contributions to fix this issue!FSDPOptimizersIssues or feature requests relating to optimizersIssues or feature requests relating to optimizers
Description
System Info
Docker base: nvcr.io/nvidia/pytorch:24.07-py3
Python 3.10.12
torch 2.6.0+cu124
bitsandbytes 0.45.5
2x NVIDIA H100 80GB HBM3
Reproduction
torchrun --standalone --nnodes=1 nproc_per_node=2 repro.py
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
fully_shard,
MixedPrecisionPolicy,
)
import bitsandbytes as bnb
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
class SimpleLayer(nn.Module):
def __init__(self, hidden_size: int = 512):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.norm(self.linear(x))
class SimpleModel(nn.Module):
def __init__(
self, vocab_size: int = 100000, hidden_size: int = 512, num_layers: int = 2
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.layers = nn.ModuleList(
[SimpleLayer(hidden_size) for _ in range(num_layers)]
)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
return self.lm_head(x)
# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", use_cache=False)
model = SimpleModel()
mesh = init_device_mesh(device_type="cuda", mesh_shape=(1, 2), mesh_dim_names=("replicate", "shard"))
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, cast_forward_inputs=True)
for module in model.modules():
if isinstance(module, SimpleLayer):
fully_shard(module, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)
optimizer = bnb.optim.AdamW8bit(model.parameters())
for _ in range(10):
batch = torch.randint(0, 99999, (1, 10), device=torch.cuda.current_device())
output = model(batch)
loss = output.sum()
optimizer.zero_grad()
loss.backward()
print("loss", loss.item())
optimizer.step()
dist.destroy_process_group()
Expected behavior
Does not crash with:
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 493, in wrapper
[rank0]: out = func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/optim/optimizer.py", line 292, in step
[rank0]: torch.cuda.synchronize()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 985, in synchronize
[rank0]: return torch._C._cuda_synchronize()
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Metadata
Metadata
Assignees
Labels
Contributions WelcomeWe welcome contributions to fix this issue!We welcome contributions to fix this issue!FSDPOptimizersIssues or feature requests relating to optimizersIssues or feature requests relating to optimizers