Skip to content

add DataTree.subset #10400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ Index into all nodes in the subtree simultaneously.

DataTree.isel
DataTree.sel
DataTree.subset

.. DataTree.drop_sel
.. DataTree.drop_isel
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ v2025.07.0 (unreleased)

New Features
~~~~~~~~~~~~
- Added :py:meth:`~xarray.DataTree.subset` to index variables on all nodes of a datatree (:pull:`10400`)
By `Mathias Hauser <https://github.com/mathause>`_.


Breaking changes
Expand Down
30 changes: 30 additions & 0 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterable,
Iterator,
Mapping,
Sequence,
)
from html import escape
from typing import (
Expand Down Expand Up @@ -1014,6 +1015,35 @@ def __delitem__(self, key: str) -> None:
else:
raise KeyError(key)

def subset(
self, keys: str | Sequence[str], *, errors: ErrorOptions = "raise"
) -> DataTree:
"""Index DataArrays on each node

Parameters
----------
keys : str | Sequence[str]
Name of the data variables to index.
errors : "raise", "ignore"
Whether to raise a key error if a data variable is missing on a node.

Returns
-------
out : DataTree
"""

if isinstance(keys, str):
keys = [keys]

def getitem(ds):
keys_for_ds = keys
if errors == "ignore":
keys_for_ds = [key for key in keys if key in ds.data_vars]

return ds[keys_for_ds]

return map_over_datasets(getitem, self)

@overload
def update(self, other: Dataset) -> None: ...

Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,34 @@ def test_getitem_dict_like_selection_access_to_dataset(self) -> None:
assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index]


def test_subset() -> None:
ds1 = xr.Dataset(data_vars={"var1": ("x", [1, 2]), "var2": ("x", [0, 1])})
ds2 = xr.Dataset(data_vars={"var1": ("x", [1, 2])})
dt = xr.DataTree.from_dict({"ds1": ds1, "ds2": ds2})

dt_var1 = xr.DataTree.from_dict({"ds1": ds1[["var1"]], "ds2": ds2})

# errors as map_over_datasets does not skip empty nodes
with pytest.raises(KeyError, match="var1"):
dt.subset("var1")

# will still error if map_over_datasets will ever skip empty nodes
with pytest.raises(KeyError, match="var2"):
dt.subset("var2")

result = dt.subset("var1", errors="ignore")
expected = dt_var1
xr.testing.assert_equal(result, expected)

result = dt.subset(["var1"], errors="ignore")
expected = dt_var1
xr.testing.assert_equal(result, expected)

result = dt.subset(["var1", "var2"], errors="ignore")
expected = dt
xr.testing.assert_equal(result, expected)


class TestUpdate:
def test_update(self) -> None:
dt = DataTree()
Expand Down
Loading