Skip to content

Commit fcd52fc

Browse files
committed
accept mxfp8 grad_output
Signed-off-by: Xin Yao <[email protected]>
1 parent 48746dc commit fcd52fc

File tree

5 files changed

+52
-25
lines changed

5 files changed

+52
-25
lines changed

tests/pytorch/test_numerics.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import math
66
import os
7-
from pickletools import genops
87
from typing import Dict, List, Tuple, Optional
98
import pytest
109
import random
@@ -108,8 +107,6 @@
108107
"swiglu": te.ops.SwiGLU,
109108
}
110109

111-
gated_act_ops = (te.ops.GEGLU, te.ops.QGEGLU, te.ops.ReGLU, te.ops.SReGLU, te.ops.SwiGLU)
112-
113110
all_normalizations = ["LayerNorm", "RMSNorm"]
114111

115112
mask_types = ["causal", "no_mask"]
@@ -1729,20 +1726,34 @@ def _test_grouped_linear_accuracy(
17291726
fp8,
17301727
fuse_wgrad_accumulation,
17311728
delay_wgrad_compute=False,
1732-
activation_func=te.ops.Identity(),
1729+
activation_func=None,
1730+
fp8_activation=False,
17331731
):
17341732
reset_rng_states()
17351733
if fp8:
17361734
FP8GlobalStateManager.reset()
17371735

1738-
if isinstance(activation_func, gated_act_ops) or (
1739-
isinstance(activation_func, te.ops.Sequential)
1740-
and isinstance(activation_func[0], gated_act_ops)
1741-
):
1736+
if activation_func in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
17421737
hidden_size = 2 * config.hidden_size
17431738
else:
17441739
hidden_size = config.hidden_size
17451740

