From ca64017fec67ec713b63217af7ae7466bed08010 Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Mon, 19 May 2025 14:22:01 +0100 Subject: [PATCH 1/2] [mlir][spirv] Allow disabling control flow structurization Currently some control flow patterns cannot be structurized into current SPIR-V MLIR constructs, e.g., conditional early exits (break). Since the support for early exit cannot be currently added (https://github.com/llvm/llvm-project/pull/138688#pullrequestreview-2830791677 in see #138688) this patch enables structurizer to be disabled to keep the control flow unstructurized. By default, the control flow is structurized. --- mlir/include/mlir/Target/SPIRV/Deserialization.h | 5 +++-- .../SPIRV/Deserialization/Deserialization.cpp | 7 ++++--- .../SPIRV/Deserialization/Deserializer.cpp | 16 ++++++++++++++-- .../Target/SPIRV/Deserialization/Deserializer.h | 8 +++++++- mlir/lib/Target/SPIRV/TranslateRegistration.cpp | 14 +++++++++++--- 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Target/SPIRV/Deserialization.h b/mlir/include/mlir/Target/SPIRV/Deserialization.h index a346a7fd1e5f7..adff26298c4cc 100644 --- a/mlir/include/mlir/Target/SPIRV/Deserialization.h +++ b/mlir/include/mlir/Target/SPIRV/Deserialization.h @@ -27,8 +27,9 @@ class ModuleOp; /// in the given `context`. Returns the ModuleOp on success; otherwise, reports /// errors to the error handler registered with `context` and returns a null /// module. -OwningOpRef deserialize(ArrayRef binary, - MLIRContext *context); +OwningOpRef +deserialize(ArrayRef binary, MLIRContext *context, + bool enableControlFlowStructurization = true); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp index 7bb8762660599..11a9de91060a8 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp @@ -12,9 +12,10 @@ using namespace mlir; -OwningOpRef spirv::deserialize(ArrayRef binary, - MLIRContext *context) { - Deserializer deserializer(binary, context); +OwningOpRef +spirv::deserialize(ArrayRef binary, MLIRContext *context, + bool enableControlFlowStructurization) { + Deserializer deserializer(binary, context, enableControlFlowStructurization); if (failed(deserializer.deserialize())) return nullptr; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 7afd6e9b25b77..1aa647485496d 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -49,9 +49,11 @@ static inline bool isFnEntryBlock(Block *block) { //===----------------------------------------------------------------------===// spirv::Deserializer::Deserializer(ArrayRef binary, - MLIRContext *context) + MLIRContext *context, + bool enableControlFlowStructurization) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), - module(createModuleOp()), opBuilder(module->getRegion()) + module(createModuleOp()), opBuilder(module->getRegion()), + enableControlFlowStructurization(enableControlFlowStructurization) #ifndef NDEBUG , logger(llvm::dbgs()) @@ -2361,6 +2363,16 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { } LogicalResult spirv::Deserializer::structurizeControlFlow() { + if (!enableControlFlowStructurization) { + LLVM_DEBUG( + { + logger.startLine() + << "//----- [cf] skip structurizing control flow -----//\n"; + logger.indent(); + }); + return success(); + } + LLVM_DEBUG({ logger.startLine() << "//----- [cf] start structurizing control flow -----//\n"; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index bcc78e3e6508d..c1df77ba7b647 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -121,7 +121,10 @@ class Deserializer { public: /// Creates a deserializer for the given SPIR-V `binary` module. /// The SPIR-V ModuleOp will be created into `context. - explicit Deserializer(ArrayRef binary, MLIRContext *context); + /// `enableControlFlowStructurization` is used to enable control flow + /// structurization. + explicit Deserializer(ArrayRef binary, MLIRContext *context, + bool enableControlFlowStructurization); /// Deserializes the remembered SPIR-V binary module. LogicalResult deserialize(); @@ -622,6 +625,9 @@ class Deserializer { /// A list of all structs which have unresolved member types. SmallVector deferredStructTypesInfos; + /// A flag to enable or disable structurizer + bool enableControlFlowStructurization; + #ifndef NDEBUG /// A logger used to emit information during the deserialzation process. llvm::ScopedPrinter logger; diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp index ff34f02d07b73..90e5b684607be 100644 --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -37,7 +37,8 @@ using namespace mlir; // Deserializes the SPIR-V binary module stored in the file named as // `inputFilename` and returns a module containing the SPIR-V module. static OwningOpRef -deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { +deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, + bool enableControlFlowStructurization) { context->loadDialect(); // Make sure the input stream can be treated as a stream of SPIR-V words @@ -51,17 +52,24 @@ deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) { auto binary = llvm::ArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - return spirv::deserialize(binary, context); + return spirv::deserialize(binary, context, enableControlFlowStructurization); } namespace mlir { void registerFromSPIRVTranslation() { + static llvm::cl::opt noControlFlowStructurization( + "spirv-no-control-flow-structurization", + llvm::cl::desc("Disable control flow structurization to enable " + "deserialization of early exits (see #138688)"), + llvm::cl::init(false)); + TranslateToMLIRRegistration fromBinary( "deserialize-spirv", "deserializes the SPIR-V module", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); return deserializeModule( - sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); + sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context, + !noControlFlowStructurization); }); } } // namespace mlir From ccac21a04a1a39413546fc0e63dc3bf3e5fd25c0 Mon Sep 17 00:00:00 2001 From: Igor Wodiany Date: Tue, 27 May 2025 15:43:25 +0100 Subject: [PATCH 2/2] Use struct --- .../mlir/Target/SPIRV/Deserialization.h | 8 +++++++- .../SPIRV/Deserialization/Deserialization.cpp | 4 ++-- .../SPIRV/Deserialization/Deserializer.cpp | 7 +++---- .../SPIRV/Deserialization/Deserializer.h | 9 ++++----- .../lib/Target/SPIRV/TranslateRegistration.cpp | 18 ++++++++++-------- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Target/SPIRV/Deserialization.h b/mlir/include/mlir/Target/SPIRV/Deserialization.h index adff26298c4cc..68eb863b4a6f2 100644 --- a/mlir/include/mlir/Target/SPIRV/Deserialization.h +++ b/mlir/include/mlir/Target/SPIRV/Deserialization.h @@ -23,13 +23,19 @@ class MLIRContext; namespace spirv { class ModuleOp; +struct DeserializationOptions { + // Whether to structurize control flow into `spirv.mlir.selection` and + // `spirv.mlir.loop`. + bool enableControlFlowStructurization = true; +}; + /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports /// errors to the error handler registered with `context` and returns a null /// module. OwningOpRef deserialize(ArrayRef binary, MLIRContext *context, - bool enableControlFlowStructurization = true); + const DeserializationOptions &options = {}); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp index 11a9de91060a8..b82c61cafc8a7 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp @@ -14,8 +14,8 @@ using namespace mlir; OwningOpRef spirv::deserialize(ArrayRef binary, MLIRContext *context, - bool enableControlFlowStructurization) { - Deserializer deserializer(binary, context, enableControlFlowStructurization); + const DeserializationOptions &options) { + Deserializer deserializer(binary, context, options); if (failed(deserializer.deserialize())) return nullptr; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 1aa647485496d..a21d691ae5142 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -50,10 +50,9 @@ static inline bool isFnEntryBlock(Block *block) { spirv::Deserializer::Deserializer(ArrayRef binary, MLIRContext *context, - bool enableControlFlowStructurization) + const spirv::DeserializationOptions &options) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), - module(createModuleOp()), opBuilder(module->getRegion()), - enableControlFlowStructurization(enableControlFlowStructurization) + module(createModuleOp()), opBuilder(module->getRegion()), options(options) #ifndef NDEBUG , logger(llvm::dbgs()) @@ -2363,7 +2362,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { } LogicalResult spirv::Deserializer::structurizeControlFlow() { - if (!enableControlFlowStructurization) { + if (!options.enableControlFlowStructurization) { LLVM_DEBUG( { logger.startLine() diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index c1df77ba7b647..e4556e7652b17 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" +#include "mlir/Target/SPIRV/Deserialization.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringRef.h" @@ -121,10 +122,8 @@ class Deserializer { public: /// Creates a deserializer for the given SPIR-V `binary` module. /// The SPIR-V ModuleOp will be created into `context. - /// `enableControlFlowStructurization` is used to enable control flow - /// structurization. explicit Deserializer(ArrayRef binary, MLIRContext *context, - bool enableControlFlowStructurization); + const DeserializationOptions &options); /// Deserializes the remembered SPIR-V binary module. LogicalResult deserialize(); @@ -625,8 +624,8 @@ class Deserializer { /// A list of all structs which have unresolved member types. SmallVector deferredStructTypesInfos; - /// A flag to enable or disable structurizer - bool enableControlFlowStructurization; + /// Deserialization options. + DeserializationOptions options; #ifndef NDEBUG /// A logger used to emit information during the deserialzation process. diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp index 90e5b684607be..682fff2784775 100644 --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -38,7 +38,7 @@ using namespace mlir; // `inputFilename` and returns a module containing the SPIR-V module. static OwningOpRef deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, - bool enableControlFlowStructurization) { + const spirv::DeserializationOptions &options) { context->loadDialect(); // Make sure the input stream can be treated as a stream of SPIR-V words @@ -52,16 +52,18 @@ deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context, auto binary = llvm::ArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - return spirv::deserialize(binary, context, enableControlFlowStructurization); + return spirv::deserialize(binary, context, options); } namespace mlir { void registerFromSPIRVTranslation() { - static llvm::cl::opt noControlFlowStructurization( - "spirv-no-control-flow-structurization", - llvm::cl::desc("Disable control flow structurization to enable " - "deserialization of early exits (see #138688)"), - llvm::cl::init(false)); + static llvm::cl::opt enableControlFlowStructurization( + "spirv-structurize-control-flow", + llvm::cl::desc( + "Enable control flow structurization into `spirv.mlir.selection` and " + "`spirv.mlir.loop`. This may need to be disabled to support " + "deserialization of early exits (see #138688)"), + llvm::cl::init(true)); TranslateToMLIRRegistration fromBinary( "deserialize-spirv", "deserializes the SPIR-V module", @@ -69,7 +71,7 @@ void registerFromSPIRVTranslation() { assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); return deserializeModule( sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context, - !noControlFlowStructurization); + {enableControlFlowStructurization}); }); } } // namespace mlir