Skip to content

Support pltpu.roll on sublanes when not all lanes are used. #28228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 127 additions & 5 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/iterator_range.h"
Expand Down Expand Up @@ -2209,16 +2210,45 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
if (layout_in != layout) {
return op.emitOpError("Not implemented: unsupported layout for input");
}
if (layout_out != layout) {
// We support non-zero offsets in the output layout via lazy rotation.
LayoutOffsets expected_offsets_out = layout_in.offsets();
auto shift = getIntConst(amount, /*silent=*/true);
int tiling_dim = op.getDimension() - (op.getType().getRank() - 2);
if (succeeded(shift) && tiling_dim == 1) {
const int64_t tile_size = layout_out.tiling()[tiling_dim];
// We assume there are no implicit dims.
const int64_t dim_size = op.getType().getShape()[op.getDimension()];
if (dim_size % tile_size != 0) {
// TODO(b/337384645): Currently we assume {0, 0} offsets in the input
// layout. Relax this assumption.
expected_offsets_out[tiling_dim] =
(dim_size - (shift.value() % dim_size)) % tile_size;
}
}
if (layout_out.bitwidth() != layout.bitwidth() ||
layout_out.offsets() != expected_offsets_out ||
layout_out.tiling() != layout.tiling() ||
layout_out.implicit_dim() != layout.implicit_dim()) {
return op.emitOpError("Not implemented: unsupported layout for output");
}
auto vty = op.getResult().getType();
if (vty.getRank() < 2) {
return op.emitOpError("Not implemented: unsupported 1D shape");
}
if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 ||
*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0) {
return op.emitOpError("Not implemented: unsupported unaliged shape");
// We do not check stride here since unaligned subline is not supported
// due to b/411170715.
if (*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) != 0 &&
op.getDimension() == vty.getRank() - 2) {
return op.emitOpError(
"Not implemented: unsupported unaligned shape in sublane dimension");
}
// We check the stride since unaligned lane is supported without a stride or
// with a zero stride.
if (*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) != 0 &&
op.getDimension() == vty.getRank() - 1 &&
op.getStride().value_or(0) != 0) {
return op.emitOpError(
"Not implemented: unsupported unaligned shape in lane dimension");
}

ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
Expand Down Expand Up @@ -2345,6 +2375,88 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
return concatenate(chunks, axis);
};

// Applies lazy rotation (see go/pltpu-roll for details).
auto lazyRotate = [&](const xla::Array<Value> &vregs, int64_t shift,
int axis) {
const int tiling_dim = axis - (vregs.num_dimensions() - 2);
const int64_t tile_size = ctx.target_shape[tiling_dim];
const int64_t input_size = vty.getShape()[axis];
const int64_t normalized_shift = shift % input_size;
const int64_t start_idx = input_size - normalized_shift;
const int64_t start_vreg_idx = start_idx / tile_size;
const int64_t valid_amount = input_size % tile_size;

// We start with the following:
//
// vregs:
// +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 XXX|
// +------+ +------+ +------+
//
// where XXX is the padding and ░░░ is the prefix of the same size as the
// padding.

// After concatenation:
//
// concat:
// +------+ +------+ +------+ +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 XXX|
// +------+ +------+ +------+ +------+ +------+ +------+
auto concat = concatenate({vregs, vregs}, axis);
auto chunks = split(concat, axis);
int64_t original_num_chunks = chunks.size() / 2;

Value rotate_amount = mlirI32Const(valid_amount);
SmallVector<Value, 2> low = {mlirIndexConst(0), mlirIndexConst(0)};
low[tiling_dim] = mlirIndexConst(valid_amount);
auto mask = builder.create<tpu::CreateMaskOp>(
VectorType::get(ctx.target_shape, builder.getI1Type()), low,
/*high=*/
ArrayRef<Value>{mlirIndexConst(ctx.target_shape[0]),
mlirIndexConst(ctx.target_shape[1])});
// overwrite padding in the last vreg with valid data from the first vreg,
// yielding:
//
// +------+ +------+ +------+ +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 XXX| |░░░ 0 | | 1 | | 2 ░░░|
// +------+ +------+ +------+ +------+ +------+ +------+
chunks.back().Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<arith::SelectOp>(
mask,
builder.create<tpu::DynamicRotateOp>(
res_vreg_ty, chunks.front()(idxs), rotate_amount, tiling_dim,
nullptr, nullptr),
*v);
});
// rotate the vregs starting from the middle vreg and then blend the vregs
// to overwrite the padding, yielding:
//
// +------+ +------+ +---+ +------+ +------+ +------+
// |░░░ 0 | | 1 | | 2 | |░░░ 0 | | 1 | | 2 ░░░|
// +------+ +------+ +---+ +------+ +------+ +------+
for (int64_t i = original_num_chunks; i < chunks.size(); ++i) {
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<tpu::DynamicRotateOp>(
res_vreg_ty, *v, rotate_amount, tiling_dim, nullptr, nullptr);
});
}
for (int64_t i = original_num_chunks - 1; i < chunks.size() - 1; ++i) {
chunks[i].Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = builder.create<arith::SelectOp>(mask, chunks[i + 1](idxs), *v);
});
}
SmallVector<int64_t> result_dimensions =
layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape);
// assemble the result
xla::Array<Value> result(result_dimensions);
SmallVector<int64_t> starts(result.num_dimensions(), 0);
for (int64_t i = 0; i < result_dimensions[axis]; ++i) {
starts[axis] = i;
result.UpdateSlice(chunks[i + start_vreg_idx], starts);
}
return result;
};

