Skip to content

Commit a6656c1

Browse files
committed
Fix bug with unit sized dims and block_sizes
stack-info: PR: #932, branch: jansel/stack/191
1 parent 2fe7490 commit a6656c1

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,11 @@ def from_config(
582582
@dataclasses.dataclass
583583
class LoopSpecBlockSizeSource(BlockSizeSource):
584584
def from_config(self, config: Config, block_size_info: BlockSizeInfo) -> int:
585-
index = CompileEnvironment.current().config_spec.block_sizes.block_id_to_index(
586-
block_size_info.block_id
587-
)
585+
env = CompileEnvironment.current()
586+
size = block_size_info.size
587+
if isinstance(size, (int, torch.SymInt)) and env.known_equal(size, 1):
588+
return 1
589+
index = env.config_spec.block_sizes.block_id_to_index(block_size_info.block_id)
588590
return config.block_sizes[index]
589591

590592

test/test_matmul.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,31 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272272
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
273273
self.assertExpectedJournal(code)
274274

275+
def test_matmul_config_reuse_with_unit_dim(self):
276+
torch.manual_seed(0)
277+
big_args = (
278+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
279+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
280+
)
281+
big_bound = matmul_with_addmm.bind(big_args)
282+
big_spec = big_bound.config_spec
283+
self.assertEqual(len(big_spec.block_sizes), 3)
284+
big_config = big_spec.default_config()
285+
286+
small_args = (
287+
torch.randn([1, 64], device=DEVICE, dtype=torch.float32),
288+
torch.randn([64, 64], device=DEVICE, dtype=torch.float32),
289+
)
290+
small_bound = matmul_with_addmm.bind(small_args)
291+
small_spec = small_bound.config_spec
292+
self.assertEqual(len(small_spec.block_sizes), 3)
293+
294+
# Previously raised when reusing configs tuned on larger shapes.
295+
small_bound.set_config(big_config)
296+
result = small_bound(*small_args)
297+
expected = small_args[0] @ small_args[1]
298+
torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-2)
299+
275300
def test_matmul_packed_rhs(self):
276301
@helion.kernel(static_shapes=False)
277302
def matmul_with_packed_b(

0 commit comments

Comments
 (0)