Skip to content

Commit 5529307

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2688f36 commit 5529307

File tree

2 files changed

+23
-23
lines changed

2 files changed

+23
-23
lines changed

tests/pytorch/test_sanity.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
899899
)
900900
torch.cuda.synchronize()
901901

902+
902903
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
903904
@pytest.mark.parametrize("N", [32])
904905
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@@ -912,23 +913,23 @@ def test_sanity_gemm_with_fp8quantization_and_unalignment(N, datatype):
912913
outp_type = datatype
913914
quantizer = Float8Quantizer(scales, amaxes, tex.DType.kFloat8E4M3)
914915
quantized_out, *_ = general_gemm(
915-
weight,
916-
inp,
917-
get_workspace(),
918-
outp_type,
919-
quantization_params=quantizer,
920-
bias=None,
921-
use_split_accumulator=False,
922-
)
916+
weight,
917+
inp,
918+
get_workspace(),
919+
outp_type,
920+
quantization_params=quantizer,
921+
bias=None,
922+
use_split_accumulator=False,
923+
)
923924
out, *_ = general_gemm(
924-
weight,
925-
inp,
926-
get_workspace(),
927-
outp_type,
928-
quantization_params=None,
929-
bias=None,
930-
use_split_accumulator=False,
931-
)
925+
weight,
926+
inp,
927+
get_workspace(),
928+
outp_type,
929+
quantization_params=None,
930+
bias=None,
931+
use_split_accumulator=False,
932+
)
932933
expected_quantized_out = quantizer(out)
933934
torch.testing.assert_close(expected_quantized_out, quantized_out)
934935

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,21 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
144144
TensorWrapper unquantized_D_tensor;
145145
py::object unquantized_out;
146146
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
147-
bool quantization_needed = !quantizer.is_none(); // TODO: Another use-case:
147+
bool quantization_needed = !quantizer.is_none(); // TODO: Another use-case:
148148
// If already output is FP8, then no need for quantization, since cublas is gonna take care of it.
149-
if(quantization_needed)
150-
{
149+
if (quantization_needed) {
151150
NoneQuantizer q{none};
152151
std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, gemm_output_dtype);
153152
}
154153

155-
TensorWrapper &out_tensor = quantization_needed ? unquantized_D_tensor : D_tensor;
154+
TensorWrapper& out_tensor = quantization_needed ? unquantized_D_tensor : D_tensor;
156155
// Bias tensor
157156
TensorWrapper bias_tensor;
158157
MaybeTensor bias_grad = std::nullopt;
159158
if (bias.has_value()) {
160159
if (grad) {
161-
auto opts = torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA);
160+
auto opts =
161+
torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA);
162162
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
163163
bias_tensor = makeTransformerEngineTensor(*bias_grad);
164164
} else {
@@ -281,8 +281,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
281281
}
282282
}
283283
}
284-
if(quantization_needed)
285-
my_quantizer->quantize(unquantized_D_tensor, D_tensor);
284+
if (quantization_needed) my_quantizer->quantize(unquantized_D_tensor, D_tensor);
286285
// Pack outputs
287286
std::vector<py::object> out;
288287
out.emplace_back(std::move(D));

0 commit comments

Comments
 (0)