Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions lib/Conversion/StructuredToMemref/StructuredToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,30 @@ struct MakeTensorPtrConverter
// clampedOff - targetOffset
// d1 = --------------------
// strideRows
//
////////////////////////////////////////////////////////////////////////////
//
// cols
//
// wrappedAroundOff
// --------------*---------------------
// | |
// | targetOffset |
// | *------------| |
// | | | |
// | | | |
// rows| rowSize | | |
// | | | |
// | | | |
// | *------------| |
// | nextOff |
// | |
// | clampedOff |
// --------------*---------------------
//
// For the case that clampedOff is not overflown
// d1 = min(d1, rowSize)
//

auto resultType = getResultMemrefType(
op, /* offset */ ShapedType::kDynamic,
Expand Down Expand Up @@ -443,6 +467,7 @@ struct MakeTensorPtrConverter
rewriter.create<arith::AddIOp>(loc, modRow, wrappedAroundOff);
Value d1 = rewriter.create<arith::SubIOp>(loc, clampedOff, targetOffset);
d1 = rewriter.create<arith::DivSIOp>(loc, d1, strideRow);
d1 = rewriter.create<arith::MinSIOp>(loc, d1, rowSize);

SmallVector<Value> sizes1{d1, colSize};
memref::ReinterpretCastOp cast1 =
Expand Down Expand Up @@ -685,11 +710,10 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
ConversionPatternRewriter &rewriter) const {
OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult col1 =
OpFoldResult subviewCol1 =
rewriter.create<memref::DimOp>(loc, block1, 1).getResult();
OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter);
OpFoldResult subviewCol2 =
subOFRs(subviewColFull, subviewCol1, loc, rewriter);
rewriter.create<memref::DimOp>(loc, block2, 1).getResult();

SmallVector<OpFoldResult> offsets(dims.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(dims.size(), rewriter.getIndexAttr(1));
Expand All @@ -707,11 +731,10 @@ struct LoadConverter : public OpConversionPattern<tts::LoadOp> {
ConversionPatternRewriter &rewriter) const {
OpFoldResult subviewRowFull = dims[0];
OpFoldResult subviewColFull = dims[1];
OpFoldResult row1 =
OpFoldResult subviewRow1 =
rewriter.create<memref::DimOp>(loc, block1, 0).getResult();
OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter);
OpFoldResult subviewRow2 =
subOFRs(subviewRowFull, subviewRow1, loc, rewriter);
rewriter.create<memref::DimOp>(loc, block2, 0).getResult();

SmallVector<OpFoldResult> offsets(dims.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(dims.size(), rewriter.getIndexAttr(1));
Expand Down
161 changes: 161 additions & 0 deletions python/examples/test_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@

import torch
import triton
import pytest
import triton.language as tl
import benchmark

@triton.jit
def prev_multiple_of(a, b):
# the largest x<a that x%b ==0
return tl.cdiv(a, b) * b - b

@triton.jit
def mm_kernel(
A,
B,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
prev_multiple = prev_multiple_of(K, BLOCK_K)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for start_k in range(0, prev_multiple, BLOCK_K):
rk = start_k + tl.arange(0, BLOCK_K)
a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
if a.dtype != b.dtype:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)

# loop peeling
rk = prev_multiple + tl.arange(0, BLOCK_K)
mask_k = rk < K
a = tl.load(
A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
mask=mask_k[None, :],
other=0.0
)
b = tl.load(
B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
mask=mask_k[:, None],
other=0.0
)
if a.dtype != b.dtype:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)

acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
tl.store(C, acc, mask=mask)


_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]


def get_higher_dtype(a, b):
if a is b:
return a

assert a in _ordered_datatypes
assert b in _ordered_datatypes

for d in _ordered_datatypes:
if a is d:
return b
if b is d:
return a


def mm(a, b):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c_dtype = get_higher_dtype(a.dtype, b.dtype)
c = torch.empty((M, N), device=device, dtype=c_dtype)
# launch kernel
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)

mm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
32,
32,
32,
GROUP_M=8,
)

return c

@pytest.mark.interpreter
@pytest.mark.parametrize("M, N, K", [(1, 1, 32), (15, 160, 1024), (495, 5333, 71)])
@pytest.mark.parametrize("dtype", [torch.float32])
def test_accuracy_mm(M, N, K, dtype):
device = 'cpu'
a = torch.randn((M, K), dtype=dtype, device=device)
b = torch.randn((K, N), dtype=dtype, device=device)

