Skip to content

Add inner_dtypes to NestedDtype for sub-column dtype casting #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions docs/tutorials/low_level.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,11 @@
{
"cell_type": "code",
"execution_count": null,
"id": "422e719861ae40f6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-05T20:34:52.352751Z",
"start_time": "2025-03-05T20:34:52.350143Z"
}
},
"id": "da7788cc04b78a2a",
"metadata": {},
"outputs": [],
"source": [
"nested_series.equals(pd.Series(struct_series, dtype=NestedDtype.from_pandas_arrow_dtype(struct_series.dtype)))"
"nested_series.equals(pd.Series(struct_series, dtype=NestedDtype(struct_series.dtype)))"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions src/nested_pandas/series/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def to_flat(self, fields: list[str] | None = None) -> pd.DataFrame:
index=pd.Series(index, name=self._series.index.name),
name=field,
copy=False,
dtype=pd.ArrowDtype(chunked_array.type),
dtype=self._series.dtype.inner_dtype(field),
)

return pd.DataFrame(flat_series)
Expand Down Expand Up @@ -292,7 +292,7 @@ def get_flat_series(self, field: str) -> pd.Series:

return pd.Series(
flat_chunked_array,
dtype=pd.ArrowDtype(flat_chunked_array.type),
dtype=self._series.dtype.inner_dtype(field),
index=self.get_flat_index(),
name=field,
copy=False,
Expand Down
127 changes: 94 additions & 33 deletions src/nested_pandas/series/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,34 @@

@register_extension_dtype
class NestedDtype(ExtensionDtype):
"""Data type to handle packed time series data"""
"""Data type to handle packed time series data

Parameters
----------
pyarrow_dtype : pyarrow.StructType or pd.ArrowDtype
The pyarrow data type to use for the nested type. It must be a struct
type where all fields are list types.
inner_dtypes : Mapping[str, object] or None, default None
A mapping of field names and their inner types. This will be used to:
1. Cast to the correct types when getting flat representations
of the nested fields.
2. To handle information of the double-nested fields, you should use
this NestedDtype for the inner types in this case.
Dtypes must be pandas-recognisable types, such as Python native types,
numpy dtypes or extension array dtypes. Please wrap pyarrow types with
pd.ArrowDtype.
We trust these dtypes and make no attempt to validate them when
casting.
If None, all inner types are assumed to be the same as the
corresponding list element types.
"""

# ExtensionDtype overrides #

_metadata = ("pyarrow_dtype",)
_metadata = (
"pyarrow_dtype",
"inner_dtypes",
)
"""Attributes to use as metadata for __eq__ and __hash__"""

@property
Expand All @@ -38,7 +61,12 @@ def na_value(self) -> Type[pd.NA]:
@property
def name(self) -> str:
"""The string representation of the nested type"""
fields = ", ".join([f"{field.name}: [{field.type.value_type!s}]" for field in self.pyarrow_dtype])
# Replace pd.ArrowDtype with pa.DataType, because it has nicer __str__
nice_dtypes = {
field: dtype.pyarrow_dtype if isinstance(dtype, pd.ArrowDtype) else dtype
for field, dtype in self.fields.items()
}
fields = ", ".join([f"{field}: [{dtype!s}]" for field, dtype in nice_dtypes.items()])
return f"nested<{fields}>"

@classmethod
Expand Down Expand Up @@ -134,19 +162,26 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> ExtensionArray:
# Additional methods and attributes #

pyarrow_dtype: pa.StructType
inner_dtypes: dict[str, object]

def __init__(self, pyarrow_dtype: pa.DataType) -> None:
def __init__(
self, pyarrow_dtype: pa.DataType | pd.ArrowDtype, inner_dtypes: Mapping[str, object] | None = None
) -> None:
if isinstance(pyarrow_dtype, pd.ArrowDtype):
pyarrow_dtype = pyarrow_dtype.pyarrow_dtype
self.pyarrow_dtype = self._validate_dtype(pyarrow_dtype)
self.inner_dtypes = self._validate_inner_dtypes(self.pyarrow_dtype, inner_dtypes)

@classmethod
def from_fields(cls, fields: Mapping[str, pa.DataType]) -> Self: # type: ignore[name-defined] # noqa: F821
def from_fields(cls, fields: Mapping[str, pa.DataType | pa.ArrowDtype | Self]) -> Self: # type: ignore[name-defined] # noqa: F821
"""Make NestedDtype from a mapping of field names and list item types.

Parameters
----------
fields : Mapping[str, pa.DataType]
A mapping of field names and their item types. Since all fields are lists, the item types are
inner types of the lists, not the list types themselves.
fields : Mapping[str, pa.DataType | NestedDtype]
A mapping of field names and their item types. Since all fields are
lists, the item types are inner types of the lists, not the list
types themselves.

Returns
-------
Expand All @@ -163,9 +198,17 @@ def from_fields(cls, fields: Mapping[str, pa.DataType]) -> Self: # type: ignore
... == pa.struct({"a": pa.list_(pa.float64()), "b": pa.list_(pa.int64())})
... )
"""
pyarrow_dtype = pa.struct({field: pa.list_(pa_type) for field, pa_type in fields.items()})
pyarrow_dtype = cast(pa.StructType, pyarrow_dtype)
return cls(pyarrow_dtype=pyarrow_dtype)
pa_fields = {}
inner_dtypes = {}
for field, dtype in fields.items():
if isinstance(dtype, NestedDtype):
inner_dtypes[field] = dtype
dtype = dtype.pyarrow_dtype
elif isinstance(dtype, pd.ArrowDtype):
dtype = dtype.pyarrow_dtype
pa_fields[field] = dtype
pyarrow_dtype = pa.struct({field: pa.list_(pa_type) for field, pa_type in pa_fields.items()})
return cls(pyarrow_dtype=pyarrow_dtype, inner_dtypes=inner_dtypes or None)

@staticmethod
def _validate_dtype(pyarrow_dtype: pa.DataType) -> pa.StructType:
Expand All @@ -183,36 +226,54 @@ def _validate_dtype(pyarrow_dtype: pa.DataType) -> pa.StructType:
)
return pyarrow_dtype

