Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert EmitC types merge #188

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ bool isIntegerIndexOrOpaqueType(Type type);

/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);

/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isAnySizeTType(mlir::Type type);

} // namespace emitc
} // namespace mlir

Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;

// Types only used in binary arithmetic operations.
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index,
EmitC_SignedSizeT, EmitC_SizeT, EmitC_OpaqueType]>;
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;

def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
Expand Down Expand Up @@ -288,7 +287,6 @@ def EmitC_CastOp : EmitC_Op<"cast",
let arguments = (ins EmitCType:$source);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
Expand Down Expand Up @@ -472,7 +470,7 @@ def EmitC_ForOp : EmitC_Op<"for",
upper bound and step respectively, and defines an SSA value for its
induction variable. It has one region capturing the loop body. The induction
variable is represented as an argument of this region. This SSA value is a
signless integer, or an index. The step is a value of same type.
signless integer or index. The step is a value of same type.

This operation has no result. The body region must contain exactly one block
that terminates with `emitc.yield`. Calling ForOp::build will create such a
Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,4 @@ def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
let assemblyFormat = "`<` qualified($pointee) `>`";
}

def EmitC_SignedSizeT : EmitC_Type<"SignedSizeT", "ssize_t"> {
let summary = "EmitC signed size type";
}

def EmitC_SizeT : EmitC_Type<"SizeT", "size_t"> {
let summary = "EmitC unsigned size type";
}

#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES
13 changes: 0 additions & 13 deletions mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h

This file was deleted.

174 changes: 69 additions & 105 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -37,11 +36,8 @@ class ArithConstantOpConversionPattern
matchAndRewrite(arith::ConstantOp arithConst,
arith::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
adaptor.getValue());
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
arithConst, arithConst.getType(), adaptor.getValue());
return success();
}
};
Expand Down Expand Up @@ -205,35 +201,6 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
}
};

/// Check if the signedness of type \p ty matches the expected
/// signedness, and issue a type with the correct signedness if
/// necessary.
Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
if (isa<IntegerType>(ty)) {
// Turns signless integers into signed integers.
if (ty.isUnsignedInteger() != needsUnsigned) {
auto signedness = needsUnsigned
? IntegerType::SignednessSemantics::Unsigned
: IntegerType::SignednessSemantics::Signed;
return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
signedness);
}
} else if (emitc::isAnySizeTType(ty)) {
if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
if (needsUnsigned)
return emitc::SizeTType::get(ty.getContext());
return emitc::SignedSizeTType::get(ty.getContext());
}
}
return ty;
}

/// Insert a cast operation to type \p ty if \p val
/// does not have this type.
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}

class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -283,25 +250,31 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = adaptor.getLhs().getType();
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
type)) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t type");
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer or index type");
}

bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());

Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

Type arithmeticType = type;
if (type.isUnsignedInteger() != needsUnsigned) {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/!needsUnsigned);
}
Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
return success();
}
};

template <typename ArithOp, bool castToUnsigned>
template <typename ArithOp, bool needsUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;
Expand All @@ -311,58 +284,52 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
opReturnType))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t result type");
if (!isa_and_nonnull<IntegerType>(opReturnType)) {
return rewriter.notifyMatchFailure(op, "expected integer result type");
}

if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
op, "CastConversion only supports unary ops");
}

Type operandType = adaptor.getIn().getType();
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
operandType))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t operand type");

// to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
Type attrType = (emitc::isAnySizeTType(operandType))
? rewriter.getIndexType()
: operandType;
auto constOne = rewriter.create<emitc::ConstantOp>(
op.getLoc(), operandType, rewriter.getIntegerAttr(attrType, 1));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
if (!isa_and_nonnull<IntegerType>(operandType)) {
return rewriter.notifyMatchFailure(op, "expected integer operand type");
}

bool isTruncation =
(isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth());
bool doUnsigned = castToUnsigned || isTruncation;

// Adapt the signedness of the result (bitwidth-preserving cast)
// This is needed e.g., if the return type is signless.
Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
bool isTruncation = operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth();
bool doUnsigned = needsUnsigned || isTruncation;

Type castType = opReturnType;
// For int conversions: if the op is a ui variant and the type wanted as
// return type isn't unsigned, we need to issue an unsigned type to do
// the conversion.
if (castType.isUnsignedInteger() != doUnsigned) {
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
}

// Adapt the signedness of the operand (bitwidth-preserving cast)
Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
Value actualOp = adaptor.getIn();
// Fix the signedness of the operand if necessary
if (operandType.isUnsignedInteger() != doUnsigned) {
Type correctSignednessType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
actualOp = rewriter.template create<emitc::CastOp>(
op.getLoc(), correctSignednessType, actualOp);
}

// Actual cast (may change bitwidth)
auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
castDestType, actualOp);
auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
actualOp);

// Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType);
// Fix the signedness of what this operation returns (for integers,
// the arith ops want signless results)
if (castType != opReturnType) {
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
opReturnType, result);
}

rewriter.replaceOp(op, result);
return success();
Expand All @@ -388,11 +355,7 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithOp,
"converting result type failed");
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
adaptor.getOperands());

return success();
Expand All @@ -409,17 +372,17 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
type)) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t type");
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer type");
}

if (type.isInteger(1)) {
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
Expand All @@ -429,15 +392,20 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);

Value arithmeticResult = rewriter.template create<EmitCOp>(
op.getLoc(), arithmeticType, lhs, rhs);

Value result = adaptValueType(arithmeticResult, rewriter, type);
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
arithmeticType, lhs, rhs);

if (arithmeticType != type) {
result =
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
}
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -567,8 +535,6 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();

mlir::populateEmitCSizeTypeConversionPatterns(typeConverter);

// clang-format off
patterns.add<
ArithConstantOpConversionPattern,
Expand All @@ -588,8 +554,6 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
SignedCastConversion<arith::IndexCastOp>,
UnsignedCastConversion<arith::IndexCastUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ add_mlir_conversion_library(MLIRArithToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
MLIREmitCTransforms
MLIRPass
MLIRTransformUtils
)
Loading