Skip to content

Commit 2688f36

Browse files
committed
Test working as I think it should work
1 parent 1398fa5 commit 2688f36

File tree

2 files changed

+66
-29
lines changed

2 files changed

+66
-29
lines changed

tests/pytorch/test_sanity.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,7 @@ def is_fp8_supported(config: ModelConfig):
104104
all_boolean = [True, False]
105105
batch_sizes_with_zero = [0, 1, 2]
106106

107-
all_activations = [
108-
"gelu",
109-
"geglu",
110-
"qgelu",
111-
"qgeglu",
112-
"relu",
113-
"reglu",
114-
"srelu",
115-
"sreglu",
116-
"silu",
117-
"swiglu",
118-
]
107+
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
119108
all_normalizations = ["LayerNorm", "RMSNorm"]
120109

121110

@@ -910,6 +899,39 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
910899
)
911900
torch.cuda.synchronize()
912901

902+
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
903+
@pytest.mark.parametrize("N", [32])
904+
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
905+
def test_sanity_gemm_with_fp8quantization_and_unalignment(N, datatype):
906+
offset = 16
907+
scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype)
908+
inp = torch.reshape(scratchpad[0][:-offset], (N, N))
909+
weight = torch.reshape(scratchpad[0][offset:], (N, N))
910+
scales = torch.ones(1).cuda().squeeze()
911+
amaxes = torch.ones(1).cuda().squeeze()
912+
outp_type = datatype
913+
quantizer = Float8Quantizer(scales, amaxes, tex.DType.kFloat8E4M3)
914+
quantized_out, *_ = general_gemm(
915+
weight,
916+
inp,
917+
get_workspace(),
918+
outp_type,
919+
quantization_params=quantizer,
920+
bias=None,
921+
use_split_accumulator=False,
922+
)
923+
out, *_ = general_gemm(
924+
weight,
925+
inp,
926+
get_workspace(),
927+
outp_type,
928+
quantization_params=None,
929+
bias=None,
930+
use_split_accumulator=False,
931+
)
932+
expected_quantized_out = quantizer(out)
933+
torch.testing.assert_close(expected_quantized_out, quantized_out)
934+
913935

914936
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
915937
def test_replace_raw_data_for_float8tensor():

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,28 +123,42 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
123123
"into D tensor. Beta has nothing to be applied to.");
124124
}
125125

126+
DType gemm_output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
126127
// Output tensor
127128
TensorWrapper D_tensor;
128129
if (D.is_none()) {
129-
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
130-
std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer);
130+
std::tie(D_tensor, D) = createOutputTensor(D_shape, gemm_output_dtype, quantizer);
131131
} else {
132132
D_tensor = makeTransformerEngineTensor(D, quantizer);
133133
NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()),
134134
"GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ",
135135
std::to_string(D_tensor.shape()), ")");
136-
if (out_dtype) {
137-
NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
138-
static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
139-
}
136+
// TODO: Rather check quantizer output data type and compare with out_dtype
137+
// if (out_dtype) {
138+
// NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
139+
// static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
140+
// }
140141
}
141142

143+
// maintain unquantized tensor in case we need to output fp8 quantized tensor from op consuming hp data type.
144+
TensorWrapper unquantized_D_tensor;
145+
py::object unquantized_out;
146+
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
147+
bool quantization_needed = !quantizer.is_none(); // TODO: Another use-case:
148+
// If already output is FP8, then no need for quantization, since cublas is gonna take care of it.
149+
if(quantization_needed)
150+
{
151+
NoneQuantizer q{none};
152+
std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, gemm_output_dtype);
153+
}
154+
155+
TensorWrapper &out_tensor = quantization_needed ? unquantized_D_tensor : D_tensor;
142156
// Bias tensor
143157
TensorWrapper bias_tensor;
144158
MaybeTensor bias_grad = std::nullopt;
145159
if (bias.has_value()) {
146160
if (grad) {
147-
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
161+
auto opts = torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA);
148162
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
149163
bias_tensor = makeTransformerEngineTensor(*bias_grad);
150164
} else {
@@ -157,7 +171,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
157171

158172
// Activation input tensor
159173
MaybeTensor pre_gelu_out = std::nullopt;
160-
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
174+
DType gelu_type = low_precision ? bias_type : out_tensor.dtype();
161175
if (gelu) {
162176
if (!grad) {
163177
auto dtype = GetATenDType(gelu_type);
@@ -210,22 +224,22 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
210224
// Direct GEMM call to the correct overlap
211225
if (bulk_overlap) {
212226
NVTE_SCOPED_GIL_RELEASE({
213-
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
227+
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor,
214228
te_pre_gelu_out, te_workspace, grad, accumulate,
215229
use_split_accumulator, comm_type.value(), extra_output_tensor,
216230
main_stream);
217231
});
218232
} else if (comm_type.value() == CommOverlapType::AG) {
219233
if (comm_overlap->is_atomic_gemm()) {
220234
NVTE_SCOPED_GIL_RELEASE({
221-
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
235+
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
222236
bias_tensor, te_pre_gelu_out, te_workspace, grad,
223237
accumulate, use_split_accumulator,
224238
extra_output_tensor, main_stream);
225239
});
226240
} else {
227241
NVTE_SCOPED_GIL_RELEASE({
228-
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
242+
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
229243
bias_tensor, te_pre_gelu_out, te_workspace, grad,
230244
accumulate, use_split_accumulator, extra_output_tensor,
231245
main_stream);
@@ -234,14 +248,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
234248
} else {
235249
if (comm_overlap->is_atomic_gemm()) {
236250
NVTE_SCOPED_GIL_RELEASE({
237-
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
251+
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
238252
bias_tensor, te_pre_gelu_out, te_workspace, grad,
239253
accumulate, use_split_accumulator,
240254
extra_output_tensor, main_stream);
241255
});
242256
} else {
243257
NVTE_SCOPED_GIL_RELEASE({
244-
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
258+
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
245259
bias_tensor, te_pre_gelu_out, te_workspace, grad,
246260
accumulate, use_split_accumulator, extra_output_tensor,
247261
main_stream);
@@ -251,23 +265,24 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
251265
} else {
252266
// Launch GEMM
253267
NVTE_SCOPED_GIL_RELEASE({
254-
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(),
268+
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(),
255269
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
256270
te_workspace.data(), alpha, *beta, use_split_accumulator,
257271
num_math_sms, main_stream);
258272
});
259273
}
260274
} else {
261-
if (D_tensor.numel() != 0 && !accumulate) {
262-
D_tensor.zero_(main_stream);
275+
if (out_tensor.numel() != 0 && !accumulate) {
276+
out_tensor.zero_(main_stream);
263277
}
264278
if (bias.has_value()) {
265279
if (bias->numel() != 0 && grad) {
266280
bias_grad->zero_();
267281
}
268282
}
269283
}
270-
284+
if(quantization_needed)
285+
my_quantizer->quantize(unquantized_D_tensor, D_tensor);
271286
// Pack outputs
272287
std::vector<py::object> out;
273288
out.emplace_back(std::move(D));

0 commit comments

Comments
 (0)