Skip to content

Commit 9fa4e8d

Browse files
LJ-underdogCopilot
andauthored
Add attn sink (#2892)
* enable attn sink Signed-off-by: JL-underdog <[email protected]> * update attn_sink script Signed-off-by: JL-underdog <[email protected]> * fix some error Signed-off-by: JL-underdog <[email protected]> * clang-format Signed-off-by: JL-underdog <[email protected]> * update fmha_bwd mask Signed-off-by: JL-underdog <[email protected]> * update fmha_bwd_kernel'mask Signed-off-by: JL-underdog <[email protected]> * update block_fmha_pipeline_qr_ks_vs.hpp Signed-off-by: JL-underdog <[email protected]> * fix ci error Signed-off-by: LJ-underdog <[email protected]> * fix format error Signed-off-by: LJ-underdog <[email protected]> * Update block_fmha_bwd_pipeline_default_policy.hpp * Update fmha_fwd_runner.hpp * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * update splitkv_pipline Signed-off-by: LJ-underdog <[email protected]> * update splitkv&pagedkv pipeline Signed-off-by: LJ-underdog <[email protected]> * add sink test Signed-off-by: LJ-underdog <[email protected]> * update attn_sink result log Signed-off-by: LJ-underdog <[email protected]> * update smoke_test_fwd_sink.sh Signed-off-by: LJ-underdog <[email protected]> * update test file Signed-off-by: LJ-underdog <[email protected]> * update test script Signed-off-by: LJ-underdog <[email protected]> * Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp * use constexpr kHasSink for sink in fmha pipeline Signed-off-by: Linjun-AMD <[email protected]> * update by pre-commit Signed-off-by: Linjun-AMD <[email protected]> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <[email protected]> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <[email protected]> * Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Co-authored-by: Copilot <[email protected]> * Update fmha_fwd.py * Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py Co-authored-by: Copilot <[email protected]> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <[email protected]> * Remove causal mask setting logic from mask.hpp Removed the mask setting logic for causal masks. * fix ci error that some usage of lamada not support in c++17 Signed-off-by: LJ-underdog <[email protected]> * Update remod.py * add smoke sink test Signed-off-by: LJ-underdog <[email protected]> * Update fmha_pagedkv_prefill.py * Update FmhaFwdPipeline parameters in fmha_fwd.py * update block_fmha_pipeline_qr_ks_vs_async_trload.hpp Signed-off-by: LJ-underdog <[email protected]> * fix c++17 unsupprot error Signed-off-by: LJ-underdog <[email protected]> * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp * Fix formatting of sink_seq_end assignment * Fix indentation for sink_seq_end assignment * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp --------- Signed-off-by: JL-underdog <[email protected]> Signed-off-by: LJ-underdog <[email protected]> Signed-off-by: Linjun-AMD <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 84540ed commit 9fa4e8d

25 files changed

+940
-195
lines changed

example/ck_tile/01_fmha/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
6262
# there is no corresponding instance for parameters).
6363
if(BUILD_TESTING)
6464
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
65-
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
65+
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
6666
endif()
6767

6868
# generate a list of kernels, but not actually emit files at config sta

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 46 additions & 28 deletions
Large diffs are not rendered by default.

example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@
7474
{F_pagedkv},
7575
kHasUnevenSplits,
7676
kMergeNumHeadGroupsSeqLenQ,
77-
{F_occupancy}>;
77+
{F_occupancy},
78+
{F_sink}>;
7879
7980
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
8081
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
@@ -118,7 +119,7 @@
118119
}} // anonymous namespace
119120
120121
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
121-
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
122+
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
122123
{F_dvpad}>;
123124
124125
#pragma clang diagnostic push
@@ -280,8 +281,8 @@
280281
"""
281282

282283
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
283-
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
284-
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
284+
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
285+
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
285286
286287
// get combine kernel tile sizes
287288
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
@@ -333,14 +334,15 @@ class FmhaFwdSplitKVApiTrait:
333334
dpad: str
334335
dvpad: str
335336
pagedkv: str
337+
sink: str # sink or not
336338
bn1comb: int # tile size along v head_dim of combine kernel
337339

338340
@property
339341
def name(self) -> str:
340342
return (
341343
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
342344
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
343-
+ f"{self.dvpad}-{self.pagedkv}"
345+
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
344346
)
345347

346348
@property
@@ -426,6 +428,7 @@ class FmhaFwdSplitKVPipeline:
426428
F_lse: str #
427429
F_squant: str #
428430
F_pagedkv: str # t/f
431+
F_sink: str # t/f
429432
F_mask: str # value from MASK_MAP
430433

431434
@property
@@ -486,6 +489,10 @@ def pad_name() -> str:
486489
n += "_pagedkv"
487490
else:
488491
n += "_npagedkv"
492+
if self.F_sink == "t":
493+
n += "_sink"
494+
else:
495+
n += "_nsink"
489496
return n
490497

491498

@@ -568,6 +575,7 @@ def api(self) -> str:
568575
F_lse=BOOL_MAP[trait.lse],
569576
F_squant=BOOL_MAP[trait.squant],
570577
F_pagedkv=BOOL_MAP[trait.pagedkv],
578+
F_sink=BOOL_MAP[trait.sink],
571579
F_scheck=trait.scheck,
572580
F_skcheck=trait.skcheck,
573581
F_dcheck=trait.dcheck,
@@ -668,6 +676,7 @@ def template(self) -> str:
668676
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
669677
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
670678
F_occupancy=self.F_tile.F_occupancy,
679+
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
671680
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
672681
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
673682
F_mode=MODE_MAP[self.F_mode],
@@ -741,19 +750,23 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
741750
squant = "t" if dtype == "fp8" else "f"
742751
pipelines = []
743752
if dtype in ["fp16", "bf16"]:
744-
for logits, mask, bias, pagedkv in itertools.product(
745-
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
753+
for logits, mask, bias, pagedkv, sink in itertools.product(
754+
["t", "f"],
755+
get_mask_map(mask_impl).keys(),
756+
BIAS_MAP.keys(),
757+
["t", "f"],
758+
["t", "f"],
746759
):
747-
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
748-
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
749-
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
750-
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
760+
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
761+
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
762+
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
763+
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
751764
elif dtype in ["fp8", "bf8"]:
752765
for logits, mask, bias in itertools.product(
753766
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
754767
):
755-
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
756-
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
768+
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
769+
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
757770
elif dtype in ["fp8fp16", "fp8bf16"]:
758771
# TODO
759772
None
@@ -909,6 +922,7 @@ def get_fwd_splitkv_blobs(
909922
cond &= pipeline.F_vlayout == "row"
910923
cond &= pipeline.F_bias in ["no", "alibi"]
911924
cond &= pipeline.F_squant == "f"
925+
cond &= pipeline.F_sink == "f"
912926
if not cond:
913927
continue
914928
# PyTorch integration
@@ -918,6 +932,7 @@ def get_fwd_splitkv_blobs(
918932
cond &= pipeline.F_bias in ["no", "bias"]
919933
cond &= pipeline.F_squant == "f"
920934
cond &= mode == "batch"
935+
cond &= pipeline.F_sink == "f"
921936
if not cond:
922937
continue
923938
# Aiter(mha_varlen_fwd) integration
@@ -1076,6 +1091,7 @@ def write_blobs(
10761091
lse=kernel.F_pipeline.F_lse,
10771092
squant=kernel.F_pipeline.F_squant,
10781093
pagedkv=kernel.F_pipeline.F_pagedkv,
1094+
sink=kernel.F_pipeline.F_sink,
10791095
spad=kernel.F_pipeline.F_spad,
10801096
skpad=kernel.F_pipeline.F_skpad,
10811097
dpad=kernel.F_pipeline.F_dpad,

example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
{F_pagedkv}, //pagedkv
6767
{F_squant},
6868
{F_occupancy},
69-
{F_skip}>;
69+
{F_skip},
70+
{F_sink}>;
7071
7172
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
7273
@@ -101,7 +102,7 @@
101102
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
102103
103104
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
104-
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
105+
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
105106
106107
template<>
107108
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
@@ -130,9 +131,9 @@
130131
}}
131132
"""
132133

