Skip to content

Implementation of healpix cell masking #407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: develop
Choose a base branch
from

Conversation

shmh40
Copy link
Contributor

@shmh40 shmh40 commented Jun 27, 2025

Description

This draft PR implements masking of cell based on healpix cells and the healpix level. Masking can be done at arbitrary healpix levels, and makes use of the nested indexing of the healpix cells.

One question (maybe @tjhunter): is the given implementation ok for passing args for specific masking strategies? I am not sure what style we want. For example, here for healpix masking, we want to pass the healpix level of the data, and the healpix level that we want the masking to occur on e.g. our data is healpix level 5, and we want to do very large scale masking (e.g. level 0 or 1). Of course these args are only relevant when we are doing healpix masking, so it is implemented here just as a dictionary of "strategy_kwargs" that can be passed in the config, with args specific to the masking strategy (hl_data, hl_mask), otherwise it is ignored. Hope that is ok.

Note this PR extends PR #383, and is currently set to merge into that branch shmh40/dev/masking_class.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

Issue Number

Fixes #397.

Code Compatibility

  • I have performed a self-review of my code

Code Performance and Testing

  • I ran the uv run train and (if necessary) uv run evaluate on a least one GPU node and it works
  • If the new feature introduces modifications at the config level, I have made sure to have notified the other software developers through Mattermost and updated the paths in the $WEATHER_GENERATOR_PRIVATE directory

Dependencies

  • I have ensured that the code is still pip-installable after the changes and runs
  • I have tested that new dependencies themselves are pip-installable.
  • I have not introduced new dependencies in the inference portion of the pipeline

Documentation

  • My code follows the style guidelines of this project
  • I have updated the documentation and docstrings to reflect the changes
  • I have added comments to my code, particularly in hard-to-understand areas

Additional Notes

shmh40 and others added 29 commits June 24, 2025 08:14
…ng_rate, update comments, remove archived class
…Masker class, remove handling special cases of masking (all masked)
…any prints and hardcoded hl_mask and hl_data
@shmh40 shmh40 requested a review from clessig June 27, 2025 16:47
@shmh40 shmh40 added the enhancement New feature or request label Jun 27, 2025
@shmh40 shmh40 moved this to In Progress in WeatherGen-dev Jun 27, 2025
@shmh40 shmh40 added the model Related to model training or definition (not generic infra) label Jun 27, 2025
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing it! Looks good already, just some minor comments.

# NOTE: adding strategy_kwargs to allow for strategy-specific configurations
# e.g., for healpix strategy, we might need hl_data and hl_mask parameters
# or for different strategies, we might need different parameters?
strategy_kwargs: dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. This looks good

):
self.masking_rate = masking_rate
self.masking_strategy = masking_strategy
self.masking_rate_sampling = masking_rate_sampling

