Skip to content

Commit 7bc5da0

Browse files
committed
fold transpose identity
1 parent 24f7531 commit 7bc5da0

File tree

4 files changed

+45
-30
lines changed

4 files changed

+45
-30
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6219,6 +6219,12 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
62196219
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
62206220
return ub::PoisonAttr::get(getContext());
62216221

6222+
// Identity transpose.
6223+
if (llvm::all_of(llvm::enumerate(getPermutation()), [](auto it) {
6224+
return it.value() == static_cast<int64_t>(it.index());
6225+
}))
6226+
return getVector();
6227+
62226228
return {};
62236229
}
62246230

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -450,17 +450,6 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
450450

451451
// -----
452452

453-
// CHECK-LABEL: transpose_3D_identity
454-
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
455-
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
456-
// CHECK-NOT: transpose
457-
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
458-
// CHECK-NEXT: return [[ARG]]
459-
return %0 : vector<4x3x2xf32>
460-
}
461-
462-
// -----
463-
464453
// CHECK-LABEL: transpose_2D_sequence
465454
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
466455
func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {

mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
// +----------------------------------------
77

88
// CHECK-LABEL: @broadcast_to_shape_cast
9-
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
10-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
11-
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
9+
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
10+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
11+
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
1212
func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
1313
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
1414
return %0 : vector<1x1x4xi8>
@@ -49,9 +49,9 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
4949
// 2 -> 1
5050
// Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
5151
// CHECK-LABEL: @transpose_to_shape_cast
52-
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
53-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
54-
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
52+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
53+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
54+
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
5555
func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
5656
%0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
5757
return %0 : vector<2x2x1xf32>
@@ -64,10 +64,10 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
6464
// 2 -> 4
6565
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
6666
// CHECK-LABEL: @shape_cast_of_transpose
67-
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
68-
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
69-
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
70-
// CHECK: return %[[SHAPE_CAST]]
67+
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
68+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
69+
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
70+
// CHECK: return %[[SHAPE_CAST]]
7171
func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> {
7272
%0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
7373
return %0 : vector<4x1x1x1x4xi8>
@@ -101,8 +101,8 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector
101101
// -----
102102

103103
// CHECK-LABEL: @shape_cast_of_transpose_scalable
104-
// CHECK-NEXT: vector.shape_cast
105-
// CHECK-NEXT: return
104+
// CHECK-NEXT: vector.shape_cast
105+
// CHECK-NEXT: return
106106
func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
107107
%0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
108108
%1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
@@ -125,9 +125,9 @@ func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]
125125
// A test where a transpose cannot be transformed to a shape_cast because it is not order
126126
// preserving
127127
// CHECK-LABEL: @negative_transpose_to_shape_cast
128-
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
129-
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
130-
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
128+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
129+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
130+
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
131131
func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
132132
%0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
133133
return %0 : vector<2x2x1xf32>
@@ -140,9 +140,9 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector
140140
// +----------------------------------------
141141

142142
// CHECK-LABEL: @extract_to_shape_cast
143-
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
144-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
145-
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
143+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
144+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
145+
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
146146
func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
147147
%0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
148148
return %0 : vector<4xf32>

mlir/test/Dialect/Vector/single-fold.mlir

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,24 @@ func.func @fold_insert_in_single_pass() -> vector<2xf16> {
3535
// CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
3636
%0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
3737
return %0 : vector<2xf16>
38-
}
38+
}
39+
40+
// -----
41+
42+
// CHECK-LABEL: transpose_3D_identity
43+
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
44+
// CHECK-NEXT: return [[ARG]]
45+
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
46+
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
47+
return %0 : vector<4x3x2xf32>
48+
}
49+
50+
// -----
51+
52+
// CHECK-LABEL: transpose_0D_identity
53+
// CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
54+
// CHECK-NEXT: return [[ARG]]
55+
func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
56+
%0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
57+
return %0 : vector<i8>
58+
}

0 commit comments

Comments
 (0)