133-
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
134+
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
134135
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
135-
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
136+
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
136137
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
137138
}}
138139
"""
@@ -164,12 +165,13 @@ class FmhaFwdApiTrait:
164165
dpad: str
165166
dvpad: str
166167
skip: str
168+
sink: str
167169

168170
@property
169171
def name(self) -> str:
170172
return (
171173
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
172-
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
174+
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
173175
)
174176

175177
@property
@@ -257,6 +259,7 @@ class FmhaFwdPipeline:
257259
F_squant: str #
258260
F_mask: str # value from MASK_MAP
259261
F_skip: str # true/false
262+
F_sink: str # true/false
260263

261264
@property
262265
def name(self) -> str:
@@ -321,6 +324,10 @@ def pad_name() -> str:
321324
n += "_pagedkv"
322325
else:
323326
n += "_npagedkv"
327+
if self.F_sink == "t":
328+
n += "_sink"
329+
else:
330+
n += "_nsink"
324331

325332
return n
326333

@@ -364,6 +371,7 @@ def api(self) -> str:
364371
F_lse=BOOL_MAP[trait.lse],
365372
F_pagedkv=BOOL_MAP[trait.pagedkv],
366373
F_skip=BOOL_MAP[trait.skip],
374+
F_sink=BOOL_MAP[trait.sink],
367375
F_squant=BOOL_MAP[trait.squant],
368376
F_scheck=trait.scheck,
369377
F_skcheck=trait.skcheck,
@@ -481,6 +489,7 @@ def template(self) -> str:
481489
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
482490
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
483491
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
492+
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
484493
F_occupancy=self.F_tile.F_occupancy,
485494
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
486495
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -527,6 +536,7 @@ def api_trait(self) -> FmhaFwdApiTrait:
527536
dpad=self.F_pipeline.F_dpad,
528537
dvpad=self.F_pipeline.F_dvpad,
529538
skip=self.F_pipeline.F_skip,
539+
sink=self.F_pipeline.F_sink,
530540
)
531541

532542

@@ -540,22 +550,23 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
540550
squant = "t" if dtype == "fp8" else "f"
541551
pipelines = []
542552
if dtype in ["fp16", "bf16"]:
543-
for logits, mask, bias, pagedkv, skip in itertools.product(
553+
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
544554
["t", "f"],
545555
get_mask_map(mask_impl).keys(),
546556
BIAS_MAP.keys(),
547557
["t"],
548558
["f"],
559+
["t", "f"],
549560
):
550-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
551-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
561+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
562+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
552563
elif dtype in ["fp8", "bf8"]:
553564
# no need lse/dropout kernels
554565
for logits, mask, bias in itertools.product(
555566
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
556567
):
557-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
558-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
568+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
569+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
559570
elif dtype in ["fp8fp16", "fp8bf16"]:
560571
pass # TODO
561572
else:
@@ -679,6 +690,7 @@ def get_fwd_blobs(
679690
cond &= pipeline.F_bias in ["no", "alibi"]
680691
cond &= pipeline.F_squant == "f"
681692
cond &= pipeline.F_skip == "f"
693+
cond &= pipeline.F_sink == "f"
682694
if not cond:
683695
continue
684696
# PyTorch integration
@@ -688,6 +700,7 @@ def get_fwd_blobs(
688700
cond &= pipeline.F_bias in ["no", "bias"]
689701
cond &= pipeline.F_squant == "f"
690702
cond &= pipeline.F_skip == "f"
703+
cond &= pipeline.F_sink == "f"
691704
if not cond:
692705
continue
693706
# Aiter(mha_fwd) integration

0 commit comments

Comments
 (0)