Skip to content

Commit

Permalink
[PyTorch] Debug checkpointing with operation-based API (#1063)
Browse files Browse the repository at this point in the history
* Debug checkpointing with operation-based API

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Store checkpoint FP8 state on CPU

Signed-off-by: Tim Moon <[email protected]>

* Fix bug where linear op was saving params multiple times

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warnings

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
timmoon10 and pre-commit-ci[bot] authored Nov 5, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 50b22da commit f20d3dd
Showing 3 changed files with 376 additions and 3 deletions.
185 changes: 185 additions & 0 deletions tests/pytorch/test_torch_save_load.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@

import pytest
import torch
import transformer_engine.common
import transformer_engine.pytorch as te
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
@@ -287,3 +289,186 @@ def test_fp8_model_checkpoint(
torch.testing.assert_close(
model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()
)


@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("save_fp8_model", (False, True))
@pytest.mark.parametrize("load_fp8_model", (False, True))
def test_sequential_model(
*,
in_shape: Iterable[int] = (16, 16),
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
save_steps: int = 2,
load_steps: int = 2,
fp8: bool,
save_fp8_model: bool,
load_fp8_model: bool,
) -> None:

# Skip invalid configurations
if fp8 or save_fp8_model or load_fp8_model:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
margin = 2
fp8_format = transformer_engine.common.recipe.Format.E4M3
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
fp8_format=fp8_format,
amax_history_len=8,
amax_compute_algo="max",
)

# Construct model to save to checkpoint
with te.fp8_model_init(enabled=save_fp8_model):
model = te_ops.Sequential(
te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype),
)
with torch.no_grad():
torch.rand(model[0].weight.size(), out=model[0].weight)
torch.rand(model[0].bias.size(), out=model[0].bias)

# Synthetic data
xs_ref = [
torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps)
]
dys_ref = [
torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps)
]

def train_step(
model: te_ops.Sequential,
x: torch.Tensor,
dy: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Helper function to perform training step"""
x = x.detach().clone().requires_grad_()
dy = dy.detach().clone()
with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe):
y = model(x)
y.backward(dy)
with torch.no_grad():
for param in model.parameters():
param += 0.125
return (
y.detach().clone(),
x.grad.detach().clone(),
model[0].weight.detach().float().clone(),
)

# Initial training steps with saved model
ys_ref = []
dxs_ref = []
ws_ref = []
for step in range(save_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
ys_ref.append(y)
dxs_ref.append(dx)
ws_ref.append(w)

# Keep track of FP8 metadata if needed
fp8_meta_ref = dict(input={}, param={}, grad_output={})
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
m_ref["amax"] = m_model.amax_history.detach().clone()
m_ref["scale"] = m_model.scale.detach().clone()
m_ref["scale_inv"] = m_model.scale_inv.detach().clone()
del m_model, m_ref

# Save checkpoint
byte_stream = io.BytesIO()
torch.save(model.state_dict(), byte_stream)
model_bytes = byte_stream.getvalue()
del byte_stream

# More training steps with saved model
for step in range(save_steps, save_steps + load_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
ys_ref.append(y)
dxs_ref.append(dx)
ws_ref.append(w)

# Disturb and destroy model
with torch.no_grad():
for param in model.parameters():
param.zero_()
model[0].basic_ops[0]._fp8_metas = None
del model

# Construct new model to load from checkpoint
with te.fp8_model_init(enabled=load_fp8_model):
model = te_ops.Sequential(
te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype),
)

# Tolerances for numerical checks
tols = {}
if fp8 or save_fp8_model or load_fp8_model:
tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625
exact_tols = dict(rtol=0, atol=0)

# Training steps with dummy data
for step in range(save_steps):
y, dx, w = train_step(
model,
torch.zeros_like(xs_ref[step]),
torch.zeros_like(dys_ref[step]),
)

# Make sure results don't match saved model
with pytest.raises(AssertionError):
torch.testing.assert_close(y, ys_ref[step], **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(dx, dxs_ref[step], **tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(w, ws_ref[step], **tols)

# Make sure new model's FP8 metadata doesn't match saved model
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols)
with pytest.raises(AssertionError):
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)

# Load checkpoint
model.load_state_dict(torch.load(io.BytesIO(model_bytes)))
del model_bytes

# Check that new model's FP8 metadata matches saved model
if fp8:
for fp8_meta_type, fp8_meta_key in (
("input", "scaling_fwd"),
("param", "scaling_fwd"),
("grad_output", "scaling_bwd"),
):
m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key]
m_ref = fp8_meta_ref[fp8_meta_type]
torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols)
torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols)
torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols)

# More training steps with loaded model
for step in range(save_steps, save_steps + load_steps):
y, dx, w = train_step(model, xs_ref[step], dys_ref[step])
torch.testing.assert_close(y, ys_ref[step], **tols)
torch.testing.assert_close(dx, dxs_ref[step], **tols)
torch.testing.assert_close(w, ws_ref[step], **tols)
38 changes: 35 additions & 3 deletions transformer_engine/pytorch/ops/linear.py
Original file line number Diff line number Diff line change
@@ -133,6 +133,38 @@ def __init__(
# Initialize base class
super().__init__(ops)

# Register parameters
self.register_parameter("weight", self.basic_ops[0].weight)
self.register_parameter("bias", self.basic_ops[1].bias if bias else None)
self._has_bias: bool = bias

@property
def weight(self) -> torch.nn.Parameter:
"""Weight tensor
Parameter is owned by `BasicLinear` operation.
"""
return self.basic_ops[0].weight

@weight.setter
def weight(self, value: Optional[torch.nn.Parameter]) -> None:
self.basic_ops[0].weight = value

@property
def bias(self) -> Optional[torch.nn.Parameter]:
"""Bias tensor
Parameter is owned by `Bias` operation.
"""
if self._has_bias:
return self.basic_ops[1].bias
return None

@bias.setter
def bias(self, value: Optional[torch.nn.Parameter]) -> None:
if self._has_bias:
self.basic_ops[1].bias = value
elif value is not None:
raise ValueError(
"Attempted to set bias parameter in Linear operation "
"that does not have bias enabled"
)
156 changes: 156 additions & 0 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import abc
from collections.abc import Iterable
import dataclasses
import pickle
from typing import Any, Optional

import torch
@@ -504,6 +505,161 @@ def forward(
basic_op_kwargs=[kwargs],
)

def get_extra_state(self) -> Optional[torch.Tensor]:
"""Serialize extra state
Contains metadata for FP8 casting.
"""

# This implementation is working around a few issues:
#
# (1) PyTorch's "extra state" infrastructure might be able to
# support any picklable type, but they make no guarantees.
# It seems that ONNX export experiences issues with
# non-tensor extra state.
# (2) PyTorch's checkpointing infrastructure does not remap
# devices for "extra state" like it does for "state dict".
# Thus, we want to avoid putting extra state on the GPU
# since it may be loaded on the wrong device.
# (3) The extra state consists of many small tensors. If we
# want to copy them all to CPU, then we need to avoid the
# overhead of many GPU-CPU memory transfers.
#
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
# See: https://github.com/NVIDIA/TransformerEngine/pull/363

# Return immediately if op has no FP8 state
has_fp8_state = any(
self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output")
)
if not has_fp8_state:
return None

def to_cpu(src: torch.Tensor) -> torch.Tensor:
"""Helper function to make CPU copy of tensor
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
dst = torch.empty_like(src, device="cpu")
dst.copy_(src, non_blocking=True)
return dst

# Store FP8 state
state = {}
for mode in ("input", "param", "grad_output"):

# Get state for a given FP8 tensor
if self.num_fp8_scales(mode) == 0:
state[mode] = None
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue
state[mode] = {}

# Store tensors
if "scaling_fwd" in fp8_meta:
state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale)
state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv)
state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history)
if "scaling_bwd" in fp8_meta:
state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale)
state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv)
state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history)

