Skip to content

Commit 57b0907

Browse files
[mlir][bufferization] Return BufferLikeType in BufferizableOpInterface
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize into custom types. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType. Affected implementors of the interface are update accordingly.
1 parent a9a71b6 commit 57b0907

File tree

14 files changed

+196
-85
lines changed

14 files changed

+196
-85
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
712712
/// This is the default implementation of
713713
/// BufferizableOpInterface::getBufferType. Should not be called from other
714714
/// places.
715-
FailureOr<BaseMemRefType>
715+
FailureOr<BufferLikeType>
716716
defaultGetBufferType(Value value, const BufferizationOptions &options,
717717
const BufferizationState &state,
718718
SmallVector<Value> &invocationStack);

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
525525
Note: This interface method should never be called directly from user
526526
code. Always use `bufferization::getBufferType`.
527527
}],
528-
/*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
528+
/*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
529529
/*methodName=*/"getBufferType",
530530
/*args=*/(ins "::mlir::Value":$value,
531531
"const ::mlir::bufferization::BufferizationOptions &":$options,

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
111111
AliasingValueList getAliasingValues(
112112
OpOperand &opOperand, const AnalysisState &state);
113113

114-
FailureOr<BaseMemRefType> getBufferType(
114+
FailureOr<BufferLikeType> getBufferType(
115115
Value value, const BufferizationOptions &options,
116116
const BufferizationState &state,
117117
SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
478478

479479
bool isWritable(Value value, const AnalysisState &state);
480480

481-
FailureOr<BaseMemRefType> getBufferType(
481+
FailureOr<BufferLikeType> getBufferType(
482482
Value value, const BufferizationOptions &options,
483483
const BufferizationState &state, SmallVector<Value> &invocationStack) {
484-
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
484+
return getBuffer().getType();
485485
}
486486
}];
487487

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// Bufferization Type Interfaces
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "mlir/IR/BuiltinTypes.h"
1617
#include "mlir/IR/Diagnostics.h"
1718
#include "mlir/IR/Types.h"
1819

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
3232
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
3333
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
3434

35-
FailureOr<BaseMemRefType>
35+
FailureOr<BufferLikeType>
3636
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
3737
const BufferizationState &state,
3838
SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
110110
if (!bufferType)
111111
return op->emitOpError("could not infer buffer type of block argument");
112112

113-
return bufferType;
113+
return cast<BufferLikeType>(bufferType);
114114
}
115115

116116
protected:

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ struct SelectOpInterface
181181
return success();
182182
}
183183

184-
FailureOr<BaseMemRefType>
184+
FailureOr<BufferLikeType>
185185
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
186186
const BufferizationState &state,
187187
SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
196196
if (failed(trueType) || failed(falseType))
197197
return failure();
198198
if (*trueType == *falseType)
199-
return *trueType;
199+
return cast<BufferLikeType>(*trueType);
200200
if (trueType->getMemorySpace() != falseType->getMemorySpace())
201201
return op->emitError("inconsistent memory space on true/false operands");
202202

203203
// If the buffers have different types, they differ only in their layout
204204
// map.
205205
auto memrefType = llvm::cast<MemRefType>(*trueType);
206-
return getMemRefTypeWithFullyDynamicLayout(
206+
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
207207
RankedTensorType::get(memrefType.getShape(),
208208
memrefType.getElementType()),
209-
memrefType.getMemorySpace());
209+
memrefType.getMemorySpace()));
210210
}
211211
};
212212

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -945,16 +945,18 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
945945
return AliasingOpOperandList(std::move(result));
946946
}
947947

