Skip to content

Commit fb5ab8d

Browse files
committed
Fix llama4_rope_with_position_map to support partial rotary factor
1 parent d5d3d81 commit fb5ab8d

File tree

1 file changed

+93
-36
lines changed

1 file changed

+93
-36
lines changed

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
117117
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
118118
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
119119
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
120-
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}
120+
return (
121+
cos_freq,
122+
sin_freq,
123+
{smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq},
124+
)
121125

122126

123127
def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
@@ -147,7 +151,11 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
147151
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
148152
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
149153
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
150-
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}
154+
return (
155+
cos_freq,
156+
sin_freq,
157+
{smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq},
158+
)
151159

152160

153161
def rope_freq_longrope( # pylint: disable=too-many-arguments
@@ -275,7 +283,7 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
275283
beta_fast=rope_scaling["beta_fast"],
276284
beta_slow=rope_scaling["beta_slow"],
277285
)
278-
raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}')
286+
raise ValueError(f"Unsupported RoPE scaling type: {rope_scaling['rope_type']}")
279287

280288

281289
# mypy: disable-error-code="attr-defined"
@@ -570,7 +578,10 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
570578
# long factors is the first half, short factors is the second half
571579
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
572580
short_factors = T.Buffer(
573-
(rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2)
581+
(rotary_dim // 2,),
582+
"float32",
583+
data=ext_factors.data,
584+
elem_offset=(rotary_dim // 2),
574585
)
575586

576587
if seq_len > original_max_position_embeddings:
@@ -687,6 +698,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
687698
rotary_dim = head_dim
688699
scale = tir.const(scale, "float32")
689700
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
701+
if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling:
702+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
703+
else:
704+
original_max_position_embeddings = 0
690705

691706
def _rope( # pylint: disable=too-many-arguments
692707
x: T.Buffer,
@@ -770,7 +785,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
770785
var_q: T.handle,
771786
var_k: T.handle,
772787
var_v: T.handle,
773-
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
788+
ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore
774789
):
775790
T.func_attr(
776791
{
@@ -787,37 +802,79 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
787802
position_map = T.match_buffer(
788803
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
789804
)
790-
for iters in T.grid(seq_len, fused_heads, head_dim):
791-
with T.block("llama_fused_rope"):
792-
s, h, d = T.axis.remap("SSS", iters)
793-
if h < num_q_heads:
794-
q[s, h, d] = T.if_then_else(
795-
d < rotary_dim,
796-
_rope(
797-
qkv,
798-
s,
799-
h,
800-
d,
801-
position_map[s],
802-
ext_factors if is_longrope_scaling else None,
803-
),
804-
qkv[s, h, d],
805-
)
806-
elif h < num_q_heads + num_kv_heads:
807-
k[s, h - num_q_heads, d] = T.if_then_else(
808-
d < rotary_dim,
809-
_rope(
810-
qkv,
811-
s,
812-
h,
813-
d,
814-
position_map[s],
815-
ext_factors if is_longrope_scaling else None,
816-
),
817-
qkv[s, h, d],
818-
)
819-
else:
820-
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
805+
# long factors is the first half, short factors is the second half
806+
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
807+
short_factors = T.Buffer(
808+
(rotary_dim // 2,),
809+
"float32",
810+
data=ext_factors.data,
811+
elem_offset=(rotary_dim // 2),
812+
)
813+
814+
if seq_len > original_max_position_embeddings:
815+
for iters in T.grid(seq_len, fused_heads, head_dim):
816+
with T.block("llama_fused_rope"):
817+
s, h, d = T.axis.remap("SSS", iters)
818+
if h < num_q_heads:
819+
q[s, h, d] = T.if_then_else(
820+
d < rotary_dim,
821+
_rope(
822+
qkv,
823+
s,
824+
h,
825+
d,
826+
position_map[s],
827+
long_factors if is_longrope_scaling else None,
828+
),
829+
qkv[s, h, d],
830+
)
831+
elif h < num_q_heads + num_kv_heads:
832+
k[s, h - num_q_heads, d] = T.if_then_else(
833+
d < rotary_dim,
834+
_rope(
835+
qkv,
836+
s,
837+
h,
838+
d,
839+
position_map[s],
840+
long_factors if is_longrope_scaling else None,
841+
),
842+
qkv[s, h, d],
843+
)
844+
else:
845+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
846+
else:
847+
for iters in T.grid(seq_len, fused_heads, head_dim):
848+
with T.block("llama_fused_rope"):
849+
s, h, d = T.axis.remap("SSS", iters)
850+
if h < num_q_heads:
851+
q[s, h, d] = T.if_then_else(
852+
d < rotary_dim,
853+
_rope(
854+
qkv,
855+
s,
856+
h,
857+
d,
858+
position_map[s],
859+
short_factors if is_longrope_scaling else None,
860+
),
861+
qkv[s, h, d],
862+
)
863+
elif h < num_q_heads + num_kv_heads:
864+
k[s, h - num_q_heads, d] = T.if_then_else(
865+
d < rotary_dim,
866+
_rope(
867+
qkv,
868+
s,
869+
h,
870+
d,
871+
position_map[s],
872+
short_factors if is_longrope_scaling else None,
873+
),
874+
qkv[s, h, d],
875+
)
876+
else:
877+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
821878

822879
if is_longrope_scaling:
823880
return fused_rope_longrope_scaling

0 commit comments

Comments
 (0)