Skip to content

Commit a96120e

Browse files
authored
Preparation for list-struct backend for NestedExtensionArray (#242)
* Allow NestedDtype initialization from pd.ArrowDtype * NestedDtype.inner_dtypes * .nest.to_flat to respect inner_dtypes * Allow ArrowDtype in NestedDtype.from_fields * Handle and derive inner_dtypes * Fix a typo in variable name * Pull changes from inner-dtypes branch * NestedDtype.to_pandas_arrow_dtype(list_struct: bool) * Building NestedDtype from list-struct * NestedExtensionArray._pa_table and ._pa_struct_array * test NestedExtensionArray.to_pyarrow_scalar * Add to_pyarrow_scalar to API docs
1 parent 626718e commit a96120e

File tree

6 files changed

+204
-21
lines changed

6 files changed

+204
-21
lines changed

docs/reference/ext_array.rst

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Functions
2323
series.ext_array.NestedExtensionArray.dropna
2424
series.ext_array.NestedExtensionArray.from_sequence
2525
series.ext_array.NestedExtensionArray.to_arrow_ext_array
26+
series.ext_array.NestedExtensionArray.to_pyarrow_scalar
2627
series.ext_array.NestedExtensionArray.iter_field_lists
2728
series.ext_array.NestedExtensionArray.view_fields
2829
series.ext_array.NestedExtensionArray.set_flat_field

src/nested_pandas/series/dtype.py

+55-18
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,24 @@
1313
from pandas.core.arrays import ExtensionArray
1414
from pandas.core.dtypes.base import ExtensionDtype
1515

16-
from nested_pandas.series.utils import is_pa_type_a_list
16+
from nested_pandas.series.utils import (
17+
transpose_list_struct_type,
18+
transpose_struct_list_type,
19+
)
1720

1821
__all__ = ["NestedDtype"]
1922

2023

2124
@register_extension_dtype
2225
class NestedDtype(ExtensionDtype):
23-
"""Data type to handle packed time series data"""
26+
"""Data type to handle packed time series data
27+
28+
Parameters
29+
----------
30+
pyarrow_dtype : pyarrow.StructType or pd.ArrowDtype
31+
The pyarrow data type to use for the nested type. It must be a struct
32+
type where all fields are list types.
33+
"""
2434

2535
# ExtensionDtype overrides #
2636

@@ -38,7 +48,12 @@ def na_value(self) -> Type[pd.NA]:
3848
@property
3949
def name(self) -> str:
4050
"""The string representation of the nested type"""
41-
fields = ", ".join([f"{field.name}: [{field.type.value_type!s}]" for field in self.pyarrow_dtype])
51+
# Replace pd.ArrowDtype with pa.DataType, because it has nicer __str__
52+
nice_dtypes = {
53+
field: dtype.pyarrow_dtype if isinstance(dtype, pd.ArrowDtype) else dtype
54+
for field, dtype in self.fields.items()
55+
}
56+
fields = ", ".join([f"{field}: [{dtype!s}]" for field, dtype in nice_dtypes.items()])
4257
return f"nested<{fields}>"
4358

4459
@classmethod
@@ -136,7 +151,7 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray:
136151
pyarrow_dtype: pa.StructType
137152

138153
def __init__(self, pyarrow_dtype: pa.DataType) -> None:
139-
self.pyarrow_dtype = self._validate_dtype(pyarrow_dtype)
154+
self.pyarrow_dtype, self.list_struct_pyarrow_dtype = self._validate_dtype(pyarrow_dtype)
140155

141156
@classmethod
142157
def from_fields(cls, fields: Mapping[str, pa.DataType]) -> Self: # type: ignore[name-defined] # noqa: F821
@@ -168,20 +183,33 @@ def from_fields(cls, fields: Mapping[str, pa.DataType]) -> Self: # type: ignore
168183
return cls(pyarrow_dtype=pyarrow_dtype)
169184

170185
@staticmethod
171-
def _validate_dtype(pyarrow_dtype: pa.DataType) -> pa.StructType:
186+
def _validate_dtype(pyarrow_dtype: pa.DataType) -> tuple[pa.StructType, pa.ListType]:
187+
"""Check that the given pyarrow type is castable to the nested type.
188+
189+
Parameters
190+
----------
191+
pyarrow_dtype : pa.DataType
192+
The pyarrow type to check and cast.
193+
194+
Returns
195+
-------
196+
pa.StructType
197+
Struct-list pyarrow type representing the nested type.
198+
pa.ListType
199+
List-struct pyarrow type representing the nested type.
200+
"""
172201
if not isinstance(pyarrow_dtype, pa.DataType):
173202
raise TypeError(f"Expected a 'pyarrow.DataType' object, got {type(pyarrow_dtype)}")
174-
if not pa.types.is_struct(pyarrow_dtype):
175-
raise ValueError("NestedDtype can only be constructed with pyarrow struct type.")
176-
pyarrow_dtype = cast(pa.StructType, pyarrow_dtype)
177-
178-
for field in pyarrow_dtype:
179-
if not is_pa_type_a_list(field.type):
180-
raise ValueError(
181-
"NestedDtype can only be constructed with pyarrow struct type, all fields must be list "
182-
f"type. Given struct has unsupported field {field}"
183-
)
184-
return pyarrow_dtype
203+
if pa.types.is_struct(pyarrow_dtype):
204+
struct_type = cast(pa.StructType, pyarrow_dtype)
205+
return struct_type, transpose_struct_list_type(struct_type)
206+
# Currently, LongList and others are not supported
207+
if pa.types.is_list(pyarrow_dtype):
208+
list_type = cast(pa.ListType, pyarrow_dtype)
209+
return transpose_list_struct_type(list_type), list_type
210+
raise ValueError(
211+
f"NestedDtype can only be constructed with pa.StructType or pa.ListType only, got {pyarrow_dtype}"
212+
)
185213

186214
@property
187215
def fields(self) -> dict[str, pa.DataType]:
@@ -194,13 +222,14 @@ def field_names(self) -> list[str]:
194222
return [field.name for field in self.pyarrow_dtype]
195223

196224
@classmethod
197-
def from_pandas_arrow_dtype(cls, pandas_arrow_dtype: ArrowDtype):
225+
def from_pandas_arrow_dtype(cls, pandas_arrow_dtype: ArrowDtype) -> Self: # type: ignore[name-defined] # noqa: F821
198226
"""Construct NestedDtype from a pandas.ArrowDtype.
199227
200228
Parameters
201229
----------
202230
pandas_arrow_dtype : ArrowDtype
203231
The pandas.ArrowDtype to construct NestedDtype from.
232+
Must be struct-list or list-struct type.
204233
205234
Returns
206235
-------
@@ -214,12 +243,20 @@ def from_pandas_arrow_dtype(cls, pandas_arrow_dtype: ArrowDtype):
214243
"""
215244
return cls(pyarrow_dtype=pandas_arrow_dtype.pyarrow_dtype)
216245

217-
def to_pandas_arrow_dtype(self) -> ArrowDtype:
246+
def to_pandas_arrow_dtype(self, list_struct: bool = False) -> ArrowDtype:
218247
"""Convert NestedDtype to a pandas.ArrowDtype.
219248
249+
Parameters
250+
----------
251+
list_struct : bool, default False
252+
If False (default) use pyarrow struct-list type,
253+
otherwise use pyarrow list-struct type.
254+
220255
Returns
221256
-------
222257
ArrowDtype
223258
The corresponding pandas.ArrowDtype.
224259
"""
260+
if list_struct:
261+
return ArrowDtype(self.list_struct_pyarrow_dtype)
225262
return ArrowDtype(self.pyarrow_dtype)

src/nested_pandas/series/ext_array.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self: #
225225
pa_array = cls._box_pa_array(scalars, pa_type=pa_type)
226226
return cls(pa_array)
227227

228-
# Tricky to implement, but required by things like pd.read_csv
228+
# Tricky to implement but required by things like pd.read_csv
229229
@classmethod
230230
def _from_sequence_of_strings(cls, strings, *, dtype=None, copy: bool = False) -> Self: # type: ignore[name-defined] # noqa: F821
231231
return super()._from_sequence_of_strings(strings, dtype=dtype, copy=copy)
@@ -680,6 +680,29 @@ def _list_array(self) -> pa.ChunkedArray:
680680
list_chunks.append(transpose_struct_list_array(struct_chunk, validate=False))
681681
return pa.chunked_array(list_chunks)
682682

683+
@property
684+
def _struct_array(self) -> pa.ChunkedArray:
685+
"""Pyarrow chunked struct-list array representation
686+
687+
Returns
688+
-------
689+
pa.ChunkedArray
690+
Pyarrow chunked-array of struct-list arrays.
691+
"""
692+
return self._chunked_array
693+
694+
@property
695+
def _pa_table(self) -> pa.Table:
696+
"""Pyarrow table representation of the extension array.
697+
698+
Returns
699+
-------
700+
pa.Table
701+
Pyarrow table where each column is a list array corresponding
702+
to a field of the struct array.
703+
"""
704+
return pa.Table.from_struct_array(self._struct_array)
705+
683706
@classmethod
684707
def from_sequence(cls, scalars, *, dtype: NestedDtype | pd.ArrowDtype | pa.DataType = None) -> Self: # type: ignore[name-defined] # noqa: F821
685708
"""Construct a NestedExtensionArray from a sequence of items
@@ -755,6 +778,23 @@ def to_arrow_ext_array(self, list_struct: bool = False) -> ArrowExtensionArray:
755778
return ArrowExtensionArray(self._list_array)
756779
return ArrowExtensionArray(self._chunked_array)
757780

781+
def to_pyarrow_scalar(self, list_struct: bool = False) -> pa.ListScalar:
782+
"""Convert to a pyarrow scalar of a list type
783+
784+
Parameters
785+
----------
786+
list_struct : bool, optional
787+
If False (default), return list-struct-list scalar,
788+
otherwise list-list-struct scalar.
789+
790+
Returns
791+
-------
792+
pyarrow.ListScalar
793+
"""
794+
pa_array = self._list_array if list_struct else self._chunked_array
795+
pa_type = pa.list_(pa_array.type)
796+
return cast(pa.ListScalar, pa.scalar(pa_array, type=pa_type))
797+
758798
def _replace_chunked_array(self, pa_array: pa.ChunkedArray, *, validate: bool) -> None:
759799
if validate:
760800
self._validate(pa_array)

src/nested_pandas/series/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def transpose_list_struct_type(t: pa.ListType) -> pa.StructType:
157157
if not is_pa_type_a_list(t):
158158
raise ValueError(f"Expected a ListType, got {t}")
159159

160+
if not pa.types.is_struct(t.value_type):
161+
raise ValueError(f"Expected a StructType as a list value type, got {t.value_type}")
162+
160163
struct_type = cast(pa.StructType, t.value_type)
161164
fields = []
162165
for field in struct_type:

tests/nested_pandas/series/test_dtype.py

+42-2
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,38 @@
1818
),
1919
],
2020
)
21-
def test_from_pyarrow_dtype(pyarrow_dtype):
21+
def test_from_pyarrow_dtype_struct_list(pyarrow_dtype):
2222
"""Test that we can construct NestedDtype from pyarrow struct type."""
2323
dtype = NestedDtype(pyarrow_dtype)
2424
assert dtype.pyarrow_dtype == pyarrow_dtype
2525

2626

27+
@pytest.mark.parametrize(
28+
"pyarrow_dtype",
29+
[
30+
pa.list_(pa.struct([pa.field("a", pa.int64())])),
31+
pa.list_(pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.float64())])),
32+
pa.list_(
33+
pa.struct(
34+
[
35+
pa.field("a", pa.list_(pa.int64())),
36+
pa.field("b", pa.list_(pa.float64())),
37+
]
38+
)
39+
),
40+
],
41+
)
42+
def test_from_pyarrow_dtype_list_struct(pyarrow_dtype):
43+
"""Test that we can construct NestedDtype from pyarrow list type."""
44+
dtype = NestedDtype(pyarrow_dtype)
45+
assert dtype.list_struct_pyarrow_dtype == pyarrow_dtype
46+
47+
2748
@pytest.mark.parametrize(
2849
"pyarrow_dtype",
2950
[
3051
pa.int64(),
3152
pa.list_(pa.int64()),
32-
pa.list_(pa.struct([pa.field("a", pa.int64())])),
3353
pa.struct([pa.field("a", pa.int64())]),
3454
pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.float64())]),
3555
pa.struct([pa.field("a", pa.list_(pa.int64())), pa.field("b", pa.float64())]),
@@ -49,6 +69,26 @@ def test_to_pandas_arrow_dtype():
4969
)
5070

5171

72+
def test_from_pandas_arrow_dtype():
73+
"""Test that we can construct NestedDtype from pandas.ArrowDtype."""
74+
dtype_from_struct = NestedDtype.from_pandas_arrow_dtype(
75+
pd.ArrowDtype(pa.struct([pa.field("a", pa.list_(pa.int64()))]))
76+
)
77+
assert dtype_from_struct.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])
78+
dtype_from_list = NestedDtype.from_pandas_arrow_dtype(
79+
pd.ArrowDtype(pa.list_(pa.struct([pa.field("a", pa.int64())])))
80+
)
81+
assert dtype_from_list.pyarrow_dtype == pa.struct([pa.field("a", pa.list_(pa.int64()))])
82+
83+
84+
def test_to_pandas_list_struct_arrow_dtype():
85+
"""Test that NestedDtype.to_pandas_arrow_dtype(list_struct=True) returns the correct pyarrow type."""
86+
dtype = NestedDtype.from_fields({"a": pa.list_(pa.int64()), "b": pa.float64()})
87+
assert dtype.to_pandas_arrow_dtype(list_struct=True) == pd.ArrowDtype(
88+
pa.list_(pa.struct([pa.field("a", pa.list_(pa.int64())), pa.field("b", pa.float64())]))
89+
)
90+
91+
5292
def test_from_fields():
5393
"""Test NestedDtype.from_fields()."""
5494
fields = {"a": pa.int64(), "b": pa.float64()}

tests/nested_pandas/series/test_ext_array.py

+62
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,38 @@ def test_chunked_list_struct_array():
626626
assert ext_array.chunked_list_struct_array.type == ext_array._pyarrow_list_struct_dtype
627627

628628

629+
def test_to_pyarrow_scalar():
630+
"""Test .to_pyarrow_scalar is correct."""
631+
struct_array = pa.StructArray.from_arrays(
632+
arrays=[
633+
pa.array([np.array([1, 2, 3]), np.array([1, 2, 1])]),
634+
pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])]),
635+
],
636+
names=["a", "b"],
637+
)
638+
ext_array = NestedExtensionArray(struct_array)
639+
640+
desired_struct_list = pa.scalar(
641+
[
642+
{"a": [1, 2, 3], "b": [-4.0, -5.0, -6.0]},
643+
{"a": [1, 2, 1], "b": [-3.0, -4.0, -5.0]},
644+
],
645+
type=pa.list_(
646+
pa.struct([pa.field("a", pa.list_(pa.int64())), pa.field("b", pa.list_(pa.float64()))])
647+
),
648+
)
649+
desired_list_struct = pa.scalar(
650+
[
651+
[{"a": 1, "b": -4.0}, {"a": 2, "b": -5.0}, {"a": 3, "b": -6.0}],
652+
[{"a": 1, "b": -3.0}, {"a": 2, "b": -4.0}, {"a": 1, "b": -5.0}],
653+
],
654+
type=pa.list_(pa.list_(pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.float64())]))),
655+
)
656+
# pyarrow returns a single bool for ==
657+
assert ext_array.to_pyarrow_scalar(list_struct=False) == desired_struct_list
658+
assert ext_array.to_pyarrow_scalar(list_struct=True) == desired_list_struct
659+
660+
629661
def test_list_offsets_single_chunk():
630662
"""Test that the .list_offset property is correct for a single chunk."""
631663
struct_array = pa.StructArray.from_arrays(
@@ -1873,6 +1905,36 @@ def test___init___with_list_struct_array():
18731905
assert pa.array(ext_array) == struct_array
18741906

18751907

1908+
def test__struct_array():
1909+
"""Test ._struct_array property"""
1910+
struct_array = pa.StructArray.from_arrays(
1911+
arrays=[
1912+
pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0, 2.0])]),
1913+
pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]),
1914+
],
1915+
names=["a", "b"],
1916+
)
1917+
ext_array = NestedExtensionArray(struct_array)
1918+
1919+
assert ext_array._struct_array.combine_chunks() == struct_array
1920+
1921+
1922+
def test__pa_table():
1923+
"""Test ._pa_table property"""
1924+
struct_array = pa.StructArray.from_arrays(
1925+
arrays=[
1926+
pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0, 2.0])]),
1927+
pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0, 6.0])]),
1928+
],
1929+
names=["a", "b"],
1930+
)
1931+
ext_array = NestedExtensionArray(struct_array)
1932+
1933+
assert ext_array._pa_table == pa.Table.from_arrays(
1934+
arrays=[struct_array.field("a"), struct_array.field("b")], names=["a", "b"]
1935+
)
1936+
1937+
18761938
def test__from_sequence_of_strings():
18771939
"""We do not support from_sequence_of_strings() which would apply things like pd.read_csv()"""
18781940
with pytest.raises(NotImplementedError):

0 commit comments

Comments
 (0)