diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 26559c1321db5..ee7c7860b05c4 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -579,7 +579,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, void spirv::ConstantOp::print(OpAsmPrinter &printer) { printer << ' ' << getValue(); - if (llvm::isa(getType())) + if (llvm::isa(getType())) printer << " : " << getType(); } @@ -626,18 +626,49 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, } return success(); } - if (auto arrayAttr = llvm::dyn_cast(value)) { - auto arrayType = llvm::dyn_cast(opType); - if (!arrayType) - return op.emitOpError( - "must have spirv.array result type for array value"); - Type elemType = arrayType.getElementType(); - for (Attribute element : arrayAttr.getValue()) { - // Verify array elements recursively. - if (failed(verifyConstantType(op, element, elemType))) - return failure(); + if (auto arrayAttr = mlir::dyn_cast(value)) { + // Case for Matrix result type + if (auto matrixType = mlir::dyn_cast(opType)) { + unsigned numColumns = matrixType.getNumColumns(); + unsigned numRows = matrixType.getNumRows(); + if (arrayAttr.size() != numColumns) + return op.emitOpError("expected ") + << numColumns << " columns in matrix constant, but got " + << arrayAttr.size(); + + Type elementTy = matrixType.getElementType(); + for (auto [colIndex, colAttr] : llvm::enumerate(arrayAttr)) { + // Ensure each column is a dense array of the right shape/type + auto denseAttr = mlir::dyn_cast(colAttr); + if (!denseAttr) + return op.emitOpError("matrix column #") + << colIndex << " must be a DenseElementsAttr"; + + auto shapedTy = mlir::dyn_cast(denseAttr.getType()); + if (!shapedTy || shapedTy.getNumElements() != numRows) + return op.emitOpError("matrix column #") + << colIndex << " has incorrect size: expected " + << numRows << " elements"; + + if (shapedTy.getElementType() != elementTy) + return op.emitOpError("matrix column #") + << colIndex << " has incorrect element type: expected " + << elementTy << ", got " << shapedTy.getElementType(); + } + return success(); } - return success(); + // Case for Array result type + if (auto arrayType = mlir::dyn_cast(opType)) { + Type elemType = arrayType.getElementType(); + for (Attribute element : arrayAttr.getValue()) { + // Verify array elements recursively. + if (failed(verifyConstantType(op, element, elemType))) + return failure(); + } + return success(); + } + return op.emitOpError( + "must have spirv.array or spirv.matrix result type for array value"); } return op.emitOpError("cannot have attribute: ") << value; } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 04469f1933819..ecc822e553aef 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1442,6 +1442,9 @@ spirv::Deserializer::processConstantComposite(ArrayRef operands) { } else if (auto arrayType = dyn_cast(resultType)) { auto attr = opBuilder.getArrayAttr(elements); constantMap.try_emplace(resultID, attr, resultType); + } else if (auto matrixType = dyn_cast(resultType)) { + auto attr = opBuilder.getArrayAttr(elements); + constantMap.try_emplace(resultID, attr, resultType); } else { return emitError(unknownLoc, "unsupported OpConstantComposite type: ") << resultType; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 1f4f5d7f764db..b5e3cd381ef82 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -782,6 +782,22 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType, SmallVector index(rank); resultID = prepareDenseElementsConstant(loc, constType, attr, /*dim=*/0, index); + } else if (isa(constType)) { + if (auto arrayAttr = dyn_cast(valueAttr)) { + resultID = getNextID(); + SmallVector operands = {typeID, resultID}; + operands.reserve(arrayAttr.size() + 2); + for (Attribute elementAttr : arrayAttr) { + if (auto elementID = prepareConstant(loc, + cast(constType).getColumnType(), elementAttr)) { + operands.push_back(elementID); + } else { + return 0; + } + } + spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; + encodeInstructionInto(typesGlobalValues, opcode, operands); + } } else if (auto arrayAttr = dyn_cast(valueAttr)) { resultID = prepareArrayConstant(loc, constType, arrayAttr); } diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir index 3fc8dfb2767d1..5c835d2e08de9 100644 --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve return %0: vector<3xf32> } +// CHECK-LABEL: func @composite_construct_matrix +func.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> { + // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> + %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> + return %0: !spirv.matrix<3 x vector<3xf32>> +} + // CHECK-LABEL: func @composite_construct_struct func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> { // CHECK: spirv.CompositeConstruct @@ -89,9 +96,31 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32> return %0: vector<4xf32> } +// ----- + +func.func @composite_construct_matrix_wrong_column_count(%v1: vector<3xf32>, %v2: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> { + // expected-error @+1 {{'spirv.CompositeConstruct' op expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}} + %0 = spirv.CompositeConstruct %v1, %v2 : (vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> + return %0: !spirv.matrix<3 x vector<3xf32>> +} + +// ----- + +func.func @composite_construct_matrix_wrong_row_count(%v1: vector<4xf32>, %v2: vector<4xf32>, %v3: vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>> { + // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<4xf32>'}} + %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>> + return %0: !spirv.matrix<3 x vector<3xf32>> +} // ----- +func.func @composite_construct_matrix_wrong_element_type(%v1: vector<3xi32>, %v2: vector<3xi32>, %v3: vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>> { + // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<3xi32>'}} + %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xi32>, vector<3xi32>, vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>> + return %0: !spirv.matrix<3 x vector<3xf32>> +} +// ----- + //===----------------------------------------------------------------------===// // spirv.CompositeExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir index 5e98b9fdb3c54..6003d2a3576b1 100644 --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -62,6 +62,7 @@ func.func @const() -> () { // CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> // CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> // CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> + // CHECK: spirv.Constant [dense<1.000000e+00> : vector<3xf32>, dense<2.000000e+00> : vector<3xf32>, dense<3.000000e+00> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> %0 = spirv.Constant true %1 = spirv.Constant 42 : i32 @@ -73,6 +74,7 @@ func.func @const() -> () { %7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> %8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> %9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>> + %10 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>, dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> return } @@ -95,7 +97,7 @@ func.func @array_constant() -> () { // ----- func.func @array_constant() -> () { - // expected-error @+1 {{must have spirv.array result type for array value}} + // expected-error @+1 {{'spirv.Constant' op must have spirv.array or spirv.matrix result type for array value}} %0 = spirv.Constant [dense<3.0> : vector<2xf32>] : !spirv.rtarray> return } @@ -132,6 +134,30 @@ func.func @value_result_num_elements_mismatch() -> () { // ----- +func.func @matrix_constant() -> () { + // CHECK: spirv.Constant [dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<3xf32>, dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<3xf32>, dense<[7.000000e+00, 8.000000e+00, 9.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> + %0 = spirv.Constant [dense<[1.0, 2.0, 3.0]> : vector<3xf32>, dense<[4.0, 5.0, 6.0]> : vector<3xf32>, dense<[7.0, 8.0, 9.0]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> + return +} + +// ----- + +func.func @matrix_constant_wrong_column_count() -> () { + // expected-error @+1 {{expected 3 columns in matrix constant, but got 2}} + %0 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> + return +} + +// ----- + +func.func @matrix_constant_non_dense_column() -> () { + // expected-error @+1 {{matrix column #1 must be a DenseElementsAttr}} + %0 = spirv.Constant [dense<1.0> : vector<3xf32>, "wrong", dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.EntryPoint //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir index 5f302fd0d38f8..bafdb3340d0e7 100644 --- a/mlir/test/Target/SPIRV/composite-op.mlir +++ b/mlir/test/Target/SPIRV/composite-op.mlir @@ -11,6 +11,11 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32> spirv.ReturnValue %0: vector<3xf32> } + spirv.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> "None" { + // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> + %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> + spirv.ReturnValue %0: !spirv.matrix<3 x vector<3xf32>> + } spirv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" { // CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32 %0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32 diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index f3950214a7f05..0fa70c7e5cdbb 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -198,6 +198,17 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.Return } + // CHECK-LABEL: @matrix_const + spirv.func @matrix_const() -> () "None" { + // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr>, Function> + %0 = spirv.Variable : !spirv.ptr>, Function> + // CHECK: %[[CST:.*]] = spirv.Constant [dense<[1.000000e+00, 0.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 1.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 0.000000e+00, 1.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> + %1 = spirv.Constant [dense<[1., 0., 0.]> : vector<3xf32>, dense<[0., 1., 0.]> : vector<3xf32>, dense<[0., 0., 1.]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>> + // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.matrix<3 x vector<3xf32>> + spirv.Store "Function" %0, %1 : !spirv.matrix<3 x vector<3xf32>> + spirv.Return + } + // CHECK-LABEL: @ui64_array_const spirv.func @ui64_array_const() -> (!spirv.array<3xui64>) "None" { // CHECK: spirv.Constant [5, 6, 7] : !spirv.array<3 x i64>