Skip to content

Commit

Permalink
Review comment v2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Jan 29, 2025
1 parent 380c729 commit 0904a94
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,77 @@ FailureOr<int64_t> getSplitStride(ArrayRef<AMDAIE::DmaCpyNdOp> dmaOps,
return splitStride;
}

/// Fetch and store all unique pairs of L2<->L1 Copy ops. This would helps us
/// figure out the split factor for all LogicalObjectFifos. Basically we get to
/// decide how many splits to perform for a particular L2 ObjectFifo based on
/// the total unique L2<->L1 Copy ops.
/// Eg:
/// %lhs = LOF_on_L2
/// %a = LOF_on_L1_0
/// %b = LOF_on_L1_1
/// %c = LOF_on_L1_2
/// DMA(%a, %lhs)
/// DMA(%b, %lhs)
/// DMA(%c, %lhs)
/// DMA(%b, %lhs)
/// DMA(%c, %lhs)
///
/// In the above snippet although we have 5 DMA ops for L2<->L1, only 3 of
/// them are unique. Hence we'd split %lhs into 3 unique splits, instead
/// of 5.
static DenseMap<Operation *, int64_t> fetchUniqueL2L1(ModuleOp moduleOp) {
DenseMap<Operation *, DenseSet<Operation *>> uniqueL2L1Pair;
moduleOp->walk([&](Operation *op) -> WalkResult {
if (auto copyOp = dyn_cast<CopyOpInterface>(op)) {
auto source = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getSource().getDefiningOp());
auto target = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getTarget().getDefiningOp());
if (!source || !target) {
return WalkResult::interrupt();
}
auto sourceFromMemrefOp =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
copyOp.getSource().getDefiningOp());
auto targetFromMemrefOp =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
copyOp.getTarget().getDefiningOp());
if (!sourceFromMemrefOp || !targetFromMemrefOp) {
return WalkResult::interrupt();
}
Operation *l2LofOp = nullptr;
Operation *l1LofOp = nullptr;
// L2 -> L1.
if (source.getMemorySpaceAsUInt() == 1 &&
target.getMemorySpaceAsUInt() == 2) {
l2LofOp = sourceFromMemrefOp;
l1LofOp = targetFromMemrefOp;
} else if (source.getMemorySpaceAsUInt() == 2 &&
target.getMemorySpaceAsUInt() == 1) {
// L1 -> L2.
l2LofOp = targetFromMemrefOp;
l1LofOp = sourceFromMemrefOp;
} else {
return WalkResult::advance();
}
uniqueL2L1Pair[l2LofOp].insert(l1LofOp);
return WalkResult::advance();
}
return WalkResult::advance();
});

DenseMap<Operation *, int64_t> uniqueL2L1Count;
for (auto &&[l2Lof, l1Lofs] : uniqueL2L1Pair)
uniqueL2L1Count[l2Lof] = l1Lofs.size();

return uniqueL2L1Count;
}

