Skip to content

Commit 894bd3a

Browse files
committed
use xarray instead of iris
1 parent 08a605c commit 894bd3a

File tree

1 file changed

+45
-61
lines changed

1 file changed

+45
-61
lines changed

oceans/datasets.py

Lines changed: 45 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import functools
12
import warnings
23

34
import numpy as np
45
from netCDF4 import Dataset
56

67
from oceans.ocfis import get_profile, wrap_lon180
7-
import functools
88

99

1010
def _woa_variable(variable):
@@ -105,7 +105,7 @@ def _woa_url(variable, time_period, resolution):
105105
@functools.lru_cache(maxsize=256)
106106
def woa_profile(lon, lat, variable="temperature", time_period="annual", resolution="1"):
107107
"""
108-
Return an iris.cube instance from a World Ocean Atlas variable at a
108+
Return a xarray DAtaset instance from a World Ocean Atlas variable at a
109109
given lon, lat point.
110110
111111
Parameters
@@ -125,121 +125,105 @@ def woa_profile(lon, lat, variable="temperature", time_period="annual", resoluti
125125
126126
Returns
127127
-------
128-
Iris.cube instance with the climatology.
128+
xr.Dataset instance with the climatology.
129129
130130
Examples
131131
--------
132132
>>> import matplotlib.pyplot as plt
133133
>>> from oceans.datasets import woa_profile
134-
>>> cube = woa_profile(
134+
>>> woa = woa_profile(
135135
... -143, 10, variable="temperature", time_period="annual", resolution="5"
136136
... )
137137
>>> fig, ax = plt.subplots(figsize=(2.25, 5))
138-
>>> z = cube.coord(axis="Z").points
139-
>>> l = ax.plot(cube[0, :].data, z)
138+
>>> woa.plot(ax=ax, y="depth")
140139
>>> ax.grid(True)
141140
>>> ax.invert_yaxis()
142141
143142
"""
144-
import iris
143+
import cf_xarray # noqa
144+
import xarray as xr
145145

146146
url = _woa_url(variable=variable, time_period=time_period, resolution=resolution)
147-
148-
with warnings.catch_warnings():
149-
warnings.simplefilter("ignore")
150-
cubes = iris.load_raw(url)
151-
152-
# TODO: should we be using `an` instead of `mn`?
153147
v = _woa_variable(variable)
154-
cube = [c for c in cubes if c.var_name == f"{v}_mn"][0]
155-
scheme = iris.analysis.Nearest()
156-
sample_points = [("longitude", lon), ("latitude", lat)]
157-
kw = {
158-
"sample_points": sample_points,
159-
"scheme": scheme,
160-
"collapse_scalar": True,
161-
}
162-
return cube.interpolate(**kw)
148+
149+
ds = xr.open_dataset(url, decode_times=False)
150+
ds = ds[f"{v}_mn"]
151+
return ds.cf.sel({"X": lon, "Y": lat}, method="nearest")
163152

164153

