Skip to content

ggml : implement REGLU/GEGLU/SWIGLU ops #14158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,8 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW,

GGML_OP_GLU,

GGML_OP_COUNT,
};

Expand All @@ -543,6 +545,14 @@ extern "C" {
GGML_UNARY_OP_COUNT,
};

enum ggml_glu_op {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,

GGML_GLU_OP_COUNT,
};

enum ggml_object_type {
GGML_OBJECT_TYPE_TENSOR,
GGML_OBJECT_TYPE_GRAPH,
Expand Down Expand Up @@ -658,6 +668,7 @@ extern "C" {
GGML_API const char * ggml_op_symbol(enum ggml_op op);

GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name

GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
Expand Down Expand Up @@ -762,6 +773,7 @@ extern "C" {
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);

GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);

GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
Expand Down Expand Up @@ -1090,6 +1102,63 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
// expects gate in second half of row, unless swapped is true
GGML_API struct ggml_tensor * ggml_glu(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_glu_op op,
bool swapped);

GGML_API struct ggml_tensor * ggml_reglu(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_reglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_geglu(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_geglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_swiglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);

// A: n columns, r rows,
// B: n columns, r rows,
GGML_API struct ggml_tensor * ggml_glu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_glu_op op);

GGML_API struct ggml_tensor * ggml_reglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_geglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_swiglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);

// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
Expand Down
16 changes: 16 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_unary(params, tensor);
} break;
case GGML_OP_GLU:
{
ggml_compute_forward_glu(params, tensor);
} break;
case GGML_OP_GET_REL_POS:
{
ggml_compute_forward_get_rel_pos(params, tensor);
Expand Down Expand Up @@ -2159,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
GGML_ABORT("fatal error");
}
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(node)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
{
n_tasks = n_threads;
} break;
default:
GGML_ABORT("fatal error");
}
break;
case GGML_OP_SILU_BACK:
case GGML_OP_MUL:
case GGML_OP_DIV:
Expand Down
Loading
Loading