Skip to content

Commit 03fb141

Browse files
Support pltpu.roll on sublanes when not all lanes are used.
PiperOrigin-RevId: 743964736
1 parent c11f2d1 commit 03fb141

File tree

3 files changed

+122
-8
lines changed

3 files changed

+122
-8
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

+79-5
Original file line numberDiff line numberDiff line change
@@ -2209,16 +2209,25 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
22092209
if (layout_in != layout) {
22102210
return op.emitOpError("Not implemented: unsupported layout for input");
22112211
}
2212-
if (layout_out != layout) {
2212+
// We support non-zero offsets in the output layout via lazy rotation.
2213+
if (layout_out.bitwidth() != layout.bitwidth() ||
2214+
layout_out.tiling() != layout.tiling() ||
2215+
layout_out.implicit_dim() != layout.implicit_dim()) {
22132216
return op.emitOpError("Not implemented: unsupported layout for output");
22142217
}
22152218
auto vty = op.getResult().getType();
22162219
if (vty.getRank() < 2) {
22172220
return op.emitOpError("Not implemented: unsupported 1D shape");
22182221
}
2219-
if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 ||
2220-
*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) {
2221-
return op.emitOpError("Not implemented: unsupported unaliged shape");
2222+
if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 &&
2223+
op.getDimension() == vty.getRank() - 2) {
2224+
return op.emitOpError(
2225+
"Not implemented: unsupported unaligned shape in sublane dimension");
2226+
}
2227+
if (*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0 &&
2228+
op.getStride().has_value()) {
2229+
return op.emitOpError(
2230+
"Not implemented: unsupported unaligned shape in lane dimension");
22222231
}
22232232

22242233
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
@@ -2345,6 +2354,63 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
23452354
return concatenate(chunks, axis);
23462355
};
23472356

2357+
// Applies lazy rotation (see go/pltpu-roll for details).
2358+
auto lazyRotate = [&](const xla::Array<Value> &vregs, int64_t shift,
2359+
int axis) {
2360+
const int tiling_dim = axis - (vregs.num_dimensions() - 2);
2361+
const int64_t tile_size = ctx.target_shape[tiling_dim];
2362+
const int64_t input_size = vty.getShape()[axis];
2363+
const int64_t normalized_shift = shift % input_size;
2364+
const int64_t start_idx = input_size - normalized_shift;
2365+
const int64_t start_vreg_idx = start_idx / tile_size;
2366+
const int64_t valid_amount = input_size % tile_size;
2367+
2368+
auto concat = concatenate({vregs, vregs}, axis);
2369+
auto chunks = split(concat, axis);
2370+
int64_t original_num_chunks = chunks.size() / 2;
2371+
xla::Array<Value> front_chunk_copy(chunks.front());
2372+
2373+
Value rotate_amount = mlirI32Const(valid_amount);
2374+
auto iota = builder.create<tpu::IotaOp>(
2375+
i32_vreg, builder.getI32IntegerAttr(tiling_dim));
2376+
auto mask = builder.create<arith::CmpIOp>(
2377+
arith::CmpIPredicate::sge, iota,
2378+
builder.create<arith::ConstantOp>(DenseElementsAttr::get(
2379+
i32_vreg, builder.getI32IntegerAttr(valid_amount))));
2380+
// overwrite padding in the last vreg with valid data from the first vreg.
2381+
chunks.back().Each([&](absl::Span<const int64_t> idxs, Value *v) {
2382+
*v = builder.create<arith::SelectOp>(
2383+
mask,
2384+
builder.create<tpu::DynamicRotateOp>(
2385+
res_vreg_ty, front_chunk_copy(idxs), rotate_amount, tiling_dim,
2386+
nullptr, nullptr),
2387+
*v);
2388+
});
2389+
// rotate the vregs starting from the middle vreg.
2390+
for (int64_t i = original_num_chunks; i < chunks.size(); ++i) {
2391+
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
2392+
*v = builder.create<tpu::DynamicRotateOp>(
2393+
res_vreg_ty, *v, rotate_amount, tiling_dim, nullptr, nullptr);
2394+
});
2395+
}
2396+
// blend the vregs to overwrite the padding.
2397+
for (int64_t i = original_num_chunks - 1; i < chunks.size() - 1; ++i) {
2398+
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
2399+
*v = builder.create<arith::SelectOp>(mask, chunks[i + 1](idxs), *v);
2400+
});
2401+
}
2402+
SmallVector<int64_t> result_dimensions =
2403+
layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape);
2404+
// assemble the result
2405+
xla::Array<Value> result(result_dimensions);
2406+
SmallVector<int64_t> starts(result.num_dimensions(), 0);
2407+
for (int64_t i = 0; i < result_dimensions[axis]; ++i) {
2408+
starts[axis] = i;
2409+
result.UpdateSlice(chunks[i + start_vreg_idx], starts);
2410+
}
2411+
return result;
2412+
};
2413+
23482414
std::function<xla::Array<Value>(const xla::Array<Value> &, Value, int, int)>
23492415
rotate;
23502416
rotate = [&](const xla::Array<Value> &vregs, Value shift, int axis,
@@ -2353,7 +2419,15 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
23532419
CHECK(axis >= 0 && axis < vregs.num_dimensions());
23542420
int tiling_dim = axis - (vregs.num_dimensions() - 2);
23552421
CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0));
2422+
const bool has_padding =
2423+
(tiling_dim == 0 || tiling_dim == 1) &&
2424+
vty.getShape()[axis] % ctx.target_shape[tiling_dim] != 0;
23562425
SmallVector<xla::Array<Value>, 4> chunks;
2426+
// Handle rotation with static shift and padding lazily.
2427+
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
2428+
succeeded(shift_cst) && has_padding) {
2429+
return lazyRotate(vregs, shift_cst.value(), axis);
2430+
}
23572431
// Handle rotation with static shift.
23582432
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
23592433
succeeded(shift_cst)) {
@@ -2445,7 +2519,7 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
24452519
roll_by *= 2;
24462520
}
24472521
return result;
2448-
};
2522+
}; // end of rotate
24492523

24502524
xla::Array<Value> out_tiles(in_tiles.dimensions());
24512525
const auto dim = op.getDimension();

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

+43-1
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,46 @@ class VectorLayoutInferer {
748748
return success();
749749
}
750750

751+
// Helper function to compute the layout offsets for a dynamic rotate op.
752+
LayoutOffsets compute_layout_offsets(tpu::DynamicRotateOp op) {
753+
LayoutOffsets layout_offsets{0, 0};
754+
const unsigned int bitwidth = op.getType().getElementTypeBitWidth();
755+
const auto tiling = nativeTiling(bitwidth);
756+
const int tiling_dim = op.getDimension() - (op.getType().getRank() - 2);
757+
if (tiling_dim != 0 && tiling_dim != 1) {
758+
return layout_offsets;
759+
}
760+
const int64_t tile_size = tiling[tiling_dim];
761+
const int64_t dim_size = op.getType().getShape()[op.getDimension()];
762+
if (dim_size % tile_size == 0) {
763+
return layout_offsets;
764+
}
765+
auto amount = op.getAmount().getDefiningOp<arith::ConstantOp>();
766+
if (!amount) {
767+
return layout_offsets;
768+
}
769+
auto integer_attr = dyn_cast<IntegerAttr>(amount.getValue());
770+
if (!integer_attr) {
771+
return layout_offsets;
772+
}
773+
if (auto stride = op.getStride(); stride.has_value() && *stride != 0) {
774+
return layout_offsets;
775+
}
776+
if (tiling_dim != 0 && tiling_dim != 1) {
777+
return layout_offsets;
778+
}
779+
int64_t shift_amount = integer_attr.getValue().getSExtValue();
780+
// Normalize the shift amount to the dimension size.
781+
shift_amount = shift_amount % dim_size;
782+
CHECK_GE(shift_amount, 0);
783+
CHECK_LE(shift_amount, dim_size);
784+
// Absolute offset.
785+
int64_t offset = dim_size - shift_amount;
786+
// Convert to relative offsets within the tile.
787+
layout_offsets[tiling_dim] = offset % tile_size;
788+
return layout_offsets;
789+
}
790+
751791
LogicalResult infer(tpu::DynamicRotateOp op) {
752792
auto bitwidth = op.getType().getElementTypeBitWidth();
753793
// TODO(b/347067057): Support dynamic rotate with packed dtype.
@@ -759,7 +799,9 @@ class VectorLayoutInferer {
759799
}
760800
auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
761801
ImplicitDim::kNone);
762-
setLayout(op, {layout, kNoLayout}, layout);
802+
auto out_layout = VectorLayout(bitwidth, compute_layout_offsets(op),
803+
nativeTiling(bitwidth), ImplicitDim::kNone);
804+
setLayout(op, {layout, kNoLayout}, out_layout);
763805
return success();
764806
}
765807

tests/pallas/tpu_pallas_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -2844,9 +2844,7 @@ def kernel(x_ref, out_ref):
28442844
)(x)
28452845
np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32))
28462846

2847-
@only_passes_in_interpret()
28482847
def test_roll_partial(self):
2849-
"""b/337384645"""
28502848
x = np.arange(8192, dtype=jnp.float32).reshape(128, 64)
28512849

28522850
def kernel(x_ref, out_ref):

0 commit comments

Comments
 (0)