-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][xegpu] Relax rank restriction of TensorDescType #145916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d85e4ff
265dc27
45f7214
bf24cc8
d8ce71e
75c86c3
741d7bb
16abfea
3edb9ef
bdd644c
ad0baf6
db6115c
8a8fa74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, | |
auto maskShape = getShapeOf(maskTy); | ||
auto valueShape = getShapeOf(valueTy); | ||
auto tdescShape = getShapeOf(tdescTy); | ||
auto chunkSize = tdescTy.getChunkSize(); | ||
auto chunkSize = tdescTy.getChunkSizeAsInt(); | ||
|
||
if (valueTy.getElementType() != tdescTy.getElementType()) | ||
return emitError() | ||
<< "Value should have the same element type as TensorDesc."; | ||
|
||
if (tdescShape[0] != maskShape[0]) | ||
llvm::SmallVector<int64_t> expectedMaskShape(tdescShape); | ||
if (chunkSize > 1) | ||
expectedMaskShape.pop_back(); | ||
if (expectedMaskShape != maskShape) | ||
return emitError() | ||
<< "dim-0 of the Mask and TensorDesc should be the same."; | ||
<< "Mask should match TensorDesc except the chunk size dim."; | ||
|
||
// a valid shape for SIMT case | ||
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) { | ||
|
@@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() { | |
"is a memref) should match with each other."); | ||
|
||
// check result TensorDesc rank | ||
invalidRank = (getType().getRank() > 2 || getType().getRank() > rank); | ||
|
||
if (invalidRank) | ||
if (getType().getRank() > rank) | ||
return emitOpError( | ||
"Expecting the TensorDesc rank is up to 2 and not greater than the " | ||
"Expecting the TensorDesc rank is not greater than the " | ||
"ranks of shape, strides, offsets or the memref source."); | ||
|
||
if (invalidElemTy) | ||
|
@@ -247,9 +248,6 @@ LogicalResult LoadNdOp::verify() { | |
auto tdescTy = getTensorDescType(); | ||
auto valueTy = getType(); | ||
|
||
if (tdescTy.getRank() > 2) | ||
return emitOpError("Expecting a 1D/2D TensorDesc.\n"); | ||
|
||
if (tdescTy.isScattered()) | ||
return emitOpError("Expects a non-scattered TensorDesc.\n"); | ||
|
||
|
@@ -316,15 +314,13 @@ LogicalResult LoadNdOp::verify() { | |
} | ||
|
||
auto array_len = tdescTy.getArrayLength(); | ||
if (array_len > 1) { | ||
if (array_len > 1) | ||
tdescShape.insert(tdescShape.begin(), array_len); | ||
} | ||
|
||
if (tdescShape != valueShape) { | ||
if (tdescShape != valueShape) | ||
return emitOpError() << "Result shape " << makeString(valueShape) | ||
<< " is not consistent with tensor descriptor " | ||
<< tdescTy; | ||
} | ||
|
||
return success(); | ||
} | ||
|
@@ -336,9 +332,6 @@ LogicalResult StoreNdOp::verify() { | |
auto dstTy = getTensorDescType(); // Tile | ||
auto valTy = getValueType(); // Vector | ||
|
||
if (dstTy.getRank() > 2) | ||
return emitOpError("Expecting a 1D/2D TensorDesc.\n"); | ||
|
||
if (dstTy.isScattered()) | ||
return emitOpError("Expects a non-scattered TensorDesc.\n"); | ||
|
||
|
@@ -370,22 +363,21 @@ LogicalResult StoreNdOp::verify() { | |
return emitOpError() | ||
<< "TensorDesc doesn't need LayoutAttr for SIMT code"; | ||
|
||
if (tdescElems % valueElems) { | ||
if (tdescElems % valueElems) | ||
return emitOpError() | ||
<< "Value shape " << makeString(getShapeOf(valTy)) | ||
<< " is not a valid distribution for tensor descriptor " << dstTy; | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
// SIMD code should have the same shape as the tensor descriptor. | ||
auto tdescShape = getShapeOf(dstTy); | ||
auto valueShape = getShapeOf(valTy); | ||
if (tdescShape != valueShape) { | ||
if (tdescShape != valueShape) | ||
return emitOpError() << "Value shape " << makeString(valueShape) | ||
<< " is not consistent with tensor descriptor " | ||
<< dstTy; | ||
} | ||
|
||
return success(); | ||
} | ||
|
@@ -449,25 +441,8 @@ LogicalResult CreateDescOp::verify() { | |
<< ", TensorDesc: " << tdescMemorySpace; | ||
|
||
// check total size | ||
auto chunkSize = tdescTy.getChunkSize(); | ||
auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth(); | ||
auto bitsPerLane = elemBits * chunkSize; | ||
if (chunkSize > 1 && bitsPerLane % 32) { | ||
// For 8-bit and 16-bit data, the hardware only supports chunk size of 1. | ||
// For 32-bit data, the hardware can support larger larger chunk size. So | ||
// we can bitcast 8-bit/16-bit data to 32-bit data for better performance. | ||
// But this requires the total size is 32 bit aligned to make the | ||
// optimization work. | ||
return emitOpError( | ||
"access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned."); | ||
} | ||
|
||
auto lscConstraints = 512 * 8; // each access is upto 512 bytes. | ||
if (elemBits * tdescTy.getNumElements() > lscConstraints) | ||
return emitOpError("total access size (simd_lanes * chunk_size * " | ||
"sizeof(elemTy)) is upto 512 bytes."); | ||
|
||
SmallVector<int64_t> shape({(int64_t)getNumOffsets()}); | ||
Comment on lines
-453
to
-470
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where are these verified now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are totally gone now. I suppose this will be checked in XeVM. They were appropriate when XeGPU was designed to match hardware abstraction. But now XeGPU is promoted to workgroup level. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
+1 looking at the current abstractions |
||
auto chunkSize = tdescTy.getChunkSizeAsInt(); | ||
SmallVector<int64_t> shape(getOffsetsType().getShape()); | ||
if (chunkSize != 1) | ||
shape.push_back(chunkSize); | ||
|
||
|
@@ -563,6 +538,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state, | |
build(builder, state, tensorDesc, ofrs); | ||
} | ||
|
||
LogicalResult UpdateOffsetOp::verify() { | ||
auto tdescTy = getTensorDescType(); | ||
if (!tdescTy.isScattered()) | ||
return emitOpError("Expects a scattered TensorDesc.\n"); | ||
|
||
SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy); | ||
SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType()); | ||
if (tdescTy.getChunkSizeAsInt() > 1) | ||
expectedOffsetShape.pop_back(); | ||
|
||
if (expectedOffsetShape != offsetShape) | ||
return emitOpError( | ||
"Offsets should match TensorDesc except the chunk size dim."); | ||
|
||
return success(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// XeGPU_DpasOp | ||
//===----------------------------------------------------------------------===// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we allow lane layout also be nD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the layout rank matches the tensor/vector rank.