Skip to content

Commit 4594737

Browse files
committed
[AMDGPU] Tail call support for whole wave functions
Support tail calls to whole wave functions (trivial) and from whole wave functions (slightly more involved because we need a new pseudo for the tail call return, that patches up the EXEC mask). Move the expansion of whole wave function return pseudos (regular and tail call returns) to prolog epilog insertion, since that's where we patch up the EXEC mask. Unnecessary register spills will be dealt with in a future patch.
1 parent 3cc5557 commit 4594737

15 files changed

+2477
-41
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7977,13 +7977,18 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79777977
}
79787978
case Intrinsic::amdgcn_call_whole_wave: {
79797979
TargetLowering::ArgListTy Args;
7980+
bool isTailCall = I.isTailCall();
79807981

79817982
// The first argument is the callee. Skip it when assembling the call args.
79827983
TargetLowering::ArgListEntry Arg;
79837984
for (unsigned Idx = 1; Idx < I.arg_size(); ++Idx) {
79847985
Arg.Node = getValue(I.getArgOperand(Idx));
79857986
Arg.Ty = I.getArgOperand(Idx)->getType();
79867987
Arg.setAttributes(&I, Idx);
7988+
7989+
if (Arg.IsSRet && isa<Instruction>(I.getArgOperand(Idx)))
7990+
isTailCall = false;
7991+
79877992
Args.push_back(Arg);
79887993
}
79897994

@@ -7998,7 +8003,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79988003
.setChain(getRoot())
79998004
.setCallee(CallingConv::AMDGPU_Gfx_WholeWave, I.getType(),
80008005
getValue(I.getArgOperand(0)), std::move(Args))
8001-
.setTailCall(false)
8006+
.setTailCall(isTailCall && canTailCall(I))
80028007
.setIsPreallocated(
80038008
I.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
80048009
.setConvergent(I.isConvergent())
@@ -8879,6 +8884,29 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
88798884
return Result;
88808885
}
88818886

8887+
bool SelectionDAGBuilder::canTailCall(const CallBase &CB) const {
8888+
bool isMustTailCall = CB.isMustTailCall();
8889+
8890+
// Avoid emitting tail calls in functions with the disable-tail-calls
8891+
// attribute.
8892+
auto *Caller = CB.getParent()->getParent();
8893+
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
8894+
"true" &&
8895+
!isMustTailCall)
8896+
return false;
8897+
8898+
// We can't tail call inside a function with a swifterror argument. Lowering
8899+
// does not support this yet. It would have to move into the swifterror
8900+
// register before the call.
8901+
if (DAG.getTargetLoweringInfo().supportSwiftError() &&
8902+
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
8903+
return false;
8904+
8905+
// Check if target-independent constraints permit a tail call here.
8906+
// Target-dependent constraints are checked within TLI->LowerCallTo.
8907+
return isInTailCallPosition(CB, DAG.getTarget());
8908+
}
8909+
88828910
void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
88838911
bool isTailCall, bool isMustTailCall,
88848912
const BasicBlock *EHPadBB,
@@ -8893,21 +8921,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
88938921
const Value *SwiftErrorVal = nullptr;
88948922
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
88958923

8896-
if (isTailCall) {
8897-
// Avoid emitting tail calls in functions with the disable-tail-calls
8898-
// attribute.
8899-
auto *Caller = CB.getParent()->getParent();
8900-
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
8901-
"true" && !isMustTailCall)
8902-
isTailCall = false;
8903-
8904-
// We can't tail call inside a function with a swifterror argument. Lowering
8905-
// does not support this yet. It would have to move into the swifterror
8906-
// register before the call.
8907-
if (TLI.supportSwiftError() &&
8908-
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
8909-
isTailCall = false;
8910-
}
8924+
if (isTailCall)
8925+
isTailCall = canTailCall(CB);
89118926

89128927
for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
89138928
TargetLowering::ArgListEntry Entry;
@@ -8952,11 +8967,6 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89528967
Args.push_back(Entry);
89538968
}
89548969

8955-
// Check if target-independent constraints permit a tail call here.
8956-
// Target-dependent constraints are checked within TLI->LowerCallTo.
8957-
if (isTailCall && !isInTailCallPosition(CB, DAG.getTarget()))
8958-
isTailCall = false;
8959-
89608970
// Disable tail calls if there is an swifterror argument. Targets have not
89618971
// been updated to support tail calls.
89628972
if (TLI.supportSwiftError() && SwiftErrorVal)

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ class SelectionDAGBuilder {
408408
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
409409
const TargetLowering::PtrAuthInfo *PAI = nullptr);
410410

