Skip to content

Commit ddf4d0c

Browse files
Improve expand-strided-metadata pass
Add support to obtain expanded sizes and strides of result of memref.expand_shape op having multiple dynamic dims within a reassociation set.
1 parent 4460766 commit ddf4d0c

File tree

4 files changed

+111
-211
lines changed

4 files changed

+111
-211
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 54 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder
253253
/// Compute the expanded sizes of the given \p expandShape for the
254254
/// \p groupId-th reassociation group.
255255
/// \p origSizes hold the sizes of the source shape as values.
256-
/// This is used to compute the new sizes in cases of dynamic shapes.
257-
///
258-
/// sizes#i =
259-
/// baseSizes#groupId / product(expandShapeSizes#j,
260-
/// for j in group excluding reassIdx#i)
261-
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
262-
///
263-
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
256+
/// For static dim sizes, we take the values from the result type
257+
/// of \p expandShape. For dynamic dims, we take the values from the
258+
/// output_shape attribute.
264259
///
265260
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
266261
/// this is not possible because this function uses the Affine dialect and the
@@ -275,42 +270,27 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
275270

276271
unsigned groupSize = reassocGroup.size();
277272
SmallVector<OpFoldResult> expandedSizes(groupSize);
278-
279-
uint64_t productOfAllStaticSizes = 1;
280-
std::optional<unsigned> dynSizeIdx;
281273
MemRefType expandShapeType = expandShape.getResultType();
282-
283-
// Fill up all the statically known sizes.
274+
DenseMap<unsigned, Value> dynSizes;
275+
Operation::operand_range dynOutShapes = expandShape.getOutputShape();
276+
for (unsigned i = 0, dynCount = 0; i < expandShapeType.getRank(); i++) {
277+
if (expandShapeType.isDynamicDim(i))
278+
dynSizes[i] = dynOutShapes[dynCount++];
279+
}
284280
for (unsigned i = 0; i < groupSize; ++i) {
285-
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
281+
unsigned index = reassocGroup[i];
282+
uint64_t dimSize = expandShapeType.getDimSize(index);
286283
if (ShapedType::isDynamic(dimSize)) {
287-
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288-
dynSizeIdx = i;
284+
expandedSizes[i] = dynSizes[index];
289285
continue;
290286
}
291-
productOfAllStaticSizes *= dimSize;
292287
expandedSizes[i] = builder.getIndexAttr(dimSize);
293288
}
294-
295-
// Compute the dynamic size using the original size and all the other known
296-
// static sizes:
297-
// expandSize = origSize / productOfAllStaticSizes.
298-
if (dynSizeIdx) {
299-
AffineExpr s0 = builder.getAffineSymbolExpr(0);
300-
expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301-
builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
302-
origSizes[groupId]);
303-
}
304-
305289
return expandedSizes;
306290
}
307291

