diff --git a/include/infiniop.h b/include/infiniop.h index b3cf8b6ca..7ce791adf 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -8,6 +8,7 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/equal.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/random_sample.h" diff --git a/include/infiniop/ops/equal.h b/include/infiniop/ops/equal.h new file mode 100644 index 000000000..3ac071eb4 --- /dev/null +++ b/include/infiniop/ops/equal.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_EQUAL_API_H__ +#define __INFINIOP_EQUAL_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopEqualDescriptor_t; + +__C __export infiniStatus_t infiniopCreateEqualDescriptor(infiniopHandle_t handle, + infiniopEqualDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b); + +__C __export infiniStatus_t infiniopGetEqualWorkspaceSize(infiniopEqualDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopEqual(infiniopEqualDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroyEqualDescriptor(infiniopEqualDescriptor_t desc); + +#endif diff --git a/src/infiniop-test/include/ops.hpp b/src/infiniop-test/include/ops.hpp index 3820f7cfd..3da9f9855 100644 --- a/src/infiniop-test/include/ops.hpp +++ b/src/infiniop-test/include/ops.hpp @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add) DECLARE_INFINIOP_TEST(causal_softmax) DECLARE_INFINIOP_TEST(rearrange) DECLARE_INFINIOP_TEST(sub) +DECLARE_INFINIOP_TEST(equal) #define REGISTER_INFINIOP_TEST(name) \ { \ @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub) REGISTER_INFINIOP_TEST(causal_softmax) \ REGISTER_INFINIOP_TEST(rearrange) \ REGISTER_INFINIOP_TEST(sub) \ + REGISTER_INFINIOP_TEST(equal) \ } namespace infiniop_test { diff --git a/src/infiniop-test/src/ops/equal.cpp b/src/infiniop-test/src/ops/equal.cpp new file mode 100644 index 000000000..a4c236410 --- /dev/null +++ b/src/infiniop-test/src/ops/equal.cpp @@ -0,0 +1,109 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::equal { +struct Test::Attributes { + std::shared_ptr a; + std::shared_ptr b; + std::shared_ptr c; + std::shared_ptr ans; +}; + +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + if (tensors.find("a") == tensors.end() + || tensors.find("b") == tensors.end() + || tensors.find("c") == tensors.end() + || tensors.find("ans") == tensors.end()) { + throw std::runtime_error("Invalid Test"); + } + + test->_attributes->a = tensors["a"]; + test->_attributes->b = tensors["b"]; + test->_attributes->c = tensors["c"]; + test->_attributes->ans = tensors["ans"]; + + return test; +} + +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) { + infiniopEqualDescriptor_t op_desc; + auto a = _attributes->a->to(device, device_id); + auto b = _attributes->b->to(device, device_id); + auto c = _attributes->c->to(device, device_id); + CHECK_OR(infiniopCreateEqualDescriptor(handle, &op_desc, + c->desc(), + a->desc(), + b->desc()), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor.")); + size_t workspace_size; + CHECK_OR(infiniopGetEqualWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size.")); + void *workspace; + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace.")); + CHECK_OR(infiniopEqual(op_desc, workspace, workspace_size, + c->data(), + a->data(), + b->data(), + nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution.")); + + try { + allClose(c, _attributes->ans, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + double elapsed_time = 0.; + + elapsed_time = benchmark( + [=]() { + infiniopEqual( + op_desc, workspace, workspace_size, + c->data(), + a->data(), + b->data(), + nullptr); + }, + warm_ups, iterations); + + return TEST_PASSED(elapsed_time); +} + +std::vector Test::attribute_names() { + return {}; +} + +std::vector Test::tensor_names() { + return {"a", "b", "c", "ans"}; +} + +std::vector Test::output_names() { + return {"c"}; +} + +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- a: " << _attributes->a->info() << std::endl; + oss << "- b: " << _attributes->b->info() << std::endl; + oss << "- c: " << _attributes->c->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +Test::~Test() { + delete _attributes; +} + +} // namespace infiniop_test::equal diff --git a/src/infiniop/ops/equal/cpu/equal_cpu.cc b/src/infiniop/ops/equal/cpu/equal_cpu.cc new file mode 100644 index 000000000..aea021ed1 --- /dev/null +++ b/src/infiniop/ops/equal/cpu/equal_cpu.cc @@ -0,0 +1,66 @@ +#include "equal_cpu.h" +#include "infinicore.h" + +namespace op::equal::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + auto dtype = a_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_BOOL: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_I8: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_I16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_I32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_I64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::equal::cpu diff --git a/src/infiniop/ops/equal/cpu/equal_cpu.h b/src/infiniop/ops/equal/cpu/equal_cpu.h new file mode 100644 index 000000000..c09a276d7 --- /dev/null +++ b/src/infiniop/ops/equal/cpu/equal_cpu.h @@ -0,0 +1,29 @@ +#ifndef __EQUAL_CPU_H__ +#define __EQUAL_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(equal, cpu) + +namespace op::equal::cpu { +typedef struct EqualOp { +public: + static constexpr size_t num_inputs = 2; + template + Tout operator()(const Ta &a, const Tb &b) const { + if constexpr (!std::is_same_v) { + printf("Ta and Tb must be the same type!\n"); + std::abort(); + } + if constexpr (std::is_same_v || std::is_same_v) { + float f_a = utils::cast(a); + float f_b = utils::cast(b); + return f_a == f_b; + } else { + return a == b; + } + } +} EqualOp; +} // namespace op::equal::cpu + +#endif // __EQUAL_CPU_H__ diff --git a/src/infiniop/ops/equal/cuda/kernel.cuh b/src/infiniop/ops/equal/cuda/kernel.cuh new file mode 100644 index 000000000..636913b26 --- /dev/null +++ b/src/infiniop/ops/equal/cuda/kernel.cuh @@ -0,0 +1,19 @@ +#ifndef __EQUAL_CUDA_H__ +#define __EQUAL_CUDA_H__ + +namespace op::equal::cuda { +typedef struct EqualOp { +public: + static constexpr size_t num_inputs = 2; + template + __device__ __forceinline__ Tout operator()(const Ta &a, const Tb &b) const { + if constexpr (!std::is_same_v) { + printf("Ta and Tb must be the same type!\n"); + std::abort(); + } + return a == b; + } +} EqualOp; +} // namespace op::equal::cuda + +#endif // __EQUAL_CUDA_H__ diff --git a/src/infiniop/ops/equal/metax/equal_metax.h b/src/infiniop/ops/equal/metax/equal_metax.h new file mode 100644 index 000000000..6e4cd64b9 --- /dev/null +++ b/src/infiniop/ops/equal/metax/equal_metax.h @@ -0,0 +1,8 @@ +#ifndef __EQUAL_METAX_API_H__ +#define __EQUAL_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(equal, metax) + +#endif // __EQUAL_METAX_API_H__ diff --git a/src/infiniop/ops/equal/metax/equal_metax.maca b/src/infiniop/ops/equal/metax/equal_metax.maca new file mode 100644 index 000000000..7629cf6aa --- /dev/null +++ b/src/infiniop/ops/equal/metax/equal_metax.maca @@ -0,0 +1,73 @@ +#include "equal_metax.h" + +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::equal::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + auto dtype = a_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // create METAX elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BOOL: + return _device_info->calculate<256, cuda::EqualOp, bool, bool, bool>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I8: + return _device_info->calculate<256, cuda::EqualOp, bool, int8_t, int8_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I16: + return _device_info->calculate<256, cuda::EqualOp, bool, int16_t, int16_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I32: + return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I64: + return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::equal::metax diff --git a/src/infiniop/ops/equal/nvidia/equal_nvidia.cu b/src/infiniop/ops/equal/nvidia/equal_nvidia.cu new file mode 100644 index 000000000..6e8f7444c --- /dev/null +++ b/src/infiniop/ops/equal/nvidia/equal_nvidia.cu @@ -0,0 +1,73 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include "equal_nvidia.cuh" +#include "infinicore.h" + +namespace op::equal::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + auto dtype = a_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BOOL: + return _device_info->calculate<256, cuda::EqualOp, bool, bool, bool>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I8: + return _device_info->calculate<256, cuda::EqualOp, bool, int8_t, int8_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I16: + return _device_info->calculate<256, cuda::EqualOp, bool, int16_t, int16_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I32: + return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I64: + return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::equal::nvidia diff --git a/src/infiniop/ops/equal/nvidia/equal_nvidia.cuh b/src/infiniop/ops/equal/nvidia/equal_nvidia.cuh new file mode 100644 index 000000000..361e54b02 --- /dev/null +++ b/src/infiniop/ops/equal/nvidia/equal_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __EQUAL_CUDA_API_H__ +#define __EQUAL_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(equal, nvidia) + +#endif // __EQUAL_CUDA_API_H__ diff --git a/src/infiniop/ops/equal/operator.cc b/src/infiniop/ops/equal/operator.cc new file mode 100644 index 000000000..2c46c28cd --- /dev/null +++ b/src/infiniop/ops/equal/operator.cc @@ -0,0 +1,145 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/equal.h" + +#ifdef ENABLE_CPU_API +#include "cpu/equal_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/equal_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/equal_metax.h" +#endif + +__C infiniStatus_t infiniopCreateEqualDescriptor( + infiniopHandle_t handle, + infiniopEqualDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::equal::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + {a_desc, \ + b_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetEqualWorkspaceSize(infiniopEqualDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopEqual( + infiniopEqualDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, {a, b}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroyEqualDescriptor(infiniopEqualDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/test/infiniop/equal.py b/test/infiniop/equal.py new file mode 100644 index 000000000..a4897da04 --- /dev/null +++ b/test/infiniop/equal.py @@ -0,0 +1,196 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, a_stride, b_stride, c_stride + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_A = auto() + INPLACE_B = auto() + + +# Inplace options applied for each test case in _TEST_CASES_ +_INPLACE = [ + Inplace.OUT_OF_PLACE, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_ +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +# Data types used for testing +_TENSOR_DTYPES = [ + InfiniDtype.BOOL, + InfiniDtype.I8, + InfiniDtype.I16, + InfiniDtype.I32, + InfiniDtype.I64, + InfiniDtype.BF16, + InfiniDtype.F16, + InfiniDtype.F32, + InfiniDtype.F64, +] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.BOOL: {"atol": 0, "rtol": 0}, + InfiniDtype.I8: {"atol": 0, "rtol": 0}, + InfiniDtype.I16: {"atol": 0, "rtol": 0}, + InfiniDtype.I32: {"atol": 0, "rtol": 0}, + InfiniDtype.I64: {"atol": 0, "rtol": 0}, + InfiniDtype.BF16: {"atol": 0, "rtol": 0}, + InfiniDtype.F16: {"atol": 0, "rtol": 0}, + InfiniDtype.F32: {"atol": 0, "rtol": 0}, + InfiniDtype.F64: {"atol": 0, "rtol": 0}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def eq(c, a, b): + torch.eq(a, b, out=c) + + +def test( + handle, + device, + shape, + a_stride=None, + b_stride=None, + c_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float16, + sync=None, +): + a = TestTensor(shape, a_stride, dtype, device) + b = TestTensor(shape, b_stride, dtype, device) + if inplace == Inplace.INPLACE_A: + if a_stride != c_stride: + return + c = a + elif inplace == Inplace.INPLACE_B: + if c_stride != b_stride: + return + c = b + else: + c = TestTensor(shape, c_stride, InfiniDtype.BOOL, device, mode="ones") + + if c.is_broadcast(): + return + + print( + f"Testing Equal on {InfiniDeviceNames[device]} with shape:{shape} a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} " + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + eq(c.torch_tensor(), a.torch_tensor(), b.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateEqualDescriptor( + handle, + ctypes.byref(descriptor), + c.descriptor, + a.descriptor, + b.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, c]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetEqualWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, c.device) + + def lib_equal(): + check_error( + LIBINFINIOP.infiniopEqual( + descriptor, + workspace.data(), + workspace.size(), + c.data(), + a.data(), + b.data(), + None, + ) + ) + + lib_equal() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(c.actual_tensor(), c.torch_tensor(), atol=atol, rtol=rtol) + assert torch.equal(c.actual_tensor(), c.torch_tensor()) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: eq(c.torch_tensor(), a.torch_tensor(), b.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_equal(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyEqualDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ba1ce33df..2796e7e07 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -583,3 +583,41 @@ def softplus_(lib): ] lib.infiniopDestroySoftplusDescriptor.restype = c_int32 lib.infiniopDestroySoftplusDescriptor.argtypes = [infiniopOperatorDescriptor_t] + + +@OpRegister.operator +def equal_(lib): + lib.infiniopCreateEqualDescriptor.restype = c_int32 + lib.infiniopCreateEqualDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + + lib.infiniopGetEqualWorkspaceSize.restype = c_int32 + lib.infiniopGetEqualWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + + lib.infiniopEqual.restype = c_int32 + lib.infiniopEqual.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + + lib.infiniopDestroyEqualDescriptor.restype = c_int32 + lib.infiniopDestroyEqualDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + + ] diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 082cdf459..4188aa348 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -67,10 +67,34 @@ def __init__( torch_strides.append(strides[i]) else: torch_shape.append(shape[i]) + + is_bool = dt == InfiniDtype.BOOL + if is_bool: + dt = InfiniDtype.F32 + + is_int = ( + dt == InfiniDtype.I8 + or dt == InfiniDtype.I16 + or dt == InfiniDtype.I32 + or dt == InfiniDtype.I64 + ) + + torch_dtype = to_torch_dtype(dt) if mode == "random": - self._torch_tensor = torch.rand( - torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device] - ) + if is_int: + self._torch_tensor = torch.randint( + 0, + 100, + torch_shape, + dtype=to_torch_dtype(dt), + device=torch_device_map[device], + ) + else: + self._torch_tensor = torch.rand( + torch_shape, + dtype=to_torch_dtype(dt), + device=torch_device_map[device], + ) elif mode == "zeros": self._torch_tensor = torch.zeros( torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device] @@ -97,6 +121,9 @@ def __init__( else: raise ValueError("Unsupported mode") + if is_bool: + self._torch_tensor = self._torch_tensor > 0.5 + if scale is not None: self._torch_tensor *= scale if bias is not None: @@ -157,6 +184,8 @@ def to_torch_dtype(dt: InfiniDtype, compatability_mode=False): return torch.float32 elif dt == InfiniDtype.F64: return torch.float64 + elif dt == InfiniDtype.BOOL: + return torch.bool # TODO: These following types may not be supported by older # versions of PyTorch. Use compatability mode to convert them. elif dt == InfiniDtype.U16: