Skip to content

Support pltpu.roll on sublanes when not all lanes are used. #28228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 118 additions & 5 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/iterator_range.h"
Expand Down Expand Up @@ -2141,16 +2142,41 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
if (layout_in != layout) {
return op.emitOpError("Not implemented: unsupported layout for input");
}
if (layout_out != layout) {
LayoutOffsets expected_offsets_out = layout_in.offsets();
auto shift = getIntConst(amount, /*silent=*/true);
const bool has_static_shift = succeeded(shift);
int rotated_tiled_dim = op.getDimension() - (op.getType().getRank() - 2);
bool has_padding_along_rotation =
(rotated_tiled_dim == 0 || rotated_tiled_dim == 1) &&
op.getType().getShape()[op.getDimension()] %
layout.tiling()[rotated_tiled_dim] !=
0;
if (has_static_shift && has_padding_along_rotation) {
// We checked above that there are no implicit dims.
const int64_t dim_size = op.getType().getShape()[op.getDimension()];
// TODO(b/337384645): Currently we assume {0, 0} offsets in the input
// layout. Relax this assumption.
expected_offsets_out[rotated_tiled_dim] =
(dim_size - (shift.value() % dim_size)) %
layout.tiling()[rotated_tiled_dim];
}
if (layout_out.bitwidth() != layout.bitwidth() ||
layout_out.offsets() != expected_offsets_out ||
layout_out.tiling() != layout.tiling() ||
layout_out.implicit_dim() != layout.implicit_dim()) {
return op.emitOpError("Not implemented: unsupported layout for output");
}
auto vty = op.getResult().getType();
if (vty.getRank() < 2) {
return op.emitOpError("Not implemented: unsupported 1D shape");
}
if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 ||
*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) {
return op.emitOpError("Not implemented: unsupported unaliged shape");
// TODO(b/411170715): Allow sublane rotation once the bug is fixed.
// TODO(b/337384645): Support non-zero stride.
if (has_padding_along_rotation &&
(!has_static_shift ||
(rotated_tiled_dim == 0 ||
(rotated_tiled_dim == 1 && op.getStride().value_or(0) != 0)))) {
return op.emitOpError("Not implemented: unsupported unaligned shape");
}

ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
Expand Down Expand Up @@ -2277,6 +2303,88 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
return concatenate(chunks, axis);
};

// Applies lazy rotation (see go/pltpu-roll for details).
auto lazyRotate = [&](const xla::Array<Value> &vregs, int64_t shift,
int axis) {
const int tiling_dim = axis - (vregs.num_dimensions() - 2);
const int64_t tile_size = ctx.target_shape[tiling_dim];
const int64_t input_size = vty.getShape()[axis];
const int64_t normalized_shift = shift % input_size;
const int64_t start_idx = input_size - normalized_shift;
const int64_t start_vreg_idx = start_idx / tile_size;
const int64_t valid_amount = input_size % tile_size;

// We start with the following:
//
// vregs:
// +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 XXX|
// +------+ +------+ +------+
//
// where XXX is the padding and ░░░ is the prefix of the same size as the
// padding.

// After concatenation:
//
// concat:
// +------+ +------+ +------+ +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 XXX|
// +------+ +------+ +------+ +------+ +------+ +------+
auto concat = concatenate({vregs, vregs}, axis);
auto chunks = split(concat, axis);
int64_t original_num_chunks = chunks.size() / 2;

Value rotate_amount = mlirI32Const(valid_amount);
SmallVector<Value, 2> low = {mlirIndexConst(0), mlirIndexConst(0)};
low[tiling_dim] = mlirIndexConst(valid_amount);
auto mask = builder.create<tpu::CreateMaskOp>(
VectorType::get(ctx.target_shape, builder.getI1Type()), low,
/*high=*/
ArrayRef<Value>{mlirIndexConst(ctx.target_shape[0]),
mlirIndexConst(ctx.target_shape[1])});
// overwrite padding in the last vreg with valid data from the first vreg,
// yielding:
//
// +------+ +------+ +------+ +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 ░░░|
// +------+ +------+ +------+ +------+ +------+ +------+
chunks.back().Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<arith::SelectOp>(
mask,
builder.create<tpu::DynamicRotateOp>(
res_vreg_ty, chunks.front()(idxs), rotate_amount, tiling_dim,
nullptr, nullptr),
*v);
});
// rotate the vregs starting from the middle vreg and then blend the vregs
// to overwrite the padding, yielding:
//
// +------+ +------+ +---+ +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 | |░░░ 0 | | 1 | | 2 ░░░|
// +------+ +------+ +---+ +------+ +------+ +------+
for (int64_t i = original_num_chunks; i < chunks.size(); ++i) {
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<tpu::DynamicRotateOp>(
res_vreg_ty, *v, rotate_amount, tiling_dim, nullptr, nullptr);
});
}
for (int64_t i = original_num_chunks - 1; i < chunks.size() - 1; ++i) {
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<arith::SelectOp>(mask, chunks[i + 1](idxs), *v);
});
}
SmallVector<int64_t> result_dimensions =
layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape);
// assemble the result
xla::Array<Value> result(result_dimensions);
SmallVector<int64_t> starts(result.num_dimensions(), 0);
for (int64_t i = 0; i < result_dimensions[axis]; ++i) {
starts[axis] = i;
result.UpdateSlice(chunks[i + start_vreg_idx], starts);
}
return result;
};