308292
/// Compute the expanded strides of the given \p expandShape for the
309293
/// \p groupId-th reassociation group.
310-
/// \p origStrides and \p origSizes hold respectively the strides and sizes
311-
/// of the source shape as values.
312-
/// This is used to compute the strides in cases of dynamic shapes and/or
313-
/// dynamic stride for this reassociation group.
314294
///
315295
/// strides#i =
316296
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
@@ -320,11 +300,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
320300
/// and expandShapeSizes#j is either:
321301
/// - The constant size at dimension j, derived directly from the result type of
322302
/// the expand_shape op, or
323-
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
324-
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325-
/// element.)
326-
///
327-
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
303+
/// - The dynamic size at dimension j, derived from the output_shape attribute
304+
/// of the expand shape op.
328305
///
329306
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
330307
/// this is not possible because this function uses the Affine dialect and the
@@ -334,74 +311,58 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
334311
ArrayRef<OpFoldResult> origSizes,
335312
ArrayRef<OpFoldResult> origStrides,
336313
unsigned groupId) {
337-
SmallVector<int64_t, 2> reassocGroup =
338-
expandShape.getReassociationIndices()[groupId];
314+
auto reassocIndices = expandShape.getReassociationIndices();
315+
unsigned currIdx = 0;
316+
for (unsigned i = 0; i < groupId; i++)
317+
currIdx += reassocIndices[i].size();
318+
SmallVector<int64_t, 2> reassocGroup = reassocIndices[groupId];
339319
assert(!reassocGroup.empty() &&
340320
"Reassociation group should have at least one dimension");
341321

342322
unsigned groupSize = reassocGroup.size();
343323
MemRefType expandShapeType = expandShape.getResultType();
344-
345-
std::optional<int64_t> dynSizeIdx;
346-
347324
// Fill up the expanded strides, with the information we can deduce from the
348325
// resulting shape.
349-
uint64_t currentStride = 1;
326+
Location loc = expandShape.getLoc();
350327
SmallVector<OpFoldResult> expandedStrides(groupSize);
351-
for (int i = groupSize - 1; i >= 0; --i) {
352-
expandedStrides[i] = builder.getIndexAttr(currentStride);
353-
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354-
if (ShapedType::isDynamic(dimSize)) {
355-
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356-
dynSizeIdx = i;
357-
continue;
358-
}
359-
360-
currentStride *= dimSize;
328+
DenseMap<int, Value> dynSizes;
329+
unsigned dynCount = 0;
330+
Operation::operand_range dynOutShapes = expandShape.getOutputShape();
331+
for (unsigned i = 0; i < expandShapeType.getRank(); i++) {
332+
if (expandShapeType.isDynamicDim(i))
333+
dynSizes[i] = dynOutShapes[dynCount++];
361334
}
362-
363-
// Collect the statically known information about the original stride.
364-
Value source = expandShape.getSrc();
365-
auto sourceType = cast<MemRefType>(source.getType());
366-
auto [strides, offset] = sourceType.getStridesAndOffset();
367-
368-
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369-
? origStrides[groupId]
370-
: builder.getIndexAttr(strides[groupId]);
371-
372-
// Apply the original stride to all the strides.
373-
int64_t doneStrideIdx = 0;
374-
// If we saw a dynamic dimension, we need to fix-up all the strides up to
375-
// that dimension with the dynamic size.
376-
if (dynSizeIdx) {
377-
int64_t productOfAllStaticSizes = currentStride;
378-
assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379-
"We shouldn't be able to change dynamicity");
380-
OpFoldResult origSize = origSizes[groupId];
381-
382-
AffineExpr s0 = builder.getAffineSymbolExpr(0);
383-
AffineExpr s1 = builder.getAffineSymbolExpr(1);
384-
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385-
int64_t baseExpandedStride =
386-
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387-
.getInt();
388-
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389-
builder, expandShape.getLoc(),
390-
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391-
{origSize, origStride});
392-
}
393-
}
394-
395-
// Now apply the origStride to the remaining dimensions.
335+
OpFoldResult origStride = origStrides[groupId];
396336
AffineExpr s0 = builder.getAffineSymbolExpr(0);
397-
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398-
int64_t baseExpandedStride =
399-
cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400-
.getInt();
401-
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402-
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
337+
AffineExpr s1 = builder.getAffineSymbolExpr(1);
338+
int64_t resultOffset;
339+
SmallVector<int64_t, 4> resultStrides;
340+
(void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset);
341+
expandedStrides[groupSize - 1] =
342+
!ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1])
343+
? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1])
344+
: origStride;
345+
OpFoldResult currentStride = builder.getIndexAttr(1);
346+
for (int i = groupSize - 2; i >= 0; i--) {
347+
unsigned index = reassocGroup[i + 1];
348+
// Multiply `currentStride` with `dimSize`.
349+
currentStride =
350+
expandShapeType.isDynamicDim(index)
351+
? makeComposedFoldedAffineApply(builder, loc, s0 * s1,
352+
{currentStride, dynSizes[index]})
353+
: makeComposedFoldedAffineApply(
354+
builder, loc, s0 * expandShapeType.getDimSize(index),
355+
{currentStride});
356+
// Multiply `origStride` to all the strides in reassociation current group.
357+
expandedStrides[i] = makeComposedFoldedAffineApply(
358+
builder, loc, s0 * s1, {currentStride, origStride});
403359
}
360+
for (unsigned i = 0; i < groupSize; i++)
361+
{
362+
if (!ShapedType::isDynamic(resultStrides[currIdx + i]))
363+
expandedStrides[i] = builder.getIndexAttr(resultStrides[currIdx + i]);
404364

365+
}
405366
return expandedStrides;
406367
}
407368

mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,4 +686,4 @@ func.func @load_and_assume(
686686
memref.assume_alignment %arg0, 16 : memref<?x?xf32, strided<[?, ?], offset: ?>>
687687
%2 = memref.load %arg0[%i0, %i1] : memref<?x?xf32, strided<[?, ?], offset: ?>>
688688
func.return %2 : f32
689-
}
689+
}

0 commit comments

Comments
 (0)