Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 6a48882

Browse files
antiagainsttensorflower-gardener
authored andcommitted
De-duplicate EnumAttr overrides by defining defaults
EnumAttr should provide meaningful defaults so concrete instances do not need to duplicate the fields. PiperOrigin-RevId: 282398431
1 parent 8c152c5 commit 6a48882

File tree

8 files changed

+19
-51
lines changed

8 files changed

+19
-51
lines changed

include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,7 @@ def ICmpPredicate : I64EnumAttr<
174174
[ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE,
175175
ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE,
176176
ICmpPredicateUGT, ICmpPredicateUGE]> {
177-
let cppNamespace = "mlir::LLVM";
178-
179-
let returnType = "ICmpPredicate";
180-
let convertFromStorage =
181-
"static_cast<" # returnType # ">($_self.getValue().getZExtValue())";
177+
let cppNamespace = "::mlir::LLVM";
182178
}
183179

184180
// Other integer operations.
@@ -225,11 +221,7 @@ def FCmpPredicate : I64EnumAttr<
225221
FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
226222
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
227223
]> {
228-
let cppNamespace = "mlir::LLVM";
229-
230-
let returnType = "FCmpPredicate";
231-
let convertFromStorage =
232-
"static_cast<" # returnType # ">($_self.getValue().getZExtValue())";
224+
let cppNamespace = "::mlir::LLVM";
233225
}
234226

235227
// Other integer operations.

include/mlir/Dialect/SPIRV/SPIRVBase.td

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,6 @@ def SPV_AddressingModelAttr :
326326
SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
327327
SPV_AM_PhysicalStorageBuffer64
328328
]> {
329-
let returnType = "::mlir::spirv::AddressingModel";
330-
let convertFromStorage = "static_cast<::mlir::spirv::AddressingModel>($_self.getInt())";
331329
let cppNamespace = "::mlir::spirv";
332330
}
333331

@@ -462,8 +460,6 @@ def SPV_BuiltInAttr :
462460
SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV,
463461
SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV
464462
]> {
465-
let returnType = "::mlir::spirv::BuiltIn";
466-
let convertFromStorage = "static_cast<::mlir::spirv::BuiltIn>($_self.getInt())";
467463
let cppNamespace = "::mlir::spirv";
468464
}
469465

@@ -672,8 +668,6 @@ def SPV_CapabilityAttr :
672668
SPV_C_SubgroupAvcMotionEstimationIntraINTEL,
673669
SPV_C_SubgroupAvcMotionEstimationChromaINTEL
674670
]> {
675-
let returnType = "::mlir::spirv::Capability";
676-
let convertFromStorage = "static_cast<::mlir::spirv::Capability>($_self.getInt())";
677671
let cppNamespace = "::mlir::spirv";
678672
}
679673

@@ -763,8 +757,6 @@ def SPV_DecorationAttr :
763757
SPV_D_AliasedPointer, SPV_D_CounterBuffer, SPV_D_UserSemantic,
764758
SPV_D_UserTypeGOOGLE
765759
]> {
766-
let returnType = "::mlir::spirv::Decoration";
767-
let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())";
768760
let cppNamespace = "::mlir::spirv";
769761
}
770762

@@ -781,8 +773,6 @@ def SPV_DimAttr :
781773
SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
782774
SPV_D_SubpassData
783775
]> {
784-
let returnType = "::mlir::spirv::Dim";
785-
let convertFromStorage = "static_cast<::mlir::spirv::Dim>($_self.getInt())";
786776
let cppNamespace = "::mlir::spirv";
787777
}
788778

@@ -866,8 +856,6 @@ def SPV_ExecutionModeAttr :
866856
SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT,
867857
SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT
868858
]> {
869-
let returnType = "::mlir::spirv::ExecutionMode";
870-
let convertFromStorage = "static_cast<::mlir::spirv::ExecutionMode>($_self.getInt())";
871859
let cppNamespace = "::mlir::spirv";
872860
}
873861