411+
// Check some of the target-independent constraints for tail calls. This does
412+
// not iterate over the call arguments.
413+
bool canTailCall(const CallBase &CB) const;
414+
411415
// Lower range metadata from 0 to N to assert zext to an integer of nearest
412416
// floor power of two.
413417
SDValue lowerRangeToAssertZExt(SelectionDAG &DAG, const Instruction &I,

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -993,8 +993,14 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
993993
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
994994
}
995995

996-
return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
997-
AMDGPU::SI_TCRETURN;
996+
if (CallerF.getFunction().getCallingConv() ==
997+
CallingConv::AMDGPU_Gfx_WholeWave)
998+
return AMDGPU::SI_TCRETURN_GFX_WholeWave;
999+
1000+
if (CC == CallingConv::AMDGPU_Gfx || CC == CallingConv::AMDGPU_Gfx_WholeWave)
1001+
return AMDGPU::SI_TCRETURN_GFX;
1002+
1003+
return AMDGPU::SI_TCRETURN;
9981004
}
9991005

10001006
// Add operands to call instruction to track the callee.
@@ -1273,6 +1279,13 @@ bool AMDGPUCallLowering::lowerTailCall(
12731279
unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
12741280
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
12751281
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);
1282+
1283+
if (FuncInfo->isWholeWaveFunction())
1284+
addOriginalExecToReturn(MF, MIB);
1285+
1286+
// Keep track of the index of the next operand to be added to the call
1287+
unsigned CalleeIdx = MIB->getNumOperands();
1288+
12761289
if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
12771290
return false;
12781291

