@@ -1496,29 +1496,118 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
14961496 }
14971497};
14981498
1499+ class CIRSwitchOpLowering : public mlir ::OpConversionPattern<cir::SwitchOp> {
1500+ public:
1501+ using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1502+
1503+ mlir::LogicalResult
1504+ matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1505+ mlir::ConversionPatternRewriter &rewriter) const override {
1506+ rewriter.setInsertionPointAfter (op);
1507+ llvm::SmallVector<CaseOp> cases;
1508+ if (!op.isSimpleForm (cases))
1509+ llvm_unreachable (" NYI" );
1510+
1511+ llvm::SmallVector<int64_t > caseValues;
1512+ // Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1513+ // This is necessary because some CaseOp might carry 0 or multiple values.
1514+ llvm::DenseMap<size_t , unsigned > indexMap;
1515+ caseValues.reserve (cases.size ());
1516+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1517+ switch (caseOp.getKind ()) {
1518+ case CaseOpKind::Equal: {
1519+ auto valueAttr = caseOp.getValue ()[0 ];
1520+ auto value = cast<cir::IntAttr>(valueAttr);
1521+ indexMap[i] = caseValues.size ();
1522+ caseValues.push_back (value.getUInt ());
1523+ break ;
1524+ }
1525+ case CaseOpKind::Default:
1526+ break ;
1527+ case CaseOpKind::Range:
1528+ case CaseOpKind::Anyof:
1529+ llvm_unreachable (" NYI" );
1530+ }
1531+ }
1532+
1533+ auto operand = adaptor.getOperands ()[0 ];
1534+ // `scf.index_switch` expects an index of type `index`.
1535+ auto indexType = mlir::IndexType::get (getContext ());
1536+ auto indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1537+ op.getLoc (), indexType, operand);
1538+ auto indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1539+ op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1540+
1541+ bool metDefault = false ;
1542+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1543+ auto ®ion = caseOp.getRegion ();
1544+ switch (caseOp.getKind ()) {
1545+ case CaseOpKind::Equal: {
1546+ auto &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1547+ rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1548+ break ;
1549+ }
1550+ case CaseOpKind::Default: {
1551+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1552+ rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1553+ metDefault = true ;
1554+ break ;
1555+ }
1556+ case CaseOpKind::Range:
1557+ case CaseOpKind::Anyof:
1558+ llvm_unreachable (" NYI" );
1559+ }
1560+ }
1561+
1562+ // `scf.index_switch` expects its default region to contain exactly one
1563+ // block. If we don't have a default region in `cir.switch`, we need to
1564+ // supply it here.
1565+ if (!metDefault) {
1566+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1567+ mlir::Block *block =
1568+ rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1569+ rewriter.setInsertionPointToEnd (block);
1570+ rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1571+ }
1572+
1573+ // The final `cir.break` should be replaced to `scf.yield`.
1574+ // After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1575+ for (auto ®ion : indexSwitch.getCaseRegions ()) {
1576+ auto &lastBlock = region.back ();
1577+ auto &lastOp = lastBlock.back ();
1578+ assert (isa<BreakOp>(lastOp));
1579+ rewriter.setInsertionPointAfter (&lastOp);
1580+ rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1581+ }
1582+
1583+ rewriter.replaceOp (op, indexSwitch);
1584+
1585+ return mlir::success ();
1586+ }
1587+ };
1588+
14991589void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
15001590 mlir::TypeConverter &converter) {
15011591 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
15021592
1503- patterns
1504- .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1505- CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1506- CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1507- CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering,
1508- CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1509- CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1510- CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1511- CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1512- CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1513- CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1514- CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1515- CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1516- CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1517- CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1518- CIRVectorInsertLowering, CIRVectorExtractLowering,
1519- CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering,
1520- CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(
1521- converter, patterns.getContext ());
1593+ patterns.add <
1594+ CIRSwitchOpLowering, CIRGetElementOpLowering, CIRATanOpLowering,
1595+ CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1596+ CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1597+ CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1598+ CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1599+ CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1600+ CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1601+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1602+ CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1603+ CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1604+ CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1605+ CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1606+ CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1607+ CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1608+ CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1609+ CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1610+ CIRTrapOpLowering>(converter, patterns.getContext ());
15221611}
15231612
15241613static mlir::TypeConverter prepareTypeConverter () {
@@ -1624,6 +1713,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
16241713
16251714 mlir::PassManager pm (mlirCtx);
16261715
1716+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
16271717 pm.addPass (createConvertCIRToMLIRPass ());
16281718 pm.addPass (createConvertMLIRToLLVMPass ());
16291719
@@ -1669,6 +1759,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
16691759
16701760 mlir::PassManager pm (mlirCtx);
16711761
1762+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
16721763 pm.addPass (createConvertCIRToMLIRPass ());
16731764
16741765 auto result = !mlir::failed (pm.run (theModule));
0 commit comments