From 00f530373c6961a3070b809e572b980dc5d80766 Mon Sep 17 00:00:00 2001 From: Annanya <34443592+annanyapr@users.noreply.github.com> Date: Fri, 24 Jan 2025 10:43:49 -0500 Subject: [PATCH] [Model] Fused rope implementation for DeepSeek-v2 (#3105) This PR fuses the rope operation with the matrix transposes which we were doing before the rope operation. Have tested the model (after compiling) post the changes. --- .../model/deepseek_v2/deepseek_v2_model.py | 44 +++++++++++-------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py index e144dd77db..566a429003 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py @@ -137,38 +137,44 @@ def forward( k: Tensor, positions: Tensor, ): - def _rope(x: te.Tensor, positions: te.Tensor): + def _rope_fused(x: te.Tensor, positions: te.Tensor): + _, _, _, d_dim = x.shape + d_dim_half = d_dim // 2 dtype = x.dtype def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): + d1 = d // d_dim_half + d2 = d % d_dim_half + cos_freq, sin_freq, var_map = self.rope_fn( positions[s], d, self.rotary_dim, self.theta, dtype ) - cos = cos_freq * x[b, s, h, d] - sin = sin_freq * tir.if_then_else( + cos = x[b, s, h, d2 * 2 + d1] * cos_freq + + partner_d = tir.if_then_else( d < self.rotary_dim // 2, - -x[b, s, h, d + self.rotary_dim // 2], - x[b, s, h, d - self.rotary_dim // 2], + d + self.rotary_dim // 2, + d - self.rotary_dim // 2, + ) + + partner_d1 = partner_d // d_dim_half + partner_d2 = partner_d % d_dim_half + sin = ( + x[b, s, h, partner_d2 * 2 + partner_d1] + * sin_freq + * tir.if_then_else( + d < self.rotary_dim // 2, tir.const(-1, dtype), tir.const(1, dtype) + ) ) expr = cos + sin - for var, value in var_map.items(): - expr = tir.Let(var, value, expr) + for var, val in var_map.items(): + expr = tir.Let(var, val, expr) return expr return te.compute(x.shape, compute, name="yarn_rope") - b, s, h, d = q.shape - q = op.reshape( - op.permute_dims(op.reshape(q, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d) - ) - - b, s, h, d = k.shape - k = op.reshape( - op.permute_dims(op.reshape(k, (b, s, h, d // 2, 2)), [0, 1, 2, 4, 3]), (b, s, h, d) - ) - - q_embed = op.tensor_expr_op(_rope, "rope", [q, positions]) - k_embed = op.tensor_expr_op(_rope, "rope", [k, positions]) + q_embed = op.tensor_expr_op(_rope_fused, "rope", [q, positions]) + k_embed = op.tensor_expr_op(_rope_fused, "rope", [k, positions]) return q_embed, k_embed