Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 8c152c5

Browse files
Mahesh Ravishankartensorflower-gardener
authored andcommitted
Introduce attributes that specify the final ABI for a spirv::ModuleOp.
To simplify the lowering into SPIR-V, while still respecting the ABI requirements of SPIR-V/Vulkan, split the process into two 1) While lowering a function to SPIR-V (when the function is an entry point function), allow specifying attributes on arguments and function itself that describe the ABI of the function. 2) Add a pass that materializes the ABI described in the function. Two attributes are needed. 1) Attribute on arguments of the entry point function that describe the descriptor_set, binding, storage class, etc, of the spv.globalVariable this argument will be replaced by 2) Attribute on function that specifies workgroup size, etc. (for now only workgroup size). Add the pass -spirv-lower-abi-attrs to materialize the ABI described by the attributes. This change makes the SPIRVBasicTypeConverter class unnecessary and is removed, further simplifying the SPIR-V lowering path. PiperOrigin-RevId: 282387587
1 parent 19cb0d1 commit 8c152c5

File tree

21 files changed

+662
-292
lines changed

21 files changed

+662
-292
lines changed

include/mlir/Dialect/SPIRV/LayoutUtils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ class VulkanLayoutUtils {
5656
public:
5757
using Size = uint64_t;
5858

59-
/// Returns a new type with layout info. Assigns the type size in bytes to the
60-
/// `size`. Assigns the type alignment in bytes to the `alignment`.
61-
static Type decorateType(spirv::StructType structType, Size &size,
62-
Size &alignment);
59+
/// Returns a new StructType with layout info. Assigns the type size in bytes
60+
/// to the `size`. Assigns the type alignment in bytes to the `alignment`.
61+
static spirv::StructType decorateType(spirv::StructType structType,
62+
Size &size, Size &alignment);
6363
/// Checks whether a type is legal in terms of Vulkan layout info
6464
/// decoration. A type is dynamically illegal if it's a composite type in the
6565
/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant Storage

include/mlir/Dialect/SPIRV/Passes.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,22 @@
2727
namespace mlir {
2828
namespace spirv {
2929

30-
// Creates a module pass that converts composite types used by objects in the
31-
// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
32-
// classes with layout information.
33-
//
34-
// Right now this pass only supports Vulkan layout rules.
30+
class ModuleOp;
31+
/// Creates a module pass that converts composite types used by objects in the
32+
/// StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage
33+
/// classes with layout information.
34+
/// Right now this pass only supports Vulkan layout rules.
3535
std::unique_ptr<OpPassBase<mlir::ModuleOp>>
3636
createDecorateSPIRVCompositeTypeLayoutPass();
3737

38+
/// Creates a module pass that lowers the ABI attributes specified during SPIR-V
39+
/// Lowering. Specifically,
40+
/// 1) Creates the global variables for arguments of entry point function using
41+
/// the specification in the ABI attributes for each argument.
42+
/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point
43+
/// functions using the specification in the EntryPointAttr.
44+
std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass();
45+
3846
} // namespace spirv
3947
} // namespace mlir
4048

include/mlir/Dialect/SPIRV/SPIRVLowering.h

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,22 @@
2323
#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
2424

2525
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
26+
#include "mlir/IR/Attributes.h"
2627
#include "mlir/Support/StringExtras.h"
2728
#include "mlir/Transforms/DialectConversion.h"
2829
#include "llvm/ADT/SetVector.h"
2930

3031
namespace mlir {
3132

32-
/// Type conversion from Standard Types to SPIR-V Types.
33-
class SPIRVBasicTypeConverter : public TypeConverter {
34-
public:
35-
/// Converts types to SPIR-V supported types.
36-
virtual Type convertType(Type t);
37-
};
38-
3933
/// Converts a function type according to the requirements of a SPIR-V entry
4034
/// function. The arguments need to be converted to spv.GlobalVariables of
4135
/// spv.ptr types so that they could be bound by the runtime.
4236
class SPIRVTypeConverter final : public TypeConverter {
4337
public:
44-
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
45-
: basicTypeConverter(basicTypeConverter) {}
38+
using TypeConverter::TypeConverter;
4639

4740
/// Converts types to SPIR-V types using the basic type converter.
4841
Type convertType(Type t) override;
49-
50-
/// Gets the basic type converter.
51-
Type convertBasicType(Type t) { return basicTypeConverter->convertType(t); }
52-
53-
private:
54-
SPIRVBasicTypeConverter *basicTypeConverter;
5542
};
5643

5744
/// Base class to define a conversion pattern to translate Ops into SPIR-V.
@@ -70,21 +57,35 @@ class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
7057
private:
7158
};
7259

60+
#include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
61+
7362
namespace spirv {
7463
/// Returns a value that represents a builtin variable value within the SPIR-V
7564
/// module.
7665
Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
7766
OpBuilder &builder);
7867

7968
/// Legalizes a function as an entry function.
80-
LogicalResult lowerAsEntryFunction(FuncOp funcOp,
81-
SPIRVTypeConverter *typeConverter,
82-
ConversionPatternRewriter &rewriter,
83-
FuncOp &newFuncOp);
84-
85-
/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
86-
/// spv.ExecutionMode ops.
87-
LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
69+
FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
70+
ConversionPatternRewriter &rewriter,
71+
ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
72+
spirv::EntryPointABIAttr entryPointInfo);
73+
74+
/// Attribute name for specifying argument ABI information.
75+
StringRef getInterfaceVarABIAttrName();
76+
77+
/// Get the InterfaceVarABIAttr given its fields.
78+
InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet,
79+
unsigned binding,
80+
spirv::StorageClass storageClass,
81+
MLIRContext *context);
82+
83+
/// Attribute name for specifying entry point information.
84+
StringRef getEntryPointABIAttrName();
85+
86+
/// Get the EntryPointABIAttr given its fields.
87+
EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
88+
MLIRContext *context);
8889

