diff --git a/mlir/include/mlir/Target/SPIRV/Deserialization.h b/mlir/include/mlir/Target/SPIRV/Deserialization.h index a346a7fd1e5f7..68eb863b4a6f2 100644 --- a/mlir/include/mlir/Target/SPIRV/Deserialization.h +++ b/mlir/include/mlir/Target/SPIRV/Deserialization.h @@ -23,12 +23,19 @@ class MLIRContext; namespace spirv { class ModuleOp; +struct DeserializationOptions { + // Whether to structurize control flow into `spirv.mlir.selection` and + // `spirv.mlir.loop`. + bool enableControlFlowStructurization = true; +}; + /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports /// errors to the error handler registered with `context` and returns a null /// module. -OwningOpRef deserialize(ArrayRef binary, - MLIRContext *context); +OwningOpRef +deserialize(ArrayRef binary, MLIRContext *context, + const DeserializationOptions &options = {}); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp index 7bb8762660599..b82c61cafc8a7 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp @@ -12,9 +12,10 @@ using namespace mlir; -OwningOpRef spirv::deserialize(ArrayRef binary, - MLIRContext *context) { - Deserializer deserializer(binary, context); +OwningOpRef +spirv::deserialize(ArrayRef binary, MLIRContext *context, + const DeserializationOptions &options) { + Deserializer deserializer(binary, context, options); if (failed(deserializer.deserialize())) return nullptr; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 7afd6e9b25b77..a21d691ae5142 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -49,9 +49,10 @@ static inline bool isFnEntryBlock(Block *block) { //===----------------------------------------------------------------------===// spirv::Deserializer::Deserializer(ArrayRef binary, - MLIRContext *context) + MLIRContext *context, + const spirv::DeserializationOptions &options) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), - module(createModuleOp()), opBuilder(module->getRegion()) + module(createModuleOp()), opBuilder(module->getRegion()), options(options) #ifndef NDEBUG , logger(llvm::dbgs()) @@ -2361,6 +2362,16 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { } LogicalResult spirv::Deserializer::structurizeControlFlow() { + if (!options.enableControlFlowStructurization) { + LLVM_DEBUG( + { + logger.startLine() + << "//----- [cf] skip structurizing control flow -----//\n"; + logger.indent(); + }); + return success(); + } + LLVM_DEBUG({ logger.startLine() << "//----- [cf] start structurizing control flow -----//\n"; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index bcc78e3e6508d..e4556e7652b17 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" +#include "mlir/Target/SPIRV/Deserialization.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringRef.h" @@ -121,7 +122,8 @@ class Deserializer { public: /// Creates a deserializer for the given SPIR-V `binary` module. /// The SPIR-V ModuleOp will be created into `context. - explicit Deserializer(ArrayRef binary, MLIRContext *context); + explicit Deserializer(ArrayRef binary, MLIRContext *context, + const DeserializationOptions &options); /// Deserializes the remembered SPIR-V binary module. LogicalResult deserialize(); @@ -622,6 +624,9 @@ class Deserializer { /// A list of all structs which have unresolved member types. SmallVector deferredStructTypesInfos; + /// Deserialization options. + DeserializationOptions options; + #ifndef NDEBUG /// A logger used to emit information during the deserialzation process. llvm::ScopedPrinter logger; diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp index ff34f02d07b73..682fff2784775 100644 --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -37,7 +37,8 @@ using namespace mlir; // Deserializes the SPIR-V binary module stored in the file named as // `inputFilename` and returns a module containing the SPIR-V module. static OwningOpRef -deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { +deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, + const spirv::DeserializationOptions &options) { context->loadDialect(); // Make sure the input stream can be treated as a stream of SPIR-V words @@ -51,17 +52,26 @@ deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { auto binary = llvm::ArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - return spirv::deserialize(binary, context); + return spirv::deserialize(binary, context, options); } namespace mlir { void registerFromSPIRVTranslation() { + static llvm::cl::opt enableControlFlowStructurization( + "spirv-structurize-control-flow", + llvm::cl::desc( + "Enable control flow structurization into `spirv.mlir.selection` and " + "`spirv.mlir.loop`. This may need to be disabled to support " + "deserialization of early exits (see #138688)"), + llvm::cl::init(true)); + TranslateToMLIRRegistration fromBinary( "deserialize-spirv", "deserializes the SPIR-V module", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); return deserializeModule( - sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); + sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context, + {enableControlFlowStructurization}); }); } } // namespace mlir