1
+ #include " set-rows.cuh"
2
+
3
+ typedef void (*set_rows_kernel_t )(const char * src, char * dst);
4
+
5
+ static __device__ void set_rows_1_f32_f32 (const char * src, char * dst) {
6
+ const float * src_f = (const float *) src;
7
+ float * dst_f = (float *) dst;
8
+ *dst_f = *src_f;
9
+ }
10
+
11
+ static __device__ void set_rows_1_f32_f16 (const char * src, char * dst) {
12
+ const float * src_f = (const float *) src;
13
+ half * dst_h = (half *) dst;
14
+ *dst_h = __float2half (*src_f);
15
+ }
16
+
17
+ // TODO: consolidate kernels from cpy.cu, get_rows etc to make this function generic
18
+ template <set_rows_kernel_t set_rows_1>
19
+ static __global__ void k_set_rows (
20
+ const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
21
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
22
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
23
+ const size_t nb01, const size_t nb02, const size_t nb03,
24
+ const size_t nb10, const size_t nb11, const size_t nb12,
25
+ const size_t nb1, const size_t nb2, const size_t nb3,
26
+ const size_t src_type_size, const size_t dst_type_size) {
27
+
28
+ const int i03 = blockIdx .z ;
29
+ const int i02 = blockIdx .y ;
30
+ const int i01 = blockIdx .x * blockDim .y + threadIdx .y ; // Row index
31
+
32
+ if (i01 >= ne01) {
33
+ return ;
34
+ }
35
+
36
+ const int i12 = i03 % ne12;
37
+ const int i11 = i02 % ne11;
38
+ const int i10 = i01;
39
+
40
+ const int64_t dst_row = *(int64_t *)((char *)src1 + i10*nb10 + i11*nb11 + i12*nb12);
41
+
42
+ const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
43
+ char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
44
+
45
+ for (int col = threadIdx .x ; col < ne00; col += blockDim .x ) {
46
+ const char * src_elem = src0_row + col * src_type_size;
47
+ char * dst_elem = dst_row_ptr + col * dst_type_size;
48
+ set_rows_1 (src_elem, dst_elem);
49
+ }
50
+ }
51
+
52
+ template <set_rows_kernel_t set_rows_1>
53
+ static void set_rows_cuda (
54
+ const char * src0_d, const int64_t * src1_d, char * dst_d,
55
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
56
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
57
+ const size_t nb01, const size_t nb02, const size_t nb03,
58
+ const size_t nb10, const size_t nb11, const size_t nb12,
59
+ const size_t nb1, const size_t nb2, const size_t nb3,
60
+ const size_t src_type_size, const size_t dst_type_size,
61
+ cudaStream_t stream) {
62
+
63
+ const int max_threads_per_row = 256 ;
64
+ const int threads_per_row = std::min ((int )ne00, max_threads_per_row);
65
+
66
+ const int max_threads_per_block = 256 ;
67
+ const int rows_per_block = std::max (1 , max_threads_per_block / threads_per_row);
68
+
69
+ const dim3 block_size (threads_per_row, rows_per_block, 1 );
70
+ const dim3 grid_size (
71
+ (ne01 + rows_per_block - 1 ) / rows_per_block, // thread-groups
72
+ ne02,
73
+ ne03
74
+ );
75
+
76
+ if (ne01 > 0 && ne00 > 0 ) {
77
+ k_set_rows<set_rows_1><<<grid_size, block_size, 0 , stream>>> (
78
+ src0_d, src1_d, dst_d,
79
+ ne00, ne01, ne02, ne03,
80
+ ne10, ne11, ne12, ne13,
81
+ nb01, nb02, nb03,
82
+ nb10, nb11, nb12,
83
+ nb1, nb2, nb3,
84
+ src_type_size, dst_type_size
85
+ );
86
+ }
87
+ }
88
+
89
+ void ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
90
+ const ggml_tensor * src0 = dst->src [0 ];
91
+ const ggml_tensor * src1 = dst->src [1 ];
92
+
93
+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
94
+ GGML_ASSERT (src1->type == GGML_TYPE_I64);
95
+
96
+ GGML_TENSOR_BINARY_OP_LOCALS
97
+
98
+ const float * src0_d = (const float *)src0->data ;
99
+ const int64_t * src1_d = (const int64_t *)src1->data ;
100
+
101
+ cudaStream_t stream = ctx.stream ();
102
+
103
+ if (dst->type == GGML_TYPE_F32) {
104
+ set_rows_cuda<set_rows_1_f32_f32>(
105
+ (const char *)src0_d, src1_d, (char *)dst->data ,
106
+ ne00, ne01, ne02, ne03,
107
+ ne10, ne11, ne12, ne13,
108
+ nb01, nb02, nb03,
109
+ nb10, nb11, nb12,
110
+ nb1, nb2, nb3,
111
+ sizeof (float ), sizeof (float ),
112
+ stream
113
+ );
114
+ } else if (dst->type == GGML_TYPE_F16) {
115
+ set_rows_cuda<set_rows_1_f32_f16>(
116
+ (const char *)src0_d, src1_d, (char *)dst->data ,
117
+ ne00, ne01, ne02, ne03,
118
+ ne10, ne11, ne12, ne13,
119
+ nb01, nb02, nb03,
120
+ nb10, nb11, nb12,
121
+ nb1, nb2, nb3,
122
+ sizeof (float ), sizeof (half),
123
+ stream
124
+ );
125
+ } else {
126
+ GGML_ABORT (" unsupported type" );
127
+ }
128
+ }
0 commit comments