Skip to content

[mlir][xegpu] Relax rank restriction of TensorDescType #145916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,18 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
}];

let parameters = (ins
OptionalParameter<"MemorySpaceAttr">: $memory_space,
OptionalParameter<"IntegerAttr", "1">: $array_length,
OptionalParameter<"BoolAttr", "true">: $boundary_check
DefaultValuedParameter<
"MemorySpaceAttr",
"MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
"Data memory location">: $memory_space,
DefaultValuedParameter<
"IntegerAttr",
"IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
"Number of continuous blocks to load">: $array_length,
DefaultValuedParameter<
"BoolAttr",
"BoolAttr::get($_ctxt, 1)",
"Checking the out of boundary access">: $boundary_check
);

let builders = [
Expand All @@ -67,8 +76,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
TensorDesc is located, `Global` device memory or `Shared` local memory.
It is default to `Global`.

2. `chunk_size`: indicates number of contiguous elements accessed for each
offset, default is 1. It is used with `scattered` attr only.
2. `chunk_size`: Specifies the number of contiguous elements accessed per offset.
The default value is 1.
}];

let parameters = (ins
Expand All @@ -91,6 +100,12 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
)>
];

let extraClassDeclaration = [{
int64_t getChunkSizeAsInt() {
return getChunkSize().getInt();
}
}];

let genVerifyDecl = 1;
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
let assemblyFormat = [{
$TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
}];

let hasVerifier = 1;
}

def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {
Expand Down
23 changes: 11 additions & 12 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;

// common base class for types in XeGPU dialect
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
Expand Down Expand Up @@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
];

