Skip to content

[mlir][spirv] Add support for Aligned memory operand in CoopMatrix memory operations #145480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
}];

let assemblyFormat = [{
$pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
$pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
type(operands) `->` type($result)
}];

Expand All @@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];

// TODO: Add scope operand for MakePointer*. See #145485.
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
SPIRV_Integer:$stride,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
OptionalAttr<I32Attr>:$alignment
);

let results = (outs
Expand All @@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, result, pointer, layout, stride,
spirv::MemoryAccessAttr{});
spirv::MemoryAccessAttr{}, IntegerAttr{});
}]>
];
}
Expand Down Expand Up @@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
}];

let assemblyFormat = [{
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
type(operands)
}];

Expand All @@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];

// TODO: Add scope operand for MakePointer*. See #145485.
let arguments = (ins
SPIRV_AnyPtr:$pointer,
SPIRV_AnyCooperativeMatrix:$object,
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
SPIRV_Integer:$stride,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
OptionalAttr<I32Attr>:$alignment
);

let results = (outs);
Expand All @@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
"spirv::ConstantOp":$stride,
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
build($_builder, $_state, pointer, object, layout, stride,
spirv::MemoryAccessAttr{});
spirv::MemoryAccessAttr{}, IntegerAttr{});
}]>
];
}
Expand Down
28 changes: 18 additions & 10 deletions mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace mlir::spirv {

static LogicalResult
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
spirv::MemoryAccessAttr memoryOperand) {
spirv::MemoryAccessAttr memoryOperand,
IntegerAttr alignment) {
auto pointerType = cast<PointerType>(pointer);
Type pointeeType = pointerType.getPointeeType();
if (!isa<ScalarType, VectorType>(pointeeType)) {
Expand All @@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
"not compatible with memory operand 'MakePointerVisible'");
}

// The 'Aligned' memory operand requires an alignment literal to follow,
// which needs to be implemented on the level of op parsing and
// (de-)serialization.
// TODO: Consider adding support for this attribute value.
if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
spirv::MemoryAccess::Aligned)) {
return op->emitOpError("has unhandled memory operand 'Aligned'");
// TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
// #145485.

if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
!alignment) {
return op->emitOpError("missing value for the 'Aligned' memory operand");
}

if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
alignment) {
return op->emitOpError(
"found alignment attribute for non-'Aligned' memory operand");
}
}

Expand All @@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,

LogicalResult KHRCooperativeMatrixLoadOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
getResult().getType(), getMemoryOperandAttr());
getResult().getType(), getMemoryOperandAttr(),
getAlignmentAttr());
}

//===----------------------------------------------------------------------===//
Expand All @@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {

LogicalResult KHRCooperativeMatrixStoreOp::verify() {
return verifyCoopMatrixAccess(*this, getPointer().getType(),
getObject().getType(), getMemoryOperandAttr());
getObject().getType(), getMemoryOperandAttr(),
getAlignmentAttr());
}

//===----------------------------------------------------------------------===//
Expand Down
34 changes: 31 additions & 3 deletions mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_aligned
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
Expand Down Expand Up @@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBu
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_aligned
spirv.func @cooperative_matrix_store_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
spirv.Return
}

// -----

spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
Expand Down Expand Up @@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuf
// -----

spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
// expected-error @+1 {{missing value for the 'Aligned' memory operand}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
Expand All @@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer
// -----

spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
// expected-error @+1 {{missing value for the 'Aligned' memory operand}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
spirv.Return
Expand Down Expand Up @@ -179,14 +198,23 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageB

spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
// expected-error @+1 {{missing value for the 'Aligned' memory operand}}
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
spirv.Return
}

// -----

spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}}
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <MakePointerVisible>, 16 :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}

// -----

spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_3
spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_1
spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
%m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
Expand All @@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32

// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32

// CHECK-NEXT: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
Expand Down
Loading