-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add averages with time dimension removed #236
Conversation
- Update docstrings of methods in `TemporalAccessor` class - Add `DEFAULT_SEASON_CONFIG` to reduce code duplication - Add comments for sections of code that is opaque - Rename `_averager()` to `_grouped_average()`
Codecov Report
@@ Coverage Diff @@
## main #236 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 9 8 -1
Lines 742 723 -19
=========================================
- Hits 742 723 -19
Continue to review full report at Codecov.
|
- Remove `data_var` arg from `_get_weights()` - Update `_convert_df_to_dt()` to use regular dt objects for mean mode and non-monthly frequencies Add `TemporalAccessor` attribute `_dim_name` - Update docstrings
- Rename `.mean()` to `.average()` - Rename existing `.average()` to `.group_average()` - Update "mean" to "average" and "average" to "group_average in `Mode` type alias
- Rename `_grouped_average()` to `_group_average()` - Rename `_time_grouped` to `_grouped_time` - Update docstring for `_convert_df_to_dt()` - Update conditional in `_convert_df_to_dt()` - Update `_validate_weights()` to get `num_groups` from `_get_weights()`instead of `self._time_grouped` - Rename classes in `test_temporal.py` to reflect private methods - Add placeholder tests in `test_temporal.py`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial review comments
Hey Tom, just to follow up the xcdat meeting that we had. Regarding to two points:
I'm thinking one possibility to resolve both: is to make the new method an option of average()/group_average(). Assume that example: |
- Update `_group_time_coords()` to `_label_time_coords()` - Update `_groupby_freq()` to `_group_data()` - Extract `_get_dt_components()` from `_group_time_coords()` - Rename `_process_season_dataframe()` to `_process_season_df()` - Move `_convert_df_to_dt()` further down the class - Rename `DATETIME_COMPONENTS` to `TIME_GROUPS` - Update season frequency in `TIME_GROUPS` - Add logic to `_get_dt_components()` to handle seasonal frequency
Hi Jill, thanks for the suggestion. On the surface it seems like a relatively simple solution. It would be nice to have a single API that can easily solve the naming inconsistency! After investigating more closely with a prototype implementation, I unfortunately found that combining these two different averaging operations under a single API makes the API less predictable because each operation produces a different output based on the For example, an end-user might expect a grouped average output, but they get a time collapsed average because they don't pass a My proposed (and current) solution is to have two well-defined APIs that serve different purposes because I think it is easier to use and can be cleanly implemented. The name of the APIs/methods explains what they do, while their respective docstrings fills the gap in how to use them.
As mentioned before, we will write an API translation table for CDAT users to help smooth the transition process to xCDAT. Let me know what you think! Also tagging @lee1043 to see if he has any thoughts or ideas. |
@tomvothecoder Thank you for the effort to explore the different options! Appreciate it. What you suggested about two APIs actually make sense to me. Tagging @lee1043 or @pochedls to see if they have some input on this implementation or perhaps better API names? |
@tomvothecoder @chengzhuzhang thank you for the discussion. I was initially leaning toward what @chengzhuzhang suggested earlier -- single API on top of two capabilities -- but I am leaning back to the well-defined 2 separate APIs. I think the point @tomvothecoder raised is very important -- make API predictable, which I think is one of key to be a successful community tool. While we expect users do READ docs first, we know it often not the case. In my personal experiences, I often start directly from some sample of the code, and read docs when I get stumbled. When API's predictability is great and I know what exactly API does, I can trust the API more. I think |
- Testing private methods introduces coupling to implementation details. We should be testing public methods, which tests behaviors and outputs
- Use dataset custom dataset fixtures with less coordinate points and values other than 1 for easier and more robust testing
- Add tests for `season_config` arg - Update names of tests
38ac565
to
131ff69
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @lee1043 and @chengzhuzhang, this PR is ready for review.
The PR's description should contain all of the relevant information to make the review process smooth. Thanks!
def _validate_weights(self, weights: xr.DataArray, num_groups: int): | ||
"""Validates that the sum of the weights for each group equals 1.0. | ||
|
||
Parameters | ||
---------- | ||
data_var : xr.DataArray | ||
The data variable. | ||
weights : xr.DataArray | ||
The data variable's time coordinates weights. | ||
The weights for the time coordinates. | ||
num_groups : int | ||
The number of groups. | ||
""" | ||
freq_groups = self._groupby_freq(data_var).count() # type: ignore | ||
# Sum the frequency group counts by all the dims except the grouped time | ||
# dimension to get a 1D array of counts. | ||
summing_dims = tuple( | ||
x for x in freq_groups.dims if x != self._time_grouped.name | ||
) | ||
freq_sums = freq_groups.sum(summing_dims) | ||
actual_sum = self._group_data(weights).sum().values # type: ignore | ||
expected_sum = np.ones(num_groups) | ||
|
||
# Replace all non-zero counts with 1.0 (total weight of 100%). | ||
expected_sum = np.where(freq_sums > 0, 1.0, freq_sums) | ||
actual_sum = self._groupby_freq(weights).sum().values # type: ignore | ||
np.testing.assert_allclose(actual_sum, expected_sum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactored implementation of _validate_weights()
speeds up weighted averages significantly.
The bottleneck was due to the call freq_groups=self_groupby_freq(data_var).count()
for getting the count of unique grouping labels. Now we just pass num_groups
directly to the method from _get_weights()
.
For the dataset I was testing on and a frequency of month
, _get_weights()
is now about 4-5x faster (~10.8-11.6 secs down to 2.8-3.2 seconds).
@@ -1,8 +1,5 @@ | |||
from datetime import datetime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed a lot of the unit tests for private methods in test_temporal.py
. I learned that testing private methods is unnecessary and introduces coupling to the implementation details, which breaks encapsulation. It makes refactoring significantly more tedious because the tests for the private methods need to also be updated (also risk of test duplication with public methods). Instead, tests should revolve around public methods and their behaviors/outputs.
More info on why testing private methods is considered "bad" practice:
https://softwareengineering.stackexchange.com/a/380296
https://stackoverflow.com/questions/105007/should-i-test-private-methods-or-only-public-ones
xcdat/temporal.py
Outdated
def average( | ||
self, | ||
data_var: str, | ||
freq: Frequency, | ||
center_times: bool = False, | ||
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG, | ||
): | ||
""" | ||
Returns a Dataset with the time weighted average of a data variable | ||
and the time dimension removed. | ||
|
||
This method is particularly useful for monthly or yearly time series | ||
data since the number of days per month can vary based on the calendar | ||
type, which can affect weighting. For unweighted time averages or other | ||
frequencies, call xarray's native ``.mean()`` method on the data | ||
variable instead. | ||
|
||
Weights are calculated by first determining the length of time for | ||
each coordinate point using the difference of its upper and lower | ||
bounds. The time lengths are grouped, then each time length is | ||
divided by the total sum of the time lengths to get the weight of | ||
each coordinate point. | ||
|
||
Parameters | ||
---------- | ||
data_var: str | ||
The key of the data variable for calculating averages | ||
freq : Frequency | ||
The time frequency for calculating weights. | ||
|
||
* "year": weights by year | ||
* "season": weights by season | ||
* "month": weights by month | ||
* "day": weights by day | ||
* "hour": weights by hour | ||
|
||
center_times: bool, optional | ||
If True, center time coordinates using the midpoint between its | ||
upper and lower bounds. Otherwise, use the provided time | ||
coordinates by default False. | ||
|
||
season_config: SeasonConfigInput, optional | ||
A dictionary for "season" frequency configurations. If configs for | ||
predefined seasons are passed, configs for custom seasons are | ||
ignored and vice versa. | ||
|
||
Configs for predefined seasons: | ||
|
||
* "dec_mode" (Literal["DJF", "JFD"], by default "DJF") | ||
The mode for the season that includes December. | ||
|
||
* "DJF": season includes the previous year December. | ||
* "JFD": season includes the same year December. | ||
Xarray labels the season with December as "DJF", but it is | ||
actually "JFD". | ||
|
||
* "drop_incomplete_djf" (bool, by default False) | ||
If the "dec_mode" is "DJF", this flag drops (True) or keeps | ||
(False) time coordinates that fall under incomplete DJF seasons | ||
Incomplete DJF seasons include the start year Jan/Feb and the | ||
end year Dec. | ||
|
||
Configs for custom seasons: | ||
|
||
* "custom_seasons" ([List[List[str]]], by default None) | ||
List of sublists containing month strings, with each sublist | ||
representing a custom season. | ||
|
||
* Month strings must be in the three letter format (e.g., 'Jan') | ||
* Each month must be included once in a custom season | ||
* Order of the months in each custom season does not matter | ||
* Custom seasons can vary in length | ||
|
||
>>> # Example of custom seasons in a three month format: | ||
>>> custom_seasons = [ | ||
>>> ["Jan", "Feb", "Mar"], # "JanFebMar" | ||
>>> ["Apr", "May", "Jun"], # "AprMayJun" | ||
>>> ["Jul", "Aug", "Sep"], # "JunJulAug" | ||
>>> ["Oct", "Nov", "Dec"], # "OctNovDec" | ||
>>> ] | ||
|
||
Returns | ||
------- | ||
xr.Dataset | ||
Dataset with the time weighted average of the data variable and the | ||
time dimension removed. | ||
|
||
Examples | ||
-------- | ||
|
||
Get weighted averages for a monthly time series data variable: | ||
|
||
>>> ds_month = ds.temporal.average("ts", freq="month", center_times=False) | ||
>>> ds_month.ts | ||
""" | ||
return self._averager( | ||
data_var, "average", freq, True, center_times, season_config | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new average()
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder freq
can be automatically assigned with a new parameter if weighted=True
. Having freq
seems can confuse with group_average
.
@chengzhuzhang won't be able to review this PR this week. @lee1043 can you give this a shot whenever you get a chance? Thanks |
I am testing this PR at this notebook. While I can give an overview at our regular meeting or we can have a separate chat, just sharing in advance for your interest. Basically what I am doing in the notebook is trying every option, even if that does not make sense, so I can learn how the function work and respond. While I think I am understanding In the notebook, I also tried some silly things, such as get temporal average of monthly time series with freq=day or hour (which does not make sense but just for curious what would return) and compare their differences in Section 1.6. theoretically they might have to return the same value as the original monthly data because they don't have higher frequency input, but It is interesting to see how different they are and there are some patterns in the differences. Regarding the group average, in the notebook I am using monthly time series as input (shape: (1872, Y, X)), so when I was doing For year and season, it seems right numbers were returned for the time dimension. So far I mostly checked number in time dimensions and yet to cross-validate returned values with CDAT. I will continue to work on that. |
Thanks for reviewing @lee1043! You have a great validation process going on.
I agree, and Jill mentioned a similar solution to what you outlined. I think it might confuse users if there is a I will try to figure out how
I noticed this too. This isn't necessarily an implementation bug, but if the user passes a frequency that isn't aligned with the actual frequency of the time intervals (e.g., For example, with the monthly time coordinates
This is possible for the user if they want to calculate weights like that, but it doesn't really make sense. Removing the
The grouping behaviors for group averages and climatologies are based on CDAT. In the case of the monthly time series input, grouping on month will return the original data because Here is a code snippet to validate this behavior against CDAT:
The grouping behavior that produces (12, Y, X) is a result of calculating the climatology, not the group average:
The code lines below show how grouping is done in xCDAT and CDAT, which was also validated in this notebook Lines 33 to 59 in da2cf9e
|
- Update `TestAverage` unit tests to test different time frequencies
def _infer_freq(self) -> Frequency: | ||
"""Infers the time frequency from the coordinates. | ||
|
||
This method infers the time frequency from the coordinates by | ||
calculating the minimum delta and comparing it against a set of | ||
conditionals. | ||
|
||
The native ``xr.infer_freq()`` method does not work for all cases | ||
because the frequency can be irregular (e.g., different hour | ||
measurements), which ends up returning None. | ||
|
||
Returns | ||
------- | ||
Frequency | ||
The time frequency. | ||
""" | ||
time_coords = self._dataset[self._dim_name] | ||
min_delta = pd.to_timedelta(np.diff(time_coords).min(), unit="ns") | ||
|
||
if min_delta < pd.Timedelta(days=1): | ||
return "hour" | ||
elif min_delta >= pd.Timedelta(days=1) and min_delta < pd.Timedelta(days=28): | ||
return "day" | ||
elif min_delta >= pd.Timedelta(days=28) and min_delta < pd.Timedelta(days=365): | ||
return "month" | ||
else: | ||
return "year" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lee1043 I added this new method, _infer_freq()
for inferring the freq
of the coordinates. The frequency is used for calculating weights if weighed=True
.
Let me know how the algorithm looks, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tomvothecoder I think this is great and very straightforward, thank you for revising it. I am playing with the new function in the notebook now.
P.S: typo in above comment weighed=True
: weighed
--> weighted
xcdat/temporal.py
Outdated
def _average(self, data_var: xr.DataArray) -> xr.DataArray: | ||
""" | ||
Returns the weighted averages for a data variable with the time | ||
dimension removed. | ||
|
||
Parameters | ||
---------- | ||
data_var : xr.DataArray | ||
The data variable. | ||
|
||
Returns | ||
------- | ||
xr.DataArray | ||
The weighted averages for a data variable with the time dimension | ||
removed. | ||
""" | ||
dv = data_var.copy() | ||
|
||
This method groups the data variable's values by the time coordinates | ||
and averages them with or without weights. The parameters for | ||
``self._temporal_average()`` are stored as DataArray attributes in the | ||
averaged data variable. | ||
with xr.set_options(keep_attrs=True): | ||
if self._weighted: | ||
self._weights = self._get_weights() | ||
dv = dv.weighted(self._weights).mean(dim=self._dim_name) | ||
else: | ||
dv = dv.mean(dim=self._dim_name) | ||
|
||
dv = self._add_operation_attrs(dv) | ||
|
||
return dv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the main method that performs the work for ds.temporal.average()
.
@tomvothecoder thank you for the revision. I have tested the function in following notebooks in this PR that are using And I think the function is working as expected! |
@lee1043 Great, thanks for adding those notebooks! I will do a final review and then merge. |
@lee1043 thank you for reviewing and testing this PR, also great to get more validation notebooks for it! At some point, we can also do a similar test as Steve's, to loop over all CMIP models and compare results between cdat and xcdat. |
@chengzhuzhang I agree, I'll ask Steve to post his scripts so we can reuse them for validation. |
Description
This PR adds the ability to calculate averages weighted by time with the time dimension removed (via
ds.temporal.average()
). It simply wraps xarray's multi-line function calls into an API, with an example below for comparison.Summary of Changes
temporal.average()
.average()
to.group_average()
class TemporalAccessor
._validate_weights()
for a significant performance increase of about 4-5x (for the monthly Dataset I was testing on)Related API docs: https://xcdat.readthedocs.io/en/feature-201-temporal-mean/generated/xarray.Dataset.temporal.average.html#
Floating point closeness comparison with CDAT/cdutil: https://github.com/xCDAT/xcdat_test/blob/pr-236-validation/validation/v1.0.0/temporal_average/qa_cdat_closeness.ipynb
Checklist
If applicable: