7474 {F_pagedkv},
7575 kHasUnevenSplits,
7676 kMergeNumHeadGroupsSeqLenQ,
77- {F_occupancy},
78- {F_sink}>;
77+ {F_occupancy}>;
7978
8079using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
8180 typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
119118}} // anonymous namespace
120119
121120using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
122- {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, { F_spad}, {F_skpad}, {F_dpad},
121+ {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
123122 {F_dvpad}>;
124123
125124#pragma clang diagnostic push
281280"""
282281
283282FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
284- ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ( {F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
285- using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
283+ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
284+ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
286285
287286 // get combine kernel tile sizes
288287 using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
@@ -334,15 +333,14 @@ class FmhaFwdSplitKVApiTrait:
334333 dpad : str
335334 dvpad : str
336335 pagedkv : str
337- sink : str # sink or not
338336 bn1comb : int # tile size along v head_dim of combine kernel
339337
340338 @property
341339 def name (self ) -> str :
342340 return (
343341 f"{ self .hdim } -{ self .dtype } -{ self .mode } -{ self .bm0 } -{ self .bn0 } -{ self .bk0 } -{ self .bn0 } -{ self .bk1 } -{ self .bk0max } -"
344342 + f"{ self .vlayout } -{ self .logits } -{ self .mask } -{ self .bias } -{ self .lse } -{ self .squant } -{ self .spad } -{ self .skpad } -{ self .dpad } -"
345- + f"{ self .dvpad } -{ self .pagedkv } - { self . sink } "
343+ + f"{ self .dvpad } -{ self .pagedkv } "
346344 )
347345
348346 @property
@@ -428,7 +426,6 @@ class FmhaFwdSplitKVPipeline:
428426 F_lse : str #
429427 F_squant : str #
430428 F_pagedkv : str # t/f
431- F_sink : str # t/f
432429 F_mask : str # value from MASK_MAP
433430
434431 @property
@@ -489,10 +486,6 @@ def pad_name() -> str:
489486 n += "_pagedkv"
490487 else :
491488 n += "_npagedkv"
492- if self .F_sink == "t" :
493- n += "_sink"
494- else :
495- n += "_nsink"
496489 return n
497490
498491
@@ -575,7 +568,6 @@ def api(self) -> str:
575568 F_lse = BOOL_MAP [trait .lse ],
576569 F_squant = BOOL_MAP [trait .squant ],
577570 F_pagedkv = BOOL_MAP [trait .pagedkv ],
578- F_sink = BOOL_MAP [trait .sink ],
579571 F_scheck = trait .scheck ,
580572 F_skcheck = trait .skcheck ,
581573 F_dcheck = trait .dcheck ,
@@ -676,7 +668,6 @@ def template(self) -> str:
676668 F_squant = BOOL_MAP [self .F_pipeline .F_squant ],
677669 F_pagedkv = BOOL_MAP [self .F_pipeline .F_pagedkv ],
678670 F_occupancy = self .F_tile .F_occupancy ,
679- F_sink = BOOL_MAP [self .F_pipeline .F_sink ],
680671 F_pipeline_enum = PIPELINE_ENUM_MAP [self .F_pipeline .tag ],
681672 F_mask = get_mask_map (self .mask_impl )[self .F_pipeline .F_mask ],
682673 F_mode = MODE_MAP [self .F_mode ],
@@ -750,23 +741,19 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
750741 squant = "t" if dtype == "fp8" else "f"
751742 pipelines = []
752743 if dtype in ["fp16" , "bf16" ]:
753- for logits , mask , bias , pagedkv , sink in itertools .product (
754- ["t" , "f" ],
755- get_mask_map (mask_impl ).keys (),
756- BIAS_MAP .keys (),
757- ["t" , "f" ],
758- ["t" , "f" ],
744+ for logits , mask , bias , pagedkv in itertools .product (
745+ ["t" , "f" ], get_mask_map (mask_impl ).keys (), BIAS_MAP .keys (), ["t" , "f" ]
759746 ):
760- pipelines .append (Pipeline ("qr" , "row" , "f" , "t" , "f" , "f" , logits , bias , "t" , squant , pagedkv , sink , mask )) # fmt: skip
761- pipelines .append (Pipeline ("qr" , "row" , "t" , "f" , "f" , "f" , logits , bias , "t" , squant , pagedkv , sink , mask )) # fmt: skip
762- pipelines .append (Pipeline ("qr" , "row" , "t" , "t" , "f" , "f" , logits , bias , "t" , squant , pagedkv , sink , mask )) # fmt: skip
763- pipelines .append (Pipeline ("qr" , "row" , "t" , "t" , "t" , "t" , logits , bias , "t" , squant , pagedkv , sink , mask )) # fmt: skip
747+ pipelines .append (Pipeline ("qr" , "row" , "f" , "t" , "f" , "f" , logits , bias , "t" , squant , pagedkv , mask )) # fmt: skip
748+ pipelines .append (Pipeline ("qr" , "row" , "t" , "f" , "f" , "f" , logits , bias , "t" , squant , pagedkv , mask )) # fmt: skip
749+ pipelines .append (Pipeline ("qr" , "row" , "t" , "t" , "f" , "f" , logits , bias , "t" , squant , pagedkv , mask )) # fmt: skip
750+ pipelines .append (Pipeline ("qr" , "row" , "t" , "t" , "t" , "t" , logits , bias , "t" , squant , pagedkv , mask )) # fmt: skip
764751 elif dtype in ["fp8" , "bf8" ]:
765752 for logits , mask , bias in itertools .product (
766753 ["t" , "f" ], get_mask_map (mask_impl ).keys (), BIAS_MAP .keys ()
767754 ):
768- pipelines .append (Pipeline ("qr" , "row" , "f" , "f" , "f" , "f" , logits , bias , "t" , squant , "f" , "f" , mask )) # fmt: skip
769- pipelines .append (Pipeline ("qr" , "row" , "t" , "t" , "f" , "f" , logits , bias , "t" , squant , "f" , "f" , mask )) # fmt: skip
755+ pipelines .append (Pipeline ("qr" , "row" , "f" , "f" , "f" , "f" , logits , bias , "t" , squant , "f" , mask )) # fmt: skip
756+ pipelines .append (Pipeline ("qr" , "row" , "t" , "t" , "f" , "f" , logits , bias , "t" , squant , "f" , mask )) # fmt: skip
770757 elif dtype in ["fp8fp16" , "fp8bf16" ]:
771758 # TODO
772759 None
@@ -922,7 +909,6 @@ def get_fwd_splitkv_blobs(
922909 cond &= pipeline .F_vlayout == "row"
923910 cond &= pipeline .F_bias in ["no" , "alibi" ]
924911 cond &= pipeline .F_squant == "f"
925- cond &= pipeline .F_sink == "f"
926912 if not cond :
927913 continue
928914 # PyTorch integration
@@ -932,7 +918,6 @@ def get_fwd_splitkv_blobs(
932918 cond &= pipeline .F_bias in ["no" , "bias" ]
933919 cond &= pipeline .F_squant == "f"
934920 cond &= mode == "batch"
935- cond &= pipeline .F_sink == "f"
936921 if not cond :
937922 continue
938923 # Aiter(mha_varlen_fwd) integration
@@ -1091,7 +1076,6 @@ def write_blobs(
10911076 lse = kernel .F_pipeline .F_lse ,
10921077 squant = kernel .F_pipeline .F_squant ,
10931078 pagedkv = kernel .F_pipeline .F_pagedkv ,
1094- sink = kernel .F_pipeline .F_sink ,
10951079 spad = kernel .F_pipeline .F_spad ,
10961080 skpad = kernel .F_pipeline .F_skpad ,
10971081 dpad = kernel .F_pipeline .F_dpad ,
0 commit comments