8990
} // namespace spirv
9091
} // namespace mlir
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
2+
//
3+
// Copyright 2019 The MLIR Authors.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
// =============================================================================
17+
//
18+
// This is the base file for supporting lowering to SPIR-V dialect. This
19+
// file defines SPIR-V attributes used for specifying the shader
20+
// interface or ABI. This is because SPIR-V module is expected to work in
21+
// an execution environment as specified by a client API. A SPIR-V module
22+
// needs to "link" correctly with the execution environment regarding the
23+
// resources that are used in the SPIR-V module and get populated with
24+
// data via the client API. The shader interface (or ABI) is passed into
25+
// SPIR-V lowering path via attributes defined in this file. A
26+
// compilation flow targeting SPIR-V is expected to attach such
27+
// attributes to resources and other suitable places.
28+
//
29+
//===----------------------------------------------------------------------===//
30+
31+
#ifndef SPIRV_LOWERING
32+
#define SPIRV_LOWERING
33+
34+
include "mlir/Dialect/SPIRV/SPIRVBase.td"
35+
36+
// For arguments that eventually map to spv.globalVariable for the
37+
// shader interface, this attribute specifies the information regarding
38+
// the global variable :
39+
// 1) Descriptor Set.
40+
// 2) Binding number.
41+
// 3) Storage class.
42+
def SPV_InterfaceVarABIAttr:
43+
StructAttr<"InterfaceVarABIAttr", SPV_Dialect,
44+
[StructFieldAttr<"descriptor_set", I32Attr>,
45+
StructFieldAttr<"binding", I32Attr>,
46+
StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>;
47+
48+
// For entry functions, this attribute specifies information related to entry
49+
// points in the generated SPIR-V module:
50+
// 1) WorkGroup Size.
51+
def SPV_EntryPointABIAttr:
52+
StructAttr<"EntryPointABIAttr", SPV_Dialect,
53+
[StructFieldAttr<"local_size", I32ElementsAttr>]>;
54+
55+
#endif // SPIRV_LOWERING

include/mlir/Dialect/SPIRV/SPIRVStructureOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,10 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
243243
"TypeAttr type, ArrayRef<NamedAttribute> namedAttrs", [{
244244
state.addAttribute("type", type);
245245
state.addAttributes(namedAttrs);
246-
}]>
246+
}]>,
247+
OpBuilder<[{Builder *builder, OperationState &state,
248+
Type type, StringRef name, unsigned descriptorSet,
249+
unsigned binding}]>
247250
];
248251

249252
let results = (outs);

include/mlir/IR/OpBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,10 @@ class ArrayMinCount<int n> : AttrConstraint<
12301230
CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,
12311231
"with at least " # n # " elements">;
12321232

