From fceff07a59bacd517baaf6a0f9cb0fb087f117ea Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 20 Feb 2025 05:55:41 +0800 Subject: [PATCH] [PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488) * fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by: Xin Yao * fix fuse_wgrad_accumulation for GroupedLinear Signed-off-by: Xin Yao * update tests Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 33 +++++++++++-- .../pytorch/module/grouped_linear.py | 49 ++++++++++--------- 2 files changed, 53 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 22735c5292..a72ba097a1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1400,7 +1400,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): +def _test_grouped_linear_accuracy( + block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation +): reset_rng_states() if fp8: FP8GlobalStateManager.reset() @@ -1447,7 +1449,11 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f outputs = [out, inp_hidden_states.grad] for p in block.parameters(): if p.requires_grad: - outputs.append(p.grad) + if getattr(p, "main_grad", None) is not None: + outputs.append(p.main_grad) + assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True + else: + outputs.append(p.grad) return outputs @@ -1458,8 +1464,17 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + parallel_mode=None, ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -1481,6 +1496,7 @@ def test_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() sequential_linear = torch.nn.ModuleList( [ @@ -1491,6 +1507,7 @@ def test_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() for _ in range(num_gemms) ] @@ -1501,12 +1518,16 @@ def test_grouped_linear_accuracy( for i in range(num_gemms): sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, recipe, fp8 + sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) outputs = _test_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) # Shoule be bit-wise match @@ -1527,6 +1548,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): recipe=recipe, fp8_model_params=True, parallel_mode=parallel_mode, + fuse_wgrad_accumulation=True, ) @@ -1541,6 +1563,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe): fp8=True, recipe=recipe, fp8_model_params=True, + fuse_wgrad_accumulation=True, ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cab8dff7c2..10b21f25c6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -178,7 +178,6 @@ def forward( if is_grad_enabled: - saved_inputs, saved_weights = [], [] ctx.weights_shape_1 = weights[0].shape[1] tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) @@ -186,9 +185,11 @@ def forward( ctx.tensor_objects = tensor_objects ctx.weights_requires_grad = weights[0].requires_grad + if fuse_wgrad_accumulation and ctx.weights_requires_grad: + ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] + else: + ctx.main_grads = [None] * num_gemms ctx.device = device - ctx.saved_inputs = saved_inputs - ctx.saved_weights = saved_weights ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits ctx.num_gemms = num_gemms @@ -220,7 +221,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] biases = saved_tensors[2 * N : 3 * N] - main_grads = saved_tensors[3 * N :] + main_grads = ctx.main_grads if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO for i in ctx.num_gemms: @@ -281,31 +282,31 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation: - wgrad_list = [w.main_grad for w in weights] + wgrad_list = main_grads else: wgrad_list = [ torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - # WGRAD - _, grad_biases_, _ = general_grouped_gemm( - inputmats, - grad_output, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - m_splits=ctx.m_splits, - use_bias=ctx.use_bias if grad_biases[0] is None else None, - bias=biases, - use_split_accumulator=_2X_ACC_WGRAD, - accumulate=accumulate_wgrad_into_param_main_grad, - ) - for i in range(ctx.num_gemms): - if grad_biases[i] is None: - grad_biases[i] = grad_biases_[i] - del grad_biases_ + # WGRAD + _, grad_biases_, _ = general_grouped_gemm( + inputmats, + grad_output, + wgrad_list, + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ # Deallocate input tensor clear_tensor_data(*inputmats)