Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4c3072e
Init branch
Chengqian-Zhang Nov 8, 2025
2c2fb83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2025
cf93f90
delete jit
Chengqian-Zhang Nov 8, 2025
9cfe69d
Merge branch '1108_default_fparam_stat' of https://github.com/Chengqi…
Chengqian-Zhang Nov 8, 2025
a4d9dce
fix typo
Chengqian-Zhang Nov 8, 2025
c3eae48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 8, 2025
9b76f49
fix typo
Chengqian-Zhang Nov 8, 2025
7b607b0
fix conflict
Chengqian-Zhang Nov 8, 2025
07861fe
Add UT for writing/load fitting stat to/from stat_file
Chengqian-Zhang Nov 9, 2025
5cfbf6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2025
489fe84
Add UT to fitting stat when using share fitting
Chengqian-Zhang Nov 9, 2025
52aba93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2025
6f18144
Merge branch 'devel' into 1108_default_fparam_stat
Chengqian-Zhang Nov 9, 2025
ddc80d6
Merge branch 'devel' into 1108_default_fparam_stat
Chengqian-Zhang Nov 18, 2025
2b7ff28
Change StatItem.number from int to float
Chengqian-Zhang Nov 18, 2025
d6120a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
07483a7
Merge branch 'devel' into 1108_default_fparam_stat
Chengqian-Zhang Nov 23, 2025
a9d28f7
Solve conflict
Chengqian-Zhang Nov 23, 2025
dbe92ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2025
9ba8496
Fix UT
Chengqian-Zhang Nov 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,14 @@ def wrapped_sampler():
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.compute_fitting_input_stat(wrapped_sampler)
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def compute_fitting_input_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Expand All @@ -416,9 +417,13 @@ def compute_fitting_input_stat(
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The dictionary of paths to the statistics files.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
sample_merged,
protection=self.data_stat_protect,
stat_file_path=stat_file_path,
)

def get_dim_fparam(self) -> int:
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pd/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
get_index_between_two_maps,
map_atom_exclude_types,
)
from deepmd.utils.path import (
DPPath,
)

dtype = env.GLOBAL_PD_FLOAT_PRECISION
device = env.DEVICE
Expand Down Expand Up @@ -76,6 +79,7 @@ def compute_input_stats(
self,
merged: Union[Callable[[], list[dict]], list[dict]],
protection: float = 1e-2,
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
Expand All @@ -91,6 +95,8 @@ def compute_input_stats(
the lazy function helps by only sampling once.
protection : float
Divided-by-zero protection
stat_file_path : Optional[DPPath]
The path to the stat file.
"""
if self.numb_fparam == 0 and self.numb_aparam == 0:
# skip data statistics
Expand Down
21 changes: 19 additions & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,26 @@ def wrapped_sampler() -> list[dict]:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
default_fparam = self.get_default_fparam()
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = default_fparam.repeat(nframe, 1)
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.compute_fitting_input_stat(wrapped_sampler)
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def compute_fitting_input_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Expand All @@ -348,9 +358,13 @@ def compute_fitting_input_stat(
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The dictionary of paths to the statistics files.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
sample_merged,
protection=self.data_stat_protect,
stat_file_path=stat_file_path,
)

def get_dim_fparam(self) -> int:
Expand All @@ -361,6 +375,9 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.fitting_net.has_default_fparam()

def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.fitting_net.get_default_fparam()

def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return self.fitting_net.get_dim_aparam()
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return self.atomic_model.has_default_fparam()

def get_default_fparam(self) -> Optional[torch.Tensor]:
return self.atomic_model.get_default_fparam()

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
Expand Down
Loading