Skip to content

Commit 80212cb

Browse files
committed
issue/346: 增加CublasLt支持
1 parent 79dbccd commit 80212cb

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

src/infiniop/devices/nvidia/nvidia_common.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHan
4949
}
5050
#endif
5151

52+
#ifdef ENABLE_CUBLASLT_API
53+
infiniStatus_t Handle::Internal::useCublasLt(cudaStream_t stream, const Fn<cublasLtHandle_t> &f) const {
54+
auto handle = blaslt_handles.pop();
55+
if (!handle) {
56+
CHECK_CUBLASLT(cublasLtCreate(&(*handle)));
57+
}
58+
CHECK_STATUS(f(*handle));
59+
blaslt_handles.push(std::move(*handle));
60+
return INFINI_STATUS_SUCCESS;
61+
}
62+
#endif
63+
5264
int Handle::Internal::warpSize() const { return _warp_size; }
5365
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
5466
int Handle::Internal::blockSizeX() const { return _block_size[0]; }

src/infiniop/devices/nvidia/nvidia_handle.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
#include <cudnn.h>
1212
#endif
1313

14+
#ifdef ENABLE_CUBLASLT_API
15+
#include <cublasLt.h>
16+
#endif
17+
1418
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
1519
#define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS)
20+
#define CHECK_CUBLASLT(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
1621

1722
namespace device::nvidia {
1823

@@ -21,6 +26,9 @@ class Handle::Internal {
2126
#ifdef ENABLE_CUDNN_API
2227
Pool<cudnnHandle_t> dnn_handles;
2328
#endif
29+
#ifdef ENABLE_CUBLASLT_API
30+
Pool<cublasLtHandle_t> blaslt_handles;
31+
#endif
2432

2533
int _warp_size,
2634
_max_threads_per_block,
@@ -37,6 +45,9 @@ public:
3745
#ifdef ENABLE_CUDNN_API
3846
infiniStatus_t useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const;
3947
#endif
48+
#ifdef ENABLE_CUBLASLT_API
49+
infiniStatus_t useCublasLt(cudaStream_t stream, const Fn<cublasLtHandle_t> &f) const;
50+
#endif
4051

4152
int warpSize() const;
4253
int maxThreadsPerBlock() const;

xmake.lua

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ if has_config("cudnn") then
6666
add_defines("ENABLE_CUDNN_API")
6767
end
6868

69+
option("cublaslt")
70+
set_default(true)
71+
set_showmenu(true)
72+
set_description("Whether to compile cublaslt for Nvidia GPU")
73+
option_end()
74+
75+
if has_config("cublaslt") then
76+
add_defines("ENABLE_CUBLASLT_API")
77+
end
78+
6979
-- 寒武纪
7080
option("cambricon-mlu")
7181
set_default(false)
@@ -244,6 +254,20 @@ target("infiniop")
244254
if has_config("iluvatar-gpu") then
245255
add_deps("infiniop-iluvatar")
246256
end
257+
if has_config("sugon-dcu") then
258+
local builddir = string.format(
259+
"build/%s/%s/%s",
260+
get_config("plat"),
261+
get_config("arch"),
262+
get_config("mode")
263+
)
264+
add_shflags("-s", "-shared", "-fPIC")
265+
add_links("cublas", "cublaslt", "cudnn", "cudadevrt", "cudart_static", "rt", "pthread", "dl")
266+
-- Using -linfiniop-nvidia will fail, manually link the target using full path
267+
add_deps("nv-gpu", {inherit = false})
268+
add_links(builddir.."/libinfiniop-nvidia.a")
269+
set_toolchains("sugon-dcu-linker")
270+
end
247271

248272
if has_config("cambricon-mlu") then
249273
add_deps("infiniop-cambricon")

0 commit comments

Comments
 (0)