1233+
class ArrayCount<int n> : AttrConstraint<
1234+
CPred<"$_self.cast<ArrayAttr>().size() == " #n>,
1235+
"with exactly " # n # " elements">;
1236+
12331237
class IntArrayNthElemEq<int index, int value> : AttrConstraint<
12341238
And<[
12351239
CPred<"$_self.cast<ArrayAttr>().size() > " # index>,

lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,26 @@ PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
163163
PatternMatchResult
164164
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
165165
ConversionPatternRewriter &rewriter) const {
166-
FuncOp newFuncOp;
167166
if (!gpu::GPUDialect::isKernel(funcOp)) {
168167
return matchFailure();
169168
}
170169

171-
if (failed(spirv::lowerAsEntryFunction(funcOp, &typeConverter, rewriter,
172-
newFuncOp))) {
170+
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
171+
for (auto argNum : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
172+
argABI.push_back(spirv::getInterfaceVarABIAttr(
173+
0, argNum, spirv::StorageClass::StorageBuffer, rewriter.getContext()));
174+
}
175+
// TODO(ravishankarm) : For now set this to {32, 1, 1}. This is incorrect. The
176+
// actual workgroup size needs to be plumbed through.
177+
auto context = rewriter.getContext();
178+
auto entryPointAttr = spirv::getEntryPointABIAttr({32, 1, 1}, context);
179+
FuncOp newFuncOp = spirv::lowerAsEntryFunction(
180+
funcOp, typeConverter, rewriter, argABI, entryPointAttr);
181+
if (!newFuncOp) {
173182
return matchFailure();
174183
}
184+
newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(),
185+
rewriter.getContext()));
175186
return matchSuccess();
176187
}
177188

lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void GPUToSPIRVPass::runOnModule() {
5454
if (!gpu::GPUDialect::isKernel(funcOp)) {
5555
return;
5656
}
57-
OpBuilder builder(module.getBodyRegion());
57+
OpBuilder builder(funcOp.getOperation());
5858
// Create a new spirv::ModuleOp for this function, and clone the
5959
// function into it.
6060
// TODO : Generalize this to account for different extensions,
@@ -77,45 +77,20 @@ void GPUToSPIRVPass::runOnModule() {
7777
});
7878

7979
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
80-
SPIRVBasicTypeConverter basicTypeConverter;
81-
SPIRVTypeConverter typeConverter(&basicTypeConverter);
80+
SPIRVTypeConverter typeConverter;
8281
OwningRewritePatternList patterns;
8382
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
8483
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
8584

8685
ConversionTarget target(*context);
8786
target.addLegalDialect<spirv::SPIRVDialect>();
88-
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
89-
// TODO(ravishankarm) : Currently lowering does not support handling
90-
// function conversion of non-kernel functions. This is to be added.
91-
92-
// For kernel functions, verify that the signature is void(void).
93-
return gpu::GPUDialect::isKernel(op) && op.getNumResults() == 0 &&
94-
op.getNumArguments() == 0;
95-
});
87+
target.addDynamicallyLegalOp<FuncOp>(
88+
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
9689

9790
if (failed(applyFullConversion(spirvModules, target, patterns,
9891
&typeConverter))) {
9992
return signalPassFailure();
10093
}
101-
102-
// After the SPIR-V modules have been generated, some finalization is needed
103-
// for the entry functions. For example, adding spv.EntryPoint op,
104-
// spv.ExecutionMode op, etc.
105-
for (auto *spvModule : spirvModules) {
106-
for (auto op :
107-
cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
108-
if (gpu::GPUDialect::isKernel(op)) {
109-
OpBuilder builder(op.getContext());
110-
builder.setInsertionPointAfter(op);
111-
if (failed(spirv::finalizeEntryFunction(op, builder))) {
112-
return signalPassFailure();
113-
}
114-
op.getOperation()->removeAttr(Identifier::get(
115-
gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
116-
}
117-
}
118-
}
11994
}
12095

12196
OpPassBase<ModuleOp> *createConvertGPUToSPIRVPass() {

lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
6363
return matchFailure();
6464
}
6565
auto spirvConstType =
66-
typeConverter.convertBasicType(constIndexOp.getResult()->getType());
66+
typeConverter.convertType(constIndexOp.getResult()->getType());
6767
auto spirvConstVal =
6868
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
6969
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
@@ -120,7 +120,7 @@ class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
120120
matchAndRewrite(StdOp operation, ArrayRef<Value *> operands,
121121
ConversionPatternRewriter &rewriter) const override {
122122
auto resultType =
123-
this->typeConverter.convertBasicType(operation.getResult()->getType());
123+
this->typeConverter.convertType(operation.getResult()->getType());
124124
rewriter.template replaceOpWithNewOp<SPIRVOp>(
125125
operation, resultType, operands, ArrayRef<NamedAttribute>());
126126
return this->matchSuccess();

lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ void ConvertStandardToSPIRVPass::runOnModule() {
4040
OwningRewritePatternList patterns;
4141
auto module = getModule();
4242

43-
SPIRVBasicTypeConverter basicTypeConverter;
44-
SPIRVTypeConverter typeConverter(&basicTypeConverter);
43+
SPIRVTypeConverter typeConverter;
4544
populateStandardToSPIRVPatterns(module.getContext(), typeConverter, patterns);
4645
ConversionTarget target(*(module.getContext()));
4746
target.addLegalDialect<spirv::SPIRVDialect>();

0 commit comments

Comments
 (0)