diff --git a/batch_invariant_ops/batch_invariant_ops.py b/batch_invariant_ops/batch_invariant_ops.py index e68e4b3..268f9b2 100644 --- a/batch_invariant_ops/batch_invariant_ops.py +++ b/batch_invariant_ops/batch_invariant_ops.py @@ -70,7 +70,7 @@ def matmul_kernel_persistent( k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n - tile_id_c = start_pid - NUM_SMS + # tile_id_c = start_pid - NUM_SMS offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n @@ -103,8 +103,8 @@ def matmul_kernel_persistent( b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator) - tile_id_c += NUM_SMS - pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) + # tile_id_c += NUM_SMS + # pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if C_LARGE: @@ -118,6 +118,10 @@ def matmul_kernel_persistent( accumulator += bias if c_ptr.dtype.element_ty == tl.float8e4nv: c = accumulator.to(tl.float8e4nv) + elif c_ptr.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif c_ptr.dtype.element_ty == tl.float32: + c = accumulator.to(tl.float32) else: c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) diff --git a/test_batch_invariance.py b/test_batch_invariance.py index 912e113..831fe24 100644 --- a/test_batch_invariance.py +++ b/test_batch_invariance.py @@ -8,9 +8,18 @@ pass def test_batch_invariance(dtype=torch.float32): - B, D = 2048, 4096 - a = torch.linspace(-100, 100, B*D, dtype=dtype).reshape(B, D) - b = torch.linspace(-100, 100, D*D, dtype=dtype).reshape(D, D) + M = 32 + K = 128 + N = 1024 + a = torch.linspace(-100, 100, M*K, dtype=dtype).reshape(M, K) + + # Create non-contiguous tensor to mimic the nn.Linear case while weight is always transposed + # See ref: https://github.com/pytorch/pytorch/blob/v2.8.0/torch/nn/modules/linear.py#L50 + b = torch.linspace(-100, 100, K*N, dtype=dtype).reshape(N, K) + b = b.transpose(0, 1) + + print(f"a is contiguous: {a.is_contiguous()}") + print(f"b is contiguous: {b.is_contiguous()}") # Method 1: Matrix-vector multiplication (batch size 1) out1 = torch.mm(a[:1], b)