Skip to content

Commit

Permalink
Pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Jan 31, 2025
1 parent 435701f commit 2409635
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,23 @@ FailureOr<int64_t> getSplitStride(ArrayRef<AMDAIE::DmaCpyNdOp> dmaOps,
/// DMA(%c, %lhs)
///
/// In the above snippet although we have 5 DMA ops for L2<->L1, only 3 of
/// them are unique. Hence we'd split %lhs into 3 unique splits, instead of 5.
static FailureOr<int64_t> fetchTotalUniqueL2L1(SmallVector<CopyOpInterface> copyLikeOps, bool fetchTarget) {
DenseSet<Operation*> uniqueLof;
/// them are unique. Hence we'd split %lhs into 3 unique splits, instead
/// of 5.
static FailureOr<int64_t> fetchTotalUniqueL2L1(
SmallVector<CopyOpInterface> copyLikeOps, bool fetchTarget) {
DenseSet<Operation *> uniqueLof;
for (CopyOpInterface copyOp : copyLikeOps) {
AMDAIE::LogicalObjectFifoFromMemrefOp lof = nullptr;
if (fetchTarget) {
lof = dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
copyOp.getTarget().getDefiningOp());
copyOp.getTarget().getDefiningOp());
} else {
lof = dyn_cast_if_present<AMDAIE::LogicalObjectFifoFromMemrefOp>(
copyOp.getSource().getDefiningOp());
copyOp.getSource().getDefiningOp());
}
if (!lof) {
return copyOp.emitOpError()<< "could not retrieve source/target objectFifo";
return copyOp.emitOpError()
<< "could not retrieve source/target objectFifo";
}
uniqueLof.insert(lof);
}
Expand Down Expand Up @@ -181,7 +184,8 @@ LogicalResult collectSplittingDims(
ModuleOp &moduleOp, const SmallVector<DmaObjFifoPairT> &dmaObjFifoPairs,
DenseMap<AMDAIE::DmaCpyNdOp, DmaSplitInfo> &dmaSplitInfoMap,
DenseMap<AMDAIE::LogicalObjectFifoFromMemrefOp, ObjFifoSplitInfo>
&objFifoSplitInfoMap, int64_t numCols) {
&objFifoSplitInfoMap,
int64_t numCols) {
for (auto [dmaOp, objFifo] : dmaObjFifoPairs) {
LLVM_DEBUG(llvm::dbgs() << "dmaOp: " << dmaOp << "\n");
LLVM_DEBUG(llvm::dbgs() << "objFifo: " << objFifo << "\n");
Expand Down Expand Up @@ -256,7 +260,8 @@ LogicalResult collectSplittingDims(
// Calculate the new source stride to be used for splitting the DMA.
int64_t newSourceStride =
splitStride != 1 ? splitDimSize / splitStride : 1;
FailureOr<int64_t> maybeUniqueL2L1 = fetchTotalUniqueL2L1(objFifo.getCopyLikeConsumers(), /*fetchTarget=*/true);
FailureOr<int64_t> maybeUniqueL2L1 = fetchTotalUniqueL2L1(
objFifo.getCopyLikeConsumers(), /*fetchTarget=*/true);
if (failed(maybeUniqueL2L1)) {
objFifo.emitOpError()
<< "could not retrieve total unique L2<->L1 pairs";
Expand All @@ -277,7 +282,8 @@ LogicalResult collectSplittingDims(
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << splitFactor << "\n");
dmaSplitInfoMap[dmaOp] = {sourceSplitDim, newSourceStride, targetSplitDim,
1, splitFactor};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, splitFactor, splitStride};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, splitFactor,
splitStride};
} else if (dmaOp.getSourceObjectFifo() == objFifo) {
// Find outermost dimension in the access pattern that has stride ==
// sizeAfterSplit and size != 1.
Expand Down Expand Up @@ -323,7 +329,8 @@ LogicalResult collectSplittingDims(
// Calculate the new target stride to be used for splitting the DMA.
int64_t newTargetStride =
splitStride != 1 ? splitDimSize / splitStride : 1;
FailureOr<int64_t> maybeUniqueL2L1 = fetchTotalUniqueL2L1(objFifo.getCopyLikeProducers(), /*fetchTarget=*/false);
FailureOr<int64_t> maybeUniqueL2L1 = fetchTotalUniqueL2L1(
objFifo.getCopyLikeProducers(), /*fetchTarget=*/false);
if (failed(maybeUniqueL2L1)) {
objFifo.emitOpError()
<< "could not retrieve total unique L2<->L1 pairs";
Expand All @@ -344,7 +351,8 @@ LogicalResult collectSplittingDims(
LLVM_DEBUG(llvm::dbgs() << "splitFactor: " << splitFactor << "\n");
dmaSplitInfoMap[dmaOp] = {sourceSplitDim, 1, targetSplitDim,
newTargetStride, splitFactor};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, splitFactor, splitStride};
objFifoSplitInfoMap[objFifo] = {objFifoSplitDim, splitFactor,
splitStride};
}
}
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ LogicalResult splitLogicalObjectFifo(
/// the provided split factor.
LogicalResult splitDoublyStridedOp(
IRRewriter &rewriter, AMDAIE::DoublyStridedOpInterface op,
size_t sourceSplitDim = 0, size_t targetSplitDim = 0,
int64_t splitFactor, int64_t sourceSplitStride = 1, int64_t targetSplitStride = 1);
size_t sourceSplitDim = 0, size_t targetSplitDim = 0, int64_t splitFactor,
int64_t sourceSplitStride = 1, int64_t targetSplitStride = 1);

} // namespace mlir::iree_compiler::AMDAIE

Expand Down

0 comments on commit 2409635

Please sign in to comment.