@@ -894,8 +882,6 @@ def SPV_ExecutionModelAttr :
894882
SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV,
895883
SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV
896884
]> {
897-
let returnType = "::mlir::spirv::ExecutionModel";
898-
let convertFromStorage = "static_cast<::mlir::spirv::ExecutionModel>($_self.getInt())";
899885
let cppNamespace = "::mlir::spirv";
900886
}
901887

@@ -909,8 +895,6 @@ def SPV_FunctionControlAttr :
909895
BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
910896
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
911897
]> {
912-
let returnType = "::mlir::spirv::FunctionControl";
913-
let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())";
914898
let cppNamespace = "::mlir::spirv";
915899
}
916900

@@ -967,8 +951,6 @@ def SPV_ImageFormatAttr :
967951
SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui,
968952
SPV_IF_R8ui
969953
]> {
970-
let returnType = "::mlir::spirv::ImageFormat";
971-
let convertFromStorage = "static_cast<::mlir::spirv::ImageFormat>($_self.getInt())";
972954
let cppNamespace = "::mlir::spirv";
973955
}
974956

@@ -979,8 +961,6 @@ def SPV_LinkageTypeAttr :
979961
I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
980962
SPV_LT_Export, SPV_LT_Import
981963
]> {
982-
let returnType = "::mlir::spirv::LinkageType";
983-
let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())";
984964
let cppNamespace = "::mlir::spirv";
985965
}
986966

@@ -1001,8 +981,6 @@ def SPV_LoopControlAttr :
1001981
SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
1002982
SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount
1003983
]> {
1004-
let returnType = "::mlir::spirv::LoopControl";
1005-
let convertFromStorage = "static_cast<::mlir::spirv::LoopControl>($_self.getInt())";
1006984
let cppNamespace = "::mlir::spirv";
1007985
}
1008986

@@ -1020,8 +998,6 @@ def SPV_MemoryAccessAttr :
1020998
SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible,
1021999
SPV_MA_NonPrivatePointer
10221000
]> {
1023-
let returnType = "::mlir::spirv::MemoryAccess";
1024-
let convertFromStorage = "static_cast<::mlir::spirv::MemoryAccess>($_self.getInt())";
10251001
let cppNamespace = "::mlir::spirv";
10261002
}
10271003

@@ -1034,8 +1010,6 @@ def SPV_MemoryModelAttr :
10341010
I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
10351011
SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
10361012
]> {
1037-
let returnType = "::mlir::spirv::MemoryModel";
1038-
let convertFromStorage = "static_cast<::mlir::spirv::MemoryModel>($_self.getInt())";
10391013
let cppNamespace = "::mlir::spirv";
10401014
}
10411015

@@ -1063,8 +1037,6 @@ def SPV_MemorySemanticsAttr :
10631037
SPV_MS_AtomicCounterMemory, SPV_MS_ImageMemory, SPV_MS_OutputMemory,
10641038
SPV_MS_MakeAvailable, SPV_MS_MakeVisible, SPV_MS_Volatile
10651039
]> {
1066-
let returnType = "::mlir::spirv::MemorySemantics";
1067-
let convertFromStorage = "static_cast<::mlir::spirv::MemorySemantics>($_self.getInt())";
10681040
let cppNamespace = "::mlir::spirv";
10691041
}
10701042

@@ -1080,8 +1052,6 @@ def SPV_ScopeAttr :
10801052
SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup,
10811053
SPV_S_Invocation, SPV_S_QueueFamily
10821054
]> {
1083-
let returnType = "::mlir::spirv::Scope";
1084-
let convertFromStorage = "static_cast<::mlir::spirv::Scope>($_self.getInt())";
10851055
let cppNamespace = "::mlir::spirv";
10861056
}
10871057

@@ -1093,8 +1063,6 @@ def SPV_SelectionControlAttr :
10931063
BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
10941064
SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
10951065
]> {
1096-
let returnType = "::mlir::spirv::SelectionControl";
1097-
let convertFromStorage = "static_cast<::mlir::spirv::SelectionControl>($_self.getInt())";
10981066
let cppNamespace = "::mlir::spirv";
10991067
}
11001068

