Skip to content

Commit b755dc0

Browse files
[MLIR][Memref] Improve expand-strided-metadata pass
Add support to obtain expanded sizes and strides of result of memref.expand_shape op having multiple dynamic dims within a reassociation set.
1 parent 4460766 commit b755dc0

File tree

3 files changed

+144
-257
lines changed

3 files changed

+144
-257
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 54 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder
253253
/// Compute the expanded sizes of the given \p expandShape for the
254254
/// \p groupId-th reassociation group.
255255
/// \p origSizes hold the sizes of the source shape as values.
256-
/// This is used to compute the new sizes in cases of dynamic shapes.
257-
///
258-
/// sizes#i =
259-
/// baseSizes#groupId / product(expandShapeSizes#j,
260-
/// for j in group excluding reassIdx#i)
261-
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
262-
///
263-
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
256+
/// For static dim sizes, we take the values from the result type
257+
/// of \p expandShape. For dynamic dims, we take the values from the
258+
/// output_shape attribute.
264259
///
265260
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
266261
/// this is not possible because this function uses the Affine dialect and the
@@ -275,42 +270,28 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
275270

276271
unsigned groupSize = reassocGroup.size();
277272
SmallVector<OpFoldResult> expandedSizes(groupSize);
278-
279-
uint64_t productOfAllStaticSizes = 1;
280-
std::optional<unsigned> dynSizeIdx;
281273
MemRefType expandShapeType = expandShape.getResultType();
282-
283-
// Fill up all the statically known sizes.
274+
DenseMap<unsigned, Value> dynSizes;
275+
Operation::operand_range dynOutShapes = expandShape.getOutputShape();
276+
for (unsigned i = 0, dynCount = 0, e = expandShapeType.getRank(); i < e;
277+
i++) {
278+
if (expandShapeType.isDynamicDim(i))
279+
dynSizes[i] = dynOutShapes[dynCount++];
280+
}
284281
for (unsigned i = 0; i < groupSize; ++i) {
285-
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
282+
unsigned index = reassocGroup[i];
283+
uint64_t dimSize = expandShapeType.getDimSize(index);
286284
if (ShapedType::isDynamic(dimSize)) {
287-
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288-
dynSizeIdx = i;
285+
expandedSizes[i] = dynSizes[index];
289286
continue;
290287
}
291-
productOfAllStaticSizes *= dimSize;
292288
expandedSizes[i] = builder.getIndexAttr(dimSize);
293289
}
294-
295-
// Compute the dynamic size using the original size and all the other known
296-
// static sizes:
297-
// expandSize = origSize / productOfAllStaticSizes.
298-
if (dynSizeIdx) {
299-
AffineExpr s0 = builder.getAffineSymbolExpr(0);
300-
expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301-
builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
302-
origSizes[groupId]);
303-
}
304-
305290
return expandedSizes;
306291
}
307292

