Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Apr 10, 2023
1 parent 9364cf7 commit df7522b
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 19 deletions.
8 changes: 4 additions & 4 deletions compiler/src/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -908,10 +908,10 @@ struct FoldTensorUpdateOpWithCasts : public OpRewritePattern<TensorUpdateOp> {
auto targetCastOp = updateOp.getTarget().getDefiningOp<tensor::CastOp>();
auto updateCastOp = updateOp.getUpdate().getDefiningOp<tensor::CastOp>();
if (!targetCastOp && !updateCastOp) return failure();
auto target =
(targetCastOp ? targetCastOp.getSource() : updateOp.getTarget());
auto update =
(updateCastOp ? updateCastOp.getSource() : updateOp.getUpdate());
Value target = (targetCastOp ? cast<Value>(targetCastOp.getSource())
: cast<Value>(updateOp.getTarget()));
Value update = (updateCastOp ? cast<Value>(updateCastOp.getSource())
: cast<Value>(updateOp.getUpdate()));
auto newOp = rewriter.create<TensorUpdateOp>(
updateOp.getLoc(), target.getType(), target,
refreshDimsOnTypeChange(updateOp, updateOp.getTarget().getType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ struct ScatterImplicitIndex : public OpRewritePattern<mhlo::ScatterOp> {
PatternRewriter &rewriter) const final {
auto dimNumbers = op.getScatterDimensionNumbers();
auto indexVectorDim = dimNumbers.getIndexVectorDim();
auto indices = op.getScatterIndices();
Value indices = op.getScatterIndices();
auto indicesTy = indices.getType().cast<ShapedType>();

// Check indices vector has an implicit dim.
Expand Down
1 change: 0 additions & 1 deletion compiler/src/iree/compiler/Tools/init_mlir_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ inline void registerMlirPasses() {
registerLoopCoalescingPass();
registerLoopInvariantCodeMotionPass();
registerAffineScalarReplacementPass();
registerSCFParallelLoopCollapsingPass();
registerPrintOpStatsPass();
registerViewOpGraphPass();
registerStripDebugInfoPass();
Expand Down
2 changes: 1 addition & 1 deletion integrations/tensorflow/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

TENSORFLOW_COMMIT = "4f1dd6d5123f4eb6afc85fac36df09b4a8b49c83"
TENSORFLOW_COMMIT = "5dd766f144ee0fc20506ee6476dc21c8e1816b69"

git_repository(
name = "org_tensorflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ transform.sequence failures(propagate) {
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2 { vectorize_padding }
transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
{bufferize_function_boundaries = true}
%3 = transform.structured.match ops{["func.func"]} in %module_op
%module_op1 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
{bufferize_function_boundaries = true} : (!pdl.operation) -> !pdl.operation
%3 = transform.structured.match ops{["func.func"]} in %module_op1
: (!pdl.operation) -> !pdl.operation


%func = transform.structured.match ops{["func.func"]} in %module_op
%func = transform.structured.match ops{["func.func"]} in %module_op1
: (!pdl.operation) -> !pdl.operation
%func_e_2 = transform.vector.lower_contraction %func
lowering_strategy = "outerproduct"
Expand All @@ -36,5 +36,5 @@ transform.sequence failures(propagate) {
lowering_strategy = "shuffle"
: (!pdl.operation) -> !pdl.operation

lower_to_llvm %module_op : (!pdl.operation) -> !pdl.operation
lower_to_llvm %module_op1 : (!pdl.operation) -> !pdl.operation
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ transform.sequence failures(propagate) {
: (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
%2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
transform.structured.vectorize %2 { vectorize_padding }
transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
{bufferize_function_boundaries = true}
%3 = transform.structured.match ops{["func.func"]} in %module_op
%module_op1 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
{bufferize_function_boundaries = true} : (!pdl.operation) -> !pdl.operation
%3 = transform.structured.match ops{["func.func"]} in %module_op1
: (!pdl.operation) -> !pdl.operation


%func = transform.structured.match ops{["func.func"]} in %module_op
%func = transform.structured.match ops{["func.func"]} in %module_op1
: (!pdl.operation) -> !pdl.operation
%func_e_2 = transform.vector.lower_contraction %func
lowering_strategy = "outerproduct"
Expand All @@ -36,5 +36,5 @@ transform.sequence failures(propagate) {
lowering_strategy = "shuffle"
: (!pdl.operation) -> !pdl.operation

lower_to_llvm %module_op : (!pdl.operation) -> !pdl.operation
lower_to_llvm %module_op1 : (!pdl.operation) -> !pdl.operation
}
2 changes: 1 addition & 1 deletion third_party/llvm-project
2 changes: 1 addition & 1 deletion third_party/mlir-hlo

0 comments on commit df7522b

Please sign in to comment.