Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion shardy/round_trip_import/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,5 @@ cc_library(
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@stablehlo//:stablehlo_passes",
"@stablehlo//:stablehlo_passes_optimization",
],
)
3 changes: 1 addition & 2 deletions shardy/round_trip_import/pipelines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "shardy/round_trip_import/import_shardy_attrs.h"
#include "shardy/round_trip_import/shard_map_import.h"
#include "stablehlo/transforms/Passes.h"
#include "stablehlo/transforms/optimization/Passes.h"

namespace mlir {
namespace sdy {
Expand All @@ -41,7 +40,7 @@ void addSdyRoundTripImportPipeline(OpPassManager& pm) {
.enableFolding(false)
.enableConstantCSE(false);
pm.addNestedPass<func::FuncOp>(
stablehlo::createStablehloAggressiveSimplificationPass({}, config));
stablehlo::createStablehloAggressiveSimplificationPass(config));
pm.addPass(createSdyRoundTripImportCallbackCustomCallsPass());
pm.addPass(createSdyRoundTripImportShardyAttrsPass());
pm.addPass(createSdyRoundTripShardMapImportPass());
Expand Down
185 changes: 80 additions & 105 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,52 @@ diff --ruN a/stablehlo/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/s
// CHECK: linalg.yield %[[VAL2]] : i32
// CHECK: %[[RES_UNSIGNED:.+]] = builtin.unrealized_conversion_cast %[[RES]] : tensor<2x1x5xi32> to tensor<2x1x5xui32>
// CHECK: return %[[RES_UNSIGNED]]
diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp
--- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp
+++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp
@@ -61,7 +61,6 @@
} // namespace

LinalgTypeConverter::LinalgTypeConverter() : RemoveSignTypeConverter() {
- addArgumentMaterialization(scalarToTensor);
addSourceMaterialization(scalarToTensor);
addTargetMaterialization(scalarToTensor);
}
diff --ruN a/stablehlo/stablehlo/dialect/StablehloAttrs.td b/stablehlo/stablehlo/dialect/StablehloAttrs.td
--- stablehlo/stablehlo/dialect/StablehloAttrs.td
+++ stablehlo/stablehlo/dialect/StablehloAttrs.td
@@ -221,7 +221,7 @@
);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
- let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
+ let constBuilderCall = "::mlir::stablehlo::ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ::mlir::stablehlo::ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS
diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td
--- stablehlo/stablehlo/dialect/VhloAttrs.td
+++ stablehlo/stablehlo/dialect/VhloAttrs.td
@@ -102,7 +102,7 @@
// Corresponds to IntegerConstant from the StableHLO spec.
def VHLO_IntegerAttrV1 : VHLO_AttrDef<"IntegerV1", "0.9.0", "current"> {
let mnemonic = "integer_v1";
- let parameters = (ins "mlir::Type":$type, "APInt":$value);
+ let parameters = (ins "mlir::Type":$type, APIntParameter<"">:$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
LogicalResult IntegerV1Attr::verify(
diff --ruN a/stablehlo/stablehlo/dialect/VhloTypes.cpp b/stablehlo/stablehlo/dialect/VhloTypes.cpp
--- stablehlo/stablehlo/dialect/VhloTypes.cpp
+++ stablehlo/stablehlo/dialect/VhloTypes.cpp
@@ -333,7 +333,6 @@
void VhloTypeConverter::addUnrealizedMaterializations() {
addTargetMaterialization(materializeIllegalCast);
addSourceMaterialization(materializeIllegalCast);
- addArgumentMaterialization(materializeIllegalCast);
}

namespace {
diff --ruN a/stablehlo/stablehlo/tests/TestUtils.cpp b/stablehlo/stablehlo/tests/TestUtils.cpp
--- stablehlo/stablehlo/tests/TestUtils.cpp
+++ stablehlo/stablehlo/tests/TestUtils.cpp
Expand All @@ -76,46 +122,6 @@ diff --ruN a/stablehlo/stablehlo/tests/TestUtils.cpp b/stablehlo/stablehlo/tests
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h
--- stablehlo/stablehlo/transforms/Passes.h
+++ stablehlo/stablehlo/transforms/Passes.h
@@ -41,10 +41,6 @@
namespace stablehlo {

#define GEN_PASS_DECL
-
-std::unique_ptr<::mlir::Pass> createStablehloAggressiveSimplificationPass(
- GreedyRewriteConfig config);
-
#define GEN_PASS_REGISTRATION
#include "stablehlo/transforms/Passes.h.inc"

@@ -61,10 +57,23 @@
RewritePatternSet *patterns,
TypeConverter *converter);

+inline void populateStablehloToVhloPatterns(RewritePatternSet *patterns,
+ MLIRContext *context,
+ TypeConverter *converter) {
+ populateStablehloToVhloPatterns(context, patterns, converter);
+}
+
// Populates VHLO ops to StableHLO ops rewriting patterns.
+
void populateVhloToStablehloPatterns(MLIRContext *context,
- RewritePatternSet *patterns,
- TypeConverter *converter);
+ RewritePatternSet *patterns,
+ TypeConverter *converter);
+
+inline void populateVhloToStablehloPatterns(RewritePatternSet *patterns,
+ TypeConverter *converter,
+ MLIRContext *context) {
+ populateVhloToStablehloPatterns(context, patterns, converter);
+}

// Populates VHLO downgrade rewriting patterns.
void populateVhloToVersionPatterns(MLIRContext *context,
diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp
--- stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp
+++ stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp
Expand All @@ -135,7 +141,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/
+ .setStrictness(GreedyRewriteStrictness::AnyOp);

RewritePatternSet patterns_(context);
populateStablehloCanonicalizeDynamismPatterns(context, &patterns_);
populateStablehloCanonicalizeDynamismPatterns(&patterns_, context);
@@ -325,7 +325,7 @@
auto func = getOperation();
if (failed(applyPatternsGreedily(func, patterns, config))) {
Expand All @@ -156,7 +162,7 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp b
+ config.setUseTopDownTraversal(true);

RewritePatternSet patterns_(context);
populateStablehloCompatibilityExpanderPatterns(context, &patterns_,
populateStablehloCompatibilityExpanderPatterns(&patterns_, context,
@@ -347,7 +347,7 @@
failed(applyPatternsGreedily(module, patterns, config))) {
module.emitError(
Expand All @@ -169,16 +175,16 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloCompatibilityExpander.cpp b
diff --ruN a/stablehlo/stablehlo/transforms/StablehloComplexMathExpander.cpp b/stablehlo/stablehlo/transforms/StablehloComplexMathExpander.cpp
--- stablehlo/stablehlo/transforms/StablehloComplexMathExpander.cpp
+++ stablehlo/stablehlo/transforms/StablehloComplexMathExpander.cpp
@@ -51,7 +51,7 @@
@@ -49,7 +49,7 @@

public:
LogicalResult initialize(MLIRContext *context) override {
- config.useTopDownTraversal = true;
+ config.setUseTopDownTraversal(true);
RewritePatternSet patterns_(context);
populateStablehloComplexMathExpanderPatterns(context, &patterns_);
populateStablehloComplexMathExpanderPatterns(&patterns_, context);
patterns = std::move(patterns_);
@@ -62,7 +62,7 @@
@@ -60,7 +60,7 @@
auto func = getOperation();
if (failed(applyPatternsGreedily(func, patterns, config))) {
func.emitError("Failed to converge StableHLOComplexMathExpanderPass in ")
Expand Down Expand Up @@ -233,9 +239,9 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehl
+ .setMaxNumRewrites(GreedyRewriteConfig::kNoLimit)
+ .setStrictness(GreedyRewriteStrictness::AnyOp);

populateStablehloRefineShapesPatterns(context, &patterns);
populateStablehloRefineShapesPatterns(&patterns, context);
patterns.add<RefineCallOpPattern>(context, state);
@@ -1049,7 +1049,7 @@
@@ -1046,7 +1046,7 @@

if (failed(applyPatternsGreedily(func, std::move(patterns), config)))
func.emitError("Failed to converge StablehloRefineShapes in ")
Expand All @@ -256,64 +262,33 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloWrapInComposite.cpp b/stabl
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
config))) {
signalPassFailure();
diff --ruN a/stablehlo/stablehlo/transforms/optimization/Passes.h b/stablehlo/stablehlo/transforms/optimization/Passes.h
--- stablehlo/stablehlo/transforms/optimization/Passes.h
+++ stablehlo/stablehlo/transforms/optimization/Passes.h
@@ -16,6 +16,7 @@
#ifndef STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H
#define STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H
diff --ruN a/stablehlo/stablehlo/transforms/conversions/TypeConversion.cpp b/stablehlo/stablehlo/transforms/conversions/TypeConversion.cpp
--- stablehlo/stablehlo/transforms/conversions/TypeConversion.cpp
+++ stablehlo/stablehlo/transforms/conversions/TypeConversion.cpp
@@ -77,7 +77,6 @@
addConversion(convertInteger);
addConversion(convertShapedType);

+#include <memory>
#include <utility>

#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -33,6 +34,15 @@
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
#include "stablehlo/transforms/optimization/Passes.h.inc"
+
+std::unique_ptr<::mlir::Pass> createStablehloAggressiveSimplificationPass(
+ StablehloAggressiveSimplificationPassOptions options,
+ GreedyRewriteConfig rewriteConfig);
+
+inline std::unique_ptr<::mlir::Pass>
+createStablehloAggressiveSimplificationPass(GreedyRewriteConfig config) {
+ return createStablehloAggressiveSimplificationPass({}, config);
+}

std::pair<StablehloAggressiveFolderPassOptions,
StablehloAggressiveSimplificationPassOptions>
@@ -70,6 +80,13 @@
RewritePatternSet *patterns,
PatternBenefit benefit = 1);

+// TODO(gunhyun): To be deleted in the next integrate.
+inline void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns,
+ MLIRContext *context,
+ PatternBenefit benefit = 1) {
+ populateStablehloShapeFolderPatterns(context, patterns, benefit);
+}
+
/// Some workloads in XLA import StableHLO from HLO. Since there are a few
/// differences in HLO (no implicit captures, lots of tuples, etc.), this
/// set of patterns brings the imported HLO back to a more canonical form
@@ -90,6 +107,10 @@
MLIRContext *context, RewritePatternSet *patterns,
StablehloAggressiveFolderPassOptions &&options,
PatternBenefit benefit = 1) = delete;
+void populateStablehloShapeFolderPatterns(
+ RewritePatternSet *patterns, MLIRContext *context,
+ StablehloAggressiveFolderPassOptions &&options,
+ PatternBenefit benefit = 1) = delete;
void populateStablehloAggressiveFolderPatterns(
MLIRContext *context, RewritePatternSet *patterns,
StablehloAggressiveFolderPassOptions &&options,
@@ -98,6 +119,7 @@
MLIRContext *context, RewritePatternSet *patterns,
StablehloAggressiveSimplificationPassOptions &&options) = delete;

+
} // namespace stablehlo
} // namespace mlir
- addArgumentMaterialization(materializeCastFromIllegal);
addSourceMaterialization(materializeCastToIllegal);
addTargetMaterialization(materializeCastFromIllegal);
}
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp
--- stablehlo/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp
+++ stablehlo/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp
@@ -56,10 +56,10 @@

void runOnOperation() override {
GreedyRewriteConfig config;
- config.fold = true;
- config.cseConstants = true;
- config.maxIterations = kFoldOpEltLimit;
- config.useTopDownTraversal = false;
+ config.enableFolding(true)
+ .enableConstantCSE(true)
+ .setMaxIterations(kFoldOpEltLimit)
+ .setUseTopDownTraversal(false);
if (failed(applyPatternsGreedily(getOperation(), patterns, config)))
signalPassFailure();
}

4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
#
STABLEHLO_COMMIT = "63246aea4a0310274d1eaebe1ec150a63e935a28"
STABLEHLO_SHA256 = "86c7265b8989f5511f4df151981fb89650c238dd302d9fa2041d96c73b53b2b5"
STABLEHLO_COMMIT = "a54938f0651d3b4b7be9771848eda2463c92a8e7"
STABLEHLO_SHA256 = "d5251548b20e51b05d1803720fecde83316b0e662d00ac093cceba7c1584b51d"
#

tf_http_archive(
Expand Down
Loading