Skip to content

[mlir] Add requiresReplacedValues and visitReplacedValues to PromotableOpInterface #86792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,10 @@ class LLVM_DbgIntrOp<string name, string argName, list<Trait> traits = []>
}];
}

def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr",
[DeclareOpInterfaceMethods<PromotableOpInterface>]> {
def LLVM_DbgDeclareOp : LLVM_DbgIntrOp<"dbg.declare", "addr", [
DeclareOpInterfaceMethods<PromotableOpInterface, [
"requiresReplacedValues", "visitReplacedValues"
]>]> {
let summary = "Describes how the address relates to a source language variable.";
let arguments = (ins
LLVM_AnyPointer:$addr,
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,36 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
(ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter)
>,
InterfaceMethod<[{
This method allows the promoted operation to visit the SSA values used
in place of the memory slot once the promotion process of the memory
slot is complete.

If this method returns true, the `visitReplacedValues` method on this
operation will be called after the main mutation stage finishes
(i.e., after all ops have been processed with `removeBlockingUses`).

Operations should only the replaced values if the intended
transformation applies to all the replaced values. Furthermore, replaced
values must not be deleted.
}], "bool", "requiresReplacedValues", (ins), [{}],
[{ return false; }]
>,
InterfaceMethod<[{
Transforms the IR using the SSA values that replaced the memory slot.

This method will only be called after all blocking uses have been
scheduled for removal and if `requiresReplacedValues` returned
true.

The rewriter is located after the promotable operation on call. All IR
mutations must happen through the rewriter. During the transformation,
*no operation should be deleted*.
}],
"void", "visitReplacedValues",
(ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name suggestion: replacingValues

"::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
Comment on lines +259 to +260
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like to be absolutely complete this method should also provide the memory slot, as one may want to know which memory slot specifically those definitions correspond to. But at the same time I am not really sure that this would be useful in practice, so change it only if you feel like it.

>,
];
}

Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,6 @@ bool LLVM::StoreOp::canUsesBeRemoved(
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the stored slot
// pointer.
for (Operation *user : slot.ptr.getUsers())
if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
declareOp.getVarInfo(),
declareOp.getLocationExpr());
return DeletionKind::Delete;
}

Expand Down Expand Up @@ -405,6 +398,18 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
return DeletionKind::Keep;
}

bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }

void LLVM::DbgDeclareOp::visitReplacedValues(
ArrayRef<std::pair<Operation *, Value>> definitions,
RewriterBase &rewriter) {
for (auto [op, value] : definitions) {
rewriter.setInsertionPointAfter(op);
rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
getLocationExpr());
}
}

//===----------------------------------------------------------------------===//
// Interfaces for GEPOp
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 15 additions & 1 deletion mlir/lib/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class MemorySlotPromoter {
/// Contains the reaching definition at this operation. Reaching definitions
/// are only computed for promotable memory operations with blocking uses.
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
Expand Down Expand Up @@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
replacedValuesMap[memOp] = stored;
}
}
}
Expand Down Expand Up @@ -552,6 +554,10 @@ void MemorySlotPromoter::removeBlockingUses() {
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());

llvm::SmallVector<Operation *> toErase;
// List of all replaced values in the slot.
llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
// Ops to visit with the `visitReplacedValues` method.
llvm::SmallVector<PromotableOpInterface> toVisit;
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
Expand All @@ -565,7 +571,9 @@ void MemorySlotPromoter::removeBlockingUses() {
slot, info.userToBlockingUses[toPromote], rewriter,
reachingDef) == DeletionKind::Delete)
toErase.push_back(toPromote);

if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
replacedValuesList.push_back({toPromoteMemOp, replacedValue});
continue;
}

Expand All @@ -574,6 +582,12 @@ void MemorySlotPromoter::removeBlockingUses() {
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
rewriter) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteBasic.requiresReplacedValues())
toVisit.push_back(toPromoteBasic);
}
for (PromotableOpInterface op : toVisit) {
rewriter.setInsertionPointAfter(op);
op.visitReplacedValues(replacedValuesList, rewriter);
}

for (Operation *toEraseOp : toErase)
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/LLVMIR/mem2reg-dbginfo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@ llvm.func @basic_store_load(%arg0: i64) -> i64 {
llvm.return %2 : i64
}

// CHECK-LABEL: llvm.func @multiple_store_load
llvm.func @multiple_store_load(%arg0: i64) -> i64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: = llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
// CHECK-NOT: llvm.intr.dbg.declare
llvm.intr.dbg.declare #di_local_variable = %1 : !llvm.ptr
// CHECK-NOT: llvm.store
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
// CHECK-NOT: llvm.intr.dbg.declare
llvm.intr.dbg.declare #di_local_variable = %1 : !llvm.ptr
// CHECK: llvm.intr.dbg.value #[[$VAR]] = %[[LOADED:.*]] : i64
// CHECK: llvm.intr.dbg.value #[[$VAR]] = %[[LOADED]] : i64
// CHECK-NOT: llvm.intr.dbg.value
// CHECK-NOT: llvm.intr.dbg.declare
// CHECK-NOT: llvm.store
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
// CHECK: llvm.return %[[LOADED]] : i64
llvm.return %2 : i64
}

// CHECK-LABEL: llvm.func @block_argument_value
// CHECK-SAME: (%[[ARG0:.*]]: i64, {{.*}})
llvm.func @block_argument_value(%arg0: i64, %arg1: i1) -> i64 {
Expand Down