diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ca9103532d..6dad57cdfa 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -28,6 +28,9 @@ namespace transformer_engine { +std::string to_string(DType type); +std::string to_string(const NVTEScalingMode &mode); + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -95,17 +98,8 @@ struct Tensor { scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} int numel() const { - NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, - "Tensor does not hold any data!"); size_t acc = 1; - if (data.dptr != nullptr) { - for (const auto &dim : data.shape) { - acc *= dim; - } - return acc; - } - // data is empty, use columnwise_data - for (const auto &dim : columnwise_data.shape) { + for (const auto dim : shape()) { acc *= dim; } return acc; @@ -122,24 +116,54 @@ struct Tensor { return data.dtype; } + std::vector shape() const { + /* Note: We sometimes experience spurious compiler errors + * (-Wstringop-overflow) from this function. It appears that GCC + * has some bugs with std::vector (see + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). + */ + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); + } + return ret; + } else { + return data.shape; + } + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); + return {}; + } + } + /*! Matrix height after tensor is flattened to 2D * * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_first_dim() const { - if (!has_data() && has_columnwise_data()) { - const auto &data_shape = columnwise_data.shape; - if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - return product(data_shape, 1, data_shape.size()); - } else { - return product(data_shape, 0, data_shape.size() - 1); + const auto &full_shape = shape(); + size_t ret = 1; + if (!full_shape.empty()) { + for (size_t i = 0; i < full_shape.size() - 1; i++) { + ret *= full_shape[i]; } } - const auto &data_shape = data.shape; - if (data_shape.empty()) return 1; - return product(data_shape, 0, data_shape.size() - 1); + return ret; } /*! Matrix width after tensor is flattened to 2D @@ -148,18 +172,12 @@ struct Tensor { * as a (D1*D2*...*D(n-1), Dn) matrix. */ size_t flat_last_dim() const { - if (!has_data() && has_columnwise_data()) { - const auto &data_shape = columnwise_data.shape; - if (data_shape.empty()) return 1; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - return data_shape.front(); - } else { - return data_shape.back(); - } + const auto &full_shape = shape(); + if (full_shape.empty()) { + return 1; + } else { + return full_shape.back(); } - const auto &data_shape = data.shape; - if (data_shape.empty()) return 1; - return data_shape.back(); } }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e393dbffc4..3523d960e6 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -501,7 +501,7 @@ class TensorWrapper { * \return Number of elements in the tensor. */ size_t numel() const noexcept { - if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + if (tensor_ == nullptr) return 0; return nvte_tensor_numel(tensor_); } diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index faf6ec990d..cc935f2d75 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -212,37 +212,58 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; - const auto &t = *reinterpret_cast(tensor); - NVTEShape ret; - - // FP8 tensor keeps shape in rowwise data - if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - return ret; + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); } + NVTEShape ret; - // Get shape based on what data is available - if (t.has_data()) { - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); - return ret; - } - if (t.has_columnwise_data()) { - ret.data = t.columnwise_data.shape.data(); - ret.ndim = t.columnwise_data.shape.size(); - return ret; + // Determine tensor shape depending on tensor format + const auto &t = *reinterpret_cast(tensor); + switch (t.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!t.has_data() && t.has_columnwise_data()) { + // We can infer tensor shape if FP8 tensor only has FP8 data + // transpose. However, NVTEShape only contains a pointer and + // cannot store temporary data. We hack around this by caching + // the tensor shape within the empty FP8 data. + auto &shape_cache = const_cast &>(t.data.shape); + shape_cache.clear(); + if (!t.columnwise_data.shape.empty()) { + for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { + shape_cache.push_back(t.columnwise_data.shape[i]); + } + shape_cache.push_back(t.columnwise_data.shape.front()); + } + ret.data = shape_cache.data(); + ret.ndim = shape_cache.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (!t.has_data() && t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", + transformer_engine::to_string(t.scaling_mode), "\""); } - // Tensor has no data - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); return ret; } NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { - if (tensor == nullptr) return {nullptr, 0}; + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); + } const auto &t = *reinterpret_cast(tensor); NVTEShape ret; ret.data = t.columnwise_data.shape.data(); @@ -250,25 +271,20 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndim(const NVTETensor tensor) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); - return t.data.shape.size(); -} +size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); - return t.data.shape[dim]; + const auto &shape = nvte_tensor_shape(tensor); + NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim, + " in a shape array with ", shape.ndim, " entries"); + return shape.data[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { - if (tensor == nullptr) return 0; - const auto &t = *reinterpret_cast(tensor); + const auto &shape = nvte_tensor_shape(tensor); size_t numel = 1; - for (auto size : t.data.shape) { - numel *= size; + for (size_t i = 0; i < shape.ndim; i++) { + numel *= shape.data[i]; } return numel; } diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 10a4ec28dc..173aad52af 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,10 +12,18 @@ #include #include +#include #include #include "../util/string.h" +#define NVTE_WARN(...) \ + do { \ + std::cerr << ::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \ + } while (false) + #define NVTE_ERROR(...) \ do { \ throw ::std::runtime_error(::transformer_engine::concat_strings( \ diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d2607e4ed0..ed29aab703 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -4,6 +4,10 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include +#include + #include "common.h" #include "pybind.h" @@ -11,66 +15,61 @@ namespace transformer_engine::pytorch { namespace detail { TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { - const at::Tensor &data = tensor.attr("_data").cast(); - const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast(); - float *scale_inv_dptr = reinterpret_cast(scale_inv.data_ptr()); - const DType dtype = tensor.attr("_fp8_dtype").cast(); - - const auto &shape = getTensorShape(data); + auto ret = TensorWrapper(quantizer->get_scaling_mode()); - bool transpose_valid = !tensor.attr("_transpose_invalid").cast(); - std::optional transpose = std::nullopt; - if (transpose_valid) { - transpose = tensor.attr("_transpose").cast>(); + // FP8 data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (!tensor.attr("_data").is_none()) { + const auto &data = tensor.attr("_data").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); } - auto ret = TensorWrapper(quantizer->get_scaling_mode()); + // FP8 data transpose + if (!tensor.attr("_transpose_invalid").cast() && !tensor.attr("_transpose").is_none()) { + const auto &data_transpose = tensor.attr("_transpose").cast(); + ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose)); + } - ret.set_rowwise_data(data.data_ptr(), dtype, shape); - if (transpose_valid && transpose != std::nullopt) { - const auto &transpose_shape = getTensorShape(*transpose); - ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape); + // Scale-inverse + { + const auto &scale_inv = tensor.attr("_scale_inv").cast(); + float *dptr = reinterpret_cast(scale_inv.data_ptr()); + const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto &shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(dptr, dtype, shape); + ret.set_columnwise_scale_inv(dptr, dtype, shape); } - const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type()); - const auto scale_inv_shape = getTensorShape(scale_inv); - ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); - ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { - const DType dtype = tensor.attr("_fp8_dtype").cast(); auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); - bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); - bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); - - if (rowwise_usage) { - const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); - const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); - void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto &shape = getTensorShape(data_rowwise); - ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); - - const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); - ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape); + // Row-scaled data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (!tensor.attr("_rowwise_data").is_none()) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); } - if (columnwise_usage) { - const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); - const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); - void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto &shape = getTensorShape(data_colwise); - ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); - - const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); - ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0, - scale_inv_colwise_shape); + // Column-scaled data + if (!tensor.attr("_columnwise_data").is_none()) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, + getTensorShape(scale_inv)); } + // Quantizer state quantizer->set_quantization_params(&ret); + return ret; } diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fe023208d1..04fa8aa1ed 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -923,20 +923,29 @@ def _all_gather_mxfp8( if out_shape is None: out_shape = [in_shape[0] * world_size] + in_shape[1:] - # Gather MXFP8 data for row-wise usage - if quantizer.rowwise_usage and not quantizer.columnwise_usage: + # Cast input tensor to MXFP8 with required data + if not isinstance(input_, MXFP8TensorBase): + input_ = quantizer(input_) + elif ( + input_.rowwise_data is None + and quantizer.rowwise_usage + or input_.columnwise_data is None + and quantizer.columnwise_usage + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to MXFP8." + ) + input_ = quantizer(input_.dequantize()) - # Cast input tensor to MXFP8 if needed - if not isinstance(input_, MXFP8TensorBase): - input_ = quantizer(input_) + # Construct MXFP8 output tensor + out = quantizer.make_empty(out_shape, dtype=input_.dtype, device=input._device) - # Construct MXFP8 output tensor - dtype = torch.float32 - device = "cuda" - if isinstance(input_, MXFP8Tensor): - dtype = input_.dtype - device = input_.device - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Async op handle + handle = None + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage: # Remove padding from MXFP8 scale-inverses in_scale_inv = input_._rowwise_scale_inv @@ -948,36 +957,48 @@ def _all_gather_mxfp8( out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers - with torch.distributed._coalescing_manager( + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: - torch.distributed.all_gather_into_tensor( - out._rowwise_data, - input_._rowwise_data, - group=process_group, - ) - torch.distributed.all_gather_into_tensor( - out_scale_inv, - in_scale_inv, - group=process_group, - ) - handle = coalescing_manager if async_op else None - return out, handle + ) + handle = torch.distributed.all_gather_into_tensor( + out._rowwise_data, + input_._rowwise_data, + group=process_group, + async_op=async_op, + ) - # Gather in high precision and quantize for column-wise usage - if isinstance(input_, QuantizedTensor): - input_ = input_.dequantize(dtype=torch.bfloat16) - out = torch.empty( - out_shape, - dtype=input_.dtype, - device=input_.device, - memory_format=torch.contiguous_format, - ) - torch.distributed.all_gather_into_tensor(out, input_, group=process_group) - out = quantizer(out) - return out, None + # Gather MXFP8 data for column-wise usage + if quantizer.columnwise_usage: + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = input_._columnwise_scale_inv + out_scale_inv = out._columnwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = torch.distributed.all_gather_into_tensor( + out._columnwise_data, + input_._columnwise_data, + group=process_group, + async_op=async_op, + ) + + return out, handle def gather_along_first_dim( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 01bda64101..55f7ba7653 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -310,6 +310,12 @@ def forward( clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + + # Input with column-wise usage is needed for dgrad GEMM + if backward_needs_input: + if isinstance(ln_out, QuantizedTensor): + ln_out.update_usage(rowwise_usage=False) + if cpu_offloading: if fp8 and weightmat is not None: set_offloading_param(weightmat, "weight_offloading", True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bae21eebfd..7b511a186a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -154,6 +154,7 @@ def forward( ) if not isinstance(inputmat, QuantizedTensor): inputmat = input_quantizer(inputmat) + own_quantized_input = True elif backward_needs_input: inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) inputmat_total = inputmat diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 892e120da1..cb93eb5e6b 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -523,9 +523,7 @@ def _functional_forward( # Configure input tensor for backward pass if own_quantized_x_local: - ### TODO Restore once column-wise usage is supported by itself # pylint: disable=fixme - # x_local.update_usage(rowwise_usage=False) - pass + x_local.update_usage(rowwise_usage=False) # Detach input tensor if needed # Note: PyTorch autograd produces esoteric errors if we save diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index da788182a0..7210217856 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -267,21 +267,40 @@ def _create_transpose(self): self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose_invalid = False - def update_usage(self, rowwise_usage=True, columnwise_usage=True): - assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor" - if rowwise_usage: - assert self._data is not None, "Rowwise usage of the tensor was already disabled" - else: - if not non_tn_fp8_gemm_supported(): - if self._transpose is None or self._transpose_invalid: - self._create_transpose() - self._data = None - if columnwise_usage: - if self._transpose is None or self._transpose_invalid: - assert self._data is not None, "The tensor does not hold any data anymore" - if not non_tn_fp8_gemm_supported(): - self._create_transpose() + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + # Figure out what data is available and what is required + has_data = self._data is not None + has_data_transpose = self._transpose is not None and not self._transpose_invalid + needs_data = has_data + needs_data_transpose = has_data_transpose + if non_tn_fp8_gemm_supported(): + if rowwise_usage is not None and rowwise_usage: + needs_data = True + if columnwise_usage is not None and columnwise_usage: + needs_data = True + needs_data_transpose = False else: + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self._data = None + if not needs_data_transpose: self._transpose = None self._transpose_invalid = True diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 86b13415a1..0db213684c 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -207,36 +207,50 @@ def detach(self) -> MXFP8Tensor: # TODO(ksivamani): Fix the detach bug return MXFP8Tensor.make_like(self) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): """ For MXFP8, columnwise scaled output is only produced by x2 scaling kernels, so this function only disables usages. """ - assert rowwise_usage or columnwise_usage, "Could not disable all usages of the tensor." - - if columnwise_usage and rowwise_usage: - assert ( - self._rowwise_data is not None - and self._rowwise_scale_inv is not None - and self._columnwise_data is not None - and self._columnwise_scale_inv is not None - ), "Cannot update to rowwise and columnwise usage." - return + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data if rowwise_usage: - assert ( - self._rowwise_data is not None and self._rowwise_scale_inv is not None - ), "Cannot update to rowwise usage." + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but MXFP8Tensor is missing column-scaled scale-inverses" + ) + else: self._columnwise_data = None self._columnwise_scale_inv = None - return - - assert ( - self._columnwise_data is not None and self._columnwise_scale_inv is not None - ), "Cannot update to columnwise usage." - self._rowwise_data = None - self._rowwise_scale_inv = None - return def clone(self) -> MXFP8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index ef21412ca7..2f449a3b09 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -289,7 +289,11 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) - def update_usage(self, rowwise_usage=True, columnwise_usage=True): + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): """Indicate to the tensor how it is going to be used This enables optimizations to memory usage in some cases