Skip to content
Open
18 changes: 18 additions & 0 deletions deepmd/pd/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,24 @@ def change_out_bias(
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
pass

def _get_forward_wrapper_func(self) -> Callable[..., paddle.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

Expand Down
20 changes: 20 additions & 0 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,26 @@ def wrapped_sampler():
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting_net.get_dim_fparam()
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pd/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def change_out_bias(
merged,
bias_adjust_mode=bias_adjust_mode,
)
if bias_adjust_mode == "set-by-statistic":
self.atomic_model.compute_fitting_stat(merged)

def forward_common_lower(
self,
Expand Down
18 changes: 18 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,24 @@ def change_out_bias(
else:
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)

def compute_fitting_stat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend changing the method name to compute_fitting_input_stat.

self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.
"""Compute the input statistics (e.g. mean and stddev) for the atomic model from packed data.


Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
pass

def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]:
"""Get a forward wrapper of the atomic model for output bias calculation."""

Expand Down
25 changes: 22 additions & 3 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
Optional,
Union,
)

import torch
Expand Down Expand Up @@ -328,12 +329,30 @@ def wrapped_sampler() -> list[dict]:
return sampled

self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.fitting_net.compute_input_stats(
wrapped_sampler, protection=self.data_stat_protect
)
self.compute_fitting_stat(wrapped_sampler)
if compute_or_load_out_stat:
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

def compute_fitting_stat(
self,
sample_merged: Union[Callable[[], list[dict]], list[dict]],
) -> None:
"""Compute the input statistics (e.g. mean and stddev) for the fittings from packed data.

Parameters
----------
sample_merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
"""
self.fitting_net.compute_input_stats(
sample_merged, protection=self.data_stat_protect
)

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
return self.fitting_net.get_dim_fparam()
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def change_out_bias(
merged,
bias_adjust_mode=bias_adjust_mode,
)
if bias_adjust_mode == "set-by-statistic":
self.atomic_model.compute_fitting_stat(merged)

def forward_common_lower(
self,
Expand Down
7 changes: 6 additions & 1 deletion source/tests/pd/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def test_dp_train(self) -> None:
state_dict_trained[state_key].numpy(),
state_dict_finetuned_empty[state_key].numpy(),
)
if "fitting_net" not in state_key:
if (
("fitting_net" not in state_key)
or ("fparam" in state_key)
or ("aparam" in state_key)
):
np.testing.assert_allclose(
state_dict_trained[state_key].numpy(),
state_dict_finetuned_random[state_key].numpy(),
Expand Down Expand Up @@ -190,6 +194,7 @@ def setUp(self) -> None:
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
self.config["model"]["data_stat_nbatch"] = 100

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
Expand Down
7 changes: 6 additions & 1 deletion source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def test_dp_train(self) -> None:
state_dict_trained[state_key],
state_dict_finetuned_empty[state_key],
)
if "fitting_net" not in state_key:
if (
("fitting_net" not in state_key)
or ("fparam" in state_key)
or ("aparam" in state_key)
):
torch.testing.assert_close(
state_dict_trained[state_key],
state_dict_finetuned_random[state_key],
Expand Down Expand Up @@ -256,6 +260,7 @@ def setUp(self) -> None:
self.config["training"]["save_freq"] = 1
self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000"
shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy")
self.config["model"]["data_stat_nbatch"] = 100

def tearDown(self) -> None:
(self.set_path / "fparam.npy").unlink(missing_ok=True)
Expand Down