Skip to content

Commit f62e1d7

Browse files
Introduce gather() and scatter_add() reference functions for token-routing in Mixture-of-Experts models.
PiperOrigin-RevId: 820298662
1 parent cc9a196 commit f62e1d7

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# fmt: off
16+
17+
"""Ops for gather and scatter-add operations related mixture-of-experts."""
18+
19+
import jax
20+
import jax.numpy as jnp
21+
22+
23+
@jax.custom_vjp
24+
def gather(x: jax.Array, expert_assignments: jax.Array) -> jax.Array:
25+
"""Gathers rows of `x` according `expert_assignments`.
26+
27+
Args:
28+
x: `(m, d)` array.
29+
expert_assignments: `(m, n)` array of non-negative integers that correspond
30+
to the `n` experts that each of `m` tokens is assigned to. Usually, the
31+
values of this array are bounded between `[0, num_experts)`.
32+
33+
Returns:
34+
`(m * n, d)` array of gathered rows of `x` where the rows of `x` are
35+
duplicated by a factor of `n` and grouped by ascending expert id order.
36+
"""
37+
return _gather_fwd(x, expert_assignments)[0]
38+
39+
40+
@jax.custom_vjp
41+
def scatter_add(x: jax.Array, expert_assignments: jax.Array) -> jax.Array:
42+
"""Scatter-adds rows of `x` according to the expert assignments.
43+
44+
Args:
45+
x: `(m * n, d)` array.
46+
expert_assignments: `(m, n)` array of non-negative integers that correspond
47+
to the `n` experts that each of `m` tokens is assigned to. Usually, the
48+
values of this array are bounded between `[0, num_experts)`.
49+
50+
Returns:
51+
`(m, d)` array of gathered rows of `x` where the rows of `x` are
52+
duplicated by a factor of `n` and grouped by ascending expert id order.
53+
"""
54+
return _scatter_add_fwd(x, expert_assignments)[0]
55+
56+
57+
def _gather_impl(x: jax.Array, gather_inds: jax.Array) -> jax.Array:
58+
return x[gather_inds, :]
59+
60+
61+
def _scatter_add_impl(x: jax.Array, scatter_inds: jax.Array) -> jax.Array:
62+
return jnp.sum(
63+
jnp.reshape(
64+
x[jnp.ravel(scatter_inds), :],
65+
scatter_inds.shape + (x.shape[-1],),
66+
),
67+
axis=1,
68+
)
69+
70+
71+
def _gather_fwd(
72+
x: jax.Array, expert_assignments: jax.Array
73+
) -> tuple[jax.Array, jax.Array]:
74+
gather_inds, scatter_inds = gather_scatter_inds(expert_assignments)
75+
return _gather_impl(x, gather_inds), scatter_inds
76+
77+
78+
def _scatter_add_fwd(
79+
x: jax.Array, expert_assignments: jax.Array
80+
) -> tuple[jax.Array, jax.Array]:
81+
gather_inds, scatter_inds = gather_scatter_inds(expert_assignments)
82+
return _scatter_add_impl(x, scatter_inds), gather_inds
83+
84+
85+
def _gather_bwd(res: jax.Array, grad: jax.Array) -> tuple[jax.Array, None]:
86+
scatter_inds = res
87+
return _scatter_add_impl(grad, scatter_inds), None
88+
89+
90+
def _scatter_add_bwd(res: jax.Array, grad: jax.Array) -> tuple[jax.Array, None]:
91+
gather_inds = res
92+
return _gather_impl(grad, gather_inds), None
93+
94+
95+
gather.defvjp(_gather_fwd, _gather_bwd)
96+
scatter_add.defvjp(_scatter_add_fwd, _scatter_add_bwd)
97+
98+
99+
def gather_scatter_inds(
100+
expert_assignments: jax.Array,
101+
) -> tuple[jax.Array, jax.Array]:
102+
"""Indexing arrays for gather and scatter-add operations.
103+
104+
Example:
105+
# For a system with 4 experts, 3 tokens, and 2 experts per token, we might
106+
# have expert assignments as follows:
107+
expert_assignments = [
108+
[0, 1], # token 0 is assigned to experts 0 & 1.
109+
[0, 2], # token 1 is assigned to experts 0 & 2.
110+
[1, 3], # token 2 is assigned to experts 1 & 3.
111+
]
112+
113+
# A valid `gather_inds` array would need group tokens together based on
114+
# expert assignment. For example:
115+
# - expert 0: token 0, token 1.
116+
# - expert 1: token 0, token 2.
117+
# - expert 2: token 1.
118+
# - expert 3: token 2.
119+
#
120+
# This can be accomplished with the following `gather_inds`:
121+
#
122+
gather_inds = [0, 1, 0, 2, 1, 2]
123+
124+
# A valid `scatter_inds` must scatter-add back to the original token order
125+
# by taking the values from the following positions and mapping them back to
126+
# the following original tokens:
127+
# - token 0: values from indices 0 and 2.
128+
# - token 1: values from indices 1 and 4.
129+
# - token 2: values from indices 3 and 5.
130+
#
131+
# This can be accomplished with the following `scatter_inds`:
132+
#
133+
scatter_inds = [
134+
[0, 2],
135+
[1, 4],
136+
[3, 5],
137+
]
138+
139+
Args:
140+
expert_assignments: `(m, n)` array of values within `[0, num_experts)`.
141+
142+
Returns:
143+
gather_inds: `(m * n,)` array of integers with values within `[0, m)` that
144+
duplicates and groups token by ascending expert id order via
145+
`x[gather_inds, :]`.
146+
scatter_inds: `(m, n)` array of integers with values within `[0, m * n)`
147+
that enables the scatter-add operation which returns the processed
148+
tokens for each expert via
149+
`jnp.sum(jnp.reshape(x[jnp.ravel(scatter_inds), :], (m, n, -1)), axis=1)`
150+
"""
151+
m, n = expert_assignments.shape
152+
gather_inds = jnp.argsort(jnp.ravel(expert_assignments)) // n
153+
scatter_inds = jnp.sort(jnp.reshape(jnp.argsort(gather_inds), (m, n)), axis=1)
154+
return gather_inds, scatter_inds

0 commit comments

Comments
 (0)