@property
def fields(self) -> dict[str, pa.DataType]:
"""The mapping of field names and their item types."""
return {field.name: field.type.value_type for field in self.pyarrow_dtype}

@property
def field_names(self) -> list[str]:
"""The list of field names of the nested type"""
return [field.name for field in self.pyarrow_dtype]
@staticmethod
def _validate_inner_dtypes(
pyarrow_dtype: pa.StructType, inner_dtypes: Mapping[str, object] | None
) -> dict[str, object]:
# Short circuit if there are no inner dtypes
if inner_dtypes is None or len(inner_dtypes) == 0:
return {}

inner_dtypes = dict(inner_dtypes)

for field_name, inner_dtype in inner_dtypes.items():
if field_name not in pyarrow_dtype.names:
raise ValueError(f"Field '{field_name}' not found in the pyarrow struct type.")
element_type = pyarrow_dtype[field_name].type.value_type
test_series = pd.Series([], dtype=pd.ArrowDtype(element_type))
try:
_ = test_series.astype(inner_dtype)
except TypeError as e:
raise TypeError(
f"Could not cast the inner dtype '{inner_dtype}' for field '{field_name}' to the"
f" corresponding element type '{element_type}'. {e}"
) from e
return inner_dtypes

@classmethod
def from_pandas_arrow_dtype(cls, pandas_arrow_dtype: ArrowDtype):
"""Construct NestedDtype from a pandas.ArrowDtype.
def inner_dtype(self, field: str) -> object:
"""Get the inner dtype for a field.

Parameters
----------
pandas_arrow_dtype : ArrowDtype
The pandas.ArrowDtype to construct NestedDtype from.
field : str
The field name.

Returns
-------
NestedDtype
The constructed NestedDtype.

Raises
------
ValueError
If the given dtype is not a valid nested type.
object
The inner dtype for the field.
"""
return cls(pyarrow_dtype=pandas_arrow_dtype.pyarrow_dtype)
return self.inner_dtypes.get(field, pd.ArrowDtype(self.pyarrow_dtype[field].type.value_type))

@property
def fields(self) -> dict[str, object]:
"""The mapping of field names and pandas dtypes of their items"""
return {field.name: self.inner_dtype(field.name) for field in self.pyarrow_dtype}

@property
def field_names(self) -> list[str]:
"""The list of field names of the nested type"""
return [field.name for field in self.pyarrow_dtype]

def to_pandas_arrow_dtype(self) -> ArrowDtype:
"""Convert NestedDtype to a pandas.ArrowDtype.
Expand Down
Loading