From f20d3ddb92c9a27d30f10edadc31eacb626feadc Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:28:30 -0800 Subject: [PATCH] [PyTorch] Debug checkpointing with operation-based API (#1063) * Debug checkpointing with operation-based API Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Store checkpoint FP8 state on CPU Signed-off-by: Tim Moon * Fix bug where linear op was saving params multiple times Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_torch_save_load.py | 185 +++++++++++++++++++++++ transformer_engine/pytorch/ops/linear.py | 38 ++++- transformer_engine/pytorch/ops/op.py | 156 +++++++++++++++++++ 3 files changed, 376 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 6af7ede234..7bf8fb99d5 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -17,7 +17,9 @@ import pytest import torch +import transformer_engine.common import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -287,3 +289,186 @@ def test_fp8_model_checkpoint( torch.testing.assert_close( model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() ) + + +@pytest.mark.parametrize("fp8", (False, True)) +@pytest.mark.parametrize("save_fp8_model", (False, True)) +@pytest.mark.parametrize("load_fp8_model", (False, True)) +def test_sequential_model( + *, + in_shape: Iterable[int] = (16, 16), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + save_steps: int = 2, + load_steps: int = 2, + fp8: bool, + save_fp8_model: bool, + load_fp8_model: bool, +) -> None: + + # Skip invalid configurations + if fp8 or save_fp8_model or load_fp8_model: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + margin = 2 + fp8_format = transformer_engine.common.recipe.Format.E4M3 + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + fp8_format=fp8_format, + amax_history_len=8, + amax_compute_algo="max", + ) + + # Construct model to save to checkpoint + with te.fp8_model_init(enabled=save_fp8_model): + model = te_ops.Sequential( + te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), + ) + with torch.no_grad(): + torch.rand(model[0].weight.size(), out=model[0].weight) + torch.rand(model[0].bias.size(), out=model[0].bias) + + # Synthetic data + xs_ref = [ + torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) + ] + dys_ref = [ + torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) + ] + + def train_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Helper function to perform training step""" + x = x.detach().clone().requires_grad_() + dy = dy.detach().clone() + with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe): + y = model(x) + y.backward(dy) + with torch.no_grad(): + for param in model.parameters(): + param += 0.125 + return ( + y.detach().clone(), + x.grad.detach().clone(), + model[0].weight.detach().float().clone(), + ) + + # Initial training steps with saved model + ys_ref = [] + dxs_ref = [] + ws_ref = [] + for step in range(save_steps): + y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) + ys_ref.append(y) + dxs_ref.append(dx) + ws_ref.append(w) + + # Keep track of FP8 metadata if needed + fp8_meta_ref = dict(input={}, param={}, grad_output={}) + if fp8: + for fp8_meta_type, fp8_meta_key in ( + ("input", "scaling_fwd"), + ("param", "scaling_fwd"), + ("grad_output", "scaling_bwd"), + ): + m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] + m_ref = fp8_meta_ref[fp8_meta_type] + m_ref["amax"] = m_model.amax_history.detach().clone() + m_ref["scale"] = m_model.scale.detach().clone() + m_ref["scale_inv"] = m_model.scale_inv.detach().clone() + del m_model, m_ref + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save(model.state_dict(), byte_stream) + model_bytes = byte_stream.getvalue() + del byte_stream + + # More training steps with saved model + for step in range(save_steps, save_steps + load_steps): + y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) + ys_ref.append(y) + dxs_ref.append(dx) + ws_ref.append(w) + + # Disturb and destroy model + with torch.no_grad(): + for param in model.parameters(): + param.zero_() + model[0].basic_ops[0]._fp8_metas = None + del model + + # Construct new model to load from checkpoint + with te.fp8_model_init(enabled=load_fp8_model): + model = te_ops.Sequential( + te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), + ) + + # Tolerances for numerical checks + tols = {} + if fp8 or save_fp8_model or load_fp8_model: + tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 + exact_tols = dict(rtol=0, atol=0) + + # Training steps with dummy data + for step in range(save_steps): + y, dx, w = train_step( + model, + torch.zeros_like(xs_ref[step]), + torch.zeros_like(dys_ref[step]), + ) + + # Make sure results don't match saved model + with pytest.raises(AssertionError): + torch.testing.assert_close(y, ys_ref[step], **tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(dx, dxs_ref[step], **tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(w, ws_ref[step], **tols) + + # Make sure new model's FP8 metadata doesn't match saved model + if fp8: + for fp8_meta_type, fp8_meta_key in ( + ("input", "scaling_fwd"), + ("param", "scaling_fwd"), + ("grad_output", "scaling_bwd"), + ): + m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] + m_ref = fp8_meta_ref[fp8_meta_type] + with pytest.raises(AssertionError): + torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) + + # Load checkpoint + model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + del model_bytes + + # Check that new model's FP8 metadata matches saved model + if fp8: + for fp8_meta_type, fp8_meta_key in ( + ("input", "scaling_fwd"), + ("param", "scaling_fwd"), + ("grad_output", "scaling_bwd"), + ): + m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] + m_ref = fp8_meta_ref[fp8_meta_type] + torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) + torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) + torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) + + # More training steps with loaded model + for step in range(save_steps, save_steps + load_steps): + y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) + torch.testing.assert_close(y, ys_ref[step], **tols) + torch.testing.assert_close(dx, dxs_ref[step], **tols) + torch.testing.assert_close(w, ws_ref[step], **tols) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index daa5a6952e..68472f171a 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -133,6 +133,38 @@ def __init__( # Initialize base class super().__init__(ops) - # Register parameters - self.register_parameter("weight", self.basic_ops[0].weight) - self.register_parameter("bias", self.basic_ops[1].bias if bias else None) + self._has_bias: bool = bias + + @property + def weight(self) -> torch.nn.Parameter: + """Weight tensor + + Parameter is owned by `BasicLinear` operation. + + """ + return self.basic_ops[0].weight + + @weight.setter + def weight(self, value: Optional[torch.nn.Parameter]) -> None: + self.basic_ops[0].weight = value + + @property + def bias(self) -> Optional[torch.nn.Parameter]: + """Bias tensor + + Parameter is owned by `Bias` operation. + + """ + if self._has_bias: + return self.basic_ops[1].bias + return None + + @bias.setter + def bias(self, value: Optional[torch.nn.Parameter]) -> None: + if self._has_bias: + self.basic_ops[1].bias = value + elif value is not None: + raise ValueError( + "Attempted to set bias parameter in Linear operation " + "that does not have bias enabled" + ) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 87a7e825bc..9e4963d52a 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -8,6 +8,7 @@ import abc from collections.abc import Iterable import dataclasses +import pickle from typing import Any, Optional import torch @@ -504,6 +505,161 @@ def forward( basic_op_kwargs=[kwargs], ) + def get_extra_state(self) -> Optional[torch.Tensor]: + """Serialize extra state + + Contains metadata for FP8 casting. + + """ + + # This implementation is working around a few issues: + # + # (1) PyTorch's "extra state" infrastructure might be able to + # support any picklable type, but they make no guarantees. + # It seems that ONNX export experiences issues with + # non-tensor extra state. + # (2) PyTorch's checkpointing infrastructure does not remap + # devices for "extra state" like it does for "state dict". + # Thus, we want to avoid putting extra state on the GPU + # since it may be loaded on the wrong device. + # (3) The extra state consists of many small tensors. If we + # want to copy them all to CPU, then we need to avoid the + # overhead of many GPU-CPU memory transfers. + # + # See: https://github.com/NVIDIA/TransformerEngine/pull/351 + # See: https://github.com/NVIDIA/TransformerEngine/pull/363 + + # Return immediately if op has no FP8 state + has_fp8_state = any( + self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") + ) + if not has_fp8_state: + return None + + def to_cpu(src: torch.Tensor) -> torch.Tensor: + """Helper function to make CPU copy of tensor + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst = torch.empty_like(src, device="cpu") + dst.copy_(src, non_blocking=True) + return dst + + # Store FP8 state + state = {} + for mode in ("input", "param", "grad_output"): + + # Get state for a given FP8 tensor + if self.num_fp8_scales(mode) == 0: + state[mode] = None + continue + fp8_meta = self.get_fp8_meta(mode) + if fp8_meta is None: + continue + state[mode] = {} + + # Store tensors + if "scaling_fwd" in fp8_meta: + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if "scaling_bwd" in fp8_meta: + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + + # Store other picklable items + extra = {} + for key, val in fp8_meta.items(): + if key == "buffer_index_and_autocast_key": + continue + if not isinstance(val, (bool, int, float, str, tuple, list)): + continue + extra[key] = val + state[mode]["extra_fp8_variables"] = extra + + # Serialize state into byte tensor + torch.cuda.synchronize() + state_serialized = bytearray(pickle.dumps(state)) + state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) + return state_serialized + + def set_extra_state(self, state: Optional[torch.Tensor]) -> None: + """Load extra state""" + if state is None: + return + + # Deserialize state from byte tensor + state = pickle.loads(state.detach().numpy(force=True).tobytes()) + if state is None: + return + + def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: + """Helper function to copy tensor from CPU + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + if src.size() != dst.size(): + dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) + dst.copy_(src, non_blocking=True) + + # Load FP8 state + for mode in ("input", "param", "grad_output"): + + # Get state for a given FP8 tensor + if mode not in state: + continue + if self.num_fp8_scales(mode) == 0: + continue + fp8_meta = self.get_fp8_meta(mode) + if fp8_meta is None: + continue + + # Load extra state + fp8_meta.update(state[mode]["extra_fp8_variables"]) + if "amax_history_fwd" in state[mode]: + fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) + elif "amax_history_bwd" in state[mode]: + fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) + if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: + del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] + + # Load tensors + fp8_meta = self.get_fp8_meta(mode) + if "scaling_fwd" in fp8_meta: + fp8_meta_fwd = fp8_meta["scaling_fwd"] + copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) + copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv) + copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) + if "scaling_bwd" in fp8_meta: + fp8_meta_bwd = fp8_meta["scaling_bwd"] + copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) + copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv) + copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + + # Finish CPU-GPU memory transfers + torch.cuda.synchronize() + + def _load_from_state_dict(self, *args, **kwargs) -> None: + """Load state""" + + # In the base PyTorch module class, the extra state is loaded + # _after_ the parameters. However, copying values into FP8 + # parameters requires an FP8 cast, which uses a scaling factor + # from the operation's FP8 metadata. The FP8 metadata is + # included in the operation's extra state, so we need to + # manually load the extra state before loading parameters. + + state_dict, prefix = args[0], args[1] + extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + super()._load_from_state_dict(*args, **kwargs) + class FusedOperation(FusibleOperation): """Compound tensor operation supported by the operation fuser