Skip to content

Commit

Permalink
[Deepseek-V2] Fused rope implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
annanyapr committed Jan 22, 2025
1 parent a175d44 commit 8eac411
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,39 +136,41 @@ def forward(
q: Tensor,
k: Tensor,
positions: Tensor,
):
def _rope(x: te.Tensor, positions: te.Tensor):
):
def _rope_fused(x: te.Tensor, positions: te.Tensor):
_, _, _, d_dim = x.shape
d_dim_half = d_dim // 2
dtype = x.dtype

def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
cos_freq, sin_freq, var_map = self.rope_fn(
positions[s], d, self.rotary_dim, self.theta, dtype
d1 = d // d_dim_half
d2 = d % d_dim_half

cos_freq, sin_freq, var_map = self.rope_fn(positions[s], d, self.rotary_dim, self.theta, dtype)
cos = x[b, s, h, d2 * 2 + d1] * cos_freq

partner_d = tir.if_then_else(
d < self.rotary_dim // 2,
d + self.rotary_dim // 2,
d - self.rotary_dim // 2,
)
cos = cos_freq * x[b, s, h, d]
sin = sin_freq * tir.if_then_else(

partner_d1 = partner_d // d_dim_half
partner_d2 = partner_d % d_dim_half
sin = x[b, s, h, partner_d2 * 2 + partner_d1] * sin_freq * tir.if_then_else(
d < self.rotary_dim // 2,
-x[b, s, h, d + self.rotary_dim // 2],
x[b, s, h, d - self.rotary_dim // 2],
tir.const(-1, dtype),
tir.const(1, dtype)
)
expr = cos + sin
for var, value in var_map.items():
expr = tir.Let(var, value, expr)
for var, val in var_map.items():
expr = tir.Let(var, val, expr)
return expr

return te.compute(x.shape, compute, name="yarn_rope")

b, s, h, d = q.shape
q = op.reshape(
op.permute_dims(op.reshape(q, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d)
)

b, s, h, d = k.shape
k = op.reshape(
op.permute_dims(op.reshape(k, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d)
)

q_embed = op.tensor_expr_op(_rope, "rope", [q, positions])
k_embed = op.tensor_expr_op(_rope, "rope", [k, positions])
q_embed = op.tensor_expr_op(_rope_fused, "rope", [q, positions])
k_embed = op.tensor_expr_op(_rope_fused, "rope", [k, positions])
return q_embed, k_embed


Expand Down

0 comments on commit 8eac411

Please sign in to comment.