Skip to content

Commit a94d184

Browse files
committed
[mlir][spirv] Add support for Aligned memory operand in CoopMatrix
1 parent 714b2fd commit a94d184

File tree

4 files changed

+73
-19
lines changed

4 files changed

+73
-19
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
112112
}];
113113

114114
let assemblyFormat = [{
115-
$pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
115+
$pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
116116
type(operands) `->` type($result)
117117
}];
118118

@@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
123123
Capability<[SPIRV_C_CooperativeMatrixKHR]>
124124
];
125125

126+
// TODO: Add scope operand for MakePointer*. See #145485.
126127
let arguments = (ins
127128
SPIRV_AnyPtr:$pointer,
128129
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
129130
SPIRV_Integer:$stride,
130-
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
131+
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
132+
OptionalAttr<I32Attr>:$alignment
131133
);
132134

133135
let results = (outs
@@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
139141
"spirv::ConstantOp":$stride,
140142
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
141143
build($_builder, $_state, result, pointer, layout, stride,
142-
spirv::MemoryAccessAttr{});
144+
spirv::MemoryAccessAttr{}, IntegerAttr{});
143145
}]>
144146
];
145147
}
@@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
194196
}];
195197

196198
let assemblyFormat = [{
197-
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
199+
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:`
198200
type(operands)
199201
}];
200202

@@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
205207
Capability<[SPIRV_C_CooperativeMatrixKHR]>
206208
];
207209

210+
// TODO: Add scope operand for MakePointer*. See #145485.
208211
let arguments = (ins
209212
SPIRV_AnyPtr:$pointer,
210213
SPIRV_AnyCooperativeMatrix:$object,
211214
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
212215
SPIRV_Integer:$stride,
213-
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
216+
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand,
217+
OptionalAttr<I32Attr>:$alignment
214218
);
215219

216220
let results = (outs);
@@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
220224
"spirv::ConstantOp":$stride,
221225
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
222226
build($_builder, $_state, pointer, object, layout, stride,
223-
spirv::MemoryAccessAttr{});
227+
spirv::MemoryAccessAttr{}, IntegerAttr{});
224228
}]>
225229
];
226230
}

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace mlir::spirv {
2323

2424
static LogicalResult
2525
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
26-
spirv::MemoryAccessAttr memoryOperand) {
26+
spirv::MemoryAccessAttr memoryOperand,
27+
IntegerAttr alignment) {
2728
auto pointerType = cast<PointerType>(pointer);
2829
Type pointeeType = pointerType.getPointeeType();
2930
if (!isa<ScalarType, VectorType>(pointeeType)) {
@@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
4950
"not compatible with memory operand 'MakePointerVisible'");
5051
}
5152

52-
// The 'Aligned' memory operand requires an alignment literal to follow,
53-
// which needs to be implemented on the level of op parsing and
54-
// (de-)serialization.
55-
// TODO: Consider adding support for this attribute value.
56-
if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
57-
spirv::MemoryAccess::Aligned)) {
58-
return op->emitOpError("has unhandled memory operand 'Aligned'");
53+
// TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
54+
// #145485.
55+
56+
if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
57+
!alignment) {
58+
return op->emitOpError("missing value for the 'Aligned' memory operand");
59+
}
60+
61+
if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
62+
alignment) {
63+
return op->emitOpError(
64+
"found alignment attribute for non-'Aligned' memory operand");
5965
}
6066
}
6167

@@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
7278

7379
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
7480
return verifyCoopMatrixAccess(*this, getPointer().getType(),
75-
getResult().getType(), getMemoryOperandAttr());
81+
getResult().getType(), getMemoryOperandAttr(),
82+
getAlignmentAttr());
7683
}
7784

7885
//===----------------------------------------------------------------------===//
@@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
8188

8289
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
8390
return verifyCoopMatrixAccess(*this, getPointer().getType(),
84-
getObject().getType(), getMemoryOperandAttr());
91+
getObject().getType(), getMemoryOperandAttr(),
92+
getAlignmentAttr());
8593
}
8694

8795
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuf
5858
spirv.Return
5959
}
6060

61+
// CHECK-LABEL: @cooperative_matrix_load_aligned
62+
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
63+
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
64+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
65+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
66+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
67+
spirv.Return
68+
}
69+
6170
// CHECK-LABEL: @cooperative_matrix_store
6271
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
6372
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
@@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBu
9099
spirv.Return
91100
}
92101

102+
// CHECK-LABEL: @cooperative_matrix_store_aligned
103+
spirv.func @cooperative_matrix_store_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
104+
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
105+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16 :
106+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
107+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
108+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
109+
spirv.Return
110+
}
111+
93112
// -----
94113

95114
spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
@@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuf
120139
// -----
121140

122141
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
123-
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
142+
// expected-error @+1 {{missing value for the 'Aligned' memory operand}}
124143
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
125144
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
126145
spirv.Return
@@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer
129148
// -----
130149

131150
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
132-
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
151+
// expected-error @+1 {{missing value for the 'Aligned' memory operand}}
133152
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
134153
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
135154
spirv.Return
@@ -179,14 +198,23 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageB
179198

180199
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
181200
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
182-
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
201+
// expected-error @+1 {{missing value for the 'Aligned' memory operand}}
183202
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
184203
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
185204
spirv.Return
186205
}
187206

188207
// -----
189208

209+
spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
210+
// expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}}
211+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <MakePointerVisible>, 16 :
212+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
213+
spirv.Return
214+
}
215+
216+
// -----
217+
190218
spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
191219
%b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
192220
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {

mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires
3030
spirv.Return
3131
}
3232

33+
// CHECK-LABEL: @cooperative_matrix_load_3
34+
spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
35+
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
36+
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
37+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Aligned>, 16 :
38+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
39+
spirv.Return
40+
}
41+
3342
// CHECK-LABEL: @cooperative_matrix_store_1
3443
spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
3544
%m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" {
@@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires
3847
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
3948
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
4049

50+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Aligned>, 16
51+
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
52+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned>, 16 :
53+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
54+
4155
// CHECK-NEXT: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
4256
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
4357
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32

0 commit comments

Comments
 (0)