Skip to content

Commit 08a6823

Browse files
Update pass documentation
1 parent ba51026 commit 08a6823

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,38 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
183183
}
184184

185185
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
186-
let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
186+
let summary ="Reifies the results of `tensor::PadOp` and `tensor::ConcatOp`.";
187187
let description = [{
188-
This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
189-
operation with ranked `memref` and `tensor` results. Replacing the
190-
operations with their reified versions, and inserting casts when results
191-
shapes are updated.
188+
This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
189+
ops with `tensor` results.
190+
191+
The pass currently only supports result shape type reification for:
192+
- tensor::PadOp
193+
- tensor::ConcatOp
194+
It addresses a representation gap where implicit op semantics are needed to
195+
infer static result types from dynamic operands.
196+
But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of
197+
truth rather than the op itself. As a consequence, this cannot generalize
198+
today.
199+
200+
TODO: in the future, we should consider coupling this information with op
201+
"transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of
202+
truth that can work across result shape inference, canonicalization and op
203+
verifiers.
204+
205+
The pass replaces the operations with their reified versions, when more
206+
static information can be derived, and inserts casts when results shapes
207+
are updated.
192208

193209
Example:
194210
```mlir
195211
#map = affine_map<(d0) -> (-d0 + 256)>
196-
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
212+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
213+
-> tensor<1x?x64xf32>
214+
{
197215
%0 = affine.apply #map(%arg1)
198-
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
216+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
217+
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
199218
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
200219
^bb0(%arg3: index, %arg4: index, %arg5: index):
201220
tensor.yield %arg0 : f32
@@ -205,9 +224,12 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
205224

206225
// mlir-opt --reify-result-shapes
207226
#map = affine_map<()[s0] -> (-s0 + 256)>
208-
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
227+
func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
228+
-> tensor<1x?x64xf32>
229+
{
209230
%0 = affine.apply #map()[%arg1]
210-
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
231+
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
232+
: tensor<64x?x64xf32> to tensor<1x?x64xf32>
211233
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
212234
^bb0(%arg3: index, %arg4: index, %arg5: index):
213235
tensor.yield %arg0 : f32

mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "llvm/Support/InterleavedRange.h"
2223

@@ -134,8 +135,10 @@ struct ReifyResultShapesPass final
134135
void ReifyResultShapesPass::runOnOperation() {
135136
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
136137
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
137-
// Some ops have rigid type checkers and need to update their operands.
138-
// Only admit the ones that are explicitly supported for now.
138+
// Handle ops that are not DPS and that do not carry an tied operand shapes.
139+
// For now, limit to tensor::PadOp and tensor::ConcatOp.
140+
if (isa<DestinationStyleOpInterface>(op.getOperation()))
141+
return;
139142
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
140143
return;
141144
ops.push_back(op);

0 commit comments

Comments
 (0)