Skip to content

Feature/add driver checkpoint #391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
21 changes: 21 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ commands:
enum:
- "test_main"
- "test_savepoint"
- "test_driver_checkpoint"
- "savepoint_tests"
- "savepoint_tests_mpi"
- "physics_savepoint_tests"
Expand Down Expand Up @@ -459,6 +460,20 @@ jobs:
target: test_savepoint
num_ranks: 6

test_driver_checkpoint:
machine:
image: ubuntu-2004:202111-02
resource_class: large
environment:
GOOGLE_APPLICATION_CREDENTIALS: /tmp/key.json
steps:
- checkout
- make_savepoints:
backend: numpy
experiment: c12_6ranks_baroclinic_dycore_microphysics
target: test_driver_checkpoint
num_ranks: 6

test_notebooks:
docker:
- image: gcr.io/vcm-ml/pace_notebook_examples
Expand Down Expand Up @@ -609,6 +624,12 @@ workflows:
filters:
tags:
only: /^v.*/
- test_driver_checkpoint:
context:
- GCLOUD_ENCODED_KEY
filters:
tags:
only: /^v.*/
- test_util:
filters:
tags:
Expand Down
40 changes: 40 additions & 0 deletions .jenkins/driver_checkpoint_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

## Summary:
# Jenkins plan (only working on Piz daint) to run dace orchestration and gather performance numbers.

## Syntax:
# .jenkins/action/driver_checkpoint_test.sh

JENKINS_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PACE_DIR=$JENKINS_DIR/../
export VIRTUALENV=${PACE_DIR}/venv
${JENKINS_DIR}/install_virtualenv.sh ${VIRTUALENV}
source ${VIRTUALENV}/bin/activate

BUILDENV_DIR=$PACE_DIR/buildenv
. ${BUILDENV_DIR}/schedulerTools.sh

cat << EOF > run.daint.slurm
#!/bin/bash
#SBATCH --constraint=gpu
#SBATCH --job-name=driver_checkpoint_test
#SBATCH --ntasks=6
#SBATCH --hint=multithread
#SBATCH --ntasks-per-node=6
#SBATCH --cpus-per-task=1
#SBATCH --output=driver.out
#SBATCH --time=00:45:00
#SBATCH --gres=gpu:1
#SBATCH --account=s1053
#SBATCH --partition=normal
########################################################
set -x
export OMP_NUM_THREADS=1
export TEST_ARGS="-v -s -rsx --backend=numpy "
export EXPERIMENT=c12_6ranks_baroclinic_dycore_microphysics
export MPIRUN_CALL="srun"
CONTAINER_CMD="" MPIRUN_ARGS="" DEV=n make test_driver_checkpoint
EOF
launch_job run.daint.slurm 3600
cat driver.out
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ test_main: build

test_savepoint: ## top level savepoint tests
TARGET=dycore $(MAKE) get_test_data
$(CONTAINER_CMD) $(CONTAINER_FLAGS) bash -c "$(SAVEPOINT_SETUP) && cd $(PACE_PATH) && $(MPIRUN_CALL) python -m pytest --data_path=$(EXPERIMENT_DATA_RUN)/dycore/ $(TEST_ARGS) $(PACE_PATH)/tests/savepoint"
$(CONTAINER_CMD) $(CONTAINER_FLAGS) bash -c "$(SAVEPOINT_SETUP) && cd $(PACE_PATH) && $(MPIRUN_CALL) python -m pytest --data_path=$(EXPERIMENT_DATA_RUN)/dycore/ $(TEST_ARGS) -k test_fv_dynamics $(PACE_PATH)/tests/savepoint"

test_driver_checkpoint:
TARGET=all EXPERIMENT=c12_6ranks_baroclinic_dycore_microphysics $(MAKE) get_test_data
$(CONTAINER_CMD) $(CONTAINER_FLAGS) bash -c "$(SAVEPOINT_SETUP) && cd $(PACE_PATH) && $(MPIRUN_CALL) python -m mpi4py -m pytest --data_path=$(EXPERIMENT_DATA_RUN)/all/ $(TEST_ARGS) -k test_driver $(PACE_PATH)/tests/savepoint"

