@@ -144,21 +144,21 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
144144 TensorWrapper unquantized_D_tensor;
145145 py::object unquantized_out;
146146 std::unique_ptr<Quantizer> my_quantizer = convert_quantizer (quantizer);
147- bool quantization_needed = !quantizer.is_none (); // TODO: Another use-case:
147+ bool quantization_needed = !quantizer.is_none (); // TODO: Another use-case:
148148 // If already output is FP8, then no need for quantization, since cublas is gonna take care of it.
149- if (quantization_needed)
150- {
149+ if (quantization_needed) {
151150 NoneQuantizer q{none};
152151 std::tie (unquantized_D_tensor, unquantized_out) = q.create_tensor (D_shape, gemm_output_dtype);
153152 }
154153
155- TensorWrapper & out_tensor = quantization_needed ? unquantized_D_tensor : D_tensor;
154+ TensorWrapper& out_tensor = quantization_needed ? unquantized_D_tensor : D_tensor;
156155 // Bias tensor
157156 TensorWrapper bias_tensor;
158157 MaybeTensor bias_grad = std::nullopt ;
159158 if (bias.has_value ()) {
160159 if (grad) {
161- auto opts = torch::TensorOptions ().dtype (GetATenDType (out_tensor.dtype ())).device (torch::kCUDA );
160+ auto opts =
161+ torch::TensorOptions ().dtype (GetATenDType (out_tensor.dtype ())).device (torch::kCUDA );
162162 bias_grad = at::empty ({static_cast <int64_t >(B_shape.data [B_shape.ndim - 1 ])}, opts);
163163 bias_tensor = makeTransformerEngineTensor (*bias_grad);
164164 } else {
@@ -281,8 +281,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
281281 }
282282 }
283283 }
284- if (quantization_needed)
285- my_quantizer->quantize (unquantized_D_tensor, D_tensor);
284+ if (quantization_needed) my_quantizer->quantize (unquantized_D_tensor, D_tensor);
286285 // Pack outputs
287286 std::vector<py::object> out;
288287 out.emplace_back (std::move (D));
0 commit comments