From 9c36fe7dee91874df235ca554562c903c2b30fc2 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Fri, 21 Feb 2025 09:06:52 -0800 Subject: [PATCH] #sdy Add extra `export_test.py` tests for using different meshes.. Under Shardy, we can: - use the same mesh on save and load (no test added here, but would be tested by existing tests) - use one mesh on save and another mesh on load with different axis names - use one mesh on save and another mesh on load with different axis names and sizes. For this case Shardy propagation may not be optimal if the module doesn't specify out shardings. This is very hard to write a unit test for, and is rare to happen, and is something we have been considering adding in Shardy b/399957785. This will be something we can allow for during Shardy propagation. This can be a standalone fix in Shardy without making any changes to JAX or XLA. PiperOrigin-RevId: 729550047 --- shardy/dialect/sdy/ir/utils.cc | 19 ++++++++++++++++--- shardy/dialect/sdy/ir/utils.h | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc index 0e7643d6..c7ff3306 100644 --- a/shardy/dialect/sdy/ir/utils.cc +++ b/shardy/dialect/sdy/ir/utils.cc @@ -128,6 +128,20 @@ int64_t isScalar(Value value) { return false; } +MeshOp getMeshOp(Operation* op, SymbolRefAttr meshSymName) { + return SymbolTable::lookupNearestSymbolFrom( + op, meshSymName); +} + +MeshOp getMeshOp(Operation* op, StringRef meshName) { + return getMeshOp(op, SymbolRefAttr::get(op->getContext(), meshName)); +} + +MeshOp getMeshOp(const SymbolTable& symbolTable, StringRef meshName) { + return symbolTable.lookup(meshName); +} + + MeshAttr getMeshOrLookup(const SymbolTable& symbolTable, Attribute meshOrRef) { if (auto mesh = dyn_cast(meshOrRef)) { return mesh; @@ -143,7 +157,7 @@ MeshAttr getMeshOrLookup(Operation* op, Attribute meshOrRef) { } MeshAttr getMeshAttr(const SymbolTable& symbolTable, StringRef meshName) { - if (auto meshOp = symbolTable.lookup(meshName)) { + if (MeshOp meshOp = getMeshOp(symbolTable, meshName)) { return meshOp.getMesh(); } @@ -160,8 +174,7 @@ MeshAttr getMeshAttr(Operation* op, StringRef meshName) { } MeshAttr getMeshAttr(Operation* op, SymbolRefAttr meshSymName) { - if (auto meshOp = - SymbolTable::lookupNearestSymbolFrom(op, meshSymName)) { + if (MeshOp meshOp = getMeshOp(op, meshSymName)) { return meshOp.getMesh(); } diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h index a8b42a67..ef589282 100644 --- a/shardy/dialect/sdy/ir/utils.h +++ b/shardy/dialect/sdy/ir/utils.h @@ -90,6 +90,20 @@ int64_t getTensorRank(Value value); // Returns true if the value is a tensor with rank 0. int64_t isScalar(Value value); +// Looks up the mesh symbol with the given `meshName` in `symbolTable`, and +// returns it if it exists in the table, or nullptr otherwise. +MeshOp getMeshOp(const SymbolTable& symbolTable, StringRef meshName); + +// Looks up the mesh symbol with the given `meshSymName` in the symbol table of +// the enclosing module of `op`, and returns it if it exists in the table, or +// nullptr otherwise. +MeshOp getMeshOp(Operation* op, SymbolRefAttr meshSymName); + +// Looks up the mesh symbol with the given `meshName` in the symbol table of +// the enclosing module of `op`, and returns it if it exists in the table, or +// nullptr otherwise. +MeshOp getMeshOp(Operation* op, StringRef meshName); + // If `meshOrRef` is a `MeshAttr`, returns it, otherwise, looks up the // referenced mesh symbol in `symbolTable`, and returns its `MeshAttr` // if it exists in the table, or nullptr otherwise.