Skip to content

illegal memory access with FSDP2 and AdamW8bit #1633

@llllvvuu

Description

@llllvvuu

System Info

Docker base: nvcr.io/nvidia/pytorch:24.07-py3
Python 3.10.12
torch 2.6.0+cu124
bitsandbytes 0.45.5
2x NVIDIA H100 80GB HBM3

Reproduction

torchrun --standalone --nnodes=1 nproc_per_node=2 repro.py
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import (
    fully_shard,
    MixedPrecisionPolicy,
)

import bitsandbytes as bnb

dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())

class SimpleLayer(nn.Module):
    def __init__(self, hidden_size: int = 512):
        super().__init__()
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.norm(self.linear(x))


class SimpleModel(nn.Module):
    def __init__(
        self, vocab_size: int = 100000, hidden_size: int = 512, num_layers: int = 2
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.layers = nn.ModuleList(
            [SimpleLayer(hidden_size) for _ in range(num_layers)]
        )
        self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embedding(input_ids)

        for layer in self.layers:
            x = layer(x)

        return self.lm_head(x)

# model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B", use_cache=False)
model = SimpleModel()

mesh = init_device_mesh(device_type="cuda", mesh_shape=(1, 2), mesh_dim_names=("replicate", "shard"))
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, cast_forward_inputs=True)
for module in model.modules():
    if isinstance(module, SimpleLayer):
        fully_shard(module, mesh=mesh, mp_policy=mp_policy)
fully_shard(model, mesh=mesh, mp_policy=mp_policy)

optimizer = bnb.optim.AdamW8bit(model.parameters())
for _ in range(10):
    batch = torch.randint(0, 99999, (1, 10), device=torch.cuda.current_device())
    output = model(batch)
    loss = output.sum()
    optimizer.zero_grad()
    loss.backward()
    print("loss", loss.item())
    optimizer.step()

dist.destroy_process_group()

Expected behavior

Does not crash with:

[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/optim/optimizer.py", line 493, in wrapper
[rank0]:     out = func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/optim/optimizer.py", line 292, in step
[rank0]:     torch.cuda.synchronize()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 985, in synchronize
[rank0]:     return torch._C._cuda_synchronize()
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Contributions WelcomeWe welcome contributions to fix this issue!FSDPOptimizersIssues or feature requests relating to optimizers

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions