Skip to content

Commit

Permalink
[Model] Fused rope implementation for DeepSeek-v2 (#3105)
Browse files Browse the repository at this point in the history
This PR fuses the rope operation with the matrix transposes which
we were doing before the rope operation. Have tested the
model (after compiling) post the changes.
  • Loading branch information
annanyapr authored Jan 24, 2025
1 parent 2c1001b commit 00f5303
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,38 +137,44 @@ def forward(
k: Tensor,
positions: Tensor,
):
def _rope(x: te.Tensor, positions: te.Tensor):
def _rope_fused(x: te.Tensor, positions: te.Tensor):
_, _, _, d_dim = x.shape
d_dim_half = d_dim // 2
dtype = x.dtype

def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
d1 = d // d_dim_half
d2 = d % d_dim_half

cos_freq, sin_freq, var_map = self.rope_fn(
positions[s], d, self.rotary_dim, self.theta, dtype
)
cos = cos_freq * x[b, s, h, d]
sin = sin_freq * tir.if_then_else(
cos = x[b, s, h, d2 * 2 + d1] * cos_freq

partner_d = tir.if_then_else(
d < self.rotary_dim // 2,
-x[b, s, h, d + self.rotary_dim // 2],
x[b, s, h, d - self.rotary_dim // 2],
d + self.rotary_dim // 2,
d - self.rotary_dim // 2,
)

partner_d1 = partner_d // d_dim_half
partner_d2 = partner_d % d_dim_half
sin = (
x[b, s, h, partner_d2 * 2 + partner_d1]
* sin_freq
* tir.if_then_else(
d < self.rotary_dim // 2, tir.const(-1, dtype), tir.const(1, dtype)
)
)
expr = cos + sin
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
for var, val in var_map.items():
expr = tir.Let(var, val, expr)
return expr

return te.compute(x.shape, compute, name="yarn_rope")

b, s, h, d = q.shape
q = op.reshape(
op.permute_dims(op.reshape(q, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d)
)

b, s, h, d = k.shape
k = op.reshape(
op.permute_dims(op.reshape(k, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d)
)

q_embed = op.tensor_expr_op(_rope, "rope", [q, positions])
k_embed = op.tensor_expr_op(_rope, "rope", [k, positions])
q_embed = op.tensor_expr_op(_rope_fused, "rope", [q, positions])
k_embed = op.tensor_expr_op(_rope_fused, "rope", [k, positions])
return q_embed, k_embed


Expand Down

0 comments on commit 00f5303

Please sign in to comment.