Skip to content

[AMDGPU] Tail call support for whole wave functions #145860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: users/rovka/whole-wave-funcs-call
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7977,13 +7977,18 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
}
case Intrinsic::amdgcn_call_whole_wave: {
TargetLowering::ArgListTy Args;
bool isTailCall = I.isTailCall();

// The first argument is the callee. Skip it when assembling the call args.
TargetLowering::ArgListEntry Arg;
for (unsigned Idx = 1; Idx < I.arg_size(); ++Idx) {
Arg.Node = getValue(I.getArgOperand(Idx));
Arg.Ty = I.getArgOperand(Idx)->getType();
Arg.setAttributes(&I, Idx);

if (Arg.IsSRet && isa<Instruction>(I.getArgOperand(Idx)))
isTailCall = false;

Comment on lines +7989 to +7991
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include this in that isEligibleFoTailCall function instead? And comment it?

Args.push_back(Arg);
}

Expand All @@ -7998,7 +8003,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
.setChain(getRoot())
.setCallee(CallingConv::AMDGPU_Gfx_WholeWave, I.getType(),
getValue(I.getArgOperand(0)), std::move(Args))
.setTailCall(false)
.setTailCall(isTailCall && canTailCall(I))
.setIsPreallocated(
I.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
.setConvergent(I.isConvergent())
Expand Down Expand Up @@ -8879,6 +8884,29 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
return Result;
}

bool SelectionDAGBuilder::canTailCall(const CallBase &CB) const {
bool isMustTailCall = CB.isMustTailCall();

// Avoid emitting tail calls in functions with the disable-tail-calls
// attribute.
auto *Caller = CB.getParent()->getParent();
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
"true" &&
!isMustTailCall)
return false;

// We can't tail call inside a function with a swifterror argument. Lowering
// does not support this yet. It would have to move into the swifterror
// register before the call.
if (DAG.getTargetLoweringInfo().supportSwiftError() &&
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
return false;

// Check if target-independent constraints permit a tail call here.
// Target-dependent constraints are checked within TLI->LowerCallTo.
return isInTailCallPosition(CB, DAG.getTarget());
}

void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
bool isTailCall, bool isMustTailCall,
const BasicBlock *EHPadBB,
Expand All @@ -8893,21 +8921,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
const Value *SwiftErrorVal = nullptr;
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

if (isTailCall) {
// Avoid emitting tail calls in functions with the disable-tail-calls
// attribute.
auto *Caller = CB.getParent()->getParent();
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
"true" && !isMustTailCall)
isTailCall = false;

// We can't tail call inside a function with a swifterror argument. Lowering
// does not support this yet. It would have to move into the swifterror
// register before the call.
if (TLI.supportSwiftError() &&
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
isTailCall = false;
}
if (isTailCall)
isTailCall = canTailCall(CB);

for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
TargetLowering::ArgListEntry Entry;
Expand Down Expand Up @@ -8952,11 +8967,6 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
Args.push_back(Entry);
}

// Check if target-independent constraints permit a tail call here.
// Target-dependent constraints are checked within TLI->LowerCallTo.
if (isTailCall && !isInTailCallPosition(CB, DAG.getTarget()))
isTailCall = false;

// Disable tail calls if there is an swifterror argument. Targets have not
// been updated to support tail calls.
if (TLI.supportSwiftError() && SwiftErrorVal)
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,10 @@ class SelectionDAGBuilder {
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
const TargetLowering::PtrAuthInfo *PAI = nullptr);

// Check some of the target-independent constraints for tail calls. This does
// not iterate over the call arguments.
bool canTailCall(const CallBase &CB) const;

// Lower range metadata from 0 to N to assert zext to an integer of nearest
// floor power of two.
SDValue lowerRangeToAssertZExt(SelectionDAG &DAG, const Instruction &I,
Expand Down
35 changes: 28 additions & 7 deletions llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,8 +993,14 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
}

return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
AMDGPU::SI_TCRETURN;
if (CallerF.getFunction().getCallingConv() ==
CallingConv::AMDGPU_Gfx_WholeWave)
return AMDGPU::SI_TCRETURN_GFX_WholeWave;

if (CC == CallingConv::AMDGPU_Gfx || CC == CallingConv::AMDGPU_Gfx_WholeWave)
return AMDGPU::SI_TCRETURN_GFX;

return AMDGPU::SI_TCRETURN;
}

// Add operands to call instruction to track the callee.
Expand Down Expand Up @@ -1273,6 +1279,13 @@ bool AMDGPUCallLowering::lowerTailCall(
unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);

if (FuncInfo->isWholeWaveFunction())
addOriginalExecToReturn(MF, MIB);

// Keep track of the index of the next operand to be added to the call
unsigned CalleeIdx = MIB->getNumOperands();

if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
return false;