/// Find the logical objectFifo and DMA source/target splitting dimensions for
/// each DMA and objectFifo pair.
///
/// Each pair is handled in the following way:
/// At first we find count of total unique L2<->L1 pairs for all L2 objectFifos.
/// Then each DMA and objectFifo pair is handled in the following way:
/// First, compute the objectFifo splitting dimension based on the last non-unit
/// shape dimension and the number of available columns. Afterwards, depending
/// on which logical objectFifo is being split on, find the outermost dimension
Expand All @@ -139,11 +206,12 @@ FailureOr<int64_t> getSplitStride(ArrayRef<AMDAIE::DmaCpyNdOp> dmaOps,
/// splitting because that's the number of elements that should be
/// produced/consumed on the respective sides before splitting.
LogicalResult collectSplittingDims(
const SmallVector<DmaObjFifoPairT> &dmaObjFifoPairs,
ModuleOp &moduleOp, const SmallVector<DmaObjFifoPairT> &dmaObjFifoPairs,
DenseMap<AMDAIE::DmaCpyNdOp, DmaSplitInfo> &dmaSplitInfoMap,
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, ObjFifoSplitInfo>
&objFifoSplitInfoMap,
int64_t numCols) {
DenseMap<Operation *, int64_t> uniqueL2L1Pair = fetchUniqueL2L1(moduleOp);
for (auto [dmaOp, objFifo] : dmaObjFifoPairs) {
LLVM_DEBUG(llvm::dbgs() << "dmaOp: " << dmaOp << "\n");
LLVM_DEBUG(llvm::dbgs() << "objFifo: " << objFifo << "\n");
Expand Down Expand Up @@ -218,17 +286,24 @@ LogicalResult collectSplittingDims(
// Calculate the new source stride to be used for splitting the DMA.
int64_t newSourceStride =
splitStride != 1 ? splitDimSize / splitStride : 1;
int64_t splitFactor = std::gcd(uniqueL2L1Pair[objFifo], numCols);
int64_t sourceSize = (*sourceSizes)[sourceSplitDim];
int64_t targetSize = (*targetSizes)[targetSplitDim];
if (sourceSize % splitFactor != 0 || targetSize % splitFactor != 0) {
splitFactor = std::gcd(sourceSize, targetSize);
}
LLVM_DEBUG(llvm::dbgs() << "sourceSplitDim: " << sourceSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "targetSplitDim: " << targetSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "newSourceStride: " << newSourceStride << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "objFifoSplitDim: " << objFifoSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitStride: " << splitStride << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << numCols << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << splitFactor << "\n");
dmaSplitInfoMap[dmaOp] = {sourceSplitDim, newSourceStride, targetSplitDim,
1, numCols};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, numCols, splitStride};
1, splitFactor};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, splitFactor,
splitStride};
} else if (dmaOp.getSourceObjectFifo() == objFifo) {
// Find outermost dimension in the access pattern that has stride ==
// sizeAfterSplit and size != 1.
Expand Down Expand Up @@ -274,17 +349,24 @@ LogicalResult collectSplittingDims(
// Calculate the new target stride to be used for splitting the DMA.
int64_t newTargetStride =
splitStride != 1 ? splitDimSize / splitStride : 1;
int64_t splitFactor = std::gcd(uniqueL2L1Pair[objFifo], numCols);
int64_t sourceSize = (*sourceSizes)[sourceSplitDim];
int64_t targetSize = (*targetSizes)[targetSplitDim];
if (sourceSize % splitFactor != 0 || targetSize % splitFactor != 0) {
splitFactor = std::gcd(sourceSize, targetSize);
}
LLVM_DEBUG(llvm::dbgs() << "sourceSplitDim: " << sourceSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "targetSplitDim: " << targetSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "newTargetStride: " << newTargetStride << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "objFifoSplitDim: " << objFifoSplitDim << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitStride: " << splitStride << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << numCols << "\n");
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << splitFactor << "\n");
dmaSplitInfoMap[dmaOp] = {sourceSplitDim, 1, targetSplitDim,
newTargetStride, numCols};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, numCols, splitStride};
newTargetStride, splitFactor};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, splitFactor,
splitStride};
}
}
return success();
Expand Down Expand Up @@ -343,115 +425,28 @@ void AMDAIESplitLogicalObjFifosPass::runOnOperation() {
DenseMap<AMDAIE::DmaCpyNdOp, DmaSplitInfo> dmaSplitInfoMap;
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, ObjFifoSplitInfo>
objFifoSplitInfoMap;
if (failed(collectSplittingDims(dmaObjFifoPairs, dmaSplitInfoMap,
if (failed(collectSplittingDims(moduleOp, dmaObjFifoPairs, dmaSplitInfoMap,
objFifoSplitInfoMap, numColumns))) {
return signalPassFailure();
}

// Fetch and store all unique pairs of L2<->L1 Copy ops. This would helps us
// figure out the split factor for all LogicalObjectFifos. Basically we get to
// decide how many splits to perform for a particular L2 ObjectFifo based on
// the total unique L2<->L1 Copy ops.
// Eg:
// %lhs = LOF_on_L2
// %a = LOF_on_L1_0
// %b = LOF_on_L1_1
// %c = LOF_on_L1_2
// DMA(%a, %lhs)
// DMA(%b, %lhs)
// DMA(%c, %lhs)
// DMA(%b, %lhs)
// DMA(%c, %lhs)
//
// In the above snippet although we have 5 DMA ops for L2<->L1, only 3 of
// them are unique. Hence we'd split %lhs into 3 unique splits, instead
// of 5.
DenseMap<Operation *, DenseSet<Operation *>> uniqueL2L1Pair;
moduleOp->walk([&](Operation *op) -> WalkResult {
if (auto copyOp = dyn_cast<CopyOpInterface>(op)) {
auto source = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getSource().getDefiningOp());
auto target = dyn_cast_if_present<AMDAIE::LogicalObjFifoOpInterface>(
copyOp.getTarget().getDefiningOp());
if (!source || !target) {
return WalkResult::interrupt();
}
auto sourceFromMemrefOp =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
copyOp.getSource().getDefiningOp());
auto targetFromMemrefOp =
dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
copyOp.getTarget().getDefiningOp());
if (!sourceFromMemrefOp || !targetFromMemrefOp) {
return WalkResult::interrupt();
}
Operation *l2DefOp = nullptr;
Operation *l1DefOp = nullptr;
// L2 -> L1.
if (target.getMemorySpaceAsUInt() == 2) {
l2DefOp = sourceFromMemrefOp.getMemref().getDefiningOp();
l1DefOp = targetFromMemrefOp;
} else if (source.getMemorySpaceAsUInt() == 2) {
// L1 -> L2.
l2DefOp = targetFromMemrefOp.getMemref().getDefiningOp();
l1DefOp = sourceFromMemrefOp;
} else {
return WalkResult::advance();
}
if (!l2DefOp || !l1DefOp) {
return WalkResult::interrupt();
}
uniqueL2L1Pair[l2DefOp].insert(l1DefOp);
return WalkResult::advance();
}
return WalkResult::advance();
});

/// Split the DMA and objectFifo ops based on the calcuated splitting
/// dimensions.
DenseMap<Operation *, int64_t> splitFactorOfLOF;
for (auto &&[dmaOp, dmaSplitInfo] : dmaSplitInfoMap) {
auto dmaCpyNd = cast<AMDAIE::DmaCpyNdOp>(dmaOp.getOperation());
int64_t splitFactor = dmaSplitInfo.splitSize;
auto sourceDefOp =
dmaCpyNd.getSource()
.getDefiningOp<AMDAIE::LogicalObjectFifoFromMemrefOp>();
auto targetDefOp =
dmaCpyNd.getTarget()
.getDefiningOp<AMDAIE::LogicalObjectFifoFromMemrefOp>();
if (!sourceDefOp || !targetDefOp) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected defining op of source/target for : " << dmaOp);
return signalPassFailure();
}
if (dmaCpyNd.getSourceMemorySpaceAsUInt() == 0) {
if (Operation *l2DefOp = targetDefOp.getMemref().getDefiningOp())
splitFactor = uniqueL2L1Pair[l2DefOp].size();
} else if (dmaCpyNd.getTargetMemorySpaceAsUInt() == 0) {
if (Operation *l2DefOp = sourceDefOp.getMemref().getDefiningOp())
splitFactor = uniqueL2L1Pair[l2DefOp].size();
}
// In cases where the number of available columns < the inferred split
// factor, we'll cap the final split factor by the lower bound.
splitFactor = std::gcd(dmaSplitInfo.splitSize, splitFactor);
FailureOr<int64_t> maybeSplitFactor = splitDoublyStridedOp(
rewriter, dmaCpyNd, dmaSplitInfo.sourceSplitDim,
dmaSplitInfo.targetSplitDim, splitFactor, dmaSplitInfo.newSourceStride,
dmaSplitInfo.newTargetStride);
if (failed(maybeSplitFactor)) {
auto stridedOp =
cast<AMDAIE::DoublyStridedOpInterface>(dmaOp.getOperation());
if (failed(splitDoublyStridedOp(
rewriter, stridedOp, dmaSplitInfo.sourceSplitDim,
dmaSplitInfo.targetSplitDim, dmaSplitInfo.splitSize,
dmaSplitInfo.newSourceStride, dmaSplitInfo.newTargetStride))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to perform splitting of the DMA op: " << dmaOp);
return signalPassFailure();
}
// The above function might change the split factor based on divisibility
// with source/target. Therefore here we maintain the final split factor
// which we'll use later to split the LogicalObjectFifo.
splitFactorOfLOF[targetDefOp] = *maybeSplitFactor;
splitFactorOfLOF[sourceDefOp] = *maybeSplitFactor;
}
for (auto &&[objFifo, splitInfo] : objFifoSplitInfoMap) {
if (failed(splitLogicalObjectFifo(rewriter, objFifo, splitInfo.splitDim,
splitFactorOfLOF[objFifo],
splitInfo.splitSize,
splitInfo.splitStride))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to perform splitting of objectFifo op");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,13 +753,12 @@ LogicalResult splitLogicalObjectFifo(IRRewriter &rewriter,
/// Split doubly strided operations on a source and target split dimension with
/// the provided split factor which might get updated. On success, return the
/// split factor to the caller, else return failure.
FailureOr<int64_t> splitDoublyStridedOp(IRRewriter &rewriter,
AMDAIE::DoublyStridedOpInterface op,
size_t sourceSplitDim,
size_t targetSplitDim,
std::optional<size_t> maybeSplitFactor,
int64_t sourceSplitStride,
int64_t targetSplitStride) {
LogicalResult splitDoublyStridedOp(IRRewriter &rewriter,
AMDAIE::DoublyStridedOpInterface op,
size_t sourceSplitDim, size_t targetSplitDim,
std::optional<size_t> maybeSplitFactor,
int64_t sourceSplitStride,
int64_t targetSplitStride) {
if (!op->use_empty())
return op.emitOpError() << "can't be split because it has uses";
SmallVector<OpFoldResult> sourceOffsets = op.getSourceMixedOffsets();
Expand Down Expand Up @@ -802,15 +801,9 @@ FailureOr<int64_t> splitDoublyStridedOp(IRRewriter &rewriter,
}
int64_t sourceSize = maybeSourceSize.value();
int64_t targetSize = maybeTargetSize.value();
int64_t splitFactor = maybeSplitFactor.has_value()
? maybeSplitFactor.value()
: std::gcd(sourceSize, targetSize);
if (sourceSize % splitFactor != 0 || targetSize % splitFactor != 0) {
int64_t newSplitFactor = std::gcd(sourceSize, targetSize);
LLVM_DEBUG(llvm::dbgs() << "split factor has been changed from "
<< splitFactor << " to " << newSplitFactor);
splitFactor = newSplitFactor;
}
assert(maybeSplitFactor.has_value() &&
"expected split factor to be sent by the caller");
int64_t splitFactor = maybeSplitFactor.value();

int64_t newSourceSize = sourceSize / splitFactor;
int64_t newTargetSize = targetSize / splitFactor;
Expand Down Expand Up @@ -859,7 +852,7 @@ FailureOr<int64_t> splitDoublyStridedOp(IRRewriter &rewriter,
targetOffsets[targetSplitDim] = newTargetOffset.value();
}
rewriter.eraseOp(op);
return splitFactor;
return success();
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ LogicalResult splitLogicalObjectFifo(
/// split factor to the caller, else return failure.
/// NOTE: If no split factor is provided, the doubly strided operation will be
/// split on the size of the dimension being split.
FailureOr<int64_t> splitDoublyStridedOp(
LogicalResult splitDoublyStridedOp(
IRRewriter &rewriter, AMDAIE::DoublyStridedOpInterface op,
size_t sourceSplitDim = 0, size_t targetSplitDim = 0,
std::optional<size_t> splitFactor = std::nullopt,
Expand Down

0 comments on commit 0904a94

Please sign in to comment.