# Store other picklable items
extra = {}
for key, val in fp8_meta.items():
if key == "buffer_index_and_autocast_key":
continue
if not isinstance(val, (bool, int, float, str, tuple, list)):
continue
extra[key] = val
state[mode]["extra_fp8_variables"] = extra

# Serialize state into byte tensor
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized

def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load extra state"""
if state is None:
return

# Deserialize state from byte tensor
state = pickle.loads(state.detach().numpy(force=True).tobytes())
if state is None:
return

def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
"""Helper function to copy tensor from CPU
Memory transfer is asynchronous w.r.t. host, so GPU should
be synchronized before using result.
"""
if src.size() != dst.size():
dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device)
dst.copy_(src, non_blocking=True)

# Load FP8 state
for mode in ("input", "param", "grad_output"):

# Get state for a given FP8 tensor
if mode not in state:
continue
if self.num_fp8_scales(mode) == 0:
continue
fp8_meta = self.get_fp8_meta(mode)
if fp8_meta is None:
continue

# Load extra state
fp8_meta.update(state[mode]["extra_fp8_variables"])
if "amax_history_fwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0)
elif "amax_history_bwd" in state[mode]:
fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0)
if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta:
del fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

# Load tensors
fp8_meta = self.get_fp8_meta(mode)
if "scaling_fwd" in fp8_meta:
fp8_meta_fwd = fp8_meta["scaling_fwd"]
copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale)
copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv)
copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history)
if "scaling_bwd" in fp8_meta:
fp8_meta_bwd = fp8_meta["scaling_bwd"]
copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale)
copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv)
copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history)

# Finish CPU-GPU memory transfers
torch.cuda.synchronize()

def _load_from_state_dict(self, *args, **kwargs) -> None:
"""Load state"""

# In the base PyTorch module class, the extra state is loaded
# _after_ the parameters. However, copying values into FP8
# parameters requires an FP8 cast, which uses a scaling factor
# from the operation's FP8 metadata. The FP8 metadata is
# included in the operation's extra state, so we need to
# manually load the extra state before loading parameters.

state_dict, prefix = args[0], args[1]
extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
super()._load_from_state_dict(*args, **kwargs)


class FusedOperation(FusibleOperation):
"""Compound tensor operation supported by the operation fuser

0 comments on commit f20d3dd

Please sign in to comment.