Skip to content

[BugFix][Codegen, CUDA] Fix faulty codegen for FP8 #17673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ def te_gelu(x: te.Tensor):
dtype = x.dtype
erf_inp = x * tir.const(0.5**0.5, dtype)

if dtype == "float16":
erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), "float16")
if dtype == "float16" or dtype == "float8_e5m2" or dtype == "float8_e4m3fn":
erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), dtype)
else:
erf = topi.erf(erf_inp)

Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/distributed/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx)
ctx->ReportFatal(Diagnostic::Error(call)
<< "Input of distributed operator must have known ndim");
}
if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float()) {
if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float() &&
!input_tensor_sinfo->dtype.is_float16() && !input_tensor_sinfo->dtype.is_float8()) {
ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float "
"dtype. However, the given input dtype is "
<< input_tensor_sinfo->dtype);
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/distributed/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx,
TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo;

if (require_float_dtype && !input_tensor_sinfo->IsUnknownDtype() &&
!input_tensor_sinfo->dtype.is_float()) {
!input_tensor_sinfo->dtype.is_float() && !input_tensor_sinfo->dtype.is_float16() &&
!input_tensor_sinfo->dtype.is_float8()) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< call->op
Expand Down
3 changes: 2 additions & 1 deletion src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) {
if (data_sinfo->IsUnknownNdim()) {
return data_sinfo;
}
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() &&
!data_sinfo->dtype.is_float16() && !data_sinfo->dtype.is_float8()) {
ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float "
"dtype. However, the given input dtype is "
<< data_sinfo->dtype);
Expand Down
5 changes: 3 additions & 2 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ template <bool require_float_dtype, typename FType>
inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx,
FType f_compute_out_dtype) {
TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
if (require_float_dtype && !input_sinfo->IsUnknownDtype() &&
(!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) {
if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float() &&
!input_sinfo->dtype.is_bfloat() && !input_sinfo->dtype.is_float16() &&
!input_sinfo->dtype.is_float8()) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< call->op
Expand Down
16 changes: 13 additions & 3 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
bool is_vol = IsVolatile(buffer_var);
bool is_vol = IsVolatile(buffer_var) && !t.is_float8();

auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
std::ostringstream ptr_os;
Expand Down Expand Up @@ -840,7 +840,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
stream << ref << " = ";
stream << value << ";\n";
} else {
arith::PVar<PrimExpr> base;

Expand Down Expand Up @@ -876,7 +877,16 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
stream << '[';
PrintVecElemLoad(index, index_expr.dtype(), i, stream);
stream << "] = ";
PrintVecElemLoad(value, op->value.dtype(), i, stream);
if (op->value.dtype().is_float8()) {
ICHECK(value_dtype.lanes() == 2);
std::string fp8_type = op->value.dtype().is_float8_e5m2() ? "e5m2" : "e4m3";
static const char access[] = {'x', 'y'};
stream << "__nv_fp8_" << fp8_type << "(__half2(";
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ")." << access[i % 2] << ")";
} else {
PrintVecElemLoad(value, op->value.dtype(), i, stream);
}
stream << ";\n";
}
EndScope(vec_scope);
Expand Down
214 changes: 199 additions & 15 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}

void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {
os << "make_";
if (!t.is_float8()) {
os << "make_"; // There is no make___nv_fp8 (/usr/local/cuda/include/vector_functions.hpp)
}
PrintType(t, os);
}

Expand All @@ -533,22 +535,56 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
if (t.is_float8()) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
ICHECK(t.is_float8_e4m3fn() || t.is_float8_e5m2());
if (t.lanes() == 2) {
value_temp << "__nv_fp8x2_" << (t.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
} else {
value_temp << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << op;
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
value_temp << "__nv_fp8x4_" << (t.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
}
for (int i = 0, lanes = t.lanes() / 2; i < lanes; ++i) {
if (isalpha(op[0]) || op[0] == '_') {
value_temp << op << "2"
<< "(__half2(";
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
value_temp << "), __half2(";
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
value_temp << "))";
} else {
value_temp << "__half2(";
PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp);
value_temp << ") " << op << " __half2(";
PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp);
value_temp << ")";
}

if (i != lanes - 1) {
value_temp << ", ";
}
if (i == lanes - 1) {
value_temp << ")";
PrintVecElemStore(sret, t, i, value_temp.str());
}
}
} else {
for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
value_temp << op;
PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
}
EndScope(ssa_scope);
Expand All @@ -563,6 +599,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
}

static const char access[] = {'x', 'y', 'z', 'w'};
std::string fp8_type = (t.is_float8()) ? (t.is_float8_e4m3fn() ? "e4m3" : "e5m2") : "";
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
Expand All @@ -584,6 +621,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
} else {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
}
} else if (t.is_float8()) {
os << "__nv_cvt_fp8x2_to_halfraw2(" << vec << ".__x,"
<< (t.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3") << ")";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -644,6 +684,10 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
}
} else if (t.is_float8()) {
// Since fp8 is a packed type (2 or 4 lanes), we only want call at end.
ICHECK(i == (t.lanes() / 2) - 1);
stream << vec << " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
Expand Down Expand Up @@ -830,7 +874,24 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr
}
os << sret;
} else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
if (ret_dtype.is_float8()) {
std::string fp8_type = (ret_dtype.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3");
os << "__nv_fp8_" << (ret_dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";

LOG_INFO << global_symbol;
os << global_symbol << "(__half(__nv_cvt_fp8_to_halfraw(";
for (size_t i = static_cast<size_t>(skip_first_arg); i < args.size(); ++i) {
this->PrintExpr(args[i], os);
os << ".__x, " << fp8_type << "))";
if (i < args.size() - 1) {
os << ", "
<< "__half(__nv_cvt_fp8_to_halfraw(";
}
}
os << "))";
} else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
}
}
}

