Skip to content

[mlir][spirv] Add support for Constant Matrices #123334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,

void spirv::ConstantOp::print(OpAsmPrinter &printer) {
printer << ' ' << getValue();
if (llvm::isa<spirv::ArrayType>(getType()))
if (llvm::isa<spirv::ArrayType, spirv::MatrixType>(getType()))
printer << " : " << getType();
}

Expand Down Expand Up @@ -626,18 +626,49 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
}
return success();
}
if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
if (!arrayType)
return op.emitOpError(
"must have spirv.array result type for array value");
Type elemType = arrayType.getElementType();
for (Attribute element : arrayAttr.getValue()) {
// Verify array elements recursively.
if (failed(verifyConstantType(op, element, elemType)))
return failure();
if (auto arrayAttr = mlir::dyn_cast<ArrayAttr>(value)) {
// Case for Matrix result type
if (auto matrixType = mlir::dyn_cast<spirv::MatrixType>(opType)) {
unsigned numColumns = matrixType.getNumColumns();
unsigned numRows = matrixType.getNumRows();
if (arrayAttr.size() != numColumns)
return op.emitOpError("expected ")
<< numColumns << " columns in matrix constant, but got "
<< arrayAttr.size();

Type elementTy = matrixType.getElementType();
for (auto [colIndex, colAttr] : llvm::enumerate(arrayAttr)) {
// Ensure each column is a dense array of the right shape/type
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(colAttr);
if (!denseAttr)
return op.emitOpError("matrix column #")
<< colIndex << " must be a DenseElementsAttr";

auto shapedTy = mlir::dyn_cast<ShapedType>(denseAttr.getType());
if (!shapedTy || shapedTy.getNumElements() != numRows)
return op.emitOpError("matrix column #")
<< colIndex << " has incorrect size: expected "
<< numRows << " elements";

if (shapedTy.getElementType() != elementTy)
return op.emitOpError("matrix column #")
<< colIndex << " has incorrect element type: expected "
<< elementTy << ", got " << shapedTy.getElementType();
}
return success();
}
return success();
// Case for Array result type
if (auto arrayType = mlir::dyn_cast<spirv::ArrayType>(opType)) {
Type elemType = arrayType.getElementType();
for (Attribute element : arrayAttr.getValue()) {
// Verify array elements recursively.
if (failed(verifyConstantType(op, element, elemType)))
return failure();
}
return success();
}
return op.emitOpError(
"must have spirv.array or spirv.matrix result type for array value");
}
return op.emitOpError("cannot have attribute: ") << value;
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,9 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
} else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
auto attr = opBuilder.getArrayAttr(elements);
constantMap.try_emplace(resultID, attr, resultType);
} else if (auto matrixType = dyn_cast<spirv::MatrixType>(resultType)) {
auto attr = opBuilder.getArrayAttr(elements);
constantMap.try_emplace(resultID, attr, resultType);
} else {
return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
<< resultType;
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,22 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
SmallVector<uint64_t, 4> index(rank);
resultID = prepareDenseElementsConstant(loc, constType, attr,
/*dim=*/0, index);
} else if (isa<spirv::MatrixType>(constType)) {
if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(arrayAttr.size() + 2);
for (Attribute elementAttr : arrayAttr) {
if (auto elementID = prepareConstant(loc,
cast<spirv::MatrixType>(constType).getColumnType(), elementAttr)) {
operands.push_back(elementID);
} else {
return 0;
}
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
}
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
resultID = prepareArrayConstant(loc, constType, arrayAttr);
}
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
return %0: vector<3xf32>
}

// CHECK-LABEL: func @composite_construct_matrix
func.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
// CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
%0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
return %0: !spirv.matrix<3 x vector<3xf32>>
}

// CHECK-LABEL: func @composite_construct_struct
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
// CHECK: spirv.CompositeConstruct
Expand Down Expand Up @@ -89,9 +96,31 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
%0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
return %0: vector<4xf32>
}
// -----

