Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensors with only column-wise data #1505

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
80 changes: 49 additions & 31 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

namespace transformer_engine {

std::string to_string(DType type);
std::string to_string(const NVTEScalingMode &mode);

inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
Expand Down Expand Up @@ -95,17 +98,8 @@ struct Tensor {
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}

int numel() const {
NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr,
"Tensor does not hold any data!");
size_t acc = 1;
if (data.dptr != nullptr) {
for (const auto &dim : data.shape) {
acc *= dim;
}
return acc;
}
// data is empty, use columnwise_data
for (const auto &dim : columnwise_data.shape) {
for (const auto dim : shape()) {
acc *= dim;
}
return acc;
Expand All @@ -122,24 +116,54 @@ struct Tensor {
return data.dtype;
}

std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
* has some bugs with std::vector (see
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]);
}
ret.push_back(columnwise_data.shape.front());
}
return ret;
} else {
return data.shape;
}
break;
case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
}
}

/*! Matrix height after tensor is flattened to 2D
*
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_first_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
return product(data_shape, 1, data_shape.size());
} else {
return product(data_shape, 0, data_shape.size() - 1);
const auto &full_shape = shape();
size_t ret = 1;
if (!full_shape.empty()) {
for (size_t i = 0; i < full_shape.size() - 1; i++) {
ret *= full_shape[i];
}
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return product(data_shape, 0, data_shape.size() - 1);
return ret;
}

/*! Matrix width after tensor is flattened to 2D
Expand All @@ -148,18 +172,12 @@ struct Tensor {
* as a (D1*D2*...*D(n-1), Dn) matrix.
*/
size_t flat_last_dim() const {
if (!has_data() && has_columnwise_data()) {
const auto &data_shape = columnwise_data.shape;
if (data_shape.empty()) return 1;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
return data_shape.front();
} else {
return data_shape.back();
}
const auto &full_shape = shape();
if (full_shape.empty()) {
return 1;
} else {
return full_shape.back();
}
const auto &data_shape = data.shape;
if (data_shape.empty()) return 1;
return data_shape.back();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ class TensorWrapper {
* \return Number of elements in the tensor.
*/
size_t numel() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
if (tensor_ == nullptr) return 0;
return nvte_tensor_numel(tensor_);
}

Expand Down
88 changes: 52 additions & 36 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,63 +212,79 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
}

NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;

// FP8 tensor keeps shape in rowwise data
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor");
}
NVTEShape ret;

// Get shape based on what data is available
if (t.has_data()) {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}
if (t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
// Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (t.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
// We can infer tensor shape if FP8 tensor only has FP8 data
// transpose. However, NVTEShape only contains a pointer and
// cannot store temporary data. We hack around this by caching
// the tensor shape within the empty FP8 data.
auto &shape_cache = const_cast<std::vector<size_t> &>(t.data.shape);
shape_cache.clear();
if (!t.columnwise_data.shape.empty()) {
for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) {
shape_cache.push_back(t.columnwise_data.shape[i]);
}
shape_cache.push_back(t.columnwise_data.shape.front());
}
ret.data = shape_cache.data();
ret.ndim = shape_cache.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (!t.has_data() && t.has_columnwise_data()) {
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
} else {
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"",
transformer_engine::to_string(t.scaling_mode), "\"");
}

// Tensor has no data
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
return ret;
}

NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0};
if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor");
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
}

size_t nvte_tensor_ndim(const NVTETensor tensor) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.shape.size();
}
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
return t.data.shape[dim];
const auto &shape = nvte_tensor_shape(tensor);
NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
" in a shape array with ", shape.ndim, " entries");
return shape.data[dim];
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
if (tensor == nullptr) return 0;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
const auto &shape = nvte_tensor_shape(tensor);
size_t numel = 1;
for (auto size : t.data.shape) {
numel *= size;
for (size_t i = 0; i < shape.ndim; i++) {
numel *= shape.data[i];
}
return numel;
}
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/common/util/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@
#include <cudnn.h>
#include <nvrtc.h>

#include <iostream>
#include <stdexcept>

#include "../util/string.h"

#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)

#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \
Expand Down
83 changes: 41 additions & 42 deletions transformer_engine/pytorch/csrc/extensions/type_converters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,72 @@
* See LICENSE for license information.
************************************************************************/

#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <transformer_engine/transformer_engine.h>

#include "common.h"
#include "pybind.h"

namespace transformer_engine::pytorch {
namespace detail {

TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) {
const at::Tensor &data = tensor.attr("_data").cast<at::Tensor>();
const at::Tensor &scale_inv = tensor.attr("_scale_inv").cast<at::Tensor>();
float *scale_inv_dptr = reinterpret_cast<float *>(scale_inv.data_ptr());
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();

const auto &shape = getTensorShape(data);
auto ret = TensorWrapper(quantizer->get_scaling_mode());

bool transpose_valid = !tensor.attr("_transpose_invalid").cast<bool>();
std::optional<at::Tensor> transpose = std::nullopt;
if (transpose_valid) {
transpose = tensor.attr("_transpose").cast<std::optional<at::Tensor>>();
// FP8 data
const DType fp8_dtype = tensor.attr("_fp8_dtype").cast<DType>();
if (!tensor.attr("_data").is_none()) {
const auto &data = tensor.attr("_data").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
}

auto ret = TensorWrapper(quantizer->get_scaling_mode());
// FP8 data transpose
if (!tensor.attr("_transpose_invalid").cast<bool>() && !tensor.attr("_transpose").is_none()) {
const auto &data_transpose = tensor.attr("_transpose").cast<at::Tensor>();
ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose));
}

ret.set_rowwise_data(data.data_ptr(), dtype, shape);
if (transpose_valid && transpose != std::nullopt) {
const auto &transpose_shape = getTensorShape(*transpose);
ret.set_columnwise_data(transpose->data_ptr(), dtype, transpose_shape);
// Scale-inverse
{
const auto &scale_inv = tensor.attr("_scale_inv").cast<at::Tensor>();
float *dptr = reinterpret_cast<float *>(scale_inv.data_ptr());
const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type());
const auto &shape = getTensorShape(scale_inv);
ret.set_rowwise_scale_inv(dptr, dtype, shape);
ret.set_columnwise_scale_inv(dptr, dtype, shape);
}

const auto scale_inv_dtype = GetTransformerEngineDType(scale_inv.scalar_type());
const auto scale_inv_shape = getTensorShape(scale_inv);
ret.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
ret.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
// Quantizer state
quantizer->set_quantization_params(&ret);

return ret;
}

TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING);

bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());

if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto &shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape);

const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat8E8M0, scale_inv_rowwise_shape);
// Row-scaled data
const DType fp8_dtype = tensor.attr("_fp8_dtype").cast<DType>();
if (!tensor.attr("_rowwise_data").is_none()) {
const auto &data = tensor.attr("_rowwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv));
}

if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto &shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);

const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat8E8M0,
scale_inv_colwise_shape);
// Column-scaled data
if (!tensor.attr("_columnwise_data").is_none()) {
const auto &data = tensor.attr("_columnwise_data").cast<at::Tensor>();
const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data));
ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0,
getTensorShape(scale_inv));
}

// Quantizer state
quantizer->set_quantization_params(&ret);

return ret;
}

Expand Down
Loading
Loading