13
13
from pandas .core .arrays import ExtensionArray
14
14
from pandas .core .dtypes .base import ExtensionDtype
15
15
16
- from nested_pandas .series .utils import is_pa_type_a_list
16
+ from nested_pandas .series .utils import (
17
+ transpose_list_struct_type ,
18
+ transpose_struct_list_type ,
19
+ )
17
20
18
21
__all__ = ["NestedDtype" ]
19
22
20
23
21
24
@register_extension_dtype
22
25
class NestedDtype (ExtensionDtype ):
23
- """Data type to handle packed time series data"""
26
+ """Data type to handle packed time series data
27
+
28
+ Parameters
29
+ ----------
30
+ pyarrow_dtype : pyarrow.StructType or pd.ArrowDtype
31
+ The pyarrow data type to use for the nested type. It must be a struct
32
+ type where all fields are list types.
33
+ """
24
34
25
35
# ExtensionDtype overrides #
26
36
@@ -38,7 +48,12 @@ def na_value(self) -> Type[pd.NA]:
38
48
@property
39
49
def name (self ) -> str :
40
50
"""The string representation of the nested type"""
41
- fields = ", " .join ([f"{ field .name } : [{ field .type .value_type !s} ]" for field in self .pyarrow_dtype ])
51
+ # Replace pd.ArrowDtype with pa.DataType, because it has nicer __str__
52
+ nice_dtypes = {
53
+ field : dtype .pyarrow_dtype if isinstance (dtype , pd .ArrowDtype ) else dtype
54
+ for field , dtype in self .fields .items ()
55
+ }
56
+ fields = ", " .join ([f"{ field } : [{ dtype !s} ]" for field , dtype in nice_dtypes .items ()])
42
57
return f"nested<{ fields } >"
43
58
44
59
@classmethod
@@ -136,7 +151,7 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray:
136
151
pyarrow_dtype : pa .StructType
137
152
138
153
def __init__ (self , pyarrow_dtype : pa .DataType ) -> None :
139
- self .pyarrow_dtype = self ._validate_dtype (pyarrow_dtype )
154
+ self .pyarrow_dtype , self . list_struct_pyarrow_dtype = self ._validate_dtype (pyarrow_dtype )
140
155
141
156
@classmethod
142
157
def from_fields (cls , fields : Mapping [str , pa .DataType ]) -> Self : # type: ignore[name-defined] # noqa: F821
@@ -168,20 +183,33 @@ def from_fields(cls, fields: Mapping[str, pa.DataType]) -> Self: # type: ignore
168
183
return cls (pyarrow_dtype = pyarrow_dtype )
169
184
170
185
@staticmethod
171
- def _validate_dtype (pyarrow_dtype : pa .DataType ) -> pa .StructType :
186
+ def _validate_dtype (pyarrow_dtype : pa .DataType ) -> tuple [pa .StructType , pa .ListType ]:
187
+ """Check that the given pyarrow type is castable to the nested type.
188
+
189
+ Parameters
190
+ ----------
191
+ pyarrow_dtype : pa.DataType
192
+ The pyarrow type to check and cast.
193
+
194
+ Returns
195
+ -------
196
+ pa.StructType
197
+ Struct-list pyarrow type representing the nested type.
198
+ pa.ListType
199
+ List-struct pyarrow type representing the nested type.
200
+ """
172
201
if not isinstance (pyarrow_dtype , pa .DataType ):
173
202
raise TypeError (f"Expected a 'pyarrow.DataType' object, got { type (pyarrow_dtype )} " )
174
- if not pa .types .is_struct (pyarrow_dtype ):
175
- raise ValueError ("NestedDtype can only be constructed with pyarrow struct type." )
176
- pyarrow_dtype = cast (pa .StructType , pyarrow_dtype )
177
-
178
- for field in pyarrow_dtype :
179
- if not is_pa_type_a_list (field .type ):
180
- raise ValueError (
181
- "NestedDtype can only be constructed with pyarrow struct type, all fields must be list "
182
- f"type. Given struct has unsupported field { field } "
183
- )
184
- return pyarrow_dtype
203
+ if pa .types .is_struct (pyarrow_dtype ):
204
+ struct_type = cast (pa .StructType , pyarrow_dtype )
205
+ return struct_type , transpose_struct_list_type (struct_type )
206
+ # Currently, LongList and others are not supported
207
+ if pa .types .is_list (pyarrow_dtype ):
208
+ list_type = cast (pa .ListType , pyarrow_dtype )
209
+ return transpose_list_struct_type (list_type ), list_type
210
+ raise ValueError (
211
+ f"NestedDtype can only be constructed with pa.StructType or pa.ListType only, got { pyarrow_dtype } "
212
+ )
185
213
186
214
@property
187
215
def fields (self ) -> dict [str , pa .DataType ]:
@@ -194,13 +222,14 @@ def field_names(self) -> list[str]:
194
222
return [field .name for field in self .pyarrow_dtype ]
195
223
196
224
@classmethod
197
- def from_pandas_arrow_dtype (cls , pandas_arrow_dtype : ArrowDtype ):
225
+ def from_pandas_arrow_dtype (cls , pandas_arrow_dtype : ArrowDtype ) -> Self : # type: ignore[name-defined] # noqa: F821
198
226
"""Construct NestedDtype from a pandas.ArrowDtype.
199
227
200
228
Parameters
201
229
----------
202
230
pandas_arrow_dtype : ArrowDtype
203
231
The pandas.ArrowDtype to construct NestedDtype from.
232
+ Must be struct-list or list-struct type.
204
233
205
234
Returns
206
235
-------
@@ -214,12 +243,20 @@ def from_pandas_arrow_dtype(cls, pandas_arrow_dtype: ArrowDtype):
214
243
"""
215
244
return cls (pyarrow_dtype = pandas_arrow_dtype .pyarrow_dtype )
216
245
217
- def to_pandas_arrow_dtype (self ) -> ArrowDtype :
246
+ def to_pandas_arrow_dtype (self , list_struct : bool = False ) -> ArrowDtype :
218
247
"""Convert NestedDtype to a pandas.ArrowDtype.
219
248
249
+ Parameters
250
+ ----------
251
+ list_struct : bool, default False
252
+ If False (default) use pyarrow struct-list type,
253
+ otherwise use pyarrow list-struct type.
254
+
220
255
Returns
221
256
-------
222
257
ArrowDtype
223
258
The corresponding pandas.ArrowDtype.
224
259
"""
260
+ if list_struct :
261
+ return ArrowDtype (self .list_struct_pyarrow_dtype )
225
262
return ArrowDtype (self .pyarrow_dtype )
0 commit comments