diff --git a/pyproject.toml b/pyproject.toml index 398f6661c0..4082a74ba0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,10 +53,10 @@ dependencies = [ # # 8.9.7 [tool.pixi.target.linux-64.pypi-dependencies] -tensorflow = {version = "~=2.16.1", extras = ["and-cuda"] } +tensorflow = {version = "~=2.16.2", extras = ["and-cuda"] } [tool.pixi.target.osx-arm64.dependencies] -tensorflow = {version = "~=2.16.1", channel = "conda-forge"} +tensorflow = {version = "~=2.16.2", channel = "conda-forge"} [project.optional-dependencies] dev = [ diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index f591aff67b..c6b3b84105 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1066,7 +1066,8 @@ def _run_mirrored_grad( ) for key, value in per_replica_details.items() } - apply_fn(total_grad) + with self.strategy.scope(): + apply_fn(total_grad) return mean_loss_details @tf.function(reduce_retracing=True) @@ -1459,6 +1460,36 @@ def _get_train_fns(self, train_gen=True, train_disc=False): logger.error(msg) raise ValueError(msg) + def _mask_obs_in_exo(self, hi_res_exo): + """Randomly mask a fraction of non-NaN obs values in the exo dict. + + For each key in ``self.obs_features``, ``self._obs_mask_fraction`` of + the existing non-NaN locations are replaced with NaN. All other + keys are returned unchanged. This is called inside + ``get_single_grad_gen`` so that the generator receives sparser + observations than those used by the loss. + + Parameters + ---------- + hi_res_exo : dict + Exogenous high-resolution input dict returned by + ``get_hr_exo_input``. + + Returns + ------- + dict + Copy of ``hi_res_exo`` with obs features partially masked. + """ + out = dict(hi_res_exo) + for k in self.obs_features: + v = out[k] + not_nan = tf.math.logical_not(tf.math.is_nan(v)) + rand = tf.random.uniform(tf.shape(v), dtype=v.dtype) + drop = tf.math.logical_and(not_nan, rand < self._obs_mask_fraction) + nan_fill = tf.fill(tf.shape(v), tf.cast(float('nan'), v.dtype)) + out[k] = tf.where(drop, nan_fill, v) + return out + @tf.function(reduce_retracing=True) def get_single_grad_gen( self, @@ -1470,6 +1501,7 @@ def get_single_grad_gen( """Run generator-only gradient calculation for one mini-batch.""" with self._training_scope(device_name), tf.GradientTape() as tape: hi_res_exo = self.get_hr_exo_input(hi_res_true) + hi_res_exo = self._mask_obs_in_exo(hi_res_exo) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( hi_res_true, hi_res_gen, **calc_loss_kwargs @@ -1477,7 +1509,7 @@ def get_single_grad_gen( grad = tape.gradient(loss, self.generator_weights) return grad, loss_details - @tf.function + @tf.function(reduce_retracing=True) def apply_grad_gen(self, grad): """Apply a generator gradient update.""" self.optimizer.apply_gradients(zip(grad, self.generator_weights)) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 15325d8cc5..8444f9b474 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -39,6 +39,7 @@ def __init__( default_device=None, name=None, sparse_disc=False, + obs_mask_fraction=0.0, ): """ Parameters @@ -105,6 +106,15 @@ def __init__( observations for training. Note that if True, the discriminator model architecture should be designed to handle sparse data (e.g. by using masking layers or other techniques). + obs_mask_fraction : float + Fraction of non-NaN observation values to randomly mask (set to + NaN) in the exogenous obs input to the generator during training. + This is applied *after* ``get_hr_exo_input`` so the loss (which + uses the unmasked ``hi_res_true``) still sees the full observation + density, while the generator must learn to infer spatial structure + from the sparser input. Only features whose key ends with + ``'_obs'`` are masked; e.g. topography is not affected. Default + is 0.0 (no additional masking). name : str | None Optional name for the GAN. """ @@ -146,6 +156,7 @@ def __init__( self._means = means self._stdevs = stdevs self._sparse_disc = sparse_disc + self._obs_mask_fraction = obs_mask_fraction def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -374,7 +385,7 @@ def get_single_grad_disc( grad = tape.gradient(loss, self.discriminator_weights) return grad, loss_details - @tf.function + @tf.function(reduce_retracing=True) def apply_grad_disc(self, grad): """Apply a discriminator gradient update.""" self.optimizer_disc.apply_gradients( @@ -453,6 +464,8 @@ def model_params(self): 'stdevs': stdevs, 'meta': self.meta, 'default_device': self.default_device, + 'sparse_disc': self._sparse_disc, + 'obs_mask_fraction': self._obs_mask_fraction, } @property diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index b542cc9d5e..d937c57c99 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -448,7 +448,7 @@ def run(cls, strategy, node_index): """ if not strategy.node_finished(node_index): logger.info( - 'Starting forward pass on node %s with %s chunks using %s ' + 'Starting forward pass on node %s with %s chunk(s) using %s ' 'execution.', node_index, len(strategy.node_chunks[node_index]), @@ -517,7 +517,7 @@ def _run_serial(cls, strategy, node_index): raise MemoryError(msg) logger.info( - 'Finished forward passes on %s chunks in %s', + 'Finished forward pass(es) on %s chunk(s) in %s', len(strategy.node_chunks[node_index]), dt.now() - start, ) diff --git a/sup3r/pipeline/forward_pass_cli.py b/sup3r/pipeline/forward_pass_cli.py index ff42f4b510..6ac778f8cc 100644 --- a/sup3r/pipeline/forward_pass_cli.py +++ b/sup3r/pipeline/forward_pass_cli.py @@ -58,6 +58,11 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): basename = config.get('job_name') log_pattern = config.get('log_pattern', None) + logger.info( + 'Initializing forward pass strategy on head node to ' + 'compute chunk distribution indices: %s', + config_file, + ) sig = signature(ForwardPassStrategy) strategy_kwargs = {k: v for k, v in config.items() if k in sig.parameters} strategy = ForwardPassStrategy(**strategy_kwargs, head_node=True) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index a4f916edcd..98643a54a5 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -355,11 +355,11 @@ def get_time_slices(self): return unpadded_slice, padded_slice def init_input_handler(self): - """Get input handler instance for given input kwargs. If self.head_node - is False or features are being cached we get all requested features. - Otherwise this is part of initialization on a head node and just used - to get the shape of the input domain, so we don't need to get any - features yet.""" + """Get input handler instance for given input kwargs. If + `self.head_node` is False or features are being cached we get all + requested features. Otherwise this is part of initialization on a head + node and just used to get the shape of the input domain, so we don't + need to get any features yet.""" self.input_handler_kwargs = self.input_handler_kwargs or {} self.input_handler_kwargs['file_paths'] = self.file_paths self.input_handler_kwargs['features'] = self.features @@ -374,10 +374,9 @@ def init_input_handler(self): input_handler_kwargs['time_slice'] = self.padded_time_slice logger.info( - 'Initializing %s for %s features over padded time slice %s.', - InputHandler.__name__, + 'Loading low-resolution data for %s features: %s', len(input_handler_kwargs['features']), - self.padded_time_slice, + input_handler_kwargs['features'], ) handler = InputHandler(**input_handler_kwargs) logger.info( @@ -469,7 +468,7 @@ def preflight(self): non_masked = self.fwp_slicer.n_spatial_chunks - sum(self.fwp_mask) non_masked *= self.fwp_slicer.n_time_chunks logger.info( - 'Chunk strategy uses %s nodes across %s total chunks ' + 'Chunk strategy uses %s node(s) across %s chunk(s): ' '(%s spatial x %s temporal, %s unmasked).', len(self.node_chunks), self.fwp_slicer.n_chunks, @@ -689,7 +688,7 @@ def load_exo_data(self, model): data.update(ExoDataHandler(**exo_kwargs).data) exo_data = ExoData(data) if exo_kwargs_list: - logger.info( + logger.debug( 'Finished loading exogenous data for %s features.', len(exo_kwargs_list), ) diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 95bb9dcfe4..3b84663a38 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -244,11 +244,18 @@ def values(self): return np.asarray(out) return out - def to_dataarray(self) -> Union[np.ndarray, da.core.Array]: + def to_dataarray(self) -> xr.DataArray: """Return xr.DataArray for the contained xr.Dataset.""" if not self.features: - coords = [self._ds[f] for f in Dimension.coords_2d()] - return da.stack(coords, axis=-1) + # xarray raises when to_array() is called on an empty dataset. + # Return a zero-variable DataArray with the correct dims. + spatial_time_dims = tuple( + d for d in Dimension.order() if d in self._ds.sizes + ) + dims = (*spatial_time_dims, Dimension.VARIABLE) + return xr.DataArray( + np.empty(self.shape, dtype=np.float32), dims=dims + ) return self.ordered(self._ds.to_array()) def as_array(self): @@ -259,8 +266,8 @@ def as_array(self): out = self.to_dataarray() out = getattr(out, 'data', out) - if self.loaded: + out = np.asarray(out) self._as_array_cache = out return out @@ -670,8 +677,10 @@ def set_regular_grid(self): np.diff(self._ds[Dimension.LONGITUDE].values, axis=0), 0 ) if not (lat_lon_2d and same_lats and same_lons): - msg = ('Cannot set regular grid for non-regular data. Latitude ' - f'check = {same_lats}, Longitude check = {same_lons}.') + msg = ( + 'Cannot set regular grid for non-regular data. Latitude ' + f'check = {same_lats}, Longitude check = {same_lons}.' + ) logger.warning(msg) warn(msg) else: diff --git a/sup3r/preprocessing/batch_queues/dc.py b/sup3r/preprocessing/batch_queues/dc.py index 237dbf16a5..09c6e46765 100644 --- a/sup3r/preprocessing/batch_queues/dc.py +++ b/sup3r/preprocessing/batch_queues/dc.py @@ -78,17 +78,14 @@ def __init__(self, samplers, n_space_bins=1, n_time_bins=1, **kwargs): Parameters ---------- samplers : list[Sampler] - List of Sampler instances + See :class:`~sup3r.preprocessing.BatchQueueDC` + for full documentation. n_space_bins : int - Number of spatial bins to use for weighted sampling. e.g. if this - is 4 the spatial domain will be divided into 4 equal regions and - losses will be calculated across these regions during traning in - order to adaptively sample from lower performing regions. + See :class:`~sup3r.preprocessing.BatchQueueDC` + for full documentation. n_time_bins : int - Number of time bins to use for weighted sampling. e.g. if this - is 4 the temporal domain will be divided into 4 equal periods and - losses will be calculated across these periods during traning in - order to adaptively sample from lower performing time periods. + See :class:`~sup3r.preprocessing.BatchQueueDC` + for full documentation. kwargs : dict Keyword arguments for parent class. """ diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index f17c0e55a3..88523a3c52 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -73,7 +73,7 @@ def __init__( # derivations can reuse them instead of recomputing the same feature. for f in new_features: self.data[f] = self.derive(f) - logger.info('Finished deriving %s.', f) + logger.debug('Finished deriving %s.', f) self.data = ( self.data[list(self.data.coords)] if not features diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 06a91e19b9..b963c8f74b 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -16,7 +16,11 @@ uniform_time_sampler, ) from sup3r.preprocessing.utilities import compute_if_dask, lowered -from sup3r.utilities.utilities import RANDOM_GENERATOR +from sup3r.utilities.utilities import ( + OUTPUT_ATTRS, + RANDOM_GENERATOR, + get_feature_basename, +) logger = logging.getLogger(__name__) @@ -82,29 +86,31 @@ def __init__( proxy_obs_kwargs : dict | None Optional dictionary of keyword arguments to pass to the proxy observation generator. This is only used when training with proxy - observations. Keys can include ``onshore_obs_frac``, - ``offshore_obs_frac``, and ``perturbation_scale``. + observations. Top-level keys (``onshore_obs_frac``, + ``offshore_obs_frac``, ``perturbation_scale``) apply to all obs + features as defaults. A source-feature-named sub-dict (keyed by + the gridded feature name, e.g. ``u_100m`` for ``u_100m_obs``) + overrides any of those keys for that specific feature:: + + proxy_obs_kwargs = { + 'onshore_obs_frac': {'spatial': [0.3, 0.7], 'temporal': 1}, + 'perturbation_scale': 0.01, + 'u_100m': { + 'onshore_obs_frac': {'spatial': 0.9}, + 'perturbation_scale': 0.05, + }, + } perturbation_scale : float - Scale of the perturbation to add to the proxy observations when - using proxy observations. This specifies the multiplier of the - noise sampled from (-standard deviation, standard deviation). - The standdard deviation is calculated per feature over each - batch. - onshore_obs_frac : float | dict - Fraction of onshore observations to include in each batch when - using proxy observations. This can be a single float or a - dictionary with keys 'spatial' and 'temporal' to specify the - fraction for each domain. If a dictionary is provided, the - actual fraction for each batch will be sampled uniformly - between the specified spatial and temporal fractions. - offshore_obs_frac : float | dict - Fraction of offshore observations to include in each batch when - using proxy observations. This can be a single float or a - dictionary with keys 'spatial' and 'temporal' to specify the - fraction for each domain. If a dictionary is provided, the - actual fraction for each batch will be sampled uniformly - between the specified spatial and temporal fractions. + If non-zero, uniform noise scaled by this value times the + per-feature batch standard deviation is added to proxy obs. + onshore_obs_frac : dict + Fraction of onshore observations per batch. Keys are + 'spatial' and 'temporal'. Each value is a float (fixed + fraction) or a [min, max] list to sample uniformly per batch. + offshore_obs_frac : dict + Same as ``onshore_obs_frac`` but applied where topography + <= 0. Ignored when topography is not a source feature. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. @@ -136,35 +142,6 @@ def use_proxy_obs(self): """ return bool(self.proxy_obs_kwargs) - @property - def onshore_obs_frac(self): - """Fraction of onshore observations to include in each batch when using - proxy observations. This can be a single float or a dictionary with - keys 'spatial' and 'temporal' to specify the fraction for each domain. - If a dictionary is provided, the actual fraction for each batch will be - sampled uniformly between the specified spatial and temporal fractions. - """ - return self.proxy_obs_kwargs.get('onshore_obs_frac', {}) - - @property - def offshore_obs_frac(self): - """Fraction of offshore observations to include in each batch when - using proxy observations. This can be a single float or a dictionary - with keys 'spatial' and 'temporal' to specify the fraction for each - domain. If a dictionary is provided, the actual fraction for each - batch will be sampled uniformly between the specified spatial and - temporal fractions. - """ - return self.proxy_obs_kwargs.get('offshore_obs_frac', {}) - - @property - def perturbation_scale(self): - """Scale of the perturbation to add to the proxy observations when - using proxy observations. This specifies the multiplier of the noise - sampled from (-standard deviation, standard deviation). - """ - return self.proxy_obs_kwargs.get('perturbation_scale', 0.01) - def get_sample_index(self, n_obs=None): """Randomly gets spatiotemporal sample index. @@ -218,12 +195,15 @@ def preflight(self): 'building batches with n_samples = batch_size, each with ' 'n_time_steps = sample_shape[2].' ) - if self.data.shape[2] < self.sample_shape[2] * self.batch_size: + if ( + self.data.shape[2] < self.sample_shape[2] * self.batch_size + and self.data.shape[2] > 1 + ): logger.warning(msg) warn(msg) if self.mode == 'eager': - logger.info('Received mode = "eager".') + logger.debug('Received mode = "eager".') _ = self.compute() def check_proxy_obs_consistency(self): @@ -430,12 +410,14 @@ def _fast_batch_possible(self): def _get_proxy_obs(self, hi_res): """Generate proxy observation data by masking the gridded high-res - data. Adds a perturbation to the proxy observations sampled from a - gaussian distribution with mean 0 and standard deviation equal to the - standard deviation of the unmasked values for each feature. This is - done to prevent the model from learning to ignore the obs features - because they are exactly the same as the gridded features at the - observed locations. Unobserved locations are set to NaN. + data. Optionally adds a perturbation to the proxy observations sampled + from a gaussian distribution with mean 0 and standard deviation equal + to the standard deviation of the unmasked values for each feature + multiplied by perturbation_scale. This is done to prevent the model + from learning to ignore the obs features because they are exactly the + same as the gridded features at the observed locations. This can also + encourage the model to condition on obs that differ significantly from + the gridded data. Unobserved locations are set to NaN. Parameters ---------- @@ -451,11 +433,22 @@ def _get_proxy_obs(self, hi_res): """ obs_mask = self._get_full_obs_mask(hi_res) obs = hi_res[..., self.obs_features_ind].copy() + stds = np.std(obs, axis=(1, 2, 3), keepdims=True) obs[obs_mask[..., : obs.shape[-1]]] = np.nan - if self.perturbation_scale > 0: - stdev = np.nanstd(obs, axis=(0, 1, 2, 3), keepdims=True) - noise = np.random.uniform(-stdev, stdev) - obs += self.perturbation_scale * noise + for i, feat in enumerate(self.obs_features): + scale = self._get_proxy_kwarg('perturbation_scale', feat, 0) + if scale > 0: + srange = stds[..., i] * scale + obs[..., i] += np.random.normal(scale=srange) + base = get_feature_basename(feat.replace('_obs', '')) + attrs = OUTPUT_ATTRS.get(base, {}) + lo = attrs.get('min', -np.inf) + hi = attrs.get('max', np.inf) + obs[..., i] = np.where( + np.isnan(obs[..., i]), + obs[..., i], + np.clip(obs[..., i], lo, hi), + ) return obs def _append_obs_features(self, samples): @@ -641,6 +634,28 @@ def obs_features_ind(self): ) return [self.hr_source_features.index(f) for f in check_feats] + def _get_proxy_kwarg(self, key, feat, default): + """Get a proxy obs kwarg value for a specific obs feature, with + fallback to the global default in ``proxy_obs_kwargs``. + + Parameters + ---------- + key : str + The kwarg name, e.g. ``'onshore_obs_frac'`` or + ``'perturbation_scale'``. + feat : str + The obs feature name (e.g. ``'u_100m_obs'``). The ``'_obs'`` + suffix is stripped to look up the source-feature override key. + default : + Value returned when neither a feature-level nor a global entry + exists in ``proxy_obs_kwargs``. + """ + src = feat.replace('_obs', '') + feat_overrides = self.proxy_obs_kwargs.get(src, {}) + if key in feat_overrides: + return feat_overrides[key] + return self.proxy_obs_kwargs.get(key, default) + def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): """Get observation mask for a given spatial and time obs fraction for an entire batch. This is divided between spatial and time fractions @@ -666,16 +681,14 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): ------- np.ndarray Mask which is True for locations that are not observed and False - for locations that are observed. - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) + for locations that are observed. Shape: + (n_obs, spatial_1, spatial_2, n_temporal, 1) Notes ----- - The output mask is repeated along the feature dimension, so each - feature will have the same observation mask. The output mask is not - repeated along the batch dimension, so each sample in the batch will - have a different observation mask. + The output mask has a trailing singleton feature dimension. Callers + are responsible for repeating or concatenating across features. + Each sample in the batch has an independent mask. """ s_range = ( spatial_frac @@ -688,7 +701,6 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): else [time_frac, time_frac] ) n_obs, n_spatial_1, n_spatial_2, n_temporal = hi_res.shape[:-1] - n_features = len(self.obs_features) s_fracs = RANDOM_GENERATOR.uniform(*s_range, size=n_obs) t_fracs = RANDOM_GENERATOR.uniform(*t_range, size=n_obs) @@ -706,21 +718,57 @@ def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): t_mask = t_mask[:, None, None, :, None] mask = ~(s_mask & t_mask) - return np.repeat(mask, n_features, axis=-1) + return mask + + def _get_topo(self, hi_res): + """Return the topography slice from ``hi_res``, or ``None`` if + topography is not in the source features.""" + if 'topography' not in self.hr_source_features: + return None + topo_idx = self.hr_source_features.index('topography') + return hi_res[..., topo_idx] + + def _get_feat_obs_mask(self, hi_res, feat, topo): + """Build the observation mask for a single obs feature, applying the + offshore mask where topography is non-positive when ``topo`` is + provided. + + Parameters + ---------- + hi_res : np.ndarray + High-resolution batch data. + feat : str + Obs feature name (e.g. ``'u_100m_obs'``). + topo : np.ndarray | None + Topography array with shape ``(n_obs, s1, s2, n_temporal)``, or + ``None`` when topography is unavailable. + + Returns + ------- + np.ndarray + Boolean mask with shape ``(n_obs, s1, s2, n_temporal, 1)``. + """ + on_frac = self._get_proxy_kwarg('onshore_obs_frac', feat, {}) + on_sf = on_frac.get('spatial', 0.0) + on_tf = on_frac.get('temporal', 1.0) + feat_mask = self._get_obs_mask(hi_res, on_sf, on_tf) + if topo is None: + return feat_mask + off_frac = self._get_proxy_kwarg('offshore_obs_frac', feat, {}) + if not off_frac: + return feat_mask + off_sf = off_frac.get('spatial', 0.0) + off_tf = off_frac.get('temporal', 1.0) + offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) + return np.where(topo[..., None] > 0, feat_mask, offshore_mask) def _get_full_obs_mask(self, hi_res): - """Define observation mask for the current batch. This differs from - ``_get_obs_mask`` by defining a composite mask based on separate - onshore and offshore masks. This is because there is often more - observation data available onshore than offshore.""" - on_sf = self.onshore_obs_frac.get('spatial', 0.0) - on_tf = self.onshore_obs_frac.get('temporal', 1.0) - obs_mask = self._get_obs_mask(hi_res, on_sf, on_tf) - if 'topography' in self.hr_source_features and self.offshore_obs_frac: - topo_idx = self.hr_source_features.index('topography') - topo = hi_res[..., topo_idx] - off_sf = self.offshore_obs_frac.get('spatial', 0.0) - off_tf = self.offshore_obs_frac.get('temporal', 1.0) - offshore_mask = self._get_obs_mask(hi_res, off_sf, off_tf) - obs_mask = np.where(topo[..., None] > 0, obs_mask, offshore_mask) - return obs_mask + """Define observation mask for the current batch. Builds a per-feature + composite mask that applies separate onshore and offshore fractions and + supports per-feature ``proxy_obs_kwargs`` overrides.""" + topo = self._get_topo(hi_res) + per_feat_masks = [ + self._get_feat_obs_mask(hi_res, feat, topo) + for feat in self.obs_features + ] + return np.concatenate(per_feat_masks, axis=-1) diff --git a/sup3r/preprocessing/samplers/cc.py b/sup3r/preprocessing/samplers/cc.py index ba36b5d8c7..aa3ed747e2 100644 --- a/sup3r/preprocessing/samplers/cc.py +++ b/sup3r/preprocessing/samplers/cc.py @@ -42,48 +42,22 @@ def __init__( Parameters ---------- data : Sup3rDataset - A :class:`~sup3r.preprocessing.Sup3rDataset` instance with low-res - and high-res data members + See :class:`~sup3r.preprocessing.DualSampler` + for full documentation. sample_shape : tuple - Size of arrays to sample from the high-res data. The sample shape - for the low-res sampler will be determined from the enhancement - factors. + See :class:`~sup3r.preprocessing.DualSampler` + for full documentation. s_enhance : int - Spatial enhancement factor + See :class:`~sup3r.preprocessing.DualSampler` + for full documentation. t_enhance : int - Temporal enhancement factor + Temporal enhancement factor. Defaults to 24 for daily data. feature_sets : Optional[dict] - Optional dictionary describing how the full set of features is - split between ``lr_features``, ``hr_exo_features``, and - ``hr_out_features``. - - lr_features : list | tuple - List of feature names or patt*erns to use as low-resolution - model inputs. If no entry is provided then all available - features from the data will be used. - hr_out_features : list | tuple - List of feature names or patt*erns that should be output - by the generative model and available as ground truth targets. - If no entry is provided then all features in lr_features will - be used. - hr_exo_features : list | tuple - List of feature names or patt*erns that should be available as - high-resolution model inputs (like topography or observations) - or for bespoke loss functions. Features used as inputs are - injected into the model mid-network to condition output on - high-resolution information. The model configuration should - have the appropriate layers to use these features. e.g. - ``Sup3rConcat`` for topography injection, ``Sup3rObsModel`` or - ``Sup3rCrossAttention`` for obs injection. If no entry is - provided then hr_exo_features will be empty. - - *To include sparse features as inputs or targets the features - must have an "_obs" suffix. + See :class:`~sup3r.preprocessing.DualSampler` + for full documentation. mode : str - Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode - pre-loads all data into memory as numpy arrays for faster access. - 'lazy' mode samples directly from the underlying data object, which - could be backed by dask arrays or on-disk netCDF files. + See :class:`~sup3r.preprocessing.DualSampler` + for full documentation. See Also -------- @@ -99,10 +73,12 @@ def __init__( if t_enhance == 1: hr = data.daily if s_enhance > 1: - lr = lr.coarsen({ - Dimension.SOUTH_NORTH: s_enhance, - Dimension.WEST_EAST: s_enhance, - }).mean() + lr = lr.coarsen( + { + Dimension.SOUTH_NORTH: s_enhance, + Dimension.WEST_EAST: s_enhance, + } + ).mean() data = Sup3rDataset(low_res=lr, high_res=hr) super().__init__( data=data, diff --git a/sup3r/preprocessing/samplers/dc.py b/sup3r/preprocessing/samplers/dc.py index 7143a46e73..aa7e59e0ac 100644 --- a/sup3r/preprocessing/samplers/dc.py +++ b/sup3r/preprocessing/samplers/dc.py @@ -42,49 +42,15 @@ def __init__( Parameters ---------- data : Union[Sup3rX, Sup3rDataset], - Object with data that will be sampled from. Usually the `.data` - attribute of various :class:`Container` objects. i.e. - :class:`Loader`, :class:`Rasterizer`, :class:`Deriver`, as long as - the spatial dimensions are not flattened. + See :class:`~sup3r.preprocessing.Sampler` for full documentation. sample_shape : tuple - Size of arrays to sample from the contained data. + See :class:`~sup3r.preprocessing.Sampler` for full documentation. batch_size : int - Number of samples to get to build a single batch. A sample of - (sample_shape[0], sample_shape[1], batch_size * sample_shape[2]) - is first selected from underlying dataset and then reshaped into - (batch_size, *sample_shape) to get a single batch. This is more - efficient than getting N = batch_size samples and then stacking. + See :class:`~sup3r.preprocessing.Sampler` for full documentation. feature_sets : Optional[dict] - Optional dictionary describing how the full set of features is - split between ``lr_features``, ``hr_exo_features``, and - ``hr_out_features``. - - lr_features : list | tuple - List of feature names or patt*erns to use as low-resolution - model inputs. If no entry is provided then all available - features from the data will be used. - hr_out_features : list | tuple - List of feature names or patt*erns that should be output - by the generative model and available as ground truth targets. - If no entry is provided then all features in lr_features will - be used. - hr_exo_features : list | tuple - List of feature names or patt*erns that should be available as - high-resolution model inputs (like topography or observations) - or for bespoke loss functions. Features used as inputs are - injected into the model mid-network to condition output on - high-resolution information. The model configuration should - have the appropriate layers to use these features. e.g. - ``Sup3rConcat`` for topography injection, ``Sup3rObsModel`` or - ``Sup3rCrossAttention`` for obs injection. If no entry is - provided then hr_exo_features will be empty. - - - *To include sparse features as inputs or targets the features - must have an "_obs" suffix. + See :class:`~sup3r.preprocessing.Sampler` for full documentation. mode : str - Loading mode for sampling. - See :class:`~sup3r.preprocessing.Sampler` + See :class:`~sup3r.preprocessing.Sampler` for full documentation. spatial_weights : Union[np.ndarray, da.core.Array] | List | None Set of weights used to initialize the spatial sampling. e.g. If we want to start off sampling across 2 spatial bins evenly this should diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index 3142cdbc70..610e500ff8 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -47,58 +47,9 @@ def __init__( t_enhance : int Temporal enhancement factor feature_sets : Optional[dict] - Optional dictionary describing how the full set of features is - split between ``lr_features``, ``hr_exo_features``, and - ``hr_out_features``. - - lr_features : list | tuple - List of feature names or patt*erns to use as low-resolution - model inputs. If no entry is provided then all available - features from the data will be used. - hr_out_features : list | tuple - List of feature names or patt*erns that should be output - by the generative model and available as ground truth targets. - If no entry is provided then all features in the high res data - will be used. - hr_exo_features : list | tuple - List of feature names or patt*erns that should be available - as high-resolution model inputs (like topography or - observations) or bespoke loss functions. Features used for - input are injected into the model mid-network to condition - output on high-resolution information. The model configuration - should have the appropriate layers to use these features. e.g. - ``Sup3rConcat`` for topography injection, ``Sup3rObsModel`` or - ``Sup3rCrossAttention`` for obs injection. If no entry is - provided then hr_exo_features will be empty. - - *To include sparse features as inputs or targets the features - must have an "_obs" suffix. + See :class:`~sup3r.preprocessing.Sampler` for full documentation. proxy_obs_kwargs : dict | None - Optional dictionary of keyword arguments to pass to the proxy - observation generator. This is only used when training with proxy - observations. Keys can include ``onshore_obs_frac``, - ``offshore_obs_frac``, and ``perturbation_scale``. - - perturbation_scale : float - Scale of the perturbation to add to the proxy observations when - using proxy observations. This specifies the multiplier of the - noise sampled from (-standard deviation, standard deviation). - The standdard deviation is calculated per feature over each - batch. - onshore_obs_frac : float | dict - Fraction of onshore observations to include in each batch when - using proxy observations. This can be a single float or a - dictionary with keys 'spatial' and 'temporal' to specify the - fraction for each domain. If a dictionary is provided, the - actual fraction for each batch will be sampled uniformly - between the specified spatial and temporal fractions. - offshore_obs_frac : float | dict - Fraction of offshore observations to include in each batch when - using proxy observations. This can be a single float or a - dictionary with keys 'spatial' and 'temporal' to specify the - fraction for each domain. If a dictionary is provided, the - actual fraction for each batch will be sampled uniformly - between the specified spatial and temporal fractions. + See :class:`~sup3r.preprocessing.Sampler` for full documentation. mode : str Mode for sampling data. Options are 'lazy' or 'eager'. 'eager' mode pre-loads all data into memory as numpy arrays for faster access. diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 2278d622e6..eca81f80be 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -617,8 +617,7 @@ def call(self, x_true, x_gen): tf.shape(x_true), tf.shape(x_gen), message=( - 'LowResLoss requires x_true and x_gen to have matching ' - 'shapes' + 'LowResLoss requires x_true and x_gen to have matching shapes' ), ) s_only = x_true.shape.rank == 4 @@ -716,7 +715,6 @@ def call(self, x_true, x_gen): for x_true_f, x_gen_f in zip( tf.unstack(x_true, axis=-1), tf.unstack(x_gen, axis=-1) ): - # VGG input needs 3 RGB channels x_true_f = tf.stack([x_true_f] * 3, axis=-1) x_gen_f = tf.stack([x_gen_f] * 3, axis=-1) @@ -894,15 +892,12 @@ def call(self, x_true, x_gen): x_true = _assert_rank_in(x_true, (5,), msg) x_gen = _assert_rank_in(x_gen, (5,), msg) - x_true_div = tf.stack( - [ - self._compute_md(x_true, feature) - for feature in self.gen_features - ] - ) - x_gen_div = tf.stack( - [self._compute_md(x_gen, feature) for feature in self.gen_features] - ) + x_true_div = tf.stack([ + self._compute_md(x_true, feature) for feature in self.gen_features + ]) + x_gen_div = tf.stack([ + self._compute_md(x_gen, feature) for feature in self.gen_features + ]) return self.LOSS_METRIC(x_true_div, x_gen_div) @@ -1312,6 +1307,212 @@ def call(self, x_true, x_gen): return obs_loss +class ObsAssimilationLoss(Sup3rLoss): + """Loss for training with both dense and sparse ground truth where the obs + locations are explicitly considered. This is designed to encourage + matching observations where they exist while blending smoothly back to + the dense background field over a Gaussian neighbourhood around each + observation. + + Assumes ``true_features`` contains ``gen_features`` followed by sparse + observation versions of those same features in matching order. For + example:: + + gen_features = ['u_10m', 'v_10m'] + true_features = ['u_10m', 'v_10m', 'u_10m_obs', 'v_10m_obs'] + + The loss is a single MAE term between generator output and a blended + target field: + + 1. At obs locations, the target is the observed value. + 2. Away from obs, the target relaxes back toward the dense background. + 3. Between the two, a Gaussian kernel determines the observation + influence out to ``blend_distance`` grid cells. + """ + + LOSS_METRIC = MeanAbsoluteError() + + def __init__(self, gen_features, true_features=None, blend_distance=1): + """ + Parameters + ---------- + gen_features : list of str + Generator output feature names (N features). + true_features : list of str | None + True-data feature names. Must contain 2N entries: the first N + match ``gen_features`` (dense reference) and the last N are the + corresponding sparse observation features (NaN where missing). + Defaults to ``gen_features + [f + '_obs' for f in gen_features]``. + blend_distance : int + Spatial distance in grid cells over which the Gaussian + observation/background blend is applied. A value of 0 uses the + observation exactly at observed cells and the background + everywhere else. + """ + gen_features = list(gen_features) + if true_features is None: + true_features = gen_features + [f + '_obs' for f in gen_features] + true_features = list(true_features) + + n = len(gen_features) + if len(true_features) != 2 * n: + raise ValueError( + 'ObsBlendLoss requires len(true_features) == ' + f'2 * len(gen_features). Got {len(true_features)} true ' + f'features and {n} gen features.' + ) + + super().__init__( + gen_features=gen_features, true_features=true_features + ) + self._blend_distance = blend_distance + + def _get_gaussian_kernel(self, dtype): + """Get a 2-D Gaussian kernel for observation blending. + + Parameters + ---------- + dtype : tf.DType + Tensor dtype for the kernel. + + Returns + ------- + tf.Tensor + Gaussian kernel with shape ``(K, K)`` and center weight 1. + """ + r = self._blend_distance + if r == 0: + return tf.ones((1, 1), dtype=dtype) + + sigma = tf.cast(r / 2.0, dtype) + coords = tf.cast(tf.range(-r, r + 1), dtype) + weights = tf.exp(-0.5 * tf.square(coords / sigma)) + return tf.expand_dims(weights, 0) * tf.expand_dims(weights, 1) + + @staticmethod + def _to_spatial_4d(x): + """Reshape 4-D or 5-D tensors to 4-D for spatial filtering. + + Parameters + ---------- + x : tf.Tensor + Tensor with shape ``(B, H, W, C)`` or ``(B, H, W, T, C)``. + + Returns + ------- + tuple[tf.Tensor, tf.Tensor | None] + Reshaped 4-D tensor and the original shape tensor when input was + 5-D. The original shape is ``None`` for 4-D inputs. + """ + if x.shape.rank == 4: + return x, None + + shape = tf.shape(x) + x = tf.transpose(x, [0, 3, 1, 2, 4]) + return tf.reshape( + x, (shape[0] * shape[3], shape[1], shape[2], shape[4]) + ), shape + + @staticmethod + def _from_spatial_4d(x, original_shape): + """Restore a filtered tensor to its original 4-D or 5-D shape.""" + if original_shape is None: + return x + + x = tf.reshape( + x, + ( + original_shape[0], + original_shape[3], + original_shape[1], + original_shape[2], + original_shape[4], + ), + ) + return tf.transpose(x, [0, 2, 3, 1, 4]) + + def _gaussian_blend_target(self, x_true_bg, x_obs, obs_mask): + """Blend sparse observations into the dense background field. + + Parameters + ---------- + x_true_bg : tf.Tensor + Dense background tensor with shape ``(B, H, W, C)`` or + ``(B, H, W, T, C)``. + x_obs : tf.Tensor + Sparse observation tensor matching ``x_true_bg`` with NaNs where + observations are missing. + obs_mask : tf.Tensor + Boolean mask matching ``x_obs`` where True marks valid obs. + + Returns + ------- + tf.Tensor + Blended target tensor with no NaNs. + """ + if self._blend_distance == 0: + return tf.where(obs_mask, x_obs, x_true_bg) + + x_true_bg_4d, original_shape = self._to_spatial_4d(x_true_bg) + x_obs_4d, _ = self._to_spatial_4d(x_obs) + obs_mask_4d, _ = self._to_spatial_4d( + tf.cast(obs_mask, x_true_bg.dtype) + ) + + kernel = self._get_gaussian_kernel(x_true_bg.dtype) + channels = tf.shape(x_true_bg_4d)[-1] + kernel = tf.tile( + kernel[..., tf.newaxis, tf.newaxis], [1, 1, channels, 1] + ) + + obs_values = tf.where( + obs_mask_4d > 0, x_obs_4d, tf.zeros_like(x_obs_4d) + ) + obs_weight = tf.nn.depthwise_conv2d( + obs_mask_4d, kernel, strides=[1, 1, 1, 1], padding='SAME' + ) + obs_smooth = tf.nn.depthwise_conv2d( + obs_values, kernel, strides=[1, 1, 1, 1], padding='SAME' + ) + obs_blend = tf.math.divide_no_nan(obs_smooth, obs_weight) + blend_weight = tf.minimum(obs_weight, tf.ones_like(obs_weight)) + target = blend_weight * obs_blend + (1.0 - blend_weight) * x_true_bg_4d + target = tf.where(obs_mask_4d > 0, x_obs_4d, target) + return self._from_spatial_4d(target, original_shape) + + @tf.function + def call(self, x_true, x_gen): + """Evaluate the sparse observation loss. + + Parameters + ---------- + x_true : tf.Tensor + True data of shape ``(B, H, W, 2C)`` or ``(B, H, W, T, 2C)``. + The first C channels are dense reference fields; the last C + channels are sparse obs fields (NaN where no observation). + x_gen : tf.Tensor + Generator output of shape ``(B, H, W, C)`` or ``(B, H, W, T, C)``. + + Returns + ------- + tf.Tensor + Scalar (0-D) loss value. + """ + dtype = tf.as_dtype(tf.keras.backend.floatx()) + x_true = tf.cast(x_true, dtype) + x_gen = tf.cast(x_gen, dtype) + + n = len(self.gen_features) + x_true_bg = x_true[..., :n] # dense reference, shape (..., n) + x_obs = x_true[ + ..., n: + ] # sparse obs, shape (..., n); NaN where missing + + obs_mask = ~tf.math.is_nan(x_obs) + target = self._gaussian_blend_target(x_true_bg, x_obs, obs_mask) + return self.LOSS_METRIC(target, x_gen) + + def _reshape_depth_feature_for_vertical_derivative(x): """Reshape a stacked depth tensor for use with tf_derivative. diff --git a/tests/data_wrapper/test_access.py b/tests/data_wrapper/test_access.py index 6d35e10592..93c0662631 100644 --- a/tests/data_wrapper/test_access.py +++ b/tests/data_wrapper/test_access.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from sup3r.preprocessing import Dimension, Loader +from sup3r.preprocessing import Loader from sup3r.preprocessing.accessor import Sup3rX from sup3r.preprocessing.base import Sup3rDataset from sup3r.utilities.pytest.helpers import ( @@ -70,9 +70,7 @@ def test_correct_single_member_access(data): _ = data['u', 0, 0, 0, 0] assert data['u'][0, 0, 0, 0].shape == () assert ['u', 'v'] in data - out = data[[Dimension.LATITUDE, Dimension.LONGITUDE]][:] - assert out.shape == (20, 20, 2) - assert np.array_equal(np.asarray(out), np.asarray(data.lat_lon)) + assert data.lat_lon.shape == (20, 20, 2) assert len(data.time_index) == 100 out = data.isel(time=slice(0, 10)) assert out.sx.as_array().shape == (20, 20, 10, 3, 2) @@ -104,14 +102,9 @@ def test_correct_multi_member_access(): _ = data['u'] _ = data[['u', 'v']] - out = data[[Dimension.LATITUDE, Dimension.LONGITUDE]][...] lat_lon = data.lat_lon time_index = data.time_index - assert all(o.shape == (20, 20, 2) for o in out) - assert all( - np.array_equal(np.asarray(o), np.asarray(ll)) - for o, ll in zip(out, lat_lon) - ) + assert all(o.shape == (20, 20, 2) for o in lat_lon) assert all(len(ti) == 100 for ti in time_index) out = data.isel(time=slice(0, 10)) assert (o.as_array().shape == (20, 20, 10, 3, 2) for o in out) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 8a320aaa60..e1c690ed8c 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -926,3 +926,87 @@ def test_slicing_pad(input_files): assert chunk.input_data.shape == padded_truth.shape assert np.allclose(chunk.input_data, padded_truth) + + +def test_fwp_all_features_as_exo(input_files): + """Test that ForwardPassStrategy and ForwardPass correctly handles the edge + case where all low-resolution model features are provided through + ``exo_handler_kwargs``. + + When all lr features are in ``exo_handler_kwargs``, ``_init_features`` + returns ``features=[]``. The ``DataHandler`` is then initialized with no + features, so ``input_handler[[]].isel(...).as_array()`` must return an + array shaped ``(s1, s2, t, 0)`` (zero feature channels). However, the + ``to_dataarray()`` fallback in ``accessor.py`` currently returns + latitude/longitude coordinate arrays stacked as ``(s1, s2, 2)`` instead, + which breaks downstream exo temporal-expansion logic in + ``pad_source_data`` (``input_data.shape[2]`` yields 2 rather than the + number of time steps). + """ + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES + model.meta['s_enhance'] = s_enhance + model.meta['t_enhance'] = t_enhance + _ = model.generate(np.ones((4, 10, 10, 12, len(FEATURES)))) + + with tempfile.TemporaryDirectory() as td: + model_dir = os.path.join(td, 'model') + model.save(model_dir) + + # Route all lr features through exo_handler_kwargs so that + # _init_features returns features=[]. + exo_handler_kwargs = {f: {'file_paths': input_files} for f in FEATURES} + input_handler_kwargs = { + 'target': target, + 'shape': shape, + 'time_slice': time_slice, + } + + # head_node=True skips load_exo_data (no cache_dir in exo kwargs, + # max_nodes=1), letting us test the _init_features / input_data shape + # logic without needing fully-configured ExoDataHandler source files. + strat = ForwardPassStrategy( + input_files, + model_kwargs={'model_dir': model_dir}, + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=0, + temporal_pad=0, + input_handler_kwargs=input_handler_kwargs, + exo_handler_kwargs=exo_handler_kwargs, + head_node=True, + max_nodes=1, + ) + + # _init_features must return [] when all lr features are exo. + assert strat.features == [], ( + f'Expected features=[] when all lr features are in ' + f'exo_handler_kwargs, got features={strat.features}' + ) + + # Replicate the slice kwargs built by prep_chunk_data for chunk 0. + lr_pad_slice = strat.lr_pad_slices[0] + ti_pad_slice = strat.ti_pad_slices[0] + kwargs = dict(zip(Dimension.dims_2d(), lr_pad_slice)) + kwargs[Dimension.TIME] = ti_pad_slice + + input_data = strat.input_handler[strat.features].isel(**kwargs) + input_data.load() + arr = input_data.as_array() + + # input_data should carry 0 feature channels with a valid time axis, + # not lat/lon coordinate arrays collapsed into the last dimension. + assert arr.ndim == 4, ( + f'input_data.as_array() returned a {arr.ndim}D array with shape ' + f'{arr.shape}; expected 4D (s1, s2, t, 0). This indicates that ' + 'to_dataarray() falls back to returning coordinate arrays ' + '(lat/lon) when features=[], producing shape (s1, s2, 2) instead ' + 'of (s1, s2, t, 0).' + ) + assert arr.shape[-1] == 0, ( + f'Expected 0 feature channels (features=[]), got shape {arr.shape}' + ) diff --git a/tests/samplers/test_with_obs.py b/tests/samplers/test_with_obs.py index d1f1a20197..b88db0df63 100644 --- a/tests/samplers/test_with_obs.py +++ b/tests/samplers/test_with_obs.py @@ -152,3 +152,31 @@ def test_proxy_obs_onshore_offshore_topography_fractions(sampler_cls): assert np.isclose(onshore_frac, 0.8, atol=0.12) assert np.isclose(offshore_frac, 0.1, atol=0.08) assert onshore_frac > offshore_frac + + +@pytest.mark.parametrize('sampler_cls', [Sampler, DualSampler]) +def test_proxy_obs_per_feature_override(sampler_cls): + """Feature-level override in proxy_obs_kwargs yields different observed + fractions per obs channel.""" + sampler = _make_sampler( + sampler_cls=sampler_cls, + hr_shape=(60, 60, 500), + sample_shape=(30, 30, 1), + batch_size=20, + proxy_obs_kwargs={ + 'onshore_obs_frac': {'spatial': 0.1, 'temporal': 1.0}, + 'u_100m': { + 'onshore_obs_frac': {'spatial': 0.8, 'temporal': 1.0} + }, + }, + ) + + batch = _get_hr_batch(sampler) + obs = batch[..., -2:] + + u_frac = np.isfinite(obs[..., 0]).mean() # u_100m_obs (overridden to 0.8) + v_frac = np.isfinite(obs[..., 1]).mean() # v_100m_obs (global default 0.1) + + assert u_frac > v_frac + assert np.isclose(u_frac, 0.8, atol=0.1) + assert np.isclose(v_frac, 0.1, atol=0.08) diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index 3fa71c8a62..25e558a4e2 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -17,6 +17,7 @@ LowResLoss, MaterialDerivativeLoss, MmdLoss, + ObsAssimilationLoss, SpatialExtremesLoss, SpatiotemporalFftLoss, TemporalExtremesLoss, @@ -266,6 +267,43 @@ def test_lr_loss(): assert ex_loss > loss +def test_obs_assimilation_loss_no_obs_matches_background_mae(): + """Obs assimilation loss should reduce to background MAE without obs.""" + x_true_bg = np.arange(25, dtype=np.float32).reshape(1, 5, 5, 1) + x_obs = np.full_like(x_true_bg, np.nan) + x_true = np.concatenate([x_true_bg, x_obs], axis=-1) + x_gen = x_true_bg + 2.0 + + loss = ObsAssimilationLoss( + gen_features=['f'], + true_features=['f', 'f_obs'], + blend_distance=2, + )(x_true, x_gen) + + assert np.allclose(loss, MeanAbsoluteError()(x_true_bg, x_gen)) + + +def test_obs_assimilation_loss_blends_obs_into_background(): + """Obs assimilation loss should apply Gaussian blending around obs.""" + x_true_bg = np.zeros((1, 5, 5, 1), dtype=np.float32) + x_obs = np.full_like(x_true_bg, np.nan) + x_obs[:, 2, 2, 0] = 1.0 + x_true = np.concatenate([x_true_bg, x_obs], axis=-1) + x_gen = np.zeros_like(x_true_bg) + + loss = ObsAssimilationLoss( + gen_features=['f'], + true_features=['f', 'f_obs'], + blend_distance=1, + )(x_true, x_gen) + + edge_weight = np.exp(-2.0) + corner_weight = np.exp(-4.0) + expected = (1.0 + 4 * edge_weight + 4 * corner_weight) / 25.0 + + assert np.allclose(loss, expected) + + def test_md_loss(): """Test the material derivative calculation in the material derivative content loss class."""