diff --git a/SYCL/BFloat16/bfloat16_builtins.cpp b/SYCL/BFloat16/bfloat16_builtins.cpp new file mode 100644 index 0000000000..ff84ecbeb3 --- /dev/null +++ b/SYCL/BFloat16/bfloat16_builtins.cpp @@ -0,0 +1,246 @@ +// REQUIRES: cuda +// +// Currently this test fails to compile for backends other than cuda. +// Other backends could use this test when bfloat16 math function support is +// added. +// +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch=sm_80 +// RUN: %t.out + +#include + +#include +#include + +using namespace cl::sycl; +using sycl::ext::oneapi::experimental::bfloat16; + +constexpr int N = 60; // divisible by all tested array sizes +constexpr float bf16_eps = 0.00390625; + +float make_fp32(uint16_t x) { + uint32_t y = x; + y = y << 16; + auto res = reinterpret_cast(&y); + return *res; +} + +bool check(float a, float b) { + return fabs(2 * (a - b) / (a + b)) > bf16_eps * 2; +} + +#define TEST_BUILTIN_1_SCAL_IMPL(NAME) \ + { \ + buffer a_buf(&a[0], N); \ + buffer err_buf(&err, 1); \ + q.submit([&](handler &cgh) { \ + accessor A(a_buf, \ + cgh); \ + accessor ERR(err_buf, cgh); \ + cgh.parallel_for(N, [=](id<1> index) { \ + if (check(NAME(bfloat16{A[index]}), NAME(A[index]))) { \ + ERR[0] = 1; \ + } \ + }); \ + }); \ + } \ + assert(err == 0); + +#define TEST_BUILTIN_1_ARR_IMPL(NAME, SZ) \ + { \ + buffer a_buf{range<2>{N / SZ, SZ}}; \ + buffer err_buf(&err, 1); \ + q.submit([&](handler &cgh) { \ + accessor A(a_buf, \ + cgh); \ + accessor ERR(err_buf, cgh); \ + cgh.parallel_for(N / SZ, [=](id<1> index) { \ + marray arg; \ + for (int i = 0; i < SZ; i++) { \ + arg[i] = A[index][i]; \ + } \ + marray res = NAME(arg); \ + for (int i = 0; i < SZ; i++) { \ + if (check(res[i], NAME(A[index][i]))) { \ + ERR[0] = 1; \ + } \ + } \ + }); \ + }); \ + } \ + assert(err == 0); + +#define TEST_BUILTIN_1(NAME) \ + TEST_BUILTIN_1_SCAL_IMPL(NAME) \ + TEST_BUILTIN_1_ARR_IMPL(NAME, 1) \ + TEST_BUILTIN_1_ARR_IMPL(NAME, 2) \ + TEST_BUILTIN_1_ARR_IMPL(NAME, 3) \ + TEST_BUILTIN_1_ARR_IMPL(NAME, 4) \ + TEST_BUILTIN_1_ARR_IMPL(NAME, 5) + +#define TEST_BUILTIN_2_SCAL_IMPL(NAME) \ + { \ + buffer a_buf(&a[0], N); \ + buffer b_buf(&b[0], N); \ + buffer err_buf(&err, 1); \ + q.submit([&](handler &cgh) { \ + accessor A(a_buf, \ + cgh); \ + accessor B(b_buf, \ + cgh); \ + accessor ERR(err_buf, cgh); \ + cgh.parallel_for(N, [=](id<1> index) { \ + if (check(NAME(bfloat16{A[index]}, bfloat16{B[index]}), \ + NAME(A[index], B[index]))) { \ + ERR[0] = 1; \ + } \ + }); \ + }); \ + } \ + assert(err == 0); + +#define TEST_BUILTIN_2_ARR_IMPL(NAME, SZ) \ + { \ + buffer a_buf{range<2>{N / SZ, SZ}}; \ + buffer b_buf{range<2>{N / SZ, SZ}}; \ + buffer err_buf(&err, 1); \ + q.submit([&](handler &cgh) { \ + accessor A(a_buf, \ + cgh); \ + accessor B(b_buf, \ + cgh); \ + accessor ERR(err_buf, cgh); \ + cgh.parallel_for(N / SZ, [=](id<1> index) { \ + marray arg0, arg1; \ + for (int i = 0; i < SZ; i++) { \ + arg0[i] = A[index][i]; \ + arg1[i] = B[index][i]; \ + } \ + marray res = NAME(arg0, arg1); \ + for (int i = 0; i < SZ; i++) { \ + if (check(res[i], NAME(A[index][i], B[index][i]))) { \ + ERR[0] = 1; \ + } \ + } \ + }); \ + }); \ + } \ + assert(err == 0); + +#define TEST_BUILTIN_2(NAME) \ + TEST_BUILTIN_2_SCAL_IMPL(NAME) \ + TEST_BUILTIN_2_ARR_IMPL(NAME, 1) \ + TEST_BUILTIN_2_ARR_IMPL(NAME, 2) \ + TEST_BUILTIN_2_ARR_IMPL(NAME, 3) \ + TEST_BUILTIN_2_ARR_IMPL(NAME, 4) \ + TEST_BUILTIN_2_ARR_IMPL(NAME, 5) + +#define TEST_BUILTIN_3_SCAL_IMPL(NAME) \ + { \ + buffer a_buf(&a[0], N); \ + buffer b_buf(&b[0], N); \ + buffer c_buf(&c[0], N); \ + buffer err_buf(&err, 1); \ + q.submit([&](handler &cgh) { \ + accessor A(a_buf, \ + cgh); \ + accessor B(b_buf, \ + cgh); \ + accessor C(c_buf, \ + cgh); \ + accessor ERR(err_buf, cgh); \ + cgh.parallel_for(N, [=](id<1> index) { \ + if (check(NAME(bfloat16{A[index]}, bfloat16{B[index]}, \ + bfloat16{C[index]}), \ + NAME(A[index], B[index], C[index]))) { \ + ERR[0] = 1; \ + } \ + }); \ + }); \ + } \ + assert(err == 0); + +#define TEST_BUILTIN_3_ARR_IMPL(NAME, SZ) \ + { \ + buffer a_buf{range<2>{N / SZ, SZ}}; \ + buffer b_buf{range<2>{N / SZ, SZ}}; \ + buffer c_buf{range<2>{N / SZ, SZ}}; \ + buffer err_buf(&err, 1); \ + q.submit([&](handler &cgh) { \ + accessor A(a_buf, \ + cgh); \ + accessor B(b_buf, \ + cgh); \ + accessor C(c_buf, \ + cgh); \ + accessor ERR(err_buf, cgh); \ + cgh.parallel_for(N / SZ, [=](id<1> index) { \ + marray arg0, arg1, arg2; \ + for (int i = 0; i < SZ; i++) { \ + arg0[i] = A[index][i]; \ + arg1[i] = B[index][i]; \ + arg2[i] = C[index][i]; \ + } \ + marray res = NAME(arg0, arg1, arg2); \ + for (int i = 0; i < SZ; i++) { \ + if (check(res[i], NAME(A[index][i], B[index][i], C[index][i]))) { \ + ERR[0] = 1; \ + } \ + } \ + }); \ + }); \ + } \ + assert(err == 0); + +#define TEST_BUILTIN_3(NAME) \ + TEST_BUILTIN_3_SCAL_IMPL(NAME) \ + TEST_BUILTIN_3_ARR_IMPL(NAME, 1) \ + TEST_BUILTIN_3_ARR_IMPL(NAME, 2) \ + TEST_BUILTIN_3_ARR_IMPL(NAME, 3) \ + TEST_BUILTIN_3_ARR_IMPL(NAME, 4) \ + TEST_BUILTIN_3_ARR_IMPL(NAME, 5) + +#define TEST_BUILTIN_2_NAN(NAME) \ + { \ + buffer err_buf(&err, 1); \ + buffer nan_buf(&check_nan, 1); \ + q.submit([&](handler &cgh) { \ + accessor ERR(err_buf, cgh); \ + accessor checkNAN( \ + nan_buf, cgh); \ + cgh.single_task([=]() { \ + checkNAN[0] = NAME(bfloat16{NAN}, bfloat16{NAN}); \ + if ((NAME(bfloat16{2}, bfloat16{NAN}) != 2) || \ + (NAME(bfloat16{NAN}, bfloat16{2}) != 2)) { \ + ERR[0] = 1; \ + } \ + }); \ + }); \ + } \ + assert(err == 0); \ + assert(std::isnan(check_nan)); + +int main() { + queue q; + + if (q.get_device().has(aspect::ext_oneapi_bfloat16)) { + std::vector a(N), b(N), c(N); + int err = 0; + + for (int i = 0; i < N; i++) { + a[i] = (i - N / 2) / (float)N; + b[i] = (N / 2 - i) / (float)N; + c[i] = (float)(3 * i); + } + + TEST_BUILTIN_1(fabs); + TEST_BUILTIN_2(fmin); + TEST_BUILTIN_2(fmax); + TEST_BUILTIN_3(fma); + + float check_nan = 0; + TEST_BUILTIN_2_NAN(fmin); + TEST_BUILTIN_2_NAN(fmax); + } + return 0; +} diff --git a/SYCL/Matrix/element_wise_all_ops_cuda.cpp b/SYCL/Matrix/element_wise_all_ops_cuda.cpp new file mode 100644 index 0000000000..69976fa7e4 --- /dev/null +++ b/SYCL/Matrix/element_wise_all_ops_cuda.cpp @@ -0,0 +1,184 @@ +//==----------- element_wise_all_ops_cuda.cpp - DPC++ joint_matrix---------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: cuda + +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out +// RUN: %t.out + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::experimental::bfloat16; + +#define SG_SZ 32 +constexpr size_t nWGperDim = 2; + +class Logical {}; + +template +class KernelName; + +template struct big_matrix { +public: + T *mat; + +public: + T *get_data() { return mat; } + void set_data(T *data) { mat = data; } + big_matrix(T *data) : mat(data) {} +}; + +template +void assert_ops_ref(T *C, const float ref) { + for (size_t i = 0; i < M; i++) + for (size_t j = 0; j < N; j++) { + auto diff = C[i + j * M] - ref; + assert(std::fabs(static_cast(diff)) < + std::numeric_limits::epsilon()); + } +} +template +void matrix_verify_op(queue q, big_matrix &C, + nd_range<2> &r, const float ref, Operation Op) { + { + buffer bufC(C.get_data(), range<2>(N * nWGperDim, M * nWGperDim)); + + q.submit([&](handler &cgh) { + accessor accC(bufC, + cgh); + + cgh.parallel_for>( + r, [accC, + Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + const auto global_idx = spmd_item.get_global_id(0); + const auto global_idy = spmd_item.get_global_id(1); + const auto sg_startx = global_idx - spmd_item.get_local_id(0); + const auto sg_starty = global_idy - spmd_item.get_local_id(1); + + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_b; + joint_matrix sub_c; + + joint_matrix_fill(sg, sub_a, 3); + joint_matrix_fill(sg, sub_b, 1); + joint_matrix_fill(sg, sub_c, -80); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + if constexpr (std::is_same_v) { + if (wi_slice_a[i]) { + if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 3.0 || + wi_slice_a[i] < 4.0 || wi_slice_a[i] <= 3.0) { + T val = (wi_slice_a[i] != (2.0)) ? wi_slice_a[i] + : static_cast(2.0); + val = ((val) - (1)); + val = ((val) + (1)); + if (wi_slice_a[i] == (2.0)) { + val = ((val) - (2)); + val = ((val) * (3)); + val = ((val) / (2)); + + } else { + val = ((val) + (2)); + } + wi_slice_a[i] = val; + } + } + } else { + wi_slice_a[i] = Op(wi_slice_a[i], 2); + } + } + + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + + joint_matrix_store(sg, sub_c, + accC.get_pointer() + + (sg_startx * M) * (N * nWGperDim) + + sg_starty / SG_SZ * N, + (N * nWGperDim)); + }); // parallel for + }).wait(); + } + assert_ops_ref(C.get_data(), ref); +} + +static constexpr size_t MATRIX_M = 16 * nWGperDim; +static constexpr size_t MATRIX_N = 16 * nWGperDim; + +int main() { + + float D[MATRIX_M][MATRIX_N]; + big_matrix MD_f((float *)&D); + + queue q; + auto computeCapability = + std::stof(q.get_device().get_info()); + nd_range<2> r({nWGperDim, nWGperDim * SG_SZ}, {1, 1 * SG_SZ}); + + if (computeCapability >= 7.0) { + matrix_verify_op(q, MD_f, r, 0.0, + std::plus{}); + matrix_verify_op(q, MD_f, r, 0.0, Logical{}); + matrix_verify_op(q, MD_f, r, 16.0, + std::multiplies{}); + matrix_verify_op(q, MD_f, r, -56.0, + std::divides{}); + matrix_verify_op(q, MD_f, r, -64.0, + std::minus{}); + } + + if (computeCapability >= 7.2) { + int32_t D_i[MATRIX_M][MATRIX_N]; + big_matrix MD_i((int32_t *)&D_i); + matrix_verify_op(q, MD_i, r, 0, + std::plus{}); + matrix_verify_op(q, MD_i, r, 16, + std::multiplies{}); + matrix_verify_op(q, MD_i, r, -64, + std::minus{}); + matrix_verify_op(q, MD_i, r, 0, + std::plus{}); + matrix_verify_op(q, MD_i, r, 0.0, Logical{}); + matrix_verify_op(q, MD_i, r, 16, + std::multiplies{}); + matrix_verify_op(q, MD_i, r, -64, + std::minus{}); + } + + if (computeCapability >= 8.0) { + + matrix_verify_op(q, MD_f, r, 0.0, + std::plus{}); + matrix_verify_op(q, MD_f, r, 0.0, Logical{}); + matrix_verify_op(q, MD_f, r, 16.0, + std::multiplies{}); + matrix_verify_op(q, MD_f, r, -56.0, + std::divides{}); + matrix_verify_op(q, MD_f, r, -64.0, + std::minus{}); + + double D_d[MATRIX_M / 2][MATRIX_N / 2]; + big_matrix MD_d((double *)&D_d); + + matrix_verify_op(q, MD_d, r, -60.0, + std::plus{}); + matrix_verify_op(q, MD_d, r, -60.0, Logical{}); + matrix_verify_op(q, MD_d, r, -56.0, + std::multiplies{}); + matrix_verify_op(q, MD_d, r, -74.0, + std::divides{}); + matrix_verify_op(q, MD_d, r, -76.0, + std::minus{}); + } + + return 0; +} diff --git a/SYCL/Matrix/element_wise_wi_marray.cpp b/SYCL/Matrix/element_wise_wi_marray.cpp new file mode 100644 index 0000000000..6ab3947ed9 --- /dev/null +++ b/SYCL/Matrix/element_wise_wi_marray.cpp @@ -0,0 +1,67 @@ +//==----------- element_wise_wi_marray.cpp - DPC++ joint_matrix------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// REQUIRES: cuda + +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out +// RUN: %t.out + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::experimental::bfloat16; + +#define SG_SZ 32 + +template void verify_wi_marray(queue q) { + int err = 0; + { + buffer err_buf(&err, 1); + q.submit([&](handler &cgh) { + accessor ERR(err_buf, cgh); + + cgh.parallel_for( + nd_range<2>({1, 1 * SG_SZ}, {1, 1 * SG_SZ}), + [ERR](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { + auto sg = spmd_item.get_sub_group(); + + joint_matrix sub_a; + joint_matrix sub_a_2; + + joint_matrix_fill(sg, sub_a, -1); + joint_matrix_fill(sg, sub_a_2, -1); + + auto wi_slice_a = sub_a.get_wi_data(); + for (int i = 0; i < wi_slice_a.length(); i++) { + wi_slice_a[i] = fabs(wi_slice_a[i]); + } + sub_a_2.wi_marray = fabs(sub_a_2.wi_marray); + + for (int i = 0; i < sub_a_2.wi_marray.size(); i++) { + if (sub_a_2.wi_marray[i] != wi_slice_a[i]) { + ERR[0] = 1; + } + } + }); // parallel for + }).wait(); + } + assert(err == 0); +} + +int main() { + + queue q; + auto computeCapability = + std::stof(q.get_device().get_info()); + + if (computeCapability >= 8.0) { + verify_wi_marray(q); + } + + return 0; +} diff --git a/SYCL/Matrix/joint_matrix_tensorcore.cpp b/SYCL/Matrix/joint_matrix_tensorcore.cpp index a489aeb0ca..2b5078d415 100644 --- a/SYCL/Matrix/joint_matrix_tensorcore.cpp +++ b/SYCL/Matrix/joint_matrix_tensorcore.cpp @@ -1,6 +1,6 @@ -// REQUIRES: gpu, cuda - -// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out +// REQUIRES: cuda +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 %s -o %t.out +// RUN: %t.out // // Specifying the sm version via the --cuda-gpu-arch flag is necessary // for the Nvidia case. DPC++ JIT compilation is not @@ -11,6 +11,8 @@ using namespace sycl; using namespace sycl::ext::oneapi::experimental::matrix; +using sycl::ext::oneapi::experimental::bfloat16; +constexpr float bf16_eps = 0.00390625; // Example usage of Nvidia matrix multiply. // Optimizations such as memory paddings for avoiding bank conflicts are not @@ -43,17 +45,17 @@ class TypeHelper; template using KernelName = class TypeHelper; -float make_fp32(short x) { - unsigned int y = x; +float make_fp32(uint16_t x) { + uint32_t y = x; y = y << 16; - float *res = reinterpret_cast(&y); + auto res = reinterpret_cast(&y); return *res; } -unsigned short make_bf16(float x) { - int *res = reinterpret_cast(&x); +uint16_t make_bf16(float x) { + auto res = reinterpret_cast(&x); *res = *res >> 16; - return (unsigned short)*res; + return (uint16_t)*res; } template @@ -63,6 +65,10 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) { if constexpr (std::is_same::value) { for (int k = 0; k < Big_K; k++) res += make_fp32(A[m * Big_K + k]) * make_fp32(B[k * Big_N + n]); + } else if constexpr (std::is_same::value) { + for (int k = 0; k < Big_K; k++) + res += + make_fp32(A[m * Big_K + k].raw()) * make_fp32(B[k * Big_N + n].raw()); } else { for (int k = 0; k < Big_K; k++) @@ -75,7 +81,7 @@ T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) { template -void test() { +void test(queue &q) { constexpr auto Big_M = Sub_Tiles_M * @@ -105,7 +111,7 @@ void test() { for (int i = 0; i < Big_K * Big_N; i++) { B[i] = make_bf16(0.1f * (i % 10)); } - } else { + } else if constexpr (!std::is_same::value) { for (int i = 0; i < Big_M * Big_K; i++) { A[i] = i % 100; } @@ -114,110 +120,157 @@ void test() { B[i] = i % 100; } } + { + buffer bufA(A, range<1>(Big_M * Big_K)); + buffer bufB(B, range<1>(Big_K * Big_N)); + buffer bufC(C, range<1>(Big_M * Big_N)); + buffer bufD(D, range<1>(Big_M * Big_N)); + + // currently bfloat16 has to be initialized on device + if constexpr (std::is_same::value) { + q.submit([&](handler &cgh) { + accessor accA(bufA, + cgh); + + cgh.parallel_for>( + range<1>(Big_M * Big_K), [=](item<1> item) { + auto i = item.get_linear_id(); + accA[i] = 0.1f * (i % 10); + }); + }); + + q.submit([&](handler &cgh) { + accessor accB(bufB, + cgh); + + cgh.parallel_for>( + range<1>(Big_K * Big_N), [=](item<1> item) { + auto i = item.get_linear_id(); + accB[i] = 0.1f * (i % 10); + }); + }); + } - buffer bufA(A, range<1>(Big_M * Big_K)); - buffer bufB(B, range<1>(Big_K * Big_N)); - buffer bufC(C, range<1>(Big_M * Big_N)); - buffer bufD(D, range<1>(Big_M * Big_N)); - - queue q; - q.submit([&](handler &cgh) { - auto accC = bufC.template get_access(cgh); - auto accA = bufA.template get_access(cgh); - auto accB = bufB.template get_access(cgh); - auto accD = bufD.template get_access(cgh); - - range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP}; - range<2> GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP}; - - cgh.parallel_for>( - nd_range<2>(GlobalRange, LocalRange), - [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { - sycl::sub_group sg = item.get_sub_group(); - const auto m = - item.get_group().get_group_id()[0]; // row id of current submatrix - // of BIG C matrix - const auto n = - item.get_group().get_group_id()[1]; // column id of current - // submatrix of BIG C matrix - - joint_matrix sub_a; - - joint_matrix sub_b; - - joint_matrix - sub_c; - - joint_matrix_load( - sg, sub_c, accC.get_pointer() + (m * M) * Big_N + n * N, Big_N); - - for (int k = 0; k < Sub_Tiles_K; - k++) // row/col id of current submatrix of BIG A/B matrices - { - joint_matrix_load(sg, sub_a, - accA.get_pointer() + (k * K) + (m * M * Big_K), - Big_K); - - joint_matrix_load(sg, sub_b, - accB.get_pointer() + (k * K * Big_N) + (n * N), - Big_N); - - // Convert values if using tf32 - if constexpr (std::is_same::value) { - for (auto i = 0; i < 4; ++i) { - sub_a.data[i] = round_to_tf32(sub_a.data[i]); - sub_b.data[i] = round_to_tf32(sub_b.data[i]); + q.submit([&](handler &cgh) { + accessor accA(bufA, cgh); + accessor accB(bufB, cgh); + accessor accC(bufC, cgh); + accessor accD(bufD, cgh); + + range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP}; + range<2> GlobalRange = {Sub_Tiles_M, + Sub_Tiles_N * N_THREADS_PER_MATRIX_OP}; + + cgh.parallel_for>( + nd_range<2>(GlobalRange, LocalRange), [=](nd_item<2> item) { + sub_group sg = item.get_sub_group(); + const auto m = + item.get_group().get_group_id()[0]; // row id of current + // submatrix of BIG C matrix + const auto n = + item.get_group().get_group_id()[1]; // column id of current + // submatrix of BIG C matrix + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + joint_matrix + sub_c; + + joint_matrix_load( + sg, sub_c, accC.get_pointer() + (m * M) * Big_N + n * N, Big_N); + + for (int k = 0; k < Sub_Tiles_K; + k++) // row/col id of current submatrix of BIG A/B matrices + { + joint_matrix_load(sg, sub_a, + accA.get_pointer() + (k * K) + (m * M * Big_K), + Big_K); + + joint_matrix_load(sg, sub_b, + accB.get_pointer() + (k * K * Big_N) + (n * N), + Big_N); + + // round values to correct precision if using tf32 + if constexpr (std::is_same::value) { + auto wi_size = sub_a.wi_marray.size(); + assert(wi_size == sub_b.wi_marray.size()); + for (auto i = 0; i < wi_size; ++i) { + sub_a.wi_marray[i] = round_to_tf32(sub_a.wi_marray[i]); + sub_b.wi_marray[i] = round_to_tf32(sub_b.wi_marray[i]); + } } - } - sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); - } - joint_matrix_store( - sg, sub_c, accD.get_pointer() + (m * M) * Big_N + n * N, Big_N); - }); - }); - - q.wait(); + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + } + joint_matrix_store( + sg, sub_c, accD.get_pointer() + (m * M) * Big_N + n * N, Big_N); + }); + }); + q.wait(); + } - const auto host_accessor = bufD.template get_access(); - for (int m = 0; m < Big_M; m++) + for (int m = 0; m < Big_M; m++) { for (int n = 0; n < Big_N; n++) { - - assert((host_accessor[m * Big_N + n] == - matrix_ref_mn(m, n, A, B, C))); + if constexpr (std::is_same::value) { + auto res_device = matrix_ref_mn(m, n, A, B, C); + assert(fabs(2 * (D[m * Big_N + n] - res_device)) / + (D[m * Big_N + n] + res_device) < + bf16_eps * 2); + } else { + assert((D[m * Big_N + n] == + matrix_ref_mn(m, n, A, B, C))); + } } + } }; int main() { - // A/B half, Accumulator float - test(); - test(); - test(); - - // A/B/Accumulator half - test(); - test(); - test(); - test(); - test(); - test(); + queue Q; + auto computeCapability = + std::stof(Q.get_device().get_info()); - test(); - test(); - test(); + if (computeCapability >= 7.0) { + // A/B half, Accumulator float + test(Q); + test(Q); + test(Q); - test(); + // A/B/Accumulator half + test(Q); + test(Q); + test(Q); + } + if (computeCapability >= 7.2) { + test(Q); + test(Q); + test(Q); + + test( + Q); + test(Q); + test(Q); + } + if (computeCapability >= 8.0) { + test(Q); - // A/B bf16 - test(); - test(); - test(); + // A/B bfloat16 using storage type + test(Q); + test(Q); + test(Q); - // A/B tf32 - test(); + test(Q); + test(Q); + test(Q); + // A/B tf32 + test(Q); + } return 0; };