diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index ec77154d17caa..c42075ba43a56 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -24,6 +24,7 @@ #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SDPatternMatch.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -3239,6 +3240,55 @@ static SDValue performBitcastCombine(SDNode *N, return SDValue(); } +static SDValue performAnyAllCombine(SDNode *N, SelectionDAG &DAG) { + // any_true (setcc , 0, eq) => (not (all_true X)) + // all_true (setcc , 0, eq) => (not (any_true X)) + // any_true (setcc , 0, ne) => (any_true X) + // all_true (setcc , 0, ne) => (all_true X) + assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN); + using namespace llvm::SDPatternMatch; + SDLoc DL(N); + static auto SimdCombiner = + [&](Intrinsic::WASMIntrinsics InPre, ISD::CondCode SetType, + Intrinsic::WASMIntrinsics InPost, bool ShouldInvert) -> SDValue { + if (N->getConstantOperandVal(0) != InPre) + return SDValue(); + + SDValue LHS; + if (!sd_match(N->getOperand(1), m_c_SetCC(m_Value(LHS), m_Zero(), + m_SpecificCondCode(SetType)))) + return SDValue(); + + EVT LT = LHS.getValueType(); + unsigned NumElts = LT.getVectorNumElements(); + if (LT.getScalarSizeInBits() > 128 / NumElts) + return SDValue(); + + SDValue Ret = DAG.getZExtOrTrunc( + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i32, + {DAG.getConstant(InPost, DL, MVT::i32), LHS}), + DL, MVT::i1); + if (ShouldInvert) + Ret = DAG.getNOT(DL, Ret, MVT::i1); + return DAG.getZExtOrTrunc(Ret, DL, N->getValueType(0)); + }; + + if (SDValue AnyTrueEQ = SimdCombiner(Intrinsic::wasm_anytrue, ISD::SETEQ, + Intrinsic::wasm_alltrue, true)) + return AnyTrueEQ; + if (SDValue AllTrueEQ = SimdCombiner(Intrinsic::wasm_alltrue, ISD::SETEQ, + Intrinsic::wasm_anytrue, true)) + return AllTrueEQ; + if (SDValue AnyTrueNE = SimdCombiner(Intrinsic::wasm_anytrue, ISD::SETNE, + Intrinsic::wasm_anytrue, false)) + return AnyTrueNE; + if (SDValue AllTrueNE = SimdCombiner(Intrinsic::wasm_alltrue, ISD::SETNE, + Intrinsic::wasm_alltrue, false)) + return AllTrueNE; + + return SDValue(); +} + template static SDValue TryMatchTrue(SDNode *N, EVT VecVT, SelectionDAG &DAG) { @@ -3427,8 +3477,11 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, return performVectorTruncZeroCombine(N, DCI); case ISD::TRUNCATE: return performTruncateCombine(N, DCI); - case ISD::INTRINSIC_WO_CHAIN: + case ISD::INTRINSIC_WO_CHAIN: { + if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG)) + return AnyAllCombine; return performLowerPartialReduction(N, DCI.DAG); + } case ISD::MUL: return performMulCombine(N, DCI.DAG); } diff --git a/llvm/test/CodeGen/WebAssembly/simd-setcc-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-setcc-reductions.ll new file mode 100644 index 0000000000000..503c7e857e6e6 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-setcc-reductions.ll @@ -0,0 +1,136 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s + +target triple = "wasm64" + +define i32 @all_true_16_i8(<16 x i8> %v) { +; CHECK-LABEL: all_true_16_i8: +; CHECK: .functype all_true_16_i8 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i8x16.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp eq <16 x i8> %v, zeroinitializer + %2 = bitcast <16 x i1> %1 to i16 + %3 = icmp eq i16 %2, 0 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + +define i32 @all_true_4_i32(<4 x i32> %v) { +; CHECK-LABEL: all_true_4_i32: +; CHECK: .functype all_true_4_i32 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i32x4.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp eq <4 x i32> %v, zeroinitializer + %2 = bitcast <4 x i1> %1 to i4 + %3 = icmp eq i4 %2, 0 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + +define i32 @all_true_8_i16(<8 x i16> %v) { +; CHECK-LABEL: all_true_8_i16: +; CHECK: .functype all_true_8_i16 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i16x8.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp eq <8 x i16> %v, zeroinitializer + %2 = bitcast <8 x i1> %1 to i8 + %3 = icmp eq i8 %2, 0 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + +define i32 @all_true_2_i64(<2 x i64> %v) { +; CHECK-LABEL: all_true_2_i64: +; CHECK: .functype all_true_2_i64 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i64x2.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp eq <2 x i64> %v, zeroinitializer + %2 = bitcast <2 x i1> %1 to i2 + %3 = icmp eq i2 %2, 0 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + +define i32 @all_true_4_i64(<4 x i64> %v) { +; CHECK-LABEL: all_true_4_i64: +; CHECK: .functype all_true_4_i64 (v128, v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: v128.const $push9=, 0, 0 +; CHECK-NEXT: local.tee $push8=, $2=, $pop9 +; CHECK-NEXT: i64x2.eq $push1=, $0, $pop8 +; CHECK-NEXT: i64x2.eq $push0=, $1, $2 +; CHECK-NEXT: i8x16.shuffle $push2=, $pop1, $pop0, 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 +; CHECK-NEXT: v128.any_true $push3=, $pop2 +; CHECK-NEXT: i32.const $push4=, -1 +; CHECK-NEXT: i32.xor $push5=, $pop3, $pop4 +; CHECK-NEXT: i32.const $push6=, 1 +; CHECK-NEXT: i32.and $push7=, $pop5, $pop6 +; CHECK-NEXT: return $pop7 + %1 = icmp eq <4 x i64> %v, zeroinitializer + %2 = bitcast <4 x i1> %1 to i4 + %3 = icmp eq i4 %2, 0 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + +; setcc (iN (bitcast (set_cc (vNi1 X), 0, ne)), 0, ne +; => any_true (set_cc (X), 0, ne) +; => any_true (X) +define i32 @any_true_1_4_i32(<4 x i32> %v) { +; CHECK-LABEL: any_true_1_4_i32: +; CHECK: .functype any_true_1_4_i32 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: v128.any_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp ne <4 x i32> %v, zeroinitializer + %2 = bitcast <4 x i1> %1 to i4 + %3 = icmp ne i4 %2, 0 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + +; setcc (iN (bitcast (set_cc (vNi1 X), 0, eq)), -1, ne +; => not all_true (set_cc (X), 0, eq) +; => not all_true (set_cc (X), 0, eq) +; => not not any_true (X) +; => any_true (X) +define i32 @any_true_2_4_i32(<4 x i32> %v) { +; CHECK-LABEL: any_true_2_4_i32: +; CHECK: .functype any_true_2_4_i32 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: v128.any_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp eq <4 x i32> %v, zeroinitializer + %2 = bitcast <4 x i1> %1 to i4 + %3 = icmp ne i4 %2, -1 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + +; setcc (iN (bitcast (set_cc (vNi1 X), 0, ne)), -1, eq +; => all_true (set_cc (X), 0, ne) +; => all_true (X) +define i32 @all_true_2_4_i32(<4 x i32> %v) { +; CHECK-LABEL: all_true_2_4_i32: +; CHECK: .functype all_true_2_4_i32 (v128) -> (i32) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: i32x4.all_true $push0=, $0 +; CHECK-NEXT: return $pop0 + %1 = icmp ne <4 x i32> %v, zeroinitializer + %2 = bitcast <4 x i1> %1 to i4 + %3 = icmp eq i4 %2, -1 + %conv3 = zext i1 %3 to i32 + ret i32 %conv3 +} + + diff --git a/llvm/test/CodeGen/WebAssembly/simd-vecreduce-bool.ll b/llvm/test/CodeGen/WebAssembly/simd-vecreduce-bool.ll index e6497bca98dc2..f7143711394fa 100644 --- a/llvm/test/CodeGen/WebAssembly/simd-vecreduce-bool.ll +++ b/llvm/test/CodeGen/WebAssembly/simd-vecreduce-bool.ll @@ -1086,9 +1086,9 @@ define i1 @test_cmp_v16i8(<16 x i8> %x) { ; CHECK-LABEL: test_cmp_v16i8: ; CHECK: .functype test_cmp_v16i8 (v128) -> (i32) ; CHECK-NEXT: # %bb.0: -; CHECK-NEXT: v128.const $push0=, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 -; CHECK-NEXT: i8x16.eq $push1=, $0, $pop0 -; CHECK-NEXT: v128.any_true $push2=, $pop1 +; CHECK-NEXT: i8x16.all_true $push0=, $0 +; CHECK-NEXT: i32.const $push1=, 1 +; CHECK-NEXT: i32.xor $push2=, $pop0, $pop1 ; CHECK-NEXT: return $pop2 %zero = icmp eq <16 x i8> %x, zeroinitializer %ret = call i1 @llvm.vector.reduce.or.v16i1(<16 x i1> %zero)