Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ uv run dmypy run # Type checking with mypy
GitHub issues/PRs unless specifically instructed
- When creating commits, always include a co-authorship trailer:
`Co-authored-by: Claude <[email protected]>`
- Submit upstream PRs against `main`
1 change: 1 addition & 0 deletions doc/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Dataset contents
Dataset.rename_dims
Dataset.swap_dims
Dataset.expand_dims
Dataset.subset
Dataset.drop_vars
Dataset.drop_indexes
Dataset.drop_duplicates
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ New Features
- :py:func:`combine_nested` now support :py:class:`DataTree` objects
(:pull:`10849`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Added :py:meth:`Dataset.subset` method for type-stable selection of multiple
variables. Unlike indexing with ``__getitem__``, this method always returns
a Dataset and accepts sequence types (lists and tuples) (:issue:`3894`).

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
96 changes: 96 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,8 @@ def as_numpy(self) -> Self:
def _copy_listed(self, names: Iterable[Hashable]) -> Self:
"""Create a new Dataset with the listed variables from this dataset and
the all relevant coordinates. Skips all validation.

For public API, use Dataset.subset() instead.
"""
variables: dict[Hashable, Variable] = {}
coord_names = set()
Expand Down Expand Up @@ -1223,6 +1225,100 @@ def _copy_listed(self, names: Iterable[Hashable]) -> Self:

return self._replace(variables, coord_names, dims, indexes=indexes)

def subset(self, names: Sequence[Hashable]) -> Self:
"""Return a new Dataset with only the specified variables.

This is a type-stable method for selecting multiple variables from a
Dataset. Unlike indexing with ``__getitem__``, this method always
returns a Dataset and accepts sequence types (lists, tuples) of variable
names.

All coordinates needed for the selected variables are automatically
included in the returned Dataset.

Parameters
----------
names : sequence of hashable
A sequence (list or tuple) of variable names to include in the
returned Dataset. The names must exist in the Dataset.

Returns
-------
Dataset
A new Dataset containing only the specified variables and their
required coordinates.

Raises
------
TypeError
If ``names`` is not a sequence (e.g., if it's a set, generator, or string).
KeyError
If any of the variable names do not exist in the Dataset.

See Also
--------
Dataset.__getitem__ : Select variables using indexing notation.
Dataset.drop_vars : Remove variables from a Dataset.

Examples
--------
>>> ds = xr.Dataset(
... {
... "temperature": (["x", "y"], [[1, 2], [3, 4]]),
... "pressure": (["x", "y"], [[5, 6], [7, 8]]),
... "humidity": (["x"], [0.5, 0.6]),
... },
... coords={"x": [10, 20], "y": [30, 40]},
... )
>>> ds
<xarray.Dataset> Size: 112B
Dimensions: (x: 2, y: 2)
Coordinates:
* x (x) int64 16B 10 20
* y (y) int64 16B 30 40
Data variables:
temperature (x, y) int64 32B 1 2 3 4
pressure (x, y) int64 32B 5 6 7 8
humidity (x) float64 16B 0.5 0.6

Select a subset of variables using a list:

>>> ds.subset(["temperature", "humidity"])
<xarray.Dataset> Size: 80B
Dimensions: (x: 2, y: 2)
Coordinates:
* x (x) int64 16B 10 20
* y (y) int64 16B 30 40
Data variables:
temperature (x, y) int64 32B 1 2 3 4
humidity (x) float64 16B 0.5 0.6

Unlike ``__getitem__``, this method accepts tuples:

>>> vars_tuple = ("temperature", "pressure")
>>> ds.subset(vars_tuple) # Works with tuples
<xarray.Dataset> Size: 96B
Dimensions: (x: 2, y: 2)
Coordinates:
* x (x) int64 16B 10 20
* y (y) int64 16B 30 40
Data variables:
temperature (x, y) int64 32B 1 2 3 4
pressure (x, y) int64 32B 5 6 7 8

The method always returns a Dataset, providing type stability:

>>> result = ds.subset(["temperature"])
>>> isinstance(result, xr.Dataset)
True
"""
# Validate that names is a sequence (but not a string)
if isinstance(names, str) or not isinstance(names, Sequence):
raise TypeError(
f"names must be a sequence (list or tuple), got {type(names).__name__}"
)
return self._copy_listed(names)

def _construct_dataarray(self, name: Hashable) -> DataArray:
"""Construct a DataArray by indexing this dataset"""
from xarray.core.dataarray import DataArray
Expand Down
51 changes: 51 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4451,6 +4451,57 @@ def test_getitem_multiple_dtype(self) -> None:
dataset = Dataset({key: ("dim0", range(1)) for key in keys})
assert_identical(dataset, dataset[keys])

def test_subset(self) -> None:
data = create_test_data()

# Test with list of variables
result = data.subset(["var1", "var2"])
expected = Dataset({"var1": data["var1"], "var2": data["var2"]})
assert_identical(result, expected)

# Test with tuple (the original issue from #3894)
vars_tuple = ("var1", "var2")
result_tuple = data.subset(vars_tuple)
assert_identical(result_tuple, expected)

# Test type stability - always returns a Dataset
result_single = data.subset(["var1"])
expected_single = Dataset({"var1": data["var1"]})
assert_identical(result_single, expected_single)

# Test that coordinates are preserved
ds = Dataset(
{
"temperature": (["x", "y"], [[1, 2], [3, 4]]),
"pressure": (["x", "y"], [[5, 6], [7, 8]]),
"humidity": (["x"], [0.5, 0.6]),
},
coords={"x": [10, 20], "y": [30, 40]},
)
result_coords = ds.subset(["temperature", "humidity"])
expected_coords = Dataset(
{
"temperature": (["x", "y"], [[1, 2], [3, 4]]),
"humidity": (["x"], [0.5, 0.6]),
},
coords={"x": [10, 20], "y": [30, 40]},
)
assert_identical(result_coords, expected_coords)

# Test error handling for non-existent variable
with pytest.raises(KeyError):
data.subset(["var1", "notfound"])

# Test that non-sequence types raise TypeError
with pytest.raises(TypeError, match="names must be a sequence"):
data.subset({"var1", "var2"}) # type: ignore[arg-type] # set

with pytest.raises(TypeError, match="names must be a sequence"):
data.subset(v for v in ["var1", "var2"]) # type: ignore[arg-type] # generator

with pytest.raises(TypeError, match="names must be a sequence"):
data.subset("var1") # string (valid Sequence type, but explicitly rejected)

def test_getitem_extra_dim_index_coord(self) -> None:
class AnyIndex(Index):
def should_add_coord_to_array(self, name, var, dims):
Expand Down
Loading