Skip to content

Commit e4da49f

Browse files
authored
[CIR] Upstream __imag__ for ComplexType (#144262)
This change adds support for `__imag__` for ComplexType #141365
1 parent 585ed21 commit e4da49f

File tree

10 files changed

+191
-4
lines changed

10 files changed

+191
-4
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,6 +2400,35 @@ def ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
24002400
let hasFolder = 1;
24012401
}
24022402

2403+
//===----------------------------------------------------------------------===//
2404+
// ComplexImagOp
2405+
//===----------------------------------------------------------------------===//
2406+
2407+
def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
2408+
let summary = "Extract the imaginary part of a complex value";
2409+
let description = [{
2410+
`cir.complex.imag` operation takes an operand of `!cir.complex` type and
2411+
yields the imaginary part of it.
2412+
2413+
Example:
2414+
2415+
```mlir
2416+
%1 = cir.complex.imag %0 : !cir.complex<!cir.float> -> !cir.float
2417+
```
2418+
}];
2419+
2420+
let results = (outs CIR_AnyIntOrFloatType:$result);
2421+
let arguments = (ins CIR_ComplexType:$operand);
2422+
2423+
let assemblyFormat = [{
2424+
$operand `:` qualified(type($operand)) `->` qualified(type($result))
2425+
attr-dict
2426+
}];
2427+
2428+
let hasVerifier = 1;
2429+
let hasFolder = 1;
2430+
}
2431+
24032432
//===----------------------------------------------------------------------===//
24042433
// Assume Operations
24052434
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,11 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
371371
return create<cir::ComplexRealOp>(loc, operandTy.getElementType(), operand);
372372
}
373373

374+
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
375+
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
376+
return create<cir::ComplexImagOp>(loc, operandTy.getElementType(), operand);
377+
}
378+
374379
/// Create a cir.ptr_stride operation to get access to an array element.
375380
/// \p idx is the index of the element to access, \p shouldDecay is true if
376381
/// the result should decay to a pointer to the element type.

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
605605

606606
mlir::Value VisitUnaryReal(const UnaryOperator *e);
607607

608+
mlir::Value VisitUnaryImag(const UnaryOperator *e);
609+
608610
mlir::Value VisitCXXThisExpr(CXXThisExpr *te) { return cgf.loadCXXThis(); }
609611

610612
/// Emit a conversion from the specified type to the specified destination
@@ -1914,6 +1916,27 @@ mlir::Value ScalarExprEmitter::VisitUnaryReal(const UnaryOperator *e) {
19141916
return Visit(op);
19151917
}
19161918

1919+
mlir::Value ScalarExprEmitter::VisitUnaryImag(const UnaryOperator *e) {
1920+
// TODO(cir): handle scalar promotion.
1921+
Expr *op = e->getSubExpr();
1922+
if (op->getType()->isAnyComplexType()) {
1923+
// If it's an l-value, load through the appropriate subobject l-value.
1924+
// Note that we have to ask `e` because `op` might be an l-value that
1925+
// this won't work for, e.g. an Obj-C property.
1926+
if (e->isGLValue()) {
1927+
mlir::Location loc = cgf.getLoc(e->getExprLoc());
1928+
mlir::Value complex = cgf.emitComplexExpr(op);
1929+
return cgf.builder.createComplexImag(loc, complex);
1930+
}
1931+
1932+
// Otherwise, calculate and project.
1933+
cgf.cgm.errorNYI(e->getSourceRange(),
1934+
"VisitUnaryImag calculate and project");
1935+
}
1936+
1937+
return Visit(op);
1938+
}
1939+
19171940
/// Return the size or alignment of the type of argument of the sizeof
19181941
/// expression as an integer.
19191942
mlir::Value ScalarExprEmitter::VisitUnaryExprOrTypeTraitExpr(

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,24 @@ OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
19321932
return complex ? complex.getReal() : nullptr;
19331933
}
19341934

