Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
address comments
Browse files Browse the repository at this point in the history
newling committed Jan 28, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent ca75e91 commit 09b0e9c
Showing 4 changed files with 45 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -34,25 +34,6 @@

using namespace xilinx;

namespace {

using namespace mlir;

std::string arrString(ArrayRef<int64_t> vs) {
return std::string("[")
.append(llvm::join(
llvm::map_range(vs, [](int64_t v) { return std::to_string(v); }),
","))
.append("]");
}

std::string intOpFoldResultsString(ArrayRef<OpFoldResult> ofrs) {
auto maybeValues = mlir::getConstantIntValues(ofrs);
if (maybeValues.has_value()) return "[not all constant]";
return arrString(maybeValues.value());
}
} // namespace

namespace mlir::iree_compiler::AMDAIE {

/// Compute the 'global' repetition count: the product over all dimensions with
@@ -114,13 +95,8 @@ static FailureOr<size_t> getRepetitionCount(LogicalObjFifoOpInterface op) {

// merge the repetition counts:
if (repetitionCounts.empty()) return 1;
size_t combinedRepetitionCount = 1;

// Sort and unique-ify the repetition counts:
std::sort(repetitionCounts.begin(), repetitionCounts.end());
auto last = std::unique(repetitionCounts.begin(), repetitionCounts.end());
repetitionCounts.erase(last, repetitionCounts.end());
combinedRepetitionCount = repetitionCounts[0];
int64_t combinedRepetitionCount =
*std::min_element(repetitionCounts.begin(), repetitionCounts.end());

// if any of the repetition counts are not divisible by the combined
// repetition count, that's a problem:
@@ -130,7 +106,7 @@ static FailureOr<size_t> getRepetitionCount(LogicalObjFifoOpInterface op) {
return op.emitOpError()
<< " could not resolved a common repetition count based on the "
"individual repetition counts: "
<< arrString(repetitionCounts);
<< getArrayString<int64_t>(repetitionCounts);
}
return combinedRepetitionCount;
}
@@ -366,8 +342,8 @@ LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides,
repetitionCount))) {
return emitError() << "could not fold repetition counts from sizes: "
<< intOpFoldResultsString(sizes)
<< " strides: " << intOpFoldResultsString(strides)
<< getConstantIntValuesString(sizes)
<< " strides: " << getConstantIntValuesString(strides)
<< " repetitionCount: " << repetitionCount << ".";
}
SmallVector<OpFoldResult> offsets(
Original file line number Diff line number Diff line change
@@ -13,6 +13,13 @@

namespace mlir::iree_compiler::AMDAIE {

std::string getConstantIntValuesString(ArrayRef<OpFoldResult> ofrs) {
auto maybeValues = mlir::getConstantIntValues(ofrs);
if (maybeValues.has_value())
return getArrayString<int64_t>(maybeValues.value());
return "[not all constant integers]";
}

template <typename T>
std::optional<T> getConfigAttr(IREE::HAL::ExecutableTargetAttr targetAttr,
StringRef name) {
Original file line number Diff line number Diff line change
@@ -91,6 +91,20 @@ int findLargestFactor(int num, int max, int multiple);

} // namespace detail

/// Convert an array into a string, for example "[1,2,3]".
template <typename T>
std::string getArrayString(ArrayRef<T> vs) {
return std::string("[")
.append(llvm::join(
llvm::map_range(vs, [](T v) { return std::to_string(v); }), ","))
.append("]");
}

/// If all values in `opFoldResults` are constant, return a string
/// representation of the constant values. Otherwise, return
/// "[not constant integers]".
std::string getConstantIntValuesString(ArrayRef<OpFoldResult> opFoldResults);

} // namespace mlir::iree_compiler::AMDAIE

#endif
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "gtest/gtest.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

namespace {

@@ -42,6 +43,24 @@ TEST(FindLargestFactorTest, Test0) {
detail::findLargestFactor(firstPrimeAbove1e5, firstPrimeAbove1e5 - 1), 1);
}

TEST(OpFoldResultPrinting, Test0) {
mlir::MLIRContext context;
llvm::SmallVector<mlir::OpFoldResult> opFoldResults = {};
EXPECT_EQ(getConstantIntValuesString(opFoldResults), "[]");

mlir::OpFoldResult three = getAsIndexOpFoldResult(&context, 3);
opFoldResults.push_back(three);
EXPECT_EQ(getConstantIntValuesString(opFoldResults), "[3]");

mlir::OpFoldResult four = getAsIndexOpFoldResult(&context, 4);
opFoldResults.push_back(four);
EXPECT_EQ(getConstantIntValuesString(opFoldResults), "[3,4]");

opFoldResults.push_back(mlir::Value{});
EXPECT_EQ(getConstantIntValuesString(opFoldResults),
"[not all constant integers]");
}

} // namespace

int main(int argc, char **argv) {

0 comments on commit 09b0e9c

Please sign in to comment.