ref_out = torch.mm(a, b)
res_out = mm(a, b)

torch.testing.assert_close(res_out, ref_out, atol=1e-2, rtol=0)


if __name__ == "__main__":
benchmark.select_cpu_backend()
M, N, K = (495, 5333, 71)
test_accuracy_mm(M, N, K, torch.float32)
20 changes: 9 additions & 11 deletions test/Conversion/StructuredToMemref/wraparound_side_by_side.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,18 @@ module {
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_15_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_19_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
// CHECK: [[VAR_20_:%.+]] = arith.minsi [[VAR_18_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_21_:%.+]] = arith.subi [[CST_4_]], [[VAR_20_]] : index
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_21_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_20_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_20_]]{{.}} [2, [[VAR_21_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>>
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] [2, [[VAR_19_]]{{.}} [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_18_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_18_]]{{.}} [2, [[VAR_19_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>>
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>>
// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>>
// CHECK: [[VAR_22_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
// CHECK: bufferization.materialize_in_destination [[VAR_22_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
// CHECK-DAG: [[VAR_23_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
// CHECK-DAG: [[VAR_24_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_11_]] : index
// CHECK: scf.yield [[VAR_23_]], [[VAR_24_]] : index, index
// CHECK: [[VAR_20_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
// CHECK: bufferization.materialize_in_destination [[VAR_20_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
// CHECK-DAG: [[VAR_21_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
// CHECK-DAG: [[VAR_22_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_11_]] : index
// CHECK: scf.yield [[VAR_21_]], [[VAR_22_]] : index, index
// CHECK: }
// CHECK: return
// CHECK: }
27 changes: 13 additions & 14 deletions test/Conversion/StructuredToMemref/wraparound_stacked.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,25 @@ module {
// CHECK: [[VAR_13_:%.+]] = arith.addi [[VAR_3_]], [[VAR_12_]] : index
// CHECK: [[VAR_14_:%.+]] = arith.subi [[VAR_13_]], [[VAR_11_]] : index
// CHECK: [[VAR_15_:%.+]] = arith.divsi [[VAR_14_]], [[VAR_1_]] : index
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_15_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_16_:%.+]] = arith.subi [[CST_4_]], [[VAR_15_]] : index
// CHECK: [[VAR_16_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_11_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_17_:%.+]] = arith.subi [[CST_4_]], [[VAR_16_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_16_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_12_]]{{.}}, sizes: {{.}}[[VAR_17_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_4_]]{{.}} : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32>
// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>)
// CHECK: [[VAR_17_:%.+]] = arith.minsi [[VAR_15_]], [[CST_4_]] : index
// CHECK-DAG: [[VAR_18_:%.+]] = arith.subi [[CST_4_]], [[VAR_17_]] : index
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_reinterpret_cast_0_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_17_]], 0] {{.}}[[VAR_18_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1], offset: ?>>
// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_reinterpret_cast_1_]][0, 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[?, ?], offset: ?>>
// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_16_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1]>>
// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_16_]], 0] {{.}}[[VAR_17_]], 3] [1, 1] : memref<4x4xf32> to memref<?x3xf32, strided<[4, 1], offset: ?>>
// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_3_]] : memref<?x3xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[4, 1]>>
// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<?x3xf32, strided<[?, ?], offset: ?>> to memref<?x3xf32, strided<[4, 1], offset: ?>>
// CHECK: [[VAR_19_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
// CHECK: bufferization.materialize_in_destination [[VAR_19_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
// CHECK-DAG: [[VAR_20_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
// CHECK-DAG: [[VAR_21_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_9_]] : index
// CHECK: scf.yield [[VAR_20_]], [[VAR_21_]] : index, index
// CHECK: [[VAR_18_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32>
// CHECK: bufferization.materialize_in_destination [[VAR_18_]] in writable [[VAR_reinterpret_cast_]] : (tensor<4x4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>>) -> ()
// CHECK-DAG: [[VAR_19_:%.+]] = arith.addi [[VAR_arg15_]], [[VAR_9_]] : index
// CHECK-DAG: [[VAR_20_:%.+]] = arith.addi [[VAR_arg16_]], [[VAR_9_]] : index
// CHECK: scf.yield [[VAR_19_]], [[VAR_20_]] : index, index
// CHECK: }
// CHECK: return
// CHECK: }
Loading