Expand Down Expand Up @@ -1390,7 +1403,7 @@ bool AMDGPUCallLowering::lowerTailCall(
// If we have -tailcallopt, we need to adjust the stack. We'll do the call
// sequence start and end here.
if (!IsSibCall) {
MIB->getOperand(1).setImm(FPDiff);
MIB->getOperand(CalleeIdx + 1).setImm(FPDiff);
CallSeqStart.addImm(NumBytes).addImm(0);
// End the call sequence *before* emitting the call. Normally, we would
// tidy the frame up after the call. However, here, we've laid out the
Expand All @@ -1402,16 +1415,24 @@ bool AMDGPUCallLowering::lowerTailCall(
// Now we can add the actual call instruction to the correct basic block.
MIRBuilder.insertInstr(MIB);

// If this is a whole wave tail call, we need to constrain the register for
// the original EXEC.
if (MIB->getOpcode() == AMDGPU::SI_TCRETURN_GFX_WholeWave) {
MIB->getOperand(0).setReg(
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
}

// If Callee is a reg, since it is used by a target specific
// instruction, it must have a register class matching the
// constraint of that instruction.

// FIXME: We should define regbankselectable call instructions to handle
// divergent call targets.
if (MIB->getOperand(0).isReg()) {
MIB->getOperand(0).setReg(
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
if (MIB->getOperand(CalleeIdx).isReg()) {
MIB->getOperand(CalleeIdx).setReg(constrainOperandRegClass(
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
MIB->getOperand(CalleeIdx), CalleeIdx));
}

MF.getFrameInfo().setHasTailCall();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5639,6 +5639,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(CALL)
NODE_NAME_CASE(TC_RETURN)
NODE_NAME_CASE(TC_RETURN_GFX)
NODE_NAME_CASE(TC_RETURN_GFX_WholeWave)
NODE_NAME_CASE(TC_RETURN_CHAIN)
NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR)
NODE_NAME_CASE(TRAP)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ enum NodeType : unsigned {
CALL,
TC_RETURN,
TC_RETURN_GFX,
TC_RETURN_GFX_WholeWave,
TC_RETURN_CHAIN,
TC_RETURN_CHAIN_DVGPR,
TRAP,
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def AMDGPUtc_return_gfx: SDNode<"AMDGPUISD::TC_RETURN_GFX", AMDGPUTCReturnTP,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
>;

def AMDGPUtc_return_gfx_ww: SDNode<"AMDGPUISD::TC_RETURN_GFX_WholeWave", AMDGPUTCReturnTP,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
>;

def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN",
SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
Expand Down
20 changes: 17 additions & 3 deletions llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,9 +1125,18 @@ void SIFrameLowering::emitCSRSpillRestores(
RestoreWWMRegisters(WWMCalleeSavedRegs);

// The original EXEC is the first operand of the return instruction.
const MachineInstr &Return = MBB.instr_back();
assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
"Unexpected return inst");
MachineInstr &Return = MBB.instr_back();
unsigned Opcode = Return.getOpcode();
switch (Opcode) {
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
Opcode = AMDGPU::SI_RETURN;
break;
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
Opcode = AMDGPU::SI_TCRETURN_GFX;
break;
default:
llvm_unreachable("Unexpected return inst");
}
Register OrigExec = Return.getOperand(0).getReg();

if (!WWMScratchRegs.empty()) {
Expand All @@ -1141,6 +1150,11 @@ void SIFrameLowering::emitCSRSpillRestores(
// Restore original EXEC.
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);

// Drop the first operand and update the opcode.
Return.removeOperand(0);
Return.setDesc(TII->get(Opcode));

return;
}

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4131,6 +4131,11 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
break;
}

// If the caller is a whole wave function, we need to use a special opcode
// so we can patch up EXEC.
if (Info->isWholeWaveFunction())
OPC = AMDGPUISD::TC_RETURN_GFX_WholeWave;

return DAG.getNode(OPC, DL, MVT::Other, Ops);
}

Expand Down Expand Up @@ -5872,6 +5877,7 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.eraseFromParent();
return SplitBB;
}
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: {
assert(MFI->isWholeWaveFunction());

Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2515,7 +2515,6 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64));
break;
}
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
case AMDGPU::SI_RETURN: {
const MachineFunction *MF = MBB.getParent();
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,33 @@ def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI <
def : GCNPat<
(AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>;

// Restores the previous EXEC and otherwise behaves entirely like a SI_TCRETURN.
// This is used for tail calls *from* a whole wave function. Tail calls to
// a whole wave function may use the usual opcodes, depending on the calling
// convention of the caller.
def SI_TCRETURN_GFX_WholeWave : SPseudoInstSI <
(outs),
(ins SReg_1:$orig_exec, Gfx_CCR_SGPR_64:$src0, unknown:$callee, i32imm:$fpdiff)> {
let isCall = 1;
let isTerminator = 1;
let isReturn = 1;
let isBarrier = 1;
let UseNamedOperandTable = 1;
let SchedRW = [WriteBranch];
let isConvergent = 1;

// We're going to use custom handling to set the $orig_exec to the correct value.
let usesCustomInserter = 1;
}

// Generate a SI_TCRETURN_GFX_WholeWave pseudo with a placeholder for its
// argument. It will be filled in by the custom inserter.
def : GCNPat<
(AMDGPUtc_return_gfx_ww i64:$src0, tglobaladdr:$callee, i32:$fpdiff),
(SI_TCRETURN_GFX_WholeWave (i1 (IMPLICIT_DEF)), Gfx_CCR_SGPR_64:$src0,
tglobaladdr:$callee, i32:$fpdiff)>;


// Return for returning shaders to a shader variant epilog.
def SI_RETURN_TO_EPILOG : SPseudoInstSI <
(outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,7 @@ constexpr bool mayTailCallThisCC(CallingConv::ID CC) {
switch (CC) {
case CallingConv::C:
case CallingConv::AMDGPU_Gfx:
case CallingConv::AMDGPU_Gfx_WholeWave:
return true;
default:
return canGuaranteeTCO(CC);
Expand Down
Loading
Loading