948-
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
948+
FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
949949
Value value, const BufferizationOptions &options,
950950
const BufferizationState &bufferizationState,
951951
SmallVector<Value> &invocationStack) {
952952
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
953953
auto tensorType = cast<TensorType>(value.getType());
954954

955955
// No further analysis is possible for a block argument.
956-
if (llvm::isa<BlockArgument>(value))
957-
return bufferization::getMemRefType(tensorType, options);
956+
if (llvm::isa<BlockArgument>(value)) {
957+
return cast<BufferLikeType>(
958+
bufferization::getMemRefType(tensorType, options));
959+
}
958960

959961
// Value is an OpResult.
960962
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
966968
// If the OpResult has an equivalent OpOperand, both OpResult and
967969
// OpOperand bufferize to the exact same buffer type.
968970
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
969-
return asMemRefType(getBufferType(equivalentOperand, options,
970-
bufferizationState, invocationStack));
971+
return getBufferType(equivalentOperand, options, bufferizationState,
972+
invocationStack);
971973
}
972974

973975
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
977979
if (!memSpace.has_value())
978980
return op->emitError("could not infer memory space");
979981

980-
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
982+
return cast<BufferLikeType>(
983+
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
981984
}
982985

983986
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
222222
return {};
223223
}
224224

225-
FailureOr<BaseMemRefType>
225+
FailureOr<BufferLikeType>
226226
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
227227
const BufferizationState &state,
228228
SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
245245
return getOperation()->emitError("could not infer memory space");
246246
}
247247

248-
return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
248+
return cast<BufferLikeType>(
249+
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
249250
}
250251

251252
LogicalResult AllocTensorOp::verify() {

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ struct CallOpInterface
211211
return result;
212212
}
213213

214-
FailureOr<BaseMemRefType>
214+
FailureOr<BufferLikeType>
215215
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
216216
const BufferizationState &state,
217217
SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
229229
Type resultType =
230230
funcType.getResult(cast<OpResult>(value).getResultNumber());
231231
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
232-
return bufferizedType;
232+
return cast<BufferLikeType>(bufferizedType);
233233

234234
// Otherwise, call the type converter to compute the bufferized type.
235235
auto tensorType = cast<TensorType>(resultType);
236-
return options.functionArgTypeConverterFn(
237-
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
236+
return cast<BufferLikeType>(options.functionArgTypeConverterFn(
237+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
238+
options));
238239
}
239240

240241
/// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
396397
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
397398
}
398399

399-
FailureOr<BaseMemRefType>
400+
FailureOr<BufferLikeType>
400401
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
401402
const BufferizationState &state,
402403
SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
405406

406407
// Function arguments are special.
407408
if (bbArg.getOwner() == &funcOp.getBody().front())
408-
return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
409-
options);
409+
return cast<BufferLikeType>(
410+
getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
410411

411412
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
412413
getBufferType(op, value, options, state, invocationStack);

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ struct IfOpInterface
274274
return success();
275275
}
276276

277-
FailureOr<BaseMemRefType>
277+
FailureOr<BufferLikeType>
278278
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
279279
const BufferizationState &state,
280280
SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
313313

314314
// Best case: Both branches have the exact same buffer type.
315315
if (thenBufferType == elseBufferType)
316-
return thenBufferType;
316+
return cast<BufferLikeType>(thenBufferType);
317317

318318
// Memory space mismatch.
319319
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
320320
return op->emitError("inconsistent memory space on then/else branches");
321321

322322
// Layout maps are different: Promote to fully dynamic layout map.
323-
return getMemRefTypeWithFullyDynamicLayout(
324-
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
323+
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
324+
cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
325325
}
326326
};
327327

@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
392392
return success();
393393
}
394394

395-
FailureOr<BaseMemRefType>
395+
FailureOr<BufferLikeType>
396396
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
397397
const BufferizationState &state,
398398
SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
436436
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
437437
}
438438

439-
return bufferType;
439+
return cast<BufferLikeType>(bufferType);
440440
}
441441
};
442442

