Skip to content

Commit

Permalink
[LowerToAIE] fix for op could not fold repetition counts (#1055)
Browse files Browse the repository at this point in the history
This is a fix for the compilation error
```
<unknown>:0: error: 'amdaie.npu.circular_dma_cpy_nd' op could not fold repetition counts
```

that is observed some transposed matmuls with the new pack-peel
pipeline. What was happening was that the `lower-to-aie` pass is assumed
a stride-0 leading dimension was always or never present, but that
wasn't the case
  • Loading branch information
newling authored Jan 29, 2025
1 parent 7a059cd commit 9438dd2
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 67 deletions.
21 changes: 13 additions & 8 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,10 @@ def __init__(
self.n_kernel_runs = n_kernel_runs

self.tile_pipeline = tile_pipeline
if tile_pipeline == "pack-peel":
self.labels.append("PackPeel")
elif tile_pipeline == "pad-pack":
self.labels.append("PadPack")
self.labels.append(self.tile_pipeline)

self.lower_to_aie_pipeline = lower_to_aie_pipeline
if lower_to_aie_pipeline == "air":
self.labels.append("Air")
elif lower_to_aie_pipeline == "objectFifo":
self.labels.append("ObjectFifo")
self.labels.append(self.lower_to_aie_pipeline)

self.use_ukernel = use_ukernel
if use_ukernel:
Expand Down Expand Up @@ -1670,6 +1664,17 @@ def __init__(self):
for input_type, acc_type in zip(["i8", "bf16"], ["i32", "f32"]):
self.register(MatmulTransposeB(32, 32, 32, input_type, acc_type))
self.register(MatmulTransposeB(128, 256, 128, input_type, acc_type))
self.register(
MatmulTransposeB(
128,
256,
128,
input_type,
acc_type,
tile_pipeline="pack-peel-4-level-tiling",
name_suffix="4level",
)
)
self.register(MatmulTransposeB(1536, 1536, 2048, input_type, acc_type))

# MatmulTransposeA test(s):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,77 +36,79 @@ using namespace xilinx;

namespace mlir::iree_compiler::AMDAIE {

/// Utility to update the current common repetition count based on the new size
/// and stride access pattern. If this new access pattern has a smaller
/// repetition count, the common repetition count will be decreased.
static FailureOr<size_t> getUpdatedRepetitionCount(
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides,
size_t curRepetitionCount) {
if (!strides.empty() && !sizes.empty() && isConstantIntValue(strides[0], 0)) {
int i = 0;
size_t repetitionCount{1};
while (i < strides.size() && isConstantIntValue(strides[i], 0)) {
std::optional<int64_t> maybeRepetitionCount =
getConstantIntValue(sizes[i]);
if (!maybeRepetitionCount) return failure();
assert(maybeRepetitionCount.value() >= 0 &&
"sizes should always be larger or equal to zero");
repetitionCount *= maybeRepetitionCount.value();
i += 1;
}
if (curRepetitionCount == 1) return repetitionCount;
size_t newRepetitionCount = std::min(repetitionCount, curRepetitionCount);
if (repetitionCount % newRepetitionCount != 0) return failure();
if (curRepetitionCount % newRepetitionCount != 0) return failure();
return newRepetitionCount;
/// Compute the 'global' repetition count: the product over all dimensions with
/// zero stride of the size of the dimension.
///
/// The case where sizes and strides are empty is a special case, and '0' is
/// returned.
static int64_t getRepetitionCount(ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
assert(sizes.size() == strides.size() &&
"expected stride and size vectors of same size");
if (strides.empty()) return 0;
size_t repetitionCount{1};
for (uint32_t i = 0; i < strides.size(); ++i) {
if (!isConstantIntValue(strides[i], 0)) continue;
std::optional<int64_t> maybeSize = getConstantIntValue(sizes[i]);
assert(maybeSize.has_value() &&
"expected constant size in this zero stride dimension");
assert(maybeSize.value() >= 0 && "expected a non-negative size");
repetitionCount *= maybeSize.value();
}
return curRepetitionCount;
return repetitionCount;
}

/// Utility to retrieve the common repetition count from all producers and
/// consumers of a logical objectFifo.
static FailureOr<size_t> getRepetitionCount(LogicalObjFifoOpInterface op) {
size_t repetitionCount = 1;
SmallVector<int64_t> repetitionCounts;
auto appendRepetitionCount = [&](ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
size_t repetitionCount = getRepetitionCount(sizes, strides);
if (repetitionCount != 0) repetitionCounts.push_back(repetitionCount);
};

for (Operation *userOp : op->getUsers()) {
if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp);
connectionOp.getTarget() &&
dyn_cast_if_present<LogicalObjFifoOpInterface>(
connectionOp.getTarget().getDefiningOp()) == op) {
// Handle producer connection operations.
FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
connectionOp.getNpuCircularDmaCpyNdUser();
if (failed(maybeNpuDmaUserOp)) {
return connectionOp.emitOpError()
<< "does not have a circular DMA op user";
}
FailureOr<size_t> maybeNewRepetitionCount = getUpdatedRepetitionCount(
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(), repetitionCount);
if (failed(maybeNewRepetitionCount)) {
return maybeNpuDmaUserOp->emitOpError() << "no repetition count found";
}
repetitionCount = maybeNewRepetitionCount.value();
} else if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp);
connectionOp.getSource() &&
dyn_cast_if_present<LogicalObjFifoOpInterface>(
connectionOp.getSource().getDefiningOp()) == op) {
// Handle consumer connection operations.
if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp)) {
FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
connectionOp.getNpuCircularDmaCpyNdUser();
if (failed(maybeNpuDmaUserOp)) {
return connectionOp.emitOpError()
<< "does not have a circular DMA op user";

if (failed(maybeNpuDmaUserOp)) continue;

AMDAIE::NpuCircularDmaCpyNdOp npuDma = maybeNpuDmaUserOp.value();

if (connectionOp.getTarget() &&
dyn_cast_if_present<LogicalObjFifoOpInterface>(
connectionOp.getTarget().getDefiningOp()) == op) {
appendRepetitionCount(npuDma.getTargetMixedSizes(),
npuDma.getTargetMixedStrides());
}
FailureOr<size_t> maybeNewRepetitionCount = getUpdatedRepetitionCount(
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(), repetitionCount);
if (failed(maybeNewRepetitionCount)) {
return maybeNpuDmaUserOp->emitOpError() << "no repetition count found";

if (connectionOp.getSource() &&
dyn_cast_if_present<LogicalObjFifoOpInterface>(
connectionOp.getSource().getDefiningOp()) == op) {
appendRepetitionCount(npuDma.getSourceMixedSizes(),
npuDma.getSourceMixedStrides());
}
repetitionCount = maybeNewRepetitionCount.value();
}
}
return repetitionCount;

// merge the repetition counts:
if (repetitionCounts.empty()) return 1;
int64_t combinedRepetitionCount =
*std::min_element(repetitionCounts.begin(), repetitionCounts.end());

// if any of the repetition counts are not divisible by the combined
// repetition count, that's a problem:
if (!std::all_of(
repetitionCounts.begin(), repetitionCounts.end(),
[&](size_t c) { return c % combinedRepetitionCount == 0; })) {
return op.emitOpError()
<< " could not resolved a common repetition count based on the "
"individual repetition counts: "
<< getArrayString<int64_t>(repetitionCounts);
}
return combinedRepetitionCount;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -339,7 +341,10 @@ LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
function_ref<InFlightDiagnostic()> emitError) {
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides,
repetitionCount))) {
return emitError() << "could not fold repetition counts";
return emitError() << "could not fold repetition counts from sizes: "
<< getConstantIntValuesString(sizes)
<< " strides: " << getConstantIntValuesString(strides)
<< " repetitionCount: " << repetitionCount << ".";
}
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

