diff --git a/.circleci/config.yml b/.circleci/config.yml index 149f4e319..fd436ad3e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -101,6 +101,7 @@ commands: enum: - "test_main" - "test_savepoint" + - "test_driver_checkpoint" - "savepoint_tests" - "savepoint_tests_mpi" - "physics_savepoint_tests" @@ -459,6 +460,20 @@ jobs: target: test_savepoint num_ranks: 6 + test_driver_checkpoint: + machine: + image: ubuntu-2004:202111-02 + resource_class: large + environment: + GOOGLE_APPLICATION_CREDENTIALS: /tmp/key.json + steps: + - checkout + - make_savepoints: + backend: numpy + experiment: c12_6ranks_baroclinic_dycore_microphysics + target: test_driver_checkpoint + num_ranks: 6 + test_notebooks: docker: - image: gcr.io/vcm-ml/pace_notebook_examples @@ -609,6 +624,12 @@ workflows: filters: tags: only: /^v.*/ + - test_driver_checkpoint: + context: + - GCLOUD_ENCODED_KEY + filters: + tags: + only: /^v.*/ - test_util: filters: tags: diff --git a/.jenkins/driver_checkpoint_test.sh b/.jenkins/driver_checkpoint_test.sh new file mode 100755 index 000000000..3678883e3 --- /dev/null +++ b/.jenkins/driver_checkpoint_test.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +## Summary: +# Jenkins plan (only working on Piz daint) to run dace orchestration and gather performance numbers. + +## Syntax: +# .jenkins/action/driver_checkpoint_test.sh + +JENKINS_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +PACE_DIR=$JENKINS_DIR/../ +export VIRTUALENV=${PACE_DIR}/venv +${JENKINS_DIR}/install_virtualenv.sh ${VIRTUALENV} +source ${VIRTUALENV}/bin/activate + +BUILDENV_DIR=$PACE_DIR/buildenv +. ${BUILDENV_DIR}/schedulerTools.sh + +cat << EOF > run.daint.slurm +#!/bin/bash +#SBATCH --constraint=gpu +#SBATCH --job-name=driver_checkpoint_test +#SBATCH --ntasks=6 +#SBATCH --hint=multithread +#SBATCH --ntasks-per-node=6 +#SBATCH --cpus-per-task=1 +#SBATCH --output=driver.out +#SBATCH --time=00:45:00 +#SBATCH --gres=gpu:1 +#SBATCH --account=s1053 +#SBATCH --partition=normal +######################################################## +set -x +export OMP_NUM_THREADS=1 +export TEST_ARGS="-v -s -rsx --backend=numpy " +export EXPERIMENT=c12_6ranks_baroclinic_dycore_microphysics +export MPIRUN_CALL="srun" +CONTAINER_CMD="" MPIRUN_ARGS="" DEV=n make test_driver_checkpoint +EOF +launch_job run.daint.slurm 3600 +cat driver.out diff --git a/Makefile b/Makefile index 92771a95e..d4819300b 100644 --- a/Makefile +++ b/Makefile @@ -133,7 +133,11 @@ test_main: build test_savepoint: ## top level savepoint tests TARGET=dycore $(MAKE) get_test_data - $(CONTAINER_CMD) $(CONTAINER_FLAGS) bash -c "$(SAVEPOINT_SETUP) && cd $(PACE_PATH) && $(MPIRUN_CALL) python -m pytest --data_path=$(EXPERIMENT_DATA_RUN)/dycore/ $(TEST_ARGS) $(PACE_PATH)/tests/savepoint" + $(CONTAINER_CMD) $(CONTAINER_FLAGS) bash -c "$(SAVEPOINT_SETUP) && cd $(PACE_PATH) && $(MPIRUN_CALL) python -m pytest --data_path=$(EXPERIMENT_DATA_RUN)/dycore/ $(TEST_ARGS) -k test_fv_dynamics $(PACE_PATH)/tests/savepoint" + +test_driver_checkpoint: + TARGET=all EXPERIMENT=c12_6ranks_baroclinic_dycore_microphysics $(MAKE) get_test_data + $(CONTAINER_CMD) $(CONTAINER_FLAGS) bash -c "$(SAVEPOINT_SETUP) && cd $(PACE_PATH) && $(MPIRUN_CALL) python -m mpi4py -m pytest --data_path=$(EXPERIMENT_DATA_RUN)/all/ $(TEST_ARGS) -k test_driver $(PACE_PATH)/tests/savepoint" test_notebooks: ## tests for jupyter notebooks, must be run in correct Python environment pytest --nbmake "examples/notebooks" diff --git a/driver/pace/driver/checkpointer.py b/driver/pace/driver/checkpointer.py new file mode 100644 index 000000000..a1a9b5855 --- /dev/null +++ b/driver/pace/driver/checkpointer.py @@ -0,0 +1,116 @@ +import abc +import dataclasses +import logging +from typing import ClassVar + +import dacite +import yaml + +from pace.util import SavepointThresholds +from pace.util.checkpointer import ( + Checkpointer, + NullCheckpointer, + SnapshotCheckpointer, + ThresholdCalibrationCheckpointer, + ValidationCheckpointer, +) + +from .registry import Registry + + +logger = logging.getLogger(__name__) + + +class CheckpointerInitializer(abc.ABC): + @abc.abstractmethod + def get_checkpointer(self, rank: int) -> Checkpointer: + ... + + +@dataclasses.dataclass +class CheckpointerInitializerSelector(CheckpointerInitializer): + """ + Dataclass for selecting the implementation of CheckpointerInitializer to use. + + Used to circumvent the issue that dacite expects static class definitions, + but we would like to dynamically define which CheckpointerInitializer to use. + Does this by representing the part of the yaml specification that asks which + initializer to use, but deferring to the implementation + in that initializer when called. + """ + + type: str + config: CheckpointerInitializer + registry: ClassVar[Registry] = Registry() + + @classmethod + def register(cls, type_name): + return cls.registry.register(type_name) + + def get_checkpointer(self, rank: int) -> Checkpointer: + return self.config.get_checkpointer(rank=rank) + + @classmethod + def from_dict(cls, config: dict): + instance = cls.registry.from_dict(config) + return cls(config=instance, type=config["type"]) + + +@CheckpointerInitializerSelector.register("null") +@dataclasses.dataclass +class NullCheckpointerInit(CheckpointerInitializer): + """ + Configuration for threshold calibration checkpointer. + """ + + def get_checkpointer(self, rank: int) -> Checkpointer: + return NullCheckpointer() + + +@CheckpointerInitializerSelector.register("threshold_calibration") +@dataclasses.dataclass +class ThresholdCalibrationCheckpointerInit(CheckpointerInitializer): + """ + Configuration for threshold calibration checkpointer. + """ + + factor: float = 10.0 + + def get_checkpointer(self, rank: int) -> Checkpointer: + return ThresholdCalibrationCheckpointer(self.factor) + + +@CheckpointerInitializerSelector.register("validation") +@dataclasses.dataclass +class ValidationCheckpointerInit(CheckpointerInitializer): + """ + Configuration for validation checkpointer. + """ + + savepoint_data_path: str + threshold_filename: str + + def get_checkpointer(self, rank: int) -> Checkpointer: + with open(self.threshold_filename, "r") as f: + data = yaml.safe_load(f) + thresholds = dacite.from_dict( + data_class=SavepointThresholds, + data=data, + config=dacite.Config(strict=True), + ) + return ValidationCheckpointer( + savepoint_data_path=self.savepoint_data_path, + thresholds=thresholds, + rank=rank, + ) + + +@CheckpointerInitializerSelector.register("snapshot") +@dataclasses.dataclass +class SnapshotCheckpointerInit(CheckpointerInitializer): + """ + Configuration for snapshot checkpointer. + """ + + def get_checkpointer(self, rank: int) -> Checkpointer: + return SnapshotCheckpointer(rank=rank) diff --git a/driver/pace/driver/driver.py b/driver/pace/driver/driver.py index ea4a4939a..2db1a854b 100644 --- a/driver/pace/driver/driver.py +++ b/driver/pace/driver/driver.py @@ -24,9 +24,11 @@ # TODO: move update_atmos_state into pace.driver from pace.stencils import update_atmos_state +from pace.util.checkpointer import Checkpointer from pace.util.communicator import CubedSphereCommunicator from . import diagnostics +from .checkpointer import CheckpointerInitializerSelector, NullCheckpointerInit from .comm import CreatesCommSelector from .grid import GeneratedGridConfig, GridInitializerSelector from .initialization import InitializerSelector @@ -83,6 +85,7 @@ class DriverConfig: initial state of the model before timestepping output_frequency: number of model timesteps between diagnostic timesteps, defaults to every timestep + checkpointer: specifies the type of checkpointer to use """ stencil_config: pace.dsl.StencilConfig @@ -124,6 +127,11 @@ class DriverConfig: pair_debug: bool = False output_initial_state: bool = False output_frequency: int = 1 + checkpointer_config: CheckpointerInitializerSelector = dataclasses.field( + default_factory=lambda: CheckpointerInitializerSelector( + type="null", config=NullCheckpointerInit() + ) + ) safety_check_frequency: Optional[int] = None @functools.cached_property @@ -229,6 +237,9 @@ def get_driver_state( grid_data=grid_data, ) + def get_checkpointer(self, rank: int) -> Checkpointer: + return self.checkpointer_config.get_checkpointer(rank) + @classmethod def from_dict(cls, kwargs: Dict[str, Any]) -> "DriverConfig": if isinstance(kwargs["dycore_config"], dict): @@ -274,7 +285,14 @@ def from_dict(cls, kwargs: Dict[str, Any]) -> "DriverConfig": kwargs["grid_config"] = GridInitializerSelector.from_dict( kwargs["grid_config"] ) - + if "checkpointer_config" in kwargs: + kwargs["checkpointer_config"] = CheckpointerInitializerSelector.from_dict( + kwargs["checkpointer_config"] + ) + else: + kwargs["checkpointer_config"] = CheckpointerInitializerSelector( + type="null", config=NullCheckpointerInit() + ) if ( isinstance(kwargs["stencil_config"], dict) and "dace_config" in kwargs["stencil_config"].keys() @@ -467,6 +485,7 @@ def exit_instead_of_build(self): ) logger.info("setting up state done") + self.checkpointer = self.config.get_checkpointer(rank=communicator.rank) self._start_time = self.config.initialization.start_time logger.info("setting up dycore object started") self.dycore = fv3core.DynamicalCore( @@ -479,6 +498,7 @@ def exit_instead_of_build(self): timestep=self.config.timestep, phis=self.state.dycore_state.phis, state=self.state.dycore_state, + checkpointer=self.checkpointer, ) logger.info("setting up dycore object done") @@ -513,6 +533,7 @@ def exit_instead_of_build(self): dycore_only=self.config.dycore_only, apply_tendencies=self.config.apply_tendencies, tendency_state=self.state.tendency_state, + checkpointer=self.checkpointer, ) else: # Make sure those are set to None to raise any issues diff --git a/driver/pace/driver/grid.py b/driver/pace/driver/grid.py index d2295e005..e31f7bd0f 100644 --- a/driver/pace/driver/grid.py +++ b/driver/pace/driver/grid.py @@ -1,16 +1,18 @@ import abc import dataclasses import logging +import os from typing import ClassVar, Optional, Tuple import f90nml +import xarray as xr import pace.driver import pace.dsl import pace.physics import pace.stencils import pace.util.grid -from pace.stencils.testing import TranslateGrid +from pace.stencils.testing import TranslateGrid, dataset_to_dict from pace.util import CubedSphereCommunicator, QuantityFactory from pace.util.grid import ( DampingCoefficients, @@ -139,6 +141,7 @@ class SerialboxGridConfig(GridInitializer): """ path: str + netcdf: bool = True @property def _f90_namelist(self) -> f90nml.Namelist: @@ -163,10 +166,21 @@ def _get_serialized_grid( communicator: pace.util.CubedSphereCommunicator, backend: str, ) -> pace.stencils.testing.grid.Grid: # type: ignore - ser = self._serializer(communicator) - grid = TranslateGrid.new_from_serialized_data( - ser, communicator.rank, self._namelist.layout, backend - ).python_grid() + if self.netcdf: + ds_grid: xr.Dataset = xr.open_dataset( + os.path.join(self.path, "Grid-Info.nc") + ).isel(savepoint=0) + grid = TranslateGrid( + dataset_to_dict(ds_grid.isel(rank=communicator.rank)), + rank=communicator.rank, + layout=self._namelist.layout, + backend=backend, + ).python_grid() + else: + ser = self._serializer(communicator) + grid = TranslateGrid.new_from_serialized_data( + ser, communicator.rank, self._namelist.layout, backend + ).python_grid() return grid def get_grid( diff --git a/driver/pace/driver/initialization.py b/driver/pace/driver/initialization.py index 60f2fc425..9e77cd26e 100644 --- a/driver/pace/driver/initialization.py +++ b/driver/pace/driver/initialization.py @@ -7,6 +7,7 @@ from typing import Callable, ClassVar, Type, TypeVar import f90nml +import xarray as xr import pace.driver import pace.dsl @@ -21,7 +22,7 @@ from pace.dsl.stencil import StencilFactory from pace.dsl.stencil_config import CompilationConfig from pace.fv3core.testing import TranslateFVDynamics -from pace.stencils.testing import TranslateGrid +from pace.stencils.testing import TranslateGrid, dataset_to_dict from pace.util.namelist import Namelist from .registry import Registry @@ -285,6 +286,7 @@ class SerialboxInit(Initializer): path: str serialized_grid: bool + netcdf: bool = True @property def start_time(self) -> datetime: @@ -298,17 +300,6 @@ def _f90_namelist(self) -> f90nml.Namelist: def _namelist(self) -> Namelist: return Namelist.from_f90nml(self._f90_namelist) - def _get_serialized_grid( - self, - communicator: pace.util.CubedSphereCommunicator, - backend: str, - ) -> pace.stencils.testing.grid.Grid: # type: ignore - ser = self._serializer(communicator) - grid = TranslateGrid.new_from_serialized_data( - ser, communicator.rank, self._namelist.layout, backend - ).python_grid() - return grid - def _serializer(self, communicator: pace.util.CubedSphereCommunicator): import serialbox @@ -319,6 +310,18 @@ def _serializer(self, communicator: pace.util.CubedSphereCommunicator): ) return serializer + def _get_grid(self, communicator: pace.util.CubedSphereCommunicator, backend: str): + ds_grid: xr.Dataset = xr.open_dataset( + os.path.join(self.path, "Grid-Info.nc") + ).isel(savepoint=0) + grid = TranslateGrid( + dataset_to_dict(ds_grid.isel(rank=communicator.rank)), + rank=communicator.rank, + layout=self._namelist.layout, + backend=backend, + ).python_grid() + return grid + def get_driver_state( self, quantity_factory: pace.util.QuantityFactory, @@ -353,10 +356,6 @@ def _initialize_dycore_state( backend: str, ) -> fv3core.DycoreState: - grid = self._get_serialized_grid(communicator=communicator, backend=backend) - - ser = self._serializer(communicator) - savepoint_in = ser.get_savepoint("Driver-In")[0] dace_config = DaceConfig( communicator, backend, @@ -370,11 +369,35 @@ def _initialize_dycore_state( dace_config=dace_config, ) stencil_factory = StencilFactory( - config=stencil_config, grid_indexing=grid.grid_indexing + config=stencil_config, + grid_indexing=pace.dsl.GridIndexing.from_sizer_and_communicator( + sizer=pace.util.SubtileGridSizer.from_tile_params( + nx_tile=self._namelist.npx - 1, + ny_tile=self._namelist.npy - 1, + nz=self._namelist.npz, + n_halo=3, + tile_partitioner=communicator.partitioner.tile, + tile_rank=communicator.rank, + extra_dim_lengths={}, + layout=self._namelist.layout, + ), + cube=communicator, + ), ) + grid = self._get_grid(communicator, backend) translate_object = TranslateFVDynamics(grid, self._namelist, stencil_factory) - input_data = translate_object.collect_input_data(ser, savepoint_in) - dycore_state = translate_object.state_from_inputs(input_data) + if self.netcdf: + ds = xr.open_dataset(os.path.join(self.path, "Driver-In.nc")).sel( + savepoint=0, rank=communicator.rank + ) + input_data = dataset_to_dict(ds.copy()) + dycore_state, grid_data = translate_object.prepare_data(input_data) + else: + ser = self._serializer(communicator) + savepoint_in = ser.get_savepoint("Driver-In")[0] + + input_data = translate_object.collect_input_data(ser, savepoint_in) + dycore_state = translate_object.state_from_inputs(input_data) return dycore_state diff --git a/fv3core/pace/fv3core/stencils/fv_dynamics.py b/fv3core/pace/fv3core/stencils/fv_dynamics.py index 9129fc9d3..cad8b989d 100644 --- a/fv3core/pace/fv3core/stencils/fv_dynamics.py +++ b/fv3core/pace/fv3core/stencils/fv_dynamics.py @@ -150,6 +150,13 @@ def __init__( dace_compiletime_args=["state", "tag"], ) + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate="_checkpoint_driver_in", + dace_compiletime_args=["state"], + ) + orchestrate( obj=self, config=stencil_factory.config.dace_config, @@ -337,6 +344,21 @@ def _checkpoint_fvdynamics(self, state: DycoreState, tag: str): qvapor=state.qvapor, ) + def _checkpoint_driver_in(self, state: DycoreState): + if self.call_checkpointer: + self.checkpointer( + "Driver-In", + u=state.u, + v=state.v, + w=state.w, + delz=state.delz, + ua=state.ua, + va=state.va, + uc=state.uc, + vc=state.vc, + qvapor=state.qvapor, + ) + def _checkpoint_remapping_in( self, state: DycoreState, @@ -434,6 +456,7 @@ def step_dynamics( state: model prognostic state and inputs """ self._checkpoint_fvdynamics(state=state, tag="In") + self._checkpoint_driver_in(state=state) self._compute(state, timer) self._checkpoint_fvdynamics(state=state, tag="Out") diff --git a/stencils/pace/stencils/update_atmos_state.py b/stencils/pace/stencils/update_atmos_state.py index cb97fabb5..529eb06a3 100644 --- a/stencils/pace/stencils/update_atmos_state.py +++ b/stencils/pace/stencils/update_atmos_state.py @@ -249,7 +249,12 @@ def __init__( dycore_only: bool, apply_tendencies: bool, tendency_state, + checkpointer: Optional[pace.util.Checkpointer] = None, ): + self.checkpointer = checkpointer + # this is only computed in init because Dace does not yet support + # this operation + self.call_checkpointer = checkpointer is not None orchestrate( obj=self, config=stencil_factory.config.dace_config, @@ -258,7 +263,13 @@ def __init__( "phy_state", ], ) - + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate="_checkpoint_driver_out", + dace_compiletime_args=["state"], + ) + self.grid_indexing = stencil_factory.grid_indexing grid_indexing = stencil_factory.grid_indexing self.namelist = namelist self._rdt = 1.0 / Float(self.namelist.dt_atmos) @@ -294,6 +305,26 @@ def __init__( # fill_GFS_delp self._apply_tendencies = apply_tendencies + def _checkpoint_driver_out(self, state: fv3core.DycoreState): + if self.call_checkpointer: + self.checkpointer( + "Driver-Out", + u=state.u, + v=state.v, + w=state.w, + delz=state.delz, + ua=state.ua, + va=state.va, + uc=state.uc, + vc=state.vc, + qvapor=state.qvapor, + qliquid=state.qliquid, + qrain=state.qrain, + qsnow=state.qsnow, + qice=state.qice, + qgraupel=state.qgraupel, + ) + # [DaCe] Parsing limit: accessing a quantity withing a dataclass more than # one-level down in the call stack is forbidden for now due to the quantity # being resolved early has an array (loose of type of the object leads @@ -347,3 +378,4 @@ def __call__( pt_dt, dt=dt, ) + self._checkpoint_driver_out(state=dycore_state) diff --git a/tests/savepoint/test_checkpoints.py b/tests/savepoint/test_checkpoints.py index dacc6b5cd..5edfbbf82 100644 --- a/tests/savepoint/test_checkpoints.py +++ b/tests/savepoint/test_checkpoints.py @@ -10,7 +10,15 @@ import pace.dsl import pace.util -from pace import fv3core +from pace import fv3core, physics +from pace.driver import Driver, DriverConfig +from pace.driver.checkpointer import ( + Checkpointer, + NullCheckpointerInit, + ValidationCheckpointerInit, +) +from pace.driver.grid import SerialboxGridConfig +from pace.driver.initialization import Initializer, SerialboxInit from pace.fv3core.initialization.dycore_state import DycoreState from pace.fv3core.testing.translate_fvdynamics import TranslateFVDynamics from pace.stencils.testing import TranslateGrid, dataset_to_dict @@ -144,6 +152,143 @@ def test_fv_dynamics( dycore.step_dynamics(state) +def make_driver( + namelist: pace.util.Namelist, + data_path: str, + backend: str, + checkpointer: Checkpointer, +): + communicator = pace.util.CubedSphereCommunicator( + comm=pace.util.MPIComm(), + partitioner=pace.util.CubedSpherePartitioner( + tile=pace.util.TilePartitioner(layout=namelist.layout) + ), + ) + stencil_config = pace.dsl.StencilConfig( + compilation_config=pace.dsl.CompilationConfig( + backend=backend, + communicator=communicator, + rebuild=False, + ) + ) + initialization = SerialboxInit(path=data_path, serialized_grid=True, netcdf=True) + nx_tile = namelist.npx - 1 + nz = namelist.npz + layout = namelist.layout + dt_atmos = namelist.dt_atmos + grid_config = SerialboxGridConfig(path=data_path, netcdf=True) + dycore_config = fv3core.DynamicalCoreConfig.from_namelist(namelist) + physics_config = physics.PhysicsConfig.from_namelist(namelist) + seconds = 225.0 + driver_config = DriverConfig( + stencil_config=stencil_config, + initialization=initialization, + nx_tile=nx_tile, + nz=nz, + layout=layout, + dt_atmos=dt_atmos, + grid_config=grid_config, + dycore_config=dycore_config, + physics_config=physics_config, + seconds=seconds, + checkpointer_config=checkpointer, + ) + driver = Driver(config=driver_config) + return driver, communicator + + +def test_driver( + backend: str, data_path: str, calibrate_thresholds: bool, threshold_path: str +): + print("start test call") + namelist = pace.util.Namelist.from_f90nml( + f90nml.read(os.path.join(data_path, "input.nml")) + ) + threshold_filename = os.path.join(threshold_path, "driver.yaml") + driver, communicator = make_driver( + namelist=namelist, + data_path=data_path, + backend=backend, + checkpointer=NullCheckpointerInit(), + ) + if calibrate_thresholds: + thresholds = _calibrate_driver_thresholds( + n_trials=10, + factor=20.0, + namelist=namelist, + data_path=data_path, + backend=backend, + initialization=driver.config.initialization, + quantity_factory=driver.quantity_factory, + communicator=communicator, + damping_coefficients=driver.state.damping_coefficients, + driver_grid_data=driver.state.driver_grid_data, + grid_data=driver.state.grid_data, + ) + print(f"calibrated thresholds: {thresholds}") + if communicator.rank == 0: + with open(threshold_filename, "w") as f: + yaml.safe_dump(dataclasses.asdict(thresholds), f) + communicator.comm.barrier() + validation = ValidationCheckpointerInit( + savepoint_data_path=data_path, threshold_filename=threshold_filename + ) + driver, _ = make_driver( + namelist=namelist, + data_path=data_path, + backend=backend, + checkpointer=validation, + ) + with driver.checkpointer.trial(): + driver.step_all() + + +def _calibrate_driver_thresholds( + n_trials: int, + factor: float, + namelist, + data_path: str, + backend: str, + initialization: Initializer, + quantity_factory: pace.util.QuantityFactory, + communicator: pace.util.CubedSphereCommunicator, + damping_coefficients: DampingCoefficients, + driver_grid_data: pace.util.grid.DriverGridData, + grid_data: pace.util.grid.GridData, +): + calibration = pace.util.ThresholdCalibrationCheckpointer(factor=factor) + for i in range(n_trials): + print(f"running calibration trial {i}") + trial_state = initialization.get_driver_state( + quantity_factory, + communicator, + damping_coefficients, + driver_grid_data, + grid_data, + ).dycore_state + perturb(dycore_state_to_dict(trial_state)) + driver, _ = make_driver( + namelist=namelist, + data_path=data_path, + backend=backend, + checkpointer=NullCheckpointerInit(), + ) + # Need to use the same calibration checkpointer for all trials + driver.checkpointer = calibration + driver.state.dycore_state = trial_state + driver.dycore.checkpointer = calibration + driver.dycore.acoustic_dynamics.checkpointer = calibration + driver.end_of_step_update.checkpointer = calibration + with calibration.trial(): + driver.step_all() + all_thresholds = communicator.comm.allgather(calibration.thresholds) + thresholds = merge_thresholds(all_thresholds) + set_manual_thresholds(thresholds) + set_manual_thresholds(thresholds, savepoint_name="Driver-In") + override_var_relative_threshold(thresholds, "Driver-Out", 100.0, "qgraupel") + return thresholds + + def _calibrate_thresholds( initializer: StateInitializer, communicator: pace.util.CubedSphereCommunicator, @@ -181,13 +326,29 @@ def _calibrate_thresholds( return thresholds -def set_manual_thresholds(thresholds: SavepointThresholds): +def set_manual_thresholds( + thresholds: SavepointThresholds, savepoint_name: str = "FVDynamics-In" +): # all thresholds on the input data are 0 because no computation has happened yet - for entry in thresholds.savepoints["FVDynamics-In"]: + for entry in thresholds.savepoints[savepoint_name]: for name in entry: entry[name] = pace.util.Threshold(relative=0.0, absolute=0.0) +def override_var_relative_threshold( + thresholds: SavepointThresholds, + savepoint_name: str, + relative: float, + var: str, +): + for entry in thresholds.savepoints[savepoint_name]: + for name in entry: + if name == var: + entry[name] = pace.util.Threshold( + relative=relative, absolute=entry[name].absolute + ) + + def merge_thresholds(all_thresholds: List[pace.util.SavepointThresholds]): thresholds = all_thresholds[0] for other_thresholds in all_thresholds[1:]: diff --git a/tests/savepoint/thresholds/driver.yaml b/tests/savepoint/thresholds/driver.yaml new file mode 100644 index 000000000..e3cb26673 --- /dev/null +++ b/tests/savepoint/thresholds/driver.yaml @@ -0,0 +1,430 @@ +savepoints: + C_SW-In: + - delpd: + absolute: 4.547473508864641e-12 + relative: 4.440654652315846e-15 + divgdd: + absolute: 0.0 + relative: 0.0 + ptd: + absolute: 7.105427357601002e-13 + relative: 3.7526204112473214e-14 + uad: + absolute: 320.0 + relative: 4.4408713566206765e-15 + ucd: + absolute: 7.105427357601002e-14 + relative: 4.440798732542637e-15 + ud: + absolute: 7.105427357601002e-14 + relative: 4.4408708176729106e-15 + utd: + absolute: 0.0 + relative: 0.0 + vad: + absolute: 320.0 + relative: 4.440857803874169e-15 + vcd: + absolute: 7.105427357601002e-14 + relative: 4.440884509085042e-15 + vd: + absolute: 7.105427357601002e-14 + relative: 4.4408708176729106e-15 + vtd: + absolute: 0.0 + relative: 0.0 + wd: + absolute: 1.3877787807814457e-16 + relative: 4.440873885562907e-15 + C_SW-Out: + - delpd: + absolute: 4.547473508864641e-12 + relative: 4.440654652315846e-15 + divgdd: + absolute: 7.73446959961583e-19 + relative: 6.60373166336595e-09 + ptd: + absolute: 7.105427357601002e-13 + relative: 3.7526204112473214e-14 + uad: + absolute: 640.0 + relative: 2.5610943871133864e-09 + ucd: + absolute: 3.552713678800501e-13 + relative: 1.0041731179975834e-10 + ud: + absolute: 7.105427357601002e-14 + relative: 4.4408708176729106e-15 + utd: + absolute: 3.337860107421875e-05 + relative: 1.0106137143516776e-10 + vad: + absolute: 640.0 + relative: 2.9879456190554356e-09 + vcd: + absolute: 3.552713678800501e-13 + relative: 1.3152167624547145e-09 + vd: + absolute: 7.105427357601002e-14 + relative: 4.4408708176729106e-15 + vtd: + absolute: 3.337860107421875e-05 + relative: 2.759744670137014e-09 + wd: + absolute: 1.3877787807814457e-16 + relative: 4.440873885562907e-15 + D_SW-In: + - delpcd: + absolute: 3.337860107421875e-05 + relative: 2.759744670137014e-09 + delpd: + absolute: 4.547473508864641e-12 + relative: 4.440654652315846e-15 + divgdd: + absolute: 7.73446959961583e-19 + relative: 6.60373166336595e-09 + mfxd: + absolute: 0.0 + relative: 0.0 + mfyd: + absolute: 0.0 + relative: 0.0 + ptd: + absolute: 7.105427357601002e-13 + relative: 3.7526204112473214e-14 + uad: + absolute: 640.0 + relative: 2.5610943871133864e-09 + ucd: + absolute: 1.9184653865522705e-12 + relative: 8.84542968289086e-09 + ud: + absolute: 7.105427357601002e-14 + relative: 4.4408708176729106e-15 + vad: + absolute: 640.0 + relative: 2.9879456190554356e-09 + vcd: + absolute: 1.5631940186722204e-12 + relative: 4.71833428344894e-07 + vd: + absolute: 7.105427357601002e-14 + relative: 4.4408708176729106e-15 + wd: + absolute: 1.3877787807814457e-16 + relative: 4.440873885562907e-15 + xfxd: + absolute: 0.0 + relative: 0.0 + yfxd: + absolute: 0.0 + relative: 0.0 + zhd: + absolute: 2.9103830456733704e-10 + relative: 1.2510886218119119e-14 + D_SW-Out: + - delpcd: + absolute: 3.337860107421875e-05 + relative: 6.60373166336595e-09 + delpd: + absolute: 3.0789142329012975e-05 + relative: 3.685628763982423e-08 + divgdd: + absolute: 6.363335016247931e-19 + relative: 7.787161293301443e-08 + mfxd: + absolute: 2699755520.0 + relative: 1.092569174756783e-05 + mfyd: + absolute: 2718433280.0 + relative: 1.5722289505077605e-05 + ptd: + absolute: 4.4649794972428936e-09 + relative: 1.45238023389111e-10 + uad: + absolute: 640.0 + relative: 2.5610943871133864e-09 + ucd: + absolute: 1.1013412404281553e-12 + relative: 3.787578374646247e-07 + ud: + absolute: 2.2351741790771484e-07 + relative: 4.128687367137859e-11 + vad: + absolute: 640.0 + relative: 2.9879456190554356e-09 + vcd: + absolute: 1.1368683772161603e-12 + relative: 3.7415288441422884e-07 + vd: + absolute: 2.2351741790771484e-07 + relative: 1.7276235613495599e-10 + wd: + absolute: 2.4303545287374106e-11 + relative: 9.884849487502515e-08 + xfxd: + absolute: 0.0002837181091308594 + relative: 8.845428594092422e-09 + yfxd: + absolute: 0.0002193450927734375 + relative: 4.7183342924510334e-07 + Driver-In: + - delz: + absolute: 0.0 + relative: 0.0 + qvapor: + absolute: 0.0 + relative: 0.0 + u: + absolute: 0.0 + relative: 0.0 + ua: + absolute: 0.0 + relative: 0.0 + uc: + absolute: 0.0 + relative: 0.0 + v: + absolute: 0.0 + relative: 0.0 + va: + absolute: 0.0 + relative: 0.0 + vc: + absolute: 0.0 + relative: 0.0 + w: + absolute: 0.0 + relative: 0.0 + Driver-Out: + - delz: + absolute: 42949672960.0 + relative: 4.1818686382352055e-09 + qgraupel: + absolute: 4.1186977060240354e-19 + relative: 100.0 + qice: + absolute: 8.579538828036863e-13 + relative: 200.7843137254902 + qliquid: + absolute: 5.778978742754497e-12 + relative: 100.0 + qrain: + absolute: 7.329630202475637e-15 + relative: 57.142857142857146 + qsnow: + absolute: 0.0 + relative: 0.0 + qvapor: + absolute: 5.779057787869135e-12 + relative: 5.259824345429924e-10 + u: + absolute: 1.1179039915987232e-07 + relative: 2.7096590900922386e-05 + ua: + absolute: 640.0 + relative: 8.205840217910252e-07 + uc: + absolute: 1.1013412404281553e-12 + relative: 3.787578374646247e-07 + v: + absolute: 1.0295885033428931e-07 + relative: 5.786090220049254e-05 + va: + absolute: 640.0 + relative: 1.3000770306829817e-05 + vc: + absolute: 1.1368683772161603e-12 + relative: 3.7415288441422884e-07 + w: + absolute: 2.4117311691988774e-07 + relative: 0.012770736069167316 + FVDynamics-In: + - delz: + absolute: 0.0 + relative: 0.0 + qvapor: + absolute: 0.0 + relative: 0.0 + u: + absolute: 0.0 + relative: 0.0 + ua: + absolute: 0.0 + relative: 0.0 + uc: + absolute: 0.0 + relative: 0.0 + v: + absolute: 0.0 + relative: 0.0 + va: + absolute: 0.0 + relative: 0.0 + vc: + absolute: 0.0 + relative: 0.0 + w: + absolute: 0.0 + relative: 0.0 + FVDynamics-Out: + - delz: + absolute: 42949672960.0 + relative: 4.1818686382352055e-09 + qvapor: + absolute: 6.938893903907228e-17 + relative: 4.440759627770741e-15 + u: + absolute: 1.1179039915987232e-07 + relative: 2.7096595351274196e-05 + ua: + absolute: 640.0 + relative: 8.173580987327614e-07 + uc: + absolute: 1.1013412404281553e-12 + relative: 3.787578374646247e-07 + v: + absolute: 1.0295885033428931e-07 + relative: 5.786089077256556e-05 + va: + absolute: 640.0 + relative: 1.3000770306829817e-05 + vc: + absolute: 1.1368683772161603e-12 + relative: 3.7415288441422884e-07 + w: + absolute: 2.4117311691988774e-07 + relative: 0.012770736069167316 + Remapping-In: + - cappa: + absolute: 2.220446049250313e-15 + relative: 7.808004003536861e-15 + delp: + absolute: 3.0789142329012975e-05 + relative: 3.685628763982423e-08 + delz: + absolute: 42949672960.0 + relative: 3.675492881583554e-08 + dp1: + absolute: 2.0526092612271896e-05 + relative: 2.457020580652998e-08 + omga: + absolute: 4.510281037539698e-16 + relative: 3.445719536224e-11 + pe: + absolute: 4.52948734164238e-05 + relative: 1.0461736693988013e-08 + peln: + absolute: 1.0461729260669017e-08 + relative: 1.3207323078842145e-09 + phis: + absolute: 0.0 + relative: 0.0 + pk: + absolute: 2.874294580124115e-08 + relative: 2.9892885543624336e-09 + pkz: + absolute: 2.7328290741479577e-08 + relative: 2.7822971963059133e-09 + ps: + absolute: .nan + relative: 2.986143728676916e-15 + pt: + absolute: 4.4649794972428936e-09 + relative: 1.4523802285210573e-10 + te_2d: + absolute: 0.0 + relative: 0.0 + u: + absolute: 1.1633431995505816e-07 + relative: 2.5634435381877942e-05 + ua: + absolute: 640.0 + relative: 2.5610943871133864e-09 + v: + absolute: 9.211390761265648e-08 + relative: 5.759187191109189e-05 + va: + absolute: 640.0 + relative: 2.9879456190554356e-09 + w: + absolute: 2.411407160323592e-07 + relative: 0.012771932314352977 + wsd: + absolute: 0.0 + relative: 0.0 + Remapping-Out: + - cappa: + absolute: 7.488454301096681e-12 + relative: 2.6269152479286847e-11 + delp: + absolute: 1.0318763088434935e-06 + relative: 5.630487788261942e-10 + delz: + absolute: 42949672960.0 + relative: 4.1818686382352055e-09 + dp1: + absolute: 2.0526092612271896e-05 + relative: 2.457020580652998e-08 + omga: + absolute: 4.510281037539698e-16 + relative: 3.445719536224e-11 + pe: + absolute: 4.52948734164238e-05 + relative: 4.53403465846039e-10 + peln: + absolute: 4.5339731968851993e-10 + relative: 3.939689472067183e-11 + pk: + absolute: 3.475335574876226e-09 + relative: 1.2955172037039188e-10 + pkz: + absolute: 4.517630713962717e-09 + relative: 1.7868505070598108e-10 + pt: + absolute: 1.0104253078679903e-06 + relative: 4.223494641918948e-09 + te_2d: + absolute: 0.0 + relative: 0.0 + u: + absolute: 1.1179039915987232e-07 + relative: 2.7096595351274196e-05 + v: + absolute: 1.0295885033428931e-07 + relative: 5.786089077256556e-05 + w: + absolute: 2.4117311691988774e-07 + relative: 0.012770736069167316 + Tracer2D1L-In: + - cxd: + absolute: 7.37257477290143e-16 + relative: 8.845428186550091e-09 + cyd: + absolute: 5.724587470723463e-16 + relative: 4.718334315123565e-07 + dp1: + absolute: 4.547473508864641e-12 + relative: 4.440654652315846e-15 + mfxd: + absolute: 2699755520.0 + relative: 1.092569174756783e-05 + mfyd: + absolute: 2718433280.0 + relative: 1.5722289505077605e-05 + Tracer2D1L-Out: + - cxd: + absolute: 2.45029690981724e-16 + relative: 8.845426742609778e-09 + cyd: + absolute: 1.9081958235744878e-16 + relative: 4.718334337123763e-07 + dp1: + absolute: 2.0526092612271896e-05 + relative: 2.457020580652998e-08 + mfxd: + absolute: 899973120.0 + relative: 1.0925691746673609e-05 + mfyd: + absolute: 906158080.0 + relative: 1.5722289505077605e-05