-
Notifications
You must be signed in to change notification settings - Fork 27
Implementation of healpix cell masking #407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…sking to use these, then style improvements
…ng_rate, update comments, remove archived class
…rom batchify_source
…Masker class, remove handling special cases of masking (all masked)
…any prints and hardcoded hl_mask and hl_data
…g strategy specific args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for implementing it! Looks good already, just some minor comments.
src/weathergen/datasets/masking.py
Outdated
# NOTE: adding strategy_kwargs to allow for strategy-specific configurations | ||
# e.g., for healpix strategy, we might need hl_data and hl_mask parameters | ||
# or for different strategies, we might need different parameters? | ||
strategy_kwargs: dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. This looks good
src/weathergen/datasets/masking.py
Outdated
): | ||
self.masking_rate = masking_rate | ||
self.masking_strategy = masking_strategy | ||
self.masking_rate_sampling = masking_rate_sampling | ||
|
||
# NOTE: strategy_kwargs is a dictionary that can hold any additional parameters | ||
self.strategy_kwargs = strategy_kwargs or {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems better to make sure that {} is passed to the function when strategy_kwargs is not set. This should do it:
cf.get( "strategy_kwargs", {})
Also please don't use NOTE in a comment ... that's the whole point of a code comment :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using NOTE for myself! Will remove before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented this, I think it works neatly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, didn't read carefully enough.
src/weathergen/datasets/masking.py
Outdated
@@ -54,6 +61,9 @@ def mask_source( | |||
token_lens = [len(t) for t in tokenized_data] | |||
num_tokens = sum(token_lens) | |||
|
|||
# print("Length of each token t in tokenized_data:", token_lens) | |||
print("Number of tokens in the batch:", num_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove before we merge the PR
@@ -89,6 +99,10 @@ def mask_source( | |||
if block_size > 0 and num_tokens > 0: | |||
start_index = self.rng.integers(0, max(1, num_tokens - block_size + 1)) | |||
flat_mask[start_index : start_index + block_size] = True | |||
|
|||
elif self.masking_strategy == "healpix": | |||
flat_mask = self._generate_healpix_mask(token_lens, rate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we should have a separate function for each of the masking strategies. My feeling is, we might want to implement in small classes at some point for generality but a separate function for every strategy seems like a good starting point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will raise another issue once this is merged
src/weathergen/datasets/masking.py
Outdated
np.ndarray: A flat boolean array (the token-level mask). | ||
""" | ||
|
||
print("Generating HEALPix mask...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove before merging
src/weathergen/datasets/masking.py
Outdated
|
||
# NOTE: hl_data and hl_mask are expected to be provided in strategy_kwargs? | ||
hl_data = self.strategy_kwargs.get("hl_data") | ||
hl_mask = self.strategy_kwargs.get("hl_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should fail as early as possible when hl_data and hl_mask are not set. Or can we fall back to default values in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjhunter is it ok to fall back to default values for a case like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! This is a good place for default values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally, in this case, I would say self.strategy_kwargs.hl_data
etc. rather than looking (yet) for a sensible default.
src/weathergen/datasets/masking.py
Outdated
hl_data = self.strategy_kwargs.get("hl_data") | ||
hl_mask = self.strategy_kwargs.get("hl_mask") | ||
|
||
print( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
# print(f"[HEALPix Setup] Each parent cell at L{hl_mask} contains {num_children_per_parent} child cells at L{hl_data}.") | ||
|
||
# Choose parent cells to mask based on the specified rate. | ||
num_parents_to_mask = int(np.round(rate * num_parent_cells)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be sampled when masking_rate_sampling = True
src/weathergen/datasets/masking.py
Outdated
# print(f"[HEALPix Masking] Parent IDs selected: {parent_ids_to_mask}") | ||
|
||
# Now determine which child cells (and their tokens) are masked. | ||
# This is cells. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unclear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
src/weathergen/datasets/masking.py
Outdated
# This is cells. | ||
cell_mask = np.zeros(num_data_cells, dtype=bool) | ||
# print("[HEALPix Masking] Mapping parent cells to child cell indices:") | ||
for parent_id in parent_ids_to_mask: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this loop be vectorized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point, done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shmh40 thanks! I have some style comments. 2 high level comments:
- it is missing how you would use it. when writing doc / updating the config, think how a newcomer would want to use this feature
- (not you): the
rng
in masking is depending on time. This makes the code non-deterministic. This issue to me is larger than biased training. I would love for us to just have a single seed for everything set in config
src/weathergen/datasets/masking.py
Outdated
|
||
# NOTE: hl_data and hl_mask are expected to be provided in strategy_kwargs? | ||
hl_data = self.strategy_kwargs.get("hl_data") | ||
hl_mask = self.strategy_kwargs.get("hl_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! This is a good place for default values.
src/weathergen/datasets/masking.py
Outdated
hl_data = self.strategy_kwargs.get("hl_data") | ||
hl_mask = self.strategy_kwargs.get("hl_mask") | ||
|
||
if hl_data is None or hl_mask is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: it is better to say: assert hl_data is not None and hl_mask is not None, "If ..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
testing systems such as pytest
can then do clever things to present to you the offending values if it fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But can we test this in the constructor. This will be a much cleaner stack trace and it will fail much earlier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been moved to be tested in the constructor, thanks.
I have not included default values at the moment, which I think is fine for now, we can include them later/in another PR if we like. I am not sure what our plan is overall with defaults, or @tjhunter the reasons for not having defaults. Did you write up a brief doc on your discussion with Sophie? Sorry if I missed it.
src/weathergen/datasets/masking.py
Outdated
"If masking with HEALPix, hl_data and hl_mask must be provided in strategy_kwargs." | ||
) | ||
|
||
if hl_mask >= hl_data: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
src/weathergen/datasets/masking.py
Outdated
assert False, "hl_mask must be less than hl_data for HEALPix masking." | ||
|
||
num_data_cells = 12 * (4**hl_data) | ||
if len(token_lens) != num_data_cells: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing these out
|
||
# if masking_rate_sampling is enabled, sample the rate from a normal distribution. | ||
if self.masking_rate_sampling: | ||
rate = np.clip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this chunk is copied over. can you factorize them into a single class method? I am sure we will need to adjust this formula eventually, better to adjust it once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an existing issue closely related to this, will do it at some point in a separate PR
config/default_config.yml
Outdated
# sample the masking rate (with normal distribution centered at masking_rate) | ||
masking_rate_sampling: True | ||
# sample a subset of all target points, useful e.g. to reduce memory requirements | ||
sampling_rate_target: 1.0 | ||
# include a masking strategy here, currently only supporting "random" and "block" | ||
masking_strategy: "random" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you put a commented out example of the healpix strategy? Right now, we do not have documentation for what is expected in hl_data and hl_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point, thank you! Added!
src/weathergen/datasets/masking.py
Outdated
|
||
# NOTE: hl_data and hl_mask are expected to be provided in strategy_kwargs? | ||
hl_data = self.strategy_kwargs.get("hl_data") | ||
hl_mask = self.strategy_kwargs.get("hl_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally, in this case, I would say self.strategy_kwargs.hl_data
etc. rather than looking (yet) for a sensible default.
rate (float): The desired masking rate, applied to the parent cells. | ||
|
||
Returns: | ||
np.ndarray: A flat boolean array (the token-level mask). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the len() of this array? the indexing is unclear to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number of tokens
src/weathergen/datasets/masking.py
Outdated
): | ||
self.masking_rate = masking_rate | ||
self.masking_strategy = masking_strategy | ||
self.masking_rate_sampling = masking_rate_sampling | ||
|
||
# strategy_kwargs is a dictionary that can hold any additional parameters | ||
self.strategy_kwargs = strategy_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would call it strategy_config. kwargs
is an implementation choice, and it will be later replaced by a class (likely)
@@ -182,7 +182,9 @@ def __init__( | |||
if cf.training_mode == "forecast": | |||
self.tokenizer = TokenizerForecast(cf.healpix_level, cf.data_loader_rng_seed) | |||
elif cf.training_mode == "masking": | |||
masker = Masker(cf.masking_rate, cf.masking_strategy, cf.masking_rate_sampling) | |||
masker = Masker( | |||
cf.masking_rate, cf.masking_strategy, cf.masking_rate_sampling, cf.get("strategy_kwargs", {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
masking_strategy_extra
instead of strategy_kwargs
? strategy is vague.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed name to masking_strategy_config
In my review I already made a suggestion how to fix it ;) |
Yes! Somehow it did not appear in the history. |
…onfig, update config with example of healpix
Description
This draft PR implements masking of cell based on healpix cells and the healpix level. Masking can be done at arbitrary healpix levels, and makes use of the nested indexing of the healpix cells.
One question (maybe @tjhunter): is the given implementation ok for passing args for specific masking strategies? I am not sure what style we want. For example, here for healpix masking, we want to pass the healpix level of the data, and the healpix level that we want the masking to occur on e.g. our data is healpix level 5, and we want to do very large scale masking (e.g. level 0 or 1). Of course these args are only relevant when we are doing healpix masking, so it is implemented here just as a dictionary of "strategy_kwargs" that can be passed in the config, with args specific to the masking strategy (hl_data, hl_mask), otherwise it is ignored. Hope that is ok.
Note this PR extends PR #383, and is currently set to merge into that branch shmh40/dev/masking_class.
Type of Change
Issue Number
Fixes #397.
Code Compatibility
Code Performance and Testing
uv run train
and (if necessary)uv run evaluate
on a least one GPU node and it works$WEATHER_GENERATOR_PRIVATE
directoryDependencies
Documentation
Additional Notes