Skip to content

Commit b1ef5a8

Browse files
authored
[mlir][MemRef] Add support for emulating narrow floats (#148036)
This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter for loads/stores, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations. There is no direct change needed for vector.load/store support. The tests added for them are to verify that float types are supported as well.
1 parent 47c9609 commit b1ef5a8

File tree

4 files changed

+181
-13
lines changed

4 files changed

+181
-13
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
323323
// It is not clear if this case actually happens in practice, but we keep
324324
// the operations just in case. Otherwise, if the arith computation bitwidth
325325
// is different from the emulated bitwidth we truncate the result.
326-
Operation *result;
326+
Value result;
327327
auto resultTy = getTypeConverter()->convertType(oldElementType);
328-
if (resultTy == convertedElementType) {
328+
auto conversionTy =
329+
resultTy.isInteger()
330+
? resultTy
331+
: IntegerType::get(rewriter.getContext(),
332+
resultTy.getIntOrFloatBitWidth());
333+
if (conversionTy == convertedElementType) {
329334
auto mask = rewriter.create<arith::ConstantOp>(
330335
loc, convertedElementType,
331336
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
332337

333338
result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
334339
} else {
335-
result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
340+
result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
336341
}
337342

338-
rewriter.replaceOp(op, result->getResult(0));
343+
if (conversionTy != resultTy) {
344+
result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
345+
}
346+
347+
rewriter.replaceOp(op, result);
339348
return success();
340349
}
341350
};
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
415424
}
416425

417426
Location loc = op.getLoc();
418-
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
419-
adaptor.getValue());
427+
428+
// Pad the input value with 0s on the left.
429+
Value input = adaptor.getValue();
430+
if (!input.getType().isInteger()) {
431+
input = rewriter.create<arith::BitcastOp>(
432+
loc,
433+
IntegerType::get(rewriter.getContext(),
434+
input.getType().getIntOrFloatBitWidth()),
435+
input);
436+
}
437+
Value extendedInput =
438+
rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
420439

421440
// Special case 0-rank memref stores. No need for masking.
422441
if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
619638
arith::NarrowTypeEmulationConverter &typeConverter) {
620639
typeConverter.addConversion(
621640
[&typeConverter](MemRefType ty) -> std::optional<Type> {
622-
auto intTy = dyn_cast<IntegerType>(ty.getElementType());
623-
if (!intTy)
641+
Type elementType = ty.getElementType();
642+
if (!elementType.isIntOrFloat())
624643
return ty;
625644

626-
unsigned width = intTy.getWidth();
645+
unsigned width = elementType.getIntOrFloatBitWidth();
627646
unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
628647
if (width >= loadStoreWidth)
629648
return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
636655
if (!strides.empty() && strides.back() != 1)
637656
return nullptr;
638657

639-
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
640-
intTy.getSignedness());
658+
auto newElemTy = IntegerType::get(
659+
ty.getContext(), loadStoreWidth,
660+
elementType.isInteger()
661+
? cast<IntegerType>(elementType).getSignedness()
662+
: IntegerType::SignednessSemantics::Signless);
641663
if (!newElemTy)
642664
return nullptr;
643665

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
12681268
bool isDivisibleInSize =
12691269
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
12701270

1271-
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
1272-
adaptor.getPadding());
1271+
// Pad the padding value with 0s on the left. These bits are discarded and
1272+
// thus their values don't matter.
1273+
Value padding = adaptor.getPadding();
1274+
if (!padding.getType().isInteger()) {
1275+
padding = rewriter.create<arith::BitcastOp>(
1276+
loc,
1277+
IntegerType::get(rewriter.getContext(),
1278+
padding.getType().getIntOrFloatBitWidth()),
1279+
padding);
1280+
}
1281+
auto newPadding =
1282+
rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
12731283

12741284
auto stridedMetadata =
12751285
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
6161

6262
// -----
6363

64+
func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
65+
%0 = memref.alloc() : memref<5xf4E2M1FN>
66+
%1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
67+
return %1 : f4E2M1FN
68+
}
69+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
70+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
71+
// CHECK: func @memref_load_f4(
72+
// CHECK-SAME: %[[ARG0:.+]]: index
73+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
74+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
75+
// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
76+
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
77+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
78+
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
79+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
80+
// CHECK: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
81+
// CHECK: return %[[BC]]
82+
83+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
84+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
85+
// CHECK32: func @memref_load_f4(
86+
// CHECK32-SAME: %[[ARG0:.+]]: index
87+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
88+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
89+
// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
90+
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
91+
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
92+
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
93+
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
94+
// CHECK32: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
95+
// CHECK32: return %[[BC]]
96+
97+
// -----
98+
6499
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
65100
%0 = memref.alloc() : memref<3x125xi4>
66101
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
470505

471506
// -----
472507

508+
func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
509+
%0 = memref.alloc() : memref<f4E2M1FN>
510+
memref.store %arg0, %0[] : memref<f4E2M1FN>
511+
return
512+
}
513+
// CHECK-LABEL: func @rank_zero_memref
514+
// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN
515+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
516+
// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
517+
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
518+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
519+
// CHECK: return
520+
521+
// CHECK32-LABEL: func @rank_zero_memref
522+
// CHECK32-SAME: %[[ARG0:.+]]: f4E2M1FN
523+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
524+
// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
525+
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
526+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
527+
// CHECK32: return
528+
529+
// -----
530+
473531
func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
474532
%arr = memref.alloc() : memref<32x8x128xi4>
475533
%collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
5353

5454
// -----
5555

56+
func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
57+
%0 = memref.alloc() : memref<3x8xf4E2M1FN>
58+
%cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
59+
%1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
60+
%2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
61+
return %2 : vector<3x8xf4E2M1FN>
62+
}
63+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
64+
// CHECK: func @vector_load_f4
65+
// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
66+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
67+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
68+
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
69+
// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
70+
71+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
72+
// CHECK32: func @vector_load_f4
73+
// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
74+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
75+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
76+
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
77+
// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
78+
79+
// -----
80+
5681
func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
5782
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
5883
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
119144

120145
// -----
121146

147+
func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
148+
%c0 = arith.constant 0.0 : f4E2M1FN
149+
%0 = memref.alloc() : memref<3x8xf4E2M1FN>
150+
%1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
151+
memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
152+
return %1 : vector<8xf4E2M1FN>
153+
}
154+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
155+
// CHECK: func @vector_transfer_read_f4
156+
// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
157+
// CHECK: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
158+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
159+
// CHECK: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
160+
// CHECK: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
161+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
162+
// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
163+
// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
164+
165+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
166+
// CHECK32: func @vector_transfer_read_f4
167+
// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
168+
// CHECK32: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
169+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
170+
// CHECK32: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
171+
// CHECK32: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
172+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
173+
// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
174+
// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
175+
176+
// -----
177+
122178
///----------------------------------------------------------------------------------------
123179
/// vector.maskedload
124180
///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
439495

440496
// -----
441497

498+
func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
499+
%0 = memref.alloc() : memref<4x8xf4E2M1FN>
500+
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
501+
return
502+
}
503+
504+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
505+
// CHECK: func @vector_store_f4
506+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
507+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
508+
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
509+
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
510+
511+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
512+
// CHECK32: func @vector_store_f4
513+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
514+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
515+
// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
516+
// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
517+
518+
// -----
519+
442520
// FIXME: This example assumes that the store happens at a byte boundary, but
443521
// that's not guaranteed. Below is a counter-example with specific dimensions:
444522
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>

0 commit comments

Comments
 (0)