diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index d2ba76cdad904..dcb33ab907ae6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4369,6 +4369,7 @@ def SPIRV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 8 def SPIRV_OC_OpImageSampleImplicitLod : I32EnumAttrCase<"OpImageSampleImplicitLod", 87>; def SPIRV_OC_OpImageSampleExplicitLod : I32EnumAttrCase<"OpImageSampleExplicitLod", 88>; def SPIRV_OC_OpImageSampleProjDrefImplicitLod : I32EnumAttrCase<"OpImageSampleProjDrefImplicitLod", 93>; +def SPIRV_OC_OpImageFetch : I32EnumAttrCase<"OpImageFetch", 95>; def SPIRV_OC_OpImageDrefGather : I32EnumAttrCase<"OpImageDrefGather", 97>; def SPIRV_OC_OpImageRead : I32EnumAttrCase<"OpImageRead", 98>; def SPIRV_OC_OpImageWrite : I32EnumAttrCase<"OpImageWrite", 99>; @@ -4577,7 +4578,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpCompositeConstruct, SPIRV_OC_OpCompositeExtract, SPIRV_OC_OpCompositeInsert, SPIRV_OC_OpTranspose, SPIRV_OC_OpImageSampleImplicitLod, SPIRV_OC_OpImageSampleExplicitLod, - SPIRV_OC_OpImageSampleProjDrefImplicitLod, SPIRV_OC_OpImageDrefGather, + SPIRV_OC_OpImageSampleProjDrefImplicitLod, SPIRV_OC_OpImageFetch, + SPIRV_OC_OpImageDrefGather, SPIRV_OC_OpImageRead, SPIRV_OC_OpImageWrite, SPIRV_OC_OpImage, SPIRV_OC_OpImageQuerySize, SPIRV_OC_OpConvertFToU, SPIRV_OC_OpConvertFToS, SPIRV_OC_OpConvertSToF, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td index 7610966b84be3..e23efa57e5e53 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td @@ -541,4 +541,50 @@ def SPIRV_ImageSampleProjDrefImplicitLodOp : SPIRV_Op<"ImageSampleProjDrefImplic // ----- +def SPIRV_ImageFetchOp : SPIRV_Op<"ImageFetch", + [SPIRV_DimIsNot<"image", ["Cube"]>, + SPIRV_SampledOperandIs<"image", ["NeedSampler"]>, + SPIRV_NoneOrElementMatchImage<"result", "image">]> { + let summary = "Fetch a single texel from an image whose Sampled operand is 1. "; + + let description = [{ + Result Type must be a vector of four components of floating-point type or + integer type. Its components must be the same as Sampled Type of the underlying + OpTypeImage (unless that underlying Sampled Type is OpTypeVoid). + + Image must be an object whose type is OpTypeImage. Its Dim operand must not be + Cube, and its Sampled operand must be 1. + + Coordinate must be a scalar or vector of integer type. It contains (u[, v] … [, + array layer]) as needed by the definition of Sampled Image. + + Image Operands encodes what operands follow, as per Image Operands. + + + + #### Example: + + ```mlir + %0 = spirv.ImageFetch %1, %2 : !spirv.image, vector<2xsi32> -> vector<4xf32> + ``` + }]; + + let arguments = (ins + SPIRV_AnyImage:$image, + SPIRV_ScalarOrVectorOf:$coordinate, + OptionalAttr:$image_operands, + Variadic:$operand_arguments + ); + + let results = (outs + AnyTypeOf<[SPIRV_Vec4, SPIRV_Vec4]>:$result + ); + + let assemblyFormat = [{ + $image `,` $coordinate custom($image_operands) ( `,` $operand_arguments^ )? attr-dict + `:` type($image) `,` type($coordinate) ( `,` type($operand_arguments)^ )? + `->` type($result) + }]; +} + #endif // MLIR_DIALECT_SPIRV_IR_IMAGE_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp index f7af79ceefa82..661f3d5d9b81d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp @@ -332,3 +332,12 @@ LogicalResult spirv::ImageSampleProjDrefImplicitLodOp::verify() { return verifyImageOperands(getOperation(), getImageOperandsAttr(), getOperandArguments()); } + +//===----------------------------------------------------------------------===// +// spirv.ImageFetchOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ImageFetchOp::verify() { + return verifyImageOperands(getOperation(), getImageOperandsAttr(), + getOperandArguments()); +} diff --git a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir index 484a54023edc0..d3aaef7ebdef6 100644 --- a/mlir/test/Dialect/SPIRV/IR/image-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/image-ops.mlir @@ -304,6 +304,56 @@ func.func @sample_implicit_proj_dref(%arg0 : !spirv.sampled_image, %arg1: vector<2xsi32>) -> () { + // CHECK: {{%.*}} = spirv.ImageFetch {{%.*}}, {{%.*}} : !spirv.image, vector<2xsi32> -> vector<4xf32> + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<4xf32> + spirv.Return +} + +// ----- + +func.func @image_fetch_dim_cube(%arg0: !spirv.image, %arg1: vector<2xsi32>) -> () { + // expected-error @+1 {{op failed to verify that the Dim operand of the underlying image must not be Cube}} + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<4xf32> + spirv.Return +} + +// ----- + +func.func @image_fetch_no_sampler(%arg0: !spirv.image, %arg1: vector<2xsi32>) -> () { + // expected-error @+1 {{op failed to verify that the sampled operand of the underlying image must be NeedSampler}} + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<4xf16> + spirv.Return +} + +// ----- + +func.func @image_fetch_type_mismatch(%arg0: !spirv.image, %arg1: vector<2xsi32>) -> () { + // expected-error @+1 {{op failed to verify that the result component type must match the image sampled type}} + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<4xf16> + spirv.Return +} + +// ----- + +func.func @image_fetch_2d_result(%arg0: !spirv.image, %arg1: vector<2xsi32>) -> () { + // expected-error @+1 {{op result #0 must be vector of 16/32/64-bit float values of length 4 or vector of 8/16/32/64-bit integer values of length 4, but got 'vector<2xf32>'}} + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<2xf32> + spirv.Return +} + +// ----- + +func.func @image_fetch_float_coords(%arg0: !spirv.image, %arg1: vector<2xf32>) -> () { + // expected-error @+1 {{op operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16, but got 'vector<2xf32>'}} + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xf32> -> vector<2xf32> + spirv.Return +} + //===----------------------------------------------------------------------===// // spirv.ImageOperands: Bias //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/image-ops.mlir b/mlir/test/Target/SPIRV/image-ops.mlir index b8d19f0f9a7d1..5c90fe9dbb89f 100644 --- a/mlir/test/Target/SPIRV/image-ops.mlir +++ b/mlir/test/Target/SPIRV/image-ops.mlir @@ -38,6 +38,11 @@ spirv.module Logical GLSL450 requires #spirv.vce>, vector<4xf32>, f32 -> f32 spirv.Return } + spirv.func @image_fetch(%arg0 : !spirv.image, %arg1 : vector<2xsi32>) "None" { + // CHECK: spirv.ImageFetch {{%.*}}, {{%.*}} : !spirv.image, vector<2xsi32> -> vector<4xf32> + %0 = spirv.ImageFetch %arg0, %arg1 : !spirv.image, vector<2xsi32> -> vector<4xf32> + spirv.Return + } } // -----