@@ -495,14 +495,14 @@ def get_flash_block_sizes(config):
495495 flash_block_sizes = None
496496 if len (config .flash_block_sizes .keys ()) > 0 :
497497 flash_block_sizes = splash_attention_kernel .BlockSizes (
498- block_q = config .flash_block_sizes ["block_q" ],
499- block_kv_compute = config .flash_block_sizes ["block_kv_compute" ],
500- block_kv = config .flash_block_sizes ["block_kv" ],
501- block_q_dkv = config .flash_block_sizes ["block_q_dkv" ],
502- block_kv_dkv = config .flash_block_sizes ["block_kv_dkv" ],
503- block_kv_dkv_compute = config .flash_block_sizes ["block_kv_dkv_compute" ],
504- block_q_dq = config .flash_block_sizes ["block_q_dq" ],
505- block_kv_dq = config .flash_block_sizes ["block_kv_dq" ],
498+ block_q = int ( config .flash_block_sizes ["block_q" ]) ,
499+ block_kv_compute = int ( config .flash_block_sizes ["block_kv_compute" ]) ,
500+ block_kv = int ( config .flash_block_sizes ["block_kv" ]) ,
501+ block_q_dkv = int ( config .flash_block_sizes ["block_q_dkv" ]) ,
502+ block_kv_dkv = int ( config .flash_block_sizes ["block_kv_dkv" ]) ,
503+ block_kv_dkv_compute = int ( config .flash_block_sizes ["block_kv_dkv_compute" ]) ,
504+ block_q_dq = int ( config .flash_block_sizes ["block_q_dq" ]) ,
505+ block_kv_dq = int ( config .flash_block_sizes ["block_kv_dq" ]) ,
506506 )
507507 return flash_block_sizes
508508
0 commit comments