From d718c192db468773c66152dbcd45bd34d4118443 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 25 Apr 2026 17:03:26 -0600 Subject: [PATCH 01/34] feat: implement multi-GPU support and related optimizations in Sup3r models --- sup3r/models/abstract.py | 228 +++++++++++++++++++++++-------- sup3r/models/base.py | 102 +++++++++----- sup3r/models/conditional.py | 15 +- tests/training/test_train_gan.py | 64 +++++++++ 4 files changed, 310 insertions(+), 99 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index b0f61a125..03ebfbbda 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -6,7 +6,7 @@ import pprint import time from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import nullcontext from warnings import warn import numpy as np @@ -36,10 +36,12 @@ class AbstractSingleModel(ABC, TensorboardMixIn): def __init__(self): super().__init__() self.gpu_list = tf.config.list_physical_devices('GPU') - self.default_device = '/cpu:0' if len(self.gpu_list) == 0 else '/gpu:0' + self.default_device = '/gpu:0' if len(self.gpu_list) > 0 else '/cpu:0' self.timer = Timer() self.name = None self.loss_name = None + self._strategy = None + self._multi_gpu = False self._loss_fun = None self._version_record = VERSION_RECORD self._meta = None @@ -73,7 +75,8 @@ def load_network(self, model, name): model = self._load_model_from_string(model, name) if isinstance(model, list): - model = CustomNetwork(hidden_layers=model, name=name) + with self._training_scope(): + model = CustomNetwork(hidden_layers=model, name=name) if not isinstance(model, CustomNetwork): msg = ( @@ -88,7 +91,7 @@ def load_network(self, model, name): def _load_model_from_string(self, model, name): """Load a CustomNetwork object from a config or a .pkl file""" if model.endswith('.pkl'): - with tf.device(self.default_device): + with self._training_scope(self.default_device): return CustomNetwork.load(model) model = load_config(model) @@ -111,6 +114,56 @@ def _load_model_from_string(self, model, name): logger.error(msg) raise KeyError(msg) + @property + def strategy(self): + """Optional TensorFlow distribution strategy.""" + return self._strategy + + def configure_multi_gpu(self, multi_gpu=False): + """Configure optional multi-GPU training through MirroredStrategy. + + Parameters + ---------- + multi_gpu : bool + Flag to create a MirroredStrategy automatically when multiple GPUs + are available. + """ + + self._multi_gpu = bool(multi_gpu) + if self._strategy is not None or not multi_gpu: + return self._strategy + + if len(self.gpu_list) > 1: + devices = [f'/gpu:{i}' for i in range(len(self.gpu_list))] + self._strategy = tf.distribute.MirroredStrategy(devices=devices) + else: + self._strategy = None + + if self._strategy is not None: + logger.info( + 'Configured distribution strategy "%s" with %s replicas.', + self._strategy.__class__.__name__, + self._strategy.num_replicas_in_sync, + ) + elif multi_gpu: + logger.warning( + 'multi_gpu=True was requested but fewer than two GPUs are ' + 'available. Falling back to non-strategy execution.' + ) + + def _training_scope(self, device=None): + """Get a strategy scope or a concrete device context.""" + if tf.distribute.get_replica_context() is not None: + if device is not None: + return tf.device(device) + return nullcontext() + + if self.strategy is not None: + return self.strategy.scope() + if device is not None: + return tf.device(device) + return nullcontext() + @property def means(self): """Get the data normalization mean values. @@ -284,6 +337,11 @@ def optimizer(self): ------- tf.keras.optimizers.Optimizer """ + if self._optimizer is None: + with self._training_scope(): + self._optimizer = self.init_optimizer( + self._optimizer_config, learning_rate=None + ) return self._optimizer def update_optimizer_gen(self, **kwargs): @@ -294,9 +352,11 @@ def update_optimizer_gen(self, **kwargs): kwargs : dict kwargs to use for optimizer configuration update """ - conf = self.get_optimizer_config(self.optimizer) + conf = self._optimizer_config.copy() conf.update(**kwargs) - self._optimizer = self.optimizer.__class__.from_config(conf) + self._optimizer_config = conf + if self.optimizer is not None: + self._optimizer = self.optimizer.__class__.from_config(conf) @property def history(self): @@ -360,7 +420,7 @@ def _init_generator_weights(self, lr_shape, hr_shape, device=None): exo_tensor = tf.cast(np.ones(exo_shape), dtype=tf.float32) hi_res_exo = dict.fromkeys(self.hr_exo_features, exo_tensor) - with tf.device(device): + with self._training_scope(device): out = self._tf_generate(low_res, hi_res_exo) msg = ( @@ -398,6 +458,23 @@ def init_optimizer(optimizer, learning_rate): return optimizer + @classmethod + def get_optimizer_init_config(cls, optimizer, learning_rate): + """Get a serializable optimizer config from init inputs.""" + if isinstance(optimizer, dict): + conf = optimizer.copy() + else: + conf = cls.get_optimizer_config( + cls.init_optimizer(optimizer, learning_rate) + ) + + for key, value in conf.items(): + if np.issubdtype(type(value), np.floating): + conf[key] = float(value) + elif np.issubdtype(type(value), np.integer): + conf[key] = int(value) + return conf + @staticmethod def load_saved_params(out_dir, verbose=True): """Load saved model_params (you need this and the gen+disc models @@ -906,7 +983,31 @@ def finish_epoch( return stop - def _run_parallel_grad( + def _distribute_value(self, value, num_replicas): + """Split a batch-like tensor and distribute one chunk per replica.""" + value = tf.convert_to_tensor(value) + chunks = tf.split(value, num_replicas, axis=0) + return self.strategy.experimental_distribute_values_from_function( + lambda ctx: chunks[ctx.replica_id_in_sync_group] + ) + + def _distribute_calc_loss_kwargs( + self, calc_loss_kwargs, batch_size, num_replicas + ): + """Distribute batch-shaped loss kwargs and pass through scalars.""" + distributed = {} + for key, value in calc_loss_kwargs.items(): + if ( + isinstance(value, (np.ndarray, tf.Tensor)) + and value.shape + and value.shape[0] == batch_size + ): + distributed[key] = self._distribute_value(value, num_replicas) + continue + distributed[key] = value + return distributed + + def _run_mirrored_grad( self, low_res, hi_res_true, @@ -914,46 +1015,54 @@ def _run_parallel_grad( apply_fn, **calc_loss_kwargs, ): - """Compute gradient for one mini-batch of (low_res, hi_res_true) - across multiple GPUs""" - - lr_chunks = tf.split(low_res, len(self.gpu_list), axis=0) - hr_true_chunks = tf.split(hi_res_true, len(self.gpu_list), axis=0) - calc_loss_kwargs_chunks = [ - dict(calc_loss_kwargs) for _ in range(len(self.gpu_list)) - ] - if 'mask' in calc_loss_kwargs: - mask_chunks = tf.split( - calc_loss_kwargs['mask'], len(self.gpu_list), axis=0 + """Compute gradient for one mini-batch using MirroredStrategy.""" + + if self.strategy is None: + msg = ( + 'Mirrored strategy execution requested but no strategy is ' + 'configured on the model.' ) - for i, mask_chunk in enumerate(mask_chunks): - calc_loss_kwargs_chunks[i]['mask'] = mask_chunk - - futures = [] - with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: - for i in range(len(self.gpu_list)): - futures.append( - exe.submit( - grad_fn, - lr_chunks[i], - hr_true_chunks[i], - device_name=f'/gpu:{i}', - **calc_loss_kwargs_chunks[i], - ) - ) - # sum the gradients from each gpu to weight equally in - # optimizer momentum calculation - grads = [] - details = [] - for future in as_completed(futures): - grad, loss_details = future.result() - grads.append(grad) - details.append(loss_details) + logger.error(msg) + raise RuntimeError(msg) + + num_replicas = self.strategy.num_replicas_in_sync + batch_size = low_res.shape[0] + if batch_size % num_replicas != 0: + msg = ( + 'Batch size must be divisible by the number of mirrored ' + f'replicas. Received batch_size={batch_size} and ' + f'num_replicas={num_replicas}.' + ) + logger.error(msg) + raise ValueError(msg) + + dist_low_res = self._distribute_value(low_res, num_replicas) + dist_hi_res_true = self._distribute_value(hi_res_true, num_replicas) + dist_loss_kwargs = self._distribute_calc_loss_kwargs( + calc_loss_kwargs, batch_size, num_replicas + ) + + per_replica_grad, per_replica_details = self.strategy.run( + grad_fn, + args=(dist_low_res, dist_hi_res_true), + kwargs=dist_loss_kwargs, + ) + total_grad = tf.nest.map_structure( - lambda *x: tf.reduce_sum(x, axis=0), *grads + lambda grad: ( + None + if grad is None + else self.strategy.reduce( + tf.distribute.ReduceOp.MEAN, grad, axis=None + ) + ), + per_replica_grad, ) mean_loss_details = { - k: tf.reduce_mean([d[k] for d in details]) for k in details[0] + key: self.strategy.reduce( + tf.distribute.ReduceOp.MEAN, value, axis=None + ) + for key, value in per_replica_details.items() } apply_fn(total_grad) return mean_loss_details @@ -1006,12 +1115,11 @@ def run_gradient_descent( you're training just the generator or one of the discriminator models. Defaults to the generator optimizer. multi_gpu : bool - Flag to break up the batch for parallel gradient descent - calculations on multiple gpus. If True and multiple GPUs are - present, each batch from the batch_handler will be divided up - between the GPUs and resulting gradients from each GPU will be - summed and then applied once per batch at the nominal learning - rate that the model and optimizer were initialized with. + Flag to use multi-GPU distributed training. If True and a + strategy has been configured, the mini-batch will be distributed + across replicas and gradients will be reduced before a single + optimizer update is applied. If True and no strategy is + configured, this method falls back to serial execution. calc_loss_kwargs : dict Kwargs to pass to the self.calc_loss() method @@ -1024,8 +1132,13 @@ def run_gradient_descent( grad_fn, apply_fn = self._get_train_fns( train_gen=train_gen, train_disc=train_disc ) - if not multi_gpu or len(self.gpu_list) < 2: - loss_details = self._run_serial_grad( + use_strategy = ( + multi_gpu + and self.strategy is not None + and self.strategy.num_replicas_in_sync >= 1 + ) + if use_strategy: + loss_details = self._run_mirrored_grad( low_res, hi_res_true, grad_fn=grad_fn, @@ -1033,11 +1146,12 @@ def run_gradient_descent( **calc_loss_kwargs, ) msg = ( - 'Finished single gradient descent step in ' + 'Finished mirrored gradient descent step on ' + f'{self.strategy.num_replicas_in_sync} replicas in ' f'{time.time() - start_time:.4f} seconds' ) else: - loss_details = self._run_parallel_grad( + loss_details = self._run_serial_grad( low_res, hi_res_true, grad_fn=grad_fn, @@ -1045,8 +1159,8 @@ def run_gradient_descent( **calc_loss_kwargs, ) msg = ( - f'Finished gradient descent steps on {len(self.gpu_list)} ' - f'GPUs in {time.time() - start_time:.4f} seconds' + 'Finished single gradient descent step in ' + f'{time.time() - start_time:.4f} seconds' ) logger.debug(msg) return loss_details @@ -1352,7 +1466,7 @@ def get_single_grad_gen( **calc_loss_kwargs, ): """Run generator-only gradient calculation for one mini-batch.""" - with tf.device(device_name), tf.GradientTape() as tape: + with self._training_scope(device_name), tf.GradientTape() as tape: hi_res_exo = self.get_hr_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( diff --git a/sup3r/models/base.py b/sup3r/models/base.py index d826fce15..545a9bd91 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1,6 +1,5 @@ """Sup3r model software""" -import copy import logging import os import pprint @@ -27,6 +26,7 @@ def __init__( self, gen_layers, disc_layers, + *, loss='MeanSquaredError', optimizer=None, learning_rate=1e-4, @@ -95,8 +95,7 @@ def __init__( Option for default device placement of model weights. If None and a single GPU exists, that GPU will be the default device. If None and multiple GPUs exist, the first GPU will be the default device - (this was tested as most efficient given the custom multi-gpu - strategy developed in self.run_gradient_descent()). Examples: + for serial execution and weight initialization. Examples: "/gpu:0" or "/cpu:0" sparse_disc : bool Flag to indicate if the discriminator can handle sparse input data. @@ -106,7 +105,6 @@ def __init__( observations for training. Note that if True, the discriminator model architecture should be designed to handle sparse data (e.g. by using masking layers or other techniques). - name : str | None Optional name for the GAN. """ @@ -127,15 +125,23 @@ def __init__( self._init_records() - optimizer_disc = optimizer_disc or copy.deepcopy(optimizer) - learning_rate_disc = learning_rate_disc or learning_rate - self._optimizer = self.init_optimizer(optimizer, learning_rate) - self._optimizer_disc = self.init_optimizer( + if optimizer_disc is None: + optimizer_disc = optimizer + if learning_rate_disc is None: + learning_rate_disc = learning_rate + + self._optimizer = None + self._optimizer_disc = None + self._optimizer_config = self.get_optimizer_init_config( + optimizer, learning_rate + ) + self._optimizer_disc_config = self.get_optimizer_init_config( optimizer_disc, learning_rate_disc ) - self._gen = self.load_network(gen_layers, 'generator') - self._disc = self.load_network(disc_layers, 'discriminator') + with self._training_scope(): + self._gen = self.load_network(gen_layers, 'generator') + self._disc = self.load_network(disc_layers, 'discriminator') self._means = means self._stdevs = stdevs @@ -205,7 +211,11 @@ def _load(cls, model_dir, verbose=True): return fp_gen, fp_disc, params @classmethod - def load(cls, model_dir, verbose=True): + def load( + cls, + model_dir, + verbose=True, + ): """Load the GAN with its sub-networks from a previously saved-to output directory. @@ -337,6 +347,11 @@ def optimizer_disc(self): ------- tf.keras.optimizers.Optimizer """ + if self._optimizer_disc is None: + with self._training_scope(): + self._optimizer_disc = self.init_optimizer( + self._optimizer_disc_config, learning_rate=None + ) return self._optimizer_disc @tf.function @@ -348,7 +363,9 @@ def get_single_grad_disc( **calc_loss_kwargs, ): """Run discriminator-only gradient calculation for one mini-batch.""" - with tf.device(device_name), tf.GradientTape() as tape: + with self._training_scope( + device_name + ), tf.GradientTape() as tape: hi_res_exo = self.get_hr_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( @@ -378,9 +395,13 @@ def update_optimizer_disc(self, **kwargs): kwargs to use for optimizer configuration update """ - conf = self.get_optimizer_config(self.optimizer_disc) + conf = self._optimizer_disc_config.copy() conf.update(**kwargs) - self._optimizer_disc = self.optimizer_disc.__class__.from_config(conf) + self._optimizer_disc_config = conf + if self.optimizer_disc is not None: + self._optimizer_disc = self.optimizer_disc.__class__.from_config( + conf + ) def update_optimizer(self, option='generator', **kwargs): """Update optimizer by changing current configuration with kwargs and @@ -430,8 +451,8 @@ def model_params(self): 'name': self.name, 'loss': self.loss_name, 'version_record': self.version_record, - 'optimizer': self.get_optimizer_config(self.optimizer), - 'optimizer_disc': self.get_optimizer_config(self.optimizer_disc), + 'optimizer': self._optimizer_config, + 'optimizer_disc': self._optimizer_disc_config, 'means': means, 'stdevs': stdevs, 'meta': self.meta, @@ -478,7 +499,7 @@ def init_weights(self, lr_shape, hr_shape, train_disc=False, device=None): 'Initializing discriminator weights on device "%s"', device ) hi_res = tf.cast(np.ones(hr_shape), dtype=tf.float32) - with tf.device(device): + with self._training_scope(device): _ = self._tf_discriminate(hi_res) @staticmethod @@ -645,6 +666,29 @@ def train(self, batch_handler, config=None, **kwargs): """ config = TrainingConfig.for_gan(config=config, **kwargs) + strategy_was_unset = self.strategy is None + self.configure_multi_gpu(multi_gpu=config.multi_gpu) + + if config.multi_gpu and self.strategy is None: + logger.warning( + 'multi_gpu=True was requested but the model does not have a ' + 'configured strategy. Falling back to the existing serial ' + 'logic.' + ) + + if self.strategy is not None and strategy_was_unset: + self._optimizer = None + self._optimizer_disc = None + + if self.optimizer is None or self.optimizer_disc is None: + with self._training_scope(): + self._optimizer = self.init_optimizer( + self._optimizer_config, learning_rate=None + ) + self._optimizer_disc = self.init_optimizer( + self._optimizer_disc_config, learning_rate=None + ) + if config.log_tb: self._init_tensorboard_writer(config.out_dir) @@ -909,14 +953,10 @@ def _train_batch( Weight factor for the adversarial loss component of the generator vs. the discriminator. multi_gpu : bool - Flag to break up the batch for parallel gradient descent - calculations on multiple gpus. If True and multiple GPUs are - present, each batch from the batch_handler will be divided up - between the GPUs and resulting gradients from each GPU will be - summed and then applied once per batch at the nominal learning - rate that the model and optimizer were initialized with. - If true and multiple gpus are found, ``default_device`` device - should be set to /gpu:0 + Flag to use multi-GPU distributed training. If True and a + strategy has been configured, the batch gradient step will be run + through the configured strategy. If no strategy is configured, + this method falls back to serial execution. Returns ------- @@ -1049,14 +1089,10 @@ def _train_epoch( the discriminators will not train unless train_disc=True or and train_gen=False. multi_gpu : bool - Flag to break up the batch for parallel gradient descent - calculations on multiple gpus. If True and multiple GPUs are - present, each batch from the batch_handler will be divided up - between the GPUs and resulting gradients from each GPU will be - summed and then applied once per batch at the nominal learning - rate that the model and optimizer were initialized with. - If true and multiple gpus are found, ``default_device`` device - should be set to /gpu:0 + Flag to use multi-GPU distributed training. If True and a + strategy has been configured, batch updates will be distributed + across replicas. If no strategy is configured, this method falls + back to serial execution. export_tb : bool Whether to export profiling information to tensorboard. This can then be viewed in the tensorboard dashboard under the profile tab diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index d4854aa90..6c61ef834 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -62,9 +62,8 @@ def __init__( default_device : str | None Option for default device placement of model weights. If None and a single GPU exists, that GPU will be the default device. If None and - multiple GPUs exist, the CPU will be the default device (this was - tested as most efficient given the custom multi-gpu strategy - developed in self.run_gradient_descent()) + multiple GPUs exist, the CPU will be the default device for + serial execution and weight initialization. name : str | None Optional name for the model. """ @@ -264,12 +263,10 @@ def _train_epoch(self, batch_handler, multi_gpu=False): batch_handler : sup3r.preprocessing.BatchHandler BatchHandler object to iterate through multi_gpu : bool - Flag to break up the batch for parallel gradient descent - calculations on multiple gpus. If True and multiple GPUs are - present, each batch from the batch_handler will be divided up - between the GPUs and the resulting gradient from each GPU will - constitute a single gradient descent step with the nominal learning - rate that the model was initialized with. + Flag to use multi-GPU distributed training. If True and a + strategy has been configured, batch updates will be distributed + across replicas. If no strategy is configured, this method falls + back to serial execution. Returns ------- diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index ad4802abb..cc7dcbaf6 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -402,6 +402,70 @@ def test_optimizer_update(): assert model.optimizer_disc.learning_rate == 0.1 +def test_run_gradient_descent_multi_gpu_dispatch(monkeypatch): + """Test that multi_gpu execution uses the strategy-backed path.""" + + class FakeStrategy: + num_replicas_in_sync = 2 + + model = Sup3rGan( + pytest.S_FP_GEN, + pytest.S_FP_DISC, + learning_rate=1e-4, + ) + model._strategy = FakeStrategy() + model._multi_gpu = True + + called = {} + + def fake_get_train_fns(train_gen=True, train_disc=False): + called['train_flags'] = (train_gen, train_disc) + return 'grad_fn', 'apply_fn' + + def fake_run_mirrored_grad( + low_res, + hi_res_true, + grad_fn, + apply_fn, + **calc_loss_kwargs, + ): + called['mirrored'] = { + 'grad_fn': grad_fn, + 'apply_fn': apply_fn, + 'kwargs': calc_loss_kwargs, + 'low_res_shape': low_res.shape, + 'hi_res_shape': hi_res_true.shape, + } + return {'loss_gen': tf.constant(0.0)} + + monkeypatch.setattr(model, '_get_train_fns', fake_get_train_fns) + monkeypatch.setattr(model, '_run_mirrored_grad', fake_run_mirrored_grad) + monkeypatch.setattr( + model, + '_run_serial_grad', + lambda *args, **kwargs: pytest.fail('serial path should not run'), + ) + + low_res = np.ones((2, 4, 4, len(FEATURES)), dtype=np.float32) + hi_res = np.ones((2, 8, 8, len(FEATURES)), dtype=np.float32) + loss_details = model.run_gradient_descent( + low_res, + hi_res, + train_gen=True, + train_disc=False, + multi_gpu=True, + weight_gen_advers=0.0, + ) + + assert called['train_flags'] == (True, False) + assert called['mirrored']['grad_fn'] == 'grad_fn' + assert called['mirrored']['apply_fn'] == 'apply_fn' + assert called['mirrored']['kwargs']['weight_gen_advers'] == 0.0 + assert called['mirrored']['low_res_shape'] == low_res.shape + assert called['mirrored']['hi_res_shape'] == hi_res.shape + assert float(loss_details['loss_gen']) == 0.0 + + def test_input_res_check(): """Make sure error is raised for invalid input resolution""" From 10523d773e8fd631d49f7a84d438a88127a76e29 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 26 Apr 2026 14:12:58 -0600 Subject: [PATCH 02/34] refactor: streamline loss calculation and update loss details in Sup3rGan --- sup3r/models/base.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 545a9bd91..2030a842f 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -363,9 +363,7 @@ def get_single_grad_disc( **calc_loss_kwargs, ): """Run discriminator-only gradient calculation for one mini-batch.""" - with self._training_scope( - device_name - ), tf.GradientTape() as tape: + with self._training_scope(device_name), tf.GradientTape() as tape: hi_res_exo = self.get_hr_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss, loss_details = self.calc_loss( @@ -848,14 +846,17 @@ def calc_loss( loss_details = {} loss = None + loss_gen_advers = None + loss_disc = None + loss_gen = None + loss_gen_content = None disc_out_true = None disc_out_gen = None - loss_gen_advers = None if train_disc or compute_disc: disc_out_true = self._tf_discriminate(hi_res_true) disc_out_gen = self._tf_discriminate(hi_res_gen) - loss_details['loss_disc'] = self.calc_loss_disc( + loss_disc = self.calc_loss_disc( disc_out_true=disc_out_true, disc_out_gen=disc_out_gen ) @@ -863,7 +864,6 @@ def calc_loss( loss_gen_advers = self.calc_loss_disc( disc_out_true=disc_out_gen, disc_out_gen=disc_out_true ) - loss_details['loss_gen_advers'] = loss_gen_advers if train_gen: loss_gen_content, loss_gen_content_details = ( @@ -874,13 +874,20 @@ def calc_loss( if loss_gen_advers is None else loss_gen_content + weight_gen_advers * loss_gen_advers ) - loss_details['loss_gen'] = loss - loss_details['loss_gen_content'] = loss_gen_content + loss_gen = loss loss_details.update(loss_gen_content_details) elif train_disc: - loss = loss_details['loss_disc'] - + loss = loss_disc + + loss_details['loss_gen_advers'] = loss_gen_advers + loss_details['loss'] = loss + loss_details['loss_disc'] = loss_disc + loss_details['loss_gen'] = loss_gen + loss_details['loss_gen_content'] = loss_gen_content + loss_details = { + k: float(v) for k, v in loss_details.items() if v is not None + } return loss, loss_details def calc_val_loss(self, batch_handler, weight_gen_advers): From 8214e7cb1d29e22353d60ec256a806ca3fad4067 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 27 Apr 2026 09:35:56 -0600 Subject: [PATCH 03/34] Refactor Sup3rX class for improved caching and feature handling - Introduced caching mechanisms for eager array views in Sup3rX, including as_array and feature selection. - Added methods to clear cache and retrieve feature indices and selectors. - Enhanced the loaded property to utilize caching for performance. - Updated the shape property to cache results for efficiency. - Refined data handling in Sup3rDataset for better member access. - Simplified batch queue and sampler classes by removing unnecessary conversions. - Added tests for rasterizer caching and dual sampler functionality to ensure correctness. --- .gitignore | 1 + pixi.lock | 45 ++++++++++- sup3r/preprocessing/accessor.py | 81 ++++++++++++++++--- sup3r/preprocessing/base.py | 17 ++-- sup3r/preprocessing/batch_queues/abstract.py | 16 ++-- sup3r/preprocessing/batch_queues/base.py | 5 +- sup3r/preprocessing/samplers/base.py | 45 ++++++----- sup3r/preprocessing/samplers/dual.py | 7 +- sup3r/preprocessing/samplers/utilities.py | 4 +- tests/batch_handlers/test_bh_h5_cc.py | 2 +- ...caching.py => test_rasterizers_caching.py} | 0 ...{test_dual.py => test_rasterizers_dual.py} | 0 ...general.py => test_rasterizers_general.py} | 0 tests/samplers/test_samplers_dual.py | 45 +++++++++++ 14 files changed, 213 insertions(+), 55 deletions(-) rename tests/rasterizers/{test_rasterizer_caching.py => test_rasterizers_caching.py} (100%) rename tests/rasterizers/{test_dual.py => test_rasterizers_dual.py} (100%) rename tests/rasterizers/{test_rasterizer_general.py => test_rasterizers_general.py} (100%) create mode 100644 tests/samplers/test_samplers_dual.py diff --git a/.gitignore b/.gitignore index 6bb9a376a..10bdc732b 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,4 @@ tags # test dirs exo_cache +.timings diff --git a/pixi.lock b/pixi.lock index 94dd0dcb8..9d069fd65 100644 --- a/pixi.lock +++ b/pixi.lock @@ -296,8 +296,10 @@ environments: - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cb/8c/2b30c12155ad8de0cf641d76a8b396a16d2c36bc6d50b621a62b7c4567c1/build-1.3.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1d/33/f1c6a276de27b7d7339a34749cc33fa87f077f921969c47185d34a887ae2/gast-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz @@ -310,6 +312,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/74/aa/f0e402ab0c1aa371e3143ed0b79744ebd6091307087cda59e3249f45cac8/nvidia_cublas_cu12-12.3.4.1-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/6a/43/b30e742c204c5c81a6954c5f20ce82b098f9c4ca3d3cb76eea5476b83f5d/nvidia_cuda_cupti_cu12-12.3.101-py3-none-manylinux1_x86_64.whl @@ -347,6 +350,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/7f/b2/0bba9bbb4596d2d2f285a16c2ab04118f6b957d8441566e1abb892e6a6b2/werkzeug-3.1.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/c7/7376998449689cf2adbdbeacad47084410d00f3ae04cf73e6127cf52b950/wrapt-1.14.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -584,12 +588,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.25.0-py312h37e1c23_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: https://files.pythonhosted.org/packages/cb/8c/2b30c12155ad8de0cf641d76a8b396a16d2c36bc6d50b621a62b7c4567c1/build-1.3.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/fa/3d/f4f2ba829efb54b6cd2d91349c7463316a9cc55a43fc980447416c88540f/pkginfo-1.12.1.2-py3-none-any.whl @@ -605,6 +612,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ default: channels: @@ -877,8 +885,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1d/33/f1c6a276de27b7d7339a34749cc33fa87f077f921969c47185d34a887ae2/gast-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz @@ -891,6 +901,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/74/aa/f0e402ab0c1aa371e3143ed0b79744ebd6091307087cda59e3249f45cac8/nvidia_cublas_cu12-12.3.4.1-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/6a/43/b30e742c204c5c81a6954c5f20ce82b098f9c4ca3d3cb76eea5476b83f5d/nvidia_cuda_cupti_cu12-12.3.101-py3-none-manylinux1_x86_64.whl @@ -926,6 +937,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/7f/b2/0bba9bbb4596d2d2f285a16c2ab04118f6b957d8441566e1abb892e6a6b2/werkzeug-3.1.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/c7/7376998449689cf2adbdbeacad47084410d00f3ae04cf73e6127cf52b950/wrapt-1.14.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -1143,12 +1155,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.25.0-py312h37e1c23_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl @@ -1162,6 +1177,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ dev: channels: @@ -2409,8 +2425,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1d/33/f1c6a276de27b7d7339a34749cc33fa87f077f921969c47185d34a887ae2/gast-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz @@ -2423,6 +2441,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/74/aa/f0e402ab0c1aa371e3143ed0b79744ebd6091307087cda59e3249f45cac8/nvidia_cublas_cu12-12.3.4.1-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/6a/43/b30e742c204c5c81a6954c5f20ce82b098f9c4ca3d3cb76eea5476b83f5d/nvidia_cuda_cupti_cu12-12.3.101-py3-none-manylinux1_x86_64.whl @@ -2459,6 +2478,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/7f/b2/0bba9bbb4596d2d2f285a16c2ab04118f6b957d8441566e1abb892e6a6b2/werkzeug-3.1.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/c7/7376998449689cf2adbdbeacad47084410d00f3ae04cf73e6127cf52b950/wrapt-1.14.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -2681,12 +2701,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.25.0-py312h37e1c23_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl @@ -2701,6 +2724,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ test: channels: @@ -2975,8 +2999,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1d/33/f1c6a276de27b7d7339a34749cc33fa87f077f921969c47185d34a887ae2/gast-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz @@ -2989,6 +3015,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/74/aa/f0e402ab0c1aa371e3143ed0b79744ebd6091307087cda59e3249f45cac8/nvidia_cublas_cu12-12.3.4.1-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/6a/43/b30e742c204c5c81a6954c5f20ce82b098f9c4ca3d3cb76eea5476b83f5d/nvidia_cuda_cupti_cu12-12.3.101-py3-none-manylinux1_x86_64.whl @@ -3026,6 +3053,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/7f/b2/0bba9bbb4596d2d2f285a16c2ab04118f6b957d8441566e1abb892e6a6b2/werkzeug-3.1.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/c7/7376998449689cf2adbdbeacad47084410d00f3ae04cf73e6127cf52b950/wrapt-1.14.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -3245,12 +3273,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.25.0-py312h37e1c23_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl @@ -3266,6 +3297,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ viz: channels: @@ -3705,8 +3737,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1d/33/f1c6a276de27b7d7339a34749cc33fa87f077f921969c47185d34a887ae2/gast-0.7.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz @@ -3719,6 +3753,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/fb/53/78c98ef5c8b2b784453487f3e4d6c017b20747c58b470393e230c78d18e8/numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/74/aa/f0e402ab0c1aa371e3143ed0b79744ebd6091307087cda59e3249f45cac8/nvidia_cublas_cu12-12.3.4.1-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/6a/43/b30e742c204c5c81a6954c5f20ce82b098f9c4ca3d3cb76eea5476b83f5d/nvidia_cuda_cupti_cu12-12.3.101-py3-none-manylinux1_x86_64.whl @@ -3754,6 +3789,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/7f/b2/0bba9bbb4596d2d2f285a16c2ab04118f6b957d8441566e1abb892e6a6b2/werkzeug-3.1.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b7/c7/7376998449689cf2adbdbeacad47084410d00f3ae04cf73e6127cf52b950/wrapt-1.14.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ osx-arm64: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda @@ -4123,12 +4159,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstandard-0.25.0-py312h37e1c23_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda + - pypi: https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl @@ -4142,6 +4181,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/99/55/db07de81b5c630da5cbf5c7df646580ca26dfaefa593667fc6f2fe016d2e/tabulate-0.10.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/44/15/bb13b4913ef95ad5448490821eee4671d0e67673342e4d4070854e5fe081/zarr-3.1.5-py3-none-any.whl - pypi: ./ packages: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda @@ -10578,8 +10618,8 @@ packages: requires_python: '>=3.9' - pypi: ./ name: nrel-sup3r - version: 0.2.7.dev15+gfdf85195e.d20260423 - sha256: 958a347dca72e7615053c48d0e1fa737a832ad79d7816dc56f6d153123bb81ca + version: 0.2.7.dev29+g2a8491d0d + sha256: 86efec4f2e4479b26d09b6f7c65223a94c21e325f2b3815d0ee41672780dadb0 requires_dist: - nlr-phygnn @ git+https://github.com/NatLabRockies/phygnn.git@bnb/tf - nrel-rex>=0.2.91 @@ -10595,6 +10635,7 @@ packages: - pillow>=10.0 - scipy>=1.0.0 - xarray>=2023.0 + - zarr>=2.18.0,<4 - pre-commit ; extra == 'dev' - pylint ; extra == 'dev' - ruff>=0.4 ; extra == 'dev' diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 70f5c98be..3a70abe0f 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -93,8 +93,46 @@ def __init__(self, ds: Union[xr.Dataset, Self]): self._ds = ds self._features = None self._meta = None + self._shape = None + self._loaded = None + self._as_array_cache = None + self._feature_inds = None + self._feature_sel_cache = None self.time_slice = None + def _clear_array_cache(self): + """Clear cached eager array views after mutating the dataset.""" + self._features = None + self._shape = None + self._loaded = None + self._as_array_cache = None + self._feature_inds = None + self._feature_sel_cache = None + + def _get_feature_inds(self, features): + """Get cached feature indices for the ordered eager array.""" + if self._feature_inds is None: + self._feature_inds = { + feature: i for i, feature in enumerate(self.features) + } + return [self._feature_inds[feature] for feature in features] + + def _get_feature_sel(self, features): + """Get cached feature selector for eager ordered array slicing.""" + if self._feature_sel_cache is None: + self._feature_sel_cache = {} + + cache_key = tuple(features) + if cache_key not in self._feature_sel_cache: + feature_inds = self._get_feature_inds(features) + self._feature_sel_cache[cache_key] = ( + slice(None) + if feature_inds == list(range(len(self.features))) + else feature_inds + ) + + return self._feature_sel_cache[cache_key] + def __getitem__( self, keys ) -> Union[Union[np.ndarray, da.core.Array], Self]: @@ -216,8 +254,16 @@ def to_dataarray(self) -> Union[np.ndarray, da.core.Array]: def as_array(self): """Return ``.data`` attribute of an xarray.DataArray with our standard dimension order ``(lats, lons, time, ..., features)``""" + if self.loaded and self._as_array_cache is not None: + return self._as_array_cache + out = self.to_dataarray() - return getattr(out, 'data', out) + out = getattr(out, 'data', out) + + if self.loaded: + self._as_array_cache = out + + return out def _stack_features(self, arrs): if self.loaded: @@ -227,6 +273,7 @@ def _stack_features(self, arrs): def compute(self, **kwargs): """Load `._ds` into memory. This updates the internal `xr.Dataset` if it has not been loaded already.""" + self._clear_array_cache() if not self.loaded: logger.debug(f'Loading dataset into memory: {self._ds}') logger.debug(f'Pre-loading: {_mem_check()}') @@ -240,15 +287,18 @@ def compute(self, **kwargs): ) logger.debug(f'Loaded dataset into memory: {self._ds}') logger.debug(f'Post-loading: {_mem_check()}') + self._loaded = True return self @property def loaded(self): """Check if data has been loaded as numpy arrays.""" - return all( - isinstance(self._ds[f].data, np.ndarray) - for f in list(self._ds.data_vars) - ) + if self._loaded is None: + self._loaded = all( + isinstance(self._ds[f].data, np.ndarray) + for f in self._ds.data_vars + ) + return self._loaded @property def flattened(self): @@ -291,6 +341,7 @@ def update_ds(self, new_dset, attrs=None): data_vars.update(new_data) self._ds = xr.Dataset(coords=coords, data_vars=data_vars, attrs=attrs) + self._clear_array_cache() return self @property @@ -317,6 +368,10 @@ def sample(self, idx): _lowered(idx[-1]) if is_type_of(idx[-1], str) else self.features ) + if self.loaded: + feature_sel = self._get_feature_sel(features) + return self.as_array()[(*idx[:-1], feature_sel)] + out = self._ds[features].isel(**isel_kwargs) return self.ordered(out.to_array()).data @@ -356,6 +411,7 @@ def std(self, **kwargs): def normalize(self, means, stds): """Normalize dataset using given means and stds. These are provided as dictionaries.""" + self._clear_array_cache() feats = set(self._ds.data_vars).intersection(means).intersection(stds) for f in feats: self._ds[f] = (self._ds[f] - means[f]) / stds[f] @@ -454,6 +510,7 @@ def assign( array). If dims are not provided this will try to use stored dims of the variable, if it exists already. """ + self._clear_array_cache() data_vars = self.add_dims_to_data_vars(vals) if all(f in self.coords for f in vals): self._ds = self._ds.assign_coords(data_vars) @@ -464,7 +521,9 @@ def assign( @property def features(self): """Features in this container.""" - return list(self._ds.data_vars) + if self._features is None: + self._features = list(self._ds.data_vars) + return self._features @property def dtype(self): @@ -475,9 +534,13 @@ def dtype(self): def shape(self): """Get shape of underlying xr.DataArray, using our standard dimension order.""" - dim_dict = dict(self._ds.sizes) - dim_vals = [dim_dict[k] for k in Dimension.order() if k in dim_dict] - return (*dim_vals, len(self._ds.data_vars)) + if self._shape is None: + dim_dict = dict(self._ds.sizes) + dim_vals = [ + dim_dict[k] for k in Dimension.order() if k in dim_dict + ] + self._shape = (*dim_vals, len(self.features)) + return self._shape @property def size(self): diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 803e83ffb..d751b0bea 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -20,10 +20,7 @@ import sup3r.preprocessing.accessor # noqa: F401 # pylint: disable=W0611 from sup3r.preprocessing.accessor import Sup3rX -from sup3r.preprocessing.utilities import ( - composite_info, - is_type_of, -) +from sup3r.preprocessing.utilities import composite_info, is_type_of from sup3r.utilities.utilities import Timer logger = logging.getLogger(__name__) @@ -225,9 +222,13 @@ def sample(self, idx): of slices for the dimensions (south_north, west_east, time) and a list of feature names or a tuple of the same, for multi-member datasets (dual datasets and dual with observations datasets).""" + if len(self._ds) == 2: + return (self._ds[0].sample(idx[0]), self._ds[1].sample(idx[1])) + if len(self._ds) > 1: return tuple(d.sample(idx[i]) for i, d in enumerate(self)) - return self._ds[-1].sample(idx) + + return self._ds[0].sample(idx) def isel(self, *args, **kwargs): """Return new Sup3rDataset with isel applied to each member.""" @@ -236,7 +237,7 @@ def isel(self, *args, **kwargs): def __getitem__(self, keys): """If keys is an int this is interpreted as a request for that member of ``self._ds``. Otherwise, if there's only a single member of - ``self._ds`` we get self._ds[-1][keys]. If there's two members we get + ``self._ds`` we get self._ds[0][keys]. If there's two members we get ``(self._ds[0][keys], self._ds[1][keys])`` and cast this back to a ``Sup3rDataset`` if each of ``self._ds[i][keys]`` is a ``Sup3rX`` object""" @@ -245,7 +246,7 @@ def __getitem__(self, keys): out = tuple(self._getitem(d, keys) for d in self._ds) if len(self._ds) == 1: - return out[-1] + return out[0] if all(isinstance(o, Sup3rX) for o in out): return type(self)(**dict(zip(self._ds.dset_names, out))) return out @@ -282,7 +283,7 @@ def __setitem__(self, keys, data): so interpret this as sending a tuple / list element to each dset member. e.g. ``vals[0] -> dsets[0]``, ``vals[1] -> dsets[1]``, etc""" if len(self._ds) == 1: - self._ds[-1].__setitem__(keys, data) + self._ds[0].__setitem__(keys, data) else: for i, self_i in enumerate(self): dat = data[i] if isinstance(data, (tuple, list)) else data diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index 4c5bef05e..f848d64b3 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -124,6 +124,8 @@ def queue_shape(self): @property def queue_len(self): """Get number of batches in the queue.""" + if self.queue is None: + return self.queue_futures return self.queue.size().numpy() + self.queue_futures @property @@ -134,6 +136,8 @@ def queue_futures(self): def get_queue(self): """Return FIFO queue for storing batches.""" + if self.mode == 'eager' or self.queue_cap == 0: + return None return tf.queue.FIFOQueue( self.queue_cap, dtypes=[tf.float32] * len(self.queue_shape), @@ -235,7 +239,7 @@ def __iter__(self): def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" - if self.mode == 'eager' or self.queue_cap == 0 or self.queue_len == 0: + if self.queue is None or self.queue_len == 0: samples = self.sample_batch() else: samples = self.queue.dequeue() @@ -252,6 +256,7 @@ def running(self): return ( self._training_flag.is_set() and self.queue_thread.is_alive() + and self.queue is not None and not self.queue.is_closed() ) @@ -327,13 +332,14 @@ def sample_batch(self): These samples are wrapped in an ``np.asarray`` call, so they have been loaded into memory. """ - out = next(self.get_random_container()) - if not isinstance(out, tuple): - return tf.convert_to_tensor(out, dtype=tf.float32) - return tuple(tf.convert_to_tensor(o, dtype=tf.float32) for o in out) + return next(self.get_random_container()) def log_queue_info(self): """Log info about queue size.""" + if self.queue is None: + return '{} queue disabled (mode={}, queue_cap={})'.format( + self._thread_name.title(), self.mode, self.queue_cap + ) return '{} queue length: {} / {}'.format( self._thread_name.title(), self.queue_len, self.queue_cap ) diff --git a/sup3r/preprocessing/batch_queues/base.py b/sup3r/preprocessing/batch_queues/base.py index c45f8c81c..0d3e36762 100644 --- a/sup3r/preprocessing/batch_queues/base.py +++ b/sup3r/preprocessing/batch_queues/base.py @@ -66,7 +66,8 @@ def transform( (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features) """ - lr_samples = numpy_if_tensor(samples)[..., self.lr_features_ind] + samples = numpy_if_tensor(samples) + lr_samples = samples[..., self.lr_features_ind] low_res = spatial_coarsening(lr_samples, self.s_enhance) low_res = ( low_res @@ -81,5 +82,5 @@ def transform( low_res = smooth_data( low_res, self.lr_features, smoothing_ignore, smoothing ) - high_res = numpy_if_tensor(samples)[..., self.hr_features_ind] + high_res = samples[..., self.hr_features_ind] return low_res, high_res diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index 836e9bab2..b99c76a2a 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -4,6 +4,7 @@ import logging from fnmatch import fnmatch +from functools import cached_property from typing import Optional from warnings import warn @@ -186,12 +187,7 @@ def get_sample_index(self, n_obs=None): time_slice = uniform_time_sampler( self.shape, self.sample_shape[2] * n_obs ) - feats = ( - [f for f in self.hr_source_features if f not in self.obs_features] - if self.use_proxy_obs - else self.hr_source_features - ) - return (*spatial_slice, time_slice, feats) + return (*spatial_slice, time_slice, self.hr_sample_features) def preflight(self): """Perform shape and feature checks.""" @@ -225,6 +221,7 @@ def preflight(self): if self.data.shape[2] < self.sample_shape[2] * self.batch_size: logger.warning(msg) warn(msg) + if self.mode == 'eager': logger.info('Received mode = "eager".') _ = self.compute() @@ -410,7 +407,8 @@ def _compute_samples(self, samples): def _fast_batch(self): """Get batch of samples with adjacent time slices.""" - out = self.data.sample(self.get_sample_index(n_obs=self.batch_size)) + idx = self.get_sample_index(n_obs=self.batch_size) + out = self.data.sample(idx) out = self._compute_samples(out) if isinstance(out, tuple): out = tuple(self._reshape_samples(o) for o in out) @@ -537,14 +535,14 @@ def _parse_features(self, unparsed_feats): parsed_feats = out return lowered(parsed_feats) - @property + @cached_property def lr_features(self): """List of feature names or patt*erns to use as low-resolution model inputs. If no entry is provided then all available features from the data will be used.""" return self._parse_features(self._lr_features) - @property + @cached_property def hr_source_features(self): """List of feature names or patt*erns that should be available natively as high-resolution. For a non-dual sampler this is all features, since @@ -562,19 +560,29 @@ def hr_source_features(self): feats += [f for f in self.hr_exo_features if f not in feats] return feats - @property + @cached_property def hr_features(self): """List of feature names or patt*erns that the model is shown at high-resolution. This does not include features that are only shown to - the model after coarsening. Thus, this includes hr_out_features and - and hr_exo_features.""" + the model after coarsening. Thus, this includes hr_out_features and + and hr_exo_features but not lr_features.""" out = [ f for f in self.hr_out_features if f not in self.hr_exo_features ] out += self.hr_exo_features return out - @property + @cached_property + def hr_sample_features(self): + """List of feature names used in the sample index for the + high-resolution training data.""" + return ( + [f for f in self.hr_source_features if f not in self.obs_features] + if self.use_proxy_obs + else self.hr_source_features + ) + + @cached_property def hr_out_features(self): """List of feature names or patt*erns that should be output by the generative model. If no entry is provided then all features in @@ -582,7 +590,7 @@ def hr_out_features(self): hr_out = self._parse_features(self._hr_out_features) return self.lr_features if len(hr_out) == 0 else hr_out - @property + @cached_property def hr_exo_features(self): """Get a list of exogenous high-resolution features that are only used for training e.g., mid-network high-res topo injection. These must come @@ -590,7 +598,7 @@ def hr_exo_features(self): model as low-res features.""" return self._parse_features(self._hr_exo_features) - @property + @cached_property def obs_features(self): """List of feature names or patt*erns that should be treated as observations. These features will be included in the high-res data but @@ -600,7 +608,7 @@ def obs_features(self): values where observations are not available.""" return [f for f in self.hr_source_features if '_obs' in f] - @property + @cached_property def hr_features_ind(self): """Get the high-resolution feature channel indices that should be included for loss calculations. This includes hr_out_features and @@ -609,14 +617,14 @@ def hr_features_ind(self): """ return [self.hr_source_features.index(f) for f in self.hr_features] - @property + @cached_property def lr_features_ind(self): """Get the low-resolution feature channel indices that should be included for training. This includes lr_features. """ return [self.hr_source_features.index(f) for f in self.lr_features] - @property + @cached_property def obs_features_ind(self): """Get the source feature indices in ``features`` for each obs feature. Each obs feature named ``_obs`` maps to the @@ -632,7 +640,6 @@ def obs_features_ind(self): if self.use_proxy_obs else self.obs_features ) - return [self.hr_source_features.index(f) for f in check_feats] def _get_obs_mask(self, hi_res, spatial_frac, time_frac=1.0): diff --git a/sup3r/preprocessing/samplers/dual.py b/sup3r/preprocessing/samplers/dual.py index cfe05285a..3142cdbc7 100644 --- a/sup3r/preprocessing/samplers/dual.py +++ b/sup3r/preprocessing/samplers/dual.py @@ -213,11 +213,6 @@ def get_sample_index(self, n_obs=None): slice(s.start * self.t_enhance, s.stop * self.t_enhance) for s in lr_index[2:-1] ] - hr_feats = ( - [f for f in self.hr_source_features if f not in self.obs_features] - if self.use_proxy_obs - else self.hr_source_features - ) - hr_index = (*hr_index, hr_feats) + hr_index = (*hr_index, self.hr_sample_features) return (lr_index, hr_index) diff --git a/sup3r/preprocessing/samplers/utilities.py b/sup3r/preprocessing/samplers/utilities.py index 2b9e69cd5..ebc026a0a 100644 --- a/sup3r/preprocessing/samplers/utilities.py +++ b/sup3r/preprocessing/samplers/utilities.py @@ -6,9 +6,7 @@ import dask.array as da import numpy as np -from sup3r.preprocessing.utilities import ( - _compute_chunks_if_dask, -) +from sup3r.preprocessing.utilities import _compute_chunks_if_dask from sup3r.utilities.utilities import RANDOM_GENERATOR logger = logging.getLogger(__name__) diff --git a/tests/batch_handlers/test_bh_h5_cc.py b/tests/batch_handlers/test_bh_h5_cc.py index f349d6ece..1ec9c1084 100644 --- a/tests/batch_handlers/test_bh_h5_cc.py +++ b/tests/batch_handlers/test_bh_h5_cc.py @@ -381,7 +381,7 @@ def test_surf_min_max_vars(): assert batch.low_res.shape[-1] == len(surf_features) # compare daily avg temp vs min and max - blr = batch.low_res.numpy() + blr = batch.low_res assert (blr[..., 0] > blr[..., 2]).all() assert (blr[..., 0] < blr[..., 3]).all() diff --git a/tests/rasterizers/test_rasterizer_caching.py b/tests/rasterizers/test_rasterizers_caching.py similarity index 100% rename from tests/rasterizers/test_rasterizer_caching.py rename to tests/rasterizers/test_rasterizers_caching.py diff --git a/tests/rasterizers/test_dual.py b/tests/rasterizers/test_rasterizers_dual.py similarity index 100% rename from tests/rasterizers/test_dual.py rename to tests/rasterizers/test_rasterizers_dual.py diff --git a/tests/rasterizers/test_rasterizer_general.py b/tests/rasterizers/test_rasterizers_general.py similarity index 100% rename from tests/rasterizers/test_rasterizer_general.py rename to tests/rasterizers/test_rasterizers_general.py diff --git a/tests/samplers/test_samplers_dual.py b/tests/samplers/test_samplers_dual.py new file mode 100644 index 000000000..40bed01f3 --- /dev/null +++ b/tests/samplers/test_samplers_dual.py @@ -0,0 +1,45 @@ +"""Dual sampler regression tests.""" + +import numpy as np + +from sup3r.preprocessing import DualSampler +from sup3r.preprocessing.base import Sup3rDataset +from sup3r.utilities.pytest.helpers import DummyData +from sup3r.utilities.utilities import RANDOM_GENERATOR + +LR_FEATURES = ['u_100m', 'v_100m', 'temperature_2m'] + + +def test_dual_sampler_eager_vs_lazy(): + """Eager dual sampling should match lazy sampling for the same indices.""" + lr = DummyData( + data_shape=(20, 20, 100), features=LR_FEATURES + ).data.high_res + hr = DummyData( + data_shape=(40, 40, 100), features=[*LR_FEATURES, 'topography'] + ).data.high_res + data = Sup3rDataset(low_res=lr, high_res=hr) + kwargs = { + 'data': data, + 'sample_shape': (20, 20, 8), + 'batch_size': 4, + 's_enhance': 2, + 't_enhance': 1, + 'feature_sets': { + 'lr_features': LR_FEATURES, + 'hr_out_features': ['u_100m', 'v_100m'], + 'hr_exo_features': ['topography'], + }, + } + + state = RANDOM_GENERATOR.bit_generator.state + eager_sampler = DualSampler(mode='eager', **kwargs) + RANDOM_GENERATOR.bit_generator.state = state + lazy_sampler = DualSampler(mode='lazy', **kwargs) + + eager_batch = next(eager_sampler) + RANDOM_GENERATOR.bit_generator.state = state + lazy_batch = next(lazy_sampler) + + assert np.allclose(eager_batch[0], lazy_batch[0]) + assert np.allclose(eager_batch[1], lazy_batch[1]) From a170c16a05fc43c8bdebfe8147cbcf4bef080285 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 27 Apr 2026 09:53:42 -0600 Subject: [PATCH 04/34] test: add assertions for queue state in eager and lazy batchers --- tests/batch_handlers/test_bh_general.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/batch_handlers/test_bh_general.py b/tests/batch_handlers/test_bh_general.py index b3e29ac17..661b47807 100644 --- a/tests/batch_handlers/test_bh_general.py +++ b/tests/batch_handlers/test_bh_general.py @@ -148,6 +148,8 @@ def test_eager_vs_lazy(): assert eager_batcher.containers[0].mode == 'eager' assert not lazy_batcher.loaded assert lazy_batcher.containers[0].mode == 'lazy' + assert eager_batcher.queue is None + assert lazy_batcher.queue is None assert np.array_equal( eager_batcher.data[0].as_array(), From e581de247bed3805b9be99c3c1a83cd08ad8ac9b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 28 Apr 2026 08:34:58 -0600 Subject: [PATCH 05/34] feat: implement exception handling in TrainingSession with dedicated thread class --- sup3r/models/utilities.py | 24 ++++++++++++++++- tests/training/test_training_session.py | 36 +++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 tests/training/test_training_session.py diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index d51d5a81b..57740848a 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -18,6 +18,23 @@ logger = logging.getLogger(__name__) +class _ExceptionPropagatingThread(threading.Thread): + """Thread wrapper that captures exceptions for re-raising on join.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exception = None + self.exception_traceback = None + + def run(self): + """Run the configured target and capture any failure.""" + try: + super().run() + except BaseException as exc: + self.exception = exc + self.exception_traceback = exc.__traceback__ + + def get_sup3r_layers(): """Get all classes from phygnn.layers.custom_layers whose names start with 'Sup3r'. @@ -171,7 +188,7 @@ def __init__(self, batch_handler, model, config=None, **kwargs): def run(self): """Wrap model.train().""" - model_thread = threading.Thread( + model_thread = _ExceptionPropagatingThread( target=self.model.train, args=(self.batch_handler,), kwargs={'config': self.config}, @@ -194,6 +211,11 @@ def run(self): sys.exit() model_thread.join() + if model_thread.exception is not None: + self.batch_handler.stop() + raise model_thread.exception.with_traceback( + model_thread.exception_traceback + ) logger.info('Finished training') diff --git a/tests/training/test_training_session.py b/tests/training/test_training_session.py new file mode 100644 index 000000000..7b010e371 --- /dev/null +++ b/tests/training/test_training_session.py @@ -0,0 +1,36 @@ +"""Tests for training session error propagation.""" + +import pytest + +from sup3r.models.utilities import TrainingSession + + +class _FakeBatchHandler: + def __init__(self): + self.stopped = False + + def stop(self): + self.stopped = True + + +class _FailingModel: + def train(self, batch_handler, config=None): + raise RuntimeError('thread failure') + + +def test_training_session_reraises_thread_failure(): + """Worker thread failures should propagate to the caller.""" + + batch_handler = _FakeBatchHandler() + session = TrainingSession( + batch_handler=batch_handler, + model=_FailingModel(), + input_resolution={'spatial': '1km', 'temporal': '1h'}, + out_dir='test_{epoch}', + n_epoch=1, + ) + + with pytest.raises(RuntimeError, match='thread failure'): + session.run() + + assert batch_handler.stopped From 05f25b0fd6b593ec5ee6dcbd604fac7398491ca7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 29 Apr 2026 13:29:25 -0600 Subject: [PATCH 06/34] refactor: remove unnecessary tf.function decorator from calc_loss method in Sup3rGan --- sup3r/models/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 2030a842f..4cfc87ee8 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -789,7 +789,6 @@ def train(self, batch_handler, config=None, **kwargs): batch_handler.stop() - @tf.function def calc_loss( self, hi_res_true, From 36987ff342d49b1db0493b552dfbc2dfa6584837 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 29 Apr 2026 19:38:06 -0600 Subject: [PATCH 07/34] feat: enhance loss metrics with TensorFlow rank assertions and optimize kernel calculations --- sup3r/utilities/loss_metrics.py | 144 +++++++++++++++++++++----------- 1 file changed, 97 insertions(+), 47 deletions(-) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 12bf9350a..07a05b216 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -124,10 +124,22 @@ def gaussian_kernel(x_true, x_gen, sigma=1.0): return result +def _assert_rank_in(x, ranks, message): + """TensorFlow rank assertion that is safe under tf.function tracing.""" + tf.debugging.assert_equal( + tf.reduce_any( + tf.equal(tf.rank(x), tf.constant(ranks, dtype=tf.int32)) + ), + True, + message=message, + ) + + class ExpLoss(Sup3rLoss): """Loss class for squared exponential difference""" @staticmethod + @tf.function def call(x_true, x_gen): """Exponential difference loss function @@ -152,6 +164,7 @@ class MmdLoss(Sup3rLoss): """Loss class for max mean discrepancy loss""" @staticmethod + @tf.function def call(x_true, x_gen, sigma=1.0): """Maximum mean discrepancy (MMD) based on Gaussian kernel function for keras models @@ -186,6 +199,7 @@ class SpatialDerivativeLoss(Sup3rLoss): LOSS_METRIC = MeanAbsoluteError() + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages accuracy of spatial derivatives @@ -208,7 +222,8 @@ def call(self, x_true, x_gen): 'spatiotemporal data only. Received tensor(s) that are not at ' 'least 4D' ) - assert len(x_true.shape) >= 4 and len(x_gen.shape) >= 4, msg + tf.debugging.assert_greater_equal(tf.rank(x_true), 4, message=msg) + tf.debugging.assert_greater_equal(tf.rank(x_gen), 4, message=msg) x_true_div = tf_derivative(x_true, axis=1) + tf_derivative( x_true, axis=2 @@ -223,6 +238,7 @@ class TemporalDerivativeLoss(Sup3rLoss): LOSS_METRIC = MeanAbsoluteError() + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages accuracy of temporal derivative @@ -244,7 +260,8 @@ def call(self, x_true, x_gen): f'The {self.__class__.__name__} is meant to be used on ' 'spatiotemporal data only. Received tensor(s) that are not 5D' ) - assert len(x_true.shape) == 5 and len(x_gen.shape) == 5, msg + _assert_rank_in(x_true, (5,), msg) + _assert_rank_in(x_gen, (5,), msg) x_true_div = tf_derivative(x_true, axis=3) x_gen_div = tf_derivative(x_gen, axis=3) @@ -257,6 +274,7 @@ class CoarseMseLoss(Sup3rLoss): MSE_LOSS = MeanSquaredError() + @tf.function def call(self, x_true, x_gen): """Exponential difference loss function @@ -286,6 +304,7 @@ class SpatialExtremesLoss(Sup3rLoss): MAE_LOSS = MeanAbsoluteError() + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages temporal min/max accuracy @@ -321,6 +340,7 @@ class TemporalExtremesLoss(Sup3rLoss): MAE_LOSS = MeanAbsoluteError() + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages temporal min/max accuracy @@ -358,11 +378,11 @@ class SpatialFftLoss(Sup3rLoss): @staticmethod def _freq_weights(x): """Get product of squared frequencies to weight frequency amplitudes""" - k0 = np.array([k**2 for k in range(x.shape[1])]) - k1 = np.array([k**2 for k in range(x.shape[2])]) - freqs = np.multiply.outer(k0, k1) - freqs = tf.convert_to_tensor(freqs[np.newaxis, ..., np.newaxis]) - return tf.cast(freqs, x.dtype) + shape = tf.shape(x) + k0 = tf.cast(tf.range(shape[1]), x.dtype) + k1 = tf.cast(tf.range(shape[2]), x.dtype) + freqs = tf.square(k0)[:, tf.newaxis] * tf.square(k1)[tf.newaxis, :] + return freqs[tf.newaxis, ..., tf.newaxis] def _fft(self, x): """Apply needed transpositions and fft operation.""" @@ -373,6 +393,7 @@ def _fft(self, x): x_hat = tf.math.multiply(self._freq_weights(x), x_hat) return tf.math.log(1 + x_hat) + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages frequency domain accuracy @@ -404,13 +425,16 @@ class SpatiotemporalFftLoss(Sup3rLoss): @staticmethod def _freq_weights(x): """Get product of squared frequencies to weight frequency amplitudes""" - k0 = np.array([k**2 for k in range(x.shape[1])]) - k1 = np.array([k**2 for k in range(x.shape[2])]) - f = np.array([f**2 for f in range(x.shape[3])]) - freqs = np.multiply.outer(k0, k1) - freqs = np.multiply.outer(freqs, f) - freqs = tf.convert_to_tensor(freqs[np.newaxis, ..., np.newaxis]) - return tf.cast(freqs, x.dtype) + shape = tf.shape(x) + k0 = tf.cast(tf.range(shape[1]), x.dtype) + k1 = tf.cast(tf.range(shape[2]), x.dtype) + freq_t = tf.cast(tf.range(shape[3]), x.dtype) + freqs = ( + tf.square(k0)[:, tf.newaxis, tf.newaxis] + * tf.square(k1)[tf.newaxis, :, tf.newaxis] + * tf.square(freq_t)[tf.newaxis, tf.newaxis, :] + ) + return freqs[tf.newaxis, ..., tf.newaxis] def _fft(self, x): """Apply needed transpositions and fft operation.""" @@ -421,6 +445,7 @@ def _fft(self, x): x_hat = tf.math.multiply(self._freq_weights(x), x_hat) return tf.math.log(1 + x_hat) + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages frequency domain accuracy @@ -498,7 +523,7 @@ def __init__( def _s_coarsen_4d_tensor(self, tensor): """Perform spatial coarsening on a 4D tensor of shape (n_obs, spatial_1, spatial_2, features)""" - shape = tensor.shape + shape = tf.shape(tensor) tensor = tf.reshape( tensor, ( @@ -516,7 +541,7 @@ def _s_coarsen_4d_tensor(self, tensor): def _s_coarsen_5d_tensor(self, tensor): """Perform spatial coarsening on a 5D tensor of shape (n_obs, spatial_1, spatial_2, time, features)""" - shape = tensor.shape + shape = tf.shape(tensor) tensor = tf.reshape( tensor, ( @@ -542,8 +567,12 @@ def _t_coarsen_sample(self, tensor): def _t_coarsen_avg(self, tensor): """Perform temporal coarsening on a 5D tensor of shape (n_obs, spatial_1, spatial_2, time, features)""" - shape = tensor.shape - assert len(shape) == 5 + shape = tf.shape(tensor) + _assert_rank_in( + tensor, + (5,), + 'LowResLoss temporal coarsening expects 5D tensors', + ) tensor = tf.reshape( tensor, (shape[0], shape[1], shape[2], -1, self._t_enhance, shape[4]), @@ -551,6 +580,7 @@ def _t_coarsen_avg(self, tensor): tensor = tf.math.reduce_sum(tensor, axis=4) / self._t_enhance return tensor + @tf.function def call(self, x_true, x_gen): """Custom content loss calculated on re-coarsened low-res fields @@ -575,8 +605,15 @@ def call(self, x_true, x_gen): x_true = tf.cast(x_true, dtype) x_gen = tf.cast(x_gen, dtype) - assert x_true.shape == x_gen.shape - s_only = len(x_true.shape) == 4 + tf.debugging.assert_equal( + tf.shape(x_true), + tf.shape(x_gen), + message=( + 'LowResLoss requires x_true and x_gen to have matching ' + 'shapes' + ), + ) + s_only = x_true.shape.rank == 4 ex_loss = tf.constant(0, dtype=x_true.dtype) if self._ex_loss is not None: @@ -641,6 +678,7 @@ def _feature_loss(self, x_true, x_gen): loss += tf.reduce_mean(tf.square(x_true_f - x_gen_f)) return loss + @tf.function def call(self, x_true, x_gen): """Perceptual loss calculated on true and synthetic feature maps @@ -660,20 +698,16 @@ def call(self, x_true, x_gen): tf.tensor 0D tensor loss value """ - if len(x_true.shape) == 5: - new_shape = ( - x_true.shape[0] * x_true.shape[3], - x_true.shape[1], - x_true.shape[2], - x_true.shape[-1], - ) + if x_true.shape.rank == 5: + shape = tf.shape(x_true) + new_shape = (shape[0] * shape[3], shape[1], shape[2], shape[4]) x_true = tf.reshape(x_true, new_shape) x_gen = tf.reshape(x_gen, new_shape) losses = [] - for i in range(x_true.shape[-1]): - x_true_f = x_true[..., i] - x_gen_f = x_gen[..., i] + for x_true_f, x_gen_f in zip( + tf.unstack(x_true, axis=-1), tf.unstack(x_gen, axis=-1) + ): # VGG input needs 3 RGB channels x_true_f = tf.stack([x_true_f] * 3, axis=-1) @@ -703,6 +737,7 @@ def __init__(self, n_projections=1024): super().__init__() self._n_projections = n_projections + @tf.function def call(self, x_true, x_gen): """Sliced Wasserstein distance based on random 1D projections @@ -720,28 +755,34 @@ def call(self, x_true, x_gen): tf.tensor 0D tensor loss value """ - assert len(x_gen.shape) in {4, 5} and len(x_true.shape) in {4, 5}, ( + msg = ( f'The {self.__class__.__name__} is meant to be used on spatial or ' 'spatiotemporal data only. Received tensor(s) that are not 4/5D' ) - if len(x_true.shape) == 4: + _assert_rank_in(x_true, (4, 5), msg) + _assert_rank_in(x_gen, (4, 5), msg) + + if x_true.shape.rank == 4: x_true = tf.expand_dims(x_true, axis=3) x_gen = tf.expand_dims(x_gen, axis=3) - B, H, W, T, C = x_true.shape + shape = tf.shape(x_true) + B, H, W, T, C = shape[0], shape[1], shape[2], shape[3], shape[4] # Flatten only spatial/time dims → (B, HWT, C) x_true_flat = tf.reshape(x_true, (B, H * W * T, C)) x_gen_flat = tf.reshape(x_gen, (B, H * W * T, C)) # Random projection directions over HWT only - proj = tf.random.normal((self._n_projections, H * W * T)) + proj = tf.random.normal( + (self._n_projections, H * W * T), dtype=x_true.dtype + ) proj = tf.math.l2_normalize(proj, axis=-1) # normalize # Project spatial dimensions → (num_proj, B, C) # matmul: (num_proj, HWT) @ (B, HWT, C) → (B, num_proj, C) - x_true_proj = proj @ x_true_flat - x_gen_proj = proj @ x_gen_flat + x_true_proj = tf.einsum('ph,bhc->bpc', proj, x_true_flat) + x_gen_proj = tf.einsum('ph,bhc->bpc', proj, x_gen_flat) # Sort each projection's distribution along the projection dimension x_true_sorted = tf.sort(x_true_proj, axis=1) @@ -819,6 +860,7 @@ def _compute_md(self, x, feature): return x_div + @tf.function def call(self, x_true, x_gen): """Custom content loss that encourages accuracy of the material derivative. @@ -841,7 +883,8 @@ def call(self, x_true, x_gen): f'The {self.__class__.__name__} is meant to be used on ' 'spatiotemporal data only. Received tensor(s) that are not 5D' ) - assert len(x_true.shape) == 5 and len(x_gen.shape) == 5, msg + _assert_rank_in(x_true, (5,), msg) + _assert_rank_in(x_gen, (5,), msg) x_true_div = tf.stack( [ @@ -1008,6 +1051,7 @@ def _compute_heat_transfer_residual(self, x): ) return -qc + q + int_g + @tf.function def call(self, __, x_gen): """Evaluate the conductive heat-transfer loss @@ -1034,7 +1078,7 @@ def call(self, __, x_gen): 'or spatiotemporal data only. Received tensor(s) that are not ' '4D or 5D' ) - assert len(x_gen.shape) in {4, 5}, msg + _assert_rank_in(x_gen, (4, 5), msg) expr = self._compute_heat_transfer_residual(x_gen) return self.LOSS_METRIC(tf.zeros_like(expr), expr) @@ -1108,6 +1152,7 @@ def _compute_temperature_gradient(self, x): dt = tf_derivative(t, axis=3) return tf.math.maximum(-1 * dt, tf.constant([0.0], dt.dtype)) + @tf.function def call(self, __, x_gen): """Evaluate the positive temperature-gradient loss @@ -1132,7 +1177,7 @@ def call(self, __, x_gen): 'or spatiotemporal data only. Received tensor(s) that are not ' '4D or 5D' ) - assert len(x_gen.shape) in {4, 5}, msg + _assert_rank_in(x_gen, (4, 5), msg) temp_grads = self._compute_temperature_gradient(x_gen) return self.LOSS_METRIC(tf.zeros_like(temp_grads), temp_grads) @@ -1180,6 +1225,7 @@ def __init__( true_features=[moho_gradient_layer], ) + @tf.function def call(self, x_moho, x_gen): """Evaluate the Moho heat-flow boundary-condition loss @@ -1226,6 +1272,7 @@ class GeothermalObsLoss(Sup3rLoss): LOSS_METRIC = MeanAbsoluteError() + @tf.function def call(self, x_true, x_gen): """Evaluate the masked geothermal observation loss @@ -1233,25 +1280,28 @@ def call(self, x_true, x_gen): with the configured generated and true features. Observed values may contain NaNs, which are ignored when computing the loss. """ - check = x_true.shape[-1] == len(self.true_features) - check &= x_gen.shape[-1] == len(self.gen_features) msg = ( f'Number of features in `x_true`: {x_true.shape[-1]} must match ' f'the length of `true_features`: {len(self.true_features)}, ' f'`x_gen`: {x_gen.shape[-1]} must match the length of ' f'`gen_features`: {len(self.gen_features)}' ) - assert check, msg + tf.debugging.assert_equal( + tf.shape(x_true)[-1], len(self.true_features), message=msg + ) + tf.debugging.assert_equal( + tf.shape(x_gen)[-1], len(self.gen_features), message=msg + ) mask = tf.math.logical_not(tf.math.is_nan(x_true)) x_true_m = tf.boolean_mask(x_true, mask) x_gen_m = tf.boolean_mask(x_gen, mask) - - return ( - tf.constant(0, dtype=x_true.dtype) - if tf.math.reduce_all(tf.math.is_nan(x_true_m)) - else self.LOSS_METRIC(x_true_m, x_gen_m) + obs_loss = tf.cond( + tf.math.reduce_all(tf.math.is_nan(x_true)), + lambda: tf.constant(0, dtype=x_true.dtype), + lambda: self.LOSS_METRIC(x_true_m, x_gen_m), ) + return obs_loss def _reshape_depth_feature_for_vertical_derivative(x): From 52c6458a6d24fd0b209669121e36c488dab549a2 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 30 Apr 2026 10:58:48 -0600 Subject: [PATCH 08/34] fix: ensure proper data type casting for high-resolution features in AbstractSingleModel --- sup3r/models/abstract.py | 10 +++++++--- tests/conftest.py | 3 +++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 03ebfbbda..6440fb0f6 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -578,7 +578,10 @@ def _combine_loss_input(self, hi_res_true, hi_res_gen): """ if hi_res_true.shape[-1] > hi_res_gen.shape[-1]: exo_dict = self.get_hr_exo_input(hi_res_true) - exo_data = [exo_dict[feat] for feat in self.hr_exo_features] + exo_data = [ + tf.cast(exo_dict[feat], hi_res_gen.dtype) + for feat in self.hr_exo_features + ] hi_res_gen = tf.concat((hi_res_gen, *exo_data), axis=-1) return hi_res_gen @@ -1287,6 +1290,7 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): feat, norm_in=norm_in, ) + exo = exo.astype(input_array.dtype, copy=False) if feat in features: feat_stack.append(exo) else: @@ -1387,9 +1391,9 @@ def _run_exo_layer(cls, layer, input_array, hi_res_exo): for feat in features + exo_features: assert feat in hi_res_exo, msg.format(feat) if feat in features: - feat_stack.append(hi_res_exo[feat]) + feat_stack.append(tf.cast(hi_res_exo[feat], input_array.dtype)) else: - extras.append(hi_res_exo[feat]) + extras.append(tf.cast(hi_res_exo[feat], input_array.dtype)) hr_exo = tf.concat(feat_stack, axis=-1) if len(extras) > 0: extras = tf.concat(extras, axis=-1) diff --git a/tests/conftest.py b/tests/conftest.py index 3f0d1c4a3..178de9ddf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,9 @@ GLOBAL_STATE = RANDOM_GENERATOR.bit_generator.state +os.environ.setdefault('CUDA_VISIBLE_DEVICES', '-1') +os.environ.setdefault('TF_ENABLE_ONEDNN_OPTS', '0') + @pytest.hookimpl def pytest_configure(config): # pylint: disable=unused-argument # noqa: ARG001 From 563ddffc870c02adf256b4b163e541cd14ed3dc5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 30 Apr 2026 12:56:02 -0600 Subject: [PATCH 09/34] fix: remove unnecessary type casting for exo in AbstractSingleModel --- sup3r/models/abstract.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 6440fb0f6..1f55939c7 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1290,7 +1290,6 @@ def run_exo_layer(self, layer, input_array, exogenous_data, norm_in=True): feat, norm_in=norm_in, ) - exo = exo.astype(input_array.dtype, copy=False) if feat in features: feat_stack.append(exo) else: From 5e3b1abecf995de6e67689108693c9f0103b7981 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 30 Apr 2026 15:15:08 -0600 Subject: [PATCH 10/34] fix: update error handling in loss functions to raise ValueError instead of AssertionError --- sup3r/utilities/loss_metrics.py | 28 ++++++++++++++++++---------- tests/utilities/test_loss_metrics.py | 8 ++++---- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 07a05b216..2278d622e 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -126,13 +126,21 @@ def gaussian_kernel(x_true, x_gen, sigma=1.0): def _assert_rank_in(x, ranks, message): """TensorFlow rank assertion that is safe under tf.function tracing.""" - tf.debugging.assert_equal( + rank = x.shape.rank + if rank is not None: + if rank not in ranks: + raise ValueError(message) + return x + + assertion = tf.debugging.assert_equal( tf.reduce_any( tf.equal(tf.rank(x), tf.constant(ranks, dtype=tf.int32)) ), True, message=message, ) + with tf.control_dependencies([assertion]): + return tf.identity(x) class ExpLoss(Sup3rLoss): @@ -260,8 +268,8 @@ def call(self, x_true, x_gen): f'The {self.__class__.__name__} is meant to be used on ' 'spatiotemporal data only. Received tensor(s) that are not 5D' ) - _assert_rank_in(x_true, (5,), msg) - _assert_rank_in(x_gen, (5,), msg) + x_true = _assert_rank_in(x_true, (5,), msg) + x_gen = _assert_rank_in(x_gen, (5,), msg) x_true_div = tf_derivative(x_true, axis=3) x_gen_div = tf_derivative(x_gen, axis=3) @@ -568,7 +576,7 @@ def _t_coarsen_avg(self, tensor): """Perform temporal coarsening on a 5D tensor of shape (n_obs, spatial_1, spatial_2, time, features)""" shape = tf.shape(tensor) - _assert_rank_in( + tensor = _assert_rank_in( tensor, (5,), 'LowResLoss temporal coarsening expects 5D tensors', @@ -759,8 +767,8 @@ def call(self, x_true, x_gen): f'The {self.__class__.__name__} is meant to be used on spatial or ' 'spatiotemporal data only. Received tensor(s) that are not 4/5D' ) - _assert_rank_in(x_true, (4, 5), msg) - _assert_rank_in(x_gen, (4, 5), msg) + x_true = _assert_rank_in(x_true, (4, 5), msg) + x_gen = _assert_rank_in(x_gen, (4, 5), msg) if x_true.shape.rank == 4: x_true = tf.expand_dims(x_true, axis=3) @@ -883,8 +891,8 @@ def call(self, x_true, x_gen): f'The {self.__class__.__name__} is meant to be used on ' 'spatiotemporal data only. Received tensor(s) that are not 5D' ) - _assert_rank_in(x_true, (5,), msg) - _assert_rank_in(x_gen, (5,), msg) + x_true = _assert_rank_in(x_true, (5,), msg) + x_gen = _assert_rank_in(x_gen, (5,), msg) x_true_div = tf.stack( [ @@ -1078,7 +1086,7 @@ def call(self, __, x_gen): 'or spatiotemporal data only. Received tensor(s) that are not ' '4D or 5D' ) - _assert_rank_in(x_gen, (4, 5), msg) + x_gen = _assert_rank_in(x_gen, (4, 5), msg) expr = self._compute_heat_transfer_residual(x_gen) return self.LOSS_METRIC(tf.zeros_like(expr), expr) @@ -1177,7 +1185,7 @@ def call(self, __, x_gen): 'or spatiotemporal data only. Received tensor(s) that are not ' '4D or 5D' ) - _assert_rank_in(x_gen, (4, 5), msg) + x_gen = _assert_rank_in(x_gen, (4, 5), msg) temp_grads = self._compute_temperature_gradient(x_gen) return self.LOSS_METRIC(tf.zeros_like(temp_grads), temp_grads) diff --git a/tests/utilities/test_loss_metrics.py b/tests/utilities/test_loss_metrics.py index 0ea2340da..3fa71c8a6 100644 --- a/tests/utilities/test_loss_metrics.py +++ b/tests/utilities/test_loss_metrics.py @@ -288,7 +288,7 @@ def test_md_loss(): with pytest.raises(ValueError): tf_derivative(x, axis=0) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): md_loss(x[..., 0], y[..., 0]) assert np.allclose(u_div, u_div_np) @@ -346,7 +346,7 @@ def test_geothermal_heat_transfer_loss_depth_intersection_and_errors(): loss_obj = GeothermalConductiveHeatTransferLoss( dx=1, dy=1, depths=[0, 1000, 2000] ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): loss_obj(np.zeros((2, 4, 6)), np.zeros((2, 4, 6))) @@ -418,7 +418,7 @@ def test_geothermal_heat_transfer_loss_errors(): loss_obj = GeothermalConductiveHeatTransferLoss( dx=dx, dy=dy, depths=[0, 1, 2, 3] ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): loss_obj(np.zeros((2, 4, 6)), np.zeros((2, 4, 6))) @@ -429,7 +429,7 @@ def test_geothermal_temp_grad_loss_depth_intersection_and_errors(): GeothermalPositiveTemperatureGradientLoss(depths=[0]) loss_obj = GeothermalPositiveTemperatureGradientLoss(depths=[0, 1, 2]) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): loss_obj(np.zeros((2, 4, 6)), np.zeros((2, 4, 6))) From f5f318d3a4d0fb7f0eed8fd4a9db58999c7da6ff Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 5 May 2026 08:15:01 -0600 Subject: [PATCH 11/34] refactor: change logger info to debug level for training loss and batch step messages in Sup3rGan --- sup3r/models/abstract.py | 3 +-- sup3r/models/base.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 1f55939c7..06dcc0c68 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -509,8 +509,7 @@ def load_saved_params(out_dir, verbose=True): version_record = params.pop('version_record') if verbose: logger.info( - 'Loading model from disk ' - 'that was created with the ' + 'Loading model from disk that was created with the ' 'following package versions: \n{}'.format( pprint.pformat(version_record, indent=2) ) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 4cfc87ee8..57a58a720 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -1044,7 +1044,7 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means): disc_loss = self._train_record['train_loss_disc'].values.mean() gen_loss = self._train_record['train_loss_gen'].values.mean() - logger.info( + logger.debug( 'Batch {} out of {} has (gen / disc) loss of: ({:.2e} / {:.2e}). ' 'Running mean (gen / disc): ({:.2e} / {:.2e}). Trained ' '(gen / disc): ({} / {})'.format( @@ -1153,7 +1153,7 @@ def _train_epoch( batch_step_time = time.time() - start batch_load_time = total_step_time - batch_step_time - logger.info( + logger.debug( f'Finished batch step {ib + 1} / {len(batch_handler)} in ' f'{total_step_time:.4f} seconds. Batch load time: ' f'{batch_load_time:.4f} seconds. Batch train time: ' From 3cc0326cbd13e2265821b6259c94dd230a69182d Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 6 May 2026 08:59:08 -0600 Subject: [PATCH 12/34] fix: update TensorFlow function decorators to reduce retracing and improve performance; adjust TrainingConfig defaults for checkpoint interval and adaptive update bounds --- sup3r/models/abstract.py | 6 +++--- sup3r/models/base.py | 22 ++++++++++++++-------- sup3r/models/utilities.py | 4 ++-- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 06dcc0c68..6bdd05566 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1069,7 +1069,7 @@ def _run_mirrored_grad( apply_fn(total_grad) return mean_loss_details - @tf.function + @tf.function(reduce_retracing=True) def _run_serial_grad( self, low_res, @@ -1459,7 +1459,7 @@ def _get_train_fns(self, train_gen=True, train_disc=False): logger.error(msg) raise ValueError(msg) - @tf.function + @tf.function(reduce_retracing=True) def get_single_grad_gen( self, low_res, @@ -1482,7 +1482,7 @@ def apply_grad_gen(self, grad): """Apply a generator gradient update.""" self.optimizer.apply_gradients(zip(grad, self.generator_weights)) - @tf.function + @tf.function(reduce_retracing=True) def get_single_grad_disc( self, low_res, diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 57a58a720..77287401f 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -711,11 +711,20 @@ def train(self, batch_handler, config=None, **kwargs): config.weight_gen_advers, config.n_epoch, epochs[0] ) ) + + lr_shape, hr_shape = batch_handler.shapes + self.init_weights(lr_shape, hr_shape, train_disc=config.train_disc) + weight_gen_advers = config.weight_gen_advers for epoch in epochs: t_epoch = time.time() + # convert to tensor to avoid retracing when using adaptive updating + # of adversarial weight. + weight_gen_advers = tf.convert_to_tensor( + weight_gen_advers, dtype=tf.float32 + ) loss_details = self._train_epoch( batch_handler, - config.weight_gen_advers, + weight_gen_advers, config.train_gen, config.train_disc, config.disc_loss_bounds, @@ -723,7 +732,7 @@ def train(self, batch_handler, config=None, **kwargs): export_tb=config.export_tb, ) loss_details.update( - self.calc_val_loss(batch_handler, config.weight_gen_advers) + self.calc_val_loss(batch_handler, weight_gen_advers) ) msg = f'Epoch {epoch} of {epochs[-1]} ' @@ -741,7 +750,7 @@ def train(self, batch_handler, config=None, **kwargs): logger.info(msg) extras = { - 'weight_gen_advers': config.weight_gen_advers, + 'weight_gen_advers': weight_gen_advers, 'disc_loss_bound_0': config.disc_loss_bounds[0], 'disc_loss_bound_1': config.disc_loss_bounds[1], } @@ -753,11 +762,11 @@ def train(self, batch_handler, config=None, **kwargs): extras.update(opt_g) extras.update(opt_d) - config.weight_gen_advers = self.update_adversarial_weights( + weight_gen_advers = self.update_adversarial_weights( loss_details, config.adaptive_update_fraction, config.adaptive_update_bounds, - config.weight_gen_advers, + weight_gen_advers, config.train_disc, ) @@ -1108,9 +1117,6 @@ def _train_epoch( loss_details : dict Namespace of the breakdown of loss components """ - lr_shape, hr_shape = batch_handler.shapes - self.init_weights(lr_shape, hr_shape, train_disc=train_disc) - disc_th_low = np.min(disc_loss_bounds) disc_th_high = np.max(disc_loss_bounds) loss_means = self._train_record.mean().to_dict() diff --git a/sup3r/models/utilities.py b/sup3r/models/utilities.py index 57740848a..1afffadcc 100644 --- a/sup3r/models/utilities.py +++ b/sup3r/models/utilities.py @@ -71,12 +71,12 @@ class TrainingConfig: train_gen: bool = True train_disc: bool = True disc_loss_bounds: tuple = (0.45, 0.6) - checkpoint_int: Optional[int] = None + checkpoint_int: int = 10 out_dir: Optional[str] = None early_stop_on: Optional[str] = None early_stop_threshold: float = 0.005 early_stop_n_epoch: int = 5 - adaptive_update_bounds: tuple = (0.9, 0.99) + adaptive_update_bounds: tuple = (0.0, 1.0) adaptive_update_fraction: float = 0.0 multi_gpu: bool = False log_tb: bool = False From dbcc297b2be8f0154d8832b674f097bb65e89490 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 9 May 2026 16:10:39 -0600 Subject: [PATCH 13/34] Enhance logging across multiple modules - Added logging statements to provide detailed information on the execution flow in `batch_cli.py`, `bias_calc_cli.py`, `abstract.py`, `base.py`, `forward_pass.py`, `forward_pass_cli.py`, `pipeline_cli.py`, `strategy.py`, `h5.py`, `nc.py`, `data_collect_cli.py`, `accessor.py`, `abstract.py`, `stats.py`, `base.py`, `exo.py`, `loaders/base.py`, `loaders/h5.py`, `rasterizers/base.py`, `rasterizers/exo.py`, `rasterizers/extended.py`, and `utilities.py`. - Changed several logging levels from `info` to `debug` to reduce verbosity and ensure that only essential information is logged at the `info` level. - Improved the clarity of log messages by using formatted strings for better readability. - Added logging for job queueing and completion in the solar CLI to track the processing of nodes. --- sup3r/batch/batch_cli.py | 15 ++++ sup3r/bias/bias_calc_cli.py | 87 ++++++++++++++----- sup3r/models/abstract.py | 12 +-- sup3r/models/base.py | 4 +- sup3r/pipeline/forward_pass.py | 26 ++++-- sup3r/pipeline/forward_pass_cli.py | 17 ++++ sup3r/pipeline/pipeline_cli.py | 8 ++ sup3r/pipeline/strategy.py | 19 +++-- sup3r/postprocessing/collectors/h5.py | 36 ++++---- sup3r/postprocessing/collectors/nc.py | 6 +- sup3r/postprocessing/data_collect_cli.py | 88 +++++++++++++++----- sup3r/preprocessing/accessor.py | 14 ++-- sup3r/preprocessing/batch_queues/abstract.py | 4 +- sup3r/preprocessing/collections/stats.py | 27 +++--- sup3r/preprocessing/data_handlers/base.py | 11 +-- sup3r/preprocessing/data_handlers/exo.py | 13 +-- sup3r/preprocessing/loaders/base.py | 2 +- sup3r/preprocessing/loaders/h5.py | 2 +- sup3r/preprocessing/rasterizers/base.py | 6 +- sup3r/preprocessing/rasterizers/exo.py | 8 +- sup3r/preprocessing/rasterizers/extended.py | 11 +-- sup3r/preprocessing/utilities.py | 18 ++-- sup3r/solar/solar_cli.py | 8 ++ 23 files changed, 309 insertions(+), 133 deletions(-) diff --git a/sup3r/batch/batch_cli.py b/sup3r/batch/batch_cli.py index 2c1f99450..0f8b9bbd4 100644 --- a/sup3r/batch/batch_cli.py +++ b/sup3r/batch/batch_cli.py @@ -1,10 +1,14 @@ # pylint: disable=all """Batch Job CLI entry points.""" +import logging + import click from gaps.batch import BatchJob from sup3r import __version__ +logger = logging.getLogger(__name__) + @click.group() @click.version_option(version=__version__) @@ -41,6 +45,15 @@ def from_config(ctx, config_file, dry_run, cancel, delete, monitor_background, """Run Sup3r batch from a config file.""" ctx.ensure_object(dict) ctx.obj['VERBOSE'] = verbose or ctx.obj.get('VERBOSE', False) + logger.info( + 'Starting batch job from %s (dry_run=%s, cancel=%s, delete=%s, ' + 'monitor_background=%s).', + config_file, + dry_run, + cancel, + delete, + monitor_background, + ) batch = BatchJob(config_file) if cancel: @@ -50,6 +63,8 @@ def from_config(ctx, config_file, dry_run, cancel, delete, monitor_background, else: batch.run(dry_run=dry_run, monitor_background=monitor_background) + logger.info('Finished batch CLI invocation for %s.', config_file) + if __name__ == '__main__': main(obj={}) diff --git a/sup3r/bias/bias_calc_cli.py b/sup3r/bias/bias_calc_cli.py index baa760664..39ccc99e5 100644 --- a/sup3r/bias/bias_calc_cli.py +++ b/sup3r/bias/bias_calc_cli.py @@ -1,4 +1,5 @@ """sup3r bias correction calculation CLI entry points.""" + import copy import logging import os @@ -15,8 +16,12 @@ @click.group() @click.version_option(version=__version__) -@click.option('-v', '--verbose', is_flag=True, - help='Flag to turn on debug logging. Default is not verbose.') +@click.option( + '-v', + '--verbose', + is_flag=True, + help='Flag to turn on debug logging. Default is not verbose.', +) @click.pass_context def main(ctx, verbose): """Sup3r bias calc Command Line Interface""" @@ -25,16 +30,25 @@ def main(ctx, verbose): @main.command() -@click.option('--config_file', '-c', required=True, - type=click.Path(exists=True), - help='sup3r bias correction calculation config .json file.') -@click.option('-v', '--verbose', is_flag=True, - help='Flag to turn on debug logging. Default is not verbose.') +@click.option( + '--config_file', + '-c', + required=True, + type=click.Path(exists=True), + help='sup3r bias correction calculation config .json file.', +) +@click.option( + '-v', + '--verbose', + is_flag=True, + help='Flag to turn on debug logging. Default is not verbose.', +) @click.pass_context def from_config(ctx, config_file, verbose=False, pipeline_step=None): """Run sup3r bias correction calculation from a config file.""" - config = BaseCLI.from_config_preflight(ModuleName.BIAS_CALC, ctx, - config_file, verbose) + config = BaseCLI.from_config_preflight( + ModuleName.BIAS_CALC, ctx, config_file, verbose + ) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.pop('option', 'local') @@ -44,31 +58,57 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): log_pattern = config.get('log_pattern', None) jobs = config['jobs'] + logger.info( + 'Preparing bias calculation from %s with hardware=%s across %s jobs ' + 'using %s.', + config_file, + hardware_option, + len(jobs), + calc_class_name, + ) for i_node, job in enumerate(jobs): node_config = copy.deepcopy(job) node_config['status_dir'] = config['status_dir'] node_config['log_file'] = ( - log_pattern if log_pattern is None - else os.path.normpath(log_pattern.format(node_index=i_node))) - name = ('{}_{}'.format(basename, str(i_node).zfill(6))) + log_pattern + if log_pattern is None + else os.path.normpath(log_pattern.format(node_index=i_node)) + ) + name = '{}_{}'.format(basename, str(i_node).zfill(6)) ctx.obj['NAME'] = name node_config['job_name'] = name - node_config["pipeline_step"] = pipeline_step + node_config['pipeline_step'] = pipeline_step cmd = BiasCalcClass.get_node_cmd(node_config) cmd_log = '\n\t'.join(cmd.split('\n')) logger.debug(f'Running command:\n\t{cmd_log}') + logger.info( + 'Queueing bias calculation node %s as job "%s".', + i_node, + name, + ) if hardware_option.lower() in AVAILABLE_HARDWARE_OPTIONS: kickoff_slurm_job(ctx, cmd, pipeline_step, **exec_kwargs) else: kickoff_local_job(ctx, cmd, pipeline_step) - -def kickoff_slurm_job(ctx, cmd, pipeline_step=None, alloc='sup3r', - memory=None, walltime=4, feature=None, - stdout_path='./stdout/'): + logger.info( + 'Finished queueing bias calculation work for %s jobs.', len(jobs) + ) + + +def kickoff_slurm_job( + ctx, + cmd, + pipeline_step=None, + alloc='sup3r', + memory=None, + walltime=4, + feature=None, + stdout_path='./stdout/', +): """Run sup3r on HPC via SLURM job submission. Parameters @@ -94,8 +134,17 @@ def kickoff_slurm_job(ctx, cmd, pipeline_step=None, alloc='sup3r', stdout_path : str Path to print .stdout and .stderr files. """ - BaseCLI.kickoff_slurm_job(ModuleName.BIAS_CALC, ctx, cmd, alloc, memory, - walltime, feature, stdout_path, pipeline_step) + BaseCLI.kickoff_slurm_job( + ModuleName.BIAS_CALC, + ctx, + cmd, + alloc, + memory, + walltime, + feature, + stdout_path, + pipeline_step, + ) def kickoff_local_job(ctx, cmd, pipeline_step=None): diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 6bdd05566..f4396e00c 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -202,11 +202,11 @@ def set_norm_stats(self, new_means, new_stdevs): if new_means is not None and new_stdevs is not None: logger.info('Setting new normalization statistics...') - logger.info( + logger.debug( "Model's previous data mean values:\n%s", pprint.pformat(self._means, indent=2), ) - logger.info( + logger.debug( "Model's previous data stdev values:\n%s", pprint.pformat(self._stdevs, indent=2), ) @@ -239,11 +239,11 @@ def set_norm_stats(self, new_means, new_stdevs): f'in new means array: {self._means}' ) - logger.info( + logger.debug( 'Set data normalization mean values:\n%s', pprint.pformat(self._means, indent=2), ) - logger.info( + logger.debug( 'Set data normalization stdev values:\n%s', pprint.pformat(self._stdevs, indent=2), ) @@ -508,7 +508,7 @@ def load_saved_params(out_dir, verbose=True): if 'version_record' in params: version_record = params.pop('version_record') if verbose: - logger.info( + logger.debug( 'Loading model from disk that was created with the ' 'following package versions: \n{}'.format( pprint.pformat(version_record, indent=2) @@ -953,7 +953,7 @@ def finish_epoch( stop : bool Flag to early stop training. """ - self.log_loss_details(loss_details) + self.log_loss_details(loss_details, level='INFO') self._history.at[epoch, 'elapsed_time'] = time.time() - t0 entry = np.vstack(list(loss_details.values())).T self._history.loc[epoch, list(loss_details.keys())] = entry diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 77287401f..d872b8532 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -202,7 +202,7 @@ def _load(cls, model_dir, verbose=True): msg = 'Active python environment versions: \n{}'.format( pprint.pformat(VERSION_RECORD, indent=4) ) - logger.info(msg) + logger.debug(msg) fp_gen = os.path.join(model_dir, 'model_gen.pkl') fp_disc = os.path.join(model_dir, 'model_disc.pkl') @@ -782,7 +782,7 @@ def train(self, batch_handler, config=None, **kwargs): config.early_stop_n_epoch, extras=extras, ) - logger.info( + logger.debug( 'Finished training epoch in {:.4f} seconds'.format( time.time() - t_epoch ) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index e37ff4159..0f530789e 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -145,7 +145,7 @@ def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'): """ out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode) - logger.info( + logger.debug( 'Padded input data shape from %s to %s using mode "%s" ' 'with padding argument: %s', input_data.shape, @@ -177,8 +177,13 @@ def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'): ) new_exo = np.pad(new_exo, exo_pad_width, mode=mode) exo_data[feature]['steps'][i]['data'] = new_exo - logger.info( - f'Got exo data for feature: {feature}, model step: {i}' + logger.debug( + 'Got exo data for feature: %s, model step: %d, ' + 's_enhance: %d, t_enhance: %d', + feature, + i, + s_enhance, + t_enhance, ) return out, exo_data @@ -442,6 +447,13 @@ def run(cls, strategy, node_index): will be run. """ if not strategy.node_finished(node_index): + logger.info( + 'Starting forward pass on node %s with %s chunks using %s ' + 'execution.', + node_index, + len(strategy.node_chunks[node_index]), + 'serial' if strategy.pass_workers == 1 else 'parallel', + ) if strategy.pass_workers == 1: cls._run_serial(strategy, node_index) else: @@ -484,7 +496,7 @@ def _run_serial(cls, strategy, node_index): overwrite=(not strategy.incremental), meta=fwp.meta, ) - logger.info( + logger.debug( 'Finished forward pass on chunk_index=' f'{chunk_index} in {dt.now() - now}. {i + 1} of ' f'{len(strategy.node_chunks[node_index])} ' @@ -570,7 +582,7 @@ def _run_parallel(cls, strategy, node_index): f'{chunk_idx} in {dt.now() - start_time}. ' f'{i + 1} of {len(futures)} complete. {_mem_check()}' ) - logger.info(msg) + logger.debug(msg) except Exception as e: msg = ( 'Error running forward pass on chunk_index=' @@ -648,7 +660,7 @@ def run_chunk( """ msg = f'Running forward pass for chunk_index={chunk.index}.' - logger.info(msg) + logger.debug(msg) model = get_model(model_class, model_kwargs) @@ -685,7 +697,7 @@ def run_chunk( failed = cls._output_check(output_data, allowed_const=allowed_const) if chunk.out_file is not None and not failed: - logger.info(f'Saving forward pass output to {chunk.out_file}.') + logger.debug(f'Saving forward pass output to {chunk.out_file}.') output_type = get_source_type(chunk.out_file) cls.OUTPUT_HANDLER_CLASS[output_type]._write_output( data=output_data, diff --git a/sup3r/pipeline/forward_pass_cli.py b/sup3r/pipeline/forward_pass_cli.py index eaf380dd1..ff42f4b51 100644 --- a/sup3r/pipeline/forward_pass_cli.py +++ b/sup3r/pipeline/forward_pass_cli.py @@ -62,6 +62,13 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): strategy_kwargs = {k: v for k, v in config.items() if k in sig.parameters} strategy = ForwardPassStrategy(**strategy_kwargs, head_node=True) + logger.info( + 'Preparing forward pass from %s with hardware=%s across %s nodes.', + config_file, + hardware_option, + len(strategy.node_chunks), + ) + if node_index is not None: nodes = ( [node_index] if not isinstance(node_index, list) else node_index @@ -82,11 +89,21 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): node_config['pipeline_step'] = pipeline_step cmd = ForwardPass.get_node_cmd(node_config) + logger.info( + 'Queueing forward pass node %s as job "%s".', + i_node, + name, + ) + if hardware_option.lower() in AVAILABLE_HARDWARE_OPTIONS: kickoff_slurm_job(ctx, cmd, pipeline_step, **exec_kwargs) else: kickoff_local_job(ctx, cmd, pipeline_step) + logger.info( + 'Finished queueing forward pass work for %s nodes.', len(nodes) + ) + def kickoff_slurm_job( ctx, diff --git a/sup3r/pipeline/pipeline_cli.py b/sup3r/pipeline/pipeline_cli.py index 1a62add9c..d652fc89b 100644 --- a/sup3r/pipeline/pipeline_cli.py +++ b/sup3r/pipeline/pipeline_cli.py @@ -42,7 +42,15 @@ def from_config(ctx, config_file, cancel, monitor, background, verbose=False): """Run sup3r pipeline from a config file.""" ctx.ensure_object(dict) ctx.obj['VERBOSE'] = verbose or ctx.obj.get('VERBOSE', False) + logger.info( + 'Starting pipeline from %s (cancel=%s, monitor=%s, background=%s).', + config_file, + cancel, + monitor, + background, + ) pipeline(config_file, cancel, monitor, background) + logger.info('Finished pipeline invocation for %s.', config_file) if __name__ == '__main__': diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index c4f67a467..54ac35fe4 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -365,6 +365,7 @@ def _init_features(self, model): """Initialize feature attributes.""" self.exo_handler_kwargs = self.exo_handler_kwargs or {} exo_features = list(self.exo_handler_kwargs) + exo_features = [f for f in exo_features if f in model.hr_exo_features] features = [f for f in model.lr_features if f not in exo_features] return features, exo_features @@ -455,7 +456,7 @@ def hr_lat_lon(self): """Get high resolution lat lons""" lr_lat_lon = self.input_handler.lat_lon shape = tuple(d * self.s_enhance for d in lr_lat_lon.shape[:-1]) - logger.info( + logger.debug( f'Getting high-resolution grid for full output domain: {shape}' ) return OutputHandler.get_lat_lon(lr_lat_lon, shape) @@ -512,13 +513,13 @@ def prep_chunk_data(self, chunk_index=0): kwargs = dict(zip(Dimension.dims_2d(), lr_pad_slice)) kwargs[Dimension.TIME] = ti_pad_slice input_data = self.input_handler[self.features].isel(**kwargs) - logger.info( + logger.debug( 'Loading data for chunk_index=%s into memory.', chunk_index ) input_data.load() if self.bias_correct_kwargs != {}: - logger.info( + logger.debug( f'Bias correcting data for chunk_index={chunk_index}, ' f'with shape={input_data.shape}' ) @@ -570,12 +571,12 @@ def init_chunk(self, chunk_index=0): 'lr_pad_slice': lr_pad_slice, 'ti_pad_slice': ti_pad_slice, } - logger.info( + logger.debug( 'Initializing ForwardPassChunk with: ' f'{pprint.pformat(args_dict, indent=2)}' ) - logger.info(f'Getting input data for chunk_index={chunk_index}.') + logger.debug(f'Getting input data for chunk_index={chunk_index}.') input_data, exo_data = self.timer( self.prep_chunk_data, log=True, call_id=chunk_index @@ -661,13 +662,13 @@ def fwp_mask(self): sup3r.pipeline.strategy.ForwardPassStrategy """ mask = np.zeros(len(self.lr_pad_slices)) - logger.info('Checking for mask in input handler.') + logger.debug('Checking for mask in input handler.') input_handler_kwargs = copy.deepcopy(self.input_handler_kwargs) try: InputHandler = get_input_handler_class(self.input_handler_name) input_handler_kwargs['features'] = ['mask'] handler = InputHandler(**input_handler_kwargs) - logger.info( + logger.debug( 'Found "mask" in DataHandler. Computing forward pass ' 'chunk mask for %s chunks', len(self.lr_pad_slices), @@ -698,7 +699,7 @@ def chunk_finished(self, chunk_idx, log=True): and self.incremental ) if check and log: - logger.info( + logger.debug( '%s already exists and incremental = True. Skipping forward ' 'pass for chunk index %s.', out_file, @@ -713,7 +714,7 @@ def chunk_masked(self, chunk_idx, log=True): s_chunk_idx, _ = self.fwp_slicer.get_chunk_indices(chunk_idx) mask_check = self.fwp_mask[s_chunk_idx] if mask_check and log: - logger.info( + logger.debug( 'Chunk %s has spatial chunk index %s, which corresponds to a ' 'masked spatial region. Skipping forward pass for this chunk.', chunk_idx, diff --git a/sup3r/postprocessing/collectors/h5.py b/sup3r/postprocessing/collectors/h5.py index 6f9d88fc6..55885181a 100644 --- a/sup3r/postprocessing/collectors/h5.py +++ b/sup3r/postprocessing/collectors/h5.py @@ -265,8 +265,8 @@ def get_unique_chunk_files(self, file_paths): t_files = list(t_files.values()) s_files = list(s_files.values()) - logger.info('Found %s unique temporal chunks', len(t_files)) - logger.info('Found %s unique spatial chunks', len(s_files)) + logger.debug('Found %s unique temporal chunks', len(t_files)) + logger.debug('Found %s unique spatial chunks', len(s_files)) return t_files, s_files def _get_collection_attrs(self, file_paths, max_workers=None): @@ -309,7 +309,7 @@ def _get_collection_attrs(self, file_paths, max_workers=None): time_index = dask.compute( *ti_tasks, scheduler='threads', num_workers=max_workers ) - logger.info( + logger.debug( 'Finished getting meta and time_index for all unique chunks.' ) time_index = pd.DatetimeIndex(np.concatenate(time_index)) @@ -323,7 +323,7 @@ def _get_collection_attrs(self, file_paths, max_workers=None): meta = meta.drop_duplicates(subset=['latitude', 'longitude']) meta = meta.sort_values('gid') - logger.info('Finished building full meta and time index.') + logger.debug('Finished building full meta and time index.') return time_index, meta def get_target_and_masked_meta( @@ -363,12 +363,12 @@ def get_target_and_masked_meta( target_meta, meta, threshold=threshold ) masked_meta = meta.iloc[mask] - logger.info(f'Masked meta coordinates: {len(masked_meta)}') + logger.debug('Masked meta coordinates: %s', len(masked_meta)) mask = self.get_coordinate_indices( masked_meta, target_meta, threshold=threshold ) target_meta = target_meta.iloc[mask] - logger.info(f'Target meta coordinates: {len(target_meta)}') + logger.debug('Target meta coordinates: %s', len(target_meta)) else: target_meta = masked_meta = meta @@ -418,7 +418,7 @@ def get_collection_attrs( that all the files in file_paths have the same global file attributes). """ - logger.info(f'Using target_meta_file={target_meta_file}') + logger.debug('Using target_meta_file=%s', target_meta_file) if isinstance(target_meta_file, str): msg = f'Provided target meta ({target_meta_file}) does not exist.' assert os.path.exists(target_meta_file), msg @@ -426,14 +426,14 @@ def get_collection_attrs( time_index, meta = self._get_collection_attrs( file_paths, max_workers=max_workers ) - logger.info('Getting target and masked meta.') + logger.debug('Getting target and masked meta.') target_meta, masked_meta = self.get_target_and_masked_meta( meta, target_meta_file, threshold=threshold ) shape = (len(time_index), len(target_meta)) - logger.info('Getting global attrs from %s', file_paths[0]) + logger.debug('Getting global attrs from %s', file_paths[0]) with RexOutputs(file_paths[0], mode='r') as fin: global_attrs = fin.global_attrs @@ -588,10 +588,10 @@ def _collect_flist( else: msg = ( 'No target coordinates found in masked meta. Skipping ' - f'collection for {file_paths}.' + 'collection for %s.' ) - logger.warning(msg) - warn(msg) + logger.warning(msg, file_paths) + warn(msg % file_paths) def get_flist_chunks(self, file_paths, n_writes=None): """Group files by temporal_chunk_index and then combines these groups @@ -620,8 +620,10 @@ def get_flist_chunks(self, file_paths, n_writes=None): if n_writes is not None and n_writes > len(file_chunks): logger.info( - f'n_writes ({n_writes}) too big, setting to the number ' - f'of temporal chunks ({len(file_chunks)}).' + 'n_writes (%s) too big, setting to the number of temporal ' + 'chunks (%s).', + n_writes, + len(file_chunks), ) n_writes = len(file_chunks) @@ -635,7 +637,9 @@ def get_flist_chunks(self, file_paths, n_writes=None): n_writes, ) - logger.debug(f'Grouped file list into {len(file_chunks)} time chunks.') + logger.debug( + 'Grouped file list into %s time chunks.', len(file_chunks) + ) return flist_chunks @@ -698,7 +702,7 @@ def collect_feature( else: for i, flist in enumerate(flist_chunks): - logger.info( + logger.debug( 'Collecting file list chunk %s out of %s ', i + 1, len(flist_chunks), diff --git a/sup3r/postprocessing/collectors/nc.py b/sup3r/postprocessing/collectors/nc.py index dc509c874..ea73439e3 100644 --- a/sup3r/postprocessing/collectors/nc.py +++ b/sup3r/postprocessing/collectors/nc.py @@ -59,7 +59,7 @@ def collect( cacher_kwargs : dict | None Dictionary of kwargs to pass to Cacher._write_single. """ - logger.info(f'Initializing collection for file_paths={file_paths}') + logger.info('Initializing collection for file_paths=%s', file_paths) if log_level is not None: init_logger( @@ -71,10 +71,10 @@ def collect( collector = cls(file_paths) logger.info( - 'Collecting {} files to {}'.format(len(collector.flist), out_file) + 'Collecting %s files to %s', len(collector.flist), out_file ) if overwrite and os.path.exists(out_file): - logger.info(f'overwrite=True, removing {out_file}.') + logger.info('overwrite=True, removing %s.', out_file) os.remove(out_file) if not os.path.exists(out_file): diff --git a/sup3r/postprocessing/data_collect_cli.py b/sup3r/postprocessing/data_collect_cli.py index ea7d58a04..79aad0461 100644 --- a/sup3r/postprocessing/data_collect_cli.py +++ b/sup3r/postprocessing/data_collect_cli.py @@ -1,4 +1,5 @@ """sup3r data collection CLI entry points.""" + import copy import logging @@ -15,8 +16,12 @@ @click.group() @click.version_option(version=__version__) -@click.option('-v', '--verbose', is_flag=True, - help='Flag to turn on debug logging. Default is not verbose.') +@click.option( + '-v', + '--verbose', + is_flag=True, + help='Flag to turn on debug logging. Default is not verbose.', +) @click.pass_context def main(ctx, verbose): """Sup3r Data Collection Command Line Interface""" @@ -25,17 +30,26 @@ def main(ctx, verbose): @main.command() -@click.option('--config_file', '-c', required=True, - type=click.Path(exists=True), - help='sup3r data collection configuration json file.') -@click.option('-v', '--verbose', is_flag=True, - help='Flag to turn on debug logging. Default is not verbose.') +@click.option( + '--config_file', + '-c', + required=True, + type=click.Path(exists=True), + help='sup3r data collection configuration json file.', +) +@click.option( + '-v', + '--verbose', + is_flag=True, + help='Flag to turn on debug logging. Default is not verbose.', +) @click.pass_context def from_config(ctx, config_file, verbose=False, pipeline_step=None): """Run sup3r data collection from a config file. If dset_split is True this each feature will be collected into a separate file.""" - config = BaseCLI.from_config_preflight(ModuleName.DATA_COLLECT, ctx, - config_file, verbose) + config = BaseCLI.from_config_preflight( + ModuleName.DATA_COLLECT, ctx, config_file, verbose + ) dset_split = config.get('dset_split', False) exec_kwargs = config.get('execution_control', {}) @@ -44,21 +58,31 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): collector_types = {'h5': CollectorH5, 'nc': CollectorNC} Collector = collector_types[source_type] + logger.info( + 'Preparing data collection from %s to %s using %s output files.', + config_file, + config['out_file'], + source_type, + ) + configs = [config] if dset_split: configs = [] for feature in config['features']: f_config = copy.deepcopy(config) f_out_file = config['out_file'].replace( - f'.{source_type}', f'_{feature}.{source_type}') + f'.{source_type}', f'_{feature}.{source_type}' + ) f_job_name = config['job_name'] + f'_{feature}' f_log_file = config.get('log_file', None) if f_log_file is not None: f_log_file = f_log_file.replace('.log', f'_{feature}.log') - f_config.update({'features': [feature], - 'out_file': f_out_file, - 'job_name': f_job_name, - 'log_file': f_log_file}) + f_config.update({ + 'features': [feature], + 'out_file': f_out_file, + 'job_name': f_job_name, + 'log_file': f_log_file, + }) configs.append(f_config) @@ -67,11 +91,21 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): config['pipeline_step'] = pipeline_step cmd = Collector.get_node_cmd(config) + logger.info( + 'Queueing data collection job "%s" for %s features.', + config['job_name'], + len(config['features']), + ) + if hardware_option.lower() in AVAILABLE_HARDWARE_OPTIONS: kickoff_slurm_job(ctx, cmd, pipeline_step, **exec_kwargs) else: kickoff_local_job(ctx, cmd, pipeline_step) + logger.info( + 'Finished queueing data collection work for %s jobs.', len(configs) + ) + def kickoff_local_job(ctx, cmd, pipeline_step=None): """Run sup3r data collection locally. @@ -91,9 +125,16 @@ def kickoff_local_job(ctx, cmd, pipeline_step=None): BaseCLI.kickoff_local_job(ModuleName.DATA_COLLECT, ctx, cmd, pipeline_step) -def kickoff_slurm_job(ctx, cmd, pipeline_step=None, alloc='sup3r', - memory=None, walltime=4, feature=None, - stdout_path='./stdout/'): +def kickoff_slurm_job( + ctx, + cmd, + pipeline_step=None, + alloc='sup3r', + memory=None, + walltime=4, + feature=None, + stdout_path='./stdout/', +): """Run sup3r on HPC via SLURM job submission. Parameters @@ -119,8 +160,17 @@ def kickoff_slurm_job(ctx, cmd, pipeline_step=None, alloc='sup3r', stdout_path : str Path to print .stdout and .stderr files. """ - BaseCLI.kickoff_slurm_job(ModuleName.DATA_COLLECT, ctx, cmd, alloc, memory, - walltime, feature, stdout_path, pipeline_step) + BaseCLI.kickoff_slurm_job( + ModuleName.DATA_COLLECT, + ctx, + cmd, + alloc, + memory, + walltime, + feature, + stdout_path, + pipeline_step, + ) if __name__ == '__main__': diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 3a70abe0f..95bb9dcfe 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -275,18 +275,20 @@ def compute(self, **kwargs): it has not been loaded already.""" self._clear_array_cache() if not self.loaded: - logger.debug(f'Loading dataset into memory: {self._ds}') - logger.debug(f'Pre-loading: {_mem_check()}') + logger.debug('Loading dataset into memory: %s', self._ds) + logger.debug('Pre-loading: %s', _mem_check()) for f in list(self._ds.data_vars) + list(self._ds.coords): if hasattr(self._ds[f], 'compute'): self._ds[f] = self._ds[f].compute(**kwargs) logger.debug( - f'Loaded {f} into memory with shape ' - f'{self._ds[f].shape}. {_mem_check()}' + 'Loaded %s into memory with shape %s. %s', + f, + self._ds[f].shape, + _mem_check(), ) - logger.debug(f'Loaded dataset into memory: {self._ds}') - logger.debug(f'Post-loading: {_mem_check()}') + logger.debug('Loaded dataset into memory: %s', self._ds) + logger.debug('Post-loading: %s', _mem_check()) self._loaded = True return self diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index f848d64b3..cceddf123 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -218,14 +218,14 @@ def start(self) -> None: and self.mode == 'lazy' and self.queue_cap > 0 ): - logger.info(f'Starting {self._thread_name} queue.') + logger.info('Starting %s queue.', self._thread_name) self.queue_thread.start() def stop(self) -> None: """Stop loading batches.""" self._training_flag.clear() if self.queue_thread.is_alive(): - logger.info(f'Stopping {self._thread_name} queue.') + logger.info('Stopping %s queue.', self._thread_name) self.queue_thread.join() def __len__(self): diff --git a/sup3r/preprocessing/collections/stats.py b/sup3r/preprocessing/collections/stats.py index 1671a05ed..e17f8da2e 100644 --- a/sup3r/preprocessing/collections/stats.py +++ b/sup3r/preprocessing/collections/stats.py @@ -115,7 +115,7 @@ def get_means(self, means): means = self._init_stats_dict(means) needed_features = set(self.features) - set(means) if any(needed_features): - logger.info(f'Getting means for {needed_features}.') + logger.debug('Getting means for %s.', needed_features) cmeans = [ cm * w for cm, w in zip( @@ -124,7 +124,7 @@ def get_means(self, means): ) ] for f in needed_features: - logger.info(f'Computing mean for {f}.') + logger.debug('Computing mean for %s.', f) means[f] = np.float32(np.nansum([cm[f] for cm in cmeans])) return means @@ -134,13 +134,13 @@ def get_stds(self, stds): stds = self._init_stats_dict(stds) needed_features = set(self.features) - set(stds) if any(needed_features): - logger.info(f'Getting stds for {needed_features}.') + logger.debug('Getting stds for %s.', needed_features) cstds = [ w * cm**2 for cm, w in zip(self._get_stat('std'), self.container_weights) ] for f in needed_features: - logger.info(f'Computing std for {f}.') + logger.debug('Computing std for %s.', f) stds[f] = np.float32( np.sqrt(np.nansum([cs[f] for cs in cstds])) ) @@ -159,22 +159,25 @@ def save_stats(self, stds, means): with open(stds, 'w') as f: f.write(safe_serialize(self.stds)) logger.info( - f'Saved standard deviations {self.stds} to {stds}.' + 'Saved standard deviations %s to %s.', + self.stds, + stds, ) if isinstance(means, str) and ( not os.path.exists(means) or self._added_stats(means, self.means) ): with open(means, 'w') as f: f.write(safe_serialize(self.means)) - logger.info(f'Saved means {self.means} to {means}.') + logger.info('Saved means %s to %s.', self.means, means) def normalize(self, containers): """Normalize container data with computed stats.""" - logger.debug( - 'Normalizing containers with:\n' - f'means: {pprint.pformat(self.means, indent=2)}\n' - f'stds: {pprint.pformat(self.stds, indent=2)}' - ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + 'Normalizing containers with:\nmeans: %s\nstds: %s', + pprint.pformat(self.means, indent=2), + pprint.pformat(self.stds, indent=2), + ) for i, c in enumerate(containers): - logger.info(f'Normalizing container {i + 1}') + logger.debug('Normalizing container %s', i + 1) c.normalize(means=self.means, stds=self.stds) diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 3eff51b2f..8e72faf7d 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -224,7 +224,7 @@ def __init__( ) self.rasterizer = self.loader = self.cache if any(missing_features) or just_coords: - logger.info('%s not found in cache', missing_features) + logger.debug('%s not found in cache', missing_features) self.rasterizer = Rasterizer( file_paths=file_paths, res_kwargs=res_kwargs, @@ -338,8 +338,8 @@ def _deriver_hook(self): n_data_days = int(len(self.time_index) / day_steps) logger.info( - 'Calculating daily average datasets for {} training ' - 'data days.'.format(n_data_days) + 'Calculating daily average datasets for %s training data days.', + n_data_days, ) daily_data = self.data.coarsen(time=day_steps).mean() feats = [f for f in self.features if 'clearsky_ratio' not in f] @@ -370,8 +370,9 @@ def _deriver_hook(self): ) logger.info( - 'Finished calculating daily average datasets for {} ' - 'training data days.'.format(n_data_days) + 'Finished calculating daily average datasets for %s training ' + 'data days.', + n_data_days, ) hourly_data = self.data[self.requested_features] daily_data = daily_data[self.requested_features] diff --git a/sup3r/preprocessing/data_handlers/exo.py b/sup3r/preprocessing/data_handlers/exo.py index e1aed21fc..e3bacad24 100644 --- a/sup3r/preprocessing/data_handlers/exo.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -252,7 +252,7 @@ def get_chunk(self, lr_slices): :class:`SingleExoDataStep` objects. This is the sliced exo data for the extent specified by `lr_slices`. """ - logger.debug(f'Getting exo data chunk for lr_slices={lr_slices}.') + logger.debug('Getting exo data chunk for lr_slices=%s.', lr_slices) exo_chunk = {f: {'steps': []} for f in self} for feature in self: for step in self[feature]['steps']: @@ -369,13 +369,10 @@ def get_exo_steps(cls, feature, models): steps = [] for i, model in enumerate(models): is_sfc_model = model.__class__.__name__ == 'SurfaceSpatialMetModel' - obs_features = getattr(model, 'obs_features', []) if feature.lower() in _lowered(model.lr_features) or is_sfc_model: steps.append({'model': i, 'combine_type': 'input'}) if feature.lower() in _lowered(model.hr_exo_features): steps.append({'model': i, 'combine_type': 'layer'}) - if feature.lower() in _lowered(obs_features): - steps.append({'model': i, 'combine_type': 'layer'}) if ( feature.lower() in _lowered(model.hr_out_features) or is_sfc_model @@ -409,6 +406,12 @@ def cache_files(self): def get_all_step_data(self): """Get exo data for each model step.""" data = {self.feature: {'steps': []}} + logger.debug( + 'Getting exo data for all steps with s_enhancements=%s and ' + 't_enhancements=%s', + self.s_enhancements, + self.t_enhancements, + ) for i, (s_enhance, t_enhance) in enumerate( zip(self.s_enhancements, self.t_enhancements) ): @@ -428,7 +431,7 @@ def get_all_step_data(self): None if step is None else step.shape for step in data[self.feature]['steps'] ] - logger.info( + logger.debug( 'Got exogenous_data of length {} with shapes: {}'.format( len(data[self.feature]['steps']), shapes ) diff --git a/sup3r/preprocessing/loaders/base.py b/sup3r/preprocessing/loaders/base.py index 4c01ea6a4..656f9f664 100644 --- a/sup3r/preprocessing/loaders/base.py +++ b/sup3r/preprocessing/loaders/base.py @@ -97,7 +97,7 @@ def __init__( self.data.meta = self._res.meta if self.chunks is None: - logger.info(f'Pre-loading data into memory for: {features}') + logger.debug('Pre-loading data into memory for: %s', features) self.data.compute() def _parse_chunks(self, dims, feature=None): diff --git a/sup3r/preprocessing/loaders/h5.py b/sup3r/preprocessing/loaders/h5.py index 8a4762f98..266de6110 100644 --- a/sup3r/preprocessing/loaders/h5.py +++ b/sup3r/preprocessing/loaders/h5.py @@ -170,7 +170,7 @@ def _check_for_elevation(self, data_vars, dims, chunks): def _get_data_vars(self, dims): """Define data_vars dict for xr.Dataset construction.""" data_vars = {} - logger.debug(f'Rechunking features with chunks: {self.chunks}') + logger.debug('Rechunking features with chunks: %s', self.chunks) chunks = self._parse_chunks(dims) data_vars = self._check_for_elevation( data_vars, dims=dims, chunks=chunks diff --git a/sup3r/preprocessing/rasterizers/base.py b/sup3r/preprocessing/rasterizers/base.py index f78ca7a70..725e2222d 100644 --- a/sup3r/preprocessing/rasterizers/base.py +++ b/sup3r/preprocessing/rasterizers/base.py @@ -130,7 +130,7 @@ def lat_lon(self): def rasterize_data(self): """Get rasterized data.""" - logger.info( + logger.debug( 'Rasterizing data for target / shape: %s / %s', np.asarray(self._target), np.asarray(self._grid_shape), @@ -152,7 +152,7 @@ def check_target_and_shape(self, full_lat_lon): def get_raster_index(self): """Get set of slices or indices selecting the requested region from the contained data.""" - logger.info( + logger.debug( 'Getting raster index for target / shape: %s / %s', np.asarray(self._target), np.asarray(self._grid_shape), @@ -222,7 +222,7 @@ def get_closest_row_col(self, lat_lon, target): add_msg = f'This exceeds the given threshold: {self.threshold}' logger.error(f'{msg} {add_msg}') raise RuntimeError(f'{msg} {add_msg}') - logger.info(msg) + logger.debug(msg) return row, col def get_lat_lon(self): diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 5f5838425..8f1b394c2 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -336,14 +336,14 @@ def data(self): cache_fp = self.cache_file if cache_fp is not None and os.path.exists(cache_fp): logger.info( - 'Loading cached data for {} from {}'.format( - self.feature, cache_fp - ) + 'Loading cached data for %s from %s', + self.feature, + cache_fp, ) data = Loader(cache_fp) else: data = self.get_data() - logger.info(f'Finished rasterizing "{self.feature}"') + logger.info('Finished rasterizing "%s"', self.feature) if cache_fp is not None and not os.path.exists(cache_fp): Cacher._write_single( diff --git a/sup3r/preprocessing/rasterizers/extended.py b/sup3r/preprocessing/rasterizers/extended.py index 44260e7b1..7ca986c23 100644 --- a/sup3r/preprocessing/rasterizers/extended.py +++ b/sup3r/preprocessing/rasterizers/extended.py @@ -157,7 +157,7 @@ def save_raster_index(self): """Save raster index to cache file.""" os.makedirs(os.path.dirname(self.raster_file), exist_ok=True) np.savetxt(self.raster_file, self.raster_index) - logger.info(f'Saved raster_index to {self.raster_file}') + logger.info('Saved raster_index to %s', self.raster_file) def get_raster_index(self): """Get set of slices or indices selecting the requested region from @@ -172,9 +172,10 @@ def _get_flat_data_raster_index(self): from WTK or NSRDB data.""" if self.raster_file is None or not os.path.exists(self.raster_file): - logger.info( - f'Calculating raster_index for target={self._target}, ' - f'shape={self._grid_shape}.' + logger.debug( + 'Calculating raster_index for target=%s, shape=%s.', + self._target, + self._grid_shape, ) msg = ('Either shape + target or a raster_file must be provided ' 'for flattened data rasterization.') @@ -186,7 +187,7 @@ def _get_flat_data_raster_index(self): ) else: raster_index = np.loadtxt(self.raster_file).astype(np.int32) - logger.info(f'Loaded raster_index from {self.raster_file}') + logger.info('Loaded raster_index from %s', self.raster_file) return raster_index diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index 457f45580..565da6034 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -56,7 +56,7 @@ def get_input_handler_class(input_handler_name: Optional[str] = None): if input_handler_name is None: input_handler_name = 'DataHandler' - logger.info( + logger.debug( '"input_handler_name" arg was not provided. Using ' f'"{input_handler_name}". If this is incorrect, please provide ' 'input_handler_name="DataHandlerName".' @@ -119,13 +119,15 @@ def _get_args_dict(thing, fun, *args, **kwargs): def _log_args(thing, fun, *args, **kwargs): """Log annotated attributes and args.""" - - args_dict = _get_args_dict(thing, fun, *args, **kwargs) - name = thing.__class__.__name__ - logger.info( - f'Initialized {name} with:\n{pprint.pformat(args_dict, indent=2)}' - ) - logger.debug(_mem_check()) + if logger.isEnabledFor(logging.DEBUG): + args_dict = _get_args_dict(thing, fun, *args, **kwargs) + name = thing.__class__.__name__ + logger.debug( + 'Initialized %s with:\n%s', + name, + pprint.pformat(args_dict, indent=2), + ) + logger.debug('%s', _mem_check()) def wrapper(self, *args, **kwargs): _log_args(self, func, *args, **kwargs) diff --git a/sup3r/solar/solar_cli.py b/sup3r/solar/solar_cli.py index d094a9f60..c78b319ca 100644 --- a/sup3r/solar/solar_cli.py +++ b/sup3r/solar/solar_cli.py @@ -86,11 +86,19 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): node_config['temporal_ids'] = list(temporal_ids) cmd = Solar.get_node_cmd(node_config) + logger.info( + 'Queueing solar node %s as job "%s".', + i_node, + name, + ) + if hardware_option.lower() in AVAILABLE_HARDWARE_OPTIONS: kickoff_slurm_job(ctx, cmd, pipeline_step, **exec_kwargs) else: kickoff_local_job(ctx, cmd, pipeline_step) + logger.info('Finished queueing solar work for %s nodes.', max_nodes) + def kickoff_slurm_job( ctx, From 17ddf08dc897a6ce57f8a00de89c54025873f173 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 08:20:13 -0600 Subject: [PATCH 14/34] Refactor logging levels and improve model loading in the pipeline - Changed logging from INFO to DEBUG in various modules to reduce verbosity. - Updated the model loading mechanism in ForwardPass to directly use the strategy's model. - Enhanced ForwardPassStrategy initialization with additional logging for clarity. - Modified the way features are resolved in BaseDeriver to improve flexibility. - Added optional model parameter in ForwardPass to allow preloaded model reuse. - Improved cache handling in DataHandler and Cacher with clearer logging. - Adjusted output handling in writers to provide more informative logging during file operations. --- pixi.lock | 99 +++++++++--------- sup3r/models/base.py | 2 +- sup3r/pipeline/forward_pass.py | 12 ++- sup3r/pipeline/strategy.py | 109 +++++++++++++------- sup3r/preprocessing/data_handlers/base.py | 17 ++- sup3r/preprocessing/data_handlers/nc_cc.py | 2 +- sup3r/preprocessing/derivers/base.py | 41 ++++++-- sup3r/preprocessing/loaders/base.py | 2 +- sup3r/preprocessing/rasterizers/base.py | 4 +- sup3r/preprocessing/rasterizers/dual.py | 8 +- sup3r/preprocessing/rasterizers/exo.py | 15 ++- sup3r/preprocessing/rasterizers/extended.py | 4 +- sup3r/solar/solar.py | 2 +- sup3r/utilities/cli.py | 13 +++ sup3r/utilities/utilities.py | 2 +- sup3r/writers/base.py | 6 +- sup3r/writers/cachers.py | 8 +- sup3r/writers/h5.py | 6 ++ sup3r/writers/nc.py | 6 ++ sup3r/writers/utilities.py | 2 +- 20 files changed, 233 insertions(+), 127 deletions(-) diff --git a/pixi.lock b/pixi.lock index 9d069fd65..42f66dc2a 100644 --- a/pixi.lock +++ b/pixi.lock @@ -309,6 +309,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ad/6a/8596c70af1e24484580fbea8feeda4b95f8e004e9de0f9e233e08342baa9/netCDF4-1.6.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl @@ -328,7 +329,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/1e/07/bf730d44c2fe1b676ad9cc2be5f5f861eb5d153fb6951987a2d6a96379a9/nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/fa/3d/f4f2ba829efb54b6cd2d91349c7463316a9cc55a43fc980447416c88540f/pkginfo-1.12.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/10/39/de2423e6a13fb2f44ecf068df41ff1c7368ecd8b06f728afa1fb30f4ff0a/pyjson5-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl @@ -593,12 +593,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/fa/3d/f4f2ba829efb54b6cd2d91349c7463316a9cc55a43fc980447416c88540f/pkginfo-1.12.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl @@ -898,6 +898,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ad/6a/8596c70af1e24484580fbea8feeda4b95f8e004e9de0f9e233e08342baa9/netCDF4-1.6.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl @@ -917,7 +918,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/1e/07/bf730d44c2fe1b676ad9cc2be5f5f861eb5d153fb6951987a2d6a96379a9/nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/10/39/de2423e6a13fb2f44ecf068df41ff1c7368ecd8b06f728afa1fb30f4ff0a/pyjson5-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl @@ -1160,12 +1160,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl @@ -1663,6 +1663,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ad/6a/8596c70af1e24484580fbea8feeda4b95f8e004e9de0f9e233e08342baa9/netCDF4-1.6.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl @@ -1683,7 +1684,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/1e/07/bf730d44c2fe1b676ad9cc2be5f5f861eb5d153fb6951987a2d6a96379a9/nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/fa/3d/f4f2ba829efb54b6cd2d91349c7463316a9cc55a43fc980447416c88540f/pkginfo-1.12.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl @@ -2118,13 +2118,13 @@ environments: - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/27/1a/1f68f9ba0c207934b35b86a8ca3aad8395a3d6dd7921c0686e23853ff5a9/mccabe-0.7.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/fa/3d/f4f2ba829efb54b6cd2d91349c7463316a9cc55a43fc980447416c88540f/pkginfo-1.12.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl @@ -2438,6 +2438,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ad/6a/8596c70af1e24484580fbea8feeda4b95f8e004e9de0f9e233e08342baa9/netCDF4-1.6.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl @@ -2457,7 +2458,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/1e/07/bf730d44c2fe1b676ad9cc2be5f5f861eb5d153fb6951987a2d6a96379a9/nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/10/39/de2423e6a13fb2f44ecf068df41ff1c7368ecd8b06f728afa1fb30f4ff0a/pyjson5-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl @@ -2706,12 +2706,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl @@ -3012,6 +3012,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ad/6a/8596c70af1e24484580fbea8feeda4b95f8e004e9de0f9e233e08342baa9/netCDF4-1.6.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl @@ -3031,7 +3032,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/1e/07/bf730d44c2fe1b676ad9cc2be5f5f861eb5d153fb6951987a2d6a96379a9/nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/10/39/de2423e6a13fb2f44ecf068df41ff1c7368ecd8b06f728afa1fb30f4ff0a/pyjson5-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl @@ -3278,12 +3278,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/27/16/ad52f56b96d851a2bcfdc1e754c3531341885bd7177a128c13ff2ca72ab4/pytest_env-1.6.0-py3-none-any.whl @@ -3750,6 +3750,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ad/6a/8596c70af1e24484580fbea8feeda4b95f8e004e9de0f9e233e08342baa9/netCDF4-1.6.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl @@ -3769,7 +3770,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/1e/07/bf730d44c2fe1b676ad9cc2be5f5f861eb5d153fb6951987a2d6a96379a9/nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/10/39/de2423e6a13fb2f44ecf068df41ff1c7368ecd8b06f728afa1fb30f4ff0a/pyjson5-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl @@ -4164,12 +4164,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9e/f7caf7486a22c3f8dde60228a9905c73dd676cdcacbdaa4390acfc9ae959/h5pyd-0.18.0.tar.gz - pypi: https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/aa/1b/1c00fac58293e95eaf2c6f560b6710af615c666bcc6babcb31a5de2e642e/netCDF4-1.6.5-cp312-cp312-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/51/b6/95a52fcd1a92e339a61d34d054401671c50e3522f8aaea37d413e651951d/nrel_farms-1.0.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/51/a025e1b9fbe459fa45eef37abc6602a16e20979cba77b723bbea2dccc203/NREL_gaps-0.6.14-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/29/15/b6b2b49b4e5e17f0d2c1006d609b8adb13aa96944c6b8b5eb02a39df99a4/NREL_rex-0.2.98-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/62/5e/3a6a3e90f35cea3853c45e5d5fb9b7192ce4384616f932cf7591298ab6e1/numpydoc-1.10.0-py3-none-any.whl - - pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 - pypi: https://files.pythonhosted.org/packages/aa/4b/4e69ccbf34f2f303e32dc0dc8853d82282f109ba41b7a9366d518751e500/pyjson5-2.0.0-cp312-cp312-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl @@ -10469,9 +10469,10 @@ packages: purls: [] size: 136216 timestamp: 1758194284857 -- pypi: git+https://github.com/NatLabRockies/phygnn.git?rev=bnb%2Ftf#525ceb98d43ea799271b6aea391be5b0cbfe0475 +- pypi: https://files.pythonhosted.org/packages/e5/92/a50dbf4aea6b38476dd74bb97f52d35594e51794bdd514d721be6f30b301/nlr_phygnn-0.0.35-py3-none-any.whl name: nlr-phygnn - version: 0.0.35.dev22+g525ceb98d + version: 0.0.35 + sha256: adccc8119865e78ed7b617440ab48e75664ae99e513b2fa32390ba4a55aa3f23 requires_dist: - matplotlib>=3.1 - nrel-rex @@ -10491,6 +10492,41 @@ packages: - pytest>=5.2 ; extra == 'test' - pytest-env ; extra == 'test' requires_python: '>=3.9,<3.14' +- pypi: ./ + name: nlr-sup3r + version: 0.2.9.dev13+gdbcc297b2.d20260510 + sha256: 7c041c39bb3c483a4161ac1bc0d0d139d3d3d74677bc583d595f3583c8fdcf9c + requires_dist: + - nlr-phygnn>=0.0.35 + - nrel-rex>=0.2.91 + - nrel-gaps>=0.6.13 + - nrel-farms>=1.0.4 + - dask>=2022.0 + - netcdf4>=1.5.8,<1.7.0 + - h5netcdf>=1.1.0 + - cftime>=1.6.2 + - matplotlib>=3.1 + - numpy>=1.7.0,<2.0.0 + - pandas>=2.0 + - pillow>=10.0 + - scipy>=1.0.0 + - xarray>=2023.0 + - zarr>=2.18.0,<4 + - pre-commit ; extra == 'dev' + - pylint ; extra == 'dev' + - ruff>=0.4 ; extra == 'dev' + - zarr>=2.18.0,<4 ; extra == 'dev' + - sphinx>=7.0 ; extra == 'doc' + - sphinx-click>=4.0 ; extra == 'doc' + - sphinx-book-theme>=1.1.1 ; extra == 'doc' + - sphinx-autosummary-accessors>=2023.4.0 ; extra == 'doc' + - pytest>=5.2 ; extra == 'test' + - pytest-env ; extra == 'test' + - pytest-cov>=5.0.0 ; extra == 'test' + - build~=1.3.0 ; extra == 'build' + - pkginfo>=1.10.0,<2 ; extra == 'build' + - twine>=5.0 ; extra == 'build' + requires_python: '>=3.9,<3.13' - pypi: https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl name: nodeenv version: 1.10.0 @@ -10616,41 +10652,6 @@ packages: - pytest-timeout>=2.3.1 ; extra == 'test' - flaky>=3.8.1 ; extra == 'test' requires_python: '>=3.9' -- pypi: ./ - name: nrel-sup3r - version: 0.2.7.dev29+g2a8491d0d - sha256: 86efec4f2e4479b26d09b6f7c65223a94c21e325f2b3815d0ee41672780dadb0 - requires_dist: - - nlr-phygnn @ git+https://github.com/NatLabRockies/phygnn.git@bnb/tf - - nrel-rex>=0.2.91 - - nrel-gaps>=0.6.13 - - nrel-farms>=1.0.4 - - dask>=2022.0 - - netcdf4>=1.5.8,<1.7.0 - - h5netcdf>=1.1.0 - - cftime>=1.6.2 - - matplotlib>=3.1 - - numpy>=1.7.0,<2.0.0 - - pandas>=2.0 - - pillow>=10.0 - - scipy>=1.0.0 - - xarray>=2023.0 - - zarr>=2.18.0,<4 - - pre-commit ; extra == 'dev' - - pylint ; extra == 'dev' - - ruff>=0.4 ; extra == 'dev' - - zarr>=2.18.0,<4 ; extra == 'dev' - - sphinx>=7.0 ; extra == 'doc' - - sphinx-click>=4.0 ; extra == 'doc' - - sphinx-book-theme>=1.1.1 ; extra == 'doc' - - sphinx-autosummary-accessors>=2023.4.0 ; extra == 'doc' - - pytest>=5.2 ; extra == 'test' - - pytest-env ; extra == 'test' - - pytest-cov>=5.0.0 ; extra == 'test' - - build~=1.3.0 ; extra == 'build' - - pkginfo>=1.10.0,<2 ; extra == 'build' - - twine>=5.0 ; extra == 'build' - requires_python: '>=3.9,<3.13' - pypi: https://files.pythonhosted.org/packages/f5/6c/86644987505dcb90ba6d627d6989c27bafb0699f9fd00187e06d05ea8594/numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl name: numcodecs version: 0.16.5 diff --git a/sup3r/models/base.py b/sup3r/models/base.py index d872b8532..b58ecfc0c 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -176,7 +176,7 @@ def save(self, out_dir): logger.info('Saved GAN to disk in directory: {}'.format(out_dir)) @classmethod - def _load(cls, model_dir, verbose=True): + def _load(cls, model_dir, verbose=False): """Get gen, disc, and params for given model_dir. Parameters diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 0f530789e..62e0c1a19 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -52,7 +52,7 @@ def __init__(self, strategy, node_index=0): """ self.timer = Timer() self.strategy = strategy - self.model = get_model(strategy.model_class, strategy.model_kwargs) + self.model = strategy.model self.node_index = node_index output_type = get_source_type(strategy.out_pattern) @@ -458,6 +458,7 @@ def run(cls, strategy, node_index): cls._run_serial(strategy, node_index) else: cls._run_parallel(strategy, node_index) + logger.info('Finished forward pass on node %s.', node_index) logger.debug( 'Timing report:\n%s', pprint.pformat(strategy.timer.log, indent=2), @@ -490,6 +491,7 @@ def _run_serial(cls, strategy, node_index): model_kwargs=strategy.model_kwargs, model_class=strategy.model_class, allowed_const=strategy.allowed_const, + model=fwp.model, output_workers=strategy.output_workers, invert_uv=strategy.invert_uv, nn_fill=strategy.nn_fill, @@ -604,6 +606,7 @@ def run_chunk( model_kwargs, model_class, allowed_const, + model=None, invert_uv=False, meta=None, nn_fill=True, @@ -631,6 +634,10 @@ def run_chunk( True to allow any constant output or a list of allowed possible constant outputs. See :class:`ForwardPassStrategy` for more information on this argument. + model : Sup3rGan | None + Optional preloaded model instance to reuse for the chunk. If + ``None``, the model is loaded from ``model_class`` and + ``model_kwargs``. invert_uv : bool Whether to convert uv to windspeed and winddirection for writing output. When this method is called during a pipeline forward pass @@ -662,7 +669,8 @@ def run_chunk( msg = f'Running forward pass for chunk_index={chunk.index}.' logger.debug(msg) - model = get_model(model_class, model_kwargs) + if model is None: + model = get_model(model_class, model_kwargs) mask = np.isnan(chunk.input_data).any(axis=(0, 1, 2)) feats = np.array(model.lr_features[: len(mask)])[mask] diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 54ac35fe4..fd24f3733 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -239,13 +239,25 @@ def __post_init__(self): self.bias_correct_kwargs = self.bias_correct_kwargs or {} self.timer = Timer() - model = get_model(self.model_class, self.model_kwargs) - self.s_enhancements = model.s_enhancements - self.t_enhancements = model.t_enhancements - self.s_enhance, self.t_enhance = model.s_enhance, model.t_enhance - self.input_features = model.lr_features - self.output_features = model.hr_out_features - self.features, self.exo_features = self._init_features(model) + logger.info( + 'Initializing forward pass strategy for model=%s with ' + 'input_handler=%s across %s input files.', + self.model_class, + self.input_handler_name or 'DataHandler', + len(self.file_paths), + ) + logger.debug('Forward pass strategy input files: %s', self.file_paths) + + self.model = get_model(self.model_class, self.model_kwargs) + self.s_enhancements = self.model.s_enhancements + self.t_enhancements = self.model.t_enhancements + self.s_enhance, self.t_enhance = ( + self.model.s_enhance, + self.model.t_enhance, + ) + self.input_features = self.model.lr_features + self.output_features = self.model.hr_out_features + self.features, self.exo_features = self._init_features(self.model) self.time_slice, self.padded_time_slice = self.get_time_slices() self.input_handler = self.timer(self.init_input_handler, log=True)() self.fwp_chunk_shape = self._get_fwp_chunk_shape() @@ -274,14 +286,16 @@ def __post_init__(self): if self.head_node and cache_check: logger.warning(msg) warn(msg) - _ = self.timer(self.load_exo_data, log=True)(model) + _ = self.timer(self.load_exo_data, log=True)(self.model) if not self.head_node: # This will either load cache files created on the head node or # directly load the exogenous data if no cache is being used. hr_shape = self.hr_lat_lon.shape[:-1] self.gids = np.arange(np.prod(hr_shape)).reshape(hr_shape) - self.exo_data = self.timer(self.load_exo_data, log=True)(model) + self.exo_data = self.timer(self.load_exo_data, log=True)( + self.model + ) self.preflight() @@ -359,7 +373,20 @@ def init_input_handler(self): input_handler_kwargs['chunks'] = 'auto' input_handler_kwargs['time_slice'] = self.padded_time_slice - return InputHandler(**input_handler_kwargs) + logger.info( + 'Initializing %s for %s features over padded time slice %s.', + InputHandler.__name__, + len(input_handler_kwargs['features']), + self.padded_time_slice, + ) + handler = InputHandler(**input_handler_kwargs) + logger.info( + '%s ready with grid shape %s and %s time steps.', + InputHandler.__name__, + handler.grid_shape, + len(handler.time_index), + ) + return handler def _init_features(self, model): """Initialize feature attributes.""" @@ -431,17 +458,15 @@ def preflight(self): self.lr_slices, self.lr_pad_slices, self.hr_slices = out non_masked = self.fwp_slicer.n_spatial_chunks - sum(self.fwp_mask) - non_masked *= int(self.fwp_slicer.n_time_chunks) - log_dict = { - 'n_nodes': len(self.node_chunks), - 'n_spatial_chunks': self.fwp_slicer.n_spatial_chunks, - 'n_time_chunks': self.fwp_slicer.n_time_chunks, - 'n_total_chunks': self.fwp_slicer.n_chunks, - 'non_masked_chunks': non_masked, - } + non_masked *= self.fwp_slicer.n_time_chunks logger.info( - f'Chunk strategy description:\n' - f'{pprint.pformat(log_dict, indent=2)}' + 'Chunk strategy uses %s nodes across %s total chunks ' + '(%s spatial x %s temporal, %s unmasked).', + len(self.node_chunks), + self.fwp_slicer.n_chunks, + self.fwp_slicer.n_spatial_chunks, + self.fwp_slicer.n_time_chunks, + int(non_masked), ) def get_chunk_indices(self, chunk_index): @@ -643,9 +668,21 @@ def load_exo_data(self, model): """ data = {} exo_data = None - for exo_kwargs in self.get_exo_kwargs(model): + exo_kwargs_list = self.get_exo_kwargs(model) + if exo_kwargs_list: + logger.info( + 'Loading exogenous data for %s features: %s.', + len(exo_kwargs_list), + [kwargs['feature'] for kwargs in exo_kwargs_list], + ) + for exo_kwargs in exo_kwargs_list: data.update(ExoDataHandler(**exo_kwargs).data) exo_data = ExoData(data) + if exo_kwargs_list: + logger.info( + 'Finished loading exogenous data for %s features.', + len(exo_kwargs_list), + ) return exo_data @cached_property @@ -664,23 +701,25 @@ def fwp_mask(self): mask = np.zeros(len(self.lr_pad_slices)) logger.debug('Checking for mask in input handler.') input_handler_kwargs = copy.deepcopy(self.input_handler_kwargs) - try: - InputHandler = get_input_handler_class(self.input_handler_name) - input_handler_kwargs['features'] = ['mask'] - handler = InputHandler(**input_handler_kwargs) + InputHandler = get_input_handler_class(self.input_handler_name) + input_handler_kwargs['features'] = [] + handler = InputHandler(**input_handler_kwargs) + mask_feature = handler.resolve_feature('mask', strict=False) + if mask_feature is None: logger.debug( - 'Found "mask" in DataHandler. Computing forward pass ' - 'chunk mask for %s chunks', - len(self.lr_pad_slices), - ) - mask_vals = handler.data['mask'].values - for s_chunk_idx, lr_slices in enumerate(self.lr_pad_slices): - mask_check = mask_vals[lr_slices[0], lr_slices[1]] - mask[s_chunk_idx] = bool(np.prod(mask_check.flatten())) - except Exception: - logger.info( 'No "mask" found in DataHandler. No chunks will be masked.' ) + return mask + + logger.debug( + 'Found "mask" in DataHandler. Computing forward pass ' + 'chunk mask for %s chunks', + len(self.lr_pad_slices), + ) + mask_vals = getattr(mask_feature, 'values', mask_feature) + for s_chunk_idx, lr_slices in enumerate(self.lr_pad_slices): + mask_check = mask_vals[lr_slices[0], lr_slices[1]] + mask[s_chunk_idx] = bool(np.prod(mask_check.flatten())) return mask def node_finished(self, node_idx): diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 8e72faf7d..8371276d0 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -207,11 +207,23 @@ def __init__( """ # pylint: disable=line-too-long features = parse_to_list(features=features) + source_files = expand_paths(file_paths) cached_files, cached_features, _, missing_features = _check_for_cache( features=features, cache_kwargs=cache_kwargs ) + logger.debug( + 'DataHandler preparing %s requested features from %s source files ' + '(%s cached, %s via rasterizer).', + len(features), + len(source_files), + len(cached_features), + len(missing_features), + ) + just_coords = not features + if just_coords: + logger.info('Rasterizing source data for coordinate-only access.') raster_feats = load_features if any(missing_features) else [] self.rasterizer = self.loader = self.cache = None if any(cached_features): @@ -263,7 +275,10 @@ def __init__( expand_paths(file_paths) + cached_files ) - if cache_kwargs is not None and 'cache_pattern' in cache_kwargs: + should_cache = cache_kwargs is not None and ( + bool(missing_features) or cache_kwargs.get('overwrite', False) + ) + if should_cache and 'cache_pattern' in cache_kwargs: self.cacher = Cacher(data=self.data, cache_kwargs=cache_kwargs) self._deriver_hook() diff --git a/sup3r/preprocessing/data_handlers/nc_cc.py b/sup3r/preprocessing/data_handlers/nc_cc.py index 1187a75fb..17b6b0ddd 100644 --- a/sup3r/preprocessing/data_handlers/nc_cc.py +++ b/sup3r/preprocessing/data_handlers/nc_cc.py @@ -124,7 +124,7 @@ def run_input_checks(self): def run_wrap_checks(self, cs_ghi): """Run check on rasterized data from clearsky_ghi source.""" - logger.info( + logger.debug( 'Reshaped clearsky_ghi data to final shape {} to ' 'correspond with CC daily average data over source ' 'time_slice {} with (lat, lon) grid shape of {}'.format( diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index bb20cb2a2..7f40252eb 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -205,23 +205,25 @@ def has_interp_variables(self, feature): count += 1 return count > 1 or fstruct.basename in self.data - def derive(self, feature) -> Union[np.ndarray, da.core.Array]: - """Routine to derive requested features. Employs a little recursion to - locate differently named features with a name map in the feature - registry. i.e. if `FEATURE_REGISTRY` contains a key, value pair like - "windspeed": "wind_speed" then requesting "windspeed" will ultimately - return a compute method (or fetch from raw data) for "wind_speed + def resolve_feature( + self, feature, strict=True + ) -> Union[np.ndarray, da.core.Array, None]: + """Resolve a feature from contained data or available derivations. - Note - ---- - Features are all saved as lower case names and __contains__ checks will - use feature.lower() + Parameters + ---------- + feature : str + Feature to resolve from the contained data or available compute + methods. + strict : bool + Whether to raise if the feature cannot be resolved. If ``False``, + return ``None`` instead. """ if feature not in self.data: compute_check = self.check_registry(feature) if compute_check is not None and isinstance(compute_check, str): new_feature = self.map_new_name(feature, compute_check) - return self.derive(new_feature) + return self.resolve_feature(new_feature, strict=strict) if compute_check is not None: return compute_check @@ -234,6 +236,9 @@ def derive(self, feature) -> Union[np.ndarray, da.core.Array]: feature, interp_kwargs=self.interp_kwargs ) + if not strict: + return None + msg = ( 'Could not find "%s" in contained data or in the available ' 'compute methods.' @@ -250,6 +255,20 @@ def derive(self, feature) -> Union[np.ndarray, da.core.Array]: warn(msg) return self.data[feature] + def derive(self, feature) -> Union[np.ndarray, da.core.Array]: + """Routine to derive requested features. Employs a little recursion to + locate differently named features with a name map in the feature + registry. i.e. if `FEATURE_REGISTRY` contains a key, value pair like + "windspeed": "wind_speed" then requesting "windspeed" will ultimately + return a compute method (or fetch from raw data) for "wind_speed + + Note + ---- + Features are all saved as lower case names and __contains__ checks will + use feature.lower() + """ + return self.resolve_feature(feature, strict=True) + def get_single_level_data(self, feature): """When doing level interpolation we should include the single level data available. e.g. If we have u_100m already and want to interpolate diff --git a/sup3r/preprocessing/loaders/base.py b/sup3r/preprocessing/loaders/base.py index 656f9f664..f36d8bb2d 100644 --- a/sup3r/preprocessing/loaders/base.py +++ b/sup3r/preprocessing/loaders/base.py @@ -73,7 +73,7 @@ def __init__( Optional base loader update. The default for H5 files is MultiFileResourceX and for NETCDF or ZARR is xarray.open_mfdataset """ - logger.info( + logger.debug( 'Loading features: %s from files: %s', features, file_paths ) super().__init__() diff --git a/sup3r/preprocessing/rasterizers/base.py b/sup3r/preprocessing/rasterizers/base.py index 725e2222d..3891c1b9d 100644 --- a/sup3r/preprocessing/rasterizers/base.py +++ b/sup3r/preprocessing/rasterizers/base.py @@ -58,8 +58,8 @@ def __init__( are more than this value away from the target lat/lon, an error is raised. """ - logger.info( - 'Rasterizing features: %s from files: %s', + logger.debug( + 'Rasterizing features: "%s" from files: %s', features, loader.file_paths, ) diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 66e1a9f6f..b8a8b92ec 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -177,7 +177,7 @@ def update_hr_data(self): : self.hr_required_shape[2] ], } - logger.info( + logger.debug( 'Updating self.data.high_res with new shape: ' f'{self.hr_required_shape[:3]}' ) @@ -203,7 +203,7 @@ def update_lr_data(self): cached features if available and overwrite=False""" if self._regrid_lr: - logger.info('Regridding low resolution feature data.') + logger.debug('Regridding low resolution feature data.') regridder = self.get_regridder() lr_data_new = {} @@ -219,7 +219,7 @@ def update_lr_data(self): : self.lr_required_shape[2] ], } - logger.info('Updating self.data.low_res with regridded data.') + logger.debug('Updating self.data.low_res with regridded data.') self.data.low_res = self.data.low_res.update_ds({ **lr_coords_new, **lr_data_new, @@ -228,7 +228,7 @@ def update_lr_data(self): def check_regridded_lr_data(self): """Check for NaNs after regridding and do NN fill if needed.""" fill_feats = [] - logger.info('Checking for NaNs after regridding') + logger.debug('Checking for NaNs after regridding') qa_info = self.data.low_res.qa(stats=['nan_perc']) for f in self.data.low_res.features: nan_perc = qa_info[f]['nan_perc'] diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 8f1b394c2..9c7c1b242 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -288,10 +288,9 @@ def get_distance_upper_bound(self): diff = da.diff(self.hr_lat_lon, axis=0) diff = da.abs(da.median(diff, axis=0)).max() self.distance_upper_bound = np.asarray(diff) - logger.info( - 'Set distance upper bound to {:.4f}'.format( - self.distance_upper_bound - ) + logger.debug( + 'Set distance upper bound to %.4f', + self.distance_upper_bound, ) return self.distance_upper_bound @@ -335,7 +334,7 @@ def data(self): cache_fp = self.cache_file if cache_fp is not None and os.path.exists(cache_fp): - logger.info( + logger.debug( 'Loading cached data for %s from %s', self.feature, cache_fp, @@ -343,7 +342,7 @@ def data(self): data = Loader(cache_fp) else: data = self.get_data() - logger.info('Finished rasterizing "%s"', self.feature) + logger.debug('Finished rasterizing "%s"', self.feature) if cache_fp is not None and not os.path.exists(cache_fp): Cacher._write_single( @@ -367,7 +366,7 @@ def _check_coverage(self, hr_data): 'probably means the source data is not high enough ' 'resolution. Filling raster with NN.' ) - logger.warning(msg) + logger.debug(msg) warn(msg) hr_data = nn_fill_array(hr_data) return hr_data @@ -610,7 +609,7 @@ def __new__(cls, feature, file_paths, source_files=None, **kwargs): 'feature': feature, **kwargs, } - logger.info( + logger.debug( f'Using {ExoClass.__name__} to rasterize feature "{feature}"' ) return ExoClass(**kwargs) diff --git a/sup3r/preprocessing/rasterizers/extended.py b/sup3r/preprocessing/rasterizers/extended.py index 7ca986c23..ba75bcad3 100644 --- a/sup3r/preprocessing/rasterizers/extended.py +++ b/sup3r/preprocessing/rasterizers/extended.py @@ -157,7 +157,7 @@ def save_raster_index(self): """Save raster index to cache file.""" os.makedirs(os.path.dirname(self.raster_file), exist_ok=True) np.savetxt(self.raster_file, self.raster_index) - logger.info('Saved raster_index to %s', self.raster_file) + logger.debug('Saved raster_index to %s', self.raster_file) def get_raster_index(self): """Get set of slices or indices selecting the requested region from @@ -187,7 +187,7 @@ def _get_flat_data_raster_index(self): ) else: raster_index = np.loadtxt(self.raster_file).astype(np.int32) - logger.info('Loaded raster_index from %s', self.raster_file) + logger.debug('Loaded raster_index from %s', self.raster_file) return raster_index diff --git a/sup3r/solar/solar.py b/sup3r/solar/solar.py index aa039da31..e8c200cdb 100644 --- a/sup3r/solar/solar.py +++ b/sup3r/solar/solar.py @@ -578,7 +578,7 @@ def write(self, fp_out, features=('ghi', 'dni', 'dhi')): run_attrs['nsrdb_source'] = self._nsrdb_fp fh.run_attrs = run_attrs - logger.info(f'Finished writing file: {fp_out}') + logger.debug(f'Finished writing file: {fp_out}') @classmethod def run_temporal_chunks( diff --git a/sup3r/utilities/cli.py b/sup3r/utilities/cli.py index e704e8b43..f61cc7535 100644 --- a/sup3r/utilities/cli.py +++ b/sup3r/utilities/cli.py @@ -70,6 +70,13 @@ def from_config(cls, module_name, module_class, ctx, config_file, verbose, exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.pop('option', 'local') + logger.info( + 'Preparing sup3r %s from %s with hardware=%s.', + module_name, + config_file, + hardware_option, + ) + cmd = module_class.get_node_cmd(config) if hardware_option.lower() in AVAILABLE_HARDWARE_OPTIONS: @@ -80,6 +87,12 @@ def from_config(cls, module_name, module_class, ctx, config_file, verbose, cls.kickoff_local_job(module_name, ctx, cmd, pipeline_step=pipeline_step) + logger.info( + 'Finished sup3r %s submission for %s.', + module_name, + config_file, + ) + @classmethod def from_config_preflight(cls, module_name, ctx, config_file, verbose): """Parse conifg file prior to running sup3r module. diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 23f9d7c8a..cc807f8a5 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -120,7 +120,7 @@ def preprocess_datasets(dset): def xr_open_mfdataset(files, **kwargs): """Wrapper for xr.open_mfdataset with default opening options.""" - default_kwargs = {'engine': 'netcdf4'} + default_kwargs = {'engine': 'netcdf4', 'compat': 'override'} default_kwargs.update(kwargs) if isinstance(files, str): files = [files] diff --git a/sup3r/writers/base.py b/sup3r/writers/base.py index 38247f980..8f7f1d050 100644 --- a/sup3r/writers/base.py +++ b/sup3r/writers/base.py @@ -76,7 +76,7 @@ def _init_h5(out_file, time_index, meta, global_attrs): with RexOutputs(out_file, mode='w-') as f: logger.info('Initializing output file: {}'.format(out_file)) - logger.info( + logger.debug( 'Initializing output file with shape {} ' 'and meta data:\n{}'.format((len(time_index), len(meta)), meta) ) @@ -101,7 +101,7 @@ def _ensure_dset_in_output(cls, out_file, dset, data=None): with RexOutputs(out_file, mode='a') as f: if dset not in f.dsets: attrs, dtype = get_dset_attrs(dset) - logger.info( + logger.debug( 'Initializing dataset "{}" with shape {} and ' 'dtype {}'.format(dset, f.shape, dtype) ) @@ -150,7 +150,7 @@ def write_data( attrs=attrs, chunks=attrs['chunks'], ) - logger.info(f'Added {dset} to output file {out_file}.') + logger.debug('Added %s to output file %s.', dset, out_file) if global_attrs is not None: attrs = { diff --git a/sup3r/writers/cachers.py b/sup3r/writers/cachers.py index 4b0e62e65..976dfa855 100644 --- a/sup3r/writers/cachers.py +++ b/sup3r/writers/cachers.py @@ -114,7 +114,7 @@ def _write_single( _, ext = os.path.splitext(out_file) os.makedirs(os.path.dirname(out_file), exist_ok=True) tmp_file = get_tmp_file(out_file) - logger.info( + logger.debug( 'Writing %s to %s with max_workers=%s. %s', features, tmp_file, @@ -145,7 +145,7 @@ def _write_single( time_last=time_last, ) os.replace(tmp_file, out_file) - logger.info('Moved %s to %s', tmp_file, out_file) + logger.debug('Moved %s to %s', tmp_file, out_file) def cache_data( self, @@ -214,7 +214,7 @@ def cache_data( time_last=time_last, overwrite=overwrite, ) - logger.info('Finished writing %s', missing_files) + logger.debug('Finished writing %s', missing_files) return missing_files + cached_files @staticmethod @@ -589,4 +589,4 @@ def write_netcdf( compute=True, ) - logger.info('Finished writing %s to %s', features, out_file) + logger.debug('Finished writing %s to %s', features, out_file) diff --git a/sup3r/writers/h5.py b/sup3r/writers/h5.py index a6033056e..951dfaad8 100644 --- a/sup3r/writers/h5.py +++ b/sup3r/writers/h5.py @@ -85,6 +85,12 @@ def _write_output( f'File already exists at {out_file}. Skipping write.' ) return + logger.info( + 'Writing H5 output to %s for %s features with shape %s.', + out_file, + len(features), + data.shape, + ) msg = ( f'Output data shape ({data.shape}) and lat_lon shape ' f'({lat_lon.shape}) conflict.' diff --git a/sup3r/writers/nc.py b/sup3r/writers/nc.py index 2e4d47bed..244112901 100644 --- a/sup3r/writers/nc.py +++ b/sup3r/writers/nc.py @@ -78,6 +78,12 @@ def _write_output( already exists. Default is False to avoid accidentally overwriting files. """ + logger.info( + 'Writing NETCDF output to %s for %s features with shape %s.', + out_file, + len(features), + data.shape, + ) data, features = cls._transform_output( data=data, features=features, diff --git a/sup3r/writers/utilities.py b/sup3r/writers/utilities.py index 5c142fb7b..99410a321 100644 --- a/sup3r/writers/utilities.py +++ b/sup3r/writers/utilities.py @@ -32,7 +32,7 @@ def _check_for_cache(features, cache_kwargs): ] if any(cached_files): - logger.info( + logger.debug( 'Found cache files for %s with file pattern: %s', cached_features, cache_pattern ) From 5a7d88acf482564ae88692d8cfd530a20454f02a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 09:09:30 -0600 Subject: [PATCH 15/34] feat: implement resolve_feature method for feature alias handling and update references across modules --- sup3r/preprocessing/base.py | 14 ++++++++++++++ sup3r/preprocessing/data_handlers/base.py | 2 +- sup3r/preprocessing/derivers/base.py | 18 ++---------------- sup3r/qa/qa.py | 12 +++++++----- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index d751b0bea..3152881f1 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -434,6 +434,20 @@ def shape(self): """Get shape of underlying data.""" return self.data.shape + def resolve_feature(self, feature, strict=True): + """Resolve feature name to a feature in the underlying data. This is + used for handling feature aliases.""" + if feature in self.data.features: + return self.data[feature] + elif strict: + msg = ( + 'Did not find feature %s in underlying data. Available ' + 'features are: %s' + ) + logger.error(msg, feature, self.data.features) + raise KeyError(msg % (feature, self.data.features)) + return None + def __len__(self): return len(self.data) diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 8371276d0..2b9f7c175 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -74,7 +74,7 @@ class DataHandler(Deriver): ... 'chunks': cache_chunks}) >>> # Derive more features from already initialized data handler: - >>> dh['windspeed_60m'] = dh.derive('windspeed_60m') + >>> dh['windspeed_60m'] = dh.resolve_feature('windspeed_60m') Derive wind speed and direction at 200m above the ground from files for geopotential height (zg), surface elevation (orog), and u/v at 10m, diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 7f40252eb..d753761e6 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -70,7 +70,7 @@ def __init__( features = parse_to_list(data=data, features=features) new_features = [f for f in features if f not in self.data] for f in new_features: - self.data[f] = self.derive(f) + self.data[f] = self.resolve_feature(f) logger.info('Finished deriving %s.', f) self.data = ( self.data[list(self.data.coords)] @@ -131,7 +131,7 @@ def check_registry( if any(missing) and can_derive: logger.debug(msg, missing) for f in missing: - self.data[f] = self.derive(f) + self.data[f] = self.resolve_feature(f) msg = 'All required features %s found. Proceeding.' if not missing or all(f in self.data for f in missing): logger.debug(msg, inputs) @@ -255,20 +255,6 @@ def resolve_feature( warn(msg) return self.data[feature] - def derive(self, feature) -> Union[np.ndarray, da.core.Array]: - """Routine to derive requested features. Employs a little recursion to - locate differently named features with a name map in the feature - registry. i.e. if `FEATURE_REGISTRY` contains a key, value pair like - "windspeed": "wind_speed" then requesting "windspeed" will ultimately - return a compute method (or fetch from raw data) for "wind_speed - - Note - ---- - Features are all saved as lower case names and __contains__ checks will - use feature.lower() - """ - return self.resolve_feature(feature, strict=True) - def get_single_level_data(self, feature): """When doing level interpolation we should include the single level data available. e.g. If we have u_100m already and want to interpolate diff --git a/sup3r/qa/qa.py b/sup3r/qa/qa.py index 6416e60fa..9bb17db49 100644 --- a/sup3r/qa/qa.py +++ b/sup3r/qa/qa.py @@ -234,7 +234,7 @@ def output_type(self): e.g. 'nc' or 'h5' """ ftype = get_source_type(self._out_fp) - if ftype not in ('nc', 'h5'): + if ftype not in {'nc', 'h5'}: msg = 'Did not recognize output file type: {}'.format(self._out_fp) logger.error(msg) raise TypeError(msg) @@ -247,7 +247,7 @@ def bias_correct_input_handler(self, input_handler): (1) Check if we need to derive any features included in the bias_correct_kwargs. - (2) Derive these features using the input_handler.derive method, and + (2) Derive these features using input_handler.resolve_feature, and update the stored data. (3) Apply bias correction to all the features in the bias_correct_kwargs @@ -261,13 +261,15 @@ def bias_correct_input_handler(self, input_handler): ) msg = ( f'Features {need_derive} need to be derived prior to bias ' - 'correction, but the input_handler has no derive method. ' + 'correction, but the input_handler has no resolve_feature method. ' 'Request an appropriate input_handler with ' 'input_handler_name=DataHandlerName.' ) - assert len(need_derive) == 0 or hasattr(input_handler, 'derive'), msg + assert len(need_derive) == 0 or hasattr( + input_handler, 'resolve_feature' + ), msg for f in need_derive: - input_handler.data[f] = input_handler.derive(f) + input_handler.data[f] = input_handler.resolve_feature(f) bc_feats = list( set(input_handler.features).intersection( set(lowered(self.bias_correct_kwargs.keys())) From 471c8ed63fe6ca711e070c5863eee2869cd5cc73 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 11:18:42 -0600 Subject: [PATCH 16/34] fix: include lr_features in exo_features filtering for ForwardPassStrategy --- sup3r/pipeline/strategy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index fd24f3733..334aec7b2 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -392,7 +392,11 @@ def _init_features(self, model): """Initialize feature attributes.""" self.exo_handler_kwargs = self.exo_handler_kwargs or {} exo_features = list(self.exo_handler_kwargs) - exo_features = [f for f in exo_features if f in model.hr_exo_features] + exo_features = [ + f + for f in exo_features + if f in model.hr_exo_features or f in model.lr_features + ] features = [f for f in model.lr_features if f not in exo_features] return features, exo_features From 4b40e2e57aa02056a2ac685c1da759c1c7c5f5c7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 11:45:13 -0600 Subject: [PATCH 17/34] fix: update hr_exo_features to return topography for LinearInterp model --- sup3r/models/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index 056af0d3c..3316ee166 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -106,8 +106,8 @@ def hr_out_features(self): @property def hr_exo_features(self): - """Returns an empty list for LinearInterp model""" - return [] + """Returns topography for LinearInterp model""" + return ['topography'] def save(self, out_dir): """ From cb06b9e8d5c11b7afe4f78b3fe7915d95afd5015 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 12:37:12 -0600 Subject: [PATCH 18/34] fix: enhance exo_features filtering to support multiple submodels in ForwardPassStrategy --- sup3r/pipeline/strategy.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 334aec7b2..158ddd4dd 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -392,11 +392,14 @@ def _init_features(self, model): """Initialize feature attributes.""" self.exo_handler_kwargs = self.exo_handler_kwargs or {} exo_features = list(self.exo_handler_kwargs) - exo_features = [ - f - for f in exo_features - if f in model.hr_exo_features or f in model.lr_features - ] + # If the model has multiple submodels with different features, we need + # to keep all exo features that are needed for any of the submodels. + # model.lr_features only inputs for the first model + models = getattr(model, 'models', [model]) + lr_features = {f for m in models for f in m.lr_features} + exo_features = set(exo_features).intersection( + lr_features | set(model.hr_exo_features) + ) features = [f for f in model.lr_features if f not in exo_features] return features, exo_features From 780bcd44c6d18e9c137b1b05ecf3890a777faaec Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 13:47:59 -0600 Subject: [PATCH 19/34] fix: update exo_features filtering to include hr_exo_features in ForwardPassStrategy --- sup3r/pipeline/strategy.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 158ddd4dd..1f7c55223 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -397,9 +397,11 @@ def _init_features(self, model): # model.lr_features only inputs for the first model models = getattr(model, 'models', [model]) lr_features = {f for m in models for f in m.lr_features} - exo_features = set(exo_features).intersection( - lr_features | set(model.hr_exo_features) - ) + exo_features = [ + f + for f in exo_features + if f in lr_features or f in model.hr_exo_features + ] features = [f for f in model.lr_features if f not in exo_features] return features, exo_features From 8f8a98d73fa14929b86d19ad82ed6657efeedee0 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 14:36:53 -0600 Subject: [PATCH 20/34] fix: update lr_features comment for clarity in ForwardPassStrategy and modify res_kwargs for NC collect test --- sup3r/pipeline/strategy.py | 2 +- tests/output/test_output_handling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 1f7c55223..fb883fba3 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -394,7 +394,7 @@ def _init_features(self, model): exo_features = list(self.exo_handler_kwargs) # If the model has multiple submodels with different features, we need # to keep all exo features that are needed for any of the submodels. - # model.lr_features only inputs for the first model + # model.lr_features only includes inputs for the first model models = getattr(model, 'models', [model]) lr_features = {f for m in models for f in m.lr_features} exo_features = [ diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index 3161772e5..f7e9c09c3 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -179,7 +179,7 @@ def test_general_nc_collect(): out_files, fp_out, features=[*features, 'latitude', 'longitude'], - res_kwargs={'combine': 'nested', 'concat_dim': 'time'}, + res_kwargs={'compat': 'no_conflicts'}, ) with Loader(fp_out) as res: From 48be408f4889fd577fb18319a9d3d884a7f30285 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 15:04:15 -0600 Subject: [PATCH 21/34] fix: update feature resolution method in test_netcdf_uv_invert for consistency --- tests/output/test_output_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index f7e9c09c3..bd0792364 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -285,8 +285,8 @@ def test_netcdf_uv_invert(): dh = DataHandler( fp_out, features=['windspeed_10m', 'winddirection_10m'] ) - uvals = dh.derive('u_10m').values - vvals = dh.derive('v_10m').values + uvals = dh.resolve_feature('u_10m').values + vvals = dh.resolve_feature('v_10m').values assert np.allclose(data[..., 0], uvals, atol=1e-5) assert np.allclose(data[..., 1], vvals, atol=1e-5) From de27efdd92d316a3e9d322fbdbefa2157a626040 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 16:56:01 -0600 Subject: [PATCH 22/34] fix: rename resolve_feature to derive for consistency and update related usages --- sup3r/pipeline/strategy.py | 11 ++++------- sup3r/preprocessing/base.py | 5 +++-- sup3r/preprocessing/data_handlers/base.py | 2 +- sup3r/preprocessing/derivers/base.py | 8 ++++---- sup3r/qa/qa.py | 8 ++++---- tests/output/test_output_handling.py | 4 ++-- 6 files changed, 18 insertions(+), 20 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index fb883fba3..94fd1737b 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -710,14 +710,11 @@ def fwp_mask(self): mask = np.zeros(len(self.lr_pad_slices)) logger.debug('Checking for mask in input handler.') input_handler_kwargs = copy.deepcopy(self.input_handler_kwargs) + input_handler_kwargs['features'] = 'all' InputHandler = get_input_handler_class(self.input_handler_name) - input_handler_kwargs['features'] = [] handler = InputHandler(**input_handler_kwargs) - mask_feature = handler.resolve_feature('mask', strict=False) - if mask_feature is None: - logger.debug( - 'No "mask" found in DataHandler. No chunks will be masked.' - ) + if 'mask' not in handler: + logger.debug('No "mask" found in data. No chunks will be masked.') return mask logger.debug( @@ -725,7 +722,7 @@ def fwp_mask(self): 'chunk mask for %s chunks', len(self.lr_pad_slices), ) - mask_vals = getattr(mask_feature, 'values', mask_feature) + mask_vals = handler.data['mask'].values for s_chunk_idx, lr_slices in enumerate(self.lr_pad_slices): mask_check = mask_vals[lr_slices[0], lr_slices[1]] mask[s_chunk_idx] = bool(np.prod(mask_check.flatten())) diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 3152881f1..a1c676a71 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -434,9 +434,10 @@ def shape(self): """Get shape of underlying data.""" return self.data.shape - def resolve_feature(self, feature, strict=True): + def derive(self, feature, strict=True): """Resolve feature name to a feature in the underlying data. This is - used for handling feature aliases.""" + used for handling feature aliases and for deriving new features from + existing ones.""" if feature in self.data.features: return self.data[feature] elif strict: diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 2b9f7c175..8371276d0 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -74,7 +74,7 @@ class DataHandler(Deriver): ... 'chunks': cache_chunks}) >>> # Derive more features from already initialized data handler: - >>> dh['windspeed_60m'] = dh.resolve_feature('windspeed_60m') + >>> dh['windspeed_60m'] = dh.derive('windspeed_60m') Derive wind speed and direction at 200m above the ground from files for geopotential height (zg), surface elevation (orog), and u/v at 10m, diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index d753761e6..4ec0a4c6b 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -70,7 +70,7 @@ def __init__( features = parse_to_list(data=data, features=features) new_features = [f for f in features if f not in self.data] for f in new_features: - self.data[f] = self.resolve_feature(f) + self.data[f] = self.derive(f) logger.info('Finished deriving %s.', f) self.data = ( self.data[list(self.data.coords)] @@ -131,7 +131,7 @@ def check_registry( if any(missing) and can_derive: logger.debug(msg, missing) for f in missing: - self.data[f] = self.resolve_feature(f) + self.data[f] = self.derive(f) msg = 'All required features %s found. Proceeding.' if not missing or all(f in self.data for f in missing): logger.debug(msg, inputs) @@ -205,7 +205,7 @@ def has_interp_variables(self, feature): count += 1 return count > 1 or fstruct.basename in self.data - def resolve_feature( + def derive( self, feature, strict=True ) -> Union[np.ndarray, da.core.Array, None]: """Resolve a feature from contained data or available derivations. @@ -223,7 +223,7 @@ def resolve_feature( compute_check = self.check_registry(feature) if compute_check is not None and isinstance(compute_check, str): new_feature = self.map_new_name(feature, compute_check) - return self.resolve_feature(new_feature, strict=strict) + return self.derive(new_feature, strict=strict) if compute_check is not None: return compute_check diff --git a/sup3r/qa/qa.py b/sup3r/qa/qa.py index 9bb17db49..e1fa302b9 100644 --- a/sup3r/qa/qa.py +++ b/sup3r/qa/qa.py @@ -247,7 +247,7 @@ def bias_correct_input_handler(self, input_handler): (1) Check if we need to derive any features included in the bias_correct_kwargs. - (2) Derive these features using input_handler.resolve_feature, and + (2) Derive these features using input_handler.derive, and update the stored data. (3) Apply bias correction to all the features in the bias_correct_kwargs @@ -261,15 +261,15 @@ def bias_correct_input_handler(self, input_handler): ) msg = ( f'Features {need_derive} need to be derived prior to bias ' - 'correction, but the input_handler has no resolve_feature method. ' + 'correction, but the input_handler has no derive method. ' 'Request an appropriate input_handler with ' 'input_handler_name=DataHandlerName.' ) assert len(need_derive) == 0 or hasattr( - input_handler, 'resolve_feature' + input_handler, 'derive' ), msg for f in need_derive: - input_handler.data[f] = input_handler.resolve_feature(f) + input_handler.data[f] = input_handler.derive(f) bc_feats = list( set(input_handler.features).intersection( set(lowered(self.bias_correct_kwargs.keys())) diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index bd0792364..f7e9c09c3 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -285,8 +285,8 @@ def test_netcdf_uv_invert(): dh = DataHandler( fp_out, features=['windspeed_10m', 'winddirection_10m'] ) - uvals = dh.resolve_feature('u_10m').values - vvals = dh.resolve_feature('v_10m').values + uvals = dh.derive('u_10m').values + vvals = dh.derive('v_10m').values assert np.allclose(data[..., 0], uvals, atol=1e-5) assert np.allclose(data[..., 1], vvals, atol=1e-5) From e634dcecb5edb399219fe3df6810d37bd8cc5b86 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 18:04:23 -0600 Subject: [PATCH 23/34] fix: update mask handling in ForwardPassStrategy to improve logging and error handling --- sup3r/pipeline/strategy.py | 35 +++++++++++++++------------- sup3r/preprocessing/derivers/base.py | 1 - 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index 94fd1737b..df35f6665 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -710,22 +710,25 @@ def fwp_mask(self): mask = np.zeros(len(self.lr_pad_slices)) logger.debug('Checking for mask in input handler.') input_handler_kwargs = copy.deepcopy(self.input_handler_kwargs) - input_handler_kwargs['features'] = 'all' - InputHandler = get_input_handler_class(self.input_handler_name) - handler = InputHandler(**input_handler_kwargs) - if 'mask' not in handler: - logger.debug('No "mask" found in data. No chunks will be masked.') - return mask - - logger.debug( - 'Found "mask" in DataHandler. Computing forward pass ' - 'chunk mask for %s chunks', - len(self.lr_pad_slices), - ) - mask_vals = handler.data['mask'].values - for s_chunk_idx, lr_slices in enumerate(self.lr_pad_slices): - mask_check = mask_vals[lr_slices[0], lr_slices[1]] - mask[s_chunk_idx] = bool(np.prod(mask_check.flatten())) + try: + InputHandler = get_input_handler_class(self.input_handler_name) + input_handler_kwargs['features'] = ['mask'] + handler = InputHandler(**input_handler_kwargs) + logger.debug( + 'Found "mask" in %s. Computing forward pass ' + 'chunk mask for %s chunks', + self.input_handler_name, + len(self.lr_pad_slices), + ) + mask_vals = handler.data['mask'].values + for s_chunk_idx, lr_slices in enumerate(self.lr_pad_slices): + mask_check = mask_vals[lr_slices[0], lr_slices[1]] + mask[s_chunk_idx] = bool(np.prod(mask_check.flatten())) + except Exception: + logger.debug( + 'No "mask" found in %s. No chunks will be masked.', + self.input_handler_name, + ) return mask def node_finished(self, node_idx): diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 4ec0a4c6b..f95e7477d 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -243,7 +243,6 @@ def derive( 'Could not find "%s" in contained data or in the available ' 'compute methods.' ) - logger.error(msg, feature) raise RuntimeError(msg % feature) if np.isnan(self.data[feature]).any(): From be71f38812ab9b0e50802a51da37a7ac809ca632 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 10 May 2026 19:50:15 -0600 Subject: [PATCH 24/34] f-string to % formatting for logging to save time when log_level = INFO --- sup3r/bias/base.py | 23 +++--- sup3r/bias/bias_calc.py | 9 +-- sup3r/bias/bias_calc_cli.py | 2 +- sup3r/bias/bias_calc_vortex.py | 67 ++++++++++------ sup3r/bias/bias_transforms.py | 10 ++- sup3r/bias/mixins.py | 14 ++-- sup3r/bias/presrat.py | 6 +- sup3r/bias/qdm.py | 6 +- sup3r/bias/utilities.py | 17 ++-- sup3r/cli.py | 14 ++-- sup3r/models/abstract.py | 28 ++++--- sup3r/models/base.py | 80 +++++++++---------- sup3r/models/conditional.py | 31 ++++--- sup3r/models/dc.py | 19 +++-- sup3r/models/linear.py | 16 ++-- sup3r/models/multi_step.py | 79 +++++++++--------- sup3r/models/solar_cc.py | 2 +- sup3r/models/surface.py | 15 ++-- sup3r/pipeline/forward_pass.py | 70 +++++++++------- sup3r/pipeline/strategy.py | 13 +-- sup3r/postprocessing/collectors/h5.py | 11 ++- sup3r/preprocessing/base.py | 5 +- sup3r/preprocessing/batch_queues/utilities.py | 6 +- sup3r/preprocessing/collections/stats.py | 9 ++- sup3r/preprocessing/data_handlers/base.py | 4 +- sup3r/preprocessing/data_handlers/exo.py | 8 +- sup3r/preprocessing/data_handlers/nc_cc.py | 26 +++--- sup3r/preprocessing/derivers/base.py | 17 +++- sup3r/preprocessing/rasterizers/base.py | 38 ++++++--- sup3r/preprocessing/rasterizers/dual.py | 6 +- sup3r/preprocessing/rasterizers/exo.py | 4 +- sup3r/preprocessing/samplers/base.py | 5 +- sup3r/preprocessing/utilities.py | 9 ++- sup3r/qa/qa.py | 25 +++--- sup3r/qa/utilities.py | 2 +- sup3r/solar/solar.py | 35 ++++---- sup3r/solar/solar_cli.py | 15 ++-- sup3r/utilities/cli.py | 18 +++-- sup3r/utilities/era_downloader.py | 43 +++++----- sup3r/utilities/interpolation.py | 2 +- sup3r/utilities/utilities.py | 2 +- sup3r/writers/base.py | 25 +++--- sup3r/writers/cachers.py | 16 ++-- sup3r/writers/h5.py | 4 +- sup3r/writers/utilities.py | 2 +- tests/utilities/test_era_downloader.py | 1 + 46 files changed, 463 insertions(+), 396 deletions(-) diff --git a/sup3r/bias/base.py b/sup3r/bias/base.py index 25664c714..00b13997c 100644 --- a/sup3r/bias/base.py +++ b/sup3r/bias/base.py @@ -115,8 +115,10 @@ class is used, all data will be loaded in this class' """ logger.info( - 'Initializing DataRetrievalBase for base dset "{}" ' - 'correcting biased dataset(s): {}'.format(base_dset, bias_feature) + 'Initializing DataRetrievalBase for base dset "%s" correcting ' + 'biased dataset(s): %s', + base_dset, + bias_feature, ) self.base_fps = base_fps self.bias_fps = bias_fps @@ -235,9 +237,8 @@ def distance_upper_bound(self): diff = np.max(np.median(diff, axis=0)) self._distance_upper_bound = diff logger.info( - 'Set distance upper bound to {:.4f}'.format( - self._distance_upper_bound - ) + 'Set distance upper bound to %.4f', + self._distance_upper_bound, ) return self._distance_upper_bound @@ -587,13 +588,11 @@ def _match_zero_rate(bias_data, base_data): q_zero_bias_out = np.nanmean(bias_data == 0) logger.debug( - 'Input bias/base zero rate is {:.3e}/{:.3e}, ' - 'output is {:.3e}/{:.3e}'.format( - q_zero_bias_in, - q_zero_base_in, - q_zero_bias_out, - q_zero_base_out, - ) + 'Input bias/base zero rate is %.3e/%.3e, output is %.3e/%.3e', + q_zero_bias_in, + q_zero_base_in, + q_zero_bias_out, + q_zero_base_out, ) return bias_data diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 4e3f40f72..2deb0ba16 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -158,9 +158,7 @@ def write_outputs(self, fp_out, out): for k, v in self.meta.items(): f.attrs[k] = json.dumps(v) - logger.info( - 'Wrote scalar adder factors to file: {}'.format(fp_out) - ) + logger.info('Wrote scalar adder factors to file: %s', fp_out) def _get_run_kwargs(self, **kwargs_extras): """Get dictionary of kwarg dictionaries to use for calls to @@ -236,9 +234,8 @@ def run( logger.debug('Starting linear correction calculation...') logger.info( - 'Initialized scalar / adder with shape: {}'.format( - self.bias_gid_raster.shape - ) + 'Initialized scalar / adder with shape: %s', + self.bias_gid_raster.shape, ) self.out = self._run( out=self.out, diff --git a/sup3r/bias/bias_calc_cli.py b/sup3r/bias/bias_calc_cli.py index 39ccc99e5..852a4536d 100644 --- a/sup3r/bias/bias_calc_cli.py +++ b/sup3r/bias/bias_calc_cli.py @@ -82,7 +82,7 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): cmd = BiasCalcClass.get_node_cmd(node_config) cmd_log = '\n\t'.join(cmd.split('\n')) - logger.debug(f'Running command:\n\t{cmd_log}') + logger.debug('Running command:\n\t%s', cmd_log) logger.info( 'Queueing bias calculation node %s as job "%s".', i_node, diff --git a/sup3r/bias/bias_calc_vortex.py b/sup3r/bias/bias_calc_vortex.py index e926203ed..b4add1a4d 100644 --- a/sup3r/bias/bias_calc_vortex.py +++ b/sup3r/bias/bias_calc_vortex.py @@ -111,7 +111,7 @@ def convert_month_height_tif(self, month, height): corresponding input file and write this to a netcdf file. """ infile = self.get_input_file(month, height) - logger.info(f'Getting mean windspeed_{height}m for {month}.') + logger.info('Getting mean windspeed_%sm for %s.', height, month) out_file = infile.replace('.tif', '.nc') if os.path.exists(out_file) and self.overwrite: os.remove(out_file) @@ -138,7 +138,7 @@ def convert_all_tifs(self): """Write netcdf files for all heights for all months.""" for i in range(1, 13): month = calendar.month_name[i] - logger.info(f'Converting tif files to netcdf files for {month}') + logger.info('Converting tif files to netcdf files for %s', month) self.convert_month_tif(month) @property @@ -175,23 +175,26 @@ def get_month(self, month): os.remove(month_file) if os.path.exists(month_file) and not self.overwrite: - logger.info(f'Loading month_file {month_file}.') + logger.info('Loading month_file %s.', month_file) data = xr.open_mfdataset(month_file) else: logger.info( - 'Getting mean windspeed for all heights ' - f'({self.in_heights}) for {month}' + 'Getting mean windspeed for all heights (%s) for %s', + self.in_heights, + month, ) data = xr.open_mfdataset(self.get_height_files(month)) logger.info( - 'Interpolating windspeed for all heights ' - f'({self.out_heights}) for {month}.' + 'Interpolating windspeed for all heights (%s) for %s.', + self.out_heights, + month, ) data = self.interp(data) data.to_netcdf(month_file, format='NETCDF4', engine='h5netcdf') logger.info( - 'Saved interpolated means for all heights for ' - f'{month} to {month_file}.' + 'Saved interpolated means for all heights for %s to %s.', + month, + month_file, ) return data @@ -222,8 +225,10 @@ def interp(self, data): lev_array[..., i] = h logger.info( - f'Interpolating {self.in_features} to {self.out_features} ' - f'for {var_array.shape[0]} coordinates.' + 'Interpolating %s to %s for %s coordinates.', + self.in_features, + self.out_features, + var_array.shape[0], ) tmp = [ interp1d(h, v, fill_value='extrapolate')(self.out_heights) @@ -312,13 +317,13 @@ def write_data(self, fp_out, out): for dset, data in out.items(): OutputHandler._ensure_dset_in_output(fp_out, dset) f[dset] = data.T - logger.info(f'Added {dset} to {fp_out}.') + logger.info('Added %s to %s.', dset, fp_out) logger.info( - f'Wrote monthly means for all out heights: {fp_out}' + 'Wrote monthly means for all out heights: %s', fp_out ) elif os.path.exists(fp_out): - logger.info(f'{fp_out} already exists and overwrite=False.') + logger.info('%s already exists and overwrite=False.', fp_out) @classmethod def run( @@ -376,13 +381,18 @@ def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): """ with Resource(bc_file) as res: logger.info( - f'Getting {dset} bias correction factors for month {month}.' + 'Getting %s bias correction factors for month %s.', + dset, + month, ) bc_factor = res[f'{dset}_scalar', :, month - 1] factors = global_scalar * bc_factor logger.info( - f'Retrieved {dset} bias correction factors for month {month}. ' - f'Using global_scalar={global_scalar}.' + 'Retrieved %s bias correction factors for month %s. ' + 'Using global_scalar=%s.', + dset, + month, + global_scalar, ) return factors @@ -419,7 +429,9 @@ def _correct_month( month=month, global_scalar=global_scalar, ) - logger.info(f'Applying bias correction factors for month {month}') + logger.info( + 'Applying bias correction factors for month %s', month + ) fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] @classmethod @@ -453,7 +465,9 @@ def update_file( Number of workers to use for parallel processing. """ tmp_file = get_tmp_file(out_file) - logger.info(f'Bias correcting {dset} in {in_file} with {bc_file}.') + logger.info( + 'Bias correcting %s in %s with %s.', dset, in_file, bc_file + ) with Resource(in_file) as fh_in: OutputHandler._init_h5( tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs @@ -482,8 +496,7 @@ def update_file( logger.info('Finished bias correcting %s in %s', dset, in_file) os.replace(tmp_file, out_file) - msg = f'Saved bias corrected {dset} to: {out_file}' - logger.info(msg) + logger.info('Saved bias corrected %s to: %s', dset, out_file) @classmethod def run( @@ -518,14 +531,16 @@ def run( Number of workers to use for parallel processing. """ if os.path.exists(out_file) and not overwrite: - logger.info( - f'{out_file} already exists and overwrite=False. Skipping.' - ) + logger.info( + '%s already exists and overwrite=False. Skipping.', + out_file, + ) else: if os.path.exists(out_file) and overwrite: logger.info( - f'{out_file} exists but overwrite=True. ' - f'Removing {out_file}.' + '%s exists but overwrite=True. Removing %s.', + out_file, + out_file, ) os.remove(out_file) cls.update_file( diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 445440ca7..8c3784e34 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -610,9 +610,9 @@ def _apply_qdm( tmp = np.reshape(subset.data, (-1, subset.shape[-1])).T # Apply QDM correction - logger.info(f'Applying QDM to data with shape {tmp.shape}...') + logger.info('Applying QDM to data with shape %s...', tmp.shape) tmp = QDM(tmp, max_workers=max_workers) - logger.info(f'Finished QDM on data shape {tmp.shape}!') + logger.info('Finished QDM on data shape %s!', tmp.shape) # Reorgnize array back from (time, space) # to (spatial, spatial, temporal) @@ -1074,8 +1074,10 @@ def local_presrat_bc( k_factor = np.minimum(k_factor, np.max(k_range)) logger.debug( - f'Presrat K Factor has shape {k_factor.shape} and ranges ' - f'from {k_factor.min()} to {k_factor.max()}' + 'Presrat K Factor has shape %s and ranges from %s to %s', + k_factor.shape, + k_factor.min(), + k_factor.max(), ) if lr_padded_slice is not None: diff --git a/sup3r/bias/mixins.py b/sup3r/bias/mixins.py index 0c41b0288..dfc1bb580 100644 --- a/sup3r/bias/mixins.py +++ b/sup3r/bias/mixins.py @@ -61,9 +61,9 @@ def fill_and_smooth( """ if len(self.bad_bias_gids) > 0: logger.info( - 'Found {} bias gids that are out of bounds: {}'.format( - len(self.bad_bias_gids), self.bad_bias_gids - ) + 'Found %s bias gids that are out of bounds: %s', + len(self.bad_bias_gids), + self.bad_bias_gids, ) for key, arr in out.items(): @@ -77,10 +77,10 @@ def fill_and_smooth( if needs_fill: logger.info( - 'Filling NaN values outside of valid spatial ' - 'extent for dataset "{}" for timestep {}'.format( - key, idt - ) + 'Filling NaN values outside of valid spatial extent ' + 'for dataset "%s" for timestep %s', + key, + idt, ) arr_smooth = nn_fill_array(arr_smooth) diff --git a/sup3r/bias/presrat.py b/sup3r/bias/presrat.py index da09ad478..be17b7e14 100644 --- a/sup3r/bias/presrat.py +++ b/sup3r/bias/presrat.py @@ -414,9 +414,7 @@ def run( logger.debug('Calculating CDF parameters for QDM') logger.info( - 'Initialized params with shape: {}'.format( - self.bias_gid_raster.shape - ) + 'Initialized params with shape: %s', self.bias_gid_raster.shape ) self.out = self._run( out=self.out, @@ -491,4 +489,4 @@ def write_outputs( if extra_attrs is not None: for a, v in extra_attrs.items(): f.attrs[a] = v - logger.info('Wrote quantiles to file: {}'.format(fp_out)) + logger.info('Wrote quantiles to file: %s', fp_out) diff --git a/sup3r/bias/qdm.py b/sup3r/bias/qdm.py index a635ee811..166059622 100644 --- a/sup3r/bias/qdm.py +++ b/sup3r/bias/qdm.py @@ -490,7 +490,7 @@ def write_outputs(self, fp_out, out=None): f.attrs['bias_fps'] = self.bias_fps f.attrs['bias_fut_fps'] = self.bias_fut_fps f.attrs['time_window_center'] = self.time_window_center - logger.info('Wrote quantiles to file: {}'.format(fp_out)) + logger.info('Wrote quantiles to file: %s', fp_out) def _get_run_kwargs(self, **kwargs_extras): """Get dictionary of kwarg dictionaries to use for calls to @@ -562,9 +562,7 @@ def run( logger.debug('Calculating CDF parameters for QDM') logger.info( - 'Initialized params with shape: {}'.format( - self.bias_gid_raster.shape - ) + 'Initialized params with shape: %s', self.bias_gid_raster.shape ) self.out = self._run( diff --git a/sup3r/bias/utilities.py b/sup3r/bias/utilities.py index 20f05c68b..ebce253ac 100644 --- a/sup3r/bias/utilities.py +++ b/sup3r/bias/utilities.py @@ -15,7 +15,6 @@ get_date_range_kwargs, ) - logger = logging.getLogger(__name__) @@ -90,10 +89,9 @@ def lin_bc(handler, bc_files, bias_feature=None, threshold=0.1): raise RuntimeError(msg) logger.info( - 'Bias correcting "{}" with linear ' - 'correction from "{}"'.format( - feature, os.path.basename(fp) - ) + 'Bias correcting "%s" with linear correction from "%s"', + feature, + os.path.basename(fp), ) handler.data[feature] = ( scalar * handler.data[feature][...] + adder @@ -193,10 +191,9 @@ def qdm_bc( if feature not in completed and check: logger.info( - 'Bias correcting "{}" with QDM ' - 'correction from "{}"'.format( - feature, os.path.basename(fp) - ) + 'Bias correcting "%s" with QDM correction from "%s"', + feature, + os.path.basename(fp), ) handler.data[feature] = local_qdm_bc( handler.data[feature], @@ -257,7 +254,7 @@ def bias_correct_feature( lat_lon = input_handler.lat_lon if bc_method is not None: bc_method = getattr(sup3r.bias.bias_transforms, bc_method) - logger.info(f'Running bias correction with: {bc_method}.') + logger.info('Running bias correction with: %s.', bc_method) feature_kwargs = bc_kwargs[source_feature] if 'date_range_kwargs' in signature(bc_method).parameters: diff --git a/sup3r/cli.py b/sup3r/cli.py index d38d13ced..d64fac1dc 100644 --- a/sup3r/cli.py +++ b/sup3r/cli.py @@ -120,7 +120,7 @@ def forward_pass(ctx, verbose): To run the job locally, use ``execution_control: {"option": "local"}``. """ # noqa : D301 config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke(fwp_cli, config_file=config_file, verbose=verbose) @@ -169,7 +169,7 @@ def solar(ctx, verbose): To run the job locally, use ``execution_control: {"option": "local"}``. """ # noqa : D301 config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke(solar_cli, config_file=config_file, verbose=verbose) @@ -232,7 +232,7 @@ def bias_calc(ctx, verbose): To run the job locally, use ``execution_control: {"option": "local"}``. """ # noqa : D301 config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke(bias_calc_cli, config_file=config_file, verbose=verbose) @@ -274,7 +274,7 @@ def data_collect(ctx, verbose): and you can set ``"option": "kestrel"`` to run on the NLR HPC. """ # noqa : D301 config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke(dc_cli, config_file=config_file, verbose=verbose) @@ -317,7 +317,7 @@ def qa(ctx, verbose): and you can set ``"option": "kestrel"`` to run on the NLR HPC. """ # noqa : D301 config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke(qa_cli, config_file=config_file, verbose=verbose) @@ -373,7 +373,7 @@ def pipeline(ctx, cancel, monitor, background, verbose): """ # noqa: D301 if ctx.invoked_subcommand is None: config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke( pipe_cli, config_file=config_file, @@ -443,7 +443,7 @@ def batch(ctx, dry_run, cancel, delete, monitor_background, verbose): """ # noqa : D301 if ctx.invoked_subcommand is None: config_file = ctx.obj['CONFIG_FILE'] - verbose = any([verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, ctx.obj['VERBOSE'])) ctx.invoke( batch_cli, config_file=config_file, diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index f4396e00c..97d008802 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -233,7 +233,7 @@ def set_norm_stats(self, new_means, new_stdevs): missing += [ f for f in self.hr_out_features if f not in self._means ] - if any(missing): + if missing: msg = ( f'Need means for features "{missing}" but did not find ' f'in new means array: {self._means}' @@ -272,7 +272,7 @@ def norm_input(self, low_res): low_res = low_res.numpy() missing = [fn for fn in self.lr_features if fn not in self._means] - if any(missing): + if missing: msg = ( f'Could not find low-res input features {missing} in ' f'means/stdevs: {self._means}/{self._stdevs}' @@ -311,7 +311,7 @@ def un_norm_output(self, output): missing = [ fn for fn in self.hr_out_features if fn not in self._means ] - if any(missing): + if missing: msg = ( f'Could not find high-res output features {missing} in ' f'means/stdevs: {self._means}/{self._stdevs}' @@ -367,6 +367,9 @@ def history(self): ------- pandas.DataFrame | None """ + if self._history is None: + self._history = pd.DataFrame(columns=['elapsed_time']) + self._history.index.name = 'epoch' return self._history @property @@ -510,9 +513,8 @@ def load_saved_params(out_dir, verbose=True): if verbose: logger.debug( 'Loading model from disk that was created with the ' - 'following package versions: \n{}'.format( - pprint.pformat(version_record, indent=2) - ) + 'following package versions: \n%s', + pprint.pformat(version_record, indent=2), ) means = params.get('means', None) @@ -839,9 +841,9 @@ def log_loss_details(loss_details, level='INFO'): for k, v in sorted(loss_details.items()): msg_format = '\t{}: {}' if isinstance(v, str) else '\t{}: {:.2e}' if level.lower() == 'info': - logger.info(msg_format.format(k, v)) + logger.info(msg_format, k, v) else: - logger.debug(msg_format.format(k, v)) + logger.debug(msg_format, k, v) @staticmethod def early_stop(history, column, threshold=0.005, n_epoch=5): @@ -878,11 +880,11 @@ def early_stop(history, column, threshold=0.005, n_epoch=5): if all(diffs[-n_epoch:] < threshold): stop = True logger.info( - 'Found early stop condition, loss values "{}" ' - 'have absolute relative differences less than ' - 'threshold {}: {}'.format( - column, threshold, diffs[-n_epoch:] - ) + 'Found early stop condition, loss values "%s" have ' + 'absolute relative differences less than threshold %s: %s', + column, + threshold, + diffs[-n_epoch:], ) return stop diff --git a/sup3r/models/base.py b/sup3r/models/base.py index b58ecfc0c..3090f70fd 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -173,7 +173,7 @@ def save(self, out_dir): self.save_params(out_dir) - logger.info('Saved GAN to disk in directory: {}'.format(out_dir)) + logger.info('Saved GAN to disk in directory: %s', out_dir) @classmethod def _load(cls, model_dir, verbose=False): @@ -196,13 +196,11 @@ def _load(cls, model_dir, verbose=False): Dictionary of model params to be used in model initialization """ if verbose: - logger.info( - 'Loading GAN from disk in directory: {}'.format(model_dir) - ) - msg = 'Active python environment versions: \n{}'.format( - pprint.pformat(VERSION_RECORD, indent=4) + logger.info('Loading GAN from disk in directory: %s', model_dir) + logger.debug( + 'Active python environment versions: \n%s', + pprint.pformat(VERSION_RECORD, indent=4), ) - logger.debug(msg) fp_gen = os.path.join(model_dir, 'model_gen.pkl') fp_disc = os.path.join(model_dir, 'model_disc.pkl') @@ -637,7 +635,8 @@ def update_adversarial_weights( if update_frac != 1: logger.debug( - f'New discriminator weight: {weight_gen_advers:.4e}' + 'New discriminator weight: %.4e', + weight_gen_advers, ) return weight_gen_advers @@ -696,20 +695,15 @@ def train(self, batch_handler, config=None, **kwargs): batch_handler=batch_handler, ) - epochs = list(range(config.n_epoch)) - - if self._history is None: - self._history = pd.DataFrame(columns=['elapsed_time']) - self._history.index.name = 'epoch' - else: - epochs += self._history.index.values[-1] + 1 + epochs = range(len(self.history), len(self.history) + config.n_epoch) t0 = time.time() logger.info( - 'Training model with adversarial weight: {} ' - 'for {} epochs starting at epoch {}'.format( - config.weight_gen_advers, config.n_epoch, epochs[0] - ) + 'Training model with adversarial weight: %s for %s epochs ' + 'starting at epoch %s', + config.weight_gen_advers, + config.n_epoch, + epochs[0], ) lr_shape, hr_shape = batch_handler.shapes @@ -783,17 +777,15 @@ def train(self, batch_handler, config=None, **kwargs): extras=extras, ) logger.debug( - 'Finished training epoch in {:.4f} seconds'.format( - time.time() - t_epoch - ) + 'Finished training epoch in %.4f seconds', + time.time() - t_epoch, ) if stop: break logger.info( - 'Finished training {} epochs in {:.4f} seconds'.format( - config.n_epoch, - time.time() - t0, - ) + 'Finished training %s epochs in %.4f seconds', + config.n_epoch, + time.time() - t0, ) batch_handler.stop() @@ -1054,20 +1046,19 @@ def _post_batch(self, ib, b_loss_details, n_batches, previous_means): gen_loss = self._train_record['train_loss_gen'].values.mean() logger.debug( - 'Batch {} out of {} has (gen / disc) loss of: ({:.2e} / {:.2e}). ' - 'Running mean (gen / disc): ({:.2e} / {:.2e}). Trained ' - '(gen / disc): ({} / {})'.format( - ib + 1, - n_batches, - b_loss_details['loss_gen'], - b_loss_details['loss_disc'], - gen_loss, - disc_loss, - trained_gen, - trained_disc, - ) + 'Batch %s out of %s has (gen / disc) loss of: (%.2e / %.2e). ' + 'Running mean (gen / disc): (%.2e / %.2e). Trained ' + '(gen / disc): (%s / %s)', + ib + 1, + n_batches, + b_loss_details['loss_gen'], + b_loss_details['loss_disc'], + gen_loss, + disc_loss, + trained_gen, + trained_disc, ) - if all([not trained_gen, not trained_disc]): + if all((not trained_gen, not trained_disc)): msg = ( 'For some reason none of the GAN networks trained during ' 'batch {} out of {}!'.format(ib, n_batches) @@ -1160,10 +1151,13 @@ def _train_epoch( batch_load_time = total_step_time - batch_step_time logger.debug( - f'Finished batch step {ib + 1} / {len(batch_handler)} in ' - f'{total_step_time:.4f} seconds. Batch load time: ' - f'{batch_load_time:.4f} seconds. Batch train time: ' - f'{batch_step_time:.4f} seconds.' + 'Finished batch step %s / %s in %.4f seconds. Batch load ' + 'time: %.4f seconds. Batch train time: %.4f seconds.', + ib + 1, + len(batch_handler), + total_step_time, + batch_load_time, + batch_step_time, ) prev_time = time.time() diff --git a/sup3r/models/conditional.py b/sup3r/models/conditional.py index 6c61ef834..12cf2f812 100644 --- a/sup3r/models/conditional.py +++ b/sup3r/models/conditional.py @@ -115,7 +115,7 @@ def save(self, out_dir): self.save_params(out_dir) - logger.info('Saved model to disk in directory: {}'.format(out_dir)) + logger.info('Saved model to disk in directory: %s', out_dir) @classmethod def load(cls, model_dir, verbose=True): @@ -135,13 +135,11 @@ def load(cls, model_dir, verbose=True): Returns a pretrained gan model that was previously saved to out_dir """ if verbose: + logger.info('Loading model from disk in directory: %s', model_dir) logger.info( - 'Loading model from disk in directory: {}'.format(model_dir) + 'Active python environment versions: \n%s', + pprint.pformat(VERSION_RECORD, indent=4), ) - msg = 'Active python environment versions: \n{}'.format( - pprint.pformat(VERSION_RECORD, indent=4) - ) - logger.info(msg) fp_gen = os.path.join(model_dir, 'model_gen.pkl') params = cls.load_saved_params(model_dir, verbose=verbose) @@ -294,10 +292,10 @@ def _train_epoch(self, batch_handler, multi_gpu=False): loss_details = self._train_record.mean().to_dict() logger.debug( - 'Batch {} out of {} has epoch-average gen loss of: ' - '{:.2e}. '.format( - ib, len(batch_handler), loss_details['train_loss_gen'] - ) + 'Batch %s out of %s has epoch-average gen loss of: %.2e. ', + ib, + len(batch_handler), + loss_details['train_loss_gen'], ) return loss_details @@ -329,19 +327,20 @@ def train(self, batch_handler, config=None, **kwargs): batch_handler=batch_handler, ) - epochs = list(range(config.n_epoch)) - if self._history is None: self._history = pd.DataFrame(columns=['elapsed_time']) self._history.index.name = 'epoch' + start_epoch = 0 else: - epochs += self._history.index.values[-1] + 1 + start_epoch = int(self._history.index.values[-1]) + 1 + + epochs = range(start_epoch, start_epoch + config.n_epoch) t0 = time.time() logger.info( - 'Training model for {} epochs starting at epoch {}'.format( - config.n_epoch, epochs[0] - ) + 'Training model for %s epochs starting at epoch %s', + config.n_epoch, + epochs[0], ) for epoch in epochs: diff --git a/sup3r/models/dc.py b/sup3r/models/dc.py index 9c21040e9..a23a91510 100644 --- a/sup3r/models/dc.py +++ b/sup3r/models/dc.py @@ -49,8 +49,11 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): dtype=np.float32, ) for i, batch in enumerate(batch_handler.val_data): - logger.info(f'Calculating validation loss for batch {i} / ' - f'{len(batch_handler.val_data)}...') + logger.info( + 'Calculating validation loss for batch %s / %s...', + i, + len(batch_handler.val_data), + ) hi_res_exo = self.get_hr_exo_input(batch.high_res) hi_res_gen = self._tf_generate(batch.low_res, hi_res_exo) loss, loss_details = self.calc_loss( @@ -96,20 +99,20 @@ def calc_val_loss(self, batch_handler, weight_gen_advers): s_weights /= s_weights.sum() logger.debug( - f'Previous spatial weights: {batch_handler.spatial_weights}' + 'Previous spatial weights: %s', batch_handler.spatial_weights ) logger.debug( - f'Previous temporal weights: {batch_handler.temporal_weights}' + 'Previous temporal weights: %s', batch_handler.temporal_weights ) batch_handler.update_weights( spatial_weights=s_weights, temporal_weights=t_weights ) logger.debug( - 'New spatiotemporal weights (space, time):\n' - f'{total_losses / total_losses.sum()}' + 'New spatiotemporal weights (space, time):\n%s', + total_losses / total_losses.sum(), ) - logger.debug(f'New spatial weights: {s_weights}') - logger.debug(f'New temporal weights: {t_weights}') + logger.debug('New spatial weights: %s', s_weights) + logger.debug('New temporal weights: %s', t_weights) loss_details['mean_val_loss_gen'] = round(np.mean(total_losses), 3) loss_details['mean_val_loss_gen_content'] = round( diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index 3316ee166..e18104f06 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -74,8 +74,7 @@ def load(cls, model_dir, verbose=False): model = cls(**kwargs) if verbose: - logger.info('Loading LinearInterp with meta data: {}' - .format(model.meta)) + logger.info('Loading LinearInterp with meta data: %s', model.meta) return model @@ -152,11 +151,14 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, int(low_res.shape[2] * self._s_enhance), int(low_res.shape[3] * self._t_enhance), len(self.hr_out_features)) - logger.debug('LinearInterp model with s_enhance of {} ' - 'and t_enhance of {} ' - 'downscaling low-res shape {} to high-res shape {}' - .format(self._s_enhance, self._t_enhance, - low_res.shape, hr_shape)) + logger.debug( + 'LinearInterp model with s_enhance of %s and t_enhance of %s ' + 'downscaling low-res shape %s to high-res shape %s', + self._s_enhance, + self._t_enhance, + low_res.shape, + hr_shape, + ) hi_res = np.zeros(hr_shape, dtype=np.float32) diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 2e5b21dfe..598121f2e 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -243,9 +243,10 @@ def generate( try: hi_res = self._transpose_model_input(model, hi_res) logger.debug( - 'Data input to model #{} of {} has shape {}'.format( - i + 1, len(self.models), hi_res.shape - ) + 'Data input to model #%s of %s has shape %s', + i + 1, + len(self.models), + hi_res.shape, ) hi_res = self._match_model_input(i, hi_res, i_exo_data) @@ -257,9 +258,10 @@ def generate( exogenous_data=i_exo_data, ) logger.debug( - 'Data output from model #{} of {} has shape {}'.format( - i + 1, len(self.models), hi_res.shape - ) + 'Data output from model #%s of %s has shape %s', + i + 1, + len(self.models), + hi_res.shape, ) except Exception as e: msg = ( @@ -409,8 +411,8 @@ def generate( and/or pressure_*m """ logger.debug( - 'Data input to the 1st step spatial-only ' - 'enhancement has shape {}'.format(low_res.shape) + 'Data input to the 1st step spatial-only enhancement has shape %s', + low_res.shape, ) msg = ( @@ -660,13 +662,11 @@ def idf_wind(self): """Get an array of feature indices for the subset of features required for the spatial_wind_models. This excludes topography which is assumed to be provided as exogenous_data.""" - return np.array( - [ - self.lr_features.index(fn) - for fn in self.spatial_wind_models.lr_features - if fn != 'topography' - ] - ) + return np.array([ + self.lr_features.index(fn) + for fn in self.spatial_wind_models.lr_features + if fn != 'topography' + ]) @property def idf_wind_out(self): @@ -675,25 +675,21 @@ def idf_wind_out(self): indices of u_200m + v_200m from the output features of spatial_wind_models""" temporal_solar_features = self.temporal_solar_models.lr_features - return np.array( - [ - self.spatial_wind_models.hr_out_features.index(fn) - for fn in temporal_solar_features[1:] - ] - ) + return np.array([ + self.spatial_wind_models.hr_out_features.index(fn) + for fn in temporal_solar_features[1:] + ]) @property def idf_solar(self): """Get an array of feature indices for the subset of features required for the spatial_solar_models. This excludes topography which is assumed to be provided as exogenous_data.""" - return np.array( - [ - self.lr_features.index(fn) - for fn in self.spatial_solar_models.lr_features - if fn != 'topography' - ] - ) + return np.array([ + self.lr_features.index(fn) + for fn in self.spatial_solar_models.lr_features + if fn != 'topography' + ]) def generate( self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None @@ -735,10 +731,9 @@ def generate( """ logger.debug( - 'Data input to the SolarMultiStepGan has shape {} which ' - 'will be split up for solar- and wind-only features.'.format( - low_res.shape - ) + 'Data input to the SolarMultiStepGan has shape %s which will be ' + 'split up for solar- and wind-only features.', + low_res.shape, ) if isinstance(exogenous_data, dict) and not isinstance( exogenous_data, ExoData @@ -780,10 +775,10 @@ def generate( raise RuntimeError(msg) from e logger.debug( - 'Data output from the 1st step spatial enhancement has ' - 'shape {} (solar) and shape {} (wind)'.format( - hi_res_solar.shape, hi_res_wind.shape - ) + 'Data output from the 1st step spatial enhancement has shape %s ' + '(solar) and shape %s (wind)', + hi_res_solar.shape, + hi_res_wind.shape, ) hi_res = (hi_res_solar, hi_res_wind[..., self.idf_wind_out]) @@ -791,15 +786,15 @@ def generate( logger.debug( 'Data output from the concatenated solar + wind 1st step ' - 'spatial-only enhancement has shape {}'.format(hi_res.shape) + 'spatial-only enhancement has shape %s', + hi_res.shape, ) hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) hi_res = np.expand_dims(hi_res, axis=0) logger.debug( - 'Data from the concatenated solar + wind 1st step ' - 'spatial-only enhancement has been reshaped to {}'.format( - hi_res.shape - ) + 'Data from the concatenated solar + wind 1st step spatial-only ' + 'enhancement has been reshaped to %s', + hi_res.shape, ) try: @@ -820,7 +815,7 @@ def generate( hi_res = self.temporal_pad(low_res, hi_res) logger.debug( - 'Final SolarMultiStepGan output has shape: {}'.format(hi_res.shape) + 'Final SolarMultiStepGan output has shape: %s', hi_res.shape ) return hi_res diff --git a/sup3r/models/solar_cc.py b/sup3r/models/solar_cc.py index ebcb9a808..38370f7f8 100644 --- a/sup3r/models/solar_cc.py +++ b/sup3r/models/solar_cc.py @@ -297,7 +297,7 @@ def generate(self, low_res, **kwargs): low_res, super().generate(low_res=low_res, **kwargs) ) - logger.debug('Final SolarCC output has shape: {}'.format(hi_res.shape)) + logger.debug('Final SolarCC output has shape: %s', hi_res.shape) return hi_res diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index 2fd85bd12..64b1b9977 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -621,8 +621,10 @@ def generate( lr_topo = np.asarray(lr_topo) hr_topo = np.asarray(hr_topo) logger.debug( - 'SurfaceSpatialMetModel received low/high res topo ' - 'shapes of {} and {}'.format(lr_topo.shape, hr_topo.shape) + 'SurfaceSpatialMetModel received low/high res topo shapes of %s ' + 'and %s', + lr_topo.shape, + hr_topo.shape, ) msg = f'topo_lr needs to be 2d but has shape {lr_topo.shape}' @@ -651,10 +653,11 @@ def generate( len(self.hr_out_features), ) logger.debug( - 'SurfaceSpatialMetModel with s_enhance of {} ' - 'downscaling low-res shape {} to high-res shape {}'.format( - self._s_enhance, low_res.shape, hr_shape - ) + 'SurfaceSpatialMetModel with s_enhance of %s downscaling ' + 'low-res shape %s to high-res shape %s', + self._s_enhance, + low_res.shape, + hr_shape, ) hi_res = np.zeros(hr_shape, dtype=np.float32) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 62e0c1a19..b542cc9d5 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -480,7 +480,9 @@ def _run_serial(cls, strategy, node_index): """ start = dt.now() - logger.debug(f'Running forward passes on node {node_index} in serial.') + logger.debug( + 'Running forward passes on node %s in serial.', node_index + ) fwp = cls(strategy, node_index=node_index) for i, chunk_index in enumerate(strategy.node_chunks[node_index]): now = dt.now() @@ -499,10 +501,13 @@ def _run_serial(cls, strategy, node_index): meta=fwp.meta, ) logger.debug( - 'Finished forward pass on chunk_index=' - f'{chunk_index} in {dt.now() - now}. {i + 1} of ' - f'{len(strategy.node_chunks[node_index])} ' - f'complete. {_mem_check()}.' + 'Finished forward pass on chunk_index=%s in %s. %s of %s ' + 'complete. %s.', + chunk_index, + dt.now() - now, + i + 1, + len(strategy.node_chunks[node_index]), + _mem_check(), ) if failed: msg = ( @@ -512,9 +517,9 @@ def _run_serial(cls, strategy, node_index): raise MemoryError(msg) logger.info( - 'Finished forward passes on ' - f'{len(strategy.node_chunks[node_index])} chunks in ' - f'{dt.now() - start}' + 'Finished forward passes on %s chunks in %s', + len(strategy.node_chunks[node_index]), + dt.now() - start, ) @classmethod @@ -533,8 +538,10 @@ def _run_parallel(cls, strategy, node_index): """ logger.info( - f'Running parallel forward passes on node {node_index}' - f' with pass_workers={strategy.pass_workers}.' + 'Running parallel forward passes on node %s with ' + 'pass_workers=%s.', + node_index, + strategy.pass_workers, ) futures = {} @@ -564,8 +571,9 @@ def _run_parallel(cls, strategy, node_index): } logger.info( - f'Started {len(futures)} forward pass runs in ' - f'{dt.now() - now}.' + 'Started %s forward pass runs in %s.', + len(futures), + dt.now() - now, ) try: @@ -579,24 +587,25 @@ def _run_parallel(cls, strategy, node_index): 'with constant output or NaNs.' ) raise MemoryError(msg) - msg = ( - 'Finished forward pass on chunk_index=' - f'{chunk_idx} in {dt.now() - start_time}. ' - f'{i + 1} of {len(futures)} complete. {_mem_check()}' + logger.debug( + 'Finished forward pass on chunk_index=%s in %s. %s ' + 'of %s complete. %s', + chunk_idx, + dt.now() - start_time, + i + 1, + len(futures), + _mem_check(), ) - logger.debug(msg) except Exception as e: - msg = ( - 'Error running forward pass on chunk_index=' - f'{futures[future]["chunk_index"]}.' - ) - logger.exception(msg) - raise RuntimeError(msg) from e + msg = 'Error running forward pass on chunk_index=%s.' + chunk_idx = futures[future]['chunk_index'] + logger.exception(msg, chunk_idx) + raise RuntimeError(msg % chunk_idx) from e logger.info( - 'Finished asynchronous forward passes on ' - f'{len(strategy.node_chunks[node_index])} chunks in ' - f'{dt.now() - start}' + 'Finished asynchronous forward passes on %s chunks in %s', + len(strategy.node_chunks[node_index]), + dt.now() - start, ) @classmethod @@ -666,8 +675,9 @@ def run_chunk( Array of high-resolution output from generator """ - msg = f'Running forward pass for chunk_index={chunk.index}.' - logger.debug(msg) + logger.debug( + 'Running forward pass for chunk_index=%s.', chunk.index + ) if model is None: model = get_model(model_class, model_kwargs) @@ -705,7 +715,9 @@ def run_chunk( failed = cls._output_check(output_data, allowed_const=allowed_const) if chunk.out_file is not None and not failed: - logger.debug(f'Saving forward pass output to {chunk.out_file}.') + logger.debug( + 'Saving forward pass output to %s.', chunk.out_file + ) output_type = get_source_type(chunk.out_file) cls.OUTPUT_HANDLER_CLASS[output_type]._write_output( data=output_data, diff --git a/sup3r/pipeline/strategy.py b/sup3r/pipeline/strategy.py index df35f6665..a4f916edc 100644 --- a/sup3r/pipeline/strategy.py +++ b/sup3r/pipeline/strategy.py @@ -491,7 +491,7 @@ def hr_lat_lon(self): lr_lat_lon = self.input_handler.lat_lon shape = tuple(d * self.s_enhance for d in lr_lat_lon.shape[:-1]) logger.debug( - f'Getting high-resolution grid for full output domain: {shape}' + 'Getting high-resolution grid for full output domain: %s', shape ) return OutputHandler.get_lat_lon(lr_lat_lon, shape) @@ -554,8 +554,9 @@ def prep_chunk_data(self, chunk_index=0): if self.bias_correct_kwargs != {}: logger.debug( - f'Bias correcting data for chunk_index={chunk_index}, ' - f'with shape={input_data.shape}' + 'Bias correcting data for chunk_index=%s, with shape=%s', + chunk_index, + input_data.shape, ) fun = self.timer( bias_correct_features, @@ -606,11 +607,11 @@ def init_chunk(self, chunk_index=0): 'ti_pad_slice': ti_pad_slice, } logger.debug( - 'Initializing ForwardPassChunk with: ' - f'{pprint.pformat(args_dict, indent=2)}' + 'Initializing ForwardPassChunk with: %s', + pprint.pformat(args_dict, indent=2), ) - logger.debug(f'Getting input data for chunk_index={chunk_index}.') + logger.debug('Getting input data for chunk_index=%s.', chunk_index) input_data, exo_data = self.timer( self.prep_chunk_data, log=True, call_id=chunk_index diff --git a/sup3r/postprocessing/collectors/h5.py b/sup3r/postprocessing/collectors/h5.py index 55885181a..e607446ac 100644 --- a/sup3r/postprocessing/collectors/h5.py +++ b/sup3r/postprocessing/collectors/h5.py @@ -486,12 +486,11 @@ def _write_flist_data( raise OSError(msg) from e logger.debug( - 'Finished writing "{}" for row {} and col {} to: {}'.format( - feature, - y_write_slice, - x_write_slice, - os.path.basename(out_file), - ) + 'Finished writing "%s" for row %s and col %s to: %s', + feature, + y_write_slice, + x_write_slice, + os.path.basename(out_file), ) def _collect_flist( diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index a1c676a71..2720b4f54 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -425,8 +425,9 @@ def post_init_log(self, args_dict=None): """Log additional arguments after initialization.""" if args_dict is not None: logger.info( - f'Finished initializing {self.__class__.__name__} with:\n' - f'{pprint.pformat(args_dict, indent=2)}' + 'Finished initializing %s with:\n%s', + self.__class__.__name__, + pprint.pformat(args_dict, indent=2), ) @property diff --git a/sup3r/preprocessing/batch_queues/utilities.py b/sup3r/preprocessing/batch_queues/utilities.py index 59f0a99e6..49b5023b5 100644 --- a/sup3r/preprocessing/batch_queues/utilities.py +++ b/sup3r/preprocessing/batch_queues/utilities.py @@ -28,16 +28,16 @@ def temporal_simple_enhancing(data, t_enhance=4, mode='constant'): 5D array with same dimensions as data with new enhanced resolution """ - if t_enhance in [None, 1]: + if t_enhance in {None, 1}: enhanced_data = data - elif t_enhance not in [None, 1] and len(data.shape) == 5: + elif t_enhance not in {None, 1} and len(data.shape) == 5: if mode == 'constant': enhancement = [1, 1, 1, t_enhance, 1] enhanced_data = zoom( data, enhancement, order=0, mode='nearest', grid_mode=True ) elif mode == 'linear': - index_t_hr = np.array(list(range(data.shape[3] * t_enhance))) + index_t_hr = np.arange(data.shape[3] * t_enhance) index_t_lr = index_t_hr[::t_enhance] enhanced_data = interp1d( index_t_lr, data, axis=3, fill_value='extrapolate' diff --git a/sup3r/preprocessing/collections/stats.py b/sup3r/preprocessing/collections/stats.py index e17f8da2e..b46ea4b70 100644 --- a/sup3r/preprocessing/collections/stats.py +++ b/sup3r/preprocessing/collections/stats.py @@ -114,7 +114,7 @@ def get_means(self, means): handlers.""" means = self._init_stats_dict(means) needed_features = set(self.features) - set(means) - if any(needed_features): + if needed_features: logger.debug('Getting means for %s.', needed_features) cmeans = [ cm * w @@ -125,6 +125,9 @@ def get_means(self, means): ] for f in needed_features: logger.debug('Computing mean for %s.', f) + # we use nansum here because the mean could be nan for a given + # container if all values are nan but there could be non-nan + # values in other containers means[f] = np.float32(np.nansum([cm[f] for cm in cmeans])) return means @@ -133,7 +136,7 @@ def get_stds(self, stds): all data handlers.""" stds = self._init_stats_dict(stds) needed_features = set(self.features) - set(stds) - if any(needed_features): + if needed_features: logger.debug('Getting stds for %s.', needed_features) cstds = [ w * cm**2 @@ -141,6 +144,8 @@ def get_stds(self, stds): ] for f in needed_features: logger.debug('Computing std for %s.', f) + # we use nansum here because one container could have all nans + # but there could be non-nan values in other containers stds[f] = np.float32( np.sqrt(np.nansum([cs[f] for cs in cstds])) ) diff --git a/sup3r/preprocessing/data_handlers/base.py b/sup3r/preprocessing/data_handlers/base.py index 8371276d0..1fa8949eb 100644 --- a/sup3r/preprocessing/data_handlers/base.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -224,9 +224,9 @@ def __init__( just_coords = not features if just_coords: logger.info('Rasterizing source data for coordinate-only access.') - raster_feats = load_features if any(missing_features) else [] + raster_feats = load_features if missing_features else [] self.rasterizer = self.loader = self.cache = None - if any(cached_features): + if cached_features: self.cache = Loader( file_paths=cached_files, features=load_features, diff --git a/sup3r/preprocessing/data_handlers/exo.py b/sup3r/preprocessing/data_handlers/exo.py index e3bacad24..850e6b6e1 100644 --- a/sup3r/preprocessing/data_handlers/exo.py +++ b/sup3r/preprocessing/data_handlers/exo.py @@ -188,7 +188,7 @@ def split(self, split_steps): ) for s in steps_i: s.update({'model': s['model'] - min_step}) - if any(steps_i): + if steps_i: split_dict[i][feature] = {'steps': steps_i} return [ExoData(split) for split in split_dict.values()] @@ -432,9 +432,9 @@ def get_all_step_data(self): for step in data[self.feature]['steps'] ] logger.debug( - 'Got exogenous_data of length {} with shapes: {}'.format( - len(data[self.feature]['steps']), shapes - ) + 'Got exogenous_data of length %s with shapes: %s', + len(data[self.feature]['steps']), + shapes, ) return data diff --git a/sup3r/preprocessing/data_handlers/nc_cc.py b/sup3r/preprocessing/data_handlers/nc_cc.py index 17b6b0ddd..3c15c4edc 100644 --- a/sup3r/preprocessing/data_handlers/nc_cc.py +++ b/sup3r/preprocessing/data_handlers/nc_cc.py @@ -125,13 +125,12 @@ def run_input_checks(self): def run_wrap_checks(self, cs_ghi): """Run check on rasterized data from clearsky_ghi source.""" logger.debug( - 'Reshaped clearsky_ghi data to final shape {} to ' - 'correspond with CC daily average data over source ' - 'time_slice {} with (lat, lon) grid shape of {}'.format( - cs_ghi.shape, - self.rasterizer.time_slice, - self.rasterizer.grid_shape, - ) + 'Reshaped clearsky_ghi data to final shape %s to correspond ' + 'with CC daily average data over source time_slice %s with ' + '(lat, lon) grid shape of %s', + cs_ghi.shape, + self.rasterizer.time_slice, + self.rasterizer.grid_shape, ) msg = ( 'nsrdb clearsky GHI time dimension {} ' @@ -182,13 +181,12 @@ def get_clearsky_ghi(self): i = np.expand_dims(i, axis=1) if len(i.shape) == 1 else i logger.info( - 'Extracting clearsky_ghi data from "{}" with time slice ' - '{} and {} locations with agg factor {}.'.format( - os.path.basename(self._nsrdb_source_fp), - t_slice, - i.shape[0], - i.shape[1], - ) + 'Extracting clearsky_ghi data from "%s" with time slice %s ' + 'and %s locations with agg factor %s.', + os.path.basename(self._nsrdb_source_fp), + t_slice, + i.shape[0], + i.shape[1], ) # spatial coarsening from NSRDB to GCM diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index f95e7477d..f17c0e55a 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -69,6 +69,8 @@ def __init__( self.interp_kwargs = interp_kwargs features = parse_to_list(data=data, features=features) new_features = [f for f in features if f not in self.data] + # Persist derived dependencies back onto self.data so downstream + # derivations can reuse them instead of recomputing the same feature. for f in new_features: self.data[f] = self.derive(f) logger.info('Finished deriving %s.', f) @@ -122,6 +124,9 @@ def check_registry( if hasattr(method, 'inputs'): inputs = self._get_inputs(feature, method) missing = [f for f in inputs if f not in self.data] + # Recursive derivation is only safe when the dependency chain does + # not loop back to the requested feature, unless interpolation can + # break that cycle from nearby levels. can_derive = all( (self.no_overlap(m) or self.has_interp_variables(m)) for m in missing @@ -164,6 +169,8 @@ def map_new_name(self, feature, pattern): name.""" fstruct = parse_feature(feature) pstruct = parse_feature(pattern) + # Registry aliases can point from one naming convention to another, + # including wildcard patterns like basename_*m. if '*' not in pattern: new_feature = pattern elif fstruct.height is not None: @@ -228,6 +235,8 @@ def derive( if compute_check is not None: return compute_check + # If no direct compute method exists, fall back to vertical + # interpolation from nearby heights or pressure levels. if self.has_interp_variables(feature): logger.debug( 'Attempting level interpolation for "%s"', feature @@ -279,6 +288,8 @@ def get_single_level_data(self, feature): if len(var_list) > 0: var_array = da.stack(var_list, axis=-1) + # Broadcast the scalar levels to match the stacked data shape so + # interpolation can treat data values and their levels uniformly. sl_shape = (*var_array.shape[:-1], len(lev_list)) lev_array = da.broadcast_to(da.from_array(lev_list), sl_shape) @@ -373,6 +384,8 @@ def do_level_interpolation( var_array = sl_var lev_array = sl_levs elif ml_var is not None and sl_var is not None: + # Prefer using every available level by combining explicit + # single-level fields with multi-level arrays before interpolation. var_array = np.concatenate([ml_var, sl_var], axis=-1) lev_array = np.concatenate([ml_levs, sl_levs], axis=-1) else: @@ -499,7 +512,7 @@ def __init__( elif np.isnan(self.data.as_array()).any(): logger.info( - f'Filling nan values with nan_method_kwargs=' - f'{nan_method_kwargs}' + 'Filling nan values with nan_method_kwargs=%s.', + nan_method_kwargs, ) self.data = self.data.interpolate_na(**nan_method_kwargs) diff --git a/sup3r/preprocessing/rasterizers/base.py b/sup3r/preprocessing/rasterizers/base.py index 3891c1b9d..8f626f9b8 100644 --- a/sup3r/preprocessing/rasterizers/base.py +++ b/sup3r/preprocessing/rasterizers/base.py @@ -213,16 +213,36 @@ def get_closest_row_col(self, lat_lon, target): lat_lon[..., 0] - target[0], lat_lon[..., 1] - target[1] ) row, col = np.unravel_index(np.argmin(dist, axis=None), dist.shape) - msg = ( - 'The distance between the closest coordinate: ' - f'{np.asarray(lat_lon[row, col])} and the requested ' - f'target: {np.asarray(target)} is {np.asarray(dist.min())}. ' - ) + closest_coord = np.asarray(lat_lon[row, col]) + requested_target = np.asarray(target) + min_dist = np.asarray(dist.min()) if self.threshold is not None and dist.min() > self.threshold: - add_msg = f'This exceeds the given threshold: {self.threshold}' - logger.error(f'{msg} {add_msg}') - raise RuntimeError(f'{msg} {add_msg}') - logger.debug(msg) + logger.error( + 'The distance between the closest coordinate: %s and the ' + 'requested target: %s is %s. This exceeds the given ' + 'threshold: %s', + closest_coord, + requested_target, + min_dist, + self.threshold, + ) + raise RuntimeError( + 'The distance between the closest coordinate: {} and the ' + 'requested target: {} is {}. This exceeds the given ' + 'threshold: {}'.format( + closest_coord, + requested_target, + min_dist, + self.threshold, + ) + ) + logger.debug( + 'The distance between the closest coordinate: %s and the ' + 'requested target: %s is %s.', + closest_coord, + requested_target, + min_dist, + ) return row, col def get_lat_lon(self): diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index b8a8b92ec..aeb071e9d 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -178,8 +178,8 @@ def update_hr_data(self): ], } logger.debug( - 'Updating self.data.high_res with new shape: ' - f'{self.hr_required_shape[:3]}' + 'Updating self.data.high_res with new shape: %s', + self.hr_required_shape[:3], ) self.data.high_res = self.data.high_res.update_ds({ **hr_coords_new, @@ -242,7 +242,7 @@ def check_regridded_lr_data(self): logger.error(msg) raise ValueError(msg) - if any(fill_feats): + if fill_feats: msg = ( 'Doing nearest neighbor nan fill on low_res data for ' f'features = {fill_feats}' diff --git a/sup3r/preprocessing/rasterizers/exo.py b/sup3r/preprocessing/rasterizers/exo.py index 9c7c1b242..4f261b4ac 100644 --- a/sup3r/preprocessing/rasterizers/exo.py +++ b/sup3r/preprocessing/rasterizers/exo.py @@ -610,7 +610,9 @@ def __new__(cls, feature, file_paths, source_files=None, **kwargs): **kwargs, } logger.debug( - f'Using {ExoClass.__name__} to rasterize feature "{feature}"' + 'Using %s to rasterize feature "%s"', + ExoClass.__name__, + feature, ) return ExoClass(**kwargs) diff --git a/sup3r/preprocessing/samplers/base.py b/sup3r/preprocessing/samplers/base.py index b99c76a2a..06a91e19b 100644 --- a/sup3r/preprocessing/samplers/base.py +++ b/sup3r/preprocessing/samplers/base.py @@ -310,9 +310,8 @@ def sample_shape(self, sample_shape): self._sample_shape = sample_shape if len(self._sample_shape) == 2: logger.info( - 'Found 2D sample shape of {}. Adding temporal dim of 1'.format( - self._sample_shape - ) + 'Found 2D sample shape of %s. Adding temporal dim of 1', + self._sample_shape, ) self._sample_shape = (*self._sample_shape, 1) diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index 565da6034..3fbef9ce3 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -57,9 +57,10 @@ def get_input_handler_class(input_handler_name: Optional[str] = None): input_handler_name = 'DataHandler' logger.debug( - '"input_handler_name" arg was not provided. Using ' - f'"{input_handler_name}". If this is incorrect, please provide ' - 'input_handler_name="DataHandlerName".' + '"input_handler_name" arg was not provided. Using "%s". If ' + 'this is incorrect, please provide ' + 'input_handler_name="DataHandlerName".', + input_handler_name, ) HandlerClass = ( @@ -340,7 +341,7 @@ def get_source_type(file_paths): if isinstance(file_paths, str) and '*' in file_paths: temp = glob(file_paths) - if any(temp): + if temp: file_paths = temp if not isinstance(file_paths, list): diff --git a/sup3r/qa/qa.py b/sup3r/qa/qa.py index e1fa302b9..ecf973983 100644 --- a/sup3r/qa/qa.py +++ b/sup3r/qa/qa.py @@ -314,7 +314,7 @@ def get_dset_out(self, name): array of shape (spatial_1, spatial_2, temporal) """ - logger.debug('Getting sup3r output dataset "{}"'.format(name)) + logger.debug('Getting sup3r output dataset "%s"', name) data = self.output_handler[name] if self.output_type == 'nc': data = data.values @@ -356,9 +356,12 @@ def coarsen_data(self, idf, feature, data): ) logger.info( - f'Coarsening feature "{feature}" with {self.s_enhance}x ' - f'spatial averaging and "{t_meth}" {self.t_enhance}x ' - 'temporal averaging' + 'Coarsening feature "%s" with %sx spatial averaging and "%s" ' + '%sx temporal averaging', + feature, + self.s_enhance, + t_meth, + self.t_enhance, ) data = spatial_coarsening( @@ -434,7 +437,7 @@ def export(self, qa_fp, data, dset_name, dset_suffix=''): """ if not os.path.exists(qa_fp): - logger.info('Initializing qa output file: "{}"'.format(qa_fp)) + logger.info('Initializing qa output file: "%s"', qa_fp) with RexOutputs(qa_fp, mode='w') as f: f.meta = self.input_handler.meta f.time_index = self.input_handler.time_index @@ -452,7 +455,7 @@ def export(self, qa_fp, data, dset_name, dset_suffix=''): if dset_suffix: dset_name = dset_name + '_' + dset_suffix - logger.info('Adding dataset "{}" to output file.'.format(dset_name)) + logger.info('Adding dataset "%s" to output file.', dset_name) # transpose and flatten to typical h5 (time, space) dimensions data = np.transpose(np.asarray(data), axes=(2, 0, 1)).reshape(shape) @@ -482,10 +485,12 @@ def run(self): ziter = zip(self.features, self.source_features, self.output_names) for idf, (feature, source_feature, dset_out) in enumerate(ziter): logger.info( - 'Running QA on dataset {} of {} for feature "{}" ' - 'with source feature name "{}"'.format( - idf + 1, len(self.features), feature, source_feature, - ) + 'Running QA on dataset %s of %s for feature "%s" with ' + 'source feature name "%s"', + idf + 1, + len(self.features), + feature, + source_feature, ) data_syn = self.get_dset_out(feature) data_syn = self.coarsen_data(idf, feature, data_syn) diff --git a/sup3r/qa/utilities.py b/sup3r/qa/utilities.py index 556ec3583..edb10f105 100644 --- a/sup3r/qa/utilities.py +++ b/sup3r/qa/utilities.py @@ -373,7 +373,7 @@ def continuous_dist(diffs, bins=None, range=None, interpolate=False): dx = dx[dx > 0] dx = np.mean(dx) bins = int((np.max(diffs) - np.min(diffs)) / dx) - logger.debug(f'Using n_bins={bins} to compute distribution') + logger.debug('Using n_bins=%s to compute distribution', bins) counts, edges = np.histogram(diffs, bins=bins, range=range) centers = edges[:-1] + (np.diff(edges) / 2) if interpolate: diff --git a/sup3r/solar/solar.py b/sup3r/solar/solar.py index e8c200cdb..6b8504f96 100644 --- a/sup3r/solar/solar.py +++ b/sup3r/solar/solar.py @@ -91,19 +91,16 @@ def __init__( self._sup3r_fps = [self._sup3r_fps] logger.debug( - 'Initializing solar module with sup3r files: {}'.format( - [os.path.basename(fp) for fp in self._sup3r_fps] - ) + 'Initializing solar module with sup3r files: %s', + [os.path.basename(fp) for fp in self._sup3r_fps], ) logger.debug( - 'Initializing solar module with temporal slice: {}'.format( - self.t_slice - ) + 'Initializing solar module with temporal slice: %s', + self.t_slice, ) logger.debug( - 'Initializing solar module with NSRDB source fp: {}'.format( - self._nsrdb_fp - ) + 'Initializing solar module with NSRDB source fp: %s', + self._nsrdb_fp, ) self.gan_data = MultiTimeResource(self._sup3r_fps) @@ -247,11 +244,9 @@ def nsrdb_tslice(self): self._nsrdb_tslice = slice(t0, t1, step) logger.debug( - 'Found nsrdb_tslice {} with corresponding ' - 'time index:\n\t{}'.format( - self._nsrdb_tslice, - self.nsrdb.time_index[self._nsrdb_tslice], - ) + 'Found nsrdb_tslice %s with corresponding time index:\n\t%s', + self._nsrdb_tslice, + self.nsrdb.time_index[self._nsrdb_tslice], ) return self._nsrdb_tslice @@ -384,7 +379,7 @@ def get_nsrdb_data(self, dset): the sites is an average across multiple NSRDB sites. """ - logger.debug('Retrieving "{}" from NSRDB source data.'.format(dset)) + logger.debug('Retrieving "%s" from NSRDB source data.', dset) out = None for idx in range(self.idnn.shape[1]): @@ -573,12 +568,12 @@ def write(self, fp_out, features=('ghi', 'dni', 'dhi')): attrs=attrs, chunks=attrs['chunks'], ) - logger.info(f'Added "{feature}" to output file.') + logger.info('Added "%s" to output file.', feature) run_attrs = self.gan_data.h5[self._sup3r_fps[0]].global_attrs run_attrs['nsrdb_source'] = self._nsrdb_fp fh.run_attrs = run_attrs - logger.debug(f'Finished writing file: {fp_out}') + logger.debug('Finished writing file: %s', fp_out) @classmethod def run_temporal_chunks( @@ -709,9 +704,9 @@ def _run_temporal_chunk( else: logger.info( - 'Running temporal index {} out of {}.'.format( - i + 1, len(fp_sets) - ) + 'Running temporal index %s out of %s.', + i + 1, + len(fp_sets), ) kwargs = { diff --git a/sup3r/solar/solar_cli.py b/sup3r/solar/solar_cli.py index c78b319ca..f637b00c5 100644 --- a/sup3r/solar/solar_cli.py +++ b/sup3r/solar/solar_cli.py @@ -60,14 +60,13 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): max_nodes = config.get('max_nodes', len(temporal_ids)) max_nodes = min((max_nodes, len(temporal_ids))) logger.info( - 'Solar module found {} sets of chunked source files to run ' - 'on. Submitting to {} nodes based on the number of temporal ' - 'chunks {} and the requested number of nodes {}'.format( - len(fp_sets), - max_nodes, - len(temporal_ids), - config.get('max_nodes', None), - ) + 'Solar module found %s sets of chunked source files to run on. ' + 'Submitting to %s nodes based on the number of temporal chunks %s ' + 'and the requested number of nodes %s', + len(fp_sets), + max_nodes, + len(temporal_ids), + config.get('max_nodes', None), ) temporal_id_chunks = np.array_split(temporal_ids, max_nodes) diff --git a/sup3r/utilities/cli.py b/sup3r/utilities/cli.py index f61cc7535..2f8177556 100644 --- a/sup3r/utilities/cli.py +++ b/sup3r/utilities/cli.py @@ -128,7 +128,7 @@ def from_config_preflight(cls, module_name, ctx, config_file, verbose): log_pattern = config.get('log_pattern', None) config_verbose = config.get('log_level', 'INFO') config_verbose = config_verbose == 'DEBUG' - verbose = any([verbose, config_verbose, ctx.obj['VERBOSE']]) + verbose = any((verbose, config_verbose, ctx.obj['VERBOSE'])) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.get('option', 'local') @@ -150,8 +150,8 @@ def from_config_preflight(cls, module_name, ctx, config_file, verbose): log_pattern = log_pattern.replace('.log', '_{node_index}.log') exec_kwargs['stdout_path'] = os.path.join(status_dir, 'stdout/') - logger.debug('Found execution kwargs: {}'.format(exec_kwargs)) - logger.debug('Hardware run option: "{}"'.format(hardware_option)) + logger.debug('Found execution kwargs: %s', exec_kwargs) + logger.debug('Hardware run option: "%s"', hardware_option) name = f'sup3r_{module_name.replace("-", "_")}' name += '_{}'.format(os.path.basename(status_dir)) @@ -249,8 +249,9 @@ def kickoff_slurm_job( if pipeline_step != module_name: job_info = f"{job_info} (pipeline step {pipeline_step!r})" logger.info( - f'Running sup3r {job_info} on SLURM with node ' - f'name "{name}".' + 'Running sup3r %s on SLURM with node name "%s".', + job_info, + name, ) out = slurm_manager.sbatch( cmd, @@ -327,8 +328,9 @@ def kickoff_local_job(cls, module_name, ctx, cmd, pipeline_step=None): if pipeline_step != module_name: job_info = f"{job_info} (pipeline step {pipeline_step!r})" logger.info( - f'Running sup3r {job_info} locally with job ' - f'name "{name}".' + 'Running sup3r %s locally with job name "%s".', + job_info, + name, ) Status.mark_job_as_submitted( out_dir, @@ -381,6 +383,6 @@ def add_status_cmd(cls, config, pipeline_step, cmd): cmd += f"Status.make_single_job_file({status_file_arg_str})" cmd_log = '\n\t'.join(cmd.split('\n')) - logger.debug(f'Running command:\n\t{cmd_log}') + logger.debug('Running command:\n\t%s', cmd_log) return cmd diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index fe900cc04..947ae08b2 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -345,11 +345,12 @@ def download_file( os.remove(out_file) if not cls._can_skip_file(out_file) or overwrite: - msg = ( - f'Downloading {variables} to {out_file} with levels ' - f'= {levels}.' + logger.info( + 'Downloading %s to %s with levels = %s.', + variables, + out_file, + levels, ) - logger.info(msg) dataset = f'reanalysis-era5-{level_type}-levels' if 'monthly' in product_type: dataset += '-monthly-means' @@ -370,7 +371,7 @@ def download_file( cds_api_client.retrieve(dataset, entry, out_file) else: - logger.info(f'File already exists: {out_file}.') + logger.info('File already exists: %s.', out_file) def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" @@ -393,8 +394,10 @@ def process_surface_file(self): ds.compute().to_netcdf(tmp_file, format='NETCDF4', engine='h5netcdf') os.replace(tmp_file, self.surface_file) logger.info( - f'Finished processing {self.surface_file}. Moved {tmp_file} to ' - f'{self.surface_file}.' + 'Finished processing %s. Moved %s to %s.', + self.surface_file, + tmp_file, + self.surface_file, ) def add_pressure(self, ds): @@ -466,8 +469,10 @@ def process_level_file(self): ds.compute().to_netcdf(tmp_file, format='NETCDF4', engine='h5netcdf') os.replace(tmp_file, self.level_file) logger.info( - f'Finished processing {self.level_file}. Moved ' - f'{tmp_file} to {self.level_file}.' + 'Finished processing %s. Moved %s to %s.', + self.level_file, + tmp_file, + self.level_file, ) def process_and_combine(self): @@ -483,11 +488,11 @@ def process_and_combine(self): if not self._can_skip_file(self.monthly_file) or self.overwrite: files = [] if os.path.exists(self.level_file): - logger.info(f'Processing {self.level_file}.') + logger.info('Processing %s.', self.level_file) self.process_level_file() files.append(self.level_file) if os.path.exists(self.surface_file): - logger.info(f'Processing {self.surface_file}.') + logger.info('Processing %s.', self.surface_file) self.process_surface_file() files.append(self.surface_file) @@ -499,7 +504,7 @@ def process_and_combine(self): if os.path.exists(self.surface_file): os.remove(self.surface_file) else: - logger.info(f'{self.monthly_file} already exists.') + logger.info('%s already exists.', self.monthly_file) def get_monthly_file(self): """Download level and surface files, process variables, and combine @@ -665,10 +670,10 @@ def run_for_var( ), msg tasks = [] - months = list(range(1, 13)) if months is None else months + months = range(1, 13) if months is None else months if days is None: days = [ - list(np.arange(1, monthrange(year, month)[1] + 1)) + range(1, monthrange(year, month)[1] + 1) for month in months ] days = [[str(day).zfill(2) for day in d] for d in days] @@ -871,7 +876,7 @@ def _can_skip_file(cls, file): def _combine_files(cls, files, out_file, chunks='auto', res_kwargs=None): if not os.path.exists(out_file): os.makedirs(os.path.dirname(out_file), exist_ok=True) - logger.info(f'Combining {files} into {out_file}.') + logger.info('Combining %s into %s.', files, out_file) try: res_kwargs = res_kwargs or {} loader = Loader(files, res_kwargs=res_kwargs) @@ -885,11 +890,11 @@ def _combine_files(cls, files, out_file, chunks='auto', res_kwargs=None): chunks=chunks, ) except Exception as e: - msg = f'Error combining {files}. {e}' - logger.error(msg) - raise RuntimeError(msg) from e + msg = 'Error combining %s. %s' + logger.error(msg, files, e) + raise RuntimeError(msg % (files, e)) from e else: - logger.info(f'{out_file} already exists.') + logger.info('%s already exists.', out_file) @classmethod def make_yearly_file( diff --git a/sup3r/utilities/interpolation.py b/sup3r/utilities/interpolation.py index af8113388..580b40660 100644 --- a/sup3r/utilities/interpolation.py +++ b/sup3r/utilities/interpolation.py @@ -183,7 +183,7 @@ def _check_lev_array(cls, lev_array, levels): raise RuntimeError(msg) nans = np.isnan(lev_array) - logger.debug('Level array shape: {}'.format(lev_array.shape)) + logger.debug('Level array shape: %s', lev_array.shape) lowest_height = np.min(lev_array, axis=-1) highest_height = np.max(lev_array, axis=-1) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index cc807f8a5..23107d44c 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -46,7 +46,7 @@ def get_tmp_file(file): tmp_file = file + '.tmp' if os.path.exists(tmp_file): logger.warning( - f'Temporary file {tmp_file} already exists. Removing...' + 'Temporary file %s already exists. Removing...', tmp_file ) os.remove(tmp_file) return tmp_file diff --git a/sup3r/writers/base.py b/sup3r/writers/base.py index 8f7f1d050..783b1dc42 100644 --- a/sup3r/writers/base.py +++ b/sup3r/writers/base.py @@ -75,10 +75,11 @@ def _init_h5(out_file, time_index, meta, global_attrs): """ with RexOutputs(out_file, mode='w-') as f: - logger.info('Initializing output file: {}'.format(out_file)) + logger.info('Initializing output file: %s', out_file) logger.debug( - 'Initializing output file with shape {} ' - 'and meta data:\n{}'.format((len(time_index), len(meta)), meta) + 'Initializing output file with shape %s and meta data:\n%s', + (len(time_index), len(meta)), + meta, ) f.time_index = time_index f.meta = meta @@ -102,8 +103,10 @@ def _ensure_dset_in_output(cls, out_file, dset, data=None): if dset not in f.dsets: attrs, dtype = get_dset_attrs(dset) logger.debug( - 'Initializing dataset "{}" with shape {} and ' - 'dtype {}'.format(dset, f.shape, dtype) + 'Initializing dataset "%s" with shape %s and dtype %s', + dset, + f.shape, + dtype, ) f._create_dset( dset, @@ -160,11 +163,11 @@ def write_data( fh.run_attrs = attrs os.replace(tmp_file, out_file) - msg = ( - 'Saved output of size ' - f'{(len(data_list), *data_list[0].shape)} to: {out_file}' + logger.info( + 'Saved output of size %s to: %s', + (len(data_list), *data_list[0].shape), + out_file, ) - logger.info(msg) class RexOutputs(BaseRexOutputs): @@ -530,7 +533,7 @@ def get_times(low_res_times, shape): """ logger.debug('Getting high resolution time indices') logger.debug( - f'Low res times: {low_res_times[0]} to {low_res_times[-1]}' + 'Low res times: %s to %s', low_res_times[0], low_res_times[-1] ) t_enhance = int(shape / len(low_res_times)) @@ -547,7 +550,7 @@ def get_times(low_res_times, shape): leap_mask = (times.month == 2) & (times.day == 29) times = times[~leap_mask] - logger.debug(f'High res times: {times[0]} to {times[-1]}') + logger.debug('High res times: %s to %s', times[0], times[-1]) assert len(times) == shape, ( f'High res times length {len(times)} does not match expected ' f'shape {shape}' diff --git a/sup3r/writers/cachers.py b/sup3r/writers/cachers.py index 976dfa855..9c15e0d1c 100644 --- a/sup3r/writers/cachers.py +++ b/sup3r/writers/cachers.py @@ -104,8 +104,9 @@ def _write_single( """Write single NETCDF or H5 cache file.""" if os.path.exists(out_file) and not overwrite: logger.info( - f'{out_file} already exists. Delete or specify overwrite=True ' - 'if you want to overwrite.' + '%s already exists. Delete or specify overwrite=True if you ' + 'want to overwrite.', + out_file, ) return if features == 'all': @@ -192,15 +193,16 @@ def cache_data( cache_kwargs={'cache_pattern': cache_pattern}, ) - if any(cached_files) and not overwrite: + if cached_files and not overwrite: logger.info( - f'Cache files with pattern {cache_pattern} already exist. ' - 'Delete or specify overwrite=True to overwrite.' + 'Cache files with pattern %s already exist. Delete or specify ' + 'overwrite=True to overwrite.', + cache_pattern, ) - elif any(cached_files) and overwrite: + elif cached_files and overwrite: missing_files += cached_files - if any(missing_files): + if missing_files: logger.info('Caching %s to %s', missing_features, missing_files) for feature, out_file in zip(missing_features, missing_files): self._write_single( diff --git a/sup3r/writers/h5.py b/sup3r/writers/h5.py index 951dfaad8..a61c2729e 100644 --- a/sup3r/writers/h5.py +++ b/sup3r/writers/h5.py @@ -79,10 +79,10 @@ def _write_output( """ if os.path.exists(out_file): if overwrite: - logger.warning(f'Overwriting existing file at {out_file}.') + logger.warning('Overwriting existing file at %s.', out_file) else: logger.info( - f'File already exists at {out_file}. Skipping write.' + 'File already exists at %s. Skipping write.', out_file ) return logger.info( diff --git a/sup3r/writers/utilities.py b/sup3r/writers/utilities.py index 99410a321..ee4263132 100644 --- a/sup3r/writers/utilities.py +++ b/sup3r/writers/utilities.py @@ -31,7 +31,7 @@ def _check_for_cache(features, cache_kwargs): cache_pattern.format(feature=f) for f in missing_features ] - if any(cached_files): + if cached_files: logger.debug( 'Found cache files for %s with file pattern: %s', cached_features, cache_pattern diff --git a/tests/utilities/test_era_downloader.py b/tests/utilities/test_era_downloader.py index 9452b4b0c..01b23a688 100644 --- a/tests/utilities/test_era_downloader.py +++ b/tests/utilities/test_era_downloader.py @@ -116,6 +116,7 @@ def test_era_dl_year(tmpdir_factory): yearly_file_pattern=yearly_file_pattern, max_workers=1, combine_all_files=True, + res_kwargs={'compat': 'no_conflicts'}, ) combined_file = yearly_file_pattern.replace('_{var}_', '').format( From 7de4f58514abdc69376e820e73924b4af30a0417 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 11 May 2026 08:54:30 -0600 Subject: [PATCH 25/34] fix: update logging format in AbstractSingleModel to use % formatting for consistency --- sup3r/models/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 97d008802..abfdb550d 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -839,7 +839,7 @@ def log_loss_details(loss_details, level='INFO'): Log level (e.g. INFO, DEBUG) """ for k, v in sorted(loss_details.items()): - msg_format = '\t{}: {}' if isinstance(v, str) else '\t{}: {:.2e}' + msg_format = '\t%s: %s' if isinstance(v, str) else '\t%s: %.2e' if level.lower() == 'info': logger.info(msg_format, k, v) else: From 16abe41b9870b35dc31baa2b9ff1148492ac04df Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 12 May 2026 13:00:33 -0600 Subject: [PATCH 26/34] fix: simplify loss_details dictionary comprehension to improve readability --- sup3r/models/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 3090f70fd..01c6ba127 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -885,9 +885,7 @@ def calc_loss( loss_details['loss_disc'] = loss_disc loss_details['loss_gen'] = loss_gen loss_details['loss_gen_content'] = loss_gen_content - loss_details = { - k: float(v) for k, v in loss_details.items() if v is not None - } + loss_details = {k: v for k, v in loss_details.items() if v is not None} return loss, loss_details def calc_val_loss(self, batch_handler, weight_gen_advers): From 405ecb81a346dc1594de98d437cd652b71f6d40a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 13 May 2026 16:41:53 -0600 Subject: [PATCH 27/34] fix: update batch size calculation in AbstractSingleModel for compatibility with TensorFlow and modify hr_exo_features in LinearInterp and SurfaceSpatialMetModel for clarity --- sup3r/models/abstract.py | 2 +- sup3r/models/linear.py | 4 ++-- sup3r/models/surface.py | 5 +++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index abfdb550d..ddc0f4471 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1030,7 +1030,7 @@ def _run_mirrored_grad( raise RuntimeError(msg) num_replicas = self.strategy.num_replicas_in_sync - batch_size = low_res.shape[0] + batch_size = tf.shape(low_res) [0] if batch_size % num_replicas != 0: msg = ( 'Batch size must be divisible by the number of mirrored ' diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index e18104f06..e96e7b7cf 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -105,8 +105,8 @@ def hr_out_features(self): @property def hr_exo_features(self): - """Returns topography for LinearInterp model""" - return ['topography'] + """Returns empty list for LinearInterp model""" + return [] def save(self, out_dir): """ diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index 64b1b9977..27a178f93 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -209,6 +209,11 @@ def feature_inds_other(self): ] return inds + @property + def hr_exo_features(self): + """Returns topography for Surface model""" + return ["topography"] + def _get_temp_rh_ind(self, idf_rh): """Get the feature index value for the temperature feature corresponding to a relative humidity feature at the same hub height. From 25a5d4a51a5ff9ed519110136b89c8c677ea6860 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 13 May 2026 16:48:16 -0600 Subject: [PATCH 28/34] fix: clarify hr_exo_features docstring in SurfaceSpatialMetModel for better understanding of exogenous data handling --- sup3r/models/surface.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index 27a178f93..d0acfee7e 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -203,16 +203,15 @@ def feature_inds_other(self): + self.feature_inds_rh ) inds = [ - i - for i, name in enumerate(self._lr_features) - if i not in finds_tprh + i for i in range(len(self._lr_features)) if i not in finds_tprh ] return inds @property def hr_exo_features(self): - """Returns topography for Surface model""" - return ["topography"] + """Returns topography for surface model so the inference machinery + knows to pass in the exogenous data""" + return ['topography'] def _get_temp_rh_ind(self, idf_rh): """Get the feature index value for the temperature feature From baeaa8df3556270b04725a6cf91349110310380a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 13 May 2026 17:12:28 -0600 Subject: [PATCH 29/34] fix: update batch size calculation in AbstractSingleModel to use low_res.shape for accuracy --- sup3r/models/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index ddc0f4471..abfdb550d 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1030,7 +1030,7 @@ def _run_mirrored_grad( raise RuntimeError(msg) num_replicas = self.strategy.num_replicas_in_sync - batch_size = tf.shape(low_res) [0] + batch_size = low_res.shape[0] if batch_size % num_replicas != 0: msg = ( 'Batch size must be divisible by the number of mirrored ' From bcf10f39bf2a81dad4c02543ebbc778429af7f2c Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 13 May 2026 19:57:47 -0600 Subject: [PATCH 30/34] fix: wrap apply_fn call in strategy.scope for multi-GPU compatibility --- sup3r/models/abstract.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index abfdb550d..27c00a6dd 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1068,7 +1068,8 @@ def _run_mirrored_grad( ) for key, value in per_replica_details.items() } - apply_fn(total_grad) + with self.strategy.scope(): + apply_fn(total_grad) return mean_loss_details @tf.function(reduce_retracing=True) From 54d5591637678ca11102f35c9ebc9442e5cc1f38 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 14 May 2026 07:35:20 -0600 Subject: [PATCH 31/34] fix: streamline sample_batch docstring for clarity --- sup3r/preprocessing/batch_queues/abstract.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index cceddf123..7dfa37d49 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -325,13 +325,7 @@ def get_random_container(self): def sample_batch(self): """Get random sampler from collection and return a batch of samples - from that sampler. - - Notes - ----- - These samples are wrapped in an ``np.asarray`` call, so they have been - loaded into memory. - """ + from that sampler.""" return next(self.get_random_container()) def log_queue_info(self): From cda59172e4ba27609b6e1cf0535c2a8f09c0a3ca Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 14 May 2026 13:50:22 -0600 Subject: [PATCH 32/34] fix: correct reference to optimizer in update_optimizer_gen method --- sup3r/models/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 27c00a6dd..679a331dc 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -355,7 +355,7 @@ def update_optimizer_gen(self, **kwargs): conf = self._optimizer_config.copy() conf.update(**kwargs) self._optimizer_config = conf - if self.optimizer is not None: + if self._optimizer is not None: self._optimizer = self.optimizer.__class__.from_config(conf) @property From 82348b309aa01dde9141cd479a2dc475b691d814 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 14 May 2026 13:55:10 -0600 Subject: [PATCH 33/34] fix: correct optimizer reference in Sup3rGan class for consistency --- sup3r/models/abstract.py | 2 +- sup3r/models/base.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 679a331dc..78c2dadbd 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -356,7 +356,7 @@ def update_optimizer_gen(self, **kwargs): conf.update(**kwargs) self._optimizer_config = conf if self._optimizer is not None: - self._optimizer = self.optimizer.__class__.from_config(conf) + self._optimizer = self._optimizer.__class__.from_config(conf) @property def history(self): diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 01c6ba127..15325d8cc 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -394,8 +394,8 @@ def update_optimizer_disc(self, **kwargs): conf = self._optimizer_disc_config.copy() conf.update(**kwargs) self._optimizer_disc_config = conf - if self.optimizer_disc is not None: - self._optimizer_disc = self.optimizer_disc.__class__.from_config( + if self._optimizer_disc is not None: + self._optimizer_disc = self._optimizer_disc.__class__.from_config( conf ) @@ -677,7 +677,7 @@ def train(self, batch_handler, config=None, **kwargs): self._optimizer = None self._optimizer_disc = None - if self.optimizer is None or self.optimizer_disc is None: + if self._optimizer is None or self._optimizer_disc is None: with self._training_scope(): self._optimizer = self.init_optimizer( self._optimizer_config, learning_rate=None From b665ce14faaf78074c29c9058e3ef1938e8b316c Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 14 May 2026 15:44:43 -0600 Subject: [PATCH 34/34] fix: remove unnecessary device context check in _training_scope method --- sup3r/models/abstract.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 78c2dadbd..f591aff67 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -154,8 +154,6 @@ def configure_multi_gpu(self, multi_gpu=False): def _training_scope(self, device=None): """Get a strategy scope or a concrete device context.""" if tf.distribute.get_replica_context() is not None: - if device is not None: - return tf.device(device) return nullcontext() if self.strategy is not None: @@ -1068,8 +1066,7 @@ def _run_mirrored_grad( ) for key, value in per_replica_details.items() } - with self.strategy.scope(): - apply_fn(total_grad) + apply_fn(total_grad) return mean_loss_details @tf.function(reduce_retracing=True)