diff --git a/include/circt/Dialect/Comb/Comb.td b/include/circt/Dialect/Comb/Comb.td index 8f4d1d6f154e..3bd7f9096904 100644 --- a/include/circt/Dialect/Comb/Comb.td +++ b/include/circt/Dialect/Comb/Comb.td @@ -15,6 +15,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index 8414c347c459..56bad66912fb 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/include/circt/Dialect/Comb/CombPasses.h b/include/circt/Dialect/Comb/CombPasses.h index c41b577cd768..cdbaf906b9bb 100644 --- a/include/circt/Dialect/Comb/CombPasses.h +++ b/include/circt/Dialect/Comb/CombPasses.h @@ -18,6 +18,10 @@ #include #include +namespace mlir { +class DataFlowSolver; +} + namespace circt { namespace comb { @@ -26,6 +30,10 @@ namespace comb { #define GEN_PASS_REGISTRATION #include "circt/Dialect/Comb/Passes.h.inc" +/// Add patterns for int range based narrowing. +void populateCombNarrowingPatterns(mlir::RewritePatternSet &patterns, + mlir::DataFlowSolver &solver); + } // namespace comb } // namespace circt diff --git a/include/circt/Dialect/Comb/Combinational.td b/include/circt/Dialect/Comb/Combinational.td index 114d5ce1ce58..3edc8473caa4 100644 --- a/include/circt/Dialect/Comb/Combinational.td +++ b/include/circt/Dialect/Comb/Combinational.td @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/IR/EnumAttr.td" // Base class for binary operators. @@ -30,7 +31,9 @@ class BinOp traits = []> : // Binary operator with uniform input/result types. class UTBinOp traits = []> : BinOp { + traits # [SameTypeOperands, SameOperandsAndResultType, + DeclareOpInterfaceMethods]> { let assemblyFormat = "(`bin` $twoState^)? $lhs `,` $rhs attr-dict `:` qualified(type($result))"; } @@ -42,8 +45,10 @@ class VariadicOp traits = []> : } class UTVariadicOp traits = []> : - VariadicOp { + VariadicOp]> { let hasCanonicalizeMethod = true; let hasFolder = true; @@ -76,7 +81,7 @@ let hasFolder = true in { } def AndOp : UTVariadicOp<"and", [Commutative]>; -def OrOp : UTVariadicOp<"or", [Commutative]>; +def OrOp : UTVariadicOp<"or", [Commutative]>; def XorOp : UTVariadicOp<"xor", [Commutative]> { let extraClassDeclaration = [{ /// Return true if this is a two operand xor with an all ones constant as @@ -114,7 +119,10 @@ def ICmpPredicate : I64EnumAttr< ICmpPredicateUGT, ICmpPredicateUGE, ICmpPredicateCEQ, ICmpPredicateCNE, ICmpPredicateWEQ, ICmpPredicateWNE]>; -def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { +def ICmpOp : CombOp<"icmp", + [Pure, + SameTypeOperands, + DeclareOpInterfaceMethods]> { let summary = "Compare two integer values"; let description = [{ This operation compares two integers using a predicate. If the predicate is @@ -178,7 +186,9 @@ def ParityOp : UnaryI1ReductionOp<"parity">; //===----------------------------------------------------------------------===// // Extract a range of bits from the specified input. -def ExtractOp : CombOp<"extract", [Pure]> { +def ExtractOp : CombOp<"extract", + [Pure, + DeclareOpInterfaceMethods]> { let summary = "Extract a range of bits into a smaller value, lowBit " "specifies the lowest bit included."; @@ -203,7 +213,9 @@ def ExtractOp : CombOp<"extract", [Pure]> { //===----------------------------------------------------------------------===// // Other Operations //===----------------------------------------------------------------------===// -def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { +def ConcatOp : CombOp<"concat", + [InferTypeOpInterface, Pure, + DeclareOpInterfaceMethods]> { let summary = "Concatenate a variadic list of operands together."; let description = [{ See the comb rationale document for details on operand ordering. @@ -237,7 +249,9 @@ def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { }]; } -def ReplicateOp : CombOp<"replicate", [Pure]> { +def ReplicateOp : CombOp<"replicate", + [Pure, + DeclareOpInterfaceMethods]> { let summary = "Concatenate the operand a constant number of times"; let arguments = (ins HWIntegerType:$input); @@ -267,8 +281,10 @@ def ReplicateOp : CombOp<"replicate", [Pure]> { } // Select one of two values based on a condition. -def MuxOp : CombOp<"mux", - [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>]> { +def MuxOp : CombOp<"mux", + [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>, + DeclareOpInterfaceMethods]> { let summary = "Return one or the other operand depending on a selector bit"; let description = [{ ``` diff --git a/include/circt/Dialect/Comb/Passes.td b/include/circt/Dialect/Comb/Passes.td index 197e8b1b4696..8ce3960f3e6f 100644 --- a/include/circt/Dialect/Comb/Passes.td +++ b/include/circt/Dialect/Comb/Passes.td @@ -22,4 +22,16 @@ def LowerComb : Pass<"lower-comb"> { }]; } +def CombIntRangeNarrowing : Pass<"comb-int-range-narrowing"> { + let summary = "Reduce comb op bitwidth based on integer range analysis."; + let description = [{ + Compute a basic value range analysis, by propagating integer intervals + through the domain. The analysis is limited by a lack of sign-extension + operator in the comb dialect, leading to an over-approximation. + Particularly for signed arithmetic, a single interval is often an + over-approximation, a more precise analysis would require a union of + intervals. + }]; +} + #endif // CIRCT_DIALECT_COMB_PASSES_TD diff --git a/include/circt/Dialect/HW/HWMiscOps.td b/include/circt/Dialect/HW/HWMiscOps.td index bd534a5b7550..b8b5e1ae9f88 100644 --- a/include/circt/Dialect/HW/HWMiscOps.td +++ b/include/circt/Dialect/HW/HWMiscOps.td @@ -19,11 +19,14 @@ include "circt/Dialect/HW/HWOpInterfaces.td" include "circt/Dialect/HW/HWTypes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def ConstantOp : HWOp<"constant", [Pure, ConstantLike, FirstAttrDerivedResultType, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Produce a constant value"; let description = [{ The constant operation produces a constant value of standard integer type diff --git a/include/circt/Dialect/HW/HWOps.h b/include/circt/Dialect/HW/HWOps.h index 75009161c4b5..02bff24c6d75 100644 --- a/include/circt/Dialect/HW/HWOps.h +++ b/include/circt/Dialect/HW/HWOps.h @@ -26,6 +26,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/StringExtras.h" diff --git a/lib/Analysis/TestPasses.cpp b/lib/Analysis/TestPasses.cpp index 8c79c09715b2..c2b3f6bc9528 100644 --- a/lib/Analysis/TestPasses.cpp +++ b/lib/Analysis/TestPasses.cpp @@ -18,6 +18,8 @@ #include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/HW/HWInstanceGraph.h" #include "circt/Scheduling/Problems.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -29,6 +31,7 @@ using namespace mlir; using namespace mlir::affine; +using namespace mlir::dataflow; using namespace circt; using namespace circt::analysis; using namespace circt::scheduling; @@ -263,6 +266,69 @@ void FIRRTLInstanceInfoPass::runOnOperation() { printModuleInfo(op, iInfo); } +//===----------------------------------------------------------------------===// +// Comb IntRange Analysis +//===----------------------------------------------------------------------===// + +namespace { +struct TestCombIntegerRangeAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestCombIntegerRangeAnalysisPass) + + void runOnOperation() override; + StringRef getArgument() const override { + return "test-comb-int-range-analysis"; + } + StringRef getDescription() const override { + return "Perform integer range analysis on comb dialect and set results as " + "attributes."; + } +}; +} // namespace + +void TestCombIntegerRangeAnalysisPass::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + // Append the integer range analysis as an operation attribute. + op->walk([&](Operation *op) { + for (auto value : op->getResults()) { + if (auto *range = solver.lookupState(value)) { + // All analyzed comb operations should return a single result. + assert(op->getResults().size() == 1 && + "Expected a single result for the operation analysis"); + assert(!range->getValue().isUninitialized() && + "Expected a valid range for the value"); + auto interval = range->getValue().getValue(); + auto smax = interval.smax(); + auto smaxAttr = + IntegerAttr::get(IntegerType::get(ctx, smax.getBitWidth()), smax); + op->setAttr("smax", smaxAttr); + auto smin = interval.smin(); + auto sminAttr = + IntegerAttr::get(IntegerType::get(ctx, smin.getBitWidth()), smin); + op->setAttr("smin", sminAttr); + auto umax = interval.umax(); + auto umaxAttr = IntegerAttr::get( + IntegerType::get(ctx, umax.getBitWidth(), IntegerType::Unsigned), + umax); + op->setAttr("umax", umaxAttr); + auto umin = interval.umin(); + auto uminAttr = IntegerAttr::get( + IntegerType::get(ctx, umin.getBitWidth(), IntegerType::Unsigned), + umin); + op->setAttr("umin", uminAttr); + } + } + }); +} + //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// @@ -285,6 +351,9 @@ void registerAnalysisTestPasses() { registerPass([]() -> std::unique_ptr { return std::make_unique(); }); + registerPass([]() -> std::unique_ptr { + return std::make_unique(); + }); } } // namespace test } // namespace circt diff --git a/lib/Dialect/Comb/CMakeLists.txt b/lib/Dialect/Comb/CMakeLists.txt index 9ecd57f3809b..d43ddbfc964e 100644 --- a/lib/Dialect/Comb/CMakeLists.txt +++ b/lib/Dialect/Comb/CMakeLists.txt @@ -3,6 +3,7 @@ add_circt_dialect_library(CIRCTComb CombOps.cpp CombAnalysis.cpp CombDialect.cpp + InferIntRangeInterfaceImpls.cpp ADDITIONAL_HEADER_DIRS ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Comb @@ -19,6 +20,7 @@ add_circt_dialect_library(CIRCTComb CIRCTHW MLIRIR MLIRInferTypeOpInterface + MLIRInferIntRangeInterface ) add_dependencies(circt-headers MLIRCombIncGen MLIRCombEnumsIncGen) diff --git a/lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp b/lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp new file mode 100644 index 000000000000..d5587bdef8ed --- /dev/null +++ b/lib/Dialect/Comb/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,309 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for comb -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of the interval range analysis interface. +// The overflow flags are not set for the comb operations since they is +// no meaningful concept of overflow detection in comb. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +using namespace mlir; +using namespace mlir::intrange; +using namespace circt; +using namespace circt::comb; +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void comb::AddOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = + inferAdd({resultRange, argRange}, intrange::OverflowFlags::None); + + setResultRange(getResult(), resultRange); +}; + +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +void comb::SubOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferSub(argRanges, intrange::OverflowFlags::None)); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +void comb::MulOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = + inferMul({resultRange, argRange}, intrange::OverflowFlags::None); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// DivUIOp +//===----------------------------------------------------------------------===// + +void comb::DivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferDivU(argRanges)); +} + +//===----------------------------------------------------------------------===// +// DivSIOp +//===----------------------------------------------------------------------===// + +void comb::DivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferDivS(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ModUOp +//===----------------------------------------------------------------------===// + +void comb::ModUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferRemU(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ModSOp +//===----------------------------------------------------------------------===// + +void comb::ModSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferRemS(argRanges)); +} +//===----------------------------------------------------------------------===// +// AndOp +//===----------------------------------------------------------------------===// + +void comb::AndOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = inferAnd({resultRange, argRange}); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// OrOp +//===----------------------------------------------------------------------===// + +void comb::OrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = inferOr({resultRange, argRange}); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// XorOp +//===----------------------------------------------------------------------===// + +void comb::XorOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto resultRange = argRanges[0]; + for (auto argRange : argRanges.drop_front()) + resultRange = inferXor({resultRange, argRange}); + + setResultRange(getResult(), resultRange); +} + +//===----------------------------------------------------------------------===// +// ShlOp +//===----------------------------------------------------------------------===// + +void comb::ShlOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferShl(argRanges, intrange::OverflowFlags::None)); +} + +//===----------------------------------------------------------------------===// +// ShRUIOp +//===----------------------------------------------------------------------===// + +void comb::ShrUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferShrU(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ShRSIOp +//===----------------------------------------------------------------------===// + +void comb::ShrSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferShrS(argRanges)); +} + +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +void comb::ConcatOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + // Compute concat as an unsigned integer of bits + const auto resWidth = getResult().getType().getIntOrFloatBitWidth(); + auto totalWidth = resWidth; + APInt umin = APInt::getZero(resWidth); + APInt umax = APInt::getZero(resWidth); + for (auto [operand, arg] : llvm::zip(getOperands(), argRanges)) { + assert(totalWidth >= operand.getType().getIntOrFloatBitWidth() && + "ConcatOp: total width in interval range calculation is negative"); + totalWidth -= operand.getType().getIntOrFloatBitWidth(); + auto uminUpd = arg.umin().zext(resWidth).ushl_sat(totalWidth); + auto umaxUpd = arg.umax().zext(resWidth).ushl_sat(totalWidth); + umin = umin.uadd_sat(uminUpd); + umax = umax.uadd_sat(umaxUpd); + } + auto urange = ConstantIntRanges::fromUnsigned(umin, umax); + setResultRange(getResult(), urange); +}; + +//===----------------------------------------------------------------------===// +// ExtractOp +//===----------------------------------------------------------------------===// + +void comb::ExtractOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + // Right-shift and truncate (trunaction implicitly handled) + const auto resWidth = getResult().getType().getIntOrFloatBitWidth(); + const auto lowBit = getLowBit(); + auto umin = argRanges[0].umin().lshr(lowBit).trunc(resWidth); + auto umax = argRanges[0].umax().lshr(lowBit).trunc(resWidth); + auto urange = ConstantIntRanges::fromUnsigned(umin, umax); + setResultRange(getResult(), urange); +}; + +//===----------------------------------------------------------------------===// +// ReplicateOp +//===----------------------------------------------------------------------===// + +void comb::ReplicateOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + // Compute replicate as an unsigned integer of bits + const auto operandWidth = getOperand().getType().getIntOrFloatBitWidth(); + const auto resWidth = getResult().getType().getIntOrFloatBitWidth(); + APInt umin = APInt::getZero(resWidth); + APInt umax = APInt::getZero(resWidth); + auto uminIn = argRanges[0].umin().zext(resWidth); + auto umaxIn = argRanges[0].umax().zext(resWidth); + for (unsigned int totalWidth = 0; totalWidth < resWidth; + totalWidth += operandWidth) { + auto uminUpd = uminIn.ushl_sat(totalWidth); + auto umaxUpd = umaxIn.ushl_sat(totalWidth); + umin = umin.uadd_sat(uminUpd); + umax = umax.uadd_sat(umaxUpd); + } + auto urange = ConstantIntRanges::fromUnsigned(umin, umax); + setResultRange(getResult(), urange); +}; + +//===----------------------------------------------------------------------===// +// MuxOp +//===----------------------------------------------------------------------===// + +void comb::MuxOp::inferResultRangesFromOptional( + ArrayRef argRanges, SetIntLatticeFn setResultRange) { + std::optional mbCondVal = + argRanges[0].isUninitialized() + ? std::nullopt + : argRanges[0].getValue().getConstantValue(); + + const IntegerValueRange &trueCase = argRanges[1]; + const IntegerValueRange &falseCase = argRanges[2]; + + if (mbCondVal) { + if (mbCondVal->isZero()) + setResultRange(getResult(), falseCase); + else + setResultRange(getResult(), trueCase); + return; + } + setResultRange(getResult(), IntegerValueRange::join(trueCase, falseCase)); +} + +//===----------------------------------------------------------------------===// +// ICmpOp +//===----------------------------------------------------------------------===// + +void comb::ICmpOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + comb::ICmpPredicate combPred = getPredicate(); + + APInt min = APInt::getZero(1); + APInt max = APInt::getAllOnes(1); + + intrange::CmpPredicate pred; + switch (combPred) { + case comb::ICmpPredicate::eq: + pred = intrange::CmpPredicate::eq; + break; + case comb::ICmpPredicate::ne: + pred = intrange::CmpPredicate::ne; + break; + case comb::ICmpPredicate::slt: + pred = intrange::CmpPredicate::slt; + break; + case comb::ICmpPredicate::sle: + pred = intrange::CmpPredicate::sle; + break; + case comb::ICmpPredicate::sgt: + pred = intrange::CmpPredicate::sgt; + break; + case comb::ICmpPredicate::sge: + pred = intrange::CmpPredicate::sge; + break; + case comb::ICmpPredicate::ult: + pred = intrange::CmpPredicate::ult; + break; + case comb::ICmpPredicate::ule: + pred = intrange::CmpPredicate::ule; + break; + case comb::ICmpPredicate::ugt: + pred = intrange::CmpPredicate::ugt; + break; + case comb::ICmpPredicate::uge: + pred = intrange::CmpPredicate::uge; + break; + default: + // These predicates are not supported for integer range analysis + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); + return; + } + + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + std::optional truthValue = intrange::evaluatePred(pred, lhs, rhs); + if (truthValue.has_value() && *truthValue) + min = max; + else if (truthValue.has_value() && !(*truthValue)) + max = min; + + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} diff --git a/lib/Dialect/Comb/Transforms/CMakeLists.txt b/lib/Dialect/Comb/Transforms/CMakeLists.txt index 9043ccf5dea3..45b19a1d5096 100644 --- a/lib/Dialect/Comb/Transforms/CMakeLists.txt +++ b/lib/Dialect/Comb/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_circt_dialect_library(CIRCTCombTransforms LowerComb.cpp + IntRangeOptimizations.cpp DEPENDS CIRCTCombTransformsIncGen diff --git a/lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp b/lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp new file mode 100644 index 000000000000..ec57c1582757 --- /dev/null +++ b/lib/Dialect/Comb/Transforms/IntRangeOptimizations.cpp @@ -0,0 +1,140 @@ +//===- IntRangeOptimizations.cpp - Narrow ops in comb ------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/Comb/CombPasses.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace circt; +using namespace circt::comb; +using namespace mlir; +using namespace mlir::dataflow; + +namespace circt { +namespace comb { +#define GEN_PASS_DEF_COMBINTRANGENARROWING +#include "circt/Dialect/Comb/Passes.h.inc" +} // namespace comb +} // namespace circt + +/// Gather ranges for all the values in `values`. Appends to the existing +/// vector. +static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, + SmallVectorImpl &ranges) { + for (Value val : values) { + auto *maybeInferredRange = + solver.lookupState(val); + if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) + return failure(); + + const ConstantIntRanges &inferredRange = + maybeInferredRange->getValue().getValue(); + ranges.push_back(inferredRange); + } + return success(); +} + +namespace { +template +struct CombOpNarrow : public OpRewritePattern { + CombOpNarrow(MLIRContext *context, DataFlowSolver &s) + : OpRewritePattern(context), solver(s) {} + + LogicalResult matchAndRewrite(CombOpTy op, + PatternRewriter &rewriter) const override { + + auto opWidth = op.getType().getIntOrFloatBitWidth(); + if (op->getNumOperands() != 2 || op->getNumResults() != 1) + return rewriter.notifyMatchFailure( + op, "Only support binary operations with one result"); + + SmallVector ranges; + if (failed(collectRanges(solver, op->getOperands(), ranges))) + return rewriter.notifyMatchFailure(op, "input without specified range"); + if (failed(collectRanges(solver, op->getResults(), ranges))) + return rewriter.notifyMatchFailure(op, "output without specified range"); + + auto removeWidth = ranges[0].umax().countLeadingZeros(); + for (const ConstantIntRanges &range : ranges) { + auto rangeCanRemove = range.umax().countLeadingZeros(); + removeWidth = std::min(removeWidth, rangeCanRemove); + } + if (removeWidth == 0) + return rewriter.notifyMatchFailure(op, "no bits to remove"); + if (removeWidth == opWidth) + return rewriter.notifyMatchFailure( + op, "all bits to remove - replace by zero"); + + // Replace operator by narrower version of itself + Value lhs = op.getOperand(0); + Value rhs = op.getOperand(1); + + Location loc = op.getLoc(); + auto newWidth = opWidth - removeWidth; + // Create a replacement type for the extracted bits + auto replaceType = rewriter.getIntegerType(newWidth); + + // Extract the lsbs from each operand + auto extractLhsOp = + rewriter.create(loc, replaceType, lhs, 0); + auto extractRhsOp = + rewriter.create(loc, replaceType, rhs, 0); + auto narrowOp = rewriter.create(loc, extractLhsOp, extractRhsOp); + + // Concatenate zeros to match the original operator width + auto zero = + rewriter.create(loc, APInt::getZero(removeWidth)); + auto replaceOp = rewriter.create( + loc, op.getType(), ValueRange{zero, narrowOp}); + + rewriter.replaceOp(op, replaceOp); + return success(); + } + +private: + DataFlowSolver &solver; +}; + +struct CombIntRangeNarrowingPass + : comb::impl::CombIntRangeNarrowingBase { + + using CombIntRangeNarrowingBase::CombIntRangeNarrowingBase; + void runOnOperation() override; +}; +} // namespace + +void CombIntRangeNarrowingPass::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + DataFlowSolver solver; + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + RewritePatternSet patterns(ctx); + populateCombNarrowingPatterns(patterns, solver); + + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + signalPassFailure(); +} + +void comb::populateCombNarrowingPatterns(RewritePatternSet &patterns, + DataFlowSolver &solver) { + patterns.add, CombOpNarrow, + CombOpNarrow>(patterns.getContext(), solver); +} diff --git a/lib/Dialect/HW/CMakeLists.txt b/lib/Dialect/HW/CMakeLists.txt index e5a4fbd50bf8..431a649d6624 100644 --- a/lib/Dialect/HW/CMakeLists.txt +++ b/lib/Dialect/HW/CMakeLists.txt @@ -14,6 +14,7 @@ set(CIRCT_HW_Sources ModuleImplementation.cpp InnerSymbolTable.cpp PortConverter.cpp + InferIntRangeInterfaceImpls.cpp ) set(LLVM_OPTIONAL_SOURCES @@ -40,6 +41,8 @@ add_circt_dialect_library(CIRCTHW MLIRIR MLIRInferTypeOpInterface MLIRMemorySlotInterfaces + MLIRInferIntRangeCommon + MLIRInferIntRangeInterface ) add_circt_library(CIRCTHWReductions diff --git a/lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp b/lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp new file mode 100644 index 000000000000..2b06306ab34e --- /dev/null +++ b/lib/Dialect/HW/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,25 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for HW -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWOps.h" + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +using namespace mlir; +using namespace mlir::intrange; +using namespace circt; + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void hw::ConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), ConstantIntRanges::constant(getValue())); +} diff --git a/test/Analysis/comb-int-range-analysis.mlir b/test/Analysis/comb-int-range-analysis.mlir new file mode 100644 index 000000000000..fc60b7e49660 --- /dev/null +++ b/test/Analysis/comb-int-range-analysis.mlir @@ -0,0 +1,174 @@ +// RUN: circt-opt %s --test-comb-int-range-analysis | FileCheck %s + +// CHECK-LABEL: @basic_csa +hw.module @basic_csa(in %a : i1, in %b : i1, in %c : i1, out add_abc : i3) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %false = hw.constant false + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %false, %a {smax = 1 : i2, smin = 0 : i2, umax = 1 : ui2, umin = 0 : ui2} : i1, i1 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %false, %b {smax = 1 : i2, smin = 0 : i2, umax = 1 : ui2, umin = 0 : ui2} : i1, i1 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[A_EXT]], %[[B_EXT]] {smax = 1 : i2, smin = -2 : i2, umax = 2 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: %[[ADD_EXT:.+]] = comb.concat %false, %[[ADD]] {smax = 2 : i3, smin = 0 : i3, umax = 2 : ui3, umin = 0 : ui3} : i1, i2 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i2, %c {smax = 1 : i3, smin = 0 : i3, umax = 1 : ui3, umin = 0 : ui3} : i2, i1 + // CHECK-NEXT: %[[ADD1:.+]] = comb.add %[[ADD_EXT]], %[[C_EXT]] {smax = 3 : i3, smin = 0 : i3, umax = 3 : ui3, umin = 0 : ui3} : i3 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %0 = comb.concat %false, %a : i1, i1 + %1 = comb.concat %false, %b : i1, i1 + %2 = comb.add %0, %1 : i2 + %3 = comb.concat %false, %2 : i1, i2 + %4 = comb.concat %c0_i2, %c : i2, i1 + %5 = comb.add %3, %4 : i3 + hw.output %5 : i3 +} + +// CHECK-LABEL: @basic_mux +hw.module @basic_mux(in %a : i3, in %b : i3, in %sel : i1, out y : i4) { + // CHECK-NEXT: %false = hw.constant false {smax = false, smin = false, umax = 0 : ui1, umin = 0 : ui1} + // CHECK-NEXT: %true = hw.constant true {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %true, %a {smax = -1 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 8 : ui4} : i1, i3 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %false, %b {smax = 7 : i4, smin = 0 : i4, umax = 7 : ui4, umin = 0 : ui4} : i1, i3 + // CHECK-NEXT: %[[MUX:.+]] = comb.mux %sel, %[[A_EXT]], %[[B_EXT]] {smax = 7 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 0 : ui4} : i4 + %false = hw.constant false + %true = hw.constant true + %0 = comb.concat %true, %a : i1, i3 + %1 = comb.concat %false, %b : i1, i3 + %2 = comb.mux %sel, %0, %1 : i4 + hw.output %2 : i4 +} + +// CHECK-LABEL: @basic_fma +hw.module @basic_fma(in %a : i4, in %b : i4, in %c : i4, out d : i9) { + // CHECK-NEXT: %c0_i5 = hw.constant 0 : i5 {smax = 0 : i5, smin = 0 : i5, umax = 0 : ui5, umin = 0 : ui5} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i5, %a {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i5, %b {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT]], %[[B_EXT]] {smax = 225 : i9, smin = 0 : i9, umax = 225 : ui9, umin = 0 : ui9} : i9 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i5, %c {smax = 15 : i9, smin = 0 : i9, umax = 15 : ui9, umin = 0 : ui9} : i5, i4 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[MUL]], %[[C_EXT]] {smax = 240 : i9, smin = 0 : i9, umax = 240 : ui9, umin = 0 : ui9} : i9 + %c0_i5 = hw.constant 0 : i5 + %0 = comb.concat %c0_i5, %a : i5, i4 + %1 = comb.concat %c0_i5, %b : i5, i4 + %2 = comb.mul %0, %1 : i9 + %3 = comb.concat %c0_i5, %c : i5, i4 + %4 = comb.add %2, %3 : i9 + hw.output %4 : i9 +} + +// CHECK-LABEL: @const_sub +hw.module @const_sub(in %a : i8, out sub_res : i10) { + // CHECK-NEXT: %c256_i10 = hw.constant 256 : i10 {smax = 256 : i10, smin = 256 : i10, umax = 256 : ui10, umin = 256 : ui10} + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i2, %a {smax = 255 : i10, smin = 0 : i10, umax = 255 : ui10, umin = 0 : ui10} : i2, i8 + // CHECK-NEXT: %[[SUB:.+]] = comb.sub %c256_i10, %[[A_EXT]] {smax = 256 : i10, smin = 1 : i10, umax = 256 : ui10, umin = 1 : ui10} : i10 + %c256_i10 = hw.constant 256 : i10 + %c0_i2 = hw.constant 0 : i2 + %0 = comb.concat %c0_i2, %a : i2, i8 + %1 = comb.sub %c256_i10, %0 : i10 + hw.output %1 : i10 +} + +// CHECK-LABEL: @logical_ops +hw.module @logical_ops(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %false = hw.constant false {smax = false, smin = false, umax = 0 : ui1, umin = 0 : ui1} + // CHECK-NEXT: %c0_i9 = hw.constant 0 : i9 {smax = 0 : i9, smin = 0 : i9, umax = 0 : ui9, umin = 0 : ui9} + // CHECK-NEXT: %c0_i8 = hw.constant 0 : i8 {smax = 0 : i8, smin = 0 : i8, umax = 0 : ui8, umin = 0 : ui8} + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i9, %a {smax = 255 : i17, smin = 0 : i17, umax = 255 : ui17, umin = 0 : ui17} : i9, i8 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i8, %b {smax = 511 : i17, smin = 0 : i17, umax = 511 : ui17, umin = 0 : ui17} : i8, i9 + // CHECK-NEXT: %[[AND:.+]] = comb.and %[[A_EXT]], %[[B_EXT]] {smax = 255 : i17, smin = 0 : i17, umax = 255 : ui17, umin = 0 : ui17} : i17 + // CHECK-NEXT: %[[AND_EXT:.+]] = comb.concat %false, %[[AND]] {smax = 255 : i18, smin = 0 : i18, umax = 255 : ui18, umin = 0 : ui18} : i1, i17 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i8, %c {smax = 1023 : i18, smin = 0 : i18, umax = 1023 : ui18, umin = 0 : ui18} : i8, i10 + // CHECK-NEXT: %[[OR:.+]] = comb.or %[[AND_EXT]], %[[C_EXT]] {smax = 1023 : i18, smin = 0 : i18, umax = 1023 : ui18, umin = 0 : ui18} : i18 + // CHECK-NEXT: %[[D_EXT:.+]] = comb.concat %c0_i2, %d {smax = 65535 : i18, smin = 0 : i18, umax = 65535 : ui18, umin = 0 : ui18} : i2, i16 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[OR]], %[[D_EXT]] {smax = 66558 : i18, smin = 0 : i18, umax = 66558 : ui18, umin = 0 : ui18} : i18 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %c0_i9 = hw.constant 0 : i9 + %c0_i8 = hw.constant 0 : i8 + %0 = comb.concat %c0_i9, %a : i9, i8 + %1 = comb.concat %c0_i8, %b : i8, i9 + %2 = comb.and %0, %1 : i17 + %3 = comb.concat %false, %2 : i1, i17 + %4 = comb.concat %c0_i8, %c : i8, i10 + %5 = comb.or %3, %4 : i18 + %6 = comb.concat %c0_i2, %d : i2, i16 + %7 = comb.add %5, %6 : i18 + hw.output %7 : i18 +} + +// CHECK-LABEL: @variadic_ops +hw.module @variadic_ops(in %a : i2, in %b : i2, in %c : i2) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %[[A_EXT2:.+]] = comb.concat %c0_i2, %a {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2 + // CHECK-NEXT: %[[B_EXT2:.+]] = comb.concat %c0_i2, %b {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2 + // CHECK-NEXT: %[[C_EXT2:.+]] = comb.concat %c0_i2, %c {smax = 3 : i4, smin = 0 : i4, umax = 3 : ui4, umin = 0 : ui4} : i2, i2 + // CHECK-NEXT: %[[ADD:.+]] = comb.add %[[A_EXT2]], %[[B_EXT2]], %[[C_EXT2]] {smax = 7 : i4, smin = -8 : i4, umax = 9 : ui4, umin = 0 : ui4} : i4 + // CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 {smax = 0 : i3, smin = 0 : i3, umax = 0 : ui3, umin = 0 : ui3} + // CHECK-NEXT: %[[A_EXT3:.+]] = comb.concat %c0_i3, %a {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2 + // CHECK-NEXT: %[[B_EXT3:.+]] = comb.concat %c0_i3, %b {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2 + // CHECK-NEXT: %[[C_EXT3:.+]] = comb.concat %c0_i3, %c {smax = 3 : i5, smin = 0 : i5, umax = 3 : ui5, umin = 0 : ui5} : i3, i2 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT3]], %[[B_EXT3]], %[[C_EXT3]] {smax = 15 : i5, smin = -16 : i5, umax = 27 : ui5, umin = 0 : ui5} : i5 + // CHECK-NEXT: %[[AND:.+]] = comb.and %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: %[[OR:.+]] = comb.or %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: %[[XOR:.+]] = comb.xor %a, %b, %c {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : i2 + // CHECK-NEXT: hw.output + %c0_i2 = hw.constant 0 : i2 + %0 = comb.concat %c0_i2, %a : i2, i2 + %1 = comb.concat %c0_i2, %b : i2, i2 + %2 = comb.concat %c0_i2, %c : i2, i2 + %3 = comb.add %0, %1, %2 : i4 + %c0_i3 = hw.constant 0 : i3 + %4 = comb.concat %c0_i3, %a : i3, i2 + %5 = comb.concat %c0_i3, %b : i3, i2 + %6 = comb.concat %c0_i3, %c : i3, i2 + %7 = comb.mul %4, %5, %6 : i5 + %8 = comb.and %a, %b, %c : i2 + %9 = comb.or %a, %b, %c : i2 + %10 = comb.xor %a, %b, %c : i2 + hw.output +} + +// CHECK-LABEL: @replicate_extract +hw.module @replicate_extract(in %a : i3, in %b : i3, in %sel : i1) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %[[EXT_A:.+]] = comb.extract %a from 1 {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i3) -> i2 + // CHECK-NEXT: %[[REPL_A:.+]] = comb.replicate %[[EXT_A]] {smax = 7 : i4, smin = -8 : i4, umax = 15 : ui4, umin = 0 : ui4} : (i2) -> i4 + // CHECK-NEXT: %[[REPL_SEL:.+]] = comb.replicate %sel {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i1) -> i2 + // CHECK-NEXT: %[[EXT_OUT:.+]] = comb.extract %[[REPL_A]] from 1 {smax = 1 : i2, smin = -2 : i2, umax = 3 : ui2, umin = 0 : ui2} : (i4) -> i2 + %c0_i2 = hw.constant 0 : i2 + %0 = comb.extract %a from 1 : (i3) -> i2 + %1 = comb.replicate %0 : (i2) -> i4 + %2 = comb.replicate %sel : (i1) -> i2 + %3 = comb.extract %1 from 1 : (i4) -> i2 + hw.output +} + +// CHECK-LABEL: @comp_predicates +hw.module @comp_predicates(in %a : i3, in %b : i3) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 {smax = 0 : i2, smin = 0 : i2, umax = 0 : ui2, umin = 0 : ui2} + // CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 {smax = 0 : i3, smin = 0 : i3, umax = 0 : ui3, umin = 0 : ui3} + // CHECK-NEXT: %c-1_i3 = hw.constant -1 : i3 {smax = -1 : i3, smin = -1 : i3, umax = 7 : ui3, umin = 7 : ui3} + // CHECK-NEXT: %[[ULT:.+]] = comb.icmp ult %a, %c-1_i3 {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[ULE:.+]] = comb.icmp ule %a, %c-1_i3 {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} : i3 + // CHECK-NEXT: %[[UGT:.+]] = comb.icmp ugt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[UGE:.+]] = comb.icmp uge %a, %c0_i3 {smax = true, smin = true, umax = 1 : ui1, umin = 1 : ui1} : i3 + // CHECK-NEXT: %[[SLT:.+]] = comb.icmp slt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[SLE:.+]] = comb.icmp sle %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[SGT:.+]] = comb.icmp sgt %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[SGE:.+]] = comb.icmp sge %a, %c0_i3 {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[EQ:.+]] = comb.icmp eq %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + // CHECK-NEXT: %[[NE:.+]] = comb.icmp ne %a, %b {smax = false, smin = true, umax = 1 : ui1, umin = 0 : ui1} : i3 + %c0_i2 = hw.constant 0 : i2 + %c0_i3 = hw.constant 0 : i3 + %c7_i3 = hw.constant 7 : i3 + %0 = comb.icmp ult %a, %c7_i3 : i3 + %1 = comb.icmp ule %a, %c7_i3 : i3 + %2 = comb.icmp ugt %a, %b : i3 + %3 = comb.icmp uge %a, %c0_i3 : i3 + %4 = comb.icmp slt %a, %b : i3 + %5 = comb.icmp sle %a, %b : i3 + %6 = comb.icmp sgt %a, %b : i3 + %7 = comb.icmp sge %a, %c0_i3 : i3 + %8 = comb.icmp eq %a, %b : i3 + %9 = comb.icmp ne %a, %b : i3 + hw.output +} diff --git a/test/Dialect/Comb/comb-int-range-narrowing.mlir b/test/Dialect/Comb/comb-int-range-narrowing.mlir new file mode 100644 index 000000000000..26c68e5f9279 --- /dev/null +++ b/test/Dialect/Comb/comb-int-range-narrowing.mlir @@ -0,0 +1,125 @@ +// RUN: circt-opt %s --comb-int-range-narrowing | FileCheck %s + +// CHECK-LABEL: @basic_csa +hw.module @basic_csa(in %a : i1, in %b : i1, in %c : i1, out add_abc : i3) { + // CHECK-NEXT %c0_i2 = hw.constant 0 : i2 + // CHECK-NEXT %false = hw.constant false + // CHECK-NEXT %[[A_EXT:.+]] = comb.concat %false, %a : i1, i1 + // CHECK-NEXT %[[B_EXT:.+]] = comb.concat %false, %b : i1, i1 + // CHECK-NEXT %[[ADD_2:.+]] = comb.add %[[A_EXT]], %[[B_EXT]] : i2 + // CHECK-NEXT %[[ADD_2_EXT:.+]] = comb.concat %false, %[[ADD_2]] : i1, i2 + // CHECK-NEXT %[[C_EXT:.+]] = comb.concat %c0_i2, %c : i2, i1 + // CHECK-NEXT %[[ADD_2_2:.+]] = comb.extract %[[ADD_2_EXT]] from 0 : (i3) -> i2 + // CHECK-NEXT %[[C_2:.+]] = comb.extract %[[C_EXT]] from 0 : (i3) -> i2 + // CHECK-NEXT %[[ADD_3:.+]] = comb.add %[[ADD_2_2]], %[[C_2]] : i2 + // CHECK-NEXT %[[RES:.+]] = comb.concat %false, %[[ADD_3]] : i1, i2 + // CHECK-NEXT hw.output %[[RES]] : i3 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %0 = comb.concat %false, %a : i1, i1 + %1 = comb.concat %false, %b : i1, i1 + %2 = comb.add %0, %1 : i2 + %3 = comb.concat %false, %2 : i1, i2 + %4 = comb.concat %c0_i2, %c : i2, i1 + %5 = comb.add %3, %4 : i3 + hw.output %5 : i3 +} + +// CHECK-LABEL: @basic_fma +hw.module @basic_fma(in %a : i4, in %b : i4, in %c : i4, out d : i9) { + // CHECK-NEXT: %false = hw.constant false + // CHECK-NEXT: %c0_i5 = hw.constant 0 : i5 + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i5, %a : i5, i4 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i5, %b : i5, i4 + // CHECK-NEXT: %[[A:.+]] = comb.extract %[[A_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[B:.+]] = comb.extract %[[B_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A]], %[[B]] : i8 + // CHECK-NEXT: %[[MUL_EXT:.+]] = comb.concat %false, %[[MUL]] : i1, i8 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i5, %c : i5, i4 + // CHECK-NEXT: %[[MUL_T:.+]] = comb.extract %[[MUL_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[C_T:.+]] = comb.extract %[[C_EXT]] from 0 : (i9) -> i8 + // CHECK-NEXT: %[[ADD_OUT:.+]] = comb.add %[[MUL_T]], %[[C_T]] : i8 + // CHECK-NEXT: %[[ADD_OUT_EXT:.+]] = comb.concat %false, %[[ADD_OUT]] : i1, i8 + // CHECK-NEXT: hw.output %[[ADD_OUT_EXT]] : i9 + %c0_i5 = hw.constant 0 : i5 + %0 = comb.concat %c0_i5, %a : i5, i4 + %1 = comb.concat %c0_i5, %b : i5, i4 + %2 = comb.mul %0, %1 : i9 + %3 = comb.concat %c0_i5, %c : i5, i4 + %4 = comb.add %2, %3 : i9 + hw.output %4 : i9 +} + +// CHECK-LABEL: @const_sub +hw.module @const_sub(in %a : i8, out sub_res : i10) { + // CHECK-NEXT: %false = hw.constant false + // CHECK-NEXT: %c-256_i9 = hw.constant -256 : i9 + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i2, %a : i2, i8 + // CHECK-NEXT: %[[A_T:.+]] = comb.extract %[[A_EXT]] from 0 : (i10) -> i9 + // CHECK-NEXT: %[[SUB:.+]] = comb.sub %c-256_i9, %[[A_T]] : i9 + // CHECK-NEXT: %[[RES:.+]] = comb.concat %false, %[[SUB]] : i1, i9 + // CHECK-NEXT: hw.output %[[RES]] : i10 + %c256_i10 = hw.constant 256 : i10 + %c0_i2 = hw.constant 0 : i2 + %0 = comb.concat %c0_i2, %a : i2, i8 + %1 = comb.sub %c256_i10, %0 : i10 + hw.output %1 : i10 +} + +// CHECK-LABEL: @do_nothing +hw.module @do_nothing(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) { + // CHECK-NEXT: %c0_i2 = hw.constant 0 : i2 + // CHECK-NEXT: %[[FALSE:.+]] = hw.constant false + // CHECK-NEXT: %c0_i9 = hw.constant 0 : i9 + // CHECK-NEXT: %c0_i8 = hw.constant 0 : i8 + // CHECK-NEXT: %[[A_EXT:.+]] = comb.concat %c0_i9, %a : i9, i8 + // CHECK-NEXT: %[[B_EXT:.+]] = comb.concat %c0_i8, %b : i8, i9 + // CHECK-NEXT: %[[MUL:.+]] = comb.mul %[[A_EXT]], %[[B_EXT]] : i17 + // CHECK-NEXT: %[[MUL_EXT:.+]] = comb.concat %[[FALSE]], %[[MUL]] : i1, i17 + // CHECK-NEXT: %[[C_EXT:.+]] = comb.concat %c0_i8, %c : i8, i10 + // CHECK-NEXT: %[[D_EXT:.+]] = comb.concat %c0_i2, %d : i2, i16 + // CHECK-NEXT: %[[RES:.+]] = comb.add %[[MUL_EXT]], %[[C_EXT]], %[[D_EXT]] : i18 + // CHECK-NEXT: hw.output %[[RES]] : i18 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %c0_i9 = hw.constant 0 : i9 + %c0_i8 = hw.constant 0 : i8 + %0 = comb.concat %c0_i9, %a : i9, i8 + %1 = comb.concat %c0_i8, %b : i8, i9 + %2 = comb.mul %0, %1 : i17 + %3 = comb.concat %false, %2 : i1, i17 + %4 = comb.concat %c0_i8, %c : i8, i10 + %5 = comb.concat %c0_i2, %d : i2, i16 + %6 = comb.add %3, %4, %5 : i18 + hw.output %6 : i18 +} + +hw.module @logical_ops(in %a : i8, in %b : i9, in %c : i10, in %d : i16, out res : i18) { + // CHECK-NEXT %c0_i7 = hw.constant 0 : i7 + // CHECK-NEXT %[[FALSE:.+]] = hw.constant false + // CHECK-NEXT %c0_i9 = hw.constant 0 : i9 + // CHECK-NEXT %c0_i8 = hw.constant 0 : i8 + // CHECK-NEXT %[[A_EXT:.+]] = comb.concat %c0_i9, %a : i9, i8 + // CHECK-NEXT %[[B_EXT:.+]] = comb.concat %c0_i8, %b : i8, i9 + // CHECK-NEXT %[[AND:.+]] = comb.and %[[A_EXT]], %[[B_EXT]] : i17 + // CHECK-NEXT %[[C_EXT:.+]] = comb.concat %c0_i7, %c : i7, i10 + // CHECK-NEXT %[[OR:.+]] = comb.or %[[AND]], %[[C_EXT]] : i17 + // CHECK-NEXT %[[D_EXT:.+]] = comb.concat %[[FALSE]], %d : i1, i16 + // CHECK-NEXT %[[ADD:.+]] = comb.add %[[OR]], %[[D_EXT]] : i17 + // CHECK-NEXT %[[RES:.+]] = comb.concat %[[FALSE]], %[[ADD]] : i1, i17 + // CHECK-NEXT hw.output %[[RES]] : i18 + %c0_i2 = hw.constant 0 : i2 + %false = hw.constant false + %c0_i9 = hw.constant 0 : i9 + %c0_i8 = hw.constant 0 : i8 + %0 = comb.concat %c0_i9, %a : i9, i8 + %1 = comb.concat %c0_i8, %b : i8, i9 + %2 = comb.and %0, %1 : i17 + %3 = comb.concat %false, %2 : i1, i17 + %4 = comb.concat %c0_i8, %c : i8, i10 + %5 = comb.or %3, %4 : i18 + %6 = comb.concat %c0_i2, %d : i2, i16 + %7 = comb.add %5, %6 : i18 + hw.output %7 : i18 +}