Skip to content

feat: JAX training #4782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.entrypoints.main import (

Check warning on line 63 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L63

Added line #L63 was not covered by tests
main,
)

return main

Check warning on line 67 in deepmd/backend/jax.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/jax.py#L67

Added line #L67 was not covered by tests

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ def __init__(
self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision])
self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision])
self.orig_sel = self.sel
self.ndescrpt = self.nnei * 4

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
72 changes: 72 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Union,
)

import numpy as np

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
Expand All @@ -17,6 +19,10 @@
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)

from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -86,3 +92,69 @@
**super().serialize(),
"type": "ener",
}

def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None:
"""Compute the output statistics.

Parameters
----------
all_stat
must have the following components:
all_stat['energy'] of shape n_sys x n_batch x n_frame
can be prepared by model.make_stat_input
mixed_type
Whether to perform the mixed_type mode.
If True, the input data has the mixed_type format (see doc/model/train_se_atten.md),
in which frames in a system may have different natoms_vec(s), with the same nloc.
"""
self.bias_atom_e = self._compute_output_stats(

Check warning on line 110 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L110

Added line #L110 was not covered by tests
all_stat, rcond=self.rcond, mixed_type=mixed_type
)

def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
data = all_stat["energy"]

Check warning on line 115 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L115

Added line #L115 was not covered by tests
# data[sys_idx][batch_idx][frame_idx]
sys_ener = []
for ss in range(len(data)):
sys_data = []
for ii in range(len(data[ss])):
for jj in range(len(data[ss][ii])):
sys_data.append(data[ss][ii][jj])
sys_data = np.concatenate(sys_data)
sys_ener.append(np.average(sys_data))
sys_ener = np.array(sys_ener)
sys_tynatom = []
if mixed_type:
data = all_stat["real_natoms_vec"]
nsys = len(data)
for ss in range(len(data)):
tmp_tynatom = []
for ii in range(len(data[ss])):
for jj in range(len(data[ss][ii])):
tmp_tynatom.append(data[ss][ii][jj].astype(np.float64))
tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0)
sys_tynatom.append(tmp_tynatom)

Check warning on line 136 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L117-L136

Added lines #L117 - L136 were not covered by tests
else:
data = all_stat["natoms_vec"]
nsys = len(data)
for ss in range(len(data)):
sys_tynatom.append(data[ss][0].astype(np.float64))
sys_tynatom = np.array(sys_tynatom)
sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
sys_tynatom = sys_tynatom[:, 2:]
if len(self.atom_ener) > 0:

Check warning on line 145 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L138-L145

Added lines #L138 - L145 were not covered by tests
# Atomic energies stats are incorrect if atomic energies are assigned.
# In this situation, we directly use these assigned energies instead of computing stats.
# This will make the loss decrease quickly
assigned_atom_ener = np.array(

Check warning on line 149 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L149

Added line #L149 was not covered by tests
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix undefined attribute reference.

The code references self.atom_ener_v which is not defined in the class. This will cause an AttributeError at runtime.

Based on the similar implementation in deepmd/tf/fit/ener.py, this should likely be self.atom_ener:

-                [ee if ee is not None else np.nan for ee in self.atom_ener_v]
+                [ee if ee is not None else np.nan for ee in self.atom_ener]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
[ee if ee is not None else np.nan for ee in self.atom_ener_v]
[ee if ee is not None else np.nan for ee in self.atom_ener]
🤖 Prompt for AI Agents
In deepmd/dpmodel/fitting/ener_fitting.py at line 150, the code references an
undefined attribute self.atom_ener_v, causing an AttributeError. Replace
self.atom_ener_v with self.atom_ener to match the correct attribute name used in
the class, following the pattern from deepmd/tf/fit/ener.py.

)
else:
assigned_atom_ener = None
energy_shift, _ = compute_stats_from_redu(

Check warning on line 154 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L153-L154

Added lines #L153 - L154 were not covered by tests
sys_ener.reshape(-1, 1),
sys_tynatom,
assigned_bias=assigned_atom_ener,
rcond=rcond,
)
return energy_shift.ravel()

