@@ -117,7 +117,11 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals
117117 smoothed_freq_var = tir .Var ("smoothed_freq" , "float32" )
118118 cos_freq = tir .cos (smoothed_freq_var ).astype (dtype )
119119 sin_freq = tir .sin (smoothed_freq_var ).astype (dtype )
120- return cos_freq , sin_freq , {smoothed_freq_var : smoothed_freq , orig_freq_var : orig_freq }
120+ return (
121+ cos_freq ,
122+ sin_freq ,
123+ {smoothed_freq_var : smoothed_freq , orig_freq_var : orig_freq },
124+ )
121125
122126
123127def rope_freq_llama3 ( # pylint: disable=too-many-arguments,too-many-locals
@@ -147,7 +151,11 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
147151 smoothed_freq_var = tir .Var ("smoothed_freq" , "float32" )
148152 cos_freq = tir .cos (smoothed_freq_var ).astype (dtype )
149153 sin_freq = tir .sin (smoothed_freq_var ).astype (dtype )
150- return cos_freq , sin_freq , {smoothed_freq_var : smoothed_freq , orig_freq_var : orig_freq }
154+ return (
155+ cos_freq ,
156+ sin_freq ,
157+ {smoothed_freq_var : smoothed_freq , orig_freq_var : orig_freq },
158+ )
151159
152160
153161def rope_freq_longrope ( # pylint: disable=too-many-arguments
@@ -275,7 +283,7 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
275283 beta_fast = rope_scaling ["beta_fast" ],
276284 beta_slow = rope_scaling ["beta_slow" ],
277285 )
278- raise ValueError (f' Unsupported RoPE scaling type: { rope_scaling [" rope_type" ] } ' )
286+ raise ValueError (f" Unsupported RoPE scaling type: { rope_scaling [' rope_type' ] } " )
279287
280288
281289# mypy: disable-error-code="attr-defined"
@@ -570,7 +578,10 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
570578 # long factors is the first half, short factors is the second half
571579 long_factors = T .Buffer ((rotary_dim // 2 ,), "float32" , data = ext_factors .data )
572580 short_factors = T .Buffer (
573- (rotary_dim // 2 ,), "float32" , data = ext_factors .data , elem_offset = (rotary_dim // 2 )
581+ (rotary_dim // 2 ,),
582+ "float32" ,
583+ data = ext_factors .data ,
584+ elem_offset = (rotary_dim // 2 ),
574585 )
575586
576587 if seq_len > original_max_position_embeddings :
@@ -687,6 +698,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments
687698 rotary_dim = head_dim
688699 scale = tir .const (scale , "float32" )
689700 is_longrope_scaling = rope_scaling .get ("rope_type" ) == "longrope"
701+ if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling :
702+ original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ]
703+ else :
704+ original_max_position_embeddings = 0
690705
691706 def _rope ( # pylint: disable=too-many-arguments
692707 x : T .Buffer ,
@@ -770,7 +785,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
770785 var_q : T .handle ,
771786 var_k : T .handle ,
772787 var_v : T .handle ,
773- ext_factors : T .Buffer ((rotary_dim // 2 ,), "float32" ), # type: ignore
788+ ext_factors : T .Buffer ((rotary_dim ,), "float32" ), # type: ignore
774789 ):
775790 T .func_attr (
776791 {
@@ -787,37 +802,79 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
787802 position_map = T .match_buffer (
788803 var_position_map , (seq_len ,), "int32" , elem_offset = position_map_elem_offset
789804 )
790- for iters in T .grid (seq_len , fused_heads , head_dim ):
791- with T .block ("llama_fused_rope" ):
792- s , h , d = T .axis .remap ("SSS" , iters )
793- if h < num_q_heads :
794- q [s , h , d ] = T .if_then_else (
795- d < rotary_dim ,
796- _rope (
797- qkv ,
798- s ,
799- h ,
800- d ,
801- position_map [s ],
802- ext_factors if is_longrope_scaling else None ,
803- ),
804- qkv [s , h , d ],
805- )
806- elif h < num_q_heads + num_kv_heads :
807- k [s , h - num_q_heads , d ] = T .if_then_else (
808- d < rotary_dim ,
809- _rope (
810- qkv ,
811- s ,
812- h ,
813- d ,
814- position_map [s ],
815- ext_factors if is_longrope_scaling else None ,
816- ),
817- qkv [s , h , d ],
818- )
819- else :
820- v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
805+ # long factors is the first half, short factors is the second half
806+ long_factors = T .Buffer ((rotary_dim // 2 ,), "float32" , data = ext_factors .data )
807+ short_factors = T .Buffer (
808+ (rotary_dim // 2 ,),
809+ "float32" ,
810+ data = ext_factors .data ,
811+ elem_offset = (rotary_dim // 2 ),
812+ )
813+
814+ if seq_len > original_max_position_embeddings :
815+ for iters in T .grid (seq_len , fused_heads , head_dim ):
816+ with T .block ("llama_fused_rope" ):
817+ s , h , d = T .axis .remap ("SSS" , iters )
818+ if h < num_q_heads :
819+ q [s , h , d ] = T .if_then_else (
820+ d < rotary_dim ,
821+ _rope (
822+ qkv ,
823+ s ,
824+ h ,
825+ d ,
826+ position_map [s ],
827+ long_factors if is_longrope_scaling else None ,
828+ ),
829+ qkv [s , h , d ],
830+ )
831+ elif h < num_q_heads + num_kv_heads :
832+ k [s , h - num_q_heads , d ] = T .if_then_else (
833+ d < rotary_dim ,
834+ _rope (
835+ qkv ,
836+ s ,
837+ h ,
838+ d ,
839+ position_map [s ],
840+ long_factors if is_longrope_scaling else None ,
841+ ),
842+ qkv [s , h , d ],
843+ )
844+ else :
845+ v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
846+ else :
847+ for iters in T .grid (seq_len , fused_heads , head_dim ):
848+ with T .block ("llama_fused_rope" ):
849+ s , h , d = T .axis .remap ("SSS" , iters )
850+ if h < num_q_heads :
851+ q [s , h , d ] = T .if_then_else (
852+ d < rotary_dim ,
853+ _rope (
854+ qkv ,
855+ s ,
856+ h ,
857+ d ,
858+ position_map [s ],
859+ short_factors if is_longrope_scaling else None ,
860+ ),
861+ qkv [s , h , d ],
862+ )
863+ elif h < num_q_heads + num_kv_heads :
864+ k [s , h - num_q_heads , d ] = T .if_then_else (
865+ d < rotary_dim ,
866+ _rope (
867+ qkv ,
868+ s ,
869+ h ,
870+ d ,
871+ position_map [s ],
872+ short_factors if is_longrope_scaling else None ,
873+ ),
874+ qkv [s , h , d ],
875+ )
876+ else :
877+ v [s , h - (num_q_heads + num_kv_heads ), d ] = qkv [s , h , d ]
821878
822879 if is_longrope_scaling :
823880 return fused_rope_longrope_scaling
0 commit comments