func.func @composite_construct_matrix_wrong_column_count(%v1: vector<3xf32>, %v2: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
// expected-error @+1 {{'spirv.CompositeConstruct' op expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
%0 = spirv.CompositeConstruct %v1, %v2 : (vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
return %0: !spirv.matrix<3 x vector<3xf32>>
}

// -----

func.func @composite_construct_matrix_wrong_row_count(%v1: vector<4xf32>, %v2: vector<4xf32>, %v3: vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
// expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<4xf32>'}}
%0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>>
return %0: !spirv.matrix<3 x vector<3xf32>>
}

// -----

func.func @composite_construct_matrix_wrong_element_type(%v1: vector<3xi32>, %v2: vector<3xi32>, %v3: vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>> {
// expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<3xi32>'}}
%0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xi32>, vector<3xi32>, vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>>
return %0: !spirv.matrix<3 x vector<3xf32>>
}
// -----

//===----------------------------------------------------------------------===//
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//
Expand Down
28 changes: 27 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func.func @const() -> () {
// CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
// CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
// CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
// CHECK: spirv.Constant [dense<1.000000e+00> : vector<3xf32>, dense<2.000000e+00> : vector<3xf32>, dense<3.000000e+00> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>

%0 = spirv.Constant true
%1 = spirv.Constant 42 : i32
Expand All @@ -73,6 +74,7 @@ func.func @const() -> () {
%7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
%8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
%9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>>
%10 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>, dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
return
}

Expand All @@ -95,7 +97,7 @@ func.func @array_constant() -> () {
// -----

func.func @array_constant() -> () {
// expected-error @+1 {{must have spirv.array result type for array value}}
// expected-error @+1 {{'spirv.Constant' op must have spirv.array or spirv.matrix result type for array value}}
%0 = spirv.Constant [dense<3.0> : vector<2xf32>] : !spirv.rtarray<vector<2xf32>>
return
}
Expand Down Expand Up @@ -132,6 +134,30 @@ func.func @value_result_num_elements_mismatch() -> () {

// -----

func.func @matrix_constant() -> () {
// CHECK: spirv.Constant [dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<3xf32>, dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<3xf32>, dense<[7.000000e+00, 8.000000e+00, 9.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
%0 = spirv.Constant [dense<[1.0, 2.0, 3.0]> : vector<3xf32>, dense<[4.0, 5.0, 6.0]> : vector<3xf32>, dense<[7.0, 8.0, 9.0]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
return
}

// -----

func.func @matrix_constant_wrong_column_count() -> () {
// expected-error @+1 {{expected 3 columns in matrix constant, but got 2}}
%0 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
return
}

// -----

func.func @matrix_constant_non_dense_column() -> () {
// expected-error @+1 {{matrix column #1 must be a DenseElementsAttr}}
%0 = spirv.Constant [dense<1.0> : vector<3xf32>, "wrong", dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/SPIRV/composite-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
spirv.ReturnValue %0: vector<3xf32>
}
spirv.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> "None" {
// CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
%0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
spirv.ReturnValue %0: !spirv.matrix<3 x vector<3xf32>>
}
spirv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {
// CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
%0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Target/SPIRV/constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}

// CHECK-LABEL: @matrix_const
spirv.func @matrix_const() -> () "None" {
// CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
%0 = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
// CHECK: %[[CST:.*]] = spirv.Constant [dense<[1.000000e+00, 0.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 1.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 0.000000e+00, 1.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
%1 = spirv.Constant [dense<[1., 0., 0.]> : vector<3xf32>, dense<[0., 1., 0.]> : vector<3xf32>, dense<[0., 0., 1.]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
// CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.matrix<3 x vector<3xf32>>
spirv.Store "Function" %0, %1 : !spirv.matrix<3 x vector<3xf32>>
spirv.Return
}

// CHECK-LABEL: @ui64_array_const
spirv.func @ui64_array_const() -> (!spirv.array<3xui64>) "None" {
// CHECK: spirv.Constant [5, 6, 7] : !spirv.array<3 x i64>
Expand Down
Loading