std::function<xla::Array<Value>(const xla::Array<Value> &, Value, int, int)>
rotate;
rotate = [&](const xla::Array<Value> &vregs, Value shift, int axis,
Expand All @@ -2353,7 +2465,15 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
CHECK(axis >= 0 && axis < vregs.num_dimensions());
int tiling_dim = axis - (vregs.num_dimensions() - 2);
CHECK((tiling_dim != 1 && stride == 0) || (tiling_dim == 1 && stride >= 0));
const bool has_padding =
(tiling_dim == 0 || tiling_dim == 1) &&
vty.getShape()[axis] % ctx.target_shape[tiling_dim] != 0;
SmallVector<xla::Array<Value>, 4> chunks;
// Handle rotation with static shift and padding lazily.
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
succeeded(shift_cst) && has_padding) {
return lazyRotate(vregs, shift_cst.value(), axis);
}
// Handle rotation with static shift.
if (auto shift_cst = getIntConst(shift, /*silent=*/true);
succeeded(shift_cst)) {
Expand Down Expand Up @@ -2447,7 +2567,9 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount,
return result;
};

xla::Array<Value> out_tiles(in_tiles.dimensions());
SmallVector<int64_t> out_dimensions =
layout_out.tileArrayImplicitShape(vty.getShape(), ctx.target_shape);
xla::Array<Value> out_tiles(out_dimensions);
const auto dim = op.getDimension();
amount = modI(amount, vty.getDimSize(dim));

Expand Down
21 changes: 20 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,26 @@ class VectorLayoutInferer {
}
auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, {layout, kNoLayout}, layout);
// Calculate the offsets for the output layout.
LayoutOffsets offsets_out = layout.offsets();
// We assume there are no implicit dims.
int tiling_dim = op.getDimension() - (op.getType().getRank() - 2);
if (auto amount = op.getAmount().getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(amount.getValue());
integer_attr && tiling_dim == 1) {
const int64_t tile_size = layout.tiling()[tiling_dim];
const int64_t dim_size = op.getType().getShape()[op.getDimension()];
const int64_t shift = integer_attr.getValue().getSExtValue();
if (dim_size % tile_size != 0) {
// TODO(b/337384645): Currently we assume {0, 0} offsets in the input
// layout. Relax this assumption.
offsets_out[tiling_dim] = (dim_size - (shift % dim_size)) % tile_size;
}
}
}
auto out_layout = VectorLayout(bitwidth, offsets_out,
nativeTiling(bitwidth), ImplicitDim::kNone);
setLayout(op, {layout, kNoLayout}, out_layout);
return success();
}

Expand Down
4 changes: 2 additions & 2 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2895,9 +2895,9 @@ def kernel(x_ref, out_ref):
)(x)
np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32))

@only_passes_in_interpret()
def test_roll_partial(self):
"""b/337384645"""
if not jtu.if_cloud_tpu_at_least(2025, 5, 10):
self.skipTest('Needs a newer libtpu')
x = np.arange(8192, dtype=jnp.float32).reshape(128, 64)

def kernel(x_ref, out_ref):
Expand Down
Loading