-
Notifications
You must be signed in to change notification settings - Fork 575
fix(finetune): calculate fitting stat when using random fitting in finetuning process #4928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
fix(finetune): calculate fitting stat when using random fitting in finetuning process #4928
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds a new compute_fitting_stat API to atomic model base classes and DP implementations, invokes it from make_model when bias_adjust_mode == "set-by-statistic", and updates tests to set model.data_stat_nbatch and broaden state-dict comparison keys. Changes
Sequence DiagramsequenceDiagram
participant MakeModel as make_model
participant AtomicModel
participant FittingNet
MakeModel->>AtomicModel: change_out_bias(merged, bias_adjust_mode)
alt bias_adjust_mode == "set-by-statistic"
MakeModel->>AtomicModel: compute_fitting_stat(merged)
AtomicModel->>FittingNet: compute_input_stats(sample_merged, protection)
FittingNet-->>AtomicModel: stats computed
else other modes
Note over MakeModel: no compute_fitting_stat call
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4928 +/- ##
=======================================
Coverage 84.23% 84.23%
=======================================
Files 709 709
Lines 70078 70092 +14
Branches 3619 3619
=======================================
+ Hits 59032 59044 +12
- Misses 9880 9883 +3
+ Partials 1166 1165 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pd/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pd/model/model/make_model.py(1 hunks)deepmd/pt/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)deepmd/pt/model/model/make_model.py(1 hunks)source/tests/pd/test_training.py(2 hunks)source/tests/pt/test_training.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/atomic_model/base_atomic_model.pysource/tests/pd/test_training.pydeepmd/pd/model/atomic_model/dp_atomic_model.pydeepmd/pd/model/atomic_model/base_atomic_model.pydeepmd/pd/model/model/make_model.pydeepmd/pt/model/model/make_model.pysource/tests/pt/test_training.py
🧠 Learnings (2)
📚 Learning: 2025-09-18T11:37:10.532Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-09-18T11:37:10.532Z
Learning: Applies to source/tests/tf/test_dp_test.py : Keep the core TensorFlow test `source/tests/tf/test_dp_test.py` passing; use it for quick validation
Applied to files:
source/tests/pd/test_training.pysource/tests/pt/test_training.py
📚 Learning: 2024-09-19T04:25:12.408Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Applied to files:
source/tests/pd/test_training.pysource/tests/pt/test_training.py
🧬 Code graph analysis (6)
deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)deepmd/pt/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(496-512)deepmd/pt/model/task/fitting.py (1)
compute_input_stats(78-157)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
deepmd/pd/model/atomic_model/dp_atomic_model.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(518-534)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)deepmd/pd/model/task/fitting.py (1)
compute_input_stats(75-160)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)
deepmd/pd/model/model/make_model.py (3)
deepmd/pd/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(518-534)deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
deepmd/pt/model/model/make_model.py (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(496-512)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
🔇 Additional comments (10)
deepmd/pd/model/model/make_model.py (1)
231-232: LGTM! Correct invocation of compute_fitting_stat.The conditional call to
compute_fitting_statafterchange_out_biasappropriately addresses the PR objective of computing fitting statistics when using random fitting in finetuning (set-by-statistic mode). The implementation correctly reuses the merged data.source/tests/pd/test_training.py (2)
92-100: LGTM! Appropriate test assertion for fparam/aparam keys.The broadened exclusion condition correctly validates that fparam and aparam statistics are preserved during random finetuning, aligning with the PR's objective to compute fitting statistics properly.
197-197: LGTM! Configuration for data statistics batching.Adding
data_stat_nbatch = 100appropriately exercises the data statistics batching behavior that's central to this PR's fitting statistics computation.deepmd/pt/model/atomic_model/base_atomic_model.py (1)
496-512: LGTM! Appropriate placeholder for PT base atomic model.The no-op implementation is correct for the base class, allowing derived classes to provide concrete implementations. The documentation correctly references
torch.Tensorfor the PyTorch path.deepmd/pt/model/model/make_model.py (1)
235-236: LGTM! Correct invocation of compute_fitting_stat in PT path.The implementation mirrors the PD path and correctly invokes
compute_fitting_statwhen bias_adjust_mode is "set-by-statistic", addressing the PR's objective for the PyTorch path.source/tests/pt/test_training.py (2)
95-103: LGTM! Test assertions align with PD path.The broadened exclusion condition for fparam/aparam keys correctly validates the new fitting statistics computation during random finetuning in the PyTorch path.
263-263: LGTM! Configuration mirrors PD test setup.Setting
data_stat_nbatch = 100appropriately exercises data statistics batching in the PyTorch path, consistent with the PD tests.deepmd/pt/model/atomic_model/dp_atomic_model.py (3)
8-8: LGTM! Union import for type hints.The Union import is correctly added to support the type hints for the new
compute_fitting_statmethod signature.
332-332: LGTM! Refactored to use compute_fitting_stat.Good refactoring that centralizes fitting statistics computation through the new
compute_fitting_statmethod, improving code organization and maintainability.
336-354: LGTM! Proper implementation of compute_fitting_stat.The method correctly delegates to
fitting_net.compute_input_statswith thedata_stat_protectparameter, providing a clean interface for computing fitting statistics from packed data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pd/model/atomic_model/base_atomic_model.pydeepmd/pt/model/atomic_model/dp_atomic_model.py
🧬 Code graph analysis (2)
deepmd/pd/model/atomic_model/base_atomic_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
compute_fitting_stat(496-512)deepmd/pt/model/task/fitting.py (1)
compute_input_stats(78-157)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
🔇 Additional comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
8-8: LGTM: Import addition supports new method signature.The
Unionimport is necessary for theUnion[Callable[[], list[dict]], list[dict]]type hint in the newcompute_fitting_statmethod.
332-332: Good refactoring to use the public API.Calling
self.compute_fitting_stat(wrapped_sampler)instead of directly accessingself.fitting_net.compute_input_statsestablishes a consistent public interface and enables proper encapsulation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (1)
496-512: LGTM! Clean API extension for computing fitting statistics.The method provides a well-documented hook for derived classes to compute fitting statistics during finetuning. The no-op default (
pass) is appropriate since not all atomic model types require fitting statistics computation.Optional: Consider adding a clarifying comment
To make the intent clearer, you could add a brief comment before the
passstatement:the lazy function helps by only sampling once. """ + # No-op in base class; derived classes override if fitting statistics are needed. passReminder: Run code quality checks
As per coding guidelines, ensure you run the following before committing:
ruff check . ruff format .
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/pd/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pd/model/atomic_model/dp_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/base_atomic_model.py(1 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/pd/model/atomic_model/base_atomic_model.py
- deepmd/pt/model/atomic_model/dp_atomic_model.py
- deepmd/pd/model/atomic_model/dp_atomic_model.py
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (AGENTS.md)
Always run
ruff check .andruff format .before committing changes to Python code
Files:
deepmd/pt/model/atomic_model/base_atomic_model.py
🧬 Code graph analysis (1)
deepmd/pt/model/atomic_model/base_atomic_model.py (2)
deepmd/pd/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(406-424)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
compute_fitting_stat(336-354)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
| else: | ||
| raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) | ||
|
|
||
| def compute_fitting_stat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend changing the method name to compute_fitting_input_stat.
|
Limitation: the imput stat is not implemented in the python backend. |
| self, | ||
| sample_merged: Union[Callable[[], list[dict]], list[dict]], | ||
| ) -> None: | ||
| """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Compute the input statistics (e.g. mean and stddev) for the fittings from packed data. | |
| """Compute the input statistics (e.g. mean and stddev) for the atomic model from packed data. |
In finetuing process, the computation of fitting stat is skipped in previous code. There are two situations:
fparamoraparamwhich has the same meaning of finetuning task. The keyfparam_avg/fparam_inv_std/aparam_avg/aparam_inv_stdload from the pretrained model. It is correct.Summary by CodeRabbit
New Features
Tests