Expand Down Expand Up @@ -1691,6 +1752,19 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
return;
}

if (t.is_float8()) {
if (i == 0) {
PrintVecConstructor(t, os);
os << "(make_float" << t.lanes() << "(";
}
if (i != 0) os << ", ";
os << "static_cast<float>(" << value << ")";
if (i == t.lanes() - 1) {
os << "))";
}
return;
}

if (i == 0) {
PrintVecConstructor(t, os);
os << "(";
Expand All @@ -1704,5 +1778,115 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
return;
}

template <typename T>
inline void PrintBinaryExpr(const T* op, const char* opstr,
std::ostream& os, // NOLINT(*)
CodeGenCUDA* p) {
if (op->dtype.lanes() == 1) {
if (op->dtype.is_float8()) {
std::string fp8_type = (op->dtype.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3");
if (isalpha(opstr[0]) || opstr[0] == '_') {
os << "__nv_fp8_" << (op->dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
os << opstr << "(";
os << "__half(__nv_cvt_fp8_to_halfraw(";
p->PrintExpr(op->a, os);
os << ".__x, " << fp8_type << ")), __half(__nv_cvt_fp8_to_halfraw(";
p->PrintExpr(op->b, os);
os << ".__x, " << fp8_type << ")))";
os << ")";
} else {
os << "__nv_fp8_" << (op->dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "(";
os << "__half(__nv_cvt_fp8_to_halfraw(";
p->PrintExpr(op->a, os);
os << ".__x, " << fp8_type << ")) " << opstr << " __half(__nv_cvt_fp8_to_halfraw(";
p->PrintExpr(op->b, os);
os << ".__x, " << fp8_type << ")))";
}
} else {
if (isalpha(opstr[0])) {
os << opstr << '(';
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ')';
} else {
os << '(';
p->PrintExpr(op->a, os);
os << ' ' << opstr << ' ';
p->PrintExpr(op->b, os);
os << ')';
}
}

} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}

void CodeGenCUDA::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "+", os, this);
}
void CodeGenCUDA::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "-", os, this);
}
void CodeGenCUDA::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "*", os, this);
}
void CodeGenCUDA::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "/", os, this);
}
void CodeGenCUDA::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*)
if (op->dtype.is_int() || op->dtype.is_uint()) {
PrintBinaryExpr(op, "%", os, this);
} else {
ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got "
<< op->dtype;
if (op->dtype.bits() == 32) {
PrintBinaryExpr(op, "fmodf", os, this);
} else if (op->dtype.bits() == 64) {
PrintBinaryExpr(op, "fmod", os, this);
} else {
ICHECK(false)
<< "Non single or double precision floating point in Mod, expected 32 or 64 bits but got "
<< op->dtype.bits() << " bits.";
}
}
}

void CodeGenCUDA::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, op->dtype.is_float8() ? "__hmin" : "min", os, this);
}
void CodeGenCUDA::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, op->dtype.is_float8() ? "__hmax" : "max", os, this);
}
void CodeGenCUDA::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "==", os, this);
}
void CodeGenCUDA::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "!=", os, this);
}
void CodeGenCUDA::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<", os, this);
}
void CodeGenCUDA::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "<=", os, this);
}
void CodeGenCUDA::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">", os, this);
}
void CodeGenCUDA::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, ">=", os, this);
}
void CodeGenCUDA::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "&&", os, this);
}
void CodeGenCUDA::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*)
PrintBinaryExpr(op, "||", os, this);
}
void CodeGenCUDA::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*)
os << '!';
PrintExpr(op->a, os);
}

} // namespace codegen
} // namespace tvm
16 changes: 16 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ class CodeGenCUDA final : public CodeGenC {
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitExpr_(const AddNode* op, std::ostream& os) final;
void VisitExpr_(const SubNode* op, std::ostream& os) final;
void VisitExpr_(const MulNode* op, std::ostream& os) final;
void VisitExpr_(const DivNode* op, std::ostream& os) final;
void VisitExpr_(const ModNode* op, std::ostream& os) final;
void VisitExpr_(const MinNode* op, std::ostream& os) final;
void VisitExpr_(const MaxNode* op, std::ostream& os) final;
void VisitExpr_(const EQNode* op, std::ostream& os) final;
void VisitExpr_(const NENode* op, std::ostream& os) final;
void VisitExpr_(const LTNode* op, std::ostream& os) final;
void VisitExpr_(const LENode* op, std::ostream& os) final;
void VisitExpr_(const GTNode* op, std::ostream& os) final;
void VisitExpr_(const GENode* op, std::ostream& os) final;
void VisitExpr_(const AndNode* op, std::ostream& os) final;
void VisitExpr_(const OrNode* op, std::ostream& os) final;
void VisitExpr_(const NotNode* op, std::ostream& os) final;
void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct CUDAMath {
default:
return "";
}
} else if (t.is_bfloat16()) {
} else if (t.is_bfloat16() || t.is_float8()) {
if (name == "fabs") {
return "__habs";
} else if (name == "round") {
Expand Down
Loading