Check warning on line 160 in deepmd/dpmodel/fitting/ener_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/ener_fitting.py#L160

Added line #L160 was not covered by tests
84 changes: 82 additions & 2 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy)
more_loss["rmse_e"] = self.display_if_exist(
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
)
if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))
if not self.use_huber:
Expand All @@ -189,7 +191,9 @@
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
more_loss["rmse_f"] = self.display_if_exist(
xp.sqrt(l2_force_loss), find_force
)
if self.has_v:
virial_reshape = xp.reshape(virial, [-1])
virial_hat_reshape = xp.reshape(virial_hat, [-1])
Expand Down Expand Up @@ -381,3 +385,79 @@
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
return cls(**data)


class EnergyHessianLoss(EnergyLoss):
def __init__(
self,
start_pref_h=0.0,
limit_pref_h=0.0,
**kwargs,
):
r"""Enable the layer to compute loss on hessian.

Parameters
----------
start_pref_h : float
The prefactor of hessian loss at the start of the training.
limit_pref_h : float
The prefactor of hessian loss at the end of the training.
**kwargs
Other keyword arguments.
"""
EnergyLoss.__init__(self, **kwargs)
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0

Check warning on line 409 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L408-L409

Added lines #L408 - L409 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix logical condition for enabling Hessian loss.

The condition uses and which requires both prefactors to be non-zero. This is likely too restrictive - the Hessian loss should be enabled if either prefactor is non-zero.

Apply this fix:

-        self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
+        self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0
self.has_h = start_pref_h != 0.0 or limit_pref_h != 0.0
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 409, the condition for enabling Hessian
loss uses 'and' to check if both start_pref_h and limit_pref_h are non-zero,
which is too restrictive. Change the logical operator from 'and' to 'or' so that
the Hessian loss is enabled if either start_pref_h or limit_pref_h is non-zero.


self.start_pref_h = start_pref_h
self.limit_pref_h = limit_pref_h

Check warning on line 412 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L411-L412

Added lines #L411 - L412 were not covered by tests

def call(
self,
learning_rate: float,
natoms: int,
model_dict: dict[str, np.ndarray],
label_dict: dict[str, np.ndarray],
) -> dict[str, np.ndarray]:
"""Calculate loss from model results and labeled results."""
loss, more_loss = EnergyLoss.call(

Check warning on line 422 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L422

Added line #L422 was not covered by tests
self, learning_rate, natoms, model_dict, label_dict
)
xp = array_api_compat.array_namespace(model_dict["energy"])
coef = learning_rate / self.starter_learning_rate
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef

Check warning on line 427 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L425-L427

Added lines #L425 - L427 were not covered by tests

if (

Check warning on line 429 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L429

Added line #L429 was not covered by tests
self.has_h
and "energy_derv_r_derv_r" in model_dict
and "hessian" in label_dict
):
find_hessian = label_dict.get("find_hessian", 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify find_hessian data type consistency.

The code uses label_dict.get("find_hessian", 0.0) which returns a scalar default, but other find_* variables in the parent class appear to be arrays from the labeled data.

Let me verify the expected data type for find_* variables:


🏁 Script executed:

#!/bin/bash
# Search for find_* variable usage patterns in loss calculations
rg -B 2 -A 2 'find_[a-z_]+\s*=' deepmd/dpmodel/loss/ --type py

Length of output: 1419


🏁 Script executed:

#!/bin/bash
# Search for DataRequirementItem definitions that might indicate find_* data structure
ast-grep --pattern 'DataRequirementItem(
  "find_$_",
  $$$
)'

Length of output: 64


Ensure find_hessian is accessed like the other find_* flags

The other find_* variables are pulled directly from label_dict (no default), so using .get(..., 0.0) here is inconsistent and may silently disable Hessian loss even when "hessian" is present. Please update this so that missing flags are surfaced (or, if you really intend an optional flag, default to a matching array of ones).

• File deepmd/dpmodel/loss/ener.py, line 434

