Skip to content

Commit 8202837

Browse files
committed
skip flash block sizes setting for cross attention.
1 parent 3cd6a44 commit 8202837

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/maxdiffusion/max_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,14 +495,14 @@ def get_flash_block_sizes(config):
495495
flash_block_sizes = None
496496
if len(config.flash_block_sizes.keys()) > 0:
497497
flash_block_sizes = splash_attention_kernel.BlockSizes(
498-
block_q=config.flash_block_sizes["block_q"],
499-
block_kv_compute=config.flash_block_sizes["block_kv_compute"],
500-
block_kv=config.flash_block_sizes["block_kv"],
501-
block_q_dkv=config.flash_block_sizes["block_q_dkv"],
502-
block_kv_dkv=config.flash_block_sizes["block_kv_dkv"],
503-
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
504-
block_q_dq=config.flash_block_sizes["block_q_dq"],
505-
block_kv_dq=config.flash_block_sizes["block_kv_dq"],
498+
block_q=int(config.flash_block_sizes["block_q"]),
499+
block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]),
500+
block_kv=int(config.flash_block_sizes["block_kv"]),
501+
block_q_dkv=int(config.flash_block_sizes["block_q_dkv"]),
502+
block_kv_dkv=int(config.flash_block_sizes["block_kv_dkv"]),
503+
block_kv_dkv_compute=int(config.flash_block_sizes["block_kv_dkv_compute"]),
504+
block_q_dq=int(config.flash_block_sizes["block_q_dq"]),
505+
block_kv_dq=int(config.flash_block_sizes["block_kv_dq"]),
506506
)
507507
return flash_block_sizes
508508

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def _tpu_flash_attention(
184184
kv_max_block_size = key.shape[1]
185185
else:
186186
kv_max_block_size = q_max_block_size
187-
if flash_block_sizes:
187+
# ensure that for cross attention we override the block sizes.
188+
if flash_block_sizes and key.shape[1] == query.shape[1]:
188189
block_sizes = flash_block_sizes
189190
else:
190191
block_sizes = splash_attention_kernel.BlockSizes(

0 commit comments

Comments
 (0)