-
Notifications
You must be signed in to change notification settings - Fork 556
feat: JAX training #4782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
feat: JAX training #4782
Changes from all commits
c3e9ce9
ac5a7a3
b99d66d
ad20de9
c62c356
d0a1ce7
e88838b
29922d7
d5d5f06
6947ee6
9218157
e3dca7a
1fdb40c
1474327
d7f06d6
68b3727
7930827
a0cd67a
d21c39c
15bb506
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -177,7 +177,9 @@ | |||||
delta=self.huber_delta, | ||||||
) | ||||||
loss += pref_e * l_huber_loss | ||||||
more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy) | ||||||
more_loss["rmse_e"] = self.display_if_exist( | ||||||
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy | ||||||
) | ||||||
if self.has_f: | ||||||
l2_force_loss = xp.mean(xp.square(diff_f)) | ||||||
if not self.use_huber: | ||||||
|
@@ -189,7 +191,9 @@ | |||||
delta=self.huber_delta, | ||||||
) | ||||||
loss += pref_f * l_huber_loss | ||||||
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force) | ||||||
more_loss["rmse_f"] = self.display_if_exist( | ||||||
xp.sqrt(l2_force_loss), find_force | ||||||
) | ||||||
if self.has_v: | ||||||
virial_reshape = xp.reshape(virial, [-1]) | ||||||
virial_hat_reshape = xp.reshape(virial_hat, [-1]) | ||||||
|
@@ -381,3 +385,79 @@ | |||||
check_version_compatibility(data.pop("@version"), 2, 1) | ||||||
data.pop("@class") | ||||||
return cls(**data) | ||||||
|
||||||
|
||||||
class EnergyHessianLoss(EnergyLoss): | ||||||
def __init__( | ||||||
self, | ||||||
start_pref_h=0.0, | ||||||
limit_pref_h=0.0, | ||||||
**kwargs, | ||||||
): | ||||||
r"""Enable the layer to compute loss on hessian. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
start_pref_h : float | ||||||
The prefactor of hessian loss at the start of the training. | ||||||
limit_pref_h : float | ||||||
The prefactor of hessian loss at the end of the training. | ||||||
**kwargs | ||||||
Other keyword arguments. | ||||||
""" | ||||||
EnergyLoss.__init__(self, **kwargs) | ||||||
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix logical condition for enabling Hessian loss. The condition uses Apply this fix: - self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+ self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||
|
||||||
self.start_pref_h = start_pref_h | ||||||
self.limit_pref_h = limit_pref_h | ||||||
|
||||||
def call( | ||||||
self, | ||||||
learning_rate: float, | ||||||
natoms: int, | ||||||
model_dict: dict[str, np.ndarray], | ||||||
label_dict: dict[str, np.ndarray], | ||||||
) -> dict[str, np.ndarray]: | ||||||
"""Calculate loss from model results and labeled results.""" | ||||||
loss, more_loss = EnergyLoss.call( | ||||||
self, learning_rate, natoms, model_dict, label_dict | ||||||
) | ||||||
xp = array_api_compat.array_namespace(model_dict["energy"]) | ||||||
coef = learning_rate / self.starter_learning_rate | ||||||
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef | ||||||
|
||||||
if ( | ||||||
self.has_h | ||||||
and "energy_derv_r_derv_r" in model_dict | ||||||
and "hessian" in label_dict | ||||||
): | ||||||
find_hessian = label_dict.get("find_hessian", 0.0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainVerify find_hessian data type consistency. The code uses Let me verify the expected data type for find_* variables: 🏁 Script executed: #!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py Length of output: 1419 🏁 Script executed: #!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
"find_$_",
$$$
)' Length of output: 64 Ensure The other • File deepmd/dpmodel/loss/ener.py, line 434 - find_hessian = label_dict.get("find_hessian", 0.0)
+ find_hessian = label_dict["find_hessian"] If you do need a default, consider instead: find_hessian = label_dict.get(
"find_hessian",
xp.ones(label_dict["hessian"].shape[0]) # match per-frame flag shape
) 🤖 Prompt for AI Agents
|
||||||
pref_h = pref_h * find_hessian | ||||||
diff_h = label_dict["hessian"].reshape( | ||||||
-1, | ||||||
) - model_dict["energy_derv_r_derv_r"].reshape( | ||||||
-1, | ||||||
) | ||||||
l2_hessian_loss = xp.mean(xp.square(diff_h)) | ||||||
loss += pref_h * l2_hessian_loss | ||||||
rmse_h = xp.sqrt(l2_hessian_loss) | ||||||
more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian) | ||||||
|
||||||
more_loss["rmse"] = xp.sqrt(loss) | ||||||
return loss, more_loss | ||||||
|
||||||
@property | ||||||
def label_requirement(self) -> list[DataRequirementItem]: | ||||||
"""Add hessian label requirement needed for this loss calculation.""" | ||||||
label_requirement = super().label_requirement | ||||||
if self.has_h: | ||||||
label_requirement.append( | ||||||
DataRequirementItem( | ||||||
"hessian", | ||||||
ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms | ||||||
atomic=True, | ||||||
must=False, | ||||||
high_prec=False, | ||||||
) | ||||||
) | ||||||
return label_requirement | ||||||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from pathlib import ( | ||
Path, | ||
) | ||
|
||
from deepmd.jax.utils.serialization import ( | ||
deserialize_to_file, | ||
serialize_from_file, | ||
) | ||
|
||
|
||
def freeze( | ||
*, | ||
checkpoint_folder: str, | ||
output: str, | ||
**kwargs, | ||
) -> None: | ||
"""Freeze the graph in supplied folder. | ||
|
||
Parameters | ||
---------- | ||
checkpoint_folder : str | ||
location of either the folder with checkpoint or the checkpoint prefix | ||
output : str | ||
output file name | ||
**kwargs | ||
other arguments | ||
""" | ||
if (Path(checkpoint_folder) / "checkpoint").is_file(): | ||
checkpoint_meta = Path(checkpoint_folder) / "checkpoint" | ||
checkpoint_folder = checkpoint_meta.read_text().strip() | ||
if Path(checkpoint_folder).is_dir(): | ||
data = serialize_from_file(checkpoint_folder) | ||
deserialize_to_file(output, data) | ||
else: | ||
raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.") | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""DeePMD-Kit entry point module.""" | ||
|
||
import argparse | ||
from pathlib import ( | ||
Path, | ||
) | ||
from typing import ( | ||
Optional, | ||
Union, | ||
) | ||
|
||
from deepmd.backend.suffix import ( | ||
format_model_suffix, | ||
) | ||
from deepmd.jax.entrypoints.freeze import ( | ||
freeze, | ||
) | ||
from deepmd.jax.entrypoints.train import ( | ||
train, | ||
) | ||
from deepmd.loggers.loggers import ( | ||
set_log_handles, | ||
) | ||
from deepmd.main import ( | ||
parse_args, | ||
) | ||
|
||
__all__ = ["main"] | ||
|
||
|
||
def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: | ||
"""DeePMD-Kit entry point. | ||
|
||
Parameters | ||
---------- | ||
args : list[str] or argparse.Namespace, optional | ||
list of command line arguments, used to avoid calling from the subprocess, | ||
as it is quite slow to import tensorflow; if Namespace is given, it will | ||
be used directly | ||
|
||
Raises | ||
------ | ||
RuntimeError | ||
if no command was input | ||
""" | ||
if not isinstance(args, argparse.Namespace): | ||
args = parse_args(args=args) | ||
|
||
dict_args = vars(args) | ||
set_log_handles( | ||
args.log_level, | ||
Path(args.log_path) if args.log_path else None, | ||
mpi_log=None, | ||
) | ||
|
||
if args.command == "train": | ||
train(**dict_args) | ||
elif args.command == "freeze": | ||
dict_args["output"] = format_model_suffix( | ||
dict_args["output"], preferred_backend=args.backend, strict_prefer=True | ||
) | ||
freeze(**dict_args) | ||
elif args.command is None: | ||
pass | ||
else: | ||
raise RuntimeError(f"unknown command {args.command}") | ||
Comment on lines
+64
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve command handling for better error reporting. The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful. Apply this diff to improve error handling: - elif args.command is None:
- pass
+ elif args.command is None:
+ raise RuntimeError("No command specified. Available commands: train, freeze")
else:
- raise RuntimeError(f"unknown command {args.command}")
+ raise RuntimeError(
+ f"Unknown command '{args.command}'. Available commands: train, freeze"
+ ) 🤖 Prompt for AI Agents
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined attribute reference.
The code references
self.atom_ener_v
which is not defined in the class. This will cause anAttributeError
at runtime.Based on the similar implementation in
deepmd/tf/fit/ener.py
, this should likely beself.atom_ener
:📝 Committable suggestion
🤖 Prompt for AI Agents