-    find_hessian = label_dict.get("find_hessian", 0.0)
+    find_hessian = label_dict["find_hessian"]

If you do need a default, consider instead:

find_hessian = label_dict.get(
    "find_hessian",
    xp.ones(label_dict["hessian"].shape[0])  # match per-frame flag shape
)
🤖 Prompt for AI Agents
In deepmd/dpmodel/loss/ener.py at line 434, the assignment of find_hessian uses
label_dict.get with a scalar default 0.0, which is inconsistent with other
find_* variables that do not use defaults and are arrays. To fix this, remove
the default value so that missing keys raise an error or, if a default is
necessary, set it to an array of ones matching the shape of
label_dict["hessian"]. This ensures data type consistency and proper handling of
the find_hessian flag.

pref_h = pref_h * find_hessian
diff_h = label_dict["hessian"].reshape(

Check warning on line 436 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L434-L436

Added lines #L434 - L436 were not covered by tests
-1,
) - model_dict["energy_derv_r_derv_r"].reshape(
-1,
)
l2_hessian_loss = xp.mean(xp.square(diff_h))
loss += pref_h * l2_hessian_loss
rmse_h = xp.sqrt(l2_hessian_loss)
more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian)

Check warning on line 444 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L441-L444

Added lines #L441 - L444 were not covered by tests

more_loss["rmse"] = xp.sqrt(loss)
return loss, more_loss

Check warning on line 447 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L446-L447

Added lines #L446 - L447 were not covered by tests

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Add hessian label requirement needed for this loss calculation."""
label_requirement = super().label_requirement
if self.has_h:
label_requirement.append(

Check warning on line 454 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L452-L454

Added lines #L452 - L454 were not covered by tests
DataRequirementItem(
"hessian",
ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms
atomic=True,
must=False,
high_prec=False,
)
)
return label_requirement

Check warning on line 463 in deepmd/dpmodel/loss/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/loss/ener.py#L463

Added line #L463 was not covered by tests
7 changes: 5 additions & 2 deletions deepmd/dpmodel/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,15 @@ def iter(
"last_dim should be 1 for raial-only or 4 for full descriptor."
)
for system in data:
coord, atype, box, natoms = (
coord, atype, box = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
coord = xp.reshape(coord, (coord.shape[0], -1, 3)) # (nframes, nloc, 3)
atype = xp.reshape(atype, (coord.shape[0], -1)) # (nframes, nloc)
if box is not None:
box = xp.reshape(box, (coord.shape[0], 3, 3))
(
extended_coord,
extended_atype,
Expand Down
7 changes: 3 additions & 4 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def __init__(
self.decay_rate = decay_rate
self.min_lr = stop_lr

def value(self, step) -> np.float64:
def value(self, step, xp=np) -> np.float64:
"""Get the learning rate at the given step."""
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
if step_lr < self.min_lr:
step_lr = self.min_lr
step_lr = self.start_lr * xp.power(self.decay_rate, step // self.decay_steps)
step_lr = xp.clip(step_lr, self.min_lr, None)
return step_lr
1 change: 1 addition & 0 deletions deepmd/jax/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
36 changes: 36 additions & 0 deletions deepmd/jax/entrypoints/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from pathlib import (

Check warning on line 2 in deepmd/jax/entrypoints/freeze.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/freeze.py#L2

Added line #L2 was not covered by tests
Path,
)

from deepmd.jax.utils.serialization import (

Check warning on line 6 in deepmd/jax/entrypoints/freeze.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/freeze.py#L6

Added line #L6 was not covered by tests
deserialize_to_file,
serialize_from_file,
)


def freeze(

Check warning on line 12 in deepmd/jax/entrypoints/freeze.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/freeze.py#L12

Added line #L12 was not covered by tests
*,
checkpoint_folder: str,
output: str,
**kwargs,
) -> None:
"""Freeze the graph in supplied folder.

