@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
13
13
return __half2float (val);
14
14
}
15
15
16
+ struct soft_max_params {
17
+
18
+ int64_t nheads;
19
+ uint32_t n_head_log2;
20
+ int64_t ncols;
21
+ int64_t nrows_x;
22
+ int64_t nrows_y;
23
+ int64_t ne00;
24
+ int64_t ne01;
25
+ int64_t ne02;
26
+ int64_t ne03;
27
+ int64_t nb11;
28
+ int64_t nb12;
29
+ int64_t nb13;
30
+
31
+ int64_t ne12;
32
+ int64_t ne13;
33
+ float scale;
34
+ float max_bias;
35
+ float m0;
36
+ float m1;
37
+ };
38
+
16
39
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
17
40
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
18
41
#ifdef __clang__
@@ -21,24 +44,32 @@ __device__ float __forceinline__ t2f32<half>(half val) {
21
44
#endif // __clang__
22
45
template <bool use_shared, int ncols_template, int block_size_template, typename T>
23
46
static __global__ void soft_max_f32 (
24
- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25
- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
47
+ const float * x, const T * mask, float * dst, const soft_max_params p) {
48
+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
27
49
28
50
const int tid = threadIdx .x ;
29
- const int rowx = blockIdx .x ;
30
- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
51
+
52
+ const int64_t i03 = blockIdx .z ;
53
+ const int64_t i02 = blockIdx .y ;
54
+ const int64_t i01 = blockIdx .x ;
55
+
56
+ // TODO: noncontigous inputs/outputs
57
+ const int rowx = blockIdx .x + blockIdx .y * gridDim .x + blockIdx .z * gridDim .x * gridDim .y ;
58
+
59
+ const int64_t i11 = i01;
60
+ const int64_t i12 = i02 % p.ne12 ;
61
+ const int64_t i13 = i03 % p.ne13 ;
31
62
32
63
x += int64_t (rowx)*ncols;
33
- mask += int64_t (rowy)*ncols * (mask != nullptr );
64
+ mask += (i11*p. nb11 + i12*p. nb12 + i13*p. nb13 ) / sizeof (T) * (mask != nullptr );
34
65
dst += int64_t (rowx)*ncols;
35
66
36
67
const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
37
68
38
69
const int warp_id = threadIdx .x / WARP_SIZE;
39
70
const int lane_id = threadIdx .x % WARP_SIZE;
40
71
41
- const float slope = get_alibi_slope (max_bias, rowx/nrows_y, n_head_log2, m0, m1);
72
+ const float slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
42
73
43
74
extern __shared__ float data_soft_max_f32[];
44
75
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -55,7 +86,7 @@ static __global__ void soft_max_f32(
55
86
break ;
56
87
}
57
88
58
- const float val = x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
89
+ const float val = x[col]*p. scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
59
90
60
91
vals[col] = val;
61
92
max_val = max (max_val, val);
@@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32(
151
182
}
152
183
153
184
template <typename T>
154
- static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias , cudaStream_t stream) {
185
+ static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params , cudaStream_t stream) {
155
186
int nth = WARP_SIZE;
187
+ const int64_t ncols_x = params.ncols ;
188
+
156
189
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
157
190
const dim3 block_dims (nth, 1 , 1 );
158
- const dim3 block_nums (nrows_x, 1 , 1 );
191
+ const dim3 block_nums (params. ne01 , params. ne02 , params. ne03 );
159
192
const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
160
193
static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
161
194
162
- const uint32_t n_head = nrows_x/nrows_y;
163
- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
164
-
165
- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
166
- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
167
195
168
196
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
169
197
if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
170
198
switch (ncols_x) {
171
199
case 32 :
172
200
soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
173
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
201
+ (x, mask, dst, params );
174
202
break ;
175
203
case 64 :
176
204
soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
177
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
205
+ (x, mask, dst, params );
178
206
break ;
179
207
case 128 :
180
208
soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
181
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
209
+ (x, mask, dst, params );
182
210
break ;
183
211
case 256 :
184
212
soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
185
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
213
+ (x, mask, dst, params );
186
214
break ;
187
215
case 512 :
188
216
soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
189
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
217
+ (x, mask, dst, params );
190
218
break ;
191
219
case 1024 :
192
220
soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
193
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
221
+ (x, mask, dst, params );
194
222
break ;
195
223
case 2048 :
196
224
soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
197
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
225
+ (x, mask, dst, params );
198
226
break ;
199
227
case 4096 :
200
228
soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
201
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
229
+ (x, mask, dst, params );
202
230
break ;
203
231
default :
204
232
soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
233
+ (x, mask, dst, params );
206
234
break ;
207
235
}
208
236
} else {
209
237
const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
210
- soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
238
+ soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params );
211
239
}
212
240
}
213
241
@@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
235
263
236
264
GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
237
265
238
- const int64_t ne00 = src0->ne [0 ];
239
266
const int64_t nrows_x = ggml_nrows (src0);
240
267
const int64_t nrows_y = src0->ne [1 ];
241
268
269
+ const int64_t ne00 = src0->ne [0 ];
270
+
242
271
float scale = 1 .0f ;
243
272
float max_bias = 0 .0f ;
244
273
@@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
247
276
248
277
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
249
278
279
+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
280
+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
281
+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
282
+
283
+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
284
+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
285
+
286
+ const uint32_t n_head = src0->ne [2 ];
287
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
288
+
289
+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
290
+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
291
+
292
+
293
+ soft_max_params params = {};
294
+ params.nheads = src0->ne [2 ];
295
+ params.n_head_log2 = n_head_log2;
296
+ params.ncols = ne00;
297
+ params.nrows_x = nrows_x;
298
+ params.nrows_y = nrows_y;
299
+ params.ne00 = src0->ne [0 ];
300
+ params.ne01 = src0->ne [1 ];
301
+ params.ne02 = src0->ne [2 ];
302
+ params.ne03 = src0->ne [3 ];
303
+ params.nb11 = nb11;
304
+ params.nb12 = nb12;
305
+ params.nb13 = nb13;
306
+ params.ne12 = ne12;
307
+ params.ne13 = ne13;
308
+ params.scale = scale;
309
+ params.max_bias = max_bias;
310
+ params.m0 = m0;
311
+ params.m1 = m1;
312
+
250
313
if (use_f16) {
251
- soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
314
+ soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, params , stream);
252
315
} else {
253
- soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias , stream);
316
+ soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, params , stream);
254
317
}
255
318
}
256
319
0 commit comments