diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 7a714c2090..761b4f552f 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -60,7 +60,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]: Callable[[Namespace], None] The entry point hook of the backend. """ - raise NotImplementedError + from deepmd.jax.entrypoints.main import ( + main, + ) + + return main @property def deep_eval(self) -> type["DeepEvalBackend"]: diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 51c56e9681..703cc887a3 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -781,6 +781,7 @@ def __init__( self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.orig_sel = self.sel + self.ndescrpt = self.nnei * 4 def get_rcut(self) -> float: """Returns the cut-off radius.""" diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index 6435b6468f..3f7684d1f9 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -6,6 +6,8 @@ Union, ) +import numpy as np + from deepmd.dpmodel.common import ( DEFAULT_PRECISION, ) @@ -17,6 +19,10 @@ from deepmd.dpmodel.fitting.general_fitting import ( GeneralFitting, ) + +from deepmd.utils.out_stat import ( + compute_stats_from_redu, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -86,3 +92,69 @@ def serialize(self) -> dict: **super().serialize(), "type": "ener", } + + def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None: + """Compute the output statistics. + + Parameters + ---------- + all_stat + must have the following components: + all_stat['energy'] of shape n_sys x n_batch x n_frame + can be prepared by model.make_stat_input + mixed_type + Whether to perform the mixed_type mode. + If True, the input data has the mixed_type format (see doc/model/train_se_atten.md), + in which frames in a system may have different natoms_vec(s), with the same nloc. + """ + self.bias_atom_e = self._compute_output_stats( + all_stat, rcond=self.rcond, mixed_type=mixed_type + ) + + def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False): + data = all_stat["energy"] + # data[sys_idx][batch_idx][frame_idx] + sys_ener = [] + for ss in range(len(data)): + sys_data = [] + for ii in range(len(data[ss])): + for jj in range(len(data[ss][ii])): + sys_data.append(data[ss][ii][jj]) + sys_data = np.concatenate(sys_data) + sys_ener.append(np.average(sys_data)) + sys_ener = np.array(sys_ener) + sys_tynatom = [] + if mixed_type: + data = all_stat["real_natoms_vec"] + nsys = len(data) + for ss in range(len(data)): + tmp_tynatom = [] + for ii in range(len(data[ss])): + for jj in range(len(data[ss][ii])): + tmp_tynatom.append(data[ss][ii][jj].astype(np.float64)) + tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0) + sys_tynatom.append(tmp_tynatom) + else: + data = all_stat["natoms_vec"] + nsys = len(data) + for ss in range(len(data)): + sys_tynatom.append(data[ss][0].astype(np.float64)) + sys_tynatom = np.array(sys_tynatom) + sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) + sys_tynatom = sys_tynatom[:, 2:] + if len(self.atom_ener) > 0: + # Atomic energies stats are incorrect if atomic energies are assigned. + # In this situation, we directly use these assigned energies instead of computing stats. + # This will make the loss decrease quickly + assigned_atom_ener = np.array( + [ee if ee is not None else np.nan for ee in self.atom_ener_v] + ) + else: + assigned_atom_ener = None + energy_shift, _ = compute_stats_from_redu( + sys_ener.reshape(-1, 1), + sys_tynatom, + assigned_bias=assigned_atom_ener, + rcond=rcond, + ) + return energy_shift.ravel() diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 7a17fcfcf0..d4b49540a4 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -177,7 +177,9 @@ def call( delta=self.huber_delta, ) loss += pref_e * l_huber_loss - more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy) + more_loss["rmse_e"] = self.display_if_exist( + xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy + ) if self.has_f: l2_force_loss = xp.mean(xp.square(diff_f)) if not self.use_huber: @@ -189,7 +191,9 @@ def call( delta=self.huber_delta, ) loss += pref_f * l_huber_loss - more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force) + more_loss["rmse_f"] = self.display_if_exist( + xp.sqrt(l2_force_loss), find_force + ) if self.has_v: virial_reshape = xp.reshape(virial, [-1]) virial_hat_reshape = xp.reshape(virial_hat, [-1]) @@ -381,3 +385,79 @@ def deserialize(cls, data: dict) -> "Loss": check_version_compatibility(data.pop("@version"), 2, 1) data.pop("@class") return cls(**data) + + +class EnergyHessianLoss(EnergyLoss): + def __init__( + self, + start_pref_h=0.0, + limit_pref_h=0.0, + **kwargs, + ): + r"""Enable the layer to compute loss on hessian. + + Parameters + ---------- + start_pref_h : float + The prefactor of hessian loss at the start of the training. + limit_pref_h : float + The prefactor of hessian loss at the end of the training. + **kwargs + Other keyword arguments. + """ + EnergyLoss.__init__(self, **kwargs) + self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 + + self.start_pref_h = start_pref_h + self.limit_pref_h = limit_pref_h + + def call( + self, + learning_rate: float, + natoms: int, + model_dict: dict[str, np.ndarray], + label_dict: dict[str, np.ndarray], + ) -> dict[str, np.ndarray]: + """Calculate loss from model results and labeled results.""" + loss, more_loss = EnergyLoss.call( + self, learning_rate, natoms, model_dict, label_dict + ) + xp = array_api_compat.array_namespace(model_dict["energy"]) + coef = learning_rate / self.starter_learning_rate + pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef + + if ( + self.has_h + and "energy_derv_r_derv_r" in model_dict + and "hessian" in label_dict + ): + find_hessian = label_dict.get("find_hessian", 0.0) + pref_h = pref_h * find_hessian + diff_h = label_dict["hessian"].reshape( + -1, + ) - model_dict["energy_derv_r_derv_r"].reshape( + -1, + ) + l2_hessian_loss = xp.mean(xp.square(diff_h)) + loss += pref_h * l2_hessian_loss + rmse_h = xp.sqrt(l2_hessian_loss) + more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian) + + more_loss["rmse"] = xp.sqrt(loss) + return loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Add hessian label requirement needed for this loss calculation.""" + label_requirement = super().label_requirement + if self.has_h: + label_requirement.append( + DataRequirementItem( + "hessian", + ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms + atomic=True, + must=False, + high_prec=False, + ) + ) + return label_requirement diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index e25739fa56..278f565a3a 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -119,12 +119,15 @@ def iter( "last_dim should be 1 for raial-only or 4 for full descriptor." ) for system in data: - coord, atype, box, natoms = ( + coord, atype, box = ( system["coord"], system["atype"], system["box"], - system["natoms"], ) + coord = xp.reshape(coord, (coord.shape[0], -1, 3)) # (nframes, nloc, 3) + atype = xp.reshape(atype, (coord.shape[0], -1)) # (nframes, nloc) + if box is not None: + box = xp.reshape(box, (coord.shape[0], 3, 3)) ( extended_coord, extended_atype, diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index 90c18fca22..9bfc6d2a16 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -45,9 +45,8 @@ def __init__( self.decay_rate = decay_rate self.min_lr = stop_lr - def value(self, step) -> np.float64: + def value(self, step, xp=np) -> np.float64: """Get the learning rate at the given step.""" - step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) - if step_lr < self.min_lr: - step_lr = self.min_lr + step_lr = self.start_lr * xp.power(self.decay_rate, step // self.decay_steps) + step_lr = xp.clip(step_lr, self.min_lr, None) return step_lr diff --git a/deepmd/jax/entrypoints/__init__.py b/deepmd/jax/entrypoints/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/entrypoints/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/entrypoints/freeze.py b/deepmd/jax/entrypoints/freeze.py new file mode 100644 index 0000000000..bd283e8681 --- /dev/null +++ b/deepmd/jax/entrypoints/freeze.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) + +from deepmd.jax.utils.serialization import ( + deserialize_to_file, + serialize_from_file, +) + + +def freeze( + *, + checkpoint_folder: str, + output: str, + **kwargs, +) -> None: + """Freeze the graph in supplied folder. + + Parameters + ---------- + checkpoint_folder : str + location of either the folder with checkpoint or the checkpoint prefix + output : str + output file name + **kwargs + other arguments + """ + if (Path(checkpoint_folder) / "checkpoint").is_file(): + checkpoint_meta = Path(checkpoint_folder) / "checkpoint" + checkpoint_folder = checkpoint_meta.read_text().strip() + if Path(checkpoint_folder).is_dir(): + data = serialize_from_file(checkpoint_folder) + deserialize_to_file(output, data) + else: + raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.") diff --git a/deepmd/jax/entrypoints/main.py b/deepmd/jax/entrypoints/main.py new file mode 100644 index 0000000000..6bbb9f08f7 --- /dev/null +++ b/deepmd/jax/entrypoints/main.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD-Kit entry point module.""" + +import argparse +from pathlib import ( + Path, +) +from typing import ( + Optional, + Union, +) + +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.jax.entrypoints.freeze import ( + freeze, +) +from deepmd.jax.entrypoints.train import ( + train, +) +from deepmd.loggers.loggers import ( + set_log_handles, +) +from deepmd.main import ( + parse_args, +) + +__all__ = ["main"] + + +def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: + """DeePMD-Kit entry point. + + Parameters + ---------- + args : list[str] or argparse.Namespace, optional + list of command line arguments, used to avoid calling from the subprocess, + as it is quite slow to import tensorflow; if Namespace is given, it will + be used directly + + Raises + ------ + RuntimeError + if no command was input + """ + if not isinstance(args, argparse.Namespace): + args = parse_args(args=args) + + dict_args = vars(args) + set_log_handles( + args.log_level, + Path(args.log_path) if args.log_path else None, + mpi_log=None, + ) + + if args.command == "train": + train(**dict_args) + elif args.command == "freeze": + dict_args["output"] = format_model_suffix( + dict_args["output"], preferred_backend=args.backend, strict_prefer=True + ) + freeze(**dict_args) + elif args.command is None: + pass + else: + raise RuntimeError(f"unknown command {args.command}") diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py new file mode 100644 index 0000000000..27b3e54e55 --- /dev/null +++ b/deepmd/jax/entrypoints/train.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD training entrypoint script. + +Can handle local or distributed training. +""" + +import json +import logging +import time +from typing import ( + Optional, +) + +from deepmd.common import ( + j_loader, +) +from deepmd.jax.env import ( + jax, + jax_export, +) +from deepmd.jax.train.trainer import ( + DPTrainer, +) +from deepmd.utils import random as dp_random +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + get_data, +) +from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter + +__all__ = ["train"] + +log = logging.getLogger(__name__) + + +class SummaryPrinter(BaseSummaryPrinter): + """Summary printer for JAX.""" + + def is_built_with_cuda(self) -> bool: + """Check if the backend is built with CUDA.""" + return jax_export.default_export_platform() == "cuda" + + def is_built_with_rocm(self) -> bool: + """Check if the backend is built with ROCm.""" + return jax_export.default_export_platform() == "rocm" + + def get_compute_device(self) -> str: + """Get Compute device.""" + return jax.default_backend() + + def get_ngpus(self) -> int: + """Get the number of GPUs.""" + return jax.device_count() + + def get_backend_info(self) -> dict: + """Get backend information.""" + return { + "Backend": "JAX", + "JAX ver": jax.__version__, + } + + +def train( + *, + INPUT: str, + init_model: Optional[str], + restart: Optional[str], + output: str, + init_frz_model: str, + mpi_log: str, + log_level: int, + log_path: Optional[str], + skip_neighbor_stat: bool = False, + finetune: Optional[str] = None, + use_pretrain_script: bool = False, + **kwargs, +) -> None: + """Run DeePMD model training. + + Parameters + ---------- + INPUT : str + json/yaml control file + init_model : Optional[str] + path prefix of checkpoint files or None + restart : Optional[str] + path prefix of checkpoint files or None + output : str + path for dump file with arguments + init_frz_model : str + path to frozen model or None + mpi_log : str + mpi logging mode + log_level : int + logging level defined by int 0-3 + log_path : Optional[str] + logging file path or None if logs are to be output only to stdout + skip_neighbor_stat : bool, default=False + skip checking neighbor statistics + finetune : Optional[str] + path to pretrained model or None + use_pretrain_script : bool + Whether to use model script in pretrained model when doing init-model or init-frz-model. + Note that this option is true and unchangeable for fine-tuning. + **kwargs + additional arguments + + Raises + ------ + RuntimeError + if distributed training job name is wrong + """ + # load json database + jdata = j_loader(INPUT) + + origin_type_map = None + + jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") + + jdata = normalize(jdata) + jdata = update_sel(jdata) + + with open(output, "w") as fp: + json.dump(jdata, fp, indent=4) + SummaryPrinter()() + + # make necessary checks + assert "training" in jdata + + # init the model + + model = DPTrainer( + jdata, + init_model=init_model, + restart=restart, + ) + rcut = model.model.get_rcut() + type_map = model.model.get_type_map() + if len(type_map) == 0: + ipt_type_map = None + else: + ipt_type_map = type_map + + # init random seed of data systems + seed = jdata["training"].get("seed", None) + if seed is not None: + seed = seed % (2**32) + dp_random.seed(seed) + + # init data + train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None) + train_data.add_data_requirements(model.data_requirements) + train_data.print_summary("training") + if jdata["training"].get("validation_data", None) is not None: + valid_data = get_data( + jdata["training"]["validation_data"], + rcut, + train_data.type_map, + None, + ) + valid_data.add_data_requirements(model.data_requirements) + valid_data.print_summary("validation") + else: + valid_data = None + + # get training info + stop_batch = jdata["training"]["numb_steps"] + origin_type_map = jdata["model"].get("origin_type_map", None) + if ( + origin_type_map is not None and not origin_type_map + ): # get the type_map from data if not provided + origin_type_map = get_data( + jdata["training"]["training_data"], rcut, None, None + ).get_type_map() + + # train the model with the provided systems in a cyclic way + start_time = time.time() + model.train(train_data, valid_data) + end_time = time.time() + log.info("finished training") + log.info(f"wall time: {(end_time - start_time):.3f} s") + + +def update_sel(jdata): + log.info( + "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" + ) + jdata_cpy = jdata.copy() + type_map = jdata["model"].get("type_map") + train_data = get_data( + jdata["training"]["training_data"], + 0, # not used + type_map, + None, # not used + ) + # TODO: OOM, need debug + # jdata_cpy["model"], min_nbor_dist = BaseModel.update_sel( + # train_data, type_map, jdata["model"] + # ) + return jdata_cpy diff --git a/deepmd/jax/train/__init__.py b/deepmd/jax/train/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/train/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py new file mode 100644 index 0000000000..408f67da9f --- /dev/null +++ b/deepmd/jax/train/trainer.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +import shutil +import time +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import numpy as np +import optax +import orbax.checkpoint as ocp + +from deepmd.common import ( + symlink_prefix_files, +) +from deepmd.dpmodel.loss.ener import ( + EnergyHessianLoss, + EnergyLoss, +) +from deepmd.dpmodel.model.transform_output import ( + communicate_extended_output, +) +from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.jax.env import ( + jnp, + nnx, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.model.model import ( + get_model, +) +from deepmd.jax.utils.serialization import ( + serialize_from_file, +) +from deepmd.loggers.training import ( + format_training_message, + format_training_message_per_task, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.model_stat import ( + make_stat_input, +) + +log = logging.getLogger(__name__) + + +class DPTrainer: + def __init__( + self, + jdata, + init_model: Optional[str] = None, + restart: Optional[str] = None, + ) -> None: + self.init_model = init_model + self.restart = restart + self.model_def_script = jdata["model"] + self.start_step = 0 + if self.init_model is not None: + model_dict = serialize_from_file(self.init_model) + self.model = BaseModel.deserialize(model_dict["model"]) + elif self.restart is not None: + model_dict = serialize_from_file(self.restart) + self.model = BaseModel.deserialize(model_dict["model"]) + self.start_step = model_dict["@variables"].get("current_step", 0) + else: + # from scratch + self.model = get_model(jdata["model"]) + self.training_param = jdata["training"] + self.num_steps = self.training_param["numb_steps"] + + def get_lr_and_coef(lr_param): + lr_type = lr_param.get("type", "exp") + if lr_type == "exp": + lr = LearningRateExp( + lr_param["start_lr"], + lr_param["stop_lr"], + lr_param["decay_steps"], + self.num_steps, + ) + else: + raise RuntimeError("unknown learning_rate type " + lr_type) + return lr + + learning_rate_param = jdata["learning_rate"] + self.lr = get_lr_and_coef(learning_rate_param) + loss_param = jdata.get("loss", {}) + loss_param["starter_learning_rate"] = learning_rate_param["start_lr"] + + loss_type = loss_param.get("type", "ener") + if loss_type == "ener" and loss_param.get("start_pref_h", 0.0) > 0.0: + self.loss = EnergyHessianLoss.get_loss(loss_param) + self.model.enable_hessian() + elif loss_type == "ener": + self.loss = EnergyLoss.get_loss(loss_param) + else: + raise RuntimeError("unknown loss type " + loss_type) + + # training + tr_data = jdata["training"] + self.disp_file = tr_data.get("disp_file", "lcurve.out") + self.disp_freq = tr_data.get("disp_freq", 1000) + self.save_freq = tr_data.get("save_freq", 1000) + self.save_ckpt = tr_data.get("save_ckpt", "model.ckpt") + self.max_ckpt_keep = tr_data.get("max_ckpt_keep", 5) + self.display_in_training = tr_data.get("disp_training", True) + self.timing_in_training = tr_data.get("time_training", True) + self.profiling = tr_data.get("profiling", False) + self.profiling_file = tr_data.get("profiling_file", "timeline.json") + self.enable_profiler = tr_data.get("enable_profiler", False) + self.tensorboard = tr_data.get("tensorboard", False) + self.tensorboard_log_dir = tr_data.get("tensorboard_log_dir", "log") + self.tensorboard_freq = tr_data.get("tensorboard_freq", 1) + self.mixed_prec = tr_data.get("mixed_precision", None) + self.change_bias_after_training = tr_data.get( + "change_bias_after_training", False + ) + self.numb_fparam = self.model.get_dim_fparam() + + if tr_data.get("validation_data", None) is not None: + self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) + else: + self.valid_numb_batch = 1 + + # if init the graph with the frozen model + self.frz_model = None + self.ckpt_meta = None + self.model_type = None + + @property + def data_requirements(self) -> list[DataRequirementItem]: + return self.loss.label_requirement + + def train(self, train_data, valid_data=None) -> None: + model = self.model + tx = optax.adam( + learning_rate=lambda step: self.lr.value(self.start_step + step, xp=jnp), + ) + optimizer = nnx.Optimizer(model, tx) + + # data stat + if self.init_model is None and self.restart is None: + data_stat_nbatch = 10 # TODO + all_stat = make_stat_input(train_data, data_stat_nbatch, merge_sys=False) + all_stat["atype"] = all_stat.pop("type") + + # swap dict key and list idx + all_stat_sys = [ + { + kk: jnp.asarray(np.concatenate(vv[ii], axis=0)) + for kk, vv in all_stat.items() + if not kk.startswith("find_") + } + for ii in range(train_data.get_nsystems()) + ] + model.atomic_model.descriptor.compute_input_stats(all_stat_sys) + model.atomic_model.fitting.compute_output_stats(all_stat) + + def loss_fn( + model, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ): + model_dict_lower = model.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + model_dict = communicate_extended_output( + model_dict_lower, + model.model_output_def(), + mapping, + do_atomic_virial=False, + ) + loss, more_loss = self.loss( + learning_rate=lr, + natoms=label_dict["coord"].shape[1], + model_dict=model_dict, + label_dict=label_dict, + ) + return loss + + @nnx.jit + def loss_fn_more_loss( + model, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ): + model_dict_lower = model.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + model_dict = communicate_extended_output( + model_dict_lower, + model.model_output_def(), + mapping, + do_atomic_virial=False, + ) + loss, more_loss = self.loss( + learning_rate=lr, + natoms=label_dict["coord"].shape[1], + model_dict=model_dict, + label_dict=label_dict, + ) + return more_loss + + @nnx.jit + def train_step( + model, + optimizer, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ): + grads = nnx.grad(loss_fn)( + model, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + optimizer.update(grads) + + start_time = time.time() + disp_file_fp = open(self.disp_file, "w") + for step in range(self.start_step, self.num_steps): + batch_data = train_data.get_batch() + # numpy to jax + jax_data = { + kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item()) + for kk, vv in batch_data.items() + } + extended_coord, extended_atype, nlist, mapping, fp, ap = prepare_input( + rcut=model.get_rcut(), + sel=model.get_sel(), + coord=jax_data["coord"], + atype=jax_data["type"], + box=jax_data["box"] if jax_data["find_box"] else None, + fparam=jax_data.get("fparam", None), + aparam=jax_data.get("aparam", None), + ) + train_step( + model, + optimizer, + self.lr.value(step), + jax_data, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + if self.display_in_training and ( + step == 0 or (step + 1) % self.disp_freq == 0 + ): + wall_time = time.time() - start_time + log.info( + format_training_message( + batch=step + 1, + wall_time=wall_time, + ) + ) + more_loss = loss_fn_more_loss( + optimizer.model, + self.lr.value(step), + jax_data, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + if valid_data is not None: + valid_batch_data = valid_data.get_batch() + jax_valid_data = { + kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() + } + extended_coord, extended_atype, nlist, mapping, fp, ap = ( + prepare_input( + rcut=model.get_rcut(), + sel=model.get_sel(), + coord=jax_valid_data["coord"], + atype=jax_valid_data["type"], + box=jax_valid_data["box"] + if jax_valid_data["find_box"] + else None, + fparam=jax_valid_data.get("fparam", None), + aparam=jax_valid_data.get("aparam", None), + ) + ) + valid_more_loss = loss_fn_more_loss( + optimizer.model, + self.lr.value(step), + jax_valid_data, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + else: + valid_more_loss = None + self.print_on_training( + disp_file_fp, + train_results=more_loss, + valid_results=valid_more_loss, + cur_batch=step + 1, + cur_lr=self.lr.value(step), + ) + start_time = time.time() + if (step + 1) % self.save_freq == 0: + # save model + _, state = nnx.split(model) + ckpt_path = Path(f"{self.save_ckpt}-{step + 1}.jax") + if ckpt_path.is_dir(): + # remove old checkpoint if it exists + shutil.rmtree(ckpt_path) + model_def_script_cpy = self.model_def_script.copy() + model_def_script_cpy["current_step"] = step + 1 + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + checkpointer.save( + ckpt_path.absolute(), + ocp.args.Composite( + state=ocp.args.StandardSave(state.to_pure_dict()), + model_def_script=ocp.args.JsonSave(model_def_script_cpy), + ), + ) + log.info(f"Trained model has been saved to: {ckpt_path!s}") + symlink_prefix_files(f"{self.save_ckpt}-{step + 1}", self.save_ckpt) + with open("checkpoint", "w") as fp: + fp.write(f"{self.save_ckpt}.jax") + + disp_file_fp.close() + + @staticmethod + def print_on_training( + fp, + train_results, + valid_results, + cur_batch, + cur_lr, + ) -> None: + print_str = "" + print_str += f"{cur_batch:7d}" + if valid_results is not None: + prop_fmt = " %11.2e %11.2e" + for k in valid_results.keys(): + # assert k in train_results.keys() + print_str += prop_fmt % (valid_results[k], train_results[k]) + else: + prop_fmt = " %11.2e" + for k in train_results.keys(): + print_str += prop_fmt % (train_results[k]) + print_str += f" {cur_lr:8.1e}\n" + log.info( + format_training_message_per_task( + batch=cur_batch, + task_name="trn", + rmse=train_results, + learning_rate=cur_lr, + ) + ) + if valid_results is not None: + log.info( + format_training_message_per_task( + batch=cur_batch, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + fp.write(print_str) + fp.flush() + + +def prepare_input( + *, # enforce keyword-only arguments + rcut: float, + sel: list[int], + coord: np.ndarray, + atype: np.ndarray, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, +): + nframes, nloc = atype.shape[:2] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if bb is not None: + coord_normalized = normalize_coord( + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), + ) + else: + coord_normalized = cc.copy() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + return extended_coord, extended_atype, nlist, mapping, fp, ap diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d4da49e08..454affba31 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -177,6 +177,7 @@ def convert_str_to_int_key(item: dict) -> None: convert_str_to_int_key(state) model_def_script = data.model_def_script + current_step = model_def_script.pop("current_step", 0) abstract_model = get_model(model_def_script) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(state) @@ -187,7 +188,9 @@ def convert_str_to_int_key(item: dict) -> None: "jax_version": jax.__version__, "model": model_dict, "model_def_script": model_def_script, - "@variables": {}, + "@variables": { + "current_step": current_step, + }, } return data elif model_file.endswith(".hlo"):