Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 93 additions & 36 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}
return (
cos_freq,
sin_freq,
{smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq},
)


def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -147,7 +151,11 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
smoothed_freq_var = tir.Var("smoothed_freq", "float32")
cos_freq = tir.cos(smoothed_freq_var).astype(dtype)
sin_freq = tir.sin(smoothed_freq_var).astype(dtype)
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}
return (
cos_freq,
sin_freq,
{smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq},
)


def rope_freq_longrope( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -275,7 +283,7 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
beta_fast=rope_scaling["beta_fast"],
beta_slow=rope_scaling["beta_slow"],
)
raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}')
raise ValueError(f"Unsupported RoPE scaling type: {rope_scaling['rope_type']}")


# mypy: disable-error-code="attr-defined"
Expand Down Expand Up @@ -570,7 +578,10 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
# long factors is the first half, short factors is the second half
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
short_factors = T.Buffer(
(rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2)
(rotary_dim // 2,),
"float32",
data=ext_factors.data,
elem_offset=(rotary_dim // 2),
)

if seq_len > original_max_position_embeddings:
Expand Down Expand Up @@ -687,6 +698,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
rotary_dim = head_dim
scale = tir.const(scale, "float32")
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling:
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
else:
original_max_position_embeddings = 0

def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
Expand Down Expand Up @@ -770,7 +785,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore
):
T.func_attr(
{
Expand All @@ -787,37 +802,79 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
)
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
# long factors is the first half, short factors is the second half
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
short_factors = T.Buffer(
(rotary_dim // 2,),
"float32",
data=ext_factors.data,
elem_offset=(rotary_dim // 2),
)

if seq_len > original_max_position_embeddings:
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
long_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
long_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
else:
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
short_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
short_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

if is_longrope_scaling:
return fused_rope_longrope_scaling
Expand Down
Loading