Skip to content

Commit feaeda0

Browse files
committed
Review: change kernel params, use strides from host
1 parent a93bf9f commit feaeda0

File tree

2 files changed

+54
-43
lines changed

2 files changed

+54
-43
lines changed

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,36 @@
22

33
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
44

5-
static __device__ void set_rows_1_f32_f32(const char * src, char * dst) {
6-
const float * src_f = (const float *) src;
7-
float * dst_f = (float *) dst;
8-
*dst_f = *src_f;
5+
template<typename src_t, typename dst_t>
6+
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
7+
GGML_ABORT("unsupport type for set_rows");
98
}
109

11-
static __device__ void set_rows_1_f32_f16(const char * src, char * dst) {
12-
const float * src_f = (const float *) src;
13-
half * dst_h = (half *) dst;
10+
template<>
11+
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
1412
*dst_h = __float2half(*src_f);
1513
}
1614

15+
template<>
16+
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
17+
*dst_f = *src_f;
18+
}
19+
1720
//TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic
18-
template<set_rows_kernel_t set_rows_1>
21+
template<typename src_t, typename dst_t>
1922
static __global__ void k_set_rows(
20-
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
23+
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
2124
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
2225
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
2326
const size_t nb01, const size_t nb02, const size_t nb03,
2427
const size_t nb10, const size_t nb11, const size_t nb12,
2528
const size_t nb1, const size_t nb2, const size_t nb3,
2629
const size_t src_type_size, const size_t dst_type_size) {
2730

28-
const int i03 = blockIdx.z;
29-
const int i02 = blockIdx.y;
30-
const int i01 = blockIdx.x * blockDim.y + threadIdx.y; // Row index
31+
const int i03 = blockIdx.z / ne02;
32+
const int i02 = blockIdx.z % ne02;
33+
const int i01 = blockDim.x * blockIdx.x + threadIdx.x;
34+
const int i00 = blockIdx.y;
3135

3236
if (i01 >= ne01) {
3337
return;
@@ -37,21 +41,19 @@ static __global__ void k_set_rows(
3741
const int i11 = i02 % ne11;
3842
const int i10 = i01;
3943

40-
const int64_t dst_row = *(int64_t *)((char *)src1 + i10*nb10 + i11*nb11 + i12*nb12);
44+
const int64_t dst_row = *(src1 + i10*nb10 + i11*nb11 + i12*nb12);
4145

42-
const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
43-
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
46+
const src_t * src0_row = (const src_t *)src0 + i01*nb01 + i02*nb02 + i03*nb03;
47+
dst_t * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
4448

45-
for (int col = threadIdx.x; col < ne00; col += blockDim.x) {
46-
const char * src_elem = src0_row + col * src_type_size;
47-
char * dst_elem = dst_row_ptr + col * dst_type_size;
48-
set_rows_1(src_elem, dst_elem);
49-
}
49+
const src_t* src_elem = src0_row + i00;
50+
dst_t* dst_elem = dst_row_ptr + i00;
51+
set_rows_1(src_elem, dst_elem);
5052
}
5153

52-
template<set_rows_kernel_t set_rows_1>
54+
template<typename src_t, typename dst_t>
5355
static void set_rows_cuda(
54-
const char * src0_d, const int64_t * src1_d, char * dst_d,
56+
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
5557
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
5658
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
5759
const size_t nb01, const size_t nb02, const size_t nb03,
@@ -60,32 +62,39 @@ static void set_rows_cuda(
6062
const size_t src_type_size, const size_t dst_type_size,
6163
cudaStream_t stream) {
6264

63-
const int max_threads_per_row = 256;
64-
const int threads_per_row = std::min((int)ne00, max_threads_per_row);
65-
66-
const int max_threads_per_block = 256;
67-
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
68-
69-
const dim3 block_size(threads_per_row, rows_per_block, 1);
65+
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
7066
const dim3 grid_size(
71-
(ne01 + rows_per_block - 1) / rows_per_block, // thread-groups
72-
ne02,
73-
ne03
67+
(ne01 + CUDA_SET_ROWS_BLOCK_SIZE - 1)/CUDA_SET_ROWS_BLOCK_SIZE,
68+
ne00,
69+
ne03*ne02
7470
);
7571

76-
if (ne01 > 0 && ne00 > 0) {
77-
k_set_rows<set_rows_1><<<grid_size, block_size, 0, stream>>>(
72+
const int s1 = nb01 / sizeof(src_t);
73+
const int s2 = nb02 / sizeof(src_t);
74+
const int s3 = nb03 / sizeof(src_t);
75+
76+
const int s10 = nb10 / sizeof(int64_t);
77+
const int s11 = nb11 / sizeof(int64_t);
78+
const int s12 = nb12 / sizeof(int64_t);
79+
80+
const int s_dst = nb1 / sizeof(dst_t);
81+
const int s_dst2 = nb2 / sizeof(dst_t);
82+
const int s_dst3 = nb3 / sizeof(dst_t);
83+
84+
85+
if(ne01 > 0 && ne00 > 0) {
86+
k_set_rows<<<grid_size, block_size, 0, stream>>>(
7887
src0_d, src1_d, dst_d,
7988
ne00, ne01, ne02, ne03,
8089
ne10, ne11, ne12, ne13,
81-
nb01, nb02, nb03,
82-
nb10, nb11, nb12,
83-
nb1, nb2, nb3,
84-
src_type_size, dst_type_size
85-
);
90+
s1, s2, s3,
91+
s10, s11, s12,
92+
s_dst, s_dst2, s_dst3,
93+
src_type_size, dst_type_size);
8694
}
8795
}
8896

97+
8998
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
9099
const ggml_tensor * src0 = dst->src[0];
91100
const ggml_tensor * src1 = dst->src[1];
@@ -101,8 +110,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
101110
cudaStream_t stream = ctx.stream();
102111

103112
if (dst->type == GGML_TYPE_F32) {
104-
set_rows_cuda<set_rows_1_f32_f32>(
105-
(const char *)src0_d, src1_d, (char *)dst->data,
113+
set_rows_cuda(
114+
src0_d, src1_d, (float*)dst->data,
106115
ne00, ne01, ne02, ne03,
107116
ne10, ne11, ne12, ne13,
108117
nb01, nb02, nb03,
@@ -112,8 +121,8 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112121
stream
113122
);
114123
} else if (dst->type == GGML_TYPE_F16) {
115-
set_rows_cuda<set_rows_1_f32_f16>(
116-
(const char *)src0_d, src1_d, (char *)dst->data,
124+
set_rows_cuda(
125+
src0_d, src1_d, (half*)dst->data,
117126
ne00, ne01, ne02, ne03,
118127
ne10, ne11, ne12, ne13,
119128
nb01, nb02, nb03,

ggml/src/ggml-cuda/set-rows.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,6 @@
22

33
#include "common.cuh"
44

5+
#define CUDA_SET_ROWS_BLOCK_SIZE 64
6+
57
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)