1741+
if activation_func is None:
1742+
input_act_func = te.ops.Identity()
1743+
output_act_func = te.ops.Identity()
1744+
elif fp8_activation:
1745+
input_act_func = te.ops.Sequential(
1746+
act_ops[activation_func](),
1747+
te.ops.Quantize(forward=True, backward=False), # Output QuantizedTensor in forward
1748+
)
1749+
output_act_func = te.ops.Sequential(
1750+
te.ops.Quantize(forward=False, backward=True), # Output QuantizedTensor in backward
1751+
act_ops[activation_func](),
1752+
)
1753+
else:
1754+
input_act_func = act_ops[activation_func]()
1755+
output_act_func = act_ops[activation_func]()
1756+
17461757
inp_hidden_states = torch.randn(
17471758
(config.max_seqlen_q, bs, hidden_size),
17481759
dtype=dtype,
@@ -1769,11 +1780,11 @@ def _test_grouped_linear_accuracy(
17691780
with fp8_autocast(enabled=fp8, fp8_recipe=recipe):
17701781
if isinstance(block, GroupedLinear):
17711782
m_splits = m_splits * bs
1772-
out = block(activation_func(inp_hidden_states), m_splits.tolist())
1783+
out = output_act_func(block(input_act_func(inp_hidden_states), m_splits.tolist()))
17731784
else:
17741785
out = torch.cat(
17751786
[
1776-
block[i](activation_func(inp))
1787+
output_act_func(block[i](input_act_func(inp)))
17771788
for i, inp in enumerate(torch.split(inp_hidden_states, m_splits.tolist()))
17781789
]
17791790
)
@@ -2063,12 +2074,6 @@ def test_grouped_linear_fp8_input(
20632074
weight_i_copy = getattr(grouped_linear_fp8_input, f"weight{i}")
20642075
weight_i_copy.main_grad = weight_i.main_grad.clone()
20652076

2066-
bf16_activation = act_ops[activation]()
2067-
fp8_activation = te.ops.Sequential(
2068-
bf16_activation,
2069-
te.ops.Quantize(forward=True, backward=False), # Output QuantizedTensor in forward
2070-
)
2071-
20722077
outputs_ref = _test_grouped_linear_accuracy(
20732078
grouped_linear_bf16_input,
20742079
num_gemms,
@@ -2078,7 +2083,8 @@ def test_grouped_linear_fp8_input(
20782083
recipe,
20792084
fp8=True,
20802085
fuse_wgrad_accumulation=True,
2081-
activation_func=bf16_activation,
2086+
activation_func=activation,
2087+
fp8_activation=False,
20822088
)
20832089
outputs = _test_grouped_linear_accuracy(
20842090
grouped_linear_fp8_input,
@@ -2089,7 +2095,8 @@ def test_grouped_linear_fp8_input(
20892095
recipe,
20902096
fp8=True,
20912097
fuse_wgrad_accumulation=True,
2092-
activation_func=fp8_activation,
2098+
activation_func=activation,
2099+
fp8_activation=True,
20932100
)
20942101
# Shoule be bit-wise match
20952102
for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)):

transformer_engine/common/util/cast_gated_kernels.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8
4444
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32
4545
static_assert(ITERATIONS >= 1);
4646

47-
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }
48-
4947
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
5048
float (*DActOP)(float, const ParamOP &), typename IType, typename OType>
5149
__global__ void __launch_bounds__(THREADS_PER_CHUNK)

transformer_engine/common/util/math.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
namespace transformer_engine {
1111

12+
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }
13+
1214
struct Empty {};
1315

1416
template <typename OType, typename IType>
@@ -28,7 +30,7 @@ __device__ inline OType dgelu(const IType val, const Empty&) {
2830
template <typename OType, typename IType>
2931
__device__ inline OType sigmoid(const IType val, const Empty&) {
3032
const float cval = val;
31-
return 1.f / (1.f + expf(-cval));
33+
return sigmoidf(cval);
3234
}
3335

3436
template <typename OType, typename IType>

transformer_engine/common/util/vectorized_pointwise.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "../common.h"
1313
#include "../utils.cuh"
14+
#include "math.h"
1415

1516
namespace transformer_engine {
1617

@@ -531,8 +532,18 @@ __launch_bounds__(unary_kernel_threads) __global__
531532
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
532533
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
533534

534-
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
535-
ComputeType after_dgate = grad_val * Activation(gelu_in, p);
535+
ComputeType act_in, dact_in;
536+
if constexpr ((Activation == &silu<fp32, fp32>) && (Dactivation == &dsilu<fp32, fp32>)) {
537+
const float s = sigmoidf(gelu_in);
538+
dact_in = gelu_in * s * (1 - s) + s;
539+
act_in = gelu_in * s;
540+
} else {
541+
dact_in = Dactivation(gelu_in, p);
542+
act_in = Activation(gelu_in, p);
543+
}
544+
545+
ComputeType after_dgelu = dact_in * grad_val * gate_in;
546+
ComputeType after_dgate = grad_val * act_in;
536547

537548
if (requires_amax) {
538549
__builtin_assume(max >= 0);

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
282282
weights[i] = w
283283

284284
# Preprocess grad output
285-
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
285+
grad_output_view = grad_output
286286
grad_output = [None] * ctx.num_gemms
287287
grad_biases = [None] * ctx.num_gemms
288-
if ctx.fp8:
288+
if isinstance(grad_output_view, QuantizedTensorBase):
289+
assert not ctx.use_bias, "Bias is not supported for quantized grad output"
290+
grad_output = tex.split_quantized_tensor(grad_output_view, ctx.m_splits)
291+
elif ctx.fp8:
292+
grad_output_view = grad_output_view.contiguous().view(
293+
-1, grad_output_view.shape[-1]
294+
)
289295
if ctx.use_bias:
290296
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
291297
recipe = ctx.fp8_recipe
@@ -313,6 +319,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
313319
ctx.grad_output_quantizers,
314320
)
315321
else:
322+
grad_output_view = grad_output_view.contiguous().view(
323+
-1, grad_output_view.shape[-1]
324+
)
316325
# Only split grad output. Grad bias is fused with
317326
# wgrad GEMM.
318327
grad_output = torch.split(

0 commit comments

Comments
 (0)