Parameters
----------
checkpoint_folder : str
location of either the folder with checkpoint or the checkpoint prefix
output : str
output file name
**kwargs
other arguments
"""
if (Path(checkpoint_folder) / "checkpoint").is_file():
checkpoint_meta = Path(checkpoint_folder) / "checkpoint"
checkpoint_folder = checkpoint_meta.read_text().strip()
if Path(checkpoint_folder).is_dir():
data = serialize_from_file(checkpoint_folder)
deserialize_to_file(output, data)

Check warning on line 34 in deepmd/jax/entrypoints/freeze.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/freeze.py#L29-L34

Added lines #L29 - L34 were not covered by tests
else:
raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.")

Check warning on line 36 in deepmd/jax/entrypoints/freeze.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/freeze.py#L36

Added line #L36 was not covered by tests
67 changes: 67 additions & 0 deletions deepmd/jax/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeePMD-Kit entry point module."""

import argparse
from pathlib import (

Check warning on line 5 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L4-L5

Added lines #L4 - L5 were not covered by tests
Path,
)
from typing import (

Check warning on line 8 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L8

Added line #L8 was not covered by tests
Optional,
Union,
)

from deepmd.backend.suffix import (

Check warning on line 13 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L13

Added line #L13 was not covered by tests
format_model_suffix,
)
from deepmd.jax.entrypoints.freeze import (

Check warning on line 16 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L16

Added line #L16 was not covered by tests
freeze,
)
from deepmd.jax.entrypoints.train import (

Check warning on line 19 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L19

Added line #L19 was not covered by tests
train,
)
from deepmd.loggers.loggers import (

Check warning on line 22 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L22

Added line #L22 was not covered by tests
set_log_handles,
)
from deepmd.main import (

Check warning on line 25 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L25

Added line #L25 was not covered by tests
parse_args,
)

__all__ = ["main"]

Check warning on line 29 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L29

Added line #L29 was not covered by tests


def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:

Check warning on line 32 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L32

Added line #L32 was not covered by tests
"""DeePMD-Kit entry point.

Parameters
----------
args : list[str] or argparse.Namespace, optional
list of command line arguments, used to avoid calling from the subprocess,
as it is quite slow to import tensorflow; if Namespace is given, it will
be used directly

Raises
------
RuntimeError
if no command was input
"""
if not isinstance(args, argparse.Namespace):
args = parse_args(args=args)

Check warning on line 48 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L47-L48

Added lines #L47 - L48 were not covered by tests

dict_args = vars(args)
set_log_handles(

Check warning on line 51 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L50-L51

Added lines #L50 - L51 were not covered by tests
args.log_level,
Path(args.log_path) if args.log_path else None,
mpi_log=None,
)

if args.command == "train":
train(**dict_args)
elif args.command == "freeze":
dict_args["output"] = format_model_suffix(

Check warning on line 60 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L57-L60

Added lines #L57 - L60 were not covered by tests
dict_args["output"], preferred_backend=args.backend, strict_prefer=True
)
freeze(**dict_args)
elif args.command is None:
pass

Check warning on line 65 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L63-L65

Added lines #L63 - L65 were not covered by tests
else:
raise RuntimeError(f"unknown command {args.command}")

Check warning on line 67 in deepmd/jax/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/entrypoints/main.py#L67

Added line #L67 was not covered by tests
Comment on lines +64 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve command handling for better error reporting.

The current implementation silently passes when no command is provided, which might hide configuration issues. Additionally, the error message for unknown commands could be more helpful.

Apply this diff to improve error handling:

-    elif args.command is None:
-        pass
+    elif args.command is None:
+        raise RuntimeError("No command specified. Available commands: train, freeze")
     else:
-        raise RuntimeError(f"unknown command {args.command}")
+        raise RuntimeError(
+            f"Unknown command '{args.command}'. Available commands: train, freeze"
+        )
🤖 Prompt for AI Agents
In deepmd/jax/entrypoints/main.py around lines 64 to 67, replace the silent pass
when args.command is None with a clear error message indicating that no command
was provided. Also, enhance the RuntimeError message for unknown commands to
suggest checking available commands or usage. This improves error reporting by
explicitly handling missing commands and providing more informative feedback for
unknown commands.

Loading
Loading