std::function<xla::Array<Value>(const xla::Array<Value> &, Value, int, int)>
rotate;
rotate = [&](const xla::Array<Value> &vregs, Value shift, int axis,
Expand All @@ -2290,6 +2398,9 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
succeeded(shift_cst)) {
int64_t static_shift = shift_cst.value();
if (has_padding_along_rotation) {
return lazyRotate(vregs, static_shift, axis);
}
if (tiling_dim >= 0) {
shift = mlirI32Const(static_shift % ctx.target_shape[tiling_dim]);
static_shift /= ctx.target_shape[tiling_dim];
Expand Down Expand Up @@ -2379,7 +2490,9 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
return result;
};

xla::Array<Value> out_tiles(in_tiles.dimensions());
SmallVector<int64_t> out_dimensions =
layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape);
xla::Array<Value> out_tiles(out_dimensions);
const auto dim = op.getDimension();
amount = modI(amount, vty.getDimSize(dim));

Expand Down
21 changes: 20 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,28 @@ class VectorLayoutInferer {
if (op.getType().getRank() < 2) {
NYI("Unsupported 1D shape");
}
// TODO(b/337384645): Currently we assume {0, 0} offsets in the input
// layout. Relax this assumption.
auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, {layout, kNoLayout}, layout);
// Calculate the offsets for the output layout.
LayoutOffsets offsets_out = layout.offsets();
// We assume there are no implicit dims.
int tiling_dim = op.getDimension() - (op.getType().getRank() - 2);
if (auto amount = op.getAmount().getDefiningOp<arith::ConstantOp>();
amount && (tiling_dim == 0 || tiling_dim == 1)) {
if (auto integer_attr = dyn_cast<IntegerAttr>(amount.getValue())) {
const int64_t tile_size = layout.tiling()[tiling_dim];
const int64_t dim_size = op.getType().getShape()[op.getDimension()];
const int64_t shift = integer_attr.getValue().getSExtValue();
if (dim_size % tile_size != 0) {
offsets_out[tiling_dim] = (dim_size - (shift % dim_size)) % tile_size;
}
}
}
auto out_layout = VectorLayout(bitwidth, offsets_out,
nativeTiling(bitwidth), ImplicitDim::kNone);
setLayout(op, {layout, kNoLayout}, out_layout);
return success();
}

Expand Down
22 changes: 19 additions & 3 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2895,9 +2895,9 @@ def kernel(x_ref, out_ref):
)(x)
np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32))

@only_passes_in_interpret()
def test_roll_partial(self):
"""b/337384645"""
def test_roll_partial_with_static_shift(self):
if not jtu.if_cloud_tpu_at_least(2025, 5, 15):
self.skipTest('Needs a newer libtpu')
x = np.arange(8192, dtype=jnp.float32).reshape(128, 64)

def kernel(x_ref, out_ref):
Expand All @@ -2908,6 +2908,22 @@ def kernel(x_ref, out_ref):
)(x)
np.testing.assert_array_equal(out, np.roll(x, 3, 1))

def test_roll_partial_with_dynamic_shift(self):
if not jtu.if_cloud_tpu_at_least(2025, 5, 15):
self.skipTest('Needs a newer libtpu')
if self.INTERPRET:
self.skipTest('Test only applies to non-interpret mode.')
x = np.arange(8192, dtype=jnp.float32).reshape(128, 64)

def kernel(x_ref, out_ref):
amount = x_ref[0, 0].astype(jnp.int32)
out_ref[...] = pltpu.roll(x_ref[...], amount, 1)

with self.assertRaisesRegex(Exception, 'unsupported unaligned shape'):
_ = self.pallas_call(
kernel, out_shape=jax.ShapeDtypeStruct((128, 64), jnp.float32)
)(x)

@only_passes_in_interpret()
def test_retiling1(self):
"""b/352626602"""
Expand Down
Loading