Skip to content

Commit

Permalink
add xarray Dataset Reader
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Jan 21, 2025
1 parent acab661 commit dd255d2
Showing 1 changed file with 238 additions and 3 deletions.
241 changes: 238 additions & 3 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations

import contextlib
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import attr
import numpy
Expand Down Expand Up @@ -43,8 +44,8 @@


@attr.s
class XarrayReader(BaseReader):
"""Xarray Reader.
class DataArrayReader(BaseReader):
"""Xarray DataArray Reader.
Attributes:
dataset (xarray.DataArray): Xarray DataArray dataset.
Expand Down Expand Up @@ -567,3 +568,237 @@ def feature(
img.array.mask = numpy.where(~cutline_mask, img.array.mask, True)

return img


@attr.s
class DatasetReader(BaseReader):
"""Xarray Reader.
Attributes:
input (str): dataset path.
dataset (xarray.Dataset): Xarray dataset.
tms (morecantile.TileMatrixSet, optional): TileMatrixSet grid definition. Defaults to `WebMercatorQuad`.
opener (Callable): Xarray dataset opener. Defaults to `xarray.open_dataset`.
opener_options (dict): Options to forward to the opener callable.
Examples:
>>> with DatasetReader(
"s3://mur-sst/zarr-v1",
opener_options={"engine": "zarr"}
) as src:
print(src)
print(src.variables)
img = src.tile(x, y, z, tmax)
"""

input: str = attr.ib()
dataset: xarray.Dataset = attr.ib(default=None)

tms: TileMatrixSet = attr.ib(default=WEB_MERCATOR_TMS)

opener: Callable[..., xarray.Dataset] = attr.ib(default=xarray.open_dataset)
opener_options: Dict = attr.ib(factory=dict)

_ctx_stack: contextlib.ExitStack = attr.ib(init=False, factory=contextlib.ExitStack)

def __attrs_post_init__(self):
"""Set bounds and CRS."""
assert xarray is not None, "xarray must be installed to use XarrayReader"
assert rioxarray is not None, "rioxarray must be installed to use XarrayReader"

if not self.dataset:
self.dataset = self._ctx_stack.enter_context(
self.opener(self.input, **self.opener_options)
)

self.bounds = None
self.crs = None

def close(self):
"""Close xarray dataset."""
self._ctx_stack.close()

def __exit__(self, exc_type, exc_value, traceback):
"""Support using with Context Managers."""
self.close()

@property
def variables(self) -> List[str]:
"""Return dataset variable names"""
return list(self.dataset.data_vars)

def _arrange_dims(self, da: xarray.DataArray) -> xarray.DataArray:
"""Arrange coordinates and time dimensions.
An rioxarray.exceptions.InvalidDimensionOrder error is raised if the coordinates are not in the correct order time, y, and x.
See: https://github.com/corteva/rioxarray/discussions/674
We conform to using x and y as the spatial dimension names..
"""
if "x" not in da.dims and "y" not in da.dims:
try:
latitude_var_name = next(
name
for name in ["lat", "latitude", "LAT", "LATITUDE", "Lat"]
if name in da.dims
)
longitude_var_name = next(
name
for name in ["lon", "longitude", "LON", "LONGITUDE", "Lon"]
if name in da.dims
)
except StopIteration as e:
raise ValueError(f"Couldn't find X/Y dimensions in {da.dims}") from e

da = da.rename({latitude_var_name: "y", longitude_var_name: "x"})

if "TIME" in da.dims:
da = da.rename({"TIME": "time"})

if extra_dims := [d for d in da.dims if d not in ["x", "y"]]:
da = da.transpose(*extra_dims, "y", "x")
else:
da = da.transpose("y", "x")

# If min/max values are stored in `valid_range` we add them in `valid_min/valid_max`
vmin, vmax = da.attrs.get("valid_min"), da.attrs.get("valid_max")
if "valid_range" in da.attrs and not (vmin is not None and vmax is not None):
valid_range = da.attrs.get("valid_range")
da.attrs.update({"valid_min": valid_range[0], "valid_max": valid_range[1]})

return da

def get_variable(
self, variable: str, drop_dim: Optional[str] = None
) -> xarray.DataArray:
"""Get DataArray from xarray Dataset."""
da = self.dataset[variable]

if drop_dim:
dim_to_drop, dim_val = drop_dim.split("=")
da = da.sel({dim_to_drop: dim_val}).drop_vars(dim_to_drop)

da = self._arrange_dims(da)

# Make sure we have a valid CRS
crs = da.rio.crs or "epsg:4326"
da = da.rio.write_crs(crs)

if crs == "epsg:4326" and (da.x > 180).any():
# Adjust the longitude coordinates to the -180 to 180 range
da = da.assign_coords(x=(da.x + 180) % 360 - 180)

# Sort the dataset by the updated longitude coordinates
da = da.sortby(da.x)

assert len(da.dims) in [
2,
3,
], "rio_tiler.io.xarray.DatasetReader can only work with 2D or 3D DataArray"

return da

def spatial_info(self, variable: str, drop_dim: Optional[str] = None):
"""Return xarray.DataArray info."""
da = DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
)
return {
"crs": da.crs,
"bounds": da.bounds,
"minzoom": da.minzoom,
"maxzoom": da.maxzoom,
}

def get_geographic_bounds( # type: ignore
self, crs: CRS, variable: str, drop_dim: Optional[str] = None
) -> BBox:
"""Return Geographic Bounds for a Geographic CRS."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
).get_geographic_bounds(crs)

def info(self, variable: str, drop_dim: Optional[str] = None) -> Info: # type: ignore
"""Return xarray.DataArray info."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
).info()

def statistics( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, BandStatistics]:
"""Return statistics from a dataset."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
).statistics(*args, **kwargs)

def tile( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Read a Web Map tile from a dataset."""
return DataArrayReader(
self.get_variable(variable, drop_dim=drop_dim),
tms=self.tms,
).tile(*args, **kwargs)

def part( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Read part of a dataset."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).part(
*args, **kwargs
)

def preview( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Return a preview of a dataset."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).preview(
*args, **kwargs
)

def point( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> PointData:
"""Read a pixel value from a dataset."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).point(
*args, **kwargs
)

def feature( # type: ignore
self,
*args: Any,
variable: str,
drop_dim: Optional[str] = None,
**kwargs: Any,
) -> ImageData:
"""Read part of a dataset defined by a geojson feature."""
return DataArrayReader(self.get_variable(variable, drop_dim=drop_dim)).feature(
*args, **kwargs
)


# Compat
XarrayReader = DataArrayReader

0 comments on commit dd255d2

Please sign in to comment.