Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ dependencies = [
# # 8.9.7

[tool.pixi.target.linux-64.pypi-dependencies]
tensorflow = {version = "~=2.16.1", extras = ["and-cuda"] }
tensorflow = {version = "~=2.16.2", extras = ["and-cuda"] }

[tool.pixi.target.osx-arm64.dependencies]
tensorflow = {version = "~=2.16.1", channel = "conda-forge"}
tensorflow = {version = "~=2.16.2", channel = "conda-forge"}

[project.optional-dependencies]
dev = [
Expand Down
36 changes: 34 additions & 2 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,8 @@ def _run_mirrored_grad(
)
for key, value in per_replica_details.items()
}
apply_fn(total_grad)
with self.strategy.scope():
apply_fn(total_grad)
return mean_loss_details

@tf.function(reduce_retracing=True)
Expand Down Expand Up @@ -1459,6 +1460,36 @@ def _get_train_fns(self, train_gen=True, train_disc=False):
logger.error(msg)
raise ValueError(msg)

def _mask_obs_in_exo(self, hi_res_exo):
"""Randomly mask a fraction of non-NaN obs values in the exo dict.

For each key in ``self.obs_features``, ``self._obs_mask_fraction`` of
the existing non-NaN locations are replaced with NaN. All other
keys are returned unchanged. This is called inside
``get_single_grad_gen`` so that the generator receives sparser
observations than those used by the loss.

Parameters
----------
hi_res_exo : dict
Exogenous high-resolution input dict returned by
``get_hr_exo_input``.

Returns
-------
dict
Copy of ``hi_res_exo`` with obs features partially masked.
"""
out = dict(hi_res_exo)
for k in self.obs_features:
v = out[k]
Comment on lines +1483 to +1485
not_nan = tf.math.logical_not(tf.math.is_nan(v))
rand = tf.random.uniform(tf.shape(v), dtype=v.dtype)
drop = tf.math.logical_and(not_nan, rand < self._obs_mask_fraction)
nan_fill = tf.fill(tf.shape(v), tf.cast(float('nan'), v.dtype))
out[k] = tf.where(drop, nan_fill, v)
return out

@tf.function(reduce_retracing=True)
def get_single_grad_gen(
self,
Expand All @@ -1470,14 +1501,15 @@ def get_single_grad_gen(
"""Run generator-only gradient calculation for one mini-batch."""
with self._training_scope(device_name), tf.GradientTape() as tape:
hi_res_exo = self.get_hr_exo_input(hi_res_true)
hi_res_exo = self._mask_obs_in_exo(hi_res_exo)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss, loss_details = self.calc_loss(
hi_res_true, hi_res_gen, **calc_loss_kwargs
)
grad = tape.gradient(loss, self.generator_weights)
return grad, loss_details

@tf.function
@tf.function(reduce_retracing=True)
def apply_grad_gen(self, grad):
"""Apply a generator gradient update."""
self.optimizer.apply_gradients(zip(grad, self.generator_weights))
Expand Down
15 changes: 14 additions & 1 deletion sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
default_device=None,
name=None,
sparse_disc=False,
obs_mask_fraction=0.0,
):
"""
Parameters
Expand Down Expand Up @@ -105,6 +106,15 @@ def __init__(
observations for training. Note that if True, the discriminator
model architecture should be designed to handle sparse data (e.g.
by using masking layers or other techniques).
obs_mask_fraction : float
Fraction of non-NaN observation values to randomly mask (set to
NaN) in the exogenous obs input to the generator during training.
This is applied *after* ``get_hr_exo_input`` so the loss (which
uses the unmasked ``hi_res_true``) still sees the full observation
density, while the generator must learn to infer spatial structure
from the sparser input. Only features whose key ends with
``'_obs'`` are masked; e.g. topography is not affected. Default
is 0.0 (no additional masking).
name : str | None
Optional name for the GAN.
"""
Expand Down Expand Up @@ -146,6 +156,7 @@ def __init__(
self._means = means
self._stdevs = stdevs
self._sparse_disc = sparse_disc
self._obs_mask_fraction = obs_mask_fraction

def save(self, out_dir):
"""Save the GAN with its sub-networks to a directory.
Expand Down Expand Up @@ -374,7 +385,7 @@ def get_single_grad_disc(
grad = tape.gradient(loss, self.discriminator_weights)
return grad, loss_details

@tf.function
@tf.function(reduce_retracing=True)
def apply_grad_disc(self, grad):
"""Apply a discriminator gradient update."""
self.optimizer_disc.apply_gradients(
Expand Down Expand Up @@ -453,6 +464,8 @@ def model_params(self):
'stdevs': stdevs,
'meta': self.meta,
'default_device': self.default_device,
'sparse_disc': self._sparse_disc,
'obs_mask_fraction': self._obs_mask_fraction,
}

@property
Expand Down
4 changes: 2 additions & 2 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def run(cls, strategy, node_index):
"""
if not strategy.node_finished(node_index):
logger.info(
'Starting forward pass on node %s with %s chunks using %s '
'Starting forward pass on node %s with %s chunk(s) using %s '
'execution.',
node_index,
len(strategy.node_chunks[node_index]),
Expand Down Expand Up @@ -517,7 +517,7 @@ def _run_serial(cls, strategy, node_index):
raise MemoryError(msg)

logger.info(
'Finished forward passes on %s chunks in %s',
'Finished forward pass(es) on %s chunk(s) in %s',
len(strategy.node_chunks[node_index]),
dt.now() - start,
)
Expand Down
5 changes: 5 additions & 0 deletions sup3r/pipeline/forward_pass_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None):
basename = config.get('job_name')
log_pattern = config.get('log_pattern', None)

logger.info(
'Initializing forward pass strategy on head node to '
'compute chunk distribution indices: %s',
config_file,
)
sig = signature(ForwardPassStrategy)
strategy_kwargs = {k: v for k, v in config.items() if k in sig.parameters}
strategy = ForwardPassStrategy(**strategy_kwargs, head_node=True)
Expand Down
19 changes: 9 additions & 10 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ def get_time_slices(self):
return unpadded_slice, padded_slice

def init_input_handler(self):
"""Get input handler instance for given input kwargs. If self.head_node
is False or features are being cached we get all requested features.
Otherwise this is part of initialization on a head node and just used
to get the shape of the input domain, so we don't need to get any
features yet."""
"""Get input handler instance for given input kwargs. If
`self.head_node` is False or features are being cached we get all
requested features. Otherwise this is part of initialization on a head
node and just used to get the shape of the input domain, so we don't
need to get any features yet."""
self.input_handler_kwargs = self.input_handler_kwargs or {}
self.input_handler_kwargs['file_paths'] = self.file_paths
self.input_handler_kwargs['features'] = self.features
Expand All @@ -374,10 +374,9 @@ def init_input_handler(self):

input_handler_kwargs['time_slice'] = self.padded_time_slice
logger.info(
'Initializing %s for %s features over padded time slice %s.',
InputHandler.__name__,
'Loading low-resolution data for %s features: %s',
len(input_handler_kwargs['features']),
self.padded_time_slice,
input_handler_kwargs['features'],
)
handler = InputHandler(**input_handler_kwargs)
logger.info(
Expand Down Expand Up @@ -469,7 +468,7 @@ def preflight(self):
non_masked = self.fwp_slicer.n_spatial_chunks - sum(self.fwp_mask)
non_masked *= self.fwp_slicer.n_time_chunks
logger.info(
'Chunk strategy uses %s nodes across %s total chunks '
'Chunk strategy uses %s node(s) across %s chunk(s): '
'(%s spatial x %s temporal, %s unmasked).',
len(self.node_chunks),
self.fwp_slicer.n_chunks,
Expand Down Expand Up @@ -689,7 +688,7 @@ def load_exo_data(self, model):
data.update(ExoDataHandler(**exo_kwargs).data)
exo_data = ExoData(data)
if exo_kwargs_list:
logger.info(
logger.debug(
'Finished loading exogenous data for %s features.',
len(exo_kwargs_list),
)
Expand Down
21 changes: 15 additions & 6 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,18 @@ def values(self):
return np.asarray(out)
return out

def to_dataarray(self) -> Union[np.ndarray, da.core.Array]:
def to_dataarray(self) -> xr.DataArray:
"""Return xr.DataArray for the contained xr.Dataset."""
if not self.features:
coords = [self._ds[f] for f in Dimension.coords_2d()]
return da.stack(coords, axis=-1)
# xarray raises when to_array() is called on an empty dataset.
# Return a zero-variable DataArray with the correct dims.
spatial_time_dims = tuple(
d for d in Dimension.order() if d in self._ds.sizes
)
dims = (*spatial_time_dims, Dimension.VARIABLE)
return xr.DataArray(
np.empty(self.shape, dtype=np.float32), dims=dims
)
return self.ordered(self._ds.to_array())

def as_array(self):
Expand All @@ -259,8 +266,8 @@ def as_array(self):

out = self.to_dataarray()
out = getattr(out, 'data', out)

if self.loaded:
out = np.asarray(out)
self._as_array_cache = out

return out
Expand Down Expand Up @@ -670,8 +677,10 @@ def set_regular_grid(self):
np.diff(self._ds[Dimension.LONGITUDE].values, axis=0), 0
)
if not (lat_lon_2d and same_lats and same_lons):
msg = ('Cannot set regular grid for non-regular data. Latitude '
f'check = {same_lats}, Longitude check = {same_lons}.')
msg = (
'Cannot set regular grid for non-regular data. Latitude '
f'check = {same_lats}, Longitude check = {same_lons}.'
)
logger.warning(msg)
warn(msg)
else:
Expand Down
15 changes: 6 additions & 9 deletions sup3r/preprocessing/batch_queues/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,14 @@ def __init__(self, samplers, n_space_bins=1, n_time_bins=1, **kwargs):
Parameters
----------
samplers : list[Sampler]
List of Sampler instances
See :class:`~sup3r.preprocessing.BatchQueueDC`
for full documentation.
n_space_bins : int
Number of spatial bins to use for weighted sampling. e.g. if this
is 4 the spatial domain will be divided into 4 equal regions and
losses will be calculated across these regions during traning in
order to adaptively sample from lower performing regions.
See :class:`~sup3r.preprocessing.BatchQueueDC`
for full documentation.
n_time_bins : int
Number of time bins to use for weighted sampling. e.g. if this
is 4 the temporal domain will be divided into 4 equal periods and
losses will be calculated across these periods during traning in
order to adaptively sample from lower performing time periods.
See :class:`~sup3r.preprocessing.BatchQueueDC`
for full documentation.
kwargs : dict
Keyword arguments for parent class.
"""
Expand Down
2 changes: 1 addition & 1 deletion sup3r/preprocessing/derivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
# derivations can reuse them instead of recomputing the same feature.
for f in new_features:
self.data[f] = self.derive(f)
logger.info('Finished deriving %s.', f)
logger.debug('Finished deriving %s.', f)
self.data = (
self.data[list(self.data.coords)]
if not features
Expand Down
Loading
Loading