1935+
//===----------------------------------------------------------------------===//
1936+
// ComplexImagOp
1937+
//===----------------------------------------------------------------------===//
1938+
1939+
LogicalResult cir::ComplexImagOp::verify() {
1940+
if (getType() != getOperand().getType().getElementType()) {
1941+
emitOpError() << ": result type does not match operand type";
1942+
return failure();
1943+
}
1944+
return success();
1945+
}
1946+
1947+
OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
1948+
auto complex =
1949+
mlir::cast_if_present<cir::ConstComplexAttr>(adaptor.getOperand());
1950+
return complex ? complex.getImag() : nullptr;
1951+
}
1952+
19351953
//===----------------------------------------------------------------------===//
19361954
// TableGen'd op method definitions
19371955
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ void CIRCanonicalizePass::runOnOperation() {
141141
// Many operations are here to perform a manual `fold` in
142142
// applyOpPatternsGreedily.
143143
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
144-
ComplexCreateOp, ComplexRealOp, VecCmpOp, VecCreateOp, VecExtractOp,
145-
VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op))
144+
ComplexCreateOp, ComplexImagOp, ComplexRealOp, VecCmpOp,
145+
VecCreateOp, VecExtractOp, VecShuffleOp, VecShuffleDynamicOp,
146+
VecTernaryOp>(op))
146147
ops.push_back(op);
147148
});
148149

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1904,7 +1904,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
19041904
CIRToLLVMVecShuffleDynamicOpLowering,
19051905
CIRToLLVMVecTernaryOpLowering,
19061906
CIRToLLVMComplexCreateOpLowering,
1907-
CIRToLLVMComplexRealOpLowering
1907+
CIRToLLVMComplexRealOpLowering,
1908+
CIRToLLVMComplexImagOpLowering
19081909
// clang-format on
19091910
>(converter, patterns.getContext());
19101911

@@ -2217,6 +2218,15 @@ mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
22172218
return mlir::success();
22182219
}
22192220

2221+
mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
2222+
cir::ComplexImagOp op, OpAdaptor adaptor,
2223+
mlir::ConversionPatternRewriter &rewriter) const {
2224+
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2225+
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2226+
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{1});
2227+
return mlir::success();
2228+
}
2229+
22202230
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
22212231
return std::make_unique<ConvertCIRToLLVMPass>();
22222232
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,16 @@ class CIRToLLVMComplexRealOpLowering
453453
mlir::ConversionPatternRewriter &) const override;
454454
};
455455

