Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/equal.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/random_sample.h"
Expand Down
26 changes: 26 additions & 0 deletions include/infiniop/ops/equal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef __INFINIOP_EQUAL_API_H__
#define __INFINIOP_EQUAL_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopEqualDescriptor_t;

__C __export infiniStatus_t infiniopCreateEqualDescriptor(infiniopHandle_t handle,
infiniopEqualDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);

__C __export infiniStatus_t infiniopGetEqualWorkspaceSize(infiniopEqualDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopEqual(infiniopEqualDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream);

__C __export infiniStatus_t infiniopDestroyEqualDescriptor(infiniopEqualDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions src/infiniop-test/include/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST(causal_softmax)
DECLARE_INFINIOP_TEST(rearrange)
DECLARE_INFINIOP_TEST(sub)
DECLARE_INFINIOP_TEST(equal)

#define REGISTER_INFINIOP_TEST(name) \
{ \
Expand Down Expand Up @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub)
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(equal) \
}

namespace infiniop_test {
Expand Down
109 changes: 109 additions & 0 deletions src/infiniop-test/src/ops/equal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>

namespace infiniop_test::equal {
struct Test::Attributes {
std::shared_ptr<Tensor> a;
std::shared_ptr<Tensor> b;
std::shared_ptr<Tensor> c;
std::shared_ptr<Tensor> ans;
};

std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();
if (tensors.find("a") == tensors.end()
|| tensors.find("b") == tensors.end()
|| tensors.find("c") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test");
}

test->_attributes->a = tensors["a"];
test->_attributes->b = tensors["b"];
test->_attributes->c = tensors["c"];
test->_attributes->ans = tensors["ans"];

return test;
}

std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
infiniopEqualDescriptor_t op_desc;
auto a = _attributes->a->to(device, device_id);
auto b = _attributes->b->to(device, device_id);
auto c = _attributes->c->to(device, device_id);
CHECK_OR(infiniopCreateEqualDescriptor(handle, &op_desc,
c->desc(),
a->desc(),
b->desc()),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
size_t workspace_size;
CHECK_OR(infiniopGetEqualWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
void *workspace;
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
CHECK_OR(infiniopEqual(op_desc, workspace, workspace_size,
c->data(),
a->data(),
b->data(),
nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));

try {
allClose(c, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}

double elapsed_time = 0.;

elapsed_time = benchmark(
[=]() {
infiniopEqual(
op_desc, workspace, workspace_size,
c->data(),
a->data(),
b->data(),
nullptr);
},
warm_ups, iterations);

return TEST_PASSED(elapsed_time);
}

std::vector<std::string> Test::attribute_names() {
return {};
}

std::vector<std::string> Test::tensor_names() {
return {"a", "b", "c", "ans"};
}

std::vector<std::string> Test::output_names() {
return {"c"};
}

std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- a: " << _attributes->a->info() << std::endl;
oss << "- b: " << _attributes->b->info() << std::endl;
oss << "- c: " << _attributes->c->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}

Test::~Test() {
delete _attributes;
}

} // namespace infiniop_test::equal
66 changes: 66 additions & 0 deletions src/infiniop/ops/equal/cpu/equal_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "equal_cpu.h"
#include "infinicore.h"

namespace op::equal::cpu {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);

const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();

auto dtype = a_desc->dtype();

CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);

CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);

// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

switch (_dtype) {
case INFINI_DTYPE_BOOL:
return _device_info->calculate<EqualOp, bool, bool, bool>(_info, output, inputs, stream);
case INFINI_DTYPE_I8:
return _device_info->calculate<EqualOp, bool, int8_t, int8_t>(_info, output, inputs, stream);
case INFINI_DTYPE_I16:
return _device_info->calculate<EqualOp, bool, int16_t, int16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<EqualOp, bool, int32_t, int32_t>(_info, output, inputs, stream);
case INFINI_DTYPE_I64:
return _device_info->calculate<EqualOp, bool, int64_t, int64_t>(_info, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<EqualOp, bool, bf16_t, bf16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<EqualOp, bool, fp16_t, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<EqualOp, bool, float, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<EqualOp, bool, double, double>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::equal::cpu
29 changes: 29 additions & 0 deletions src/infiniop/ops/equal/cpu/equal_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef __EQUAL_CPU_H__
#define __EQUAL_CPU_H__

#include "../../../elementwise/cpu/elementwise_cpu.h"

ELEMENTWISE_DESCRIPTOR(equal, cpu)

namespace op::equal::cpu {
typedef struct EqualOp {
public:
static constexpr size_t num_inputs = 2;
template <typename Tout, typename Ta, typename Tb>
Tout operator()(const Ta &a, const Tb &b) const {
if constexpr (!std::is_same_v<Ta, Tb>) {
printf("Ta and Tb must be the same type!\n");
std::abort();
}
if constexpr (std::is_same_v<Ta, bf16_t> || std::is_same_v<Ta, fp16_t>) {
float f_a = utils::cast<float, Ta>(a);
float f_b = utils::cast<float, Ta>(b);
return f_a == f_b;
} else {
return a == b;
}
}
} EqualOp;
} // namespace op::equal::cpu

#endif // __EQUAL_CPU_H__
19 changes: 19 additions & 0 deletions src/infiniop/ops/equal/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef __EQUAL_CUDA_H__
#define __EQUAL_CUDA_H__

namespace op::equal::cuda {
typedef struct EqualOp {
public:
static constexpr size_t num_inputs = 2;
template <typename Tout, typename Ta, typename Tb>
__device__ __forceinline__ Tout operator()(const Ta &a, const Tb &b) const {
if constexpr (!std::is_same_v<Ta, Tb>) {
printf("Ta and Tb must be the same type!\n");
std::abort();
}
return a == b;
}
} EqualOp;
} // namespace op::equal::cuda

#endif // __EQUAL_CUDA_H__
8 changes: 8 additions & 0 deletions src/infiniop/ops/equal/metax/equal_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __EQUAL_METAX_API_H__
#define __EQUAL_METAX_API_H__

#include "../../../elementwise/metax/elementwise_metax_api.h"

ELEMENTWISE_DESCRIPTOR(equal, metax)

#endif // __EQUAL_METAX_API_H__
73 changes: 73 additions & 0 deletions src/infiniop/ops/equal/metax/equal_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include "equal_metax.h"

#include "../../../elementwise/metax/elementwise_metax.h"

#include "../cuda/kernel.cuh"

namespace op::equal::metax {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::metax::Handle *>(handle_);

const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();

auto dtype = a_desc->dtype();

CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);

CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);

// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_BOOL:
return _device_info->calculate<256, cuda::EqualOp, bool, bool, bool>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I8:
return _device_info->calculate<256, cuda::EqualOp, bool, int8_t, int8_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I16:
return _device_info->calculate<256, cuda::EqualOp, bool, int16_t, int16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I64:
return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::equal::metax
Loading
Loading