Skip to content

Commit cd34df2

Browse files
committed
Fix complex py2c conversion to allow for numpy complex types
1 parent 430ca86 commit cd34df2

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

c++/cpp2py/converters/complex.hpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
2-
#include <Python.h>
2+
#include "./../pyref.hpp"
3+
4+
#include <numpy/arrayobject.h>
35

46
namespace cpp2py {
57

@@ -8,6 +10,18 @@ namespace cpp2py {
810
template <> struct py_converter<std::complex<double>> {
911
static PyObject *c2py(std::complex<double> x) { return PyComplex_FromDoubles(x.real(), x.imag()); }
1012
static std::complex<double> py2c(PyObject *ob) {
13+
14+
if (PyArray_CheckScalar(ob)) {
15+
// Convert NPY Scalar Type to Builtin Type
16+
pyref py_builtin = PyObject_CallMethod(ob, "item", NULL);
17+
if (PyComplex_Check(py_builtin)) {
18+
auto r = PyComplex_AsCComplex(py_builtin);
19+
return {r.real, r.imag};
20+
} else {
21+
return PyFloat_AsDouble(py_builtin);
22+
}
23+
}
24+
1125
if (PyComplex_Check(ob)) {
1226
auto r = PyComplex_AsCComplex(ob);
1327
return {r.real, r.imag};
@@ -16,6 +30,10 @@ namespace cpp2py {
1630
}
1731
static bool is_convertible(PyObject *ob, bool raise_exception) {
1832
if (PyComplex_Check(ob) || PyFloat_Check(ob) || PyLong_Check(ob)) return true;
33+
if (PyArray_CheckScalar(ob)) {
34+
pyref py_arr = PyArray_FromScalar(ob, NULL);
35+
if (PyArray_ISINTEGER((PyObject*)py_arr) or PyArray_ISFLOAT((PyObject*)py_arr) or PyArray_ISCOMPLEX((PyObject*)py_arr)) return true;
36+
}
1937
if (raise_exception) { PyErr_SetString(PyExc_TypeError, "Cannot convert to complex"); }
2038
return false;
2139
}

0 commit comments

Comments
 (0)