Skip to content

(fix): disallow NumpyExtensionArray #10334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hypothesis.extra.pandas as pdst # isort:skip
import hypothesis.strategies as st # isort:skip
from hypothesis import given # isort:skip
from xarray.tests import has_pyarrow

numeric_dtypes = st.one_of(
npst.unsigned_integer_dtypes(endianness="="),
Expand Down Expand Up @@ -134,10 +135,39 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
@pytest.mark.parametrize(
"extension_array",
[
pd.Categorical(["a", "b", "c"]),
pd.array(["a", "b", "c"], dtype="string"),
pd.arrays.IntervalArray(
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
),
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])),
pd.arrays.DatetimeArray._from_sequence(
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D")
),
np.array([1, 2, 3], dtype="int64"),
]
+ ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []),
ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"]
+ (["pyarrow"] if has_pyarrow else []),
)
@pytest.mark.parametrize("is_index", [True, False])
def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None:
df = pd.DataFrame({"arr": extension_array})
if is_index:
df = df.set_index("arr")
arr = xr.Dataset.from_dataframe(df)["arr"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
df_arr_to_test = df.index if is_index else df["arr"]
assert (df_arr_to_test == roundtripped).all()
# `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes.
if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined]
assert isinstance(arr.data, np.ndarray)
else:
assert (
df_arr_to_test.dtype
== (roundtripped.index if is_index else roundtripped).dtype
)
xr.testing.assert_identical(arr, roundtripped.to_xarray())
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
parse_dims_as_set,
)
from xarray.core.variable import (
UNSUPPORTED_EXTENSION_ARRAY_TYPES,
IndexVariable,
Variable,
as_variable,
Expand Down Expand Up @@ -7281,7 +7282,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v) or isinstance(
v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
):
arrays.append((k, np.asarray(v)))
else:
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
def __post_init__(self):
if not isinstance(self.array, pd.api.extensions.ExtensionArray):
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
# we do support extension arrays from datetime, for example, that need
# duck array support internally via this class.
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
raise TypeError(
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
)

def __array_function__(self, func, types, args, kwargs):
def replace_duck_with_extension_array(args) -> list:
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,8 +1802,12 @@ def __array__(

def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
# We return an PandasExtensionArray wrapper type that satisfies
# duck array protocols. This is what's needed for tests to pass.
if pd.api.types.is_extension_array_dtype(self.array):
# duck array protocols.
# `NumpyExtensionArray` is excluded
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
self.array.array,
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
):
from xarray.core.extension_array import PandasExtensionArray

return PandasExtensionArray(self.array.array)
Expand Down
16 changes: 15 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@
)
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)
UNSUPPORTED_EXTENSION_ARRAY_TYPES = (
pd.arrays.DatetimeArray,
pd.arrays.TimedeltaArray,
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
)

if TYPE_CHECKING:
from xarray.core.types import (
Expand Down Expand Up @@ -190,6 +195,8 @@ def _maybe_wrap_data(data):
"""
if isinstance(data, pd.Index):
return PandasIndexingAdapter(data)
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
return data.to_numpy()
if isinstance(data, pd.api.extensions.ExtensionArray):
return PandasExtensionArray(data)
return data
Expand Down Expand Up @@ -251,7 +258,14 @@ def convert_non_numpy_type(data):

# we don't want nested self-described arrays
if isinstance(data, pd.Series | pd.DataFrame):
pandas_data = data.values
if (
isinstance(data, pd.Series)
and pd.api.types.is_extension_array_dtype(data)
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
):
pandas_data = data.array
else:
pandas_data = data.values # type: ignore[assignment]
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(pandas_data)
else:
Expand Down
Loading