@@ -2209,16 +2209,25 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2209
2209
if (layout_in != layout) {
2210
2210
return op.emitOpError (" Not implemented: unsupported layout for input" );
2211
2211
}
2212
- if (layout_out != layout) {
2212
+ // We support non-zero offsets in the output layout via lazy rotation.
2213
+ if (layout_out.bitwidth () != layout.bitwidth () ||
2214
+ layout_out.tiling () != layout.tiling () ||
2215
+ layout_out.implicit_dim () != layout.implicit_dim ()) {
2213
2216
return op.emitOpError (" Not implemented: unsupported layout for output" );
2214
2217
}
2215
2218
auto vty = op.getResult ().getType ();
2216
2219
if (vty.getRank () < 2 ) {
2217
2220
return op.emitOpError (" Not implemented: unsupported 1D shape" );
2218
2221
}
2219
- if (*(vty.getShape ().end () - 2 ) % *(layout.tiling ().end () - 2 ) != 0 ||
2220
- *(vty.getShape ().end () - 1 ) % *(layout.tiling ().end () - 1 ) != 0 ) {
2221
- return op.emitOpError (" Not implemented: unsupported unaliged shape" );
2222
+ if (*(vty.getShape ().end () - 2 ) % *(layout.tiling ().end () - 2 ) != 0 &&
2223
+ op.getDimension () == vty.getRank () - 2 ) {
2224
+ return op.emitOpError (
2225
+ " Not implemented: unsupported unaligned shape in sublane dimension" );
2226
+ }
2227
+ if (*(vty.getShape ().end () - 1 ) % *(layout.tiling ().end () - 1 ) != 0 &&
2228
+ op.getStride ().has_value ()) {
2229
+ return op.emitOpError (
2230
+ " Not implemented: unsupported unaligned shape in lane dimension" );
2222
2231
}
2223
2232
2224
2233
ImplicitLocOpBuilder builder (op.getLoc (), op.getOperation ());
@@ -2345,6 +2354,63 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2345
2354
return concatenate (chunks, axis);
2346
2355
};
2347
2356
2357
+ // Applies lazy rotation (see go/pltpu-roll for details).
2358
+ auto lazyRotate = [&](const xla::Array<Value> &vregs, int64_t shift,
2359
+ int axis) {
2360
+ const int tiling_dim = axis - (vregs.num_dimensions () - 2 );
2361
+ const int64_t tile_size = ctx.target_shape [tiling_dim];
2362
+ const int64_t input_size = vty.getShape ()[axis];
2363
+ const int64_t normalized_shift = shift % input_size;
2364
+ const int64_t start_idx = input_size - normalized_shift;
2365
+ const int64_t start_vreg_idx = start_idx / tile_size;
2366
+ const int64_t valid_amount = input_size % tile_size;
2367
+
2368
+ auto concat = concatenate ({vregs, vregs}, axis);
2369
+ auto chunks = split (concat, axis);
2370
+ int64_t original_num_chunks = chunks.size () / 2 ;
2371
+ xla::Array<Value> front_chunk_copy (chunks.front ());
2372
+
2373
+ Value rotate_amount = mlirI32Const (valid_amount);
2374
+ auto iota = builder.create <tpu::IotaOp>(
2375
+ i32_vreg, builder.getI32IntegerAttr (tiling_dim));
2376
+ auto mask = builder.create <arith::CmpIOp>(
2377
+ arith::CmpIPredicate::sge, iota,
2378
+ builder.create <arith::ConstantOp>(DenseElementsAttr::get (
2379
+ i32_vreg, builder.getI32IntegerAttr (valid_amount))));
2380
+ // overwrite padding in the last vreg with valid data from the first vreg.
2381
+ chunks.back ().Each ([&](absl::Span<const int64_t > idxs, Value *v) {
2382
+ *v = builder.create <arith::SelectOp>(
2383
+ mask,
2384
+ builder.create <tpu::DynamicRotateOp>(
2385
+ res_vreg_ty, front_chunk_copy (idxs), rotate_amount, tiling_dim,
2386
+ nullptr , nullptr ),
2387
+ *v);
2388
+ });
2389
+ // rotate the vregs starting from the middle vreg.
2390
+ for (int64_t i = original_num_chunks; i < chunks.size (); ++i) {
2391
+ chunks[i].Each ([&](absl::Span<const int64_t > idxs, Value *v) {
2392
+ *v = builder.create <tpu::DynamicRotateOp>(
2393
+ res_vreg_ty, *v, rotate_amount, tiling_dim, nullptr , nullptr );
2394
+ });
2395
+ }
2396
+ // blend the vregs to overwrite the padding.
2397
+ for (int64_t i = original_num_chunks - 1 ; i < chunks.size () - 1 ; ++i) {
2398
+ chunks[i].Each ([&](absl::Span<const int64_t > idxs, Value *v) {
2399
+ *v = builder.create <arith::SelectOp>(mask, chunks[i + 1 ](idxs), *v);
2400
+ });
2401
+ }
2402
+ SmallVector<int64_t > result_dimensions =
2403
+ layout_out.tileArrayImplicitShape (vty.getShape (), ctx.target_shape );
2404
+ // assemble the result
2405
+ xla::Array<Value> result (result_dimensions);
2406
+ SmallVector<int64_t > starts (result.num_dimensions (), 0 );
2407
+ for (int64_t i = 0 ; i < result_dimensions[axis]; ++i) {
2408
+ starts[axis] = i;
2409
+ result.UpdateSlice (chunks[i + start_vreg_idx], starts);
2410
+ }
2411
+ return result;
2412
+ };
2413
+
2348
2414
std::function<xla::Array<Value>(const xla::Array<Value> &, Value, int , int )>
2349
2415
rotate;
2350
2416
rotate = [&](const xla::Array<Value> &vregs, Value shift, int axis,
@@ -2353,7 +2419,15 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2353
2419
CHECK (axis >= 0 && axis < vregs.num_dimensions ());
2354
2420
int tiling_dim = axis - (vregs.num_dimensions () - 2 );
2355
2421
CHECK ((tiling_dim != 1 && stride == 0 ) || (tiling_dim == 1 && stride >= 0 ));
2422
+ const bool has_padding =
2423
+ (tiling_dim == 0 || tiling_dim == 1 ) &&
2424
+ vty.getShape ()[axis] % ctx.target_shape [tiling_dim] != 0 ;
2356
2425
SmallVector<xla::Array<Value>, 4 > chunks;
2426
+ // Handle rotation with static shift and padding lazily.
2427
+ if (auto shift_cst = getIntConst (shift, /* silent=*/ true );
2428
+ succeeded (shift_cst) && has_padding) {
2429
+ return lazyRotate (vregs, shift_cst.value (), axis);
2430
+ }
2357
2431
// Handle rotation with static shift.
2358
2432
if (auto shift_cst = getIntConst (shift, /* silent=*/ true );
2359
2433
succeeded (shift_cst)) {
@@ -2445,7 +2519,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
2445
2519
roll_by *= 2 ;
2446
2520
}
2447
2521
return result;
2448
- };
2522
+ }; // end of rotate
2449
2523
2450
2524
xla::Array<Value> out_tiles (in_tiles.dimensions ());
2451
2525
const auto dim = op.getDimension ();
0 commit comments