|
15 | 15 | from .. import exc |
16 | 16 | from .._compat import get_tensor_descriptor_fn_name |
17 | 17 | from .ast_extension import expr_from_string |
| 18 | +from .ast_extension import statement_from_string |
18 | 19 | from .compile_environment import CompileEnvironment |
19 | 20 | from .device_function import DeviceFunction |
20 | 21 | from .host_function import HostFunction |
@@ -353,7 +354,6 @@ def codegen_load( |
353 | 354 | ) |
354 | 355 | assert extra_mask is None |
355 | 356 | indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) |
356 | | - |
357 | 357 | # Load from tensor descriptor with permuted offsets |
358 | 358 | load_expr = expr_from_string( |
359 | 359 | f"{indexing.tensor_descriptor(state)}.load({indexing.offsets_str_permuted(state)})" |
@@ -383,23 +383,119 @@ def codegen_store( |
383 | 383 | ) |
384 | 384 | assert extra_mask is None |
385 | 385 | indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) |
| 386 | + store_value = indexing.reshape_store(state, value) |
| 387 | + |
| 388 | + config = DeviceFunction.current().config |
| 389 | + epilogue_subtiles = state.config.epilogue_subtiling |
| 390 | + if torch.cuda.get_device_capability() >= (9, 0) and (idx := state.device_function.device_store_index) < len(epilogue_subtiles): |
| 391 | + subtile_split = epilogue_subtiles[idx] |
| 392 | + state.device_function.device_store_index += 1 |
| 393 | + |
| 394 | + subtile_codegen = self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value, subtile_split, config) |
| 395 | + if subtile_codegen is not None: |
| 396 | + return subtile_codegen |
386 | 397 |
|
387 | 398 | # Apply permutation to the value being stored if needed |
388 | 399 | desc_arg = indexing.tensor_descriptor_arg(state) |
389 | | - store_value = indexing.reshape_store(state, value) |
390 | 400 |
|
391 | 401 | if desc_arg.permutation is not None: |
392 | 402 | # Apply permutation to the value |
393 | 403 | store_value = expr_from_string( |
394 | 404 | f"tl.permute({{store_val}}, {desc_arg.permutation!r})", |
395 | 405 | store_val=store_value, |
396 | 406 | ) |
397 | | - |
| 407 | + |
398 | 408 | return expr_from_string( |
399 | 409 | f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", |
400 | 410 | value=store_value, |
401 | 411 | ) |
402 | 412 |
|
| 413 | + def _codegen_epilogue_subtile_store( |
| 414 | + self, |
| 415 | + state: CodegenState, |
| 416 | + fake_tensor: torch.Tensor, |
| 417 | + indexing: BlockedSubscriptIndexing, |
| 418 | + store_value: ast.AST, |
| 419 | + subtile_split: int, |
| 420 | + config: Config, |
| 421 | + ) -> ast.AST | None: |
| 422 | + # Currently support 2D tiles without permutations |
| 423 | + if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2 or subtile_split == 0: |
| 424 | + return None |
| 425 | + |
| 426 | + env = CompileEnvironment.current() |
| 427 | + block_m, block_n = indexing.block_shape |
| 428 | + try: |
| 429 | + block_n_hint = env.size_hint(block_n) |
| 430 | + block_idx = env.get_block_id(block_n) |
| 431 | + block_size = env.block_sizes[block_idx].from_config(config) |
| 432 | + except Exception: |
| 433 | + return None |
| 434 | + |
| 435 | + if block_n_hint % 2 != 0 or block_size <= 16: |
| 436 | + return None |
| 437 | + |
| 438 | + device_fn = state.device_function |
| 439 | + codegen = state.codegen |
| 440 | + |
| 441 | + block_m_str = device_fn.literal_expr(block_m) |
| 442 | + block_n_str = device_fn.literal_expr(block_n) |
| 443 | + indexing.block_shape[1] //= subtile_split |
| 444 | + |
| 445 | + desc_arg = indexing.tensor_descriptor_arg(state) |
| 446 | + |
| 447 | + # TODO: Support more epilogue subtile configs besides 2 |
| 448 | + block_n_half_str = f"({block_n_str} // {subtile_split})" |
| 449 | + |
| 450 | + # Lift the store value into a temporary variable for reuse |
| 451 | + acc_var = codegen.lift(store_value, prefix="acc") |
| 452 | + |
| 453 | + reshape_expr = expr_from_string( |
| 454 | + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)", |
| 455 | + acc=acc_var, |
| 456 | + dim_m=expr_from_string(block_m_str), |
| 457 | + dim_half=expr_from_string(block_n_half_str), |
| 458 | + ) |
| 459 | + reshape_var = codegen.lift(reshape_expr, prefix="acc") |
| 460 | + |
| 461 | + acc0_name = codegen.tmpvar(prefix="acc") |
| 462 | + acc1_name = codegen.tmpvar(prefix="acc") |
| 463 | + codegen.add_statement( |
| 464 | + statement_from_string( |
| 465 | + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", |
| 466 | + acc=reshape_var, |
| 467 | + ) |
| 468 | + ) |
| 469 | + acc0 = expr_from_string(acc0_name) |
| 470 | + acc1 = expr_from_string(acc1_name) |
| 471 | + |
| 472 | + desc_name = indexing.tensor_descriptor(state) |
| 473 | + offset0 = expr_from_string(indexing.offsets[0]) |
| 474 | + offset1 = expr_from_string(indexing.offsets[1]) |
| 475 | + |
| 476 | + # First subtile store |
| 477 | + codegen.add_statement( |
| 478 | + statement_from_string( |
| 479 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 480 | + off0=offset0, |
| 481 | + off1=offset1, |
| 482 | + value=acc0, |
| 483 | + ) |
| 484 | + ) |
| 485 | + |
| 486 | + offset1_shifted = expr_from_string( |
| 487 | + "({offset} + {half})", |
| 488 | + offset=expr_from_string(indexing.offsets[1]), |
| 489 | + half=expr_from_string(block_n_half_str), |
| 490 | + ) |
| 491 | + |
| 492 | + # Emit second subtile store as the expression returned to the caller |
| 493 | + return expr_from_string( |
| 494 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 495 | + off0=offset0, |
| 496 | + off1=offset1_shifted, |
| 497 | + value=acc1, |
| 498 | + ) |
403 | 499 |
|
404 | 500 | class StackIndexingStrategy: |
405 | 501 | """ |
|
0 commit comments