Skip to content

Commit 27272ef

Browse files
committed
workaround: avoid ggml cuda error
1 parent 18a2804 commit 27272ef

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

ggml_extend.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
954954
if (scale != 1.f) {
955955
x = ggml_scale(ctx, x, scale);
956956
}
957-
x = ggml_mul_mat(ctx, w, x);
957+
if (x->ne[2] * x->ne[3] > 1024) {
958+
// workaround: avoid ggml cuda error
959+
int64_t ne2 = x->ne[2];
960+
int64_t ne3 = x->ne[3];
961+
x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1]*x->ne[2]*x->ne[3]);
962+
x = ggml_mul_mat(ctx, w, x);
963+
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1]/ne2/ne3, ne2, ne3);
964+
} else {
965+
x = ggml_mul_mat(ctx, w, x);
966+
}
958967
if (force_prec_f32) {
959968
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
960969
}

0 commit comments

Comments
 (0)