308293
/// Compute the expanded strides of the given \p expandShape for the
309294
/// \p groupId-th reassociation group.
310-
/// \p origStrides and \p origSizes hold respectively the strides and sizes
311-
/// of the source shape as values.
312-
/// This is used to compute the strides in cases of dynamic shapes and/or
313-
/// dynamic stride for this reassociation group.
314295
///
315296
/// strides#i =
316297
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
@@ -320,11 +301,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
320301
/// and expandShapeSizes#j is either:
321302
/// - The constant size at dimension j, derived directly from the result type of
322303
/// the expand_shape op, or
323-
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
324-
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325-
/// element.)
326-
///
327-
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
304+
/// - The dynamic size at dimension j, derived from the output_shape attribute
305+
/// of the expand shape op.
328306
///
329307
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
330308
/// this is not possible because this function uses the Affine dialect and the
@@ -334,74 +312,56 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
334312
ArrayRef<OpFoldResult> origSizes,
335313
ArrayRef<OpFoldResult> origStrides,
336314
unsigned groupId) {
337-
SmallVector<int64_t, 2> reassocGroup =
338-
expandShape.getReassociationIndices()[groupId];
315+
auto reassocIndices = expandShape.getReassociationIndices();
316+
unsigned currIdx = 0;
317+
for (unsigned i = 0; i < groupId; i++)
318+
currIdx += reassocIndices[i].size();
319+
SmallVector<int64_t, 2> reassocGroup = reassocIndices[groupId];
339320
assert(!reassocGroup.empty() &&
340321
"Reassociation group should have at least one dimension");
341322

342323
unsigned groupSize = reassocGroup.size();
343324
MemRefType expandShapeType = expandShape.getResultType();
344-
345-
std::optional<int64_t> dynSizeIdx;
346-
347325
// Fill up the expanded strides, with the information we can deduce from the
348326
// resulting shape.
349-
uint64_t currentStride = 1;
327+
Location loc = expandShape.getLoc();
350328
SmallVector<OpFoldResult> expandedStrides(groupSize);
351-
for (int i = groupSize - 1; i >= 0; --i) {
352-
expandedStrides[i] = builder.getIndexAttr(currentStride);
353-
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354-
if (ShapedType::isDynamic(dimSize)) {
355-
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356-
dynSizeIdx = i;
357-
continue;
358-
}
359-
360-
currentStride *= dimSize;
329+
DenseMap<int, Value> dynSizes;
330+
unsigned dynCount = 0;
331+
Operation::operand_range dynOutShapes = expandShape.getOutputShape();
332+
for (unsigned i = 0, e = expandShapeType.getRank(); i < e; i++) {
333+
if (expandShapeType.isDynamicDim(i))
334+
dynSizes[i] = dynOutShapes[dynCount++];
361335
}
362-
363-
// Collect the statically known information about the original stride.
364-
Value source = expandShape.getSrc();
365-
auto sourceType = cast<MemRefType>(source.getType());
366-
auto [strides, offset] = sourceType.getStridesAndOffset();
367-
368-
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369-
? origStrides[groupId]
370-
: builder.getIndexAttr(strides[groupId]);
371-
372-
// Apply the original stride to all the strides.
373-
int64_t doneStrideIdx = 0;
374-
// If we saw a dynamic dimension, we need to fix-up all the strides up to
375-
// that dimension with the dynamic size.
376-
if (dynSizeIdx) {
377-
int64_t productOfAllStaticSizes = currentStride;
378-
assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379-
"We shouldn't be able to change dynamicity");
380-
OpFoldResult origSize = origSizes[groupId];
381-
382-
AffineExpr s0 = builder.getAffineSymbolExpr(0);
383-
AffineExpr s1 = builder.getAffineSymbolExpr(1);
384-
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385-
int64_t baseExpandedStride =
386-
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387-
.getInt();
388-
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389-
builder, expandShape.getLoc(),
390-
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391-
{origSize, origStride});
392-
}
393-
}
394-
395-
// Now apply the origStride to the remaining dimensions.
336+
OpFoldResult origStride = origStrides[groupId];
396337
AffineExpr s0 = builder.getAffineSymbolExpr(0);
397-
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398-
int64_t baseExpandedStride =
399-
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400-
.getInt();
401-
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402-
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
338+
AffineExpr s1 = builder.getAffineSymbolExpr(1);
339+
int64_t resultOffset;
340+
SmallVector<int64_t, 4> resultStrides;
341+
(void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset);
342+
expandedStrides[groupSize - 1] =
343+
!ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1])
344+
? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1])
345+
: origStride;
346+
OpFoldResult currentStride = builder.getIndexAttr(1);
347+
for (int i = groupSize - 2; i >= 0; i--) {
348+
unsigned index = reassocGroup[i + 1];
349+
// Multiply `currentStride` with `dimSize`.
350+
currentStride =
351+
expandShapeType.isDynamicDim(index)
352+
? makeComposedFoldedAffineApply(builder, loc, s0 * s1,
353+
{currentStride, dynSizes[index]})
354+
: makeComposedFoldedAffineApply(
355+
builder, loc, s0 * expandShapeType.getDimSize(index),
356+
{currentStride});
357+
// Multiply `origStride` to all the strides in reassociation current group.
358+
expandedStrides[i] = makeComposedFoldedAffineApply(
359+
builder, loc, s0 * s1, {currentStride, origStride});
360+
}
361+
for (unsigned i = 0; i < groupSize; i++) {
362+
if (!ShapedType::isDynamic(resultStrides[currIdx + i]))
363+
expandedStrides[i] = builder.getIndexAttr(resultStrides[currIdx + i]);
403364
}
404-
405365
return expandedStrides;
406366
}
407367

0 commit comments

Comments
 (0)