Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions flash_sparse_attn/ops/triton/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_COMPILED_KERNEL_CACHE_MAXSIZE = 4096
_COMPILED_KERNEL_CACHE: OrderedDict = OrderedDict()
_LAUNCHER_CACHE_MAXSIZE = 1024
_STATIC_BUFFER_POOL: dict[tuple, torch.Tensor] = {}


def _compiled_cache_get(key):
Expand All @@ -24,6 +25,16 @@ def _compiled_cache_put(key, compiled):
_COMPILED_KERNEL_CACHE.popitem(last=False)


def get_static_buffer(shape, dtype, device, tag=""):
key = (shape, dtype, device.type, device.index, tag)
buf = _STATIC_BUFFER_POOL.get(key)
if buf is not None and buf.shape == shape:
return buf
buf = torch.empty(shape, dtype=dtype, device=device)
_STATIC_BUFFER_POOL[key] = buf
return buf


@functools.lru_cache(maxsize=8)
def get_device_num_sms(device: torch.device) -> int:
"""
Expand Down
72 changes: 38 additions & 34 deletions flash_sparse_attn/ops/triton/flash_dense_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def _flash_dense_attn_base_backward(
is_causal: bool = False,
softmax_scale: float = None,
window_size: Tuple[int, int] = (None, None),
skip_checks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = query.device
arch = cache_utils.get_device_arch(device)
Expand All @@ -656,23 +657,24 @@ def _flash_dense_attn_base_backward(
softmax_scale_log2 = softmax_scale * math.log2(math.e)
qhead_per_kvhead = num_heads_q // num_heads_kv

assert_inputs.assert_bwd_inputs(
query,
key,
value,
out,
dout,
lse,
cu_seqlens_q=None,
cu_seqlens_k=None,
seqused_q=None,
seqused_k=None,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)
if not skip_checks:
assert_inputs.assert_bwd_inputs(
query,
key,
value,
out,
dout,
lse,
cu_seqlens_q=None,
cu_seqlens_k=None,
seqused_q=None,
seqused_k=None,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand Down Expand Up @@ -834,6 +836,7 @@ def _flash_dense_attn_varlen_base_backward(
window_size: Tuple[int, int] = (None, None),
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
skip_checks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device = query.device
arch = cache_utils.get_device_arch(device)
Expand All @@ -848,23 +851,24 @@ def _flash_dense_attn_varlen_base_backward(
softmax_scale_log2 = softmax_scale * math.log2(math.e)
qhead_per_kvhead = num_heads_q // num_heads_kv

assert_inputs.assert_bwd_inputs(
query,
key,
value,
out,
dout,
lse,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)
if not skip_checks:
assert_inputs.assert_bwd_inputs(
query,
key,
value,
out,
dout,
lse,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand Down
110 changes: 61 additions & 49 deletions flash_sparse_attn/ops/triton/flash_dense_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def _flash_dense_attn_base_decode(
window_size: Tuple[int, int] = (None, None),
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
skip_checks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = query.device
arch = cache_utils.get_device_arch(device)
Expand All @@ -881,21 +882,22 @@ def _flash_dense_attn_base_decode(
softmax_scale_log2 = softmax_scale * math.log2(math.e)
qheads_per_kvhead = num_heads_q // num_heads_kv

assert_inputs.assert_dec_inputs(
query,
key,
value,
query_scale=query_scale,
key_scale=key_scale,
value_scale=value_scale,
cu_seqlens_k=None,
seqused_k=None,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)
if not skip_checks:
assert_inputs.assert_dec_inputs(
query,
key,
value,
query_scale=query_scale,
key_scale=key_scale,
value_scale=value_scale,
cu_seqlens_k=None,
seqused_k=None,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand All @@ -920,23 +922,29 @@ def _flash_dense_attn_base_decode(
out = (
out
if out is not None
else torch.empty(query.shape, dtype=out_dtype, device=device)
else cache_utils.get_static_buffer(
query.shape, out_dtype, device, tag="dec_out"
)
)
lse = (
lse
if lse is not None
else torch.empty((batch_size, num_heads_q), dtype=torch.float32, device=device)
else cache_utils.get_static_buffer(
(batch_size, num_heads_q), torch.float32, device, tag="dec_lse"
)
)

out_partial = torch.empty(
out_partial = cache_utils.get_static_buffer(
(num_splits, batch_size, num_heads_q, head_dim),
dtype=torch.float32,
device=query.device,
torch.float32,
device,
tag="dec_out_partial",
)
lse_partial = torch.empty(
lse_partial = cache_utils.get_static_buffer(
(num_splits, batch_size, num_heads_q),
dtype=torch.float32,
device=query.device,
torch.float32,
device,
tag="dec_lse_partial",
)

grid = launch_grid.get_dec_grid(
Expand Down Expand Up @@ -1069,6 +1077,7 @@ def _flash_dense_attn_varlen_base_decode(
seqused_k: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
skip_checks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = query.device
arch = cache_utils.get_device_arch(device)
Expand All @@ -1083,21 +1092,22 @@ def _flash_dense_attn_varlen_base_decode(
softmax_scale_log2 = softmax_scale * math.log2(math.e)
qheads_per_kvhead = num_heads_q // num_heads_kv

assert_inputs.assert_dec_inputs(
query,
key,
value,
query_scale=query_scale,
key_scale=key_scale,
value_scale=value_scale,
cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)
if not skip_checks:
assert_inputs.assert_dec_inputs(
query,
key,
value,
query_scale=query_scale,
key_scale=key_scale,
value_scale=value_scale,
cu_seqlens_k=cu_seqlens_k,
seqused_k=seqused_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand All @@ -1122,27 +1132,29 @@ def _flash_dense_attn_varlen_base_decode(
out = (
out
if out is not None
else torch.empty(query.shape, dtype=out_dtype, device=device)
else cache_utils.get_static_buffer(
query.shape, out_dtype, device, tag="dec_out"
)
)
lse = (
lse
if lse is not None
else torch.empty(
(batch_size, num_heads_q),
dtype=torch.float32,
device=device,
else cache_utils.get_static_buffer(
(batch_size, num_heads_q), torch.float32, device, tag="dec_lse"
)
)

out_partial = torch.empty(
out_partial = cache_utils.get_static_buffer(
(num_splits, batch_size, num_heads_q, head_dim),
dtype=torch.float32,
device=device,
torch.float32,
device,
tag="dec_out_partial",
)
lse_partial = torch.empty(
lse_partial = cache_utils.get_static_buffer(
(num_splits, batch_size, num_heads_q),
dtype=torch.float32,
device=device,
torch.float32,
device,
tag="dec_lse_partial",
)

grid = launch_grid.get_dec_grid(
Expand Down
60 changes: 32 additions & 28 deletions flash_sparse_attn/ops/triton/flash_dense_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def _flash_dense_attn_base_forward(
window_size: Tuple[int, int] = (None, None),
is_split_kv: bool = False,
pack_gqa: bool = False,
skip_checks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
device = query.device
arch = cache_utils.get_device_arch(device)
Expand All @@ -630,20 +631,21 @@ def _flash_dense_attn_base_forward(
qheads_per_kvhead = num_heads_q // num_heads_kv
qheads_per_kvhead_packgqa = num_heads_q // num_heads_kv if pack_gqa else 1

assert_inputs.assert_fwd_inputs(
query,
key,
value,
cu_seqlens_q=None,
cu_seqlens_k=None,
seqused_q=None,
seqused_k=None,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)
if not skip_checks:
assert_inputs.assert_fwd_inputs(
query,
key,
value,
cu_seqlens_q=None,
cu_seqlens_k=None,
seqused_q=None,
seqused_k=None,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand Down Expand Up @@ -775,6 +777,7 @@ def _flash_dense_attn_varlen_base_forward(
pack_gqa: bool = False,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
skip_checks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, float]:
device = query.device
arch = cache_utils.get_device_arch(device)
Expand All @@ -791,20 +794,21 @@ def _flash_dense_attn_varlen_base_forward(
qheads_per_kvhead = num_heads_q // num_heads_kv
qheads_per_kvhead_packgqa = num_heads_q // num_heads_kv if pack_gqa else 1

assert_inputs.assert_fwd_inputs(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)
if not skip_checks:
assert_inputs.assert_fwd_inputs(
query,
key,
value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
num_heads_q=num_heads_q,
num_heads_kv=num_heads_kv,
head_dim=head_dim,
device=device,
arch=arch,
)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand Down
Loading
Loading