@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
522522
/// If both buffer types are equal, no casts are needed the computed buffer type
523523
/// can be used directly. Otherwise, the buffer types can only differ in their
524524
/// layout map and a cast must be inserted.
525-
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
525+
static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
526526
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
527527
const BufferizationOptions &options, const BufferizationState &state,
528528
SmallVector<Value> &invocationStack) {
529529
// Determine the buffer type of the init_arg.
530-
auto initArgBufferType = bufferization::detail::asMemRefType(
531-
bufferization::getBufferType(initArg, options, state, invocationStack));
530+
auto initArgBufferType =
531+
bufferization::getBufferType(initArg, options, state, invocationStack);
532532
if (failed(initArgBufferType))
533533
return failure();
534534

@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
547547
}
548548

549549
// Compute the buffer type of the yielded value.
550-
BaseMemRefType yieldedValueBufferType;
550+
BufferLikeType yieldedValueBufferType;
551551
if (isa<BaseMemRefType>(yieldedValue.getType())) {
552552
// scf.yield was already bufferized.
553-
yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
553+
yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
554554
} else {
555555
// Note: This typically triggers a recursive call for the buffer type of
556556
// the iter_arg.
557-
auto maybeBufferType =
558-
bufferization::detail::asMemRefType(bufferization::getBufferType(
559-
yieldedValue, options, state, invocationStack));
557+
auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
558+
state, invocationStack);
560559
if (failed(maybeBufferType))
561560
return failure();
562561
yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
584583
"expected same shape");
585584
}
586585
#endif // NDEBUG
587-
return getMemRefTypeWithFullyDynamicLayout(
588-
iterTensorType, yieldedBufferType.getMemorySpace());
586+
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
587+
iterTensorType, yieldedBufferType.getMemorySpace()));
589588
}
590589

591590
/// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
708707
return success();
709708
}
710709

711-
FailureOr<BaseMemRefType>
710+
FailureOr<BufferLikeType>
712711
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
713712
const BufferizationState &state,
714713
SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
719718
if (auto opResult = dyn_cast<OpResult>(value)) {
720719
// The type of an OpResult must match the corresponding iter_arg type.
721720
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
722-
auto bufferType =
723-
bufferization::getBufferType(bbArg, options, state, invocationStack);
724-
if (failed(bufferType))
725-
return failure();
726-
assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
727-
return cast<BaseMemRefType>(*bufferType);
721+
return bufferization::getBufferType(bbArg, options, state,
722+
invocationStack);
728723
}
729724

730725
// Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
10471042
return success();
10481043
}
10491044

1050-
FailureOr<BaseMemRefType>
1045+
FailureOr<BufferLikeType>
10511046
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
10521047
const BufferizationState &state,
10531048
SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
10811076
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
10821077
if (!isa<TensorType>(conditionYieldedVal.getType())) {
10831078
// scf.condition was already bufferized.
1084-
return cast<BaseMemRefType>(conditionYieldedVal.getType());
1079+
return cast<BufferLikeType>(conditionYieldedVal.getType());
10851080
}
1086-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1087-
conditionYieldedVal, options, state, invocationStack));
1081+
return bufferization::getBufferType(conditionYieldedVal, options, state,
1082+
invocationStack);
10881083
}
10891084

10901085
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
13031298
return success();
13041299
}
13051300

1306-
FailureOr<BaseMemRefType>
1301+
FailureOr<BufferLikeType>
13071302
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
13081303
const BufferizationState &state,
13091304
SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
13121307
if (auto bbArg = dyn_cast<BlockArgument>(value))
13131308
// A tensor block argument has the same bufferized type as the
13141309
// corresponding output operand.
1315-
return bufferization::detail::asMemRefType(
1316-
bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
1317-
options, state, invocationStack));
1310+
return bufferization::getBufferType(
1311+
forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1312+
invocationStack);
13181313

13191314
// The bufferized result type is the same as the bufferized type of the
13201315
// corresponding output operand.
1321-
return bufferization::detail::asMemRefType(bufferization::getBufferType(
1316+
return bufferization::getBufferType(
13221317
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1323-
state, invocationStack));
1318+
state, invocationStack);
13241319
}
13251320

13261321
bool isRepetitiveRegion(Operation *op, unsigned index) const {

0 commit comments

Comments
 (0)