Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions src/runtime/device_c_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include <tvm/runtime/device_api.h>

using tvm::Device;
using tvm::runtime::DeviceAPI;

extern "C" {
struct TVMDeviceAPI {
DeviceAPI* impl;
};

TVM_DLL TVMDeviceAPI* TVMDeviceAPIGet(DLDevice dev, bool allow_missing) {
DeviceAPI* api = DeviceAPI::Get(dev, allow_missing);
if (!api) {
return nullptr;
}
TVMDeviceAPI* handle = new TVMDeviceAPI();
handle->impl = api;
return handle;
}

TVM_DLL void TVMDeviceAPIDestroy(TVMDeviceAPI* handle) { delete handle; }

TVM_DLL void TVMDeviceAPISetDevice(TVMDeviceAPI* handle, DLDevice dev) {
handle->impl->SetDevice(dev);
}

TVM_DLL void TVMDeviceAPIGetAttr(TVMDeviceAPI* handle, DLDevice dev, int kind, void* out_any) {
auto* any = static_cast<tvm::ffi::Any*>(out_any);
handle->impl->GetAttr(dev, static_cast<tvm::runtime::DeviceAttrKind>(kind), any);
}

TVM_DLL size_t TVMDeviceAPIGetDataSize(TVMDeviceAPI* handle, const DLTensor* arr,
const char* mem_scope) {
tvm::ffi::Optional<tvm::ffi::String> opt_scope = std::nullopt;
if (mem_scope != nullptr && mem_scope[0] != '\0') {
opt_scope = tvm::ffi::String(mem_scope);
}
return handle->impl->GetDataSize(*arr, opt_scope);
}

TVM_DLL void TVMDeviceAPIGetTargetProperty(TVMDeviceAPI* handle, DLDevice dev, const char* property,
void* out_any) {
auto* any = static_cast<tvm::ffi::Any*>(out_any);
handle->impl->GetTargetProperty(dev, std::string(property), any);
}

TVM_DLL void* TVMDeviceAPIAllocDataSpace(TVMDeviceAPI* handle, DLDevice dev, size_t nbytes,
size_t alignment, DLDataType type_hint) {
return handle->impl->AllocDataSpace(dev, nbytes, alignment, type_hint);
}

TVM_DLL void* TVMDeviceAPIAllocDataSpaceND(TVMDeviceAPI* handle, DLDevice dev, int ndim,
const int64_t* shape, DLDataType dtype,
const char* mem_scope) {
tvm::ffi::Optional<tvm::ffi::String> opt_scope = std::nullopt;
if (mem_scope != nullptr && mem_scope[0] != '\0') {
opt_scope = tvm::ffi::String(mem_scope);
}
return handle->impl->AllocDataSpace(dev, ndim, shape, dtype, opt_scope);
}

TVM_DLL void TVMDeviceAPIFreeDataSpace(TVMDeviceAPI* handle, DLDevice dev, void* ptr) {
handle->impl->FreeDataSpace(dev, ptr);
}

TVM_DLL void TVMDeviceAPICopyDataFromTo(TVMDeviceAPI* handle, DLTensor* from, DLTensor* to,
TVMStreamHandle stream) {
handle->impl->CopyDataFromTo(from, to, stream);
}

TVM_DLL TVMStreamHandle TVMDeviceAPICreateStream(TVMDeviceAPI* handle, DLDevice dev) {
return handle->impl->CreateStream(dev);
}

TVM_DLL void TVMDeviceAPIFreeStream(TVMDeviceAPI* handle, DLDevice dev, TVMStreamHandle stream) {
handle->impl->FreeStream(dev, stream);
}

TVM_DLL void TVMDeviceAPIStreamSync(TVMDeviceAPI* handle, DLDevice dev, TVMStreamHandle stream) {
handle->impl->StreamSync(dev, stream);
}

TVM_DLL void TVMDeviceAPISetStream(TVMDeviceAPI* handle, DLDevice dev, TVMStreamHandle stream) {
handle->impl->SetStream(dev, stream);
}

TVM_DLL TVMStreamHandle TVMDeviceAPIGetCurrentStream(TVMDeviceAPI* handle, DLDevice dev) {
return handle->impl->GetCurrentStream(dev);
}

TVM_DLL void TVMDeviceAPISyncStreamFromTo(TVMDeviceAPI* handle, DLDevice dev,
TVMStreamHandle event_src, TVMStreamHandle event_dst) {
handle->impl->SyncStreamFromTo(dev, event_src, event_dst);
}

TVM_DLL void* TVMDeviceAPIAllocWorkspace(TVMDeviceAPI* handle, DLDevice dev, size_t nbytes,
DLDataType type_hint) {
return handle->impl->AllocWorkspace(dev, nbytes, type_hint);
}

TVM_DLL void TVMDeviceAPIFreeWorkspace(TVMDeviceAPI* handle, DLDevice dev, void* ptr) {
handle->impl->FreeWorkspace(dev, ptr);
}

TVM_DLL int TVMDeviceAPISupportsDevicePointerArithmeticsOnHost(TVMDeviceAPI* handle) {
return handle->impl->SupportsDevicePointerArithmeticsOnHost() ? 1 : 0;
}
}
10 changes: 5 additions & 5 deletions src/runtime/vm/attn_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ std::unique_ptr<PagedPrefillFunc> ConvertPagedPrefillFunc(ffi::Array<ffi::Any> a
return nullptr;
}
ffi::String backend_name = args[0].cast<ffi::String>();
if (backend_name == "tir") {
if (backend_name == "tirx" || backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = args[1].cast<ffi::Function>();
return std::make_unique<TIRPagedPrefillFunc>(std::move(attn_func), attn_kind);
Expand All @@ -53,7 +53,7 @@ std::unique_ptr<RaggedPrefillFunc> ConvertRaggedPrefillFunc(ffi::Array<ffi::Any>
return nullptr;
}
ffi::String backend_name = args[0].cast<ffi::String>();
if (backend_name == "tir") {
if (backend_name == "tirx" || backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = args[1].cast<ffi::Function>();
return std::make_unique<TIRRaggedPrefillFunc>(std::move(attn_func), attn_kind);
Expand Down Expand Up @@ -82,7 +82,7 @@ std::unique_ptr<PagedDecodeFunc> ConvertPagedDecodeFunc(ffi::Array<ffi::Any> arg
return nullptr;
}
ffi::String backend_name = args[0].cast<ffi::String>();
if (backend_name == "tir") {
if (backend_name == "tirx" || backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = args[1].cast<ffi::Function>();
return std::make_unique<TIRPagedDecodeFunc>(std::move(attn_func), attn_kind);
Expand All @@ -104,7 +104,7 @@ std::unique_ptr<PagedPrefillTreeMaskFunc> ConvertPagedPrefillTreeMaskFunc(ffi::A
return nullptr;
}
ffi::String backend_name = args[0].cast<ffi::String>();
if (backend_name == "tir") {
if (backend_name == "tirx" || backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = args[1].cast<ffi::Function>();
return std::make_unique<TIRPagedPrefillTreeMaskFunc>(std::move(attn_func), attn_kind);
Expand All @@ -119,7 +119,7 @@ std::unique_ptr<RaggedPrefillTreeMaskFunc> ConvertRaggedPrefillTreeMaskFunc(
return nullptr;
}
ffi::String backend_name = args[0].cast<ffi::String>();
if (backend_name == "tir") {
if (backend_name == "tirx" || backend_name == "tir") {
CHECK_EQ(args.size(), 2);
ffi::Function attn_func = args[1].cast<ffi::Function>();
return std::make_unique<TIRRaggedPrefillTreeMaskFunc>(std::move(attn_func), attn_kind);
Expand Down