diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index f4163b0743a9a..711c5de7c909f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1339,6 +1339,18 @@ class VPWidenRecipe : public VPRecipeWithIRFlags, public VPIRMetadata { void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; #endif + + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + + // The operand(1) of the extract value is the index to extract, which should + // be scalar. + if (Opcode == Instruction::ExtractValue) + return Op == getOperand(1); + + return false; + } }; /// VPWidenCastRecipe is a recipe to create vector cast instructions. @@ -1533,6 +1545,15 @@ class VPWidenCallRecipe : public VPRecipeWithIRFlags, public VPIRMetadata { void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; #endif + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + // The last operand is the function pointer of the underlying scalar + // function, which should be scalar. + return Op == getOperand(getNumOperands() - 1); + } }; /// A recipe representing a sequence of load -> update -> store as part of diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 3e12fdf9163eb..c84016ba166f3 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -938,6 +938,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const { default: return false; case Instruction::ExtractElement: + case Instruction::ExtractValue: return Op == getOperand(1); case Instruction::PHI: return true; @@ -959,8 +960,9 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const { case VPInstruction::PtrAdd: return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this); case VPInstruction::ComputeAnyOfResult: - case VPInstruction::ComputeFindLastIVResult: return Op == getOperand(1); + case VPInstruction::ComputeFindLastIVResult: + return Op == getOperand(1) || Op == getOperand(2); }; llvm_unreachable("switch should return"); }