@@ -1128,8 +1096,6 @@ def SPV_StorageClassAttr :
11281096
SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV,
11291097
SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBuffer
11301098
]> {
1131-
let returnType = "::mlir::spirv::StorageClass";
1132-
let convertFromStorage = "static_cast<::mlir::spirv::StorageClass>($_self.getInt())";
11331099
let cppNamespace = "::mlir::spirv";
11341100
}
11351101

include/mlir/IR/OpBase.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,12 +938,18 @@ class IntEnumAttr<I intType, string name, string description,
938938
class I32EnumAttr<string name, string description,
939939
list<I32EnumAttrCase> cases> :
940940
IntEnumAttr<I32, name, description, cases> {
941+
let returnType = cppNamespace # "::" # name;
941942
let underlyingType = "uint32_t";
943+
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
944+
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
942945
}
943946
class I64EnumAttr<string name, string description,
944947
list<I64EnumAttrCase> cases> :
945948
IntEnumAttr<I64, name, description, cases> {
949+
let returnType = cppNamespace # "::" # name;
946950
let underlyingType = "uint64_t";
951+
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
952+
let constBuilderCall = "$_builder.getI64IntegerAttr(static_cast<int64_t>($0))";
947953
}
948954

949955
// A bit enum stored with 32-bit IntegerAttr.
@@ -963,7 +969,10 @@ class BitEnumAttr<string name, string description,
963969
")))">
964970
]>;
965971

972+
let returnType = cppNamespace # "::" # name;
966973
let underlyingType = "uint32_t";
974+
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
975+
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
967976

968977
// We need to return a string because we may concatenate symbols for multiple
969978
// bits together.

test/lib/TestDialect/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ set(LLVM_OPTIONAL_SOURCES
66
set(LLVM_TARGET_DEFINITIONS TestOps.td)
77
mlir_tablegen(TestOps.h.inc -gen-op-decls)
88
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
9+
mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
10+
mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
911
mlir_tablegen(TestPatterns.inc -gen-rewriters)
1012
add_public_tablegen_target(MLIRTestOpsIncGen)
1113

test/lib/TestDialect/TestDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/TypeUtilities.h"
2323
#include "mlir/Transforms/FoldUtils.h"
2424
#include "mlir/Transforms/InliningUtils.h"
25+
#include "llvm/ADT/StringSwitch.h"
2526

2627
using namespace mlir;
2728

@@ -304,5 +305,7 @@ SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
304305
// Static initialization for Test dialect registration.
305306
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
306307

308+
#include "TestOpEnums.cpp.inc"
309+
307310
#define GET_OP_CLASSES
308311
#include "TestOps.cpp.inc"

test/lib/TestDialect/TestDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
#include "mlir/IR/StandardTypes.h"
3333
#include "mlir/IR/SymbolTable.h"
3434

35+
#include "TestOpEnums.h.inc"
36+
3537
namespace mlir {
3638

3739
class TestDialect : public Dialect {

test/lib/TestDialect/TestOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
694694
def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;
695695

696696
def MultiResultOpEnum: I64EnumAttr<
697-
"Multi-result op kinds", "", [
697+
"MultiResultOpEnum", "Multi-result op kinds", [
698698
MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
699699
MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
700700
]>;

utils/spirv/gen_spirv_dialect.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,6 @@ def get_case_symbol(kind_name, case_name):
200200
enum_attr = 'def SPV_{name}Attr :\n '\
201201
'{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\
202202
' ]> {{\n'\
203-
' let returnType = "::mlir::spirv::{name}";\n'\
204-
' let convertFromStorage = '\
205-
'"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
206203
' let cppNamespace = "::mlir::spirv";\n}}'.format(
207204
name=kind_name, category=kind_category, cases=case_names)
208205
return kind_name, case_defs + '\n\n' + enum_attr
@@ -240,9 +237,6 @@ def gen_opcode(instructions):
240237
' I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
241238
'{lst}\n'\
242239
' ]> {{\n'\
243-
' let returnType = "::mlir::spirv::{name}";\n'\
244-
' let convertFromStorage = '\
245-
'"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
246240
' let cppNamespace = "::mlir::spirv";\n}}'.format(
247241
name='Opcode', lst=opcode_list)
248242
return opcode_str + '\n\n' + enum_attr

0 commit comments

Comments
 (0)