From 0b93b3e0c61eac89c596fd94e7fb8be188f07f41 Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Thu, 9 Oct 2025 01:26:51 -0700 Subject: [PATCH] Enable native float8 dtypes for add/sub/mul/div --- .../native/xpu/sycl/BinaryDivFloorKernel.cpp | 32 ++++++++++++---- .../native/xpu/sycl/BinaryDivTrueKernel.cpp | 37 ++++++++++++++----- .../native/xpu/sycl/BinaryDivTruncKernel.cpp | 32 ++++++++++++---- src/ATen/native/xpu/sycl/BinaryKernels.cpp | 37 ++++++++++++++----- 4 files changed, 106 insertions(+), 32 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BinaryDivFloorKernel.cpp b/src/ATen/native/xpu/sycl/BinaryDivFloorKernel.cpp index f5067bb2f0..7cdaff09e4 100644 --- a/src/ATen/native/xpu/sycl/BinaryDivFloorKernel.cpp +++ b/src/ATen/native/xpu/sycl/BinaryDivFloorKernel.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -74,8 +74,10 @@ void div_floor_kernel(TensorIteratorBase& iter) { // optimization for floating-point types: if the second operand is a CPU // scalar, compute a * reciprocal(b). Note that this may lose one bit of // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, dtype, "div_floor_xpu", [&]() { + AT_DISPATCH_V2( + dtype, + "div_floor_xpu", + AT_WRAP([&]() { using accscalar_t = at::acc_type_device; auto b = iter.scalar_value(2); if (C10_UNLIKELY(b == 0)) { @@ -86,12 +88,28 @@ void div_floor_kernel(TensorIteratorBase& iter) { iter.remove_operand(2); gpu_kernel( iter, DivFloorWithScalarFunctor(b, inv_b)); - }); + }), + AT_EXPAND(AT_FLOATING_TYPES), + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, dtype, "div_floor_xpu", [&]() { + AT_DISPATCH_V2( + dtype, + "div_floor_xpu", + AT_WRAP([&]() { gpu_kernel_with_scalars(iter, DivFloorFloatFunctor()); - }); + }), + AT_EXPAND(AT_FLOATING_TYPES), + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } } } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryDivTrueKernel.cpp b/src/ATen/native/xpu/sycl/BinaryDivTrueKernel.cpp index 6f35c0cb62..9491a8ac2d 100644 --- a/src/ATen/native/xpu/sycl/BinaryDivTrueKernel.cpp +++ b/src/ATen/native/xpu/sycl/BinaryDivTrueKernel.cpp @@ -1,11 +1,10 @@ -#include +#include #include #include #include -#include - #include +#include namespace at::native::xpu { @@ -21,8 +20,10 @@ void div_true_kernel(TensorIteratorBase& iter) { // optimization for floating-point types: if the second operand is a CPU // scalar, compute a * reciprocal(b). Note that this may lose one bit of // precision compared to computing the division. - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - kHalf, kBFloat16, common_dtype, "div_true_xpu", [&]() { + AT_DISPATCH_V2( + common_dtype, + "div_true_xpu", + AT_WRAP([&]() { using opmath_t = at::opmath_type; auto inv_b = opmath_t(1.0) / iter.scalar_value(2); iter.remove_operand(2); @@ -30,13 +31,31 @@ void div_true_kernel(TensorIteratorBase& iter) { iter, BUnaryFunctor>( MulFunctor(), inv_b)); - }); + }), + AT_EXPAND(AT_COMPLEX_TYPES), + AT_EXPAND(AT_FLOATING_TYPES), + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } else { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( - kHalf, kBFloat16, common_dtype, "div_true_xpu", [&]() { + AT_DISPATCH_V2( + common_dtype, + "div_true_xpu", + AT_WRAP([&]() { DivFunctor f; gpu_kernel_with_scalars(iter, f); - }); + }), + AT_EXPAND(AT_COMPLEX_TYPES), + AT_EXPAND(AT_FLOATING_TYPES), + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } } diff --git a/src/ATen/native/xpu/sycl/BinaryDivTruncKernel.cpp b/src/ATen/native/xpu/sycl/BinaryDivTruncKernel.cpp index 4c2e8a11cd..6f1348e409 100644 --- a/src/ATen/native/xpu/sycl/BinaryDivTruncKernel.cpp +++ b/src/ATen/native/xpu/sycl/BinaryDivTruncKernel.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include @@ -40,18 +40,36 @@ void div_trunc_kernel(TensorIteratorBase& iter) { // optimization for floating-point types: if the second operand is a CPU // scalar, compute a * reciprocal(b). Note that this may lose one bit of // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, dtype, "div_trunc_xpu", [&]() { + AT_DISPATCH_V2( + dtype, + "div_trunc_xpu", + AT_WRAP([&]() { using accscalar_t = at::acc_type_device; auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); iter.remove_operand(2); gpu_kernel(iter, DivTruncScalarFunctor(inv_b)); - }); + }), + AT_EXPAND(AT_FLOATING_TYPES), + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, kBFloat16, dtype, "div_trunc_xpu", [&]() { + AT_DISPATCH_V2( + dtype, + "div_trunc_xpu", + AT_WRAP([&]() { gpu_kernel_with_scalars(iter, DivTruncFunctor()); - }); + }), + AT_EXPAND(AT_FLOATING_TYPES), + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } } diff --git a/src/ATen/native/xpu/sycl/BinaryKernels.cpp b/src/ATen/native/xpu/sycl/BinaryKernels.cpp index daafadd231..3df15f4288 100644 --- a/src/ATen/native/xpu/sycl/BinaryKernels.cpp +++ b/src/ATen/native/xpu/sycl/BinaryKernels.cpp @@ -1,11 +1,10 @@ -#include +#include #include #include #include -#include - #include +#include namespace at::native::xpu { @@ -28,12 +27,22 @@ void add_kernel(TensorIteratorBase& iter, const c10::Scalar& alpha) { opmath_gpu_kernel_with_scalars( iter, AddFunctor(alpha.to())); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - kHalf, kBFloat16, kBool, iter.common_dtype(), "add_xpu", [&]() { + AT_DISPATCH_V2( + common_dtype, + "add_xpu", + AT_WRAP([&]() { using opmath_t = opmath_type; opmath_gpu_kernel_with_scalars( iter, AddFunctor(alpha.to())); - }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + kBool, + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } } @@ -49,12 +58,22 @@ void mul_kernel(TensorIteratorBase& iter) { opmath_symmetric_gpu_kernel_with_scalars( iter, MulFunctor()); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_xpu", [&]() { + AT_DISPATCH_V2( + common_dtype, + "mul_xpu", + AT_WRAP([&]() { using opmath_t = opmath_type; opmath_symmetric_gpu_kernel_with_scalars( iter, MulFunctor()); - }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + kBool, + kHalf, + kBFloat16, + kFloat8_e5m2, + kFloat8_e4m3fn, + kFloat8_e5m2fnuz, + kFloat8_e4m3fnuz); } }