test_notebooks: ## tests for jupyter notebooks, must be run in correct Python environment
pytest --nbmake "examples/notebooks"
Expand Down
116 changes: 116 additions & 0 deletions driver/pace/driver/checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import abc
import dataclasses
import logging
from typing import ClassVar

import dacite
import yaml

from pace.util import SavepointThresholds
from pace.util.checkpointer import (
Checkpointer,
NullCheckpointer,
SnapshotCheckpointer,
ThresholdCalibrationCheckpointer,
ValidationCheckpointer,
)

from .registry import Registry


logger = logging.getLogger(__name__)


class CheckpointerInitializer(abc.ABC):
@abc.abstractmethod
def get_checkpointer(self, rank: int) -> Checkpointer:
...


@dataclasses.dataclass
class CheckpointerInitializerSelector(CheckpointerInitializer):
"""
Dataclass for selecting the implementation of CheckpointerInitializer to use.

Used to circumvent the issue that dacite expects static class definitions,
but we would like to dynamically define which CheckpointerInitializer to use.
Does this by representing the part of the yaml specification that asks which
initializer to use, but deferring to the implementation
in that initializer when called.
"""

type: str
config: CheckpointerInitializer
registry: ClassVar[Registry] = Registry()

@classmethod
def register(cls, type_name):
return cls.registry.register(type_name)

def get_checkpointer(self, rank: int) -> Checkpointer:
return self.config.get_checkpointer(rank=rank)

@classmethod
def from_dict(cls, config: dict):
instance = cls.registry.from_dict(config)
return cls(config=instance, type=config["type"])


@CheckpointerInitializerSelector.register("null")
@dataclasses.dataclass
class NullCheckpointerInit(CheckpointerInitializer):
"""
Configuration for threshold calibration checkpointer.
"""

def get_checkpointer(self, rank: int) -> Checkpointer:
return NullCheckpointer()


@CheckpointerInitializerSelector.register("threshold_calibration")
@dataclasses.dataclass
class ThresholdCalibrationCheckpointerInit(CheckpointerInitializer):
"""
Configuration for threshold calibration checkpointer.
"""

factor: float = 10.0

def get_checkpointer(self, rank: int) -> Checkpointer:
return ThresholdCalibrationCheckpointer(self.factor)


@CheckpointerInitializerSelector.register("validation")
@dataclasses.dataclass
class ValidationCheckpointerInit(CheckpointerInitializer):
"""
Configuration for validation checkpointer.
"""

savepoint_data_path: str
threshold_filename: str

def get_checkpointer(self, rank: int) -> Checkpointer:
with open(self.threshold_filename, "r") as f:
data = yaml.safe_load(f)
thresholds = dacite.from_dict(
data_class=SavepointThresholds,
data=data,
config=dacite.Config(strict=True),
)
return ValidationCheckpointer(
savepoint_data_path=self.savepoint_data_path,
thresholds=thresholds,
rank=rank,
)


@CheckpointerInitializerSelector.register("snapshot")
@dataclasses.dataclass
class SnapshotCheckpointerInit(CheckpointerInitializer):
"""
Configuration for snapshot checkpointer.
"""

def get_checkpointer(self, rank: int) -> Checkpointer:
return SnapshotCheckpointer(rank=rank)
23 changes: 22 additions & 1 deletion driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@

# TODO: move update_atmos_state into pace.driver
from pace.stencils import update_atmos_state
from pace.util.checkpointer import Checkpointer
from pace.util.communicator import CubedSphereCommunicator

