Skip to content

[mlir][spirv] Add SPV_KHR_float_controls2 extension #143974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
def SPV_KHR_float_controls2 : I32EnumAttrCase<"SPV_KHR_float_controls2", 32>;

def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
Expand Down Expand Up @@ -469,7 +470,8 @@ def SPIRV_ExtensionAttr :
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
SPV_KHR_float_controls2
]>;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -850,6 +852,11 @@ def SPIRV_C_BitInstructions : I32EnumAttrCase<"BitIn
Extension<[SPV_KHR_bit_instructions]>
];
}
def SPIRV_C_FloatControls2 : I32EnumAttrCase<"FloatControls2", 6029> {
list<Availability> availability = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a MinVersion here? It seems like:

This extension requires SPIR-V 1.2.

Source

Extension<[SPV_KHR_float_controls2]>
];
}
def SPIRV_C_AtomicFloat32AddEXT : I32EnumAttrCase<"AtomicFloat32AddEXT", 6033> {
list<Availability> availability = [
Extension<[SPV_EXT_shader_atomic_float_add]>
Expand Down Expand Up @@ -1461,8 +1468,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_USMStorageClassesINTEL, SPIRV_C_IOPipesINTEL, SPIRV_C_BlockingPipesINTEL,
SPIRV_C_FPGARegINTEL, SPIRV_C_DotProductInputAll,
SPIRV_C_DotProductInput4x8BitPacked, SPIRV_C_DotProduct, SPIRV_C_RayCullMaskKHR,
SPIRV_C_CooperativeMatrixKHR,
SPIRV_C_BitInstructions, SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_CooperativeMatrixKHR, SPIRV_C_BitInstructions, SPIRV_C_FloatControls2,
SPIRV_C_AtomicFloat32AddEXT, SPIRV_C_AtomicFloat64AddEXT,
SPIRV_C_LongConstantCompositeINTEL, SPIRV_C_OptNoneINTEL,
SPIRV_C_AtomicFloat16AddEXT, SPIRV_C_DebugInfoModuleINTEL, SPIRV_C_SplitBarrierINTEL,
SPIRV_C_GroupUniformArithmeticKHR, SPIRV_C_Shader, SPIRV_C_Vector16,
Expand Down Expand Up @@ -2255,7 +2262,8 @@ def SPIRV_D_FuncParamAttr : I32EnumAttrCase<"FuncParamAttr"
def SPIRV_D_FPRoundingMode : I32EnumAttrCase<"FPRoundingMode", 39>;
def SPIRV_D_FPFastMathMode : I32EnumAttrCase<"FPFastMathMode", 40> {
list<Availability> availability = [
Capability<[SPIRV_C_Kernel]>
MinVersion<SPIRV_V_1_0>,
Capability<[SPIRV_C_FloatControls2, SPIRV_C_Kernel]>
];
}
def SPIRV_D_LinkageAttributes : I32EnumAttrCase<"LinkageAttributes", 41> {
Expand Down Expand Up @@ -3086,6 +3094,11 @@ def SPIRV_EM_SchedulerTargetFmaxMhzINTEL : I32EnumAttrCase<"SchedulerTarget
Capability<[SPIRV_C_FPGAKernelAttributesINTEL]>
];
}
def SPIRV_EM_FPFastMathDefault : I32EnumAttrCase<"FPFastMathDefault", 6028> {
list<Availability> availability = [
Capability<[SPIRV_C_FloatControls2]>
];
}
def SPIRV_EM_StreamingInterfaceINTEL : I32EnumAttrCase<"StreamingInterfaceINTEL", 6154> {
list<Availability> availability = [
Capability<[SPIRV_C_FPGAKernelAttributesINTEL]>
Expand Down Expand Up @@ -4781,22 +4794,27 @@ def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
def SPIRV_FPFMM_NSZ : I32BitEnumAttrCaseBit<"NSZ", 2>;
def SPIRV_FPFMM_AllowRecip : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
def SPIRV_FPFMM_Fast : I32BitEnumAttrCaseBit<"Fast", 4>;
def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
def SPIRV_FPFMM_AllowContract : I32BitEnumAttrCaseBit<"AllowContract", 16> {
list<Availability> availability = [
Capability<[SPIRV_C_FPFastMathModeINTEL, SPIRV_C_FloatControls2]>
];
}
def SPIRV_FPFMM_AllowReassoc : I32BitEnumAttrCaseBit<"AllowReassoc", 17> {
list<Availability> availability = [
Capability<[SPIRV_C_FPFastMathModeINTEL]>
Capability<[SPIRV_C_FPFastMathModeINTEL, SPIRV_C_FloatControls2]>
];
}
def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
def SPIRV_FPFMM_AllowTransform : I32BitEnumAttrCaseBit<"AllowTransform", 18> {
list<Availability> availability = [
Capability<[SPIRV_C_FPFastMathModeINTEL]>
Capability<[SPIRV_C_FloatControls2]>
];
}

def SPIRV_FPFastMathModeAttr :
SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
SPIRV_FPFMM_AllowReassocINTEL
SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContract,
SPIRV_FPFMM_AllowReassoc, SPIRV_FPFMM_AllowTransform
]>;

#endif // MLIR_DIALECT_SPIRV_IR_BASE
10 changes: 10 additions & 0 deletions mlir/test/Target/SPIRV/decorations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing the new capability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add to this, what I find useful with changes like that is to serialize the test into .spv and run it through spirv-val. This should tell you if you missed something or not.

spirv.func @fadd_floatcontrols2(%arg: f32) -> f32 "None" {
// CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<AllowContract|AllowReassoc|AllowTransform>}
%0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<AllowContract|AllowReassoc|AllowTransform>} : f32
spirv.ReturnValue %0 : f32
}
}

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
// CHECK: spirv.FMul %{{.*}}, %{{.*}} {no_contraction}
Expand Down
Loading