Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

standardize requests to optimize caching #228

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
22 changes: 14 additions & 8 deletions earthkit/data/sources/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,32 @@ def retrieve(target, args):
@normalize("date", "date-list(%Y-%m-%d)")
@normalize("area", "bounding-box(list)")
def _normalize_request(**kwargs):
return kwargs
request = {}
for k, v in sorted(kwargs.items()):
v = ensure_iterable(v)
if k not in ("area", "grid"):
v = sorted(v)
request[k] = v[0] if len(v) == 1 else v
return request

@cached_property
def requests(self):
requests = []
for arg in self._args:
request = self._normalize_request(**arg)
split_on = request.pop("split_on", None)
if split_on is None:
for request in self._args:
split_on = request.pop("split_on", {})
if not isinstance(split_on, dict):
split_on = {k: 1 for k in ensure_iterable(split_on) if k is not None}
if not split_on:
requests.append(request)
continue

if not isinstance(split_on, dict):
split_on = {k: 1 for k in ensure_iterable(split_on)}
request = self._normalize_request(**request)
for values in itertools.product(
*[batched(ensure_iterable(request[k]), v) for k, v in split_on.items()]
):
subrequest = dict(zip(split_on, values))
requests.append(request | subrequest)
return requests
return [self._normalize_request(**request) for request in requests]


source = CdsRetriever
30 changes: 30 additions & 0 deletions tests/sources/test_cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,36 @@ def test_cds_grib_multi_var_date(date, expected_date):
assert s.metadata("date") == expected_date


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
@pytest.mark.parametrize(
"variable1,variable2,expected_vars",
(
("2t", ["2t"], {"t2m"}),
(["2t", "msl"], ["msl", "2t"], {"t2m", "msl"}),
),
)
def test_cds_normalized_request(variable1, variable2, expected_vars):
base_request = dict(
product_type="reanalysis",
area=[50, -50, 20, 50],
grid=[2, 1],
date="2012-12-12",
time="12:00",
)
s1 = from_source(
"cds", "reanalysis-era5-single-levels", variable=variable1, **base_request
)
s2 = from_source(
"cds", "reanalysis-era5-single-levels", variable=variable2, **base_request
)
assert s1.path == s2.path
assert set(s1.to_xarray().data_vars) == expected_vars
assert s1.to_xarray()["longitude"].values.tolist() == list(range(-50, 51, 2))
assert s1.to_xarray()["latitude"].values.tolist() == list(range(50, 19, -1))


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
Expand Down