Skip to content

Commit

Permalink
[PyTorch] Fix fuse_wgrad_accumulation for GroupedLinear (#1488)
Browse files Browse the repository at this point in the history
* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <[email protected]>

* fix fuse_wgrad_accumulation for GroupedLinear

Signed-off-by: Xin Yao <[email protected]>

* update tests

Signed-off-by: Xin Yao <[email protected]>

---------

Signed-off-by: Xin Yao <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
yaox12 and timmoon10 authored Feb 19, 2025
1 parent 56c0c07 commit fceff07
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 29 deletions.
33 changes: 28 additions & 5 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype])


def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False):
def _test_grouped_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
):
reset_rng_states()
if fp8:
FP8GlobalStateManager.reset()
Expand Down Expand Up @@ -1447,7 +1449,11 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs


Expand All @@ -1458,8 +1464,17 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
def test_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None
dtype,
num_gemms,
bs,
model,
fp8,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
parallel_mode=None,
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
Expand All @@ -1481,6 +1496,7 @@ def test_grouped_linear_accuracy(
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
sequential_linear = torch.nn.ModuleList(
[
Expand All @@ -1491,6 +1507,7 @@ def test_grouped_linear_accuracy(
params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
).eval()
for _ in range(num_gemms)
]
Expand All @@ -1501,12 +1518,16 @@ def test_grouped_linear_accuracy(
for i in range(num_gemms):
sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone())
sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone())
if fuse_wgrad_accumulation:
weight_i = getattr(grouped_linear, f"weight{i}")
weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32)
sequential_linear[i].weight.main_grad = weight_i.main_grad.clone()

outputs_ref = _test_grouped_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)
outputs = _test_grouped_linear_accuracy(
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8
grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
)

# Shoule be bit-wise match
Expand All @@ -1527,6 +1548,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe):
recipe=recipe,
fp8_model_params=True,
parallel_mode=parallel_mode,
fuse_wgrad_accumulation=True,
)


Expand All @@ -1541,6 +1563,7 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
fp8=True,
recipe=recipe,
fp8_model_params=True,
fuse_wgrad_accumulation=True,
)


Expand Down
49 changes: 25 additions & 24 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,18 @@ def forward(

if is_grad_enabled:

saved_inputs, saved_weights = [], []
ctx.weights_shape_1 = weights[0].shape[1]

tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects

ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
else:
ctx.main_grads = [None] * num_gemms
ctx.device = device
ctx.saved_inputs = saved_inputs
ctx.saved_weights = saved_weights
ctx.grad_output_quantizers = grad_output_quantizers
ctx.m_splits = m_splits
ctx.num_gemms = num_gemms
Expand Down Expand Up @@ -220,7 +221,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N]
main_grads = saved_tensors[3 * N :]
main_grads = ctx.main_grads

if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
for i in ctx.num_gemms:
Expand Down Expand Up @@ -281,31 +282,31 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation:
wgrad_list = [w.main_grad for w in weights]
wgrad_list = main_grads
else:
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights
]
# WGRAD
_, grad_biases_, _ = general_grouped_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# WGRAD
_, grad_biases_, _ = general_grouped_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_

# Deallocate input tensor
clear_tensor_data(*inputmats)
Expand Down

0 comments on commit fceff07

Please sign in to comment.