from . import diagnostics
from .checkpointer import CheckpointerInitializerSelector, NullCheckpointerInit
from .comm import CreatesCommSelector
from .grid import GeneratedGridConfig, GridInitializerSelector
from .initialization import InitializerSelector
Expand Down Expand Up @@ -83,6 +85,7 @@ class DriverConfig:
initial state of the model before timestepping
output_frequency: number of model timesteps between diagnostic timesteps,
defaults to every timestep
checkpointer: specifies the type of checkpointer to use
"""

stencil_config: pace.dsl.StencilConfig
Expand Down Expand Up @@ -124,6 +127,11 @@ class DriverConfig:
pair_debug: bool = False
output_initial_state: bool = False
output_frequency: int = 1
checkpointer_config: CheckpointerInitializerSelector = dataclasses.field(
default_factory=lambda: CheckpointerInitializerSelector(
type="null", config=NullCheckpointerInit()
)
)
safety_check_frequency: Optional[int] = None

@functools.cached_property
Expand Down Expand Up @@ -229,6 +237,9 @@ def get_driver_state(
grid_data=grid_data,
)

def get_checkpointer(self, rank: int) -> Checkpointer:
return self.checkpointer_config.get_checkpointer(rank)

@classmethod
def from_dict(cls, kwargs: Dict[str, Any]) -> "DriverConfig":
if isinstance(kwargs["dycore_config"], dict):
Expand Down Expand Up @@ -274,7 +285,14 @@ def from_dict(cls, kwargs: Dict[str, Any]) -> "DriverConfig":
kwargs["grid_config"] = GridInitializerSelector.from_dict(
kwargs["grid_config"]
)

if "checkpointer_config" in kwargs:
kwargs["checkpointer_config"] = CheckpointerInitializerSelector.from_dict(
kwargs["checkpointer_config"]
)
else:
kwargs["checkpointer_config"] = CheckpointerInitializerSelector(
type="null", config=NullCheckpointerInit()
)
if (
isinstance(kwargs["stencil_config"], dict)
and "dace_config" in kwargs["stencil_config"].keys()
Expand Down Expand Up @@ -467,6 +485,7 @@ def exit_instead_of_build(self):
)
logger.info("setting up state done")

self.checkpointer = self.config.get_checkpointer(rank=communicator.rank)
self._start_time = self.config.initialization.start_time
logger.info("setting up dycore object started")
self.dycore = fv3core.DynamicalCore(
Expand All @@ -479,6 +498,7 @@ def exit_instead_of_build(self):
timestep=self.config.timestep,
phis=self.state.dycore_state.phis,
state=self.state.dycore_state,
checkpointer=self.checkpointer,
)
logger.info("setting up dycore object done")

Expand Down Expand Up @@ -513,6 +533,7 @@ def exit_instead_of_build(self):
dycore_only=self.config.dycore_only,
apply_tendencies=self.config.apply_tendencies,
tendency_state=self.state.tendency_state,
checkpointer=self.checkpointer,
)
else:
# Make sure those are set to None to raise any issues
Expand Down
24 changes: 19 additions & 5 deletions driver/pace/driver/grid.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import abc
import dataclasses
import logging
import os
from typing import ClassVar, Optional, Tuple

import f90nml
import xarray as xr

import pace.driver
import pace.dsl
import pace.physics
import pace.stencils
import pace.util.grid
from pace.stencils.testing import TranslateGrid
from pace.stencils.testing import TranslateGrid, dataset_to_dict
from pace.util import CubedSphereCommunicator, QuantityFactory
from pace.util.grid import (
DampingCoefficients,
Expand Down Expand Up @@ -139,6 +141,7 @@ class SerialboxGridConfig(GridInitializer):
"""

path: str
netcdf: bool = True

@property
def _f90_namelist(self) -> f90nml.Namelist:
Expand All @@ -163,10 +166,21 @@ def _get_serialized_grid(
communicator: pace.util.CubedSphereCommunicator,
backend: str,
) -> pace.stencils.testing.grid.Grid: # type: ignore
ser = self._serializer(communicator)
grid = TranslateGrid.new_from_serialized_data(
ser, communicator.rank, self._namelist.layout, backend
).python_grid()
if self.netcdf:
ds_grid: xr.Dataset = xr.open_dataset(
os.path.join(self.path, "Grid-Info.nc")
).isel(savepoint=0)
grid = TranslateGrid(
dataset_to_dict(ds_grid.isel(rank=communicator.rank)),
rank=communicator.rank,
layout=self._namelist.layout,
backend=backend,
).python_grid()
else:
ser = self._serializer(communicator)
grid = TranslateGrid.new_from_serialized_data(
ser, communicator.rank, self._namelist.layout, backend
).python_grid()
return grid

def get_grid(
Expand Down
Loading