Skip to content

Commit

Permalink
[Attn] Fix the construction of attn result merge kernel (#1995)
Browse files Browse the repository at this point in the history
This PR fixes the mistake of passing wrong number of heads
to the attention result merge kernel.
  • Loading branch information
MasterJH5574 authored Mar 21, 2024
1 parent c74f176 commit 244c2e7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/mlc_llm/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"),
bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"),
bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"),
bb.add_func(_merge_state_inplace(num_key_value_heads, head_dim, dtype, target), "tir_attention_merge_state"),
bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
# fmt: on
Expand Down

0 comments on commit 244c2e7

Please sign in to comment.