let extraClassDeclaration = [{
using TensorType::clone;
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
using mlir::ShapedType::Trait<TensorDescType>::getRank;
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
Expand Down Expand Up @@ -157,6 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return MemorySpace::Global;
}

// get the ChunkSize for blocked TensorDesc
int getArrayLength() {
auto attr = getEncoding();
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
Expand All @@ -181,13 +181,12 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
return bool(getEncodingAsScatterTensorDescAttr());
}

int getChunkSize() {
// get the ChunkSize for scattered TensorDesc
int getChunkSizeAsInt() {
auto attr = getEncoding();
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
if (scatter_attr)
return scatter_attr.getChunkSize().getInt();
return 1;
assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
return scatter_attr.getChunkSizeAsInt();
}

/// Helper to drop all layout information from the TensorDesc type.
Expand Down
45 changes: 18 additions & 27 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ LogicalResult ScatterTensorDescAttr::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
int64_t chunkSize = chunk_size.getInt();
SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
16, 32, 64, 128, 256};
if (!llvm::is_contained(supportedChunkSizes, chunkSize))
if (chunkSize <= 0)
return emitError() << "invalid chunk size";

return success();
Expand Down Expand Up @@ -310,15 +308,16 @@ LogicalResult TensorDescType::verify(
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
if (rank != 1 && rank != 2)
return emitError() << "expected 1D or 2D tensor";

if (rank == 0)
return emitError() << "expected non-zero rank tensor";

auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
if (blockAttr) {
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
if (rank == 2 && memorySpaceAttr &&
if (rank > 1 && memorySpaceAttr &&
memorySpaceAttr.getValue() == MemorySpace::SLM)
return emitError() << "SLM is not supported for 2D block tensor";
return emitError() << "SLM is only supported for 1D block tensor";
}

// for gather and scatter ops, Low-precision types are packed in 32-bit units.
Expand All @@ -329,22 +328,18 @@ LogicalResult TensorDescType::verify(
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
// Expected tensor ranks for scattered data:
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
// - 2D tensor for scattered blocks (chunk size > 1)
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
if (rank == 1 && chunkSize != 1)
return emitError() << "expected non-contiguous elements for 1D tensor";
if (rank == 2 && chunkSize < 2)
return emitError() << "expected chunk blocks for 2D tensor";

// If chunk size > 1, the second dimension of the tensor shape must be
// equal to chunk size and it must be a multiple of the packing factor.
// equal to chunk size and it must be a multiple of the
// chunkAlignmentFactor.
if (chunkSize > 1) {
if (shape.back() != chunkSize)
return emitError() << "expected tensor shape[1] to match chunk size";
return emitError() << "expected last dim of tensor to match chunk size";
if (shape.back() % chunkAlignmentFactor != 0)
return emitError() << "expected tensor shape[1] to be a multiple of "
"chunk alignment factor "
return emitError() << "expected last dim of tensor to be a multiple of "
<< chunkAlignmentFactor;
}
}
Expand All @@ -357,17 +352,13 @@ LogicalResult TensorDescType::verify(
auto laneData = layoutAttr.getLaneData();
if (scatterAttr && laneData) {
// Validate subgroup mapping rules for scattered tensors.
// A work-item's slice of the tensor with shape [sg_size] or
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
// respectively, the mapping should reflect that. This is because each
// work item access data in 32 bit granularity.

if (rank > 1 && laneData[0] != 1)
// if chunkSize > 1, the last dimension of the tensor should
// be distributed in the units divisible by chunkAlignmentFactor.
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we allow lane layout also be nD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the layout rank matches the tensor/vector rank.

return emitError()
<< "cannot map over non-contiguous scattered row elements";
if (laneData[rank - 1] != chunkAlignmentFactor)
return emitError() << "work item data mapping must match the number of "
"contiguous elements";
<< "expected last dim of lane_data to be a multiple of: "
<< chunkAlignmentFactor;
}

if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
Expand Down
72 changes: 32 additions & 40 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto tdescShape = getShapeOf(tdescTy);
auto chunkSize = tdescTy.getChunkSize();
auto chunkSize = tdescTy.getChunkSizeAsInt();

if (valueTy.getElementType() != tdescTy.getElementType())
return emitError()
<< "Value should have the same element type as TensorDesc.";

if (tdescShape[0] != maskShape[0])
llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
return emitError()
<< "dim-0 of the Mask and TensorDesc should be the same.";
<< "Mask should match TensorDesc except the chunk size dim.";

// a valid shape for SIMT case
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
Expand Down Expand Up @@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() {
"is a memref) should match with each other.");

// check result TensorDesc rank
invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);

if (invalidRank)
if (getType().getRank() > rank)
return emitOpError(
"Expecting the TensorDesc rank is up to 2 and not greater than the "
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");

if (invalidElemTy)
Expand Down Expand Up @@ -247,9 +248,6 @@ LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();

if (tdescTy.getRank() > 2)
return emitOpError("Expecting a 1D/2D TensorDesc.\n");

if (tdescTy.isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");

Expand Down Expand Up @@ -316,15 +314,13 @@ LogicalResult LoadNdOp::verify() {
}

auto array_len = tdescTy.getArrayLength();
if (array_len > 1) {
if (array_len > 1)
tdescShape.insert(tdescShape.begin(), array_len);
}

if (tdescShape != valueShape) {
if (tdescShape != valueShape)
return emitOpError() << "Result shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< tdescTy;
}

return success();
}
Expand All @@ -336,9 +332,6 @@ LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector

if (dstTy.getRank() > 2)
return emitOpError("Expecting a 1D/2D TensorDesc.\n");

if (dstTy.isScattered())
return emitOpError("Expects a non-scattered TensorDesc.\n");

Expand Down Expand Up @@ -370,22 +363,21 @@ LogicalResult StoreNdOp::verify() {
return emitOpError()
<< "TensorDesc doesn't need LayoutAttr for SIMT code";

if (tdescElems % valueElems) {
if (tdescElems % valueElems)
return emitOpError()
<< "Value shape " << makeString(getShapeOf(valTy))
<< " is not a valid distribution for tensor descriptor " << dstTy;
}

return success();
}

// SIMD code should have the same shape as the tensor descriptor.
auto tdescShape = getShapeOf(dstTy);
auto valueShape = getShapeOf(valTy);
if (tdescShape != valueShape) {
if (tdescShape != valueShape)
return emitOpError() << "Value shape " << makeString(valueShape)
<< " is not consistent with tensor descriptor "
<< dstTy;
}

return success();
}
Expand Down Expand Up @@ -449,25 +441,8 @@ LogicalResult CreateDescOp::verify() {
<< ", TensorDesc: " << tdescMemorySpace;

// check total size
auto chunkSize = tdescTy.getChunkSize();
auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
auto bitsPerLane = elemBits * chunkSize;
if (chunkSize > 1 && bitsPerLane % 32) {
// For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
// For 32-bit data, the hardware can support larger larger chunk size. So
// we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
// But this requires the total size is 32 bit aligned to make the
// optimization work.
return emitOpError(
"access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
}

auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
if (elemBits * tdescTy.getNumElements() > lscConstraints)
return emitOpError("total access size (simd_lanes * chunk_size * "
"sizeof(elemTy)) is upto 512 bytes.");

SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
Comment on lines -453 to -470
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are these verified now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are totally gone now. I suppose this will be checked in XeVM. They were appropriate when XeGPU was designed to match hardware abstraction. But now XeGPU is promoted to workgroup level.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this will be checked in XeVM.

+1 looking at the current abstractions
A completely separate validation pass that also takes target uarch info might be best at this point

auto chunkSize = tdescTy.getChunkSizeAsInt();
SmallVector<int64_t> shape(getOffsetsType().getShape());
if (chunkSize != 1)
shape.push_back(chunkSize);

Expand Down Expand Up @@ -563,6 +538,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tensorDesc, ofrs);
}

LogicalResult UpdateOffsetOp::verify() {
auto tdescTy = getTensorDescType();
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");

SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
if (tdescTy.getChunkSizeAsInt() > 1)
expectedOffsetShape.pop_back();

if (expectedOffsetShape != offsetShape)
return emitOpError(
"Offsets should match TensorDesc except the chunk size dim.");

return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
if (tdescTy.isScattered()) {
auto scatterAttr =
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
int64_t chunkSize = tdescTy.getChunkSizeAsInt();

if (chunkSize > 1) {
int64_t blockedChunkSize = chunkSize;
Expand All @@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() {

// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
ctx, tdescTy.getMemorySpace(), blockedChunkSize);

encoding = newEncoding;
}
Expand Down
Loading
Loading