# NOTE: strategy_kwargs is a dictionary that can hold any additional parameters
self.strategy_kwargs = strategy_kwargs or {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems better to make sure that {} is passed to the function when strategy_kwargs is not set. This should do it:

cf.get( "strategy_kwargs", {})

Also please don't use NOTE in a comment ... that's the whole point of a code comment :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using NOTE for myself! Will remove before merging

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented this, I think it works neatly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, didn't read carefully enough.

@@ -54,6 +61,9 @@ def mask_source(
token_lens = [len(t) for t in tokenized_data]
num_tokens = sum(token_lens)

# print("Length of each token t in tokenized_data:", token_lens)
print("Number of tokens in the batch:", num_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove before we merge the PR

@@ -89,6 +99,10 @@ def mask_source(
if block_size > 0 and num_tokens > 0:
start_index = self.rng.integers(0, max(1, num_tokens - block_size + 1))
flat_mask[start_index : start_index + block_size] = True

elif self.masking_strategy == "healpix":
flat_mask = self._generate_healpix_mask(token_lens, rate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we should have a separate function for each of the masking strategies. My feeling is, we might want to implement in small classes at some point for generality but a separate function for every strategy seems like a good starting point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will raise another issue once this is merged

np.ndarray: A flat boolean array (the token-level mask).
"""

print("Generating HEALPix mask...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove before merging


# NOTE: hl_data and hl_mask are expected to be provided in strategy_kwargs?
hl_data = self.strategy_kwargs.get("hl_data")
hl_mask = self.strategy_kwargs.get("hl_mask")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fail as early as possible when hl_data and hl_mask are not set. Or can we fall back to default values in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjhunter is it ok to fall back to default values for a case like this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! This is a good place for default values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, in this case, I would say self.strategy_kwargs.hl_data etc. rather than looking (yet) for a sensible default.

hl_data = self.strategy_kwargs.get("hl_data")
hl_mask = self.strategy_kwargs.get("hl_mask")

print(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

# print(f"[HEALPix Setup] Each parent cell at L{hl_mask} contains {num_children_per_parent} child cells at L{hl_data}.")

# Choose parent cells to mask based on the specified rate.
num_parents_to_mask = int(np.round(rate * num_parent_cells))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be sampled when masking_rate_sampling = True

# print(f"[HEALPix Masking] Parent IDs selected: {parent_ids_to_mask}")

# Now determine which child cells (and their tokens) are masked.
# This is cells.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

# This is cells.
cell_mask = np.zeros(num_data_cells, dtype=bool)
# print("[HEALPix Masking] Mapping parent cells to child cell indices:")
for parent_id in parent_ids_to_mask:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this loop be vectorized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point, done

Base automatically changed from shmh40/dev/masking_class to develop June 28, 2025 12:31
@shmh40 shmh40 linked an issue Jun 30, 2025 that may be closed by this pull request
@shmh40 shmh40 marked this pull request as ready for review June 30, 2025 11:41
Copy link
Collaborator

@tjhunter tjhunter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shmh40 thanks! I have some style comments. 2 high level comments:

  • it is missing how you would use it. when writing doc / updating the config, think how a newcomer would want to use this feature
  • (not you): the rng in masking is depending on time. This makes the code non-deterministic. This issue to me is larger than biased training. I would love for us to just have a single seed for everything set in config


# NOTE: hl_data and hl_mask are expected to be provided in strategy_kwargs?
hl_data = self.strategy_kwargs.get("hl_data")
hl_mask = self.strategy_kwargs.get("hl_mask")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! This is a good place for default values.

hl_data = self.strategy_kwargs.get("hl_data")
hl_mask = self.strategy_kwargs.get("hl_mask")

if hl_data is None or hl_mask is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: it is better to say: assert hl_data is not None and hl_mask is not None, "If ..."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing systems such as pytest can then do clever things to present to you the offending values if it fails.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can we test this in the constructor. This will be a much cleaner stack trace and it will fail much earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been moved to be tested in the constructor, thanks.

I have not included default values at the moment, which I think is fine for now, we can include them later/in another PR if we like. I am not sure what our plan is overall with defaults, or @tjhunter the reasons for not having defaults. Did you write up a brief doc on your discussion with Sophie? Sorry if I missed it.

"If masking with HEALPix, hl_data and hl_mask must be provided in strategy_kwargs."
)

if hl_mask >= hl_data:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

assert False, "hl_mask must be less than hl_data for HEALPix masking."

num_data_cells = 12 * (4**hl_data)
if len(token_lens) != num_data_cells:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing these out


# if masking_rate_sampling is enabled, sample the rate from a normal distribution.
if self.masking_rate_sampling:
rate = np.clip(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this chunk is copied over. can you factorize them into a single class method? I am sure we will need to adjust this formula eventually, better to adjust it once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an existing issue closely related to this, will do it at some point in a separate PR

# sample the masking rate (with normal distribution centered at masking_rate)
masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements
sampling_rate_target: 1.0
# include a masking strategy here, currently only supporting "random" and "block"
masking_strategy: "random"


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put a commented out example of the healpix strategy? Right now, we do not have documentation for what is expected in hl_data and hl_mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point, thank you! Added!


# NOTE: hl_data and hl_mask are expected to be provided in strategy_kwargs?
hl_data = self.strategy_kwargs.get("hl_data")
hl_mask = self.strategy_kwargs.get("hl_mask")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, in this case, I would say self.strategy_kwargs.hl_data etc. rather than looking (yet) for a sensible default.

rate (float): The desired masking rate, applied to the parent cells.

Returns:
np.ndarray: A flat boolean array (the token-level mask).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the len() of this array? the indexing is unclear to me

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of tokens

):
self.masking_rate = masking_rate
self.masking_strategy = masking_strategy
self.masking_rate_sampling = masking_rate_sampling

# strategy_kwargs is a dictionary that can hold any additional parameters
self.strategy_kwargs = strategy_kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call it strategy_config. kwargs is an implementation choice, and it will be later replaced by a class (likely)

@@ -182,7 +182,9 @@ def __init__(
if cf.training_mode == "forecast":
self.tokenizer = TokenizerForecast(cf.healpix_level, cf.data_loader_rng_seed)
elif cf.training_mode == "masking":
masker = Masker(cf.masking_rate, cf.masking_strategy, cf.masking_rate_sampling)
masker = Masker(
cf.masking_rate, cf.masking_strategy, cf.masking_rate_sampling, cf.get("strategy_kwargs", {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

masking_strategy_extra instead of strategy_kwargs ? strategy is vague.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed name to masking_strategy_config

@clessig
Copy link
Collaborator

clessig commented Jul 6, 2025

  • (not you): the rng in masking is depending on time. This makes the code non-deterministic. This issue to me is larger than biased training. I would love for us to just have a single seed for everything set in config

In my review I already made a suggestion how to fix it ;)

@tjhunter
Copy link
Collaborator

tjhunter commented Jul 7, 2025

  • (not you): the rng in masking is depending on time. This makes the code non-deterministic. This issue to me is larger than biased training. I would love for us to just have a single seed for everything set in config

In my review I already made a suggestion how to fix it ;)

Yes! Somehow it did not appear in the history.

@tjhunter tjhunter added the merge-hold Do not merge this PR yet, being tested. label Jul 16, 2025
@shmh40 shmh40 requested a review from sophie-xhonneux July 21, 2025 14:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request merge-hold Do not merge this PR yet, being tested. model Related to model training or definition (not generic infra)
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

Masking strategies based on healpix cells
3 participants