diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h index cda07d6a91364..4d399f2e0ed19 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -83,59 +83,6 @@ namespace acc { /// combined and the final mapping value would be 5 (4 | 1). enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 }; -/// Used to obtain the `varPtr` from a data clause operation. -/// Returns empty value if not a data clause operation or is a data exit -/// operation with no `varPtr`. -mlir::Value getVarPtr(mlir::Operation *accDataClauseOp); - -/// Used to obtain the `accPtr` from a data clause operation. -/// When a data entry operation, it obtains its result `accPtr` value. -/// If a data exit operation, it obtains its operand `accPtr` value. -/// Returns empty value if not a data clause operation. -mlir::Value getAccPtr(mlir::Operation *accDataClauseOp); - -/// Used to obtain the `varPtrPtr` from a data clause operation. -/// Returns empty value if not a data clause operation. -mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp); - -/// Used to obtain `bounds` from an acc data clause operation. -/// Returns an empty vector if there are no bounds. -mlir::SmallVector getBounds(mlir::Operation *accDataClauseOp); - -/// Used to obtain `async` operands from an acc data clause operation. -/// Returns an empty vector if there are no such operands. -mlir::SmallVector -getAsyncOperands(mlir::Operation *accDataClauseOp); - -/// Returns an array of acc:DeviceTypeAttr attributes attached to -/// an acc data clause operation, that correspond to the device types -/// associated with the async clauses with an async-value. -mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp); - -/// Returns an array of acc:DeviceTypeAttr attributes attached to -/// an acc data clause operation, that correspond to the device types -/// associated with the async clauses without an async-value. -mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp); - -/// Used to obtain the `name` from an acc operation. -std::optional getVarName(mlir::Operation *accOp); - -/// Used to obtain the `dataClause` from a data entry operation. -/// Returns empty optional if not a data entry operation. -std::optional -getDataClause(mlir::Operation *accDataEntryOp); - -/// Used to find out whether data operation is implicit. -/// Returns false if not a data operation or if it is a data operation without -/// implicit flag. -bool getImplicitFlag(mlir::Operation *accDataEntryOp); - -/// Used to get an immutable range iterating over the data operands. -mlir::ValueRange getDataOperands(mlir::Operation *accOp); - -/// Used to get a mutable range iterating over the data operands. -mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp); - /// Used to obtain the attribute name for declare. static constexpr StringLiteral getDeclareAttrName() { return StringLiteral("acc.declare"); diff --git a/mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h new file mode 100644 index 0000000000000..2c84381ff8331 --- /dev/null +++ b/mlir/include/mlir/Dialect/OpenACC/Utils/OpenACCUtils.h @@ -0,0 +1,116 @@ +//===- OpenACCUtils.h - OpenACC Utilities -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_OPENACC_UTILS_OPENACCUTILS_H_ +#define MLIR_DIALECT_OPENACC_UTILS_OPENACCUTILS_H_ + +#include "mlir/Dialect/OpenACC/OpenACC.h" + +namespace mlir { +namespace acc { +/// Used to obtain the `varPtr` from a data clause operation. +/// Returns empty value if not a data clause operation or is a data exit +/// operation with no `varPtr`. +mlir::Value getVarPtr(mlir::Operation *accDataClauseOp); + +/// Used to set the `varPtr` of a data clause operation. +/// Returns true if it was set successfully and false if this is not a data +/// clause operation. +bool setVarPtr(mlir::Operation *accDataClauseOp, mlir::Value varPtr); + +/// Used to obtain the `accPtr` from a data clause operation. +/// When a data entry operation, it obtains its result `accPtr` value. +/// If a data exit operation, it obtains its operand `accPtr` value. +/// Returns empty value if not a data clause operation. +mlir::Value getAccPtr(mlir::Operation *accDataClauseOp); + +/// Used to set the `accPtr` for a data exit operation. +/// Returns true if it was set successfully and false if is not a data exit +/// operation (data entry operations have their result as `accPtr` which +/// cannot be changed). +bool setAccPtr(mlir::Operation *accDataClauseOp, mlir::Value accPtr); + +/// Used to obtain the `varPtrPtr` from a data clause operation. +/// Returns empty value if not a data clause operation. +mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp); + +/// Used to set the `varPtrPtr` for a data clause operation. +/// Returns false if the operation does not have varPtrPtr or is not a data +/// clause op. +bool setVarPtrPtr(mlir::Operation *accDataClauseOp, mlir::Value varPtrPtr); + +/// Used to obtain `bounds` from an acc data clause operation. +/// Returns an empty vector if there are no bounds. +mlir::SmallVector getBounds(mlir::Operation *accDataClauseOp); + +/// Used to set `bounds` for an acc data clause operation. It completely +/// replaces all bounds operands with the new list. +/// Returns false if new bounds were not set (such as when argument is not +/// an acc data clause operation). +bool setBounds(mlir::Operation *accDataClauseOp, + mlir::SmallVector &bounds); +bool setBounds(mlir::Operation *accDataClauseOp, mlir::Value bound); + +/// Used to obtain the `dataClause` from a data clause operation. +/// Returns empty optional if not a data operation. +std::optional +getDataClause(mlir::Operation *accDataClauseOp); + +/// Used to set the `dataClause` on a data clause operation. +/// Returns true if successfully set and false otherwise. +bool setDataClause(mlir::Operation *accDataClauseOp, + mlir::acc::DataClause dataClause); + +/// Used to find out whether this data operation uses structured runtime +/// counters. Returns false if not a data operation or if it is a data operation +/// without the structured flag set. +bool getStructuredFlag(mlir::Operation *accDataClauseOp); + +/// Used to update the data clause operation whether it represents structured +/// or dynamic (value of `structured` is passed as false). +/// Returns true if successfully set and false otherwise. +bool setStructuredFlag(mlir::Operation *accDataClauseOp, bool structured); + +/// Used to find out whether data operation is implicit. +/// Returns false if not a data operation or if it is a data operation without +/// implicit flag. +bool getImplicitFlag(mlir::Operation *accDataClauseOp); + +/// Used to update the data clause operation whether this operation is +/// implicit or explicit (`implicit` set as false). +/// Returns true if successfully set and false otherwise. +bool setImplicitFlag(mlir::Operation *accDataClauseOp, bool implicit); + +/// Used to obtain the `name` from an acc operation. +std::optional getVarName(mlir::Operation *accDataClauseOp); + +/// Used to obtain `async` operands from an acc data clause operation. +/// Returns an empty vector if there are no such operands. +mlir::SmallVector +getAsyncOperands(mlir::Operation *accDataClauseOp); + +/// Returns an array of acc:DeviceTypeAttr attributes attached to +/// an acc data clause operation, that correspond to the device types +/// associated with the async clauses with an async-value. +mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp); + +/// Returns an array of acc:DeviceTypeAttr attributes attached to +/// an acc data clause operation, that correspond to the device types +/// associated with the async clauses without an async-value. +mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp); + +/// Used to get an immutable range iterating over the data operands. +mlir::ValueRange getDataOperands(mlir::Operation *accOp); + +/// Used to get a mutable range iterating over the data operands. +mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp); + +} // namespace acc +} // namespace mlir + +#endif // MLIR_DIALECT_OPENACC_UTILS_OPENACCUTILS_H_ diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 9f57627c321fb..31167e6af908b 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 280260e0485bb..1cc67629f9741 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -2379,11 +2379,21 @@ checkDeclareOperands(Op &op, const mlir::ValueRange &operands, "expect valid declare data entry operation or acc.getdeviceptr " "as defining op"); - mlir::Value varPtr{getVarPtr(operand.getDefiningOp())}; + mlir::Value varPtr{ + llvm::TypeSwitch( + operand.getDefiningOp()) + .Case( + [&](auto entry) { return entry.getVarPtr(); }) + .Default([&](mlir::Operation *) { return mlir::Value(); })}; assert(varPtr && "declare operands can only be data entry operations which " "must have varPtr"); std::optional dataClauseOptional{ - getDataClause(operand.getDefiningOp())}; + llvm::TypeSwitch>( + operand.getDefiningOp()) + .Case( + [&](auto entry) { return entry.getDataClause(); }) + .Default([&](mlir::Operation *) { return std::nullopt; })}; assert(dataClauseOptional.has_value() && "declare operands can only be data entry operations which must have " "dataClause"); @@ -2409,8 +2419,13 @@ checkDeclareOperands(Op &op, const mlir::ValueRange &operands, // since implicit data action may be inserted to do actions like updating // device copy, in which case the variable is not necessarily implicitly // declare'd. + bool operandOpImplicitFlag{ + llvm::TypeSwitch(operand.getDefiningOp()) + .Case( + [&](auto entry) { return entry.getImplicit(); }) + .Default([&](mlir::Operation *) { return false; })}; if (declAttr.getImplicit() && - declAttr.getImplicit() != acc::getImplicitFlag(operand.getDefiningOp())) + declAttr.getImplicit() != operandOpImplicitFlag) return op.emitError( "implicitness must match between declare op and flag on variable"); } @@ -2868,127 +2883,3 @@ LogicalResult acc::WaitOp::verify() { #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" - -//===----------------------------------------------------------------------===// -// acc dialect utilities -//===----------------------------------------------------------------------===// - -mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { - auto varPtr{llvm::TypeSwitch(accDataClauseOp) - .Case( - [&](auto entry) { return entry.getVarPtr(); }) - .Case( - [&](auto exit) { return exit.getVarPtr(); }) - .Default([&](mlir::Operation *) { return mlir::Value(); })}; - return varPtr; -} - -mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { - auto accPtr{llvm::TypeSwitch(accDataClauseOp) - .Case( - [&](auto dataClause) { return dataClause.getAccPtr(); }) - .Default([&](mlir::Operation *) { return mlir::Value(); })}; - return accPtr; -} - -mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) { - auto varPtrPtr{ - llvm::TypeSwitch(accDataClauseOp) - .Case( - [&](auto dataClause) { return dataClause.getVarPtrPtr(); }) - .Default([&](mlir::Operation *) { return mlir::Value(); })}; - return varPtrPtr; -} - -mlir::SmallVector -mlir::acc::getBounds(mlir::Operation *accDataClauseOp) { - mlir::SmallVector bounds{ - llvm::TypeSwitch>( - accDataClauseOp) - .Case([&](auto dataClause) { - return mlir::SmallVector( - dataClause.getBounds().begin(), dataClause.getBounds().end()); - }) - .Default([&](mlir::Operation *) { - return mlir::SmallVector(); - })}; - return bounds; -} - -mlir::SmallVector -mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) { - return llvm::TypeSwitch>( - accDataClauseOp) - .Case([&](auto dataClause) { - return mlir::SmallVector( - dataClause.getAsyncOperands().begin(), - dataClause.getAsyncOperands().end()); - }) - .Default([&](mlir::Operation *) { - return mlir::SmallVector(); - }); -} - -mlir::ArrayAttr -mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) { - return llvm::TypeSwitch(accDataClauseOp) - .Case([&](auto dataClause) { - return dataClause.getAsyncOperandsDeviceTypeAttr(); - }) - .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; }); -} - -mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) { - return llvm::TypeSwitch(accDataClauseOp) - .Case( - [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); }) - .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; }); -} - -std::optional mlir::acc::getVarName(mlir::Operation *accOp) { - auto name{ - llvm::TypeSwitch>(accOp) - .Case([&](auto entry) { return entry.getName(); }) - .Default([&](mlir::Operation *) -> std::optional { - return {}; - })}; - return name; -} - -std::optional -mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) { - auto dataClause{ - llvm::TypeSwitch>( - accDataEntryOp) - .Case( - [&](auto entry) { return entry.getDataClause(); }) - .Default([&](mlir::Operation *) { return std::nullopt; })}; - return dataClause; -} - -bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) { - auto implicit{llvm::TypeSwitch(accDataEntryOp) - .Case( - [&](auto entry) { return entry.getImplicit(); }) - .Default([&](mlir::Operation *) { return false; })}; - return implicit; -} - -mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) { - auto dataOperands{ - llvm::TypeSwitch(accOp) - .Case( - [&](auto entry) { return entry.getDataClauseOperands(); }) - .Default([&](mlir::Operation *) { return mlir::ValueRange(); })}; - return dataOperands; -} - -mlir::MutableOperandRange -mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { - auto dataOperands{ - llvm::TypeSwitch(accOp) - .Case( - [&](auto entry) { return entry.getDataClauseOperandsMutable(); }) - .Default([&](mlir::Operation *) { return nullptr; })}; - return dataOperands; -} diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 7d934956089a5..42fcc138c5ffb 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms LINK_LIBS PUBLIC MLIROpenACCDialect + MLIROpenACCUtils MLIRFuncDialect MLIRIR MLIRPass diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp index 026b309ce4969..897c2dabd88ea 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp @@ -6,12 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/OpenACC/Transforms/Passes.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/OpenACC/Utils/OpenACCUtils.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/ErrorHandling.h" namespace mlir { diff --git a/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt new file mode 100644 index 0000000000000..83a843b65168b --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Utils/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_dialect_library(MLIROpenACCUtils + OpenACCUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + DEPENDS + MLIROpenACCPassIncGen + MLIROpenACCOpsIncGen + MLIROpenACCEnumsIncGen + MLIROpenACCAttributesIncGen + MLIROpenACCMPOpsInterfacesIncGen + MLIROpenACCOpsInterfacesIncGen + MLIROpenACCTypeInterfacesIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACCDialect + MLIRSupport +) diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp new file mode 100644 index 0000000000000..2fa4d70fd1f3f --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -0,0 +1,240 @@ +//===- OpenACCUtils.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Utils/OpenACCUtils.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "llvm/ADT/TypeSwitch.h" + +mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { + auto varPtr{llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + return dataClauseOp.getVarPtr(); + }) + .Default([&](mlir::Operation *) { return mlir::Value(); })}; + return varPtr; +} + +bool mlir::acc::setVarPtr(mlir::Operation *accDataClauseOp, + mlir::Value varPtr) { + bool res{llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + dataClauseOp.getVarPtrMutable().assign(varPtr); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { + auto accPtr{ + llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getAccPtr(); }) + .Default([&](mlir::Operation *) { return mlir::Value(); })}; + return accPtr; +} + +bool mlir::acc::setAccPtr(mlir::Operation *accDataClauseOp, + mlir::Value accPtr) { + bool res{llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + // Cannot set the result of an existing operation and + // data entry ops produce `accPtr` as a result. + return false; + }) + .Case([&](auto dataClauseOp) { + dataClauseOp.getAccPtrMutable().assign(accPtr); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) { + auto varPtrPtr{ + llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getVarPtrPtr(); }) + .Default([&](mlir::Operation *) { return mlir::Value(); })}; + return varPtrPtr; +} + +bool mlir::acc::setVarPtrPtr(mlir::Operation *accDataClauseOp, + mlir::Value varPtrPtr) { + bool res{llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + dataClauseOp.getVarPtrPtrMutable().assign(varPtrPtr); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +mlir::SmallVector +mlir::acc::getBounds(mlir::Operation *accDataClauseOp) { + mlir::SmallVector bounds{ + llvm::TypeSwitch>( + accDataClauseOp) + .Case([&](auto dataClauseOp) { + return mlir::SmallVector( + dataClauseOp.getBounds().begin(), + dataClauseOp.getBounds().end()); + }) + .Default([&](mlir::Operation *) { + return mlir::SmallVector(); + })}; + return bounds; +} + +bool mlir::acc::setBounds(mlir::Operation *accDataClauseOp, + mlir::SmallVector &bounds) { + bool res{ + llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + dataClauseOp.getBoundsMutable().assign(bounds); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +bool mlir::acc::setBounds(mlir::Operation *accDataClauseOp, mlir::Value bound) { + mlir::SmallVector bounds({bound}); + return setBounds(accDataClauseOp, bounds); +} + +std::optional +mlir::acc::getVarName(mlir::Operation *accDataClauseOp) { + auto name{ + llvm::TypeSwitch>( + accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getName(); }) + .Default([&](mlir::Operation *) -> std::optional { + return {}; + })}; + return name; +} + +std::optional +mlir::acc::getDataClause(mlir::Operation *accDataClauseOp) { + auto dataClause{ + llvm::TypeSwitch>( + accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getDataClause(); }) + .Default([&](mlir::Operation *) { return std::nullopt; })}; + return dataClause; +} + +bool mlir::acc::setDataClause(mlir::Operation *accDataClauseOp, + mlir::acc::DataClause dataClause) { + bool res{ + llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + dataClauseOp.setDataClause(dataClause); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +bool mlir::acc::getStructuredFlag(mlir::Operation *accDataClauseOp) { + auto structured{ + llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getStructured(); }) + .Default([&](mlir::Operation *) { return false; })}; + return structured; +} + +bool mlir::acc::setStructuredFlag(mlir::Operation *accDataClauseOp, + bool structured) { + auto res{ + llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + dataClauseOp.setStructured(structured); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +bool mlir::acc::getImplicitFlag(mlir::Operation *accDataClauseOp) { + auto implicit{ + llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getImplicit(); }) + .Default([&](mlir::Operation *) { return false; })}; + return implicit; +} + +bool mlir::acc::setImplicitFlag(mlir::Operation *accDataClauseOp, + bool implicit) { + auto res{ + llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + dataClauseOp.setImplicit(implicit); + return true; + }) + .Default([&](mlir::Operation *) { return false; })}; + return res; +} + +mlir::SmallVector +mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) { + return llvm::TypeSwitch>( + accDataClauseOp) + .Case([&](auto dataClauseOp) { + return mlir::SmallVector( + dataClauseOp.getAsyncOperands().begin(), + dataClauseOp.getAsyncOperands().end()); + }) + .Default([&](mlir::Operation *) { + return mlir::SmallVector(); + }); +} + +mlir::ArrayAttr +mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) { + return llvm::TypeSwitch(accDataClauseOp) + .Case([&](auto dataClauseOp) { + return dataClauseOp.getAsyncOperandsDeviceTypeAttr(); + }) + .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; }); +} + +mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) { + return llvm::TypeSwitch(accDataClauseOp) + .Case( + [&](auto dataClauseOp) { return dataClauseOp.getAsyncOnlyAttr(); }) + .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; }); +} + +mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) { + auto dataOperands{ + llvm::TypeSwitch(accOp) + .Case([&](auto accConstructOp) { + return accConstructOp.getDataClauseOperands(); + }) + .Default([&](mlir::Operation *) { return mlir::ValueRange(); })}; + return dataOperands; +} + +mlir::MutableOperandRange +mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { + auto dataOperands{ + llvm::TypeSwitch(accOp) + .Case([&](auto accConstructOp) { + return accConstructOp.getDataClauseOperandsMutable(); + }) + .Default([&](mlir::Operation *) { return nullptr; })}; + return dataOperands; +} diff --git a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt index 5133d7fc38296..c1ba546f7d069 100644 --- a/mlir/unittests/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/unittests/Dialect/OpenACC/CMakeLists.txt @@ -1,8 +1,10 @@ add_mlir_unittest(MLIROpenACCTests OpenACCOpsTest.cpp + OpenACCUtilsTest.cpp ) target_link_libraries(MLIROpenACCTests PRIVATE MLIRIR MLIROpenACCDialect + MLIROpenACCUtils ) diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp new file mode 100644 index 0000000000000..2f3be7c9106a7 --- /dev/null +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -0,0 +1,457 @@ +//===- OpenACCUtilsTest.cpp - Unit tests for OpenACC utils ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Utils/OpenACCUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::acc; + +class OpenACCUtilsTest : public ::testing::Test { +protected: + OpenACCUtilsTest() : b(&context), loc(UnknownLoc::get(&context)) { + context.loadDialect(); + } + + MLIRContext context; + OpBuilder b; + Location loc; +}; + +template +void testDataOpVarPtr(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + auto memrefTy2 = MemRefType::get({}, b.getF64Type()); + OwningOpRef varPtrOp2 = + b.create(loc, memrefTy2); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } + EXPECT_EQ(varPtrOp->getResult(), getVarPtr(op.get())); + EXPECT_EQ(op->getVarPtr(), getVarPtr(op.get())); + setVarPtr(op.get(), varPtrOp2->getResult()); + EXPECT_EQ(varPtrOp2->getResult(), getVarPtr(op.get())); + EXPECT_EQ(op->getVarPtr(), getVarPtr(op.get())); +} + +TEST_F(OpenACCUtilsTest, dataOpVarPtr) { + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); +} + +template +void testDataOpAccPtr(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + auto memrefTy2 = MemRefType::get({}, b.getF64Type()); + OwningOpRef varPtrOp2 = + b.create(loc, memrefTy2); + OwningOpRef accPtrOp2 = + b.create(loc, varPtrOp2->getResult(), true, true); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get())); + EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult()); + setAccPtr(op.get(), accPtrOp2->getResult()); + EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get())); + EXPECT_EQ(op->getAccPtr(), accPtrOp2->getResult()); + } else if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get())); + EXPECT_EQ(op->getAccPtr(), accPtrOp->getResult()); + setAccPtr(op.get(), accPtrOp2->getResult()); + EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get())); + EXPECT_EQ(op->getAccPtr(), accPtrOp2->getResult()); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + EXPECT_EQ(op->getAccPtr(), op->getResult()); + EXPECT_EQ(op->getAccPtr(), getAccPtr(op.get())); + } +} + +TEST_F(OpenACCUtilsTest, dataOpAccPtr) { + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); + testDataOpAccPtr(b, context, loc); +} + +template +void testDataOpVarPtrPtr(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + + auto memrefTy2 = MemRefType::get({}, memrefTy); + OwningOpRef varPtrPtr = + b.create(loc, memrefTy2); + + OwningOpRef op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + + EXPECT_EQ(op->getVarPtrPtr(), getVarPtrPtr(op.get())); + EXPECT_EQ(op->getVarPtrPtr(), Value()); + setVarPtrPtr(op.get(), varPtrPtr->getResult()); + EXPECT_EQ(op->getVarPtrPtr(), getVarPtrPtr(op.get())); + EXPECT_EQ(op->getVarPtrPtr(), varPtrPtr->getResult()); +} + +TEST_F(OpenACCUtilsTest, dataOpVarPtrPtr) { + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); + testDataOpVarPtr(b, context, loc); +} + +template +void testDataOpBounds(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + OwningOpRef extent = + b.create(loc, 1); + OwningOpRef bounds = + b.create(loc, extent->getResult()); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } + + EXPECT_EQ(op->getBounds().size(), getBounds(op.get()).size()); + for (auto [bound1, bound2] : + llvm::zip(op->getBounds(), getBounds(op.get()))) { + EXPECT_EQ(bound1, bound2); + } + setBounds(op.get(), bounds->getResult()); + EXPECT_EQ(op->getBounds().size(), getBounds(op.get()).size()); + for (auto [bound1, bound2] : + llvm::zip(op->getBounds(), getBounds(op.get()))) { + EXPECT_EQ(bound1, bound2); + EXPECT_EQ(bound1, bounds->getResult()); + } +} + +TEST_F(OpenACCUtilsTest, dataOpBounds) { + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); + testDataOpBounds(b, context, loc); +} + +template +void testDataOpName(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, "varName"); + } else if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, "varName"); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true, "varName"); + } + + EXPECT_EQ(op->getNameAttr().str(), "varName"); + EXPECT_EQ(getVarName(op.get()), "varName"); +} + +TEST_F(OpenACCUtilsTest, dataOpName) { + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); + testDataOpName(b, context, loc); +} + +template +void testDataOpStructured(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } + + EXPECT_EQ(op->getStructured(), getStructuredFlag(op.get())); + EXPECT_EQ(op->getStructured(), true); + setStructuredFlag(op.get(), false); + EXPECT_EQ(op->getStructured(), getStructuredFlag(op.get())); + EXPECT_EQ(op->getStructured(), false); +} + +TEST_F(OpenACCUtilsTest, dataOpStructured) { + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); + testDataOpStructured(b, context, loc); +} + +template +void testDataOpImplicit(OpBuilder &b, MLIRContext &context, Location loc) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } + + EXPECT_EQ(op->getImplicit(), getImplicitFlag(op.get())); + EXPECT_EQ(op->getImplicit(), true); + setImplicitFlag(op.get(), false); + EXPECT_EQ(op->getImplicit(), getImplicitFlag(op.get())); + EXPECT_EQ(op->getImplicit(), false); +} + +TEST_F(OpenACCUtilsTest, dataOpImplicit) { + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); + testDataOpImplicit(b, context, loc); +} + +template +void testDataOpDataClause(OpBuilder &b, MLIRContext &context, Location loc, + DataClause dataClause) { + auto memrefTy = MemRefType::get({}, b.getI32Type()); + OwningOpRef varPtrOp = + b.create(loc, memrefTy); + OwningOpRef accPtrOp = + b.create(loc, varPtrOp->getResult(), true, true); + + OwningOpRef op; + if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else if constexpr (std::is_same() || + std::is_same()) { + op = b.create(loc, /*accPtr=*/accPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } else { + op = b.create(loc, varPtrOp->getResult(), + /*structured=*/true, /*implicit=*/true); + } + + EXPECT_EQ(op->getDataClause(), getDataClause(op.get()).value()); + EXPECT_EQ(op->getDataClause(), dataClause); + setDataClause(op.get(), DataClause::acc_getdeviceptr); + EXPECT_EQ(op->getDataClause(), getDataClause(op.get()).value()); + EXPECT_EQ(op->getDataClause(), DataClause::acc_getdeviceptr); +} + +TEST_F(OpenACCUtilsTest, dataOpDataClause) { + testDataOpDataClause(b, context, loc, DataClause::acc_private); + testDataOpDataClause(b, context, loc, + DataClause::acc_firstprivate); + testDataOpDataClause(b, context, loc, DataClause::acc_reduction); + testDataOpDataClause(b, context, loc, DataClause::acc_deviceptr); + testDataOpDataClause(b, context, loc, DataClause::acc_present); + testDataOpDataClause(b, context, loc, DataClause::acc_copyin); + testDataOpDataClause(b, context, loc, DataClause::acc_create); + testDataOpDataClause(b, context, loc, DataClause::acc_no_create); + testDataOpDataClause(b, context, loc, DataClause::acc_attach); + testDataOpDataClause(b, context, loc, + DataClause::acc_getdeviceptr); + testDataOpDataClause(b, context, loc, + DataClause::acc_update_device); + testDataOpDataClause(b, context, loc, + DataClause::acc_use_device); + testDataOpDataClause( + b, context, loc, DataClause::acc_declare_device_resident); + testDataOpDataClause(b, context, loc, + DataClause::acc_declare_link); + testDataOpDataClause(b, context, loc, DataClause::acc_cache); + testDataOpDataClause(b, context, loc, DataClause::acc_copyout); + testDataOpDataClause(b, context, loc, + DataClause::acc_update_host); + testDataOpDataClause(b, context, loc, DataClause::acc_delete); + testDataOpDataClause(b, context, loc, DataClause::acc_detach); +}