Skip to content

Commit

Permalink
Two warp tiles per CTA in each dim, increase instr to 64_64_16
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Jan 2, 2025
1 parent 521d5cc commit dce16ad
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4275,18 +4275,18 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
MatMulTileOptions gemm_tile;
// Regardless of the instruction, this should result in 2 warp groups i.e. 256
// threads
gemm_tile.cta_tile = GemmTile(128, 256, 16);
gemm_tile.warp_tile = GemmTile(64, 128, 16);
gemm_tile.cta_tile = GemmTile(256, 256, 16);
gemm_tile.warp_tile = GemmTile(128, 128, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_8_16;
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
// NOTE: disabling smem use for this test since we currrently hit a bank
Expand Down

0 comments on commit dce16ad

Please sign in to comment.