@@ -123,28 +123,42 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
123123 " into D tensor. Beta has nothing to be applied to." );
124124 }
125125
126+ DType gemm_output_dtype = out_dtype ? *out_dtype : A_tensor.dtype ();
126127 // Output tensor
127128 TensorWrapper D_tensor;
128129 if (D.is_none ()) {
129- DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype ();
130- std::tie (D_tensor, D) = createOutputTensor (D_shape, output_dtype, quantizer);
130+ std::tie (D_tensor, D) = createOutputTensor (D_shape, gemm_output_dtype, quantizer);
131131 } else {
132132 D_tensor = makeTransformerEngineTensor (D, quantizer);
133133 NVTE_CHECK (detail::checkGemmShape (D_shape, D_tensor.shape ()),
134134 " GEMM output has invalid dims (expected " , std::to_string (D_shape), " , got " ,
135135 std::to_string (D_tensor.shape ()), " )" );
136- if (out_dtype) {
137- NVTE_CHECK (*out_dtype == D_tensor.dtype (), " GEMM output has invalid dtype (expected " ,
138- static_cast <int >(*out_dtype), " , found " , static_cast <int >(D_tensor.dtype ()), " )" );
139- }
136+ // TODO: Rather check quantizer output data type and compare with out_dtype
137+ // if (out_dtype) {
138+ // NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
139+ // static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
140+ // }
140141 }
141142
143+ // maintain unquantized tensor in case we need to output fp8 quantized tensor from op consuming hp data type.
144+ TensorWrapper unquantized_D_tensor;
145+ py::object unquantized_out;
146+ std::unique_ptr<Quantizer> my_quantizer = convert_quantizer (quantizer);
147+ bool quantization_needed = !quantizer.is_none (); // TODO: Another use-case:
148+ // If already output is FP8, then no need for quantization, since cublas is gonna take care of it.
149+ if (quantization_needed)
150+ {
151+ NoneQuantizer q{none};
152+ std::tie (unquantized_D_tensor, unquantized_out) = q.create_tensor (D_shape, gemm_output_dtype);
153+ }
154+
155+ TensorWrapper &out_tensor = quantization_needed ? unquantized_D_tensor : D_tensor;
142156 // Bias tensor
143157 TensorWrapper bias_tensor;
144158 MaybeTensor bias_grad = std::nullopt ;
145159 if (bias.has_value ()) {
146160 if (grad) {
147- auto opts = torch::TensorOptions ().dtype (GetATenDType (D_tensor .dtype ())).device (torch::kCUDA );
161+ auto opts = torch::TensorOptions ().dtype (GetATenDType (out_tensor .dtype ())).device (torch::kCUDA );
148162 bias_grad = at::empty ({static_cast <int64_t >(B_shape.data [B_shape.ndim - 1 ])}, opts);
149163 bias_tensor = makeTransformerEngineTensor (*bias_grad);
150164 } else {
@@ -157,7 +171,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
157171
158172 // Activation input tensor
159173 MaybeTensor pre_gelu_out = std::nullopt ;
160- DType gelu_type = low_precision ? bias_type : D_tensor .dtype ();
174+ DType gelu_type = low_precision ? bias_type : out_tensor .dtype ();
161175 if (gelu) {
162176 if (!grad) {
163177 auto dtype = GetATenDType (gelu_type);
@@ -210,22 +224,22 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
210224 // Direct GEMM call to the correct overlap
211225 if (bulk_overlap) {
212226 NVTE_SCOPED_GIL_RELEASE ({
213- comm_overlap->bulk_overlap (A_tensor, transa, B_tensor, transb, D_tensor , bias_tensor,
227+ comm_overlap->bulk_overlap (A_tensor, transa, B_tensor, transb, out_tensor , bias_tensor,
214228 te_pre_gelu_out, te_workspace, grad, accumulate,
215229 use_split_accumulator, comm_type.value (), extra_output_tensor,
216230 main_stream);
217231 });
218232 } else if (comm_type.value () == CommOverlapType::AG) {
219233 if (comm_overlap->is_atomic_gemm ()) {
220234 NVTE_SCOPED_GIL_RELEASE ({
221- comm_overlap->atomic_gemm_overlap_ag (A_tensor, transa, B_tensor, transb, D_tensor ,
235+ comm_overlap->atomic_gemm_overlap_ag (A_tensor, transa, B_tensor, transb, out_tensor ,
222236 bias_tensor, te_pre_gelu_out, te_workspace, grad,
223237 accumulate, use_split_accumulator,
224238 extra_output_tensor, main_stream);
225239 });
226240 } else {
227241 NVTE_SCOPED_GIL_RELEASE ({
228- comm_overlap->split_overlap_ag (A_tensor, transa, B_tensor, transb, D_tensor ,
242+ comm_overlap->split_overlap_ag (A_tensor, transa, B_tensor, transb, out_tensor ,
229243 bias_tensor, te_pre_gelu_out, te_workspace, grad,
230244 accumulate, use_split_accumulator, extra_output_tensor,
231245 main_stream);
@@ -234,14 +248,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
234248 } else {
235249 if (comm_overlap->is_atomic_gemm ()) {
236250 NVTE_SCOPED_GIL_RELEASE ({
237- comm_overlap->atomic_gemm_overlap_rs (A_tensor, transa, B_tensor, transb, D_tensor ,
251+ comm_overlap->atomic_gemm_overlap_rs (A_tensor, transa, B_tensor, transb, out_tensor ,
238252 bias_tensor, te_pre_gelu_out, te_workspace, grad,
239253 accumulate, use_split_accumulator,
240254 extra_output_tensor, main_stream);
241255 });
242256 } else {
243257 NVTE_SCOPED_GIL_RELEASE ({
244- comm_overlap->split_overlap_rs (A_tensor, transa, B_tensor, transb, D_tensor ,
258+ comm_overlap->split_overlap_rs (A_tensor, transa, B_tensor, transb, out_tensor ,
245259 bias_tensor, te_pre_gelu_out, te_workspace, grad,
246260 accumulate, use_split_accumulator, extra_output_tensor,
247261 main_stream);
@@ -251,23 +265,24 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
251265 } else {
252266 // Launch GEMM
253267 NVTE_SCOPED_GIL_RELEASE ({
254- nvte_cublas_gemm_scaled (A_tensor.data (), B_tensor.data (), D_tensor .data (),
268+ nvte_cublas_gemm_scaled (A_tensor.data (), B_tensor.data (), out_tensor .data (),
255269 bias_tensor.data (), te_pre_gelu_out.data (), transa, transb, grad,
256270 te_workspace.data (), alpha, *beta, use_split_accumulator,
257271 num_math_sms, main_stream);
258272 });
259273 }
260274 } else {
261- if (D_tensor .numel () != 0 && !accumulate) {
262- D_tensor .zero_ (main_stream);
275+ if (out_tensor .numel () != 0 && !accumulate) {
276+ out_tensor .zero_ (main_stream);
263277 }
264278 if (bias.has_value ()) {
265279 if (bias->numel () != 0 && grad) {
266280 bias_grad->zero_ ();
267281 }
268282 }
269283 }
270-
284+ if (quantization_needed)
285+ my_quantizer->quantize (unquantized_D_tensor, D_tensor);
271286 // Pack outputs
272287 std::vector<py::object> out;
273288 out.emplace_back (std::move (D));
0 commit comments