diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index d9fb4701f7e9..463bc07e76be 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -475,8 +475,8 @@ def te_gelu(x: te.Tensor): dtype = x.dtype erf_inp = x * tir.const(0.5**0.5, dtype) - if dtype == "float16": - erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), "float16") + if dtype == "float16" or dtype == "float8_e5m2" or dtype == "float8_e4m3fn": + erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), dtype) else: erf = topi.erf(erf_inp) diff --git a/src/relax/op/distributed/nn.cc b/src/relax/op/distributed/nn.cc index ec0bdaeb3242..c1e786a894a3 100644 --- a/src/relax/op/distributed/nn.cc +++ b/src/relax/op/distributed/nn.cc @@ -32,7 +32,8 @@ StructInfo InferDistStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) ctx->ReportFatal(Diagnostic::Error(call) << "Input of distributed operator must have known ndim"); } - if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float()) { + if (!input_tensor_sinfo->IsUnknownDtype() && !input_tensor_sinfo->dtype.is_float() && + !input_tensor_sinfo->dtype.is_float16() && !input_tensor_sinfo->dtype.is_float8()) { ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " << input_tensor_sinfo->dtype); diff --git a/src/relax/op/distributed/unary.h b/src/relax/op/distributed/unary.h index cfde689421f7..7a02343ab2fb 100644 --- a/src/relax/op/distributed/unary.h +++ b/src/relax/op/distributed/unary.h @@ -40,7 +40,8 @@ StructInfo InferDistStructInfoUnary(const Call& call, const BlockBuilder& ctx, TensorStructInfo input_tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; if (require_float_dtype && !input_tensor_sinfo->IsUnknownDtype() && - !input_tensor_sinfo->dtype.is_float()) { + !input_tensor_sinfo->dtype.is_float() && !input_tensor_sinfo->dtype.is_float16() && + !input_tensor_sinfo->dtype.is_float8()) { ctx->ReportFatal( Diagnostic::Error(call) << call->op diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b4668d65d399..4f3ab668c54d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -74,7 +74,8 @@ StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) { if (data_sinfo->IsUnknownNdim()) { return data_sinfo; } - if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float() && + !data_sinfo->dtype.is_float16() && !data_sinfo->dtype.is_float8()) { ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input tensor to have float " "dtype. However, the given input dtype is " << data_sinfo->dtype); diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index eea6db22fdda..f0cf908c5054 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -199,8 +199,9 @@ template inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx, FType f_compute_out_dtype) { TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); - if (require_float_dtype && !input_sinfo->IsUnknownDtype() && - (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) { + if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float() && + !input_sinfo->dtype.is_bfloat() && !input_sinfo->dtype.is_float16() && + !input_sinfo->dtype.is_float8()) { ctx->ReportFatal( Diagnostic::Error(call) << call->op diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 575f52e2257a..b2f7bdf1d262 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -214,7 +214,7 @@ std::string CodeGenC::GetBufferRef(DataType t, const BufferNode* buffer, PrimExp if (alloc_storage_scope_.count(buffer_var)) { scope = alloc_storage_scope_.at(buffer_var); } - bool is_vol = IsVolatile(buffer_var); + bool is_vol = IsVolatile(buffer_var) && !t.is_float8(); auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { std::ostringstream ptr_os; @@ -840,7 +840,8 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { std::string value = this->PrintExpr(op->value); std::string ref = this->GetBufferRef(value_dtype, op->buffer.get(), index_expr); this->PrintIndent(); - stream << ref << " = " << value << ";\n"; + stream << ref << " = "; + stream << value << ";\n"; } else { arith::PVar base; @@ -876,7 +877,16 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) { stream << '['; PrintVecElemLoad(index, index_expr.dtype(), i, stream); stream << "] = "; - PrintVecElemLoad(value, op->value.dtype(), i, stream); + if (op->value.dtype().is_float8()) { + ICHECK(value_dtype.lanes() == 2); + std::string fp8_type = op->value.dtype().is_float8_e5m2() ? "e5m2" : "e4m3"; + static const char access[] = {'x', 'y'}; + stream << "__nv_fp8_" << fp8_type << "(__half2("; + PrintVecElemLoad(value, op->value.dtype(), i, stream); + stream << ")." << access[i % 2] << ")"; + } else { + PrintVecElemLoad(value, op->value.dtype(), i, stream); + } stream << ";\n"; } EndScope(vec_scope); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a97e66d3467c..370ae6893a80 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -516,7 +516,9 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) { - os << "make_"; + if (!t.is_float8()) { + os << "make_"; // There is no make___nv_fp8 (/usr/local/cuda/include/vector_functions.hpp) + } PrintType(t, os); } @@ -533,22 +535,56 @@ void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr l std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); - for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + if (t.is_float8()) { std::ostringstream value_temp; - if (isalpha(op[0])) { - value_temp << op << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); - value_temp << ", "; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); - value_temp << ")"; + ICHECK(t.is_float8_e4m3fn() || t.is_float8_e5m2()); + if (t.lanes() == 2) { + value_temp << "__nv_fp8x2_" << (t.is_float8_e5m2() ? "e5m2" : "e4m3") << "("; } else { - value_temp << "("; - PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); - value_temp << op; - PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); - value_temp << ")"; + value_temp << "__nv_fp8x4_" << (t.is_float8_e5m2() ? "e5m2" : "e4m3") << "("; + } + for (int i = 0, lanes = t.lanes() / 2; i < lanes; ++i) { + if (isalpha(op[0]) || op[0] == '_') { + value_temp << op << "2" + << "(__half2("; + PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp); + value_temp << "), __half2("; + PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp); + value_temp << "))"; + } else { + value_temp << "__half2("; + PrintVecElemLoad(vlhs, lhs.dtype(), i * lanes, value_temp); + value_temp << ") " << op << " __half2("; + PrintVecElemLoad(vrhs, rhs.dtype(), i * lanes, value_temp); + value_temp << ")"; + } + + if (i != lanes - 1) { + value_temp << ", "; + } + if (i == lanes - 1) { + value_temp << ")"; + PrintVecElemStore(sret, t, i, value_temp.str()); + } + } + } else { + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream value_temp; + if (isalpha(op[0])) { + value_temp << op << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << ", "; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } else { + value_temp << "("; + PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); + value_temp << op; + PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); + value_temp << ")"; + } + PrintVecElemStore(sret, t, i, value_temp.str()); } - PrintVecElemStore(sret, t, i, value_temp.str()); } } EndScope(ssa_scope); @@ -563,6 +599,7 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } static const char access[] = {'x', 'y', 'z', 'w'}; + std::string fp8_type = (t.is_float8()) ? (t.is_float8_e4m3fn() ? "e4m3" : "e5m2") : ""; ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { std::string type_name = t.is_int() ? "char" : "unsigned char"; @@ -584,6 +621,9 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, } else { os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } + } else if (t.is_float8()) { + os << "__nv_cvt_fp8x2_to_halfraw2(" << vec << ".__x," + << (t.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3") << ")"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -644,6 +684,10 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } + } else if (t.is_float8()) { + // Since fp8 is a packed type (2 or 4 lanes), we only want call at end. + ICHECK(i == (t.lanes() / 2) - 1); + stream << vec << " = " << value << ";\n"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -830,7 +874,24 @@ void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Arr } os << sret; } else { - CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); + if (ret_dtype.is_float8()) { + std::string fp8_type = (ret_dtype.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3"); + os << "__nv_fp8_" << (ret_dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "("; + + LOG_INFO << global_symbol; + os << global_symbol << "(__half(__nv_cvt_fp8_to_halfraw("; + for (size_t i = static_cast(skip_first_arg); i < args.size(); ++i) { + this->PrintExpr(args[i], os); + os << ".__x, " << fp8_type << "))"; + if (i < args.size() - 1) { + os << ", " + << "__half(__nv_cvt_fp8_to_halfraw("; + } + } + os << "))"; + } else { + CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); + } } } @@ -1691,6 +1752,19 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val return; } + if (t.is_float8()) { + if (i == 0) { + PrintVecConstructor(t, os); + os << "(make_float" << t.lanes() << "("; + } + if (i != 0) os << ", "; + os << "static_cast(" << value << ")"; + if (i == t.lanes() - 1) { + os << "))"; + } + return; + } + if (i == 0) { PrintVecConstructor(t, os); os << "("; @@ -1704,5 +1778,115 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val return; } +template +inline void PrintBinaryExpr(const T* op, const char* opstr, + std::ostream& os, // NOLINT(*) + CodeGenCUDA* p) { + if (op->dtype.lanes() == 1) { + if (op->dtype.is_float8()) { + std::string fp8_type = (op->dtype.is_float8_e5m2() ? "__NV_E5M2" : "__NV_E4M3"); + if (isalpha(opstr[0]) || opstr[0] == '_') { + os << "__nv_fp8_" << (op->dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "("; + os << opstr << "("; + os << "__half(__nv_cvt_fp8_to_halfraw("; + p->PrintExpr(op->a, os); + os << ".__x, " << fp8_type << ")), __half(__nv_cvt_fp8_to_halfraw("; + p->PrintExpr(op->b, os); + os << ".__x, " << fp8_type << ")))"; + os << ")"; + } else { + os << "__nv_fp8_" << (op->dtype.is_float8_e5m2() ? "e5m2" : "e4m3") << "("; + os << "__half(__nv_cvt_fp8_to_halfraw("; + p->PrintExpr(op->a, os); + os << ".__x, " << fp8_type << ")) " << opstr << " __half(__nv_cvt_fp8_to_halfraw("; + p->PrintExpr(op->b, os); + os << ".__x, " << fp8_type << ")))"; + } + } else { + if (isalpha(opstr[0])) { + os << opstr << '('; + p->PrintExpr(op->a, os); + os << ", "; + p->PrintExpr(op->b, os); + os << ')'; + } else { + os << '('; + p->PrintExpr(op->a, os); + os << ' ' << opstr << ' '; + p->PrintExpr(op->b, os); + os << ')'; + } + } + + } else { + p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os); + } +} + +void CodeGenCUDA::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "+", os, this); +} +void CodeGenCUDA::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "-", os, this); +} +void CodeGenCUDA::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "*", os, this); +} +void CodeGenCUDA::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "/", os, this); +} +void CodeGenCUDA::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) + if (op->dtype.is_int() || op->dtype.is_uint()) { + PrintBinaryExpr(op, "%", os, this); + } else { + ICHECK(op->dtype.is_float()) << "Expected floating point or integer dtype in Mod, but got " + << op->dtype; + if (op->dtype.bits() == 32) { + PrintBinaryExpr(op, "fmodf", os, this); + } else if (op->dtype.bits() == 64) { + PrintBinaryExpr(op, "fmod", os, this); + } else { + ICHECK(false) + << "Non single or double precision floating point in Mod, expected 32 or 64 bits but got " + << op->dtype.bits() << " bits."; + } + } +} + +void CodeGenCUDA::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, op->dtype.is_float8() ? "__hmin" : "min", os, this); +} +void CodeGenCUDA::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, op->dtype.is_float8() ? "__hmax" : "max", os, this); +} +void CodeGenCUDA::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "==", os, this); +} +void CodeGenCUDA::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "!=", os, this); +} +void CodeGenCUDA::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "<", os, this); +} +void CodeGenCUDA::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "<=", os, this); +} +void CodeGenCUDA::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, ">", os, this); +} +void CodeGenCUDA::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, ">=", os, this); +} +void CodeGenCUDA::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "&&", os, this); +} +void CodeGenCUDA::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*) + PrintBinaryExpr(op, "||", os, this); +} +void CodeGenCUDA::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) + os << '!'; + PrintExpr(op->a, os); +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index ed5709ac12be..802111441420 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -68,6 +68,22 @@ class CodeGenCUDA final : public CodeGenC { void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CastNode* op, std::ostream& os) final; + void VisitExpr_(const AddNode* op, std::ostream& os) final; + void VisitExpr_(const SubNode* op, std::ostream& os) final; + void VisitExpr_(const MulNode* op, std::ostream& os) final; + void VisitExpr_(const DivNode* op, std::ostream& os) final; + void VisitExpr_(const ModNode* op, std::ostream& os) final; + void VisitExpr_(const MinNode* op, std::ostream& os) final; + void VisitExpr_(const MaxNode* op, std::ostream& os) final; + void VisitExpr_(const EQNode* op, std::ostream& os) final; + void VisitExpr_(const NENode* op, std::ostream& os) final; + void VisitExpr_(const LTNode* op, std::ostream& os) final; + void VisitExpr_(const LENode* op, std::ostream& os) final; + void VisitExpr_(const GTNode* op, std::ostream& os) final; + void VisitExpr_(const GENode* op, std::ostream& os) final; + void VisitExpr_(const AndNode* op, std::ostream& os) final; + void VisitExpr_(const OrNode* op, std::ostream& os) final; + void VisitExpr_(const NotNode* op, std::ostream& os) final; void VisitStmt_(const EvaluateNode* op) final; void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AttrStmtNode* op) final; diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index e762bde69f4d..0567c55958c0 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -52,7 +52,7 @@ struct CUDAMath { default: return ""; } - } else if (t.is_bfloat16()) { + } else if (t.is_bfloat16() || t.is_float8()) { if (name == "fabs") { return "__habs"; } else if (name == "round") { diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8_operators.py b/tests/python/codegen/test_target_codegen_cuda_fp8_operators.py new file mode 100644 index 000000000000..0c6a1126ca4f --- /dev/null +++ b/tests/python/codegen/test_target_codegen_cuda_fp8_operators.py @@ -0,0 +1,420 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T +from tvm import dlight + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2", "float16"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_matmul_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = T.int64(batch_size) + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + weight = relax.const(np.random.randn(784, 128), original_dtype) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.astype(weight, dtype)) + lv3 = bb.emit(relax.op.matmul(lv1, lv2, dtype)) + lv4 = bb.emit(relax.op.astype(lv3, original_dtype)) + gv = bb.emit_output(lv4) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_conv2d_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 1, 28, 28), original_dtype)) + weight = relax.const(np.random.randn(32, 1, 3, 3), original_dtype) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.astype(weight, dtype)) + lv3 = bb.emit( + relax.op.nn.conv2d( + lv1, + lv2, + strides=[1, 1], + padding=[1, 1, 1, 1], + dilation=[1, 1], + groups=1, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + ) + ) + lv4 = bb.emit(relax.op.astype(lv3, original_dtype)) + gv = bb.emit_output(lv4) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_maxpool2d_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 1, 28, 28), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv3 = bb.emit( + relax.op.nn.max_pool2d( + lv1, + pool_size=[3, 3], + strides=[2, 2], + dilation=[1, 1], + padding=[1, 1, 1, 1], + ceil_mode=False, + count_include_pad=False, + layout="NCHW", + out_layout="NCHW", + ) + ) + lv4 = bb.emit(relax.op.astype(lv3, original_dtype)) + gv = bb.emit_output(lv4) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_add_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + bias = relax.const(np.random.randn(784), original_dtype) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.astype(bias, dtype)) + lv3 = bb.emit(relax.op.add(lv1, lv2)) + lv4 = bb.emit(relax.op.astype(lv3, original_dtype)) + gv = bb.emit_output(lv4) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_relu_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.nn.relu(lv1)) + lv3 = bb.emit(relax.op.astype(lv2, original_dtype)) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_gelu_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.nn.gelu(lv1)) + lv3 = bb.emit(relax.op.astype(lv2, original_dtype)) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_gelu_tanh_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.nn.gelu(lv1)) + lv3 = bb.emit(relax.op.astype(lv2, original_dtype)) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_sigmoid_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.nn.silu(lv1)) + lv3 = bb.emit(relax.op.astype(lv2, original_dtype)) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_silu_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.nn.silu(lv1)) + lv3 = bb.emit(relax.op.astype(lv2, original_dtype)) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("original_dtype", ["float16", "float32"]) +@pytest.mark.parametrize("dtype", ["float8_e4m3fn", "float8_e5m2"]) +@pytest.mark.parametrize("batch_size", [1, 64]) +def test_fp8_softmax_compile(dtype, original_dtype, batch_size): + bb = relax.BlockBuilder() + batch = batch_size + x = relax.Var("x", R.Tensor((batch, 784), original_dtype)) + + with bb.function("forward", [x]): + with bb.dataflow(): + lv1 = bb.emit(relax.op.astype(x, dtype)) + lv2 = bb.emit(relax.op.nn.softmax(lv1)) + lv3 = bb.emit(relax.op.astype(lv2, original_dtype)) + gv = bb.emit_output(lv3) + bb.emit_func_output(gv) + + mod = bb.get() + mod.show() + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + + with target: + mod = relax.get_pipeline("zero")(mod) + mod = dlight.ApplyDefaultSchedule( # pylint: disable=not-callable + dlight.gpu.Matmul(), + dlight.gpu.GEMV(), + dlight.gpu.Reduction(), + dlight.gpu.GeneralReduction(), + dlight.gpu.Fallback(), + )(mod) + + _exe = relax.build(mod, target) + + +if __name__ == "__main__": + tvm.testing.main()