Skip to content

Commit 5adaa20

Browse files
authored
Revert "Add attn sink (#2892)" (#3250)
This reverts commit 9fa4e8d.
1 parent 9fa4e8d commit 5adaa20

25 files changed

+195
-940
lines changed

example/ck_tile/01_fmha/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
6262
# there is no corresponding instance for parameters).
6363
if(BUILD_TESTING)
6464
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
65-
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
65+
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
6666
endif()
6767

6868
# generate a list of kernels, but not actually emit files at config sta

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 28 additions & 46 deletions
Large diffs are not rendered by default.

example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@
7474
{F_pagedkv},
7575
kHasUnevenSplits,
7676
kMergeNumHeadGroupsSeqLenQ,
77-
{F_occupancy},
78-
{F_sink}>;
77+
{F_occupancy}>;
7978
8079
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
8180
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
@@ -119,7 +118,7 @@
119118
}} // anonymous namespace
120119
121120
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
122-
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
121+
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
123122
{F_dvpad}>;
124123
125124
#pragma clang diagnostic push
@@ -281,8 +280,8 @@
281280
"""
282281

283282
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
284-
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
285-
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
283+
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
284+
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
286285
287286
// get combine kernel tile sizes
288287
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
@@ -334,15 +333,14 @@ class FmhaFwdSplitKVApiTrait:
334333
dpad: str
335334
dvpad: str
336335
pagedkv: str
337-
sink: str # sink or not
338336
bn1comb: int # tile size along v head_dim of combine kernel
339337

340338
@property
341339
def name(self) -> str:
342340
return (
343341
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
344342
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
345-
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
343+
+ f"{self.dvpad}-{self.pagedkv}"
346344
)
347345

348346
@property
@@ -428,7 +426,6 @@ class FmhaFwdSplitKVPipeline:
428426
F_lse: str #
429427
F_squant: str #
430428
F_pagedkv: str # t/f
431-
F_sink: str # t/f
432429
F_mask: str # value from MASK_MAP
433430

434431
@property
@@ -489,10 +486,6 @@ def pad_name() -> str:
489486
n += "_pagedkv"
490487
else:
491488
n += "_npagedkv"
492-
if self.F_sink == "t":
493-
n += "_sink"
494-
else:
495-
n += "_nsink"
496489
return n
497490

498491

@@ -575,7 +568,6 @@ def api(self) -> str:
575568
F_lse=BOOL_MAP[trait.lse],
576569
F_squant=BOOL_MAP[trait.squant],
577570
F_pagedkv=BOOL_MAP[trait.pagedkv],
578-
F_sink=BOOL_MAP[trait.sink],
579571
F_scheck=trait.scheck,
580572
F_skcheck=trait.skcheck,
581573
F_dcheck=trait.dcheck,
@@ -676,7 +668,6 @@ def template(self) -> str:
676668
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
677669
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
678670
F_occupancy=self.F_tile.F_occupancy,
679-
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
680671
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
681672
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
682673
F_mode=MODE_MAP[self.F_mode],
@@ -750,23 +741,19 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
750741
squant = "t" if dtype == "fp8" else "f"
751742
pipelines = []
752743
if dtype in ["fp16", "bf16"]:
753-
for logits, mask, bias, pagedkv, sink in itertools.product(
754-
["t", "f"],
755-
get_mask_map(mask_impl).keys(),
756-
BIAS_MAP.keys(),
757-
["t", "f"],
758-
["t", "f"],
744+
for logits, mask, bias, pagedkv in itertools.product(
745+
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
759746
):
760-
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
761-
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
762-
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
763-
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
747+
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
748+
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
749+
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
750+
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
764751
elif dtype in ["fp8", "bf8"]:
765752
for logits, mask, bias in itertools.product(
766753
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
767754
):
768-
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
769-
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
755+
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
756+
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
770757
elif dtype in ["fp8fp16", "fp8bf16"]:
771758
# TODO
772759
None
@@ -922,7 +909,6 @@ def get_fwd_splitkv_blobs(
922909
cond &= pipeline.F_vlayout == "row"
923910
cond &= pipeline.F_bias in ["no", "alibi"]
924911
cond &= pipeline.F_squant == "f"
925-
cond &= pipeline.F_sink == "f"
926912
if not cond:
927913
continue
928914
# PyTorch integration
@@ -932,7 +918,6 @@ def get_fwd_splitkv_blobs(
932918
cond &= pipeline.F_bias in ["no", "bias"]
933919
cond &= pipeline.F_squant == "f"
934920
cond &= mode == "batch"
935-
cond &= pipeline.F_sink == "f"
936921
if not cond:
937922
continue
938923
# Aiter(mha_varlen_fwd) integration
@@ -1091,7 +1076,6 @@ def write_blobs(
10911076
lse=kernel.F_pipeline.F_lse,
10921077
squant=kernel.F_pipeline.F_squant,
10931078
pagedkv=kernel.F_pipeline.F_pagedkv,
1094-
sink=kernel.F_pipeline.F_sink,
10951079
spad=kernel.F_pipeline.F_spad,
10961080
skpad=kernel.F_pipeline.F_skpad,
10971081
dpad=kernel.F_pipeline.F_dpad,

example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@
6666
{F_pagedkv}, //pagedkv
6767
{F_squant},
6868
{F_occupancy},
69-
{F_skip},
70-
{F_sink}>;
69+
{F_skip}>;
7170
7271
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
7372
@@ -102,7 +101,7 @@
102101
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
103102
104103
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
105-
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
104+
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
106105
107106
template<>
108107
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
@@ -131,9 +130,9 @@
131130
}}
132131
"""
133132

