diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 4fe900e9421f8..06239a01825f3 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1349,6 +1349,41 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, return nullptr; }; + // Special case for reconstructing across a select: + // (Cond ? V1 : (X & Mask)) | + // zext (Cond ? V2 : trunc X) + // -> (Cond ? (V1 | zext V2) : X) + auto foldReconstruction = [&](Value *V1, Value *Masked, + Value *ZExtSel) -> Value * { + if (Opcode != Instruction::Or) + return nullptr; + + Value *X; + const APInt *C; + if (!match(Masked, m_OneUse(m_And(m_Value(X), m_APInt(C))))) + return nullptr; + + Value *V2, *Trunc; + if (!match(ZExtSel, m_ZExt(m_OneUse(m_Select(m_Specific(Cond), m_Value(V2), + m_Value(Trunc)))))) + return nullptr; + + if (*C != APInt::getBitsSetFrom(X->getType()->getScalarSizeInBits(), + Trunc->getType()->getScalarSizeInBits())) { + return nullptr; + } + + if (!match(Trunc, m_Trunc(m_Specific(X)))) + return nullptr; + + Value *ZExtTrue = Builder.CreateZExt(V2, V1->getType()); + Value *True = simplifyBinOp(Opcode, V1, ZExtTrue, FMF, Q); + if (!True) + return nullptr; + + return Builder.CreateSelect(Cond, True, X, I.getName()); + }; + if (LHSIsSelect && RHSIsSelect && A == D) { // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) Cond = A; @@ -1368,6 +1403,8 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, False = simplifyBinOp(Opcode, C, RHS, FMF, Q); if (Value *NewSel = foldAddNegate(B, C, RHS)) return NewSel; + if (Value *NewSel = foldReconstruction(B, C, RHS)) + return NewSel; } else if (RHSIsSelect && RHS->hasOneUse()) { // X op (D ? E : F) -> D ? (X op E) : (X op F) Cond = D; @@ -1375,6 +1412,8 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, False = simplifyBinOp(Opcode, LHS, F, FMF, Q); if (Value *NewSel = foldAddNegate(E, F, LHS)) return NewSel; + if (Value *NewSel = foldReconstruction(E, F, LHS)) + return NewSel; } if (!True || !False) diff --git a/llvm/test/Transforms/InstCombine/select-reconstruction.ll b/llvm/test/Transforms/InstCombine/select-reconstruction.ll new file mode 100644 index 0000000000000..ea1c4e09e7b17 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-reconstruction.ll @@ -0,0 +1,120 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i40 @select_reconstruction_i40(i40 %arg0) { +; CHECK-LABEL: define i40 @select_reconstruction_i40( +; CHECK-SAME: i40 [[ARG0:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = trunc i40 [[ARG0]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i40 0, i40 [[ARG0]] +; CHECK-NEXT: ret i40 [[TMP3]] +; + %low = trunc i40 %arg0 to i8 + %is_low_two = icmp eq i8 %low, 2 + %high = and i40 %arg0, -256 + %select_low = select i1 %is_low_two, i8 0, i8 %low + %select_high = select i1 %is_low_two, i40 0, i40 %high + %zext_low = zext i8 %select_low to i40 + %recomb = or disjoint i40 %select_high, %zext_low + ret i40 %recomb +} + +define i40 @select_reconstruction_any_cmp_val(i40 %arg0, i8 %arg1) { +; CHECK-LABEL: define i40 @select_reconstruction_any_cmp_val( +; CHECK-SAME: i40 [[ARG0:%.*]], i8 [[ARG1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = trunc i40 [[ARG0]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[ARG1]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i40 0, i40 [[ARG0]] +; CHECK-NEXT: ret i40 [[TMP3]] +; + %low = trunc i40 %arg0 to i8 + %is_low_arg1 = icmp eq i8 %low, %arg1 + %high = and i40 %arg0, -256 + %select_low = select i1 %is_low_arg1, i8 0, i8 %low + %select_high = select i1 %is_low_arg1, i40 0, i40 %high + %zext_low = zext i8 %select_low to i40 + %recomb = or disjoint i40 %select_high, %zext_low + ret i40 %recomb +} + +; negative test +define i40 @select_reconstruction_257_mask(i40 %arg0) { +; CHECK-LABEL: define i40 @select_reconstruction_257_mask( +; CHECK-SAME: i40 [[ARG0:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = trunc i40 [[ARG0]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = and i40 [[ARG0]], -257 +; CHECK-NEXT: [[SELECT_LOW:%.*]] = select i1 [[TMP2]], i8 0, i8 [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP2]], i40 0, i40 [[TMP3]] +; CHECK-NEXT: [[ZEXT_LOW:%.*]] = zext i8 [[SELECT_LOW]] to i40 +; CHECK-NEXT: [[RECOMB:%.*]] = or disjoint i40 [[TMP4]], [[ZEXT_LOW]] +; CHECK-NEXT: ret i40 [[RECOMB]] +; + %low = trunc i40 %arg0 to i8 + %is_low_two = icmp eq i8 %low, 2 + %high = and i40 %arg0, -257 + %select_low = select i1 %is_low_two, i8 0, i8 %low + %select_high = select i1 %is_low_two, i40 0, i40 %high + %zext_low = zext i8 %select_low to i40 + %recomb = or disjoint i40 %select_high, %zext_low + ret i40 %recomb +} + +define i40 @select_reconstruction_i16_mask(i40 %arg0) { +; CHECK-LABEL: define i40 @select_reconstruction_i16_mask( +; CHECK-SAME: i40 [[ARG0:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = trunc i40 [[ARG0]] to i16 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i16 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i40 0, i40 [[ARG0]] +; CHECK-NEXT: ret i40 [[TMP3]] +; + %low = trunc i40 %arg0 to i16 + %is_low_two = icmp eq i16 %low, 2 + %high = and i40 %arg0, -65536 + %select_low = select i1 %is_low_two, i16 0, i16 %low + %select_high = select i1 %is_low_two, i40 0, i40 %high + %zext_low = zext i16 %select_low to i40 + %recomb = or disjoint i40 %select_high, %zext_low + ret i40 %recomb +} + +define <2 x i32> @select_reconstruction_vec_any_cmp_val(<2 x i32> %arg0, <2 x i8> %arg1) { +; CHECK-LABEL: define <2 x i32> @select_reconstruction_vec_any_cmp_val( +; CHECK-SAME: <2 x i32> [[ARG0:%.*]], <2 x i8> [[ARG1:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> [[ARG0]] to <2 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <2 x i8> [[ARG1]], [[TMP1]] +; CHECK-NEXT: [[TMP7:%.*]] = select <2 x i1> [[TMP2]], <2 x i32> zeroinitializer, <2 x i32> [[ARG0]] +; CHECK-NEXT: ret <2 x i32> [[TMP7]] +; + %low = trunc <2 x i32> %arg0 to <2 x i8> + %is_low_arg1 = icmp eq <2 x i8> %low, %arg1 + %high = and <2 x i32> %arg0, + %select_low = select <2 x i1> %is_low_arg1, <2 x i8> , <2 x i8> %low + %select_high = select <2 x i1> %is_low_arg1, <2 x i32> , <2 x i32> %high + %zext_low = zext <2 x i8> %select_low to <2 x i32> + %recomb = or <2 x i32> %select_high, %zext_low + ret <2 x i32> %recomb +} + +; negative test +define i40 @select_reconstruction_impure_i16_mask_and(i40 %arg0) { +; CHECK-LABEL: define i40 @select_reconstruction_impure_i16_mask_and( +; CHECK-SAME: i40 [[ARG0:%.*]]) { +; CHECK-NEXT: [[TMP1:%.*]] = trunc i40 [[ARG0]] to i16 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i16 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = and i40 [[ARG0]], 180 +; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP2]], i16 0, i16 [[TMP1]] +; CHECK-NEXT: [[TMP5:%.*]] = select i1 [[TMP2]], i40 0, i40 [[TMP3]] +; CHECK-NEXT: [[TMP6:%.*]] = zext i16 [[TMP4]] to i40 +; CHECK-NEXT: [[TMP7:%.*]] = and i40 [[TMP5]], [[TMP6]] +; CHECK-NEXT: ret i40 [[TMP7]] +; + %low = trunc i40 %arg0 to i16 + %is_low_two = icmp eq i16 %low, 2 + %high = and i40 %arg0, -65356 + %select_low = select i1 %is_low_two, i16 0, i16 %low + %select_high = select i1 %is_low_two, i40 0, i40 %high + %zext_low = zext i16 %select_low to i40 + %recomb = and i40 %select_high, %zext_low + ret i40 %recomb +}