namespace mlir::iree_compiler::AMDAIE {

std::string getConstantIntValuesString(ArrayRef<OpFoldResult> ofrs) {
auto maybeValues = mlir::getConstantIntValues(ofrs);
if (maybeValues.has_value())
return getArrayString<int64_t>(maybeValues.value());
return "[not all constant integers]";
}

template <typename T>
std::optional<T> getConfigAttr(IREE::HAL::ExecutableTargetAttr targetAttr,
StringRef name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ int findLargestFactor(int num, int max, int multiple);

} // namespace detail

/// Convert an array into a string, for example "[1,2,3]".
template <typename T>
std::string getArrayString(ArrayRef<T> vs) {
return std::string("[")
.append(llvm::join(
llvm::map_range(vs, [](T v) { return std::to_string(v); }), ","))
.append("]");
}

/// If all values in `opFoldResults` are constant, return a string
/// representation of the constant values. Otherwise, return
/// "[not constant integers]".
std::string getConstantIntValuesString(ArrayRef<OpFoldResult> opFoldResults);

} // namespace mlir::iree_compiler::AMDAIE

#endif
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "gtest/gtest.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

namespace {

Expand Down Expand Up @@ -42,6 +43,24 @@ TEST(FindLargestFactorTest, Test0) {
detail::findLargestFactor(firstPrimeAbove1e5, firstPrimeAbove1e5 - 1), 1);
}

TEST(OpFoldResultPrinting, Test0) {
mlir::MLIRContext context;
llvm::SmallVector<mlir::OpFoldResult> opFoldResults = {};
EXPECT_EQ(getConstantIntValuesString(opFoldResults), "[]");

mlir::OpFoldResult three = getAsIndexOpFoldResult(&context, 3);
opFoldResults.push_back(three);
EXPECT_EQ(getConstantIntValuesString(opFoldResults), "[3]");

mlir::OpFoldResult four = getAsIndexOpFoldResult(&context, 4);
opFoldResults.push_back(four);
EXPECT_EQ(getConstantIntValuesString(opFoldResults), "[3,4]");

opFoldResults.push_back(mlir::Value{});
EXPECT_EQ(getConstantIntValuesString(opFoldResults),
"[not all constant integers]");
}

} // namespace

int main(int argc, char **argv) {
Expand Down
Loading

0 comments on commit 9438dd2

Please sign in to comment.