Skip to content

Commit 7de5c7c

Browse files
authored
CUDA: add set rows for f32 and f16 (#14551)
* CUDA: add set rows for f32 and f16 * Review: change kernel params, use strides from host * Use 1-d kernel * Review: use int64_t for blockDim.x, rename nb->s for clarity
1 parent 8eff955 commit 7de5c7c

File tree

3 files changed

+147
-0
lines changed

3 files changed

+147
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "ggml-cuda/upscale.cuh"
4444
#include "ggml-cuda/wkv.cuh"
4545
#include "ggml-cuda/gla.cuh"
46+
#include "ggml-cuda/set-rows.cuh"
4647
#include "ggml.h"
4748

4849
#include <algorithm>
@@ -2230,6 +2231,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22302231
case GGML_OP_GET_ROWS_BACK:
22312232
ggml_cuda_op_get_rows_back(ctx, dst);
22322233
break;
2234+
case GGML_OP_SET_ROWS:
2235+
ggml_cuda_op_set_rows(ctx, dst);
2236+
break;
22332237
case GGML_OP_DUP:
22342238
ggml_cuda_dup(ctx, dst);
22352239
break;
@@ -3216,6 +3220,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32163220
{
32173221
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
32183222
} break;
3223+
case GGML_OP_SET_ROWS:
3224+
{
3225+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
3226+
op->src[0]->type == GGML_TYPE_F32 &&
3227+
op->src[1]->type == GGML_TYPE_I64;
3228+
} break;
32193229
case GGML_OP_CPY:
32203230
{
32213231
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/set-rows.cu

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#include "set-rows.cuh"
2+
3+
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
4+
5+
template<typename src_t, typename dst_t>
6+
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {}
7+
8+
template<>
9+
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
10+
*dst_h = __float2half(*src_f);
11+
}
12+
13+
template<>
14+
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
15+
*dst_f = *src_f;
16+
}
17+
18+
template<typename src_t, typename dst_t>
19+
static __global__ void k_set_rows(
20+
const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
21+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
22+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
23+
const int64_t s01, const int64_t s02, const int64_t s03,
24+
const int64_t s10, const int64_t s11, const int64_t s12,
25+
const int64_t s1, const int64_t s2, const int64_t s3) {
26+
27+
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
28+
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
29+
30+
if (i >= ne_total) {
31+
return;
32+
}
33+
34+
const int64_t i03 = i / (ne00 * ne01 * ne02);
35+
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
36+
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
37+
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
38+
39+
const int64_t i12 = i03 % ne12;
40+
const int64_t i11 = i02 % ne11;
41+
const int64_t i10 = i01;
42+
43+
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
44+
45+
const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
46+
dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
47+
48+
const src_t* src_elem = src0_row + i00;
49+
dst_t* dst_elem = dst_row_ptr + i00;
50+
set_rows_1(src_elem, dst_elem);
51+
}
52+
53+
template<typename src_t, typename dst_t>
54+
static void set_rows_cuda(
55+
const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
56+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
57+
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
58+
const size_t nb01, const size_t nb02, const size_t nb03,
59+
const size_t nb10, const size_t nb11, const size_t nb12,
60+
const size_t nb1, const size_t nb2, const size_t nb3,
61+
cudaStream_t stream) {
62+
63+
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
64+
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
65+
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
66+
const dim3 grid_size(num_blocks);
67+
68+
69+
const int64_t s01 = nb01/sizeof(src_t);
70+
const int64_t s02 = nb02/sizeof(src_t);
71+
const int64_t s03 = nb03/sizeof(src_t);
72+
const int64_t s10 = nb10/sizeof(int64_t);
73+
const int64_t s11 = nb11/sizeof(int64_t);
74+
const int64_t s12 = nb12/sizeof(int64_t);
75+
const int64_t s1 = nb1/sizeof(dst_t);
76+
const int64_t s2 = nb2/sizeof(dst_t);
77+
const int64_t s3 = nb3/sizeof(dst_t);
78+
79+
if (ne_total > 0) {
80+
k_set_rows<<<grid_size, block_size, 0, stream>>>(
81+
src0_d, src1_d, dst_d,
82+
ne00, ne01, ne02, ne03,
83+
ne10, ne11, ne12, ne13,
84+
s01, s02, s03,
85+
s10, s11, s12,
86+
s1, s2, s3);
87+
}
88+
}
89+
90+
91+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
92+
const ggml_tensor * src0 = dst->src[0];
93+
const ggml_tensor * src1 = dst->src[1];
94+
95+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
96+
GGML_ASSERT(src1->type == GGML_TYPE_I64);
97+
98+
GGML_TENSOR_BINARY_OP_LOCALS
99+
100+
const float * src0_d = (const float *)src0->data;
101+
const int64_t * src1_d = (const int64_t *)src1->data;
102+
103+
cudaStream_t stream = ctx.stream();
104+
105+
106+
107+
if (dst->type == GGML_TYPE_F32) {
108+
set_rows_cuda(
109+
src0_d, src1_d, (float*)dst->data,
110+
ne00, ne01, ne02, ne03,
111+
ne10, ne11, ne12, ne13,
112+
nb01, nb02, nb03,
113+
nb10, nb11, nb12,
114+
nb1, nb2, nb3,
115+
stream
116+
);
117+
} else if (dst->type == GGML_TYPE_F16) {
118+
set_rows_cuda(
119+
src0_d, src1_d, (half*)dst->data,
120+
ne00, ne01, ne02, ne03,
121+
ne10, ne11, ne12, ne13,
122+
nb01, nb02, nb03,
123+
nb10, nb11, nb12,
124+
nb1, nb2, nb3,
125+
stream
126+
);
127+
} else {
128+
GGML_ABORT("unsupported type");
129+
}
130+
}

ggml/src/ggml-cuda/set-rows.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include "common.cuh"
4+
5+
#define CUDA_SET_ROWS_BLOCK_SIZE 256
6+
7+
void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)