165154
@functools.lru_cache(maxsize=256)
166155
def woa_subset(
167-
bbox,
156+
min_lon,
157+
max_lon,
158+
min_lat,
159+
max_lat,
168160
variable="temperature",
169161
time_period="annual",
170162
resolution="5",
171163
full=False,
172164
):
173165
"""
174-
Return an iris.cube instance from a World Ocean Atlas variable at a
166+
Return an xarray Dataset instance from a World Ocean Atlas variable at a
175167
given lon, lat bounding box.
176168
177169
Parameters
178170
----------
179-
bbox: list, tuple
180-
minx, maxx, miny, maxy positions to extract.
171+
min_lon, max_lon, min_lat, max_lat: positions to extract.
181172
See `woa_profile` for the other options.
182173
183174
Returns
184175
-------
185-
`iris.Cube` instance with the climatology.
176+
`xr.Dataset` instance with the climatology.
186177
187178
Examples
188179
--------
189180
>>> # Extract a 2D surface -- Annual temperature climatology:
190-
>>> import iris.plot as iplt
191181
>>> import matplotlib.pyplot as plt
192-
>>> from oceans.colormaps import cm
193-
>>> bbox = [2.5, 357.5, -87.5, 87.5]
194-
>>> cube = woa_subset(
195-
... bbox, variable="temperature", time_period="annual", resolution="5"
182+
>>> from cmcrameri import cm
183+
>>> from oceans.datasets import woa_subset
184+
>>> bbox = [-177.5, 177.5, -87.5, 87.5]
185+
>>> woa = woa_subset(
186+
... *bbox, variable="temperature", time_period="annual", resolution="5"
196187
... )
197-
>>> c = cube[0, 0, ...] # Slice singleton time and first level.
198-
>>> cs = iplt.pcolormesh(c, cmap=cm.avhrr)
199-
>>> cbar = plt.colorbar(cs)
188+
>>> woa.squeeze().sel(depth=0).plot(cmap=cm.lajolla)
200189
201190
>>> # Extract a square around the Mariana Trench averaging into a profile.
202-
>>> import iris
191+
>>> import matplotlib.pyplot as plt
192+
>>> import numpy as np
203193
>>> from oceans.colormaps import get_color
204194
>>> colors = get_color(12)
205195
>>> months = "Jan Feb Apr Mar May Jun Jul Aug Sep Oct Nov Dec".split()
196+
>>> def area_weights_avg(woa):
197+
... woa = woa["t_mn"].squeeze()
198+
... weights = np.cos(np.deg2rad(woa["lat"])).where(~woa.isnull())
199+
... weights /= weights.mean()
200+
... return (woa * weights).mean(dim=["lon", "lat"])
201+
...
206202
>>> bbox = [-143, -141, 10, 12]
207203
>>> fig, ax = plt.subplots(figsize=(5, 5))
208204
>>> for month in months:
209-
... cube = woa_subset(
210-
... bbox, time_period=month, variable="temperature", resolution="1"
205+
... woa = woa_subset(
206+
... *bbox, time_period=month, variable="temperature", resolution="1"
211207
... )
212-
... grid_areas = iris.analysis.cartography.area_weights(cube)
213-
... c = cube.collapsed(
214-
... ["longitude", "latitude"], iris.analysis.MEAN, weights=grid_areas
215-
... )
216-
... z = c.coord(axis="Z").points
217-
... l = ax.plot(c[0, :].data, z, label=month, color=next(colors))
208+
... profile = area_weights_avg(woa)
209+
... profile.plot(ax=ax, y="depth", label=month, color=next(colors))
218210
...
219211
>>> ax.grid(True)
220212
>>> ax.invert_yaxis()
221213
>>> leg = ax.legend(loc="lower left")
222-
>>> _ = ax.set_ylim(200, 0)
214+
>>> ax.set_ylim(200, 0)
223215
224216
"""
225-
import iris
217+
import cf_xarray # noqa
218+
import xarray as xr
226219

227-
v = _woa_variable(variable)
228220
url = _woa_url(variable, time_period, resolution)
229-
cubes = iris.load_raw(url)
230-
cubes = [
231-
cube.intersection(longitude=(bbox[0], bbox[1]), latitude=(bbox[2], bbox[3]))
232-
for cube in cubes
233-
]
234-
235-
with warnings.catch_warnings():
236-
warnings.simplefilter("ignore")
237-
cubes = iris.cube.CubeList(cubes)
238-
221+
ds = xr.open_dataset(url, decode_times=False)
222+
ds = ds.cf.sel({"X": slice(min_lon, max_lon), "Y": slice(min_lat, max_lat)})
223+
v = _woa_variable(variable)
239224
if full:
240-
return cubes
241-
else:
242-
return [c for c in cubes if c.var_name == f"{v}_mn"][0]
225+
return ds
226+
return ds[[f"{v}_mn"]] # always return a dataset
243227

244228

245229
@functools.lru_cache(maxsize=256)
@@ -322,7 +306,7 @@ def get_isobath(bbox, iso=-200, tfile=None, smoo=False):
322306
>>> lon, lat, bathy = etopo_subset(bbox=bbox, smoo=True)
323307
>>> fig, ax = plt.subplots()
324308
>>> cs = ax.pcolormesh(lon, lat, bathy)
325-
>>> for segment in segments:
309+
>>> for segment in segmentsz:
326310
... lines = ax.plot(segment[:, 0], segment[:, -1], "k", linewidth=2)
327311
...
328312

0 commit comments

Comments
 (0)