|
21 | 21 |
|
22 | 22 | import helion |
23 | 23 | from helion._testing import DEVICE |
| 24 | +from helion._testing import run_example |
24 | 25 | import helion.language as hl |
25 | 26 |
|
26 | 27 | # %% |
@@ -130,37 +131,72 @@ def run_kernel() -> torch.Tensor: |
130 | 131 |
|
131 | 132 |
|
132 | 133 | # %% |
133 | | -def check(m: int, k: int, n: int) -> None: |
| 134 | +def _pack_int4_matrix(unpacked: torch.Tensor) -> torch.Tensor: |
134 | 135 | """ |
135 | | - Test the INT4 GEMM implementation. |
| 136 | + Pack int4 matrix into int8 container with two values per byte. |
136 | 137 |
|
137 | 138 | Args: |
138 | | - m (int): Number of rows in the left input matrix. |
139 | | - k (int): Shared dimension (must be even). |
140 | | - n (int): Number of columns in the right input matrix. |
| 139 | + unpacked (torch.Tensor): Tensor of shape [K, N] with values in [-8, 7]. |
| 140 | +
|
| 141 | + Returns: |
| 142 | + torch.Tensor: Packed tensor of shape [K//2, N] in int8 format. |
141 | 143 | """ |
142 | | - # Create test matrices |
143 | | - A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE) |
| 144 | + k, n = unpacked.shape |
| 145 | + assert k % 2 == 0, "K dimension must be even for int4 packing" |
| 146 | + reshaped = unpacked.reshape(k // 2, 2, n).permute(1, 0, 2) |
| 147 | + return ((reshaped[0] & 0xF) | (reshaped[1] << 4)).to(torch.int8) |
144 | 148 |
|
145 | | - # Create packed int4 matrix B (K//2 x N) |
146 | | - # Generate random int4 values in range [-8, 7] and pack them |
147 | | - B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE) |
148 | 149 |
|
149 | | - # Pack using the same format as tritonbench |
150 | | - B_reshaped = B_unpacked.reshape(k // 2, 2, n).permute(1, 0, 2) |
151 | | - B_packed = ((B_reshaped[0] & 0xF) | (B_reshaped[1] << 4)).to(torch.int8) |
| 150 | +def _unpack_int4_matrix(packed: torch.Tensor) -> torch.Tensor: |
| 151 | + """ |
| 152 | + Unpack an int4 matrix stored as two 4-bit values per int8 byte. |
| 153 | +
|
| 154 | + Args: |
| 155 | + packed (torch.Tensor): Packed tensor of shape [K//2, N] in int8 format. |
| 156 | +
|
| 157 | + Returns: |
| 158 | + torch.Tensor: Unpacked tensor of shape [K, N] in int8 format. |
| 159 | + """ |
| 160 | + b_lo = ((packed << 4) >> 4).to(torch.int8) |
| 161 | + b_hi = (packed >> 4).to(torch.int8) |
| 162 | + stacked = torch.stack([b_lo, b_hi], dim=1) |
| 163 | + return stacked.reshape(packed.shape[0] * 2, packed.shape[1]) |
| 164 | + |
152 | 165 |
|
153 | | - # Convert unpacked values to bfloat16 for reference |
154 | | - B_unpacked_bf16 = B_unpacked.to(torch.bfloat16) |
| 166 | +def reference_matmul_bf16_int4(A: Tensor, B_packed: Tensor) -> Tensor: |
| 167 | + """ |
| 168 | + Reference implementation that unpacks the int4 weights and performs matmul. |
| 169 | +
|
| 170 | + Args: |
| 171 | + A (Tensor): Input tensor in bfloat16 format. |
| 172 | + B_packed (Tensor): Packed int4 tensor. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + Tensor: Output tensor in bfloat16 format. |
| 176 | + """ |
| 177 | + B_unpacked = _unpack_int4_matrix(B_packed).to(torch.bfloat16) |
| 178 | + return torch.matmul(A, B_unpacked) |
155 | 179 |
|
156 | | - # Compute reference result |
157 | | - expected = torch.matmul(A, B_unpacked_bf16) |
158 | 180 |
|
159 | | - # Run the kernel |
160 | | - result = matmul_bf16_int4(A, B_packed) |
| 181 | +def check(m: int, k: int, n: int) -> None: |
| 182 | + """ |
| 183 | + Test the INT4 GEMM implementation using the run_example utility. |
161 | 184 |
|
162 | | - # Check accuracy with appropriate tolerance |
163 | | - torch.testing.assert_close(result, expected, rtol=2e-1, atol=1.0) |
| 185 | + Args: |
| 186 | + m (int): Number of rows in the left input matrix. |
| 187 | + k (int): Shared dimension (must be even). |
| 188 | + n (int): Number of columns in the right input matrix. |
| 189 | + """ |
| 190 | + A = torch.randn(m, k, dtype=torch.bfloat16, device=DEVICE) |
| 191 | + B_unpacked = torch.randint(-8, 8, (k, n), dtype=torch.int8, device=DEVICE) |
| 192 | + B_packed = _pack_int4_matrix(B_unpacked) |
| 193 | + run_example( |
| 194 | + matmul_bf16_int4, |
| 195 | + reference_matmul_bf16_int4, |
| 196 | + (A, B_packed), |
| 197 | + rtol=2e-1, |
| 198 | + atol=1.0, |
| 199 | + ) |
164 | 200 | print(f"Test passed for shapes: M={m}, K={k}, N={n}") |
165 | 201 |
|
166 | 202 |
|
|
0 commit comments