Skip to content

[MLIR][Memref] Improve expand-strided-metadata pass #129642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 54 additions & 94 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder
/// Compute the expanded sizes of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// sizes#i =
/// baseSizes#groupId / product(expandShapeSizes#j,
/// for j in group excluding reassIdx#i)
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
/// For static dim sizes, we take the values from the result type
/// of \p expandShape. For dynamic dims, we take the values from the
/// output_shape attribute.
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
Expand All @@ -275,42 +270,28 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,

unsigned groupSize = reassocGroup.size();
SmallVector<OpFoldResult> expandedSizes(groupSize);

uint64_t productOfAllStaticSizes = 1;
std::optional<unsigned> dynSizeIdx;
MemRefType expandShapeType = expandShape.getResultType();

// Fill up all the statically known sizes.
DenseMap<unsigned, Value> dynSizes;
Operation::operand_range dynOutShapes = expandShape.getOutputShape();
for (unsigned i = 0, dynCount = 0, e = expandShapeType.getRank(); i < e;
i++) {
if (expandShapeType.isDynamicDim(i))
dynSizes[i] = dynOutShapes[dynCount++];
}
for (unsigned i = 0; i < groupSize; ++i) {
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
unsigned index = reassocGroup[i];
uint64_t dimSize = expandShapeType.getDimSize(index);
if (ShapedType::isDynamic(dimSize)) {
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
dynSizeIdx = i;
expandedSizes[i] = dynSizes[index];
continue;
}
productOfAllStaticSizes *= dimSize;
expandedSizes[i] = builder.getIndexAttr(dimSize);
}

// Compute the dynamic size using the original size and all the other known
// static sizes:
// expandSize = origSize / productOfAllStaticSizes.
if (dynSizeIdx) {
AffineExpr s0 = builder.getAffineSymbolExpr(0);
expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
origSizes[groupId]);
}

return expandedSizes;
}

/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// strides#i =
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
Expand All @@ -320,11 +301,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
/// and expandShapeSizes#j is either:
/// - The constant size at dimension j, derived directly from the result type of
/// the expand_shape op, or
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
/// element.)
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
/// - The dynamic size at dimension j, derived from the output_shape attribute
/// of the expand shape op.
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
Expand All @@ -334,74 +312,56 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
auto reassocIndices = expandShape.getReassociationIndices();
unsigned currIdx = 0;
for (unsigned i = 0; i < groupId; i++)
currIdx += reassocIndices[i].size();
SmallVector<int64_t, 2> reassocGroup = reassocIndices[groupId];
assert(!reassocGroup.empty() &&
"Reassociation group should have at least one dimension");

unsigned groupSize = reassocGroup.size();
MemRefType expandShapeType = expandShape.getResultType();

std::optional<int64_t> dynSizeIdx;

// Fill up the expanded strides, with the information we can deduce from the
// resulting shape.
uint64_t currentStride = 1;
Location loc = expandShape.getLoc();
SmallVector<OpFoldResult> expandedStrides(groupSize);
for (int i = groupSize - 1; i >= 0; --i) {
expandedStrides[i] = builder.getIndexAttr(currentStride);
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
if (ShapedType::isDynamic(dimSize)) {
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
dynSizeIdx = i;
continue;
}

currentStride *= dimSize;
DenseMap<int, Value> dynSizes;
unsigned dynCount = 0;
Operation::operand_range dynOutShapes = expandShape.getOutputShape();
for (unsigned i = 0, e = expandShapeType.getRank(); i < e; i++) {
if (expandShapeType.isDynamicDim(i))
dynSizes[i] = dynOutShapes[dynCount++];
}

// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = sourceType.getStridesAndOffset();

OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
? origStrides[groupId]
: builder.getIndexAttr(strides[groupId]);

// Apply the original stride to all the strides.
int64_t doneStrideIdx = 0;
// If we saw a dynamic dimension, we need to fix-up all the strides up to
// that dimension with the dynamic size.
if (dynSizeIdx) {
int64_t productOfAllStaticSizes = currentStride;
assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
"We shouldn't be able to change dynamicity");
OpFoldResult origSize = origSizes[groupId];

AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
int64_t baseExpandedStride =
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
{origSize, origStride});
}
}

// Now apply the origStride to the remaining dimensions.
OpFoldResult origStride = origStrides[groupId];
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
int64_t baseExpandedStride =
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
AffineExpr s1 = builder.getAffineSymbolExpr(1);
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
(void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset);
expandedStrides[groupSize - 1] =
!ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1])
? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1])
: origStride;
OpFoldResult currentStride = builder.getIndexAttr(1);
for (int i = groupSize - 2; i >= 0; i--) {
unsigned index = reassocGroup[i + 1];
// Multiply `currentStride` with `dimSize`.
currentStride =
expandShapeType.isDynamicDim(index)
? makeComposedFoldedAffineApply(builder, loc, s0 * s1,
{currentStride, dynSizes[index]})
: makeComposedFoldedAffineApply(
builder, loc, s0 * expandShapeType.getDimSize(index),
{currentStride});
// Multiply `origStride` to all the strides in reassociation current group.
expandedStrides[i] = makeComposedFoldedAffineApply(
builder, loc, s0 * s1, {currentStride, origStride});
}
for (unsigned i = 0; i < groupSize; i++) {
if (!ShapedType::isDynamic(resultStrides[currIdx + i]))
expandedStrides[i] = builder.getIndexAttr(resultStrides[currIdx + i]);
}

return expandedStrides;
}

Expand Down
Loading