134-
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
133+
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
135134
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
136-
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
135+
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
137136
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
138137
}}
139138
"""
@@ -165,13 +164,12 @@ class FmhaFwdApiTrait:
165164
dpad: str
166165
dvpad: str
167166
skip: str
168-
sink: str
169167

170168
@property
171169
def name(self) -> str:
172170
return (
173171
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
174-
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
172+
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
175173
)
176174

177175
@property
@@ -259,7 +257,6 @@ class FmhaFwdPipeline:
259257
F_squant: str #
260258
F_mask: str # value from MASK_MAP
261259
F_skip: str # true/false
262-
F_sink: str # true/false
263260

264261
@property
265262
def name(self) -> str:
@@ -324,10 +321,6 @@ def pad_name() -> str:
324321
n += "_pagedkv"
325322
else:
326323
n += "_npagedkv"
327-
if self.F_sink == "t":
328-
n += "_sink"
329-
else:
330-
n += "_nsink"
331324

332325
return n
333326

@@ -371,7 +364,6 @@ def api(self) -> str:
371364
F_lse=BOOL_MAP[trait.lse],
372365
F_pagedkv=BOOL_MAP[trait.pagedkv],
373366
F_skip=BOOL_MAP[trait.skip],
374-
F_sink=BOOL_MAP[trait.sink],
375367
F_squant=BOOL_MAP[trait.squant],
376368
F_scheck=trait.scheck,
377369
F_skcheck=trait.skcheck,
@@ -489,7 +481,6 @@ def template(self) -> str:
489481
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
490482
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
491483
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
492-
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
493484
F_occupancy=self.F_tile.F_occupancy,
494485
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
495486
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -536,7 +527,6 @@ def api_trait(self) -> FmhaFwdApiTrait:
536527
dpad=self.F_pipeline.F_dpad,
537528
dvpad=self.F_pipeline.F_dvpad,
538529
skip=self.F_pipeline.F_skip,
539-
sink=self.F_pipeline.F_sink,
540530
)
541531

542532

@@ -550,23 +540,22 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
550540
squant = "t" if dtype == "fp8" else "f"
551541
pipelines = []
552542
if dtype in ["fp16", "bf16"]:
553-
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
543+
for logits, mask, bias, pagedkv, skip in itertools.product(
554544
["t", "f"],
555545
get_mask_map(mask_impl).keys(),
556546
BIAS_MAP.keys(),
557547
["t"],
558548
["f"],
559-
["t", "f"],
560549
):
561-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
562-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
550+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
551+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
563552
elif dtype in ["fp8", "bf8"]:
564553
# no need lse/dropout kernels
565554
for logits, mask, bias in itertools.product(
566555
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
567556
):
568-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
569-
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
557+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
558+
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
570559
elif dtype in ["fp8fp16", "fp8bf16"]:
571560
pass # TODO
572561
else:
@@ -690,7 +679,6 @@ def get_fwd_blobs(
690679
cond &= pipeline.F_bias in ["no", "alibi"]
691680
cond &= pipeline.F_squant == "f"
692681
cond &= pipeline.F_skip == "f"
693-
cond &= pipeline.F_sink == "f"
694682
if not cond:
695683
continue
696684
# PyTorch integration
@@ -700,7 +688,6 @@ def get_fwd_blobs(
700688
cond &= pipeline.F_bias in ["no", "bias"]
701689
cond &= pipeline.F_squant == "f"
702690
cond &= pipeline.F_skip == "f"
703-
cond &= pipeline.F_sink == "f"
704691
if not cond:
705692
continue
706693
# Aiter(mha_fwd) integration

0 commit comments

Comments
 (0)