Skip to content

Commit

Permalink
Override params for horizontal fusion tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Feb 6, 2025
1 parent 4d0226c commit 07c93c6
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4536,6 +4536,13 @@ TEST_P(MLPBenchmarkTest, FwdHorizontalFusion) {
(tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * tv10_ref.to(at::kFloat))
.to(at::kBFloat16);

// Adjust parameters in order to fit smem and register constraints
mparams.tile_sizes.cta_tile = GemmTile(128, 128, 64);
mparams.tile_sizes.warp_tile = GemmTile(64, 128, 64);
mparams.mma_macro = MmaMacro::Hopper_64_128_16;
mparams.promote_prologue_smem_reuse = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

Expand Down Expand Up @@ -4603,6 +4610,13 @@ TEST_P(MLPBenchmarkTest, FwdHorizontalFusion_BroadcastInputs) {
(tv4_ref * (1. / (1.0 + at::exp(-tv4_ref))) * tv10_ref.to(at::kFloat))
.to(at::kBFloat16);

// Adjust parameters in order to fit smem and register constraints
mparams.tile_sizes.cta_tile = GemmTile(128, 128, 64);
mparams.tile_sizes.warp_tile = GemmTile(64, 128, 64);
mparams.mma_macro = MmaMacro::Hopper_64_128_16;
mparams.promote_prologue_smem_reuse = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

Expand Down

0 comments on commit 07c93c6

Please sign in to comment.