456+
class CIRToLLVMComplexImagOpLowering
457+
: public mlir::OpConversionPattern<cir::ComplexImagOp> {
458+
public:
459+
using mlir::OpConversionPattern<cir::ComplexImagOp>::OpConversionPattern;
460+
461+
mlir::LogicalResult
462+
matchAndRewrite(cir::ComplexImagOp op, OpAdaptor,
463+
mlir::ConversionPatternRewriter &) const override;
464+
};
465+
456466
} // namespace direct
457467
} // namespace cir
458468

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,29 @@ void foo9(double a, double b) {
216216
// OGCG: store double %[[TMP_A]], ptr %[[C_REAL_PTR]], align 8
217217
// OGCG: store double %[[TMP_B]], ptr %[[C_IMAG_PTR]], align 8
218218

219+
void foo12() {
220+
double _Complex c;
221+
double imag = __imag__ c;
222+
}
223+
224+
// CIR: %[[COMPLEX:.*]] = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
225+
// CIR: %[[INIT:.*]] = cir.alloca !cir.double, !cir.ptr<!cir.double>, ["imag", init]
226+
// CIR: %[[TMP:.*]] = cir.load{{.*}} %[[COMPLEX]] : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
227+
// CIR: %[[IMAG:.*]] = cir.complex.imag %[[TMP]] : !cir.complex<!cir.double> -> !cir.double
228+
// CIR: cir.store{{.*}} %[[IMAG]], %[[INIT]] : !cir.double, !cir.ptr<!cir.double>
229+
230+
// LLVM: %[[COMPLEX:.*]] = alloca { double, double }, i64 1, align 8
231+
// LLVM: %[[INIT:.*]] = alloca double, i64 1, align 8
232+
// LLVM: %[[TMP:.*]] = load { double, double }, ptr %[[COMPLEX]], align 8
233+
// LLVM: %[[IMAG:.*]] = extractvalue { double, double } %[[TMP]], 1
234+
// LLVM: store double %[[IMAG]], ptr %[[INIT]], align 8
235+
236+
// OGCG: %[[COMPLEX:.*]] = alloca { double, double }, align 8
237+
// OGCG: %[[INIT:.*]] = alloca double, align 8
238+
// OGCG: %[[IMAG:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 1
239+
// OGCG: %[[TMP:.*]] = load double, ptr %[[IMAG]], align 8
240+
// OGCG: store double %[[TMP]], ptr %[[INIT]], align 8
241+
219242
void foo13() {
220243
double _Complex c;
221244
double real = __real__ c;
@@ -281,6 +304,39 @@ void foo15() {
281304
// OGCG: store i32 %[[A_REAL]], ptr %[[B_REAL_PTR]], align 4
282305
// OGCG: store i32 %[[A_IMAG]], ptr %[[B_IMAG_PTR]], align 4
283306

307+
int foo16(int _Complex a, int _Complex b) {
308+
return __imag__ a + __imag__ b;
309+
}
310+
311+
// CIR: %[[RET:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"]
312+
// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
313+
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[COMPLEX_A]] : !cir.complex<!s32i> -> !s32i
314+
// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
315+
// CIR: %[[B_IMAG:.*]] = cir.complex.imag %[[COMPLEX_B]] : !cir.complex<!s32i> -> !s32i
316+
// CIR: %[[ADD:.*]] = cir.binop(add, %[[A_IMAG]], %[[B_IMAG]]) nsw : !s32i
317+
// CIR: cir.store %[[ADD]], %[[RET]] : !s32i, !cir.ptr<!s32i>
318+
// CIR: %[[TMP:.*]] = cir.load %[[RET]] : !cir.ptr<!s32i>, !s32i
319+
// CIR: cir.return %[[TMP]] : !s32i
320+
321+
// LLVM: %[[RET:.*]] = alloca i32, i64 1, align 4
322+
// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
323+
// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
324+
// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
325+
// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
326+
// LLVM: %[[ADD:.*]] = add nsw i32 %[[A_IMAG]], %[[B_IMAG]]
327+
// LLVM: store i32 %[[ADD]], ptr %[[RET]], align 4
328+
// LLVM: %[[TMP:.*]] = load i32, ptr %[[RET]], align 4
329+
// LLVM: ret i32 %[[TMP]]
330+
331+
// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
332+
// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
333+
// OGCG: %[[A_IMAG:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
334+
// OGCG: %[[TMP_A:.*]] = load i32, ptr %[[A_IMAG]], align 4
335+
// OGCG: %[[B_IMAG:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
336+
// OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_IMAG]], align 4
337+
// OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
338+
// OGCG: ret i32 %[[ADD]]
339+
284340
int foo17(int _Complex a, int _Complex b) {
285341
return __real__ a + __real__ b;
286342
}
@@ -312,4 +368,4 @@ int foo17(int _Complex a, int _Complex b) {
312368
// OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
313369
// OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4
314370
// OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
315-
// OGCG: ret i32 %[[ADD]]
371+
// OGCG: ret i32 %[[ADD]]

clang/test/CIR/IR/invalid-complex.cir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,15 @@ module {
3333
cir.return
3434
}
3535
}
36+
37+
// -----
38+
39+
module {
40+
cir.func @complex_imag_invalid_result_type() -> !cir.double {
41+
%0 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
42+
%2 = cir.load align(8) %0 : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
43+
// expected-error @below {{result type does not match operand type}}
44+
%3 = cir.complex.imag %2 : !cir.complex<!cir.double> -> !cir.float
45+
cir.return
46+
}
47+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @fold_complex_imag_test() -> !s32i {
7+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"]
8+
%2 = cir.const #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
9+
%4 = cir.complex.imag %2 : !cir.complex<!s32i> -> !s32i
10+
cir.store %4, %0 : !s32i, !cir.ptr<!s32i>
11+
%5 = cir.load %0 : !cir.ptr<!s32i>, !s32i
12+
cir.return %5 : !s32i
13+
}
14+
15+
// CHECK: cir.func @fold_complex_imag_test() -> !s32i {
16+
// CHECK: %[[RET:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"]
17+
// CHECK: %[[IMAG:.*]] = cir.const #cir.int<2> : !s32i
18+
// CHECK: cir.store %[[IMAG]], %[[RET]] : !s32i, !cir.ptr<!s32i>
19+
// CHECK: %[[TMP:.]] = cir.load %[[RET]] : !cir.ptr<!s32i>, !s32i
20+
// CHECK: cir.return %[[TMP]] : !s32i
21+
// CHECK: }
22+
23+
}

0 commit comments

Comments
 (0)