@@ -199,6 +199,54 @@ def test_matmul_static_shapes3(self):
199199 torch .testing .assert_close (output , args [0 ] @ args [1 ], atol = 1e-1 , rtol = 1e-2 )
200200 self .assertExpectedJournal (code )
201201
202+ def test_matmul_packed_int4_block_size_constexpr (self ):
203+ torch .manual_seed (0 )
204+ M = N = K = 32
205+
206+ @helion .kernel (use_default_config = True , static_shapes = True )
207+ def matmul_bf16_packed_int4 (
208+ A : torch .Tensor , B_packed : torch .Tensor , C : torch .Tensor
209+ ) -> torch .Tensor :
210+ M0 , K0 = A .shape
211+ _ , N0 = B_packed .shape
212+
213+ block_n = hl .register_block_size (N0 )
214+ block_k = hl .register_block_size (K0 )
215+
216+ for tile_m in hl .tile (M0 ):
217+ for tile_n in hl .tile (N0 , block_size = block_n ):
218+ acc = hl .zeros ((tile_m , tile_n ), dtype = torch .float32 )
219+
220+ for tile_k in hl .tile (K0 , block_size = block_k ):
221+ tile_k_begin = tile_k .begin
222+ b_tile = B_packed [
223+ tile_k_begin // 2 : tile_k_begin // 2 + block_k // 2 ,
224+ tile_n ,
225+ ]
226+ shift = hl .full ((1 ,), 4 , dtype = torch .int8 )
227+ b_lo = (b_tile << shift ) >> shift
228+ b_hi = b_tile >> shift
229+ stacked = torch .stack (
230+ (b_lo .to (torch .float16 ), b_hi .to (torch .float16 )), dim = 2
231+ )
232+ stacked = stacked .permute (0 , 2 , 1 )
233+ b_block = stacked .reshape ([block_k , block_n ])
234+ acc = hl .dot (A [tile_m , tile_k ], b_block , acc = acc )
235+
236+ C [tile_m , tile_n ] = acc
237+
238+ return C
239+
240+ A = torch .randn ((M , K ), dtype = torch .bfloat16 , device = DEVICE )
241+ B_packed = torch .randint (0 , 16 , (K // 2 , N ), dtype = torch .int8 , device = DEVICE )
242+ C = torch .zeros ((M , N ), dtype = torch .float32 , device = DEVICE )
243+
244+ matmul_bf16_packed_int4 (A , B_packed , C )
245+ torch .cuda .synchronize ()
246+
247+ self .assertTrue (torch .isfinite (C ).all ())
248+ self .assertFalse (torch .allclose (C , torch .zeros_like (C )))
249+
202250 def test_matmul_split_k (self ):
203251 @helion .kernel (dot_precision = "ieee" )
204252 def matmul_split_k (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
0 commit comments