Skip to content

Commit 6cfc296

Browse files
authored
Merge pull request #199 from SwayamInSync/198
2 parents f52ffdf + cebefce commit 6cfc296

File tree

2 files changed

+260
-4
lines changed

2 files changed

+260
-4
lines changed

quaddtype/numpy_quaddtype/src/scalar.c

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "scalar.h"
1515
#include "scalar_ops.h"
1616
#include "dragon4.h"
17+
#include "dtype.h"
1718

1819
// For IEEE 754 binary128 (quad precision), we need 36 decimal digits
1920
// to guarantee round-trip conversion (string -> parse -> equals original value)
@@ -42,7 +43,77 @@ QuadPrecision_raw_new(QuadBackendType backend)
4243

4344
QuadPrecisionObject *
4445
QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
45-
{
46+
{
47+
// Handle numpy scalars (np.int32, np.float32, etc.) before arrays
48+
// We need to check this before PySequence_Check because some numpy scalars are sequences
49+
if (PyArray_CheckScalar(value)) {
50+
QuadPrecisionObject *self = QuadPrecision_raw_new(backend);
51+
if (!self)
52+
return NULL;
53+
54+
// Try as floating point first
55+
if (PyArray_IsScalar(value, Floating)) {
56+
PyObject *py_float = PyNumber_Float(value);
57+
if (py_float == NULL) {
58+
Py_DECREF(self);
59+
return NULL;
60+
}
61+
double dval = PyFloat_AsDouble(py_float);
62+
Py_DECREF(py_float);
63+
64+
if (backend == BACKEND_SLEEF) {
65+
self->value.sleef_value = Sleef_cast_from_doubleq1(dval);
66+
}
67+
else {
68+
self->value.longdouble_value = (long double)dval;
69+
}
70+
return self;
71+
}
72+
// Try as integer
73+
else if (PyArray_IsScalar(value, Integer)) {
74+
PyObject *py_int = PyNumber_Long(value);
75+
if (py_int == NULL) {
76+
Py_DECREF(self);
77+
return NULL;
78+
}
79+
long long lval = PyLong_AsLongLong(py_int);
80+
Py_DECREF(py_int);
81+
82+
if (backend == BACKEND_SLEEF) {
83+
self->value.sleef_value = Sleef_cast_from_int64q1(lval);
84+
}
85+
else {
86+
self->value.longdouble_value = (long double)lval;
87+
}
88+
return self;
89+
}
90+
// For other scalar types, fall through to error handling
91+
Py_DECREF(self);
92+
}
93+
94+
// this checks arrays and sequences (array, tuple)
95+
// rejects strings; they're parsed below
96+
if (PyArray_Check(value) || (PySequence_Check(value) && !PyUnicode_Check(value) && !PyBytes_Check(value)))
97+
{
98+
QuadPrecDTypeObject *dtype_descr = new_quaddtype_instance(backend);
99+
if (dtype_descr == NULL) {
100+
return NULL;
101+
}
102+
103+
// steals reference to the descriptor
104+
PyObject *result = PyArray_FromAny(
105+
value,
106+
(PyArray_Descr *)dtype_descr,
107+
0,
108+
0,
109+
NPY_ARRAY_ENSUREARRAY, // this should handle the casting if possible
110+
NULL
111+
);
112+
113+
// PyArray_FromAny steals the reference to dtype_descr, so no need to DECREF
114+
return (QuadPrecisionObject *)result;
115+
}
116+
46117
QuadPrecisionObject *self = QuadPrecision_raw_new(backend);
47118
if (!self)
48119
return NULL;
@@ -105,21 +176,21 @@ QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
105176
const char *type_cstr = PyUnicode_AsUTF8(type_str);
106177
if (type_cstr != NULL) {
107178
PyErr_Format(PyExc_TypeError,
108-
"QuadPrecision value must be a quad, float, int or string, but got %s "
179+
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got %s "
109180
"instead",
110181
type_cstr);
111182
}
112183
else {
113184
PyErr_SetString(
114185
PyExc_TypeError,
115-
"QuadPrecision value must be a quad, float, int or string, but got an "
186+
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got an "
116187
"unknown type instead");
117188
}
118189
Py_DECREF(type_str);
119190
}
120191
else {
121192
PyErr_SetString(PyExc_TypeError,
122-
"QuadPrecision value must be a quad, float, int or string, but got an "
193+
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got an "
123194
"unknown type instead");
124195
}
125196
Py_DECREF(self);

quaddtype/tests/test_quaddtype.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,191 @@ def test_create_scalar_simple():
1313
assert isinstance(QuadPrecision(1), QuadPrecision)
1414

1515

16+
class TestQuadPrecisionArrayCreation:
17+
"""Test suite for QuadPrecision array creation from sequences and arrays."""
18+
19+
def test_create_array_from_list(self):
20+
"""Test that QuadPrecision can create arrays from lists."""
21+
# Test with simple list
22+
result = QuadPrecision([3, 4, 5])
23+
assert isinstance(result, np.ndarray)
24+
assert result.dtype.name == "QuadPrecDType128"
25+
assert result.shape == (3,)
26+
np.testing.assert_array_equal(result, np.array([3, 4, 5], dtype=QuadPrecDType(backend='sleef')))
27+
28+
# Test with float list
29+
result = QuadPrecision([1.5, 2.5, 3.5])
30+
assert isinstance(result, np.ndarray)
31+
assert result.dtype.name == "QuadPrecDType128"
32+
assert result.shape == (3,)
33+
np.testing.assert_array_equal(result, np.array([1.5, 2.5, 3.5], dtype=QuadPrecDType(backend='sleef')))
34+
35+
def test_create_array_from_tuple(self):
36+
"""Test that QuadPrecision can create arrays from tuples."""
37+
result = QuadPrecision((10, 20, 30))
38+
assert isinstance(result, np.ndarray)
39+
assert result.dtype.name == "QuadPrecDType128"
40+
assert result.shape == (3,)
41+
np.testing.assert_array_equal(result, np.array([10, 20, 30], dtype=QuadPrecDType(backend='sleef')))
42+
43+
def test_create_array_from_ndarray(self):
44+
"""Test that QuadPrecision can create arrays from numpy arrays."""
45+
arr = np.array([1, 2, 3, 4])
46+
result = QuadPrecision(arr)
47+
assert isinstance(result, np.ndarray)
48+
assert result.dtype.name == "QuadPrecDType128"
49+
assert result.shape == (4,)
50+
np.testing.assert_array_equal(result, arr.astype(QuadPrecDType(backend='sleef')))
51+
52+
def test_create_2d_array_from_nested_list(self):
53+
"""Test that QuadPrecision can create 2D arrays from nested lists."""
54+
result = QuadPrecision([[1, 2], [3, 4]])
55+
assert isinstance(result, np.ndarray)
56+
assert result.dtype.name == "QuadPrecDType128"
57+
assert result.shape == (2, 2)
58+
expected = np.array([[1, 2], [3, 4]], dtype=QuadPrecDType(backend='sleef'))
59+
np.testing.assert_array_equal(result, expected)
60+
61+
def test_create_array_with_backend(self):
62+
"""Test that QuadPrecision respects backend parameter for arrays."""
63+
# Test with sleef backend (default)
64+
result_sleef = QuadPrecision([1, 2, 3], backend='sleef')
65+
assert isinstance(result_sleef, np.ndarray)
66+
assert result_sleef.dtype == QuadPrecDType(backend='sleef')
67+
68+
# Test with longdouble backend
69+
result_ld = QuadPrecision([1, 2, 3], backend='longdouble')
70+
assert isinstance(result_ld, np.ndarray)
71+
assert result_ld.dtype == QuadPrecDType(backend='longdouble')
72+
73+
def test_quad_precision_array_vs_astype_equivalence(self):
74+
"""Test that QuadPrecision(array) is equivalent to array.astype(QuadPrecDType)."""
75+
test_arrays = [
76+
[1, 2, 3],
77+
[1.5, 2.5, 3.5],
78+
[[1, 2], [3, 4]],
79+
np.array([10, 20, 30]),
80+
]
81+
82+
for arr in test_arrays:
83+
result_quad = QuadPrecision(arr)
84+
result_astype = np.array(arr).astype(QuadPrecDType(backend='sleef'))
85+
np.testing.assert_array_equal(result_quad, result_astype)
86+
assert result_quad.dtype == result_astype.dtype
87+
88+
def test_create_empty_array(self):
89+
"""Test that QuadPrecision can create arrays from empty sequences."""
90+
result = QuadPrecision([])
91+
assert isinstance(result, np.ndarray)
92+
assert result.dtype.name == "QuadPrecDType128"
93+
assert result.shape == (0,)
94+
expected = np.array([], dtype=QuadPrecDType(backend='sleef'))
95+
np.testing.assert_array_equal(result, expected)
96+
97+
def test_create_from_numpy_int_scalars(self):
98+
"""Test that QuadPrecision can create scalars from numpy integer types."""
99+
# Test np.int32
100+
result = QuadPrecision(np.int32(42))
101+
assert isinstance(result, QuadPrecision)
102+
assert float(result) == 42.0
103+
104+
# Test np.int64
105+
result = QuadPrecision(np.int64(100))
106+
assert isinstance(result, QuadPrecision)
107+
assert float(result) == 100.0
108+
109+
# Test np.uint32
110+
result = QuadPrecision(np.uint32(255))
111+
assert isinstance(result, QuadPrecision)
112+
assert float(result) == 255.0
113+
114+
# Test np.int8
115+
result = QuadPrecision(np.int8(-128))
116+
assert isinstance(result, QuadPrecision)
117+
assert float(result) == -128.0
118+
119+
def test_create_from_numpy_float_scalars(self):
120+
"""Test that QuadPrecision can create scalars from numpy floating types."""
121+
# Test np.float64
122+
result = QuadPrecision(np.float64(3.14))
123+
assert isinstance(result, QuadPrecision)
124+
assert abs(float(result) - 3.14) < 1e-10
125+
126+
# Test np.float32
127+
result = QuadPrecision(np.float32(2.71))
128+
assert isinstance(result, QuadPrecision)
129+
# Note: float32 has limited precision, so we use a looser tolerance
130+
assert abs(float(result) - 2.71) < 1e-5
131+
132+
# Test np.float16
133+
result = QuadPrecision(np.float16(1.5))
134+
assert isinstance(result, QuadPrecision)
135+
assert abs(float(result) - 1.5) < 1e-3
136+
137+
def test_create_from_zero_dimensional_array(self):
138+
"""Test that QuadPrecision can create from 0-d numpy arrays."""
139+
# 0-d array from scalar
140+
arr_0d = np.array(5.5)
141+
result = QuadPrecision(arr_0d)
142+
assert isinstance(result, np.ndarray)
143+
assert result.shape == () # 0-d array
144+
assert result.dtype.name == "QuadPrecDType128"
145+
expected = np.array(5.5, dtype=QuadPrecDType(backend='sleef'))
146+
np.testing.assert_array_equal(result, expected)
147+
148+
# Another test with integer
149+
arr_0d = np.array(42)
150+
result = QuadPrecision(arr_0d)
151+
assert isinstance(result, np.ndarray)
152+
assert result.shape == ()
153+
expected = np.array(42, dtype=QuadPrecDType(backend='sleef'))
154+
np.testing.assert_array_equal(result, expected)
155+
156+
def test_numpy_scalar_with_backend(self):
157+
"""Test that numpy scalars respect the backend parameter."""
158+
# Test with sleef backend
159+
result = QuadPrecision(np.int32(10), backend='sleef')
160+
assert isinstance(result, QuadPrecision)
161+
assert "backend='sleef'" in repr(result)
162+
163+
# Test with longdouble backend
164+
result = QuadPrecision(np.float64(3.14), backend='longdouble')
165+
assert isinstance(result, QuadPrecision)
166+
assert "backend='longdouble'" in repr(result)
167+
168+
def test_numpy_scalar_types_coverage(self):
169+
"""Test a comprehensive set of numpy scalar types."""
170+
# Integer types
171+
int_types = [
172+
(np.int8, 10),
173+
(np.int16, 1000),
174+
(np.int32, 100000),
175+
(np.int64, 10000000),
176+
(np.uint8, 200),
177+
(np.uint16, 50000),
178+
(np.uint32, 4000000000),
179+
]
180+
181+
for dtype, value in int_types:
182+
result = QuadPrecision(dtype(value))
183+
assert isinstance(result, QuadPrecision), f"Failed for {dtype.__name__}"
184+
assert float(result) == float(value), f"Value mismatch for {dtype.__name__}"
185+
186+
# Float types
187+
float_types = [
188+
(np.float16, 1.5),
189+
(np.float32, 2.5),
190+
(np.float64, 3.5),
191+
]
192+
193+
for dtype, value in float_types:
194+
result = QuadPrecision(dtype(value))
195+
assert isinstance(result, QuadPrecision), f"Failed for {dtype.__name__}"
196+
# Use appropriate tolerance based on dtype precision
197+
expected = float(dtype(value))
198+
assert abs(float(result) - expected) < 1e-5, f"Value mismatch for {dtype.__name__}"
199+
200+
16201
def test_string_roundtrip():
17202
# Test with various values that require full quad precision
18203
test_values = [

0 commit comments

Comments
 (0)