@@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder
253
253
// / Compute the expanded sizes of the given \p expandShape for the
254
254
// / \p groupId-th reassociation group.
255
255
// / \p origSizes hold the sizes of the source shape as values.
256
- // / This is used to compute the new sizes in cases of dynamic shapes.
257
- // /
258
- // / sizes#i =
259
- // / baseSizes#groupId / product(expandShapeSizes#j,
260
- // / for j in group excluding reassIdx#i)
261
- // / Where reassIdx#i is the reassociation index at index i in \p groupId.
262
- // /
263
- // / \post result.size() == expandShape.getReassociationIndices()[groupId].size()
256
+ // / For static dim sizes, we take the values from the result type
257
+ // / of \p expandShape. For dynamic dims, we take the values from the
258
+ // / output_shape attribute.
264
259
// /
265
260
// / TODO: Move this utility function directly within ExpandShapeOp. For now,
266
261
// / this is not possible because this function uses the Affine dialect and the
@@ -275,42 +270,28 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
275
270
276
271
unsigned groupSize = reassocGroup.size ();
277
272
SmallVector<OpFoldResult> expandedSizes (groupSize);
278
-
279
- uint64_t productOfAllStaticSizes = 1 ;
280
- std::optional<unsigned > dynSizeIdx;
281
273
MemRefType expandShapeType = expandShape.getResultType ();
282
-
283
- // Fill up all the statically known sizes.
274
+ DenseMap<unsigned , Value> dynSizes;
275
+ Operation::operand_range dynOutShapes = expandShape.getOutputShape ();
276
+ for (unsigned i = 0 , dynCount = 0 , e = expandShapeType.getRank (); i < e;
277
+ i++) {
278
+ if (expandShapeType.isDynamicDim (i))
279
+ dynSizes[i] = dynOutShapes[dynCount++];
280
+ }
284
281
for (unsigned i = 0 ; i < groupSize; ++i) {
285
- uint64_t dimSize = expandShapeType.getDimSize (reassocGroup[i]);
282
+ unsigned index = reassocGroup[i];
283
+ uint64_t dimSize = expandShapeType.getDimSize (index);
286
284
if (ShapedType::isDynamic (dimSize)) {
287
- assert (!dynSizeIdx && " There must be at most one dynamic size per group" );
288
- dynSizeIdx = i;
285
+ expandedSizes[i] = dynSizes[index];
289
286
continue ;
290
287
}
291
- productOfAllStaticSizes *= dimSize;
292
288
expandedSizes[i] = builder.getIndexAttr (dimSize);
293
289
}
294
-
295
- // Compute the dynamic size using the original size and all the other known
296
- // static sizes:
297
- // expandSize = origSize / productOfAllStaticSizes.
298
- if (dynSizeIdx) {
299
- AffineExpr s0 = builder.getAffineSymbolExpr (0 );
300
- expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply (
301
- builder, expandShape.getLoc (), s0.floorDiv (productOfAllStaticSizes),
302
- origSizes[groupId]);
303
- }
304
-
305
290
return expandedSizes;
306
291
}
307
292
308
293
// / Compute the expanded strides of the given \p expandShape for the
309
294
// / \p groupId-th reassociation group.
310
- // / \p origStrides and \p origSizes hold respectively the strides and sizes
311
- // / of the source shape as values.
312
- // / This is used to compute the strides in cases of dynamic shapes and/or
313
- // / dynamic stride for this reassociation group.
314
295
// /
315
296
// / strides#i =
316
297
// / origStrides#reassDim * product(expandShapeSizes#j, for j in
@@ -320,11 +301,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
320
301
// / and expandShapeSizes#j is either:
321
302
// / - The constant size at dimension j, derived directly from the result type of
322
303
// / the expand_shape op, or
323
- // / - An affine expression: baseSizes#reassDim / product of all constant sizes
324
- // / in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325
- // / element.)
326
- // /
327
- // / \post result.size() == expandShape.getReassociationIndices()[groupId].size()
304
+ // / - The dynamic size at dimension j, derived from the output_shape attribute
305
+ // / of the expand shape op.
328
306
// /
329
307
// / TODO: Move this utility function directly within ExpandShapeOp. For now,
330
308
// / this is not possible because this function uses the Affine dialect and the
@@ -334,74 +312,56 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
334
312
ArrayRef<OpFoldResult> origSizes,
335
313
ArrayRef<OpFoldResult> origStrides,
336
314
unsigned groupId) {
337
- SmallVector<int64_t , 2 > reassocGroup =
338
- expandShape.getReassociationIndices ()[groupId];
315
+ auto reassocIndices = expandShape.getReassociationIndices ();
316
+ unsigned currIdx = 0 ;
317
+ for (unsigned i = 0 ; i < groupId; i++)
318
+ currIdx += reassocIndices[i].size ();
319
+ SmallVector<int64_t , 2 > reassocGroup = reassocIndices[groupId];
339
320
assert (!reassocGroup.empty () &&
340
321
" Reassociation group should have at least one dimension" );
341
322
342
323
unsigned groupSize = reassocGroup.size ();
343
324
MemRefType expandShapeType = expandShape.getResultType ();
344
-
345
- std::optional<int64_t > dynSizeIdx;
346
-
347
325
// Fill up the expanded strides, with the information we can deduce from the
348
326
// resulting shape.
349
- uint64_t currentStride = 1 ;
327
+ Location loc = expandShape. getLoc () ;
350
328
SmallVector<OpFoldResult> expandedStrides (groupSize);
351
- for (int i = groupSize - 1 ; i >= 0 ; --i) {
352
- expandedStrides[i] = builder.getIndexAttr (currentStride);
353
- uint64_t dimSize = expandShapeType.getDimSize (reassocGroup[i]);
354
- if (ShapedType::isDynamic (dimSize)) {
355
- assert (!dynSizeIdx && " There must be at most one dynamic size per group" );
356
- dynSizeIdx = i;
357
- continue ;
358
- }
359
-
360
- currentStride *= dimSize;
329
+ DenseMap<int , Value> dynSizes;
330
+ unsigned dynCount = 0 ;
331
+ Operation::operand_range dynOutShapes = expandShape.getOutputShape ();
332
+ for (unsigned i = 0 , e = expandShapeType.getRank (); i < e; i++) {
333
+ if (expandShapeType.isDynamicDim (i))
334
+ dynSizes[i] = dynOutShapes[dynCount++];
361
335
}
362
-
363
- // Collect the statically known information about the original stride.
364
- Value source = expandShape.getSrc ();
365
- auto sourceType = cast<MemRefType>(source.getType ());
366
- auto [strides, offset] = sourceType.getStridesAndOffset ();
367
-
368
- OpFoldResult origStride = ShapedType::isDynamic (strides[groupId])
369
- ? origStrides[groupId]
370
- : builder.getIndexAttr (strides[groupId]);
371
-
372
- // Apply the original stride to all the strides.
373
- int64_t doneStrideIdx = 0 ;
374
- // If we saw a dynamic dimension, we need to fix-up all the strides up to
375
- // that dimension with the dynamic size.
376
- if (dynSizeIdx) {
377
- int64_t productOfAllStaticSizes = currentStride;
378
- assert (ShapedType::isDynamic (sourceType.getDimSize (groupId)) &&
379
- " We shouldn't be able to change dynamicity" );
380
- OpFoldResult origSize = origSizes[groupId];
381
-
382
- AffineExpr s0 = builder.getAffineSymbolExpr (0 );
383
- AffineExpr s1 = builder.getAffineSymbolExpr (1 );
384
- for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385
- int64_t baseExpandedStride =
386
- cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387
- .getInt ();
388
- expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply (
389
- builder, expandShape.getLoc (),
390
- (s0 * baseExpandedStride).floorDiv (productOfAllStaticSizes) * s1,
391
- {origSize, origStride});
392
- }
393
- }
394
-
395
- // Now apply the origStride to the remaining dimensions.
336
+ OpFoldResult origStride = origStrides[groupId];
396
337
AffineExpr s0 = builder.getAffineSymbolExpr (0 );
397
- for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398
- int64_t baseExpandedStride =
399
- cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400
- .getInt ();
401
- expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply (
402
- builder, expandShape.getLoc (), s0 * baseExpandedStride, {origStride});
338
+ AffineExpr s1 = builder.getAffineSymbolExpr (1 );
339
+ int64_t resultOffset;
340
+ SmallVector<int64_t , 4 > resultStrides;
341
+ (void )expandShapeType.getStridesAndOffset (resultStrides, resultOffset);
342
+ expandedStrides[groupSize - 1 ] =
343
+ !ShapedType::isDynamic (resultStrides[currIdx + groupSize - 1 ])
344
+ ? builder.getIndexAttr (resultStrides[currIdx + groupSize - 1 ])
345
+ : origStride;
346
+ OpFoldResult currentStride = builder.getIndexAttr (1 );
347
+ for (int i = groupSize - 2 ; i >= 0 ; i--) {
348
+ unsigned index = reassocGroup[i + 1 ];
349
+ // Multiply `currentStride` with `dimSize`.
350
+ currentStride =
351
+ expandShapeType.isDynamicDim (index)
352
+ ? makeComposedFoldedAffineApply (builder, loc, s0 * s1,
353
+ {currentStride, dynSizes[index]})
354
+ : makeComposedFoldedAffineApply (
355
+ builder, loc, s0 * expandShapeType.getDimSize (index),
356
+ {currentStride});
357
+ // Multiply `origStride` to all the strides in reassociation current group.
358
+ expandedStrides[i] = makeComposedFoldedAffineApply (
359
+ builder, loc, s0 * s1, {currentStride, origStride});
360
+ }
361
+ for (unsigned i = 0 ; i < groupSize; i++) {
362
+ if (!ShapedType::isDynamic (resultStrides[currIdx + i]))
363
+ expandedStrides[i] = builder.getIndexAttr (resultStrides[currIdx + i]);
403
364
}
404
-
405
365
return expandedStrides;
406
366
}
407
367
0 commit comments