diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 40578e5edc3ab..8cd82e43cde23 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1349,6 +1349,75 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, return nullptr; }; + // Given: + // (Cond ? TVal : FVal) op (MinMaxCand) [SelectIsLHS == true] + // or + // (MinMaxCand) op (Cond ? TVal : FVal) [SelectIsLHS == false] + // + // If MinMaxCand can be expressed as a select with the same Cond, + // then try to optimize it the same way as if it was a select. + // Such patterns may appear after foldSelectInstWithICmpConst(). + auto foldWithMinMax = [&](Value *Cond, Value *TVal, Value *FVal, + Value *MinMaxCand, + bool SelectIsLHS = true) -> Value * { + if (True && False) + return nullptr; + + const APInt *Cst; + Value *V; + bool FoundMatch = false; + + // umax(V, UMIN+1) is equivalent to (V == UMIN) ? UMIN+1 : V + FoundMatch = match(MinMaxCand, m_UMax(m_Value(V), m_APInt(Cst))) && + !Cst->isMinValue() && (*Cst - 1).isMinValue() && + match(Cond, m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(V), + m_SpecificInt(*Cst - 1))); + + // umin(V, UMAX-1) is equivalent to (V == UMAX) ? UMAX-1 : V + if (!FoundMatch) + FoundMatch = + match(MinMaxCand, m_UMin(m_Value(V), m_APInt(Cst))) && + !Cst->isMaxValue() && (*Cst + 1).isMaxValue() && + match(Cond, m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(V), + m_SpecificInt(*Cst + 1))); + + // smax(V, SMIN+1) is equivalent to (V == SMIN) ? SMIN+1 : V + if (!FoundMatch) + FoundMatch = + match(MinMaxCand, m_SMax(m_Value(V), m_APInt(Cst))) && + !Cst->isMinSignedValue() && (*Cst - 1).isMinSignedValue() && + match(Cond, m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(V), + m_SpecificInt(*Cst - 1))); + + // smin(V, SMAX-1) is equivalent to (V == SMAX) ? SMAX-1 : V + if (!FoundMatch) + FoundMatch = + match(MinMaxCand, m_SMin(m_Value(V), m_APInt(Cst))) && + !Cst->isMaxSignedValue() && (*Cst + 1).isMaxSignedValue() && + match(Cond, m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(V), + m_SpecificInt(*Cst + 1))); + + if (!FoundMatch) + return nullptr; + + Value *OtherTVal = ConstantInt::get(V->getType(), *Cst); + Value *OtherFVal = V; + if (SelectIsLHS) { + True = simplifyBinOp(Opcode, TVal, OtherTVal, FMF, Q); + False = simplifyBinOp(Opcode, FVal, OtherFVal, FMF, Q); + } else { + True = simplifyBinOp(Opcode, OtherTVal, TVal, FMF, Q); + False = simplifyBinOp(Opcode, OtherFVal, FVal, FMF, Q); + } + + if (!True || !False) + return nullptr; + + Value *SI = Builder.CreateSelect(Cond, True, False); + SI->takeName(&I); + return SI; + }; + if (LHSIsSelect && RHSIsSelect && A == D) { // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) Cond = A; @@ -1368,6 +1437,8 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, False = simplifyBinOp(Opcode, C, RHS, FMF, Q); if (Value *NewSel = foldAddNegate(B, C, RHS)) return NewSel; + if (Value *NewSel = foldWithMinMax(Cond, B, C, RHS)) + return NewSel; } else if (RHSIsSelect && RHS->hasOneUse()) { // X op (D ? E : F) -> D ? (X op E) : (X op F) Cond = D; @@ -1375,6 +1446,8 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, False = simplifyBinOp(Opcode, LHS, F, FMF, Q); if (Value *NewSel = foldAddNegate(E, F, LHS)) return NewSel; + if (Value *NewSel = foldWithMinMax(Cond, E, F, LHS, /*SelectIsLHS=*/false)) + return NewSel; } if (!True || !False) diff --git a/llvm/test/Transforms/InstCombine/select_arithmetic_with_min_max.ll b/llvm/test/Transforms/InstCombine/select_arithmetic_with_min_max.ll new file mode 100644 index 0000000000000..5f6f7675e1694 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select_arithmetic_with_min_max.ll @@ -0,0 +1,107 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; These tests check folding of a select and min/max feeding +; the same binary operation. When min/max can be expressed +; as a select with the same condition as in another select, +; the binary op might be applied to the operands of the selects. + +define void @test_umax1(i32 %V, ptr %m1, ptr %m2) { +; CHECK-LABEL: @test_umax1( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[V:%.*]], 0 +; CHECK-NEXT: [[UMAX:%.*]] = call i32 @llvm.umax.i32(i32 [[V]], i32 1) +; CHECK-NEXT: [[OP:%.*]] = select i1 [[CMP]], i32 8, i32 0 +; CHECK-NEXT: store i32 [[UMAX]], ptr [[M1:%.*]], align 4 +; CHECK-NEXT: store i32 [[OP]], ptr [[M2:%.*]], align 4 +; CHECK-NEXT: ret void +; + %cmp = icmp eq i32 %V, 0 +; %umax = select i1 %cmp, i32 1, i32 %V + %umax = call i32 @llvm.umax.i32(i32 %V, i32 1) + %select = select i1 %cmp, i32 9, i32 %V + %op = sub i32 %select, %umax + store i32 %umax, ptr %m1 + store i32 %op, ptr %m2 + ret void +} + +define void @test_umax2(i32 %V, ptr %m1, ptr %m2) { +; CHECK-LABEL: @test_umax2( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[V:%.*]], 0 +; CHECK-NEXT: [[UMAX:%.*]] = call i32 @llvm.umax.i32(i32 [[V]], i32 1) +; CHECK-NEXT: [[OP:%.*]] = select i1 [[CMP]], i32 8, i32 0 +; CHECK-NEXT: store i32 [[UMAX]], ptr [[M1:%.*]], align 4 +; CHECK-NEXT: store i32 [[OP]], ptr [[M2:%.*]], align 4 +; CHECK-NEXT: ret void +; + %cmp = icmp eq i32 0, %V +; %umax = select i1 %cmp, i32 1, i32 %V + %umax = call i32 @llvm.umax.i32(i32 1, i32 %V) + %select = select i1 %cmp, i32 9, i32 %V + %op = sub i32 %select, %umax + store i32 %umax, ptr %m1 + store i32 %op, ptr %m2 + ret void +} + +define void @test_umin(i16 %V, ptr %m1, ptr %m2) { +; CHECK-LABEL: @test_umin( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[V:%.*]], -1 +; CHECK-NEXT: [[UMIN:%.*]] = call i16 @llvm.umin.i16(i16 [[V]], i16 -2) +; CHECK-NEXT: [[OP:%.*]] = select i1 [[CMP]], i16 4, i16 0 +; CHECK-NEXT: store i16 [[UMIN]], ptr [[M1:%.*]], align 2 +; CHECK-NEXT: store i16 [[OP]], ptr [[M2:%.*]], align 2 +; CHECK-NEXT: ret void +; + %cmp = icmp eq i16 %V, 65535 +; %umin = select i1 %cmp, i16 65534, i16 %V + %umin = call i16 @llvm.umin.i16(i16 %V, i16 65534) + %select = select i1 %cmp, i16 65530, i16 %V + %op = sub i16 %umin, %select + store i16 %umin, ptr %m1 + store i16 %op, ptr %m2 + ret void +} + +define void @test_smax(i8 %V, ptr %m1, ptr %m2) { +; CHECK-LABEL: @test_smax( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[V:%.*]], -128 +; CHECK-NEXT: [[SMAX:%.*]] = call i8 @llvm.smax.i8(i8 [[V]], i8 -127) +; CHECK-NEXT: [[OP:%.*]] = select i1 [[CMP]], i8 -128, i8 0 +; CHECK-NEXT: store i8 [[SMAX]], ptr [[M1:%.*]], align 1 +; CHECK-NEXT: store i8 [[OP]], ptr [[M2:%.*]], align 1 +; CHECK-NEXT: ret void +; + %cmp = icmp eq i8 %V, -128 +; %smax = select i1 %cmp, i8 -127, i8 %V + %smax = call i8 @llvm.smax.i8(i8 -127, i8 %V) + %select = select i1 %cmp, i8 1, i8 %V + %op = sub i8 %smax, %select + store i8 %smax, ptr %m1 + store i8 %op, ptr %m2 + ret void +} + +define void @test_smin(i8 %V, ptr %m1, ptr %m2) { +; CHECK-LABEL: @test_smin( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[V:%.*]], 127 +; CHECK-NEXT: [[SMIN:%.*]] = call i8 @llvm.smin.i8(i8 [[V]], i8 126) +; CHECK-NEXT: [[OP:%.*]] = select i1 [[CMP]], i8 6, i8 0 +; CHECK-NEXT: store i8 [[SMIN]], ptr [[M1:%.*]], align 1 +; CHECK-NEXT: store i8 [[OP]], ptr [[M2:%.*]], align 1 +; CHECK-NEXT: ret void +; + %cmp = icmp eq i8 %V, 127 +; %smin = select i1 %cmp, i8 126, i8 %V + %smin = call i8 @llvm.smin.i8(i8 %V, i8 126) + %select = select i1 %cmp, i8 120, i8 %V + %op = sub i8 %smin, %select + store i8 %smin, ptr %m1 + store i8 %op, ptr %m2 + ret void +} + +declare i32 @llvm.umax.i32(i32, i32) +declare i16 @llvm.umin.i16(i16, i16) +declare i8 @llvm.smax.i8(i8, i8) +declare i8 @llvm.smin.i8(i8, i8)