@@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder
253
253
// / Compute the expanded sizes of the given \p expandShape for the
254
254
// / \p groupId-th reassociation group.
255
255
// / \p origSizes hold the sizes of the source shape as values.
256
- // / This is used to compute the new sizes in cases of dynamic shapes.
257
- // /
258
- // / sizes#i =
259
- // / baseSizes#groupId / product(expandShapeSizes#j,
260
- // / for j in group excluding reassIdx#i)
261
- // / Where reassIdx#i is the reassociation index at index i in \p groupId.
262
- // /
263
- // / \post result.size() == expandShape.getReassociationIndices()[groupId].size()
256
+ // / For static dim sizes, we take the values from the result type
257
+ // / of \p expandShape. For dynamic dims, we take the values from the
258
+ // / output_shape attribute.
264
259
// /
265
260
// / TODO: Move this utility function directly within ExpandShapeOp. For now,
266
261
// / this is not possible because this function uses the Affine dialect and the
@@ -275,42 +270,27 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
275
270
276
271
unsigned groupSize = reassocGroup.size ();
277
272
SmallVector<OpFoldResult> expandedSizes (groupSize);
278
-
279
- uint64_t productOfAllStaticSizes = 1 ;
280
- std::optional<unsigned > dynSizeIdx;
281
273
MemRefType expandShapeType = expandShape.getResultType ();
282
-
283
- // Fill up all the statically known sizes.
274
+ DenseMap<unsigned , Value> dynSizes;
275
+ Operation::operand_range dynOutShapes = expandShape.getOutputShape ();
276
+ for (unsigned i = 0 , dynCount = 0 ; i < expandShapeType.getRank (); i++) {
277
+ if (expandShapeType.isDynamicDim (i))
278
+ dynSizes[i] = dynOutShapes[dynCount++];
279
+ }
284
280
for (unsigned i = 0 ; i < groupSize; ++i) {
285
- uint64_t dimSize = expandShapeType.getDimSize (reassocGroup[i]);
281
+ unsigned index = reassocGroup[i];
282
+ uint64_t dimSize = expandShapeType.getDimSize (index);
286
283
if (ShapedType::isDynamic (dimSize)) {
287
- assert (!dynSizeIdx && " There must be at most one dynamic size per group" );
288
- dynSizeIdx = i;
284
+ expandedSizes[i] = dynSizes[index];
289
285
continue ;
290
286
}
291
- productOfAllStaticSizes *= dimSize;
292
287
expandedSizes[i] = builder.getIndexAttr (dimSize);
293
288
}
294
-
295
- // Compute the dynamic size using the original size and all the other known
296
- // static sizes:
297
- // expandSize = origSize / productOfAllStaticSizes.
298
- if (dynSizeIdx) {
299
- AffineExpr s0 = builder.getAffineSymbolExpr (0 );
300
- expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply (
301
- builder, expandShape.getLoc (), s0.floorDiv (productOfAllStaticSizes),
302
- origSizes[groupId]);
303
- }
304
-
305
289
return expandedSizes;
306
290
}
307
291
308
292
// / Compute the expanded strides of the given \p expandShape for the
309
293
// / \p groupId-th reassociation group.
310
- // / \p origStrides and \p origSizes hold respectively the strides and sizes
311
- // / of the source shape as values.
312
- // / This is used to compute the strides in cases of dynamic shapes and/or
313
- // / dynamic stride for this reassociation group.
314
294
// /
315
295
// / strides#i =
316
296
// / origStrides#reassDim * product(expandShapeSizes#j, for j in
@@ -320,11 +300,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
320
300
// / and expandShapeSizes#j is either:
321
301
// / - The constant size at dimension j, derived directly from the result type of
322
302
// / the expand_shape op, or
323
- // / - An affine expression: baseSizes#reassDim / product of all constant sizes
324
- // / in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325
- // / element.)
326
- // /
327
- // / \post result.size() == expandShape.getReassociationIndices()[groupId].size()
303
+ // / - The dynamic size at dimension j, derived from the output_shape attribute
304
+ // / of the expand shape op.
328
305
// /
329
306
// / TODO: Move this utility function directly within ExpandShapeOp. For now,
330
307
// / this is not possible because this function uses the Affine dialect and the
@@ -334,74 +311,58 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
334
311
ArrayRef<OpFoldResult> origSizes,
335
312
ArrayRef<OpFoldResult> origStrides,
336
313
unsigned groupId) {
337
- SmallVector<int64_t , 2 > reassocGroup =
338
- expandShape.getReassociationIndices ()[groupId];
314
+ auto reassocIndices = expandShape.getReassociationIndices ();
315
+ unsigned currIdx = 0 ;
316
+ for (unsigned i = 0 ; i < groupId; i++)
317
+ currIdx += reassocIndices[i].size ();
318
+ SmallVector<int64_t , 2 > reassocGroup = reassocIndices[groupId];
339
319
assert (!reassocGroup.empty () &&
340
320
" Reassociation group should have at least one dimension" );
341
321
342
322
unsigned groupSize = reassocGroup.size ();
343
323
MemRefType expandShapeType = expandShape.getResultType ();
344
-
345
- std::optional<int64_t > dynSizeIdx;
346
-
347
324
// Fill up the expanded strides, with the information we can deduce from the
348
325
// resulting shape.
349
- uint64_t currentStride = 1 ;
326
+ Location loc = expandShape. getLoc () ;
350
327
SmallVector<OpFoldResult> expandedStrides (groupSize);
351
- for (int i = groupSize - 1 ; i >= 0 ; --i) {
352
- expandedStrides[i] = builder.getIndexAttr (currentStride);
353
- uint64_t dimSize = expandShapeType.getDimSize (reassocGroup[i]);
354
- if (ShapedType::isDynamic (dimSize)) {
355
- assert (!dynSizeIdx && " There must be at most one dynamic size per group" );
356
- dynSizeIdx = i;
357
- continue ;
358
- }
359
-
360
- currentStride *= dimSize;
328
+ DenseMap<int , Value> dynSizes;
329
+ unsigned dynCount = 0 ;
330
+ Operation::operand_range dynOutShapes = expandShape.getOutputShape ();
331
+ for (unsigned i = 0 ; i < expandShapeType.getRank (); i++) {
332
+ if (expandShapeType.isDynamicDim (i))
333
+ dynSizes[i] = dynOutShapes[dynCount++];
361
334
}
362
-
363
- // Collect the statically known information about the original stride.
364
- Value source = expandShape.getSrc ();
365
- auto sourceType = cast<MemRefType>(source.getType ());
366
- auto [strides, offset] = sourceType.getStridesAndOffset ();
367
-
368
- OpFoldResult origStride = ShapedType::isDynamic (strides[groupId])
369
- ? origStrides[groupId]
370
- : builder.getIndexAttr (strides[groupId]);
371
-
372
- // Apply the original stride to all the strides.
373
- int64_t doneStrideIdx = 0 ;
374
- // If we saw a dynamic dimension, we need to fix-up all the strides up to
375
- // that dimension with the dynamic size.
376
- if (dynSizeIdx) {
377
- int64_t productOfAllStaticSizes = currentStride;
378
- assert (ShapedType::isDynamic (sourceType.getDimSize (groupId)) &&
379
- " We shouldn't be able to change dynamicity" );
380
- OpFoldResult origSize = origSizes[groupId];
381
-
382
- AffineExpr s0 = builder.getAffineSymbolExpr (0 );
383
- AffineExpr s1 = builder.getAffineSymbolExpr (1 );
384
- for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385
- int64_t baseExpandedStride =
386
- cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387
- .getInt ();
388
- expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply (
389
- builder, expandShape.getLoc (),
390
- (s0 * baseExpandedStride).floorDiv (productOfAllStaticSizes) * s1,
391
- {origSize, origStride});
392
- }
393
- }
394
-
395
- // Now apply the origStride to the remaining dimensions.
335
+ OpFoldResult origStride = origStrides[groupId];
396
336
AffineExpr s0 = builder.getAffineSymbolExpr (0 );
397
- for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398
- int64_t baseExpandedStride =
399
- cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400
- .getInt ();
401
- expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply (
402
- builder, expandShape.getLoc (), s0 * baseExpandedStride, {origStride});
337
+ AffineExpr s1 = builder.getAffineSymbolExpr (1 );
338
+ int64_t resultOffset;
339
+ SmallVector<int64_t , 4 > resultStrides;
340
+ (void )expandShapeType.getStridesAndOffset (resultStrides, resultOffset);
341
+ expandedStrides[groupSize - 1 ] =
342
+ !ShapedType::isDynamic (resultStrides[currIdx + groupSize - 1 ])
343
+ ? builder.getIndexAttr (resultStrides[currIdx + groupSize - 1 ])
344
+ : origStride;
345
+ OpFoldResult currentStride = builder.getIndexAttr (1 );
346
+ for (int i = groupSize - 2 ; i >= 0 ; i--) {
347
+ unsigned index = reassocGroup[i + 1 ];
348
+ // Multiply `currentStride` with `dimSize`.
349
+ currentStride =
350
+ expandShapeType.isDynamicDim (index)
351
+ ? makeComposedFoldedAffineApply (builder, loc, s0 * s1,
352
+ {currentStride, dynSizes[index]})
353
+ : makeComposedFoldedAffineApply (
354
+ builder, loc, s0 * expandShapeType.getDimSize (index),
355
+ {currentStride});
356
+ // Multiply `origStride` to all the strides in reassociation current group.
357
+ expandedStrides[i] = makeComposedFoldedAffineApply (
358
+ builder, loc, s0 * s1, {currentStride, origStride});
403
359
}
360
+ for (unsigned i = 0 ; i < groupSize; i++)
361
+ {
362
+ if (!ShapedType::isDynamic (resultStrides[currIdx + i]))
363
+ expandedStrides[i] = builder.getIndexAttr (resultStrides[currIdx + i]);
404
364
365
+ }
405
366
return expandedStrides;
406
367
}
407
368
0 commit comments