-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Extend FuseElementwiseOps
pattern to work with named ops
#144922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
linalg.map
op
linalg.map
opLinalgOp
✅ With the latest revision this PR passed the C/C++ code formatter. |
The current changes in this PR will fuse |
Although i didn't expect differences between implementing for generic vs map, but i am seeing a failure in a case with linalg.map that passes for the equivalent linalg.generic version. map version
fails with
generic version
successfully fuses and generates
since It's strange that this case doesn't constant fold in the first place. The first 3 generics can definitely be constant folded to |
I think I know why this fails. The fundamental difference between |
// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]] | ||
// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]] | ||
// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]] | ||
// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the duplicate computations are an old artifact. these do go away with cse
but let me know if this is something that should be looked at
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed odd. Looks like a bug in the fuser. Could be related to the map
vs generic
issue you've seen above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did make a generic version of this and ran the old version of the pass and got same results to confirm it's a pre-existing thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here's the generic version
#map = affine_map<(d0)->(d0)>
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<8xf32> {
%init = tensor.empty() : tensor<8xi1>
%initf = tensor.empty() : tensor<8xf32>
%0 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
^bb0(%in0 : f32, %out : f32):
%sqrt = math.sqrt %in0 : f32
linalg.yield %sqrt : f32
} -> tensor<8xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
^bb0(%in0 : f32, %out : f32):
%sqrt = math.exp %in0 : f32
linalg.yield %sqrt : f32
} -> tensor<8xf32>
%2 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]} ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs(%init : tensor<8xi1>) {
^bb0(%in0 : f32, %in1 : f32, %out: i1):
%cmp = arith.cmpf olt, %in0, %in1 : f32
linalg.yield %cmp : i1
} -> tensor<8xi1>
%3 = linalg.generic {
indexing_maps = [#map, #map, #map, #map],
iterator_types = ["parallel"]} ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) {
^bb0(%in0 : i1, %in1 : f32, %in2 : f32, %out: f32):
%select = arith.select %in0, %in1, %in2 : f32
linalg.yield %select : f32
} -> tensor<8xf32>
return %3 : tensor<8xf32>
}
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (srcarroll) ChangesThis patch modifies Full diff: https://github.com/llvm/llvm-project/pull/144922.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 147a2907f52e4..f0c8f0de06637 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -529,8 +529,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
-llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
- GenericOp consumer,
+llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
+ LinalgOp consumer,
OpOperand *fusedOperand);
/// Try to peel and canonicalize loop `op` and return the new result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f97ed3d6d5111..fc435b47f5977 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -77,11 +77,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
- GenericOp producer, GenericOp consumer,
+ LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;
- SmallVector<GenericOp> ops = {producer, consumer};
+ SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
@@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
-llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
- GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
+llvm::SmallDenseSet<int>
+mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
+ OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
@@ -140,8 +141,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;
- auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
- auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
+ auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());
// Check producer and consumer are generic ops.
if (!producer || !consumer)
@@ -215,16 +216,39 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
- RewriterBase &rewriter, GenericOp fusedOp,
+ RewriterBase &rewriter, LinalgOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
- auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
- auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
+ auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.
+
+ // Since some ops, like `linalg.map`, do not have block arguments for init operands
+ // then we first "generalize" the block by adding arguments for init operands when
+ // they aren't present. We detect this case by checking if
+ // `getOpOperandsMatchingBBargs() == getDpsInputOperands();
Block &producerBlock = producer->getRegion(0).front();
+ if (producer.getOpOperandsMatchingBBargs() ==
+ producer.getDpsInputOperands()) {
+ for (auto init : producer.getDpsInits()) {
+ Type bbType = isa<ShapedType>(init.getType())
+ ? cast<ShapedType>(init.getType()).getElementType()
+ : init.getType();
+ producerBlock.addArgument(bbType, producer.getLoc());
+ }
+ }
Block &consumerBlock = consumer->getRegion(0).front();
+ if (consumer.getOpOperandsMatchingBBargs() ==
+ consumer.getDpsInputOperands()) {
+ for (auto init : consumer.getDpsInits()) {
+ Type bbType = isa<ShapedType>(init.getType())
+ ? cast<ShapedType>(init.getType()).getElementType()
+ : init.getType();
+ consumerBlock.addArgument(bbType, consumer.getLoc());
+ }
+ }
OpBuilder::InsertionGuard guard(rewriter);
- Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
+ Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
IRMapping mapper;
// 2. Add an index operation for every fused loop dimension and use the
@@ -330,7 +354,7 @@ static void generateFusedElementwiseOpRegion(
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
// Sanity checks.
- assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
+ assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
}
@@ -340,8 +364,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
- auto producer = cast<GenericOp>(producerResult.getOwner());
- auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = cast<LinalgOp>(producerResult.getOwner());
+ auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
@@ -418,10 +442,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// Generate the fused op.
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
- fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- consumer.getIteratorTypes(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
+ fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
@@ -460,14 +481,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
-class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
+class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
@@ -494,7 +515,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
rewriter.eraseOp(genericOp);
return success();
}
- return failure();
+ return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
}
private:
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 66fc55fadf8fa..b581567cf57a7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1014,3 +1014,24 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index bd9977f1410b9..18ca8b42fa79c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -59,3 +59,57 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.map
+
+// -----
+
+func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+ %init = tensor.empty() : tensor<8xi1>
+ %initf = tensor.empty() : tensor<8xf32>
+ %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
+ (%in0 : f32, %in1 : f32) {
+ %cmp = arith.cmpf olt, %in0, %in1 : f32
+ linalg.yield %cmp : i1
+ }
+ %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ return %3 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops_mixed_types
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
+// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
+// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
+// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
+// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
+// CHECK-NEXT: linalg.yield %[[RES]]
+// CHECK-NOT: linalg.map
+
|
After a chat with @rengolin here , in particular [quote="rengolin, post:6, topic:83927"] it's probably important enough to warrant discussion |
It occurs to me now that options like Any thoughts? Edit: Actually, I was just confused by the naming, but I see now how |
LinalgOp
FuseElementwiseOps
pattern to work with named ops
I'm now realizing I should extend all the patterns in |
// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]] | ||
// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]] | ||
// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]] | ||
// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed odd. Looks like a bug in the fuser. Could be related to the map
vs generic
issue you've seen above.
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> | ||
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> | ||
// CHECK: %[[FUSED_OP:.+]] = linalg.generic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe not for this PR, but we have discussed keeping the named ops for as long as possible. Here, since they're both map
s, we could fuse into a map
still. Technically, they're the same (as discussed in the forum), but if I have a chain of matches and fusers, I'd have to match against all possible representations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yah I mentioned in a comment that I wanted to try to do that, but don't know of a clean way to do that yet. I've done something like that before by just using clone
and modifying as a way to generalizing a transform from a named op to that same named op. It seems more complicated here, but maybe not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also that wouldn't help with elementwise+elementwise -> map anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also that wouldn't help with elementwise+elementwise -> map anyway
Right, the idea is that ew
+ ew
-> map
is still better than to generic
. So we only walk up the tree when needed (assuming generic
-> map
-> ew
is the branch we're walking).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing to note. ew
+ ew
-> map
only works if maps for both ew
are same rank identity on all operands since indexing maps for map
are limited
|
||
// ----- | ||
|
||
func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also add a test where the map
operations each have more than one region ops? The fuser should be able to cope with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. I initially didn't go that far with tests because at the time my reasoning was that the logic is unchanged so should generalize. but then I ran across that one issue with the bb args, so I'm less convinced of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
// ----- | ||
|
||
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to fuse map
and generic
together, if their affine maps and iterator types are compatible? If yes, we should have a quick test on it. If not, this should eventually be supported (separate PR), so a FIXME in the code would help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my thinking is that if it works for the generic form it should work for map form. but again, that block arg oddity kinda ruins that. nevertheless I expect that to work, but will add a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
// Build the region of the fused op. | ||
|
||
// Since some ops, like `linalg.map`, do not have block arguments for init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a missed opportunity. IIRC, neither contract
and elementwise
have that problem. Perhaps we can update map
to behave like the new ones first? @javedabsar1 @shahidact
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it needs to be updated, or even better if we deprecate it in favour of linalg.elementwise
as linalg.map
is same semantically with linalg.elementwise
, in fact linalg.elementwise
seems more general, IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rengolin and I had a discussion of that here. seems like we want to keep map
and it does support multi-op regions whereas elementwise
does not. So elementwise
isn't quite more general. They each have something the other doesn't. But yes i agree that this is an odd feature of map
, so if we do keep it would be nice to canonicalize this kind of thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i actually found another bug and i think it's because of what i'm doing here. i'm modifying the blocks for producers and consumers. as long as they don't stick around, this works fine. but if the producer has more than one user so that it has to stick around, then this logic here converts it to an invalid op. i'm in the process of confirming this. In any event, it's probably not a good idea to modify the blocks in place like this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can update map to behave like the new ones first?
i'm fine with waiting for that change so we don't have to figure out special logic here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i fixed the bug here. the problem was what i expected. i made note here that this is a hack and should be changed before merging. So I'll either come up with something better or wait for the map
op change (if we decide to go with that).
I can't actually reproduce the bug with builtin passes. I only discovered it when I applied my own control function for populateElementwiseOpsFusionPatterns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't actually reproduce the bug with builtin passes. I only discovered it when I applied my own control function for
populateElementwiseOpsFusionPatterns
If there's a bug present is the current fuser, it'd be worth adding such test case to the TestLinalgElementwiseFusion
. You could extend it with another option flag that adds necessary custom control logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. however, since the bug was an artifact of my own changes that I noted are hacky, and those changes will go away before this merges anyway, it would be irrelevant by that point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rengolin do you know if anyone will make the change to map
, or should I do it?
I'd leave this for another PR. I have similar comments here, where fusion of
We can go slow here, as there are many nuances. This PR hits the spot of being short and meaningful as is. As this will definitely need more discussion, we can use the evolution of PRs as inflection points in that discussion. |
Sounds good. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont have a major issue with the change itself, but I think this is different from the direction I think we had agreed on. I think we need to remove regions from the elementwise/contract/fill ops. If we go down that path then we will really not need anything cause elementwise fusion will require generalization (or at least some interface methods that can give you the region needed.
At the very least I think a lot of complexity introduced here are due to map operations. If we fix that then Id expect almost no changes to the fusion logic here. Id rather do that than import that complexity here.
I wasn't involved in the discussion so I dont know exactly what was agreed upon but, as far as I know, this doesn't preclude any op design changes in linalg.
Not sure why we would do that. If we want these ops to have generic counterparts, then I dont think it makes sense to remove regions. But again, since I wasn't involved in the discussion I'll concede to those of you who know more about this. Nevertheless, I dont think these ops should be excluded from fusion if that's what you're getting at.
Agree. discussed here |
The region does not to be there always. We only need it when generalizing. Having it always is kind of unnecessary (and currently there is a region but is hidden from parsing/printing which is extremely spooky.
Cool! Ill wait for that to land before reviewing this again. |
Hmm ok. Well I don't have a strong opinion about that exactly. But I do have a strong opinion about being able to fuse such ops in a more or less generic way (provided they are structured to do so, so I understand there might be certain ops that don't make sense to fuse in the first place). So I think some way to describe regions for such ops is actually useful. But again I wasn't involved in the discussion so I accept that there are nuances around this that I'm likely not aware of. Would like to hear other people's thoughts on it too. @rengolin |
You can have a method that can generate the ops that would go into the region on demand. |
Oh ya sure. I think you initially mentioned an interface method. Something like that would satisfy me enough I think |
Yeah. I just hate that there is this implicit region that is not parsed or printed and blocks Linalg ops from being used effectively in things like PDL. |
I see. I dont know PDL but sounds fair enough to me |
This patch modifies
FuseElementwiseOps
to support fusing named ops, such aslinalg.map
andlinalg.elementwise
, which are always elementwise. The fundamental logic remains the same and, for the most part, we need only to changeGenericOp
toLinalgOp
.