diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index 46732ba19afed..fd75532ae3d70 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -112,7 +112,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad" }]; let assemblyFormat = [{ - $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:` + $pointer `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:` type(operands) `->` type($result) }]; @@ -123,11 +123,13 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad" Capability<[SPIRV_C_CooperativeMatrixKHR]> ]; + // TODO: Add scope operand for MakePointer*. See #145485. let arguments = (ins SPIRV_AnyPtr:$pointer, SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout, SPIRV_Integer:$stride, - OptionalAttr:$memory_operand + OptionalAttr:$memory_operand, + OptionalAttr:$alignment ); let results = (outs @@ -139,7 +141,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad" "spirv::ConstantOp":$stride, "spirv::CooperativeMatrixLayoutKHR":$layout), [{ build($_builder, $_state, result, pointer, layout, stride, - spirv::MemoryAccessAttr{}); + spirv::MemoryAccessAttr{}, IntegerAttr{}); }]> ]; } @@ -194,7 +196,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor }]; let assemblyFormat = [{ - $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:` + $pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? ( `,` $alignment^ )? attr-dict `:` type(operands) }]; @@ -205,12 +207,14 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor Capability<[SPIRV_C_CooperativeMatrixKHR]> ]; + // TODO: Add scope operand for MakePointer*. See #145485. let arguments = (ins SPIRV_AnyPtr:$pointer, SPIRV_AnyCooperativeMatrix:$object, SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout, SPIRV_Integer:$stride, - OptionalAttr:$memory_operand + OptionalAttr:$memory_operand, + OptionalAttr:$alignment ); let results = (outs); @@ -220,7 +224,7 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor "spirv::ConstantOp":$stride, "spirv::CooperativeMatrixLayoutKHR":$layout), [{ build($_builder, $_state, pointer, object, layout, stride, - spirv::MemoryAccessAttr{}); + spirv::MemoryAccessAttr{}, IntegerAttr{}); }]> ]; } diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index 2ff3efdc96a7f..fa20cc179f892 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -23,7 +23,8 @@ namespace mlir::spirv { static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, - spirv::MemoryAccessAttr memoryOperand) { + spirv::MemoryAccessAttr memoryOperand, + IntegerAttr alignment) { auto pointerType = cast(pointer); Type pointeeType = pointerType.getPointeeType(); if (!isa(pointeeType)) { @@ -49,13 +50,18 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, "not compatible with memory operand 'MakePointerVisible'"); } - // The 'Aligned' memory operand requires an alignment literal to follow, - // which needs to be implemented on the level of op parsing and - // (de-)serialization. - // TODO: Consider adding support for this attribute value. - if (spirv::bitEnumContainsAll(memoryOperand.getValue(), - spirv::MemoryAccess::Aligned)) { - return op->emitOpError("has unhandled memory operand 'Aligned'"); + // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See + // #145485. + + if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) && + !alignment) { + return op->emitOpError("missing value for the 'Aligned' memory operand"); + } + + if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) && + alignment) { + return op->emitOpError( + "found alignment attribute for non-'Aligned' memory operand"); } } @@ -72,7 +78,8 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, LogicalResult KHRCooperativeMatrixLoadOp::verify() { return verifyCoopMatrixAccess(*this, getPointer().getType(), - getResult().getType(), getMemoryOperandAttr()); + getResult().getType(), getMemoryOperandAttr(), + getAlignmentAttr()); } //===----------------------------------------------------------------------===// @@ -81,7 +88,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() { LogicalResult KHRCooperativeMatrixStoreOp::verify() { return verifyCoopMatrixAccess(*this, getPointer().getType(), - getObject().getType(), getMemoryOperandAttr()); + getObject().getType(), getMemoryOperandAttr(), + getAlignmentAttr()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 8733ff93768ab..56d477cca97b7 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -58,6 +58,15 @@ spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr, %stride : i32) "None" { + // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, , , 16 : + // CHECK-SAME: !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , , 16 : + !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.Return +} + // CHECK-LABEL: @cooperative_matrix_store spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { @@ -90,6 +99,16 @@ spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, , , 16 : + // CHECK-SAME: !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , , 16 : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 + spirv.Return +} + // ----- spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32) "None" { @@ -120,7 +139,7 @@ spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr, %stride : i32) "None" { - // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} + // expected-error @+1 {{missing value for the 'Aligned' memory operand}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> spirv.Return @@ -129,7 +148,7 @@ spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" { - // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} + // expected-error @+1 {{missing value for the 'Aligned' memory operand}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> spirv.Return @@ -179,7 +198,7 @@ spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { - // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} + // expected-error @+1 {{missing value for the 'Aligned' memory operand}} spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 spirv.Return @@ -187,6 +206,15 @@ spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %str // ----- +spirv.func @cooperative_matrix_store_bad_operand_arg(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{found alignment attribute for non-'Aligned' memory operand}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , , 16 : + !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>, %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>, %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" { diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir index 153ff47937972..77949908e8883 100644 --- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir @@ -30,6 +30,15 @@ spirv.module Logical GLSL450 requires spirv.Return } + // CHECK-LABEL: @cooperative_matrix_load_3 + spirv.func @cooperative_matrix_load_3(%ptr : !spirv.ptr, %stride : i32) "None" { + // CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, , , 16 + // CHECK-SAME: : !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , , 16 : + !spirv.ptr, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA> + spirv.Return + } + // CHECK-LABEL: @cooperative_matrix_store_1 spirv.func @cooperative_matrix_store_1(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>) "None" { @@ -38,6 +47,11 @@ spirv.module Logical GLSL450 requires spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32 + // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, , , 16 + // CHECK-SAME: : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32 + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , , 16 : + !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32 + // CHECK-NEXT: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, , spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32