@@ -183,19 +183,38 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
183
183
}
184
184
185
185
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
186
- let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations ";
186
+ let summary ="Reifies the results of `tensor::PadOp` and `tensor::ConcatOp`. ";
187
187
let description = [{
188
- This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
189
- operation with ranked `memref` and `tensor` results. Replacing the
190
- operations with their reified versions, and inserting casts when results
191
- shapes are updated.
188
+ This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
189
+ ops with `tensor` results.
190
+
191
+ The pass currently only supports result shape type reification for:
192
+ - tensor::PadOp
193
+ - tensor::ConcatOp
194
+ It addresses a representation gap where implicit op semantics are needed to
195
+ infer static result types from dynamic operands.
196
+ But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of
197
+ truth rather than the op itself. As a consequence, this cannot generalize
198
+ today.
199
+
200
+ TODO: in the future, we should consider coupling this information with op
201
+ "transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of
202
+ truth that can work across result shape inference, canonicalization and op
203
+ verifiers.
204
+
205
+ The pass replaces the operations with their reified versions, when more
206
+ static information can be derived, and inserts casts when results shapes
207
+ are updated.
192
208
193
209
Example:
194
210
```mlir
195
211
#map = affine_map<(d0) -> (-d0 + 256)>
196
- func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
212
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
213
+ -> tensor<1x?x64xf32>
214
+ {
197
215
%0 = affine.apply #map(%arg1)
198
- %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
216
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
217
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
199
218
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
200
219
^bb0(%arg3: index, %arg4: index, %arg5: index):
201
220
tensor.yield %arg0 : f32
@@ -205,9 +224,12 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
205
224
206
225
// mlir-opt --reify-result-shapes
207
226
#map = affine_map<()[s0] -> (-s0 + 256)>
208
- func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
227
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
228
+ -> tensor<1x?x64xf32>
229
+ {
209
230
%0 = affine.apply #map()[%arg1]
210
- %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
231
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
232
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
211
233
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
212
234
^bb0(%arg3: index, %arg4: index, %arg5: index):
213
235
tensor.yield %arg0 : f32
0 commit comments