@@ -1390,7 +1403,7 @@ bool AMDGPUCallLowering::lowerTailCall(
13901403
// If we have -tailcallopt, we need to adjust the stack. We'll do the call
13911404
// sequence start and end here.
13921405
if (!IsSibCall) {
1393-
MIB->getOperand(1).setImm(FPDiff);
1406+
MIB->getOperand(CalleeIdx + 1).setImm(FPDiff);
13941407
CallSeqStart.addImm(NumBytes).addImm(0);
13951408
// End the call sequence *before* emitting the call. Normally, we would
13961409
// tidy the frame up after the call. However, here, we've laid out the
@@ -1402,16 +1415,24 @@ bool AMDGPUCallLowering::lowerTailCall(
14021415
// Now we can add the actual call instruction to the correct basic block.
14031416
MIRBuilder.insertInstr(MIB);
14041417

1418+
// If this is a whole wave tail call, we need to constrain the register for
1419+
// the original EXEC.
1420+
if (MIB->getOpcode() == AMDGPU::SI_TCRETURN_GFX_WholeWave) {
1421+
MIB->getOperand(0).setReg(
1422+
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1423+
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
1424+
}
1425+
14051426
// If Callee is a reg, since it is used by a target specific
14061427
// instruction, it must have a register class matching the
14071428
// constraint of that instruction.
14081429

14091430
// FIXME: We should define regbankselectable call instructions to handle
14101431
// divergent call targets.
1411-
if (MIB->getOperand(0).isReg()) {
1412-
MIB->getOperand(0).setReg(
1413-
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1414-
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
1432+
if (MIB->getOperand(CalleeIdx).isReg()) {
1433+
MIB->getOperand(CalleeIdx).setReg(constrainOperandRegClass(
1434+
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
1435+
MIB->getOperand(CalleeIdx), CalleeIdx));
14151436
}
14161437

14171438
MF.getFrameInfo().setHasTailCall();

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5578,6 +5578,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
55785578
NODE_NAME_CASE(CALL)
55795579
NODE_NAME_CASE(TC_RETURN)
55805580
NODE_NAME_CASE(TC_RETURN_GFX)
5581+
NODE_NAME_CASE(TC_RETURN_GFX_WholeWave)
55815582
NODE_NAME_CASE(TC_RETURN_CHAIN)
55825583
NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR)
55835584
NODE_NAME_CASE(TRAP)

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ enum NodeType : unsigned {
402402
CALL,
403403
TC_RETURN,
404404
TC_RETURN_GFX,
405+
TC_RETURN_GFX_WholeWave,
405406
TC_RETURN_CHAIN,
406407
TC_RETURN_CHAIN_DVGPR,
407408
TRAP,

llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def AMDGPUtc_return_gfx: SDNode<"AMDGPUISD::TC_RETURN_GFX", AMDGPUTCReturnTP,
9494
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
9595
>;
9696

97+
def AMDGPUtc_return_gfx_ww: SDNode<"AMDGPUISD::TC_RETURN_GFX_WholeWave", AMDGPUTCReturnTP,
98+
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
99+
>;
100+
97101
def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN",
98102
SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>,
99103
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]

llvm/lib/Target/AMDGPU/SIFrameLowering.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,18 @@ void SIFrameLowering::emitCSRSpillRestores(
11251125
RestoreWWMRegisters(WWMCalleeSavedRegs);
11261126

11271127
// The original EXEC is the first operand of the return instruction.
1128-
const MachineInstr &Return = MBB.instr_back();
1129-
assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
1130-
"Unexpected return inst");
1128+
MachineInstr &Return = MBB.instr_back();
1129+
unsigned Opcode = Return.getOpcode();
1130+
switch (Opcode) {
1131+
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
1132+
Opcode = AMDGPU::SI_RETURN;
1133+
break;
1134+
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
1135+
Opcode = AMDGPU::SI_TCRETURN_GFX;
1136+
break;
1137+
default:
1138+
llvm_unreachable("Unexpected return inst");
1139+
}
11311140
Register OrigExec = Return.getOperand(0).getReg();
11321141

11331142
if (!WWMScratchRegs.empty()) {
@@ -1141,6 +1150,11 @@ void SIFrameLowering::emitCSRSpillRestores(
11411150
// Restore original EXEC.
11421151
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
11431152
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
1153+
1154+
// Drop the first operand and update the opcode.
1155+
Return.removeOperand(0);
1156+
Return.setDesc(TII->get(Opcode));
1157+
11441158
return;
11451159
}
11461160

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4130,6 +4130,11 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
41304130
break;
41314131
}
41324132

4133+
// If the caller is a whole wave function, we need to use a special opcode
4134+
// so we can patch up EXEC.
4135+
if (Info->isWholeWaveFunction())
4136+
OPC = AMDGPUISD::TC_RETURN_GFX_WholeWave;
4137+
41334138
return DAG.getNode(OPC, DL, MVT::Other, Ops);
41344139
}
41354140

@@ -5871,6 +5876,7 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
58715876
MI.eraseFromParent();
58725877
return SplitBB;
58735878
}
5879+
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
58745880
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: {
58755881
assert(MFI->isWholeWaveFunction());
58765882

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2515,7 +2515,6 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
25152515
MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64));
25162516
break;
25172517
}
2518-
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
25192518
case AMDGPU::SI_RETURN: {
25202519
const MachineFunction *MF = MBB.getParent();
25212520
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,33 @@ def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI <
670670
def : GCNPat<
671671
(AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>;
672672

673+
// Restores the previous EXEC and otherwise behaves entirely like a SI_TCRETURN.
674+
// This is used for tail calls *from* a whole wave function. Tail calls to
675+
// a whole wave function may use the usual opcodes, depending on the calling
676+
// convention of the caller.
677+
def SI_TCRETURN_GFX_WholeWave : SPseudoInstSI <
678+
(outs),
679+
(ins SReg_1:$orig_exec, Gfx_CCR_SGPR_64:$src0, unknown:$callee, i32imm:$fpdiff)> {
680+
let isCall = 1;
681+
let isTerminator = 1;
682+
let isReturn = 1;
683+
let isBarrier = 1;
684+
let UseNamedOperandTable = 1;
685+
let SchedRW = [WriteBranch];
686+
let isConvergent = 1;
687+
688+
// We're going to use custom handling to set the $orig_exec to the correct value.
689+
let usesCustomInserter = 1;
690+
}
691+
692+
// Generate a SI_TCRETURN_GFX_WholeWave pseudo with a placeholder for its
693+
// argument. It will be filled in by the custom inserter.
694+
def : GCNPat<
695+
(AMDGPUtc_return_gfx_ww i64:$src0, tglobaladdr:$callee, i32:$fpdiff),
696+
(SI_TCRETURN_GFX_WholeWave (i1 (IMPLICIT_DEF)), Gfx_CCR_SGPR_64:$src0,
697+
tglobaladdr:$callee, i32:$fpdiff)>;
698+
699+
673700
// Return for returning shaders to a shader variant epilog.
674701
def SI_RETURN_TO_EPILOG : SPseudoInstSI <
675702
(outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> {

llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,7 @@ constexpr bool mayTailCallThisCC(CallingConv::ID CC) {
14121412
switch (CC) {
14131413
case CallingConv::C:
14141414
case CallingConv::AMDGPU_Gfx:
1415+
case CallingConv::AMDGPU_Gfx_WholeWave:
14151416
return true;
14161417
default:
14171418
return canGuaranteeTCO(CC);

0 commit comments

Comments
 (0)