-
Notifications
You must be signed in to change notification settings - Fork 559
feat(pt): add hook to last fitting layer output #4789
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughThis update adds support for evaluating and retrieving the output of the last hidden layer (before the final layer) of the fitting network in deep potential models. New methods and hooks are introduced across the neural network, inference, and model classes to enable, cache, and access these intermediate outputs, with API extensions for both standard and PyTorch-based implementations. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant DeepEval
participant DeepEvalBackend
participant DPModelCommon
participant DPAtomicModel
participant GeneralFitting
User->>DeepEval: eval_fitting_last_layer(...)
DeepEval->>DeepEvalBackend: eval_fitting_last_layer(...)
DeepEvalBackend->>DPModelCommon: set_eval_fitting_last_layer_hook(True)
DeepEvalBackend->>DPModelCommon: eval(...)
DPModelCommon->>DPAtomicModel: set_eval_fitting_last_layer_hook(True)
DPAtomicModel->>GeneralFitting: set_return_middle_output(True)
DPModelCommon->>DPAtomicModel: forward_atomic(...)
DPAtomicModel->>GeneralFitting: _forward_common(...)
GeneralFitting->>GeneralFitting: call_until_last(...)
GeneralFitting-->>DPAtomicModel: return {"middle_output": ...}
DPAtomicModel->>DPAtomicModel: Cache middle_output
DPAtomicModel->>DPAtomicModel: set_eval_fitting_last_layer_hook(False)
DPModelCommon->>DeepEvalBackend: return eval_fitting_last_layer()
DeepEvalBackend->>DeepEval: return result
DeepEval->>User: return result
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (4)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
82-92
: Well-designed hook management methods.The implementation correctly:
- Manages the hook enable/disable state
- Integrates with the fitting network's
set_return_middle_output
method- Clears the cache to prevent stale data
Consider potential thread safety issues if multiple threads access these methods concurrently.
272-278
: Correct implementation of middle output caching.The logic properly checks for the presence of
middle_output
, removes it from the result dictionary, detaches it from the computation graph, and caches it. The assertion ensures the feature is only used with compatible fitting networks.Consider making the error message more descriptive to help users understand which fitting network types support this feature.
- assert "middle_output" in fit_ret, ( - f"eval_fitting_last_layer not supported for fitting net {type(self.fitting_net.__class__)}!" - ) + assert "middle_output" in fit_ret, ( + f"eval_fitting_last_layer not supported for fitting net {type(self.fitting_net)}! " + f"Only mixed_types fitting networks support this feature." + )deepmd/infer/deep_eval.py (1)
504-569
: Well-implemented high-level interface method.The implementation correctly follows the established pattern of input standardization and delegation to the backend. The parameter handling is consistent with other evaluation methods.
Minor documentation inconsistency: The docstring mentions an
efield
parameter that's not in the method signature.- efield - The external field on atoms. - The array should be of size nframes x natoms x 3
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
deepmd/dpmodel/utils/network.py
(1 hunks)deepmd/infer/deep_eval.py
(2 hunks)deepmd/pt/infer/deep_eval.py
(2 hunks)deepmd/pt/model/atomic_model/dp_atomic_model.py
(3 hunks)deepmd/pt/model/model/dp_model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(4 hunks)deepmd/pt/model/task/invar_fitting.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (4)
deepmd/pt/model/model/dp_model.py (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (2)
set_eval_fitting_last_layer_hook
(82-87)eval_fitting_last_layer
(89-91)deepmd/pt/infer/deep_eval.py (1)
eval_fitting_last_layer
(683-736)
deepmd/pt/model/task/invar_fitting.py (1)
deepmd/pt/model/task/fitting.py (1)
_forward_common
(505-645)
deepmd/pt/model/task/fitting.py (4)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
mixed_types
(118-128)deepmd/pt/model/descriptor/se_a.py (2)
mixed_types
(171-175)mixed_types
(587-597)deepmd/pt/model/descriptor/hybrid.py (1)
mixed_types
(143-147)deepmd/dpmodel/utils/network.py (1)
call_until_last
(636-651)
deepmd/infer/deep_eval.py (1)
deepmd/pt/infer/deep_eval.py (1)
eval_fitting_last_layer
(683-736)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test Python (1, 3.9)
🔇 Additional comments (10)
deepmd/dpmodel/utils/network.py (1)
636-651
: LGTM! Well-implemented method for intermediate output extraction.The
call_until_last
method correctly implements forward pass through all layers except the last one. The implementation properly handles edge cases (empty layers or single layer) and follows the existing code patterns with clear documentation.deepmd/pt/model/model/dp_model.py (1)
68-76
: LGTM! New methods follow established patterns correctly.The new fitting last layer hook methods are well-implemented:
- Consistent naming and documentation with existing descriptor methods
- Proper delegation to
atomic_model
maintains the architecture@torch.jit.export
decorators ensure TorchScript compatibility- Clear documentation following existing patterns
deepmd/pt/model/task/invar_fitting.py (1)
184-194
: LGTM! Safe and backward-compatible implementation.The modified
forward
method correctly handles the conditional inclusion of"middle_output"
:
- Properly captures output from
_forward_common
- Safely checks for
"middle_output"
existence before adding to result- Correctly converts both main output and middle output to global precision
- Maintains backward compatibility when middle output is not available
deepmd/pt/infer/deep_eval.py (2)
133-135
: Verify the impact of disabling JIT compilation.The JIT compilation is commented out with a TODO comment. This might impact model performance during inference. Ensure this is a temporary workaround and track the issue for resolution.
683-736
: LGTM! Well-implemented evaluation method following established patterns.The
eval_fitting_last_layer
method is correctly implemented:
- Follows the same pattern as
eval_descriptor
for consistency- Comprehensive parameter documentation matching other evaluation methods
- Proper hook management (enable -> eval -> retrieve -> disable)
- Correct type conversion to NumPy array for the return value
- Method signature consistent with other evaluation APIs
deepmd/pt/model/task/fitting.py (3)
331-331
: LGTM! Clean initialization of the new attribute.The
eval_return_middle_output
attribute is properly initialized toFalse
, ensuring backward compatibility.
454-456
: LGTM! Well-designed setter method.The method provides a clean interface to toggle the middle output evaluation behavior.
605-646
:✅ Verification successful
The return type change from tensor to dictionary is a significant modification.
The logic correctly handles the middle output extraction for mixed types and appropriately raises
NotImplementedError
for non-mixed types. The implementation properly integrates with thecall_until_last
method from the network utility layer.Consider the impact of this breaking change on existing code that expects a single tensor return value.
Verify that all callers of
_forward_common
have been updated to handle the dictionary return type:
🏁 Script executed:
#!/bin/bash # Description: Find all callers of _forward_common to ensure they handle the new dictionary return type # Expected: All callers should handle dictionary return or be updated accordingly rg -A 10 "_forward_common" --type pyLength of output: 4275
🏁 Script executed:
#!/bin/bash # Inspect how the Paddle invar_fitting caller handles the new dict return rg -C5 "self._forward_common" deepmd/pd/model/task/invar_fitting.pyLength of output: 298
Dict return correctly handled in all PyTorch callers
All PyTorch-based tasks now unpack the dictionary returned by
_forward_common
:
- deepmd/pt/model/task/dipole.py indexes
out[self.var_name]
- deepmd/pt/model/task/polarizability.py indexes
out[self.var_name]
- deepmd/pt/model/task/invar_fitting.py builds its
result
fromout[self.var_name]
(andmiddle_output
)The Paddle implementation still returns a tensor and its callers remain unaffected. No further updates are required.
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
65-67
: LGTM! Consistent attribute additions following the existing pattern.The new attributes for fitting last layer hook follow the same design pattern as the existing descriptor hook, maintaining consistency.
Also applies to: 70-70
deepmd/infer/deep_eval.py (1)
218-258
: Excellent addition of abstract interface method.The method signature and documentation are comprehensive and consistent with existing evaluation methods. The abstract nature ensures all backends implement this functionality.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4789 +/- ##
==========================================
+ Coverage 84.57% 84.58% +0.01%
==========================================
Files 699 699
Lines 68070 68124 +54
Branches 3541 3541
==========================================
+ Hits 57567 57621 +54
Misses 9369 9369
Partials 1134 1134 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
source/tests/infer/test_models.py (1)
167-183
: Fix unused loop variableii
.The test implementation is correct and follows established patterns. However, the loop variable
ii
is not used within the loop body.Apply this diff to fix the unused variable:
- for ii, result in enumerate(self.case.results): + for result in self.case.results:Since the variable is not used for error messages or indexing in this test, it can be removed entirely.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/pt/infer/deep_eval.py
(1 hunks)deepmd/pt/model/task/fitting.py
(4 hunks)source/tests/infer/case.py
(1 hunks)source/tests/infer/deeppot-testcase.yaml
(1 hunks)source/tests/infer/test_models.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- source/tests/infer/deeppot-testcase.yaml
🧰 Additional context used
🧠 Learnings (2)
deepmd/pt/infer/deep_eval.py (3)
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4226
File: deepmd/dpmodel/model/make_model.py:370-373
Timestamp: 2024-10-16T21:50:10.680Z
Learning: In `deepmd/dpmodel/model/make_model.py`, the variable `nall` assigned but not used is intentional and should not be flagged in future reviews.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#3875
File: doc/model/train-fitting-dos.md:107-107
Timestamp: 2024-10-08T15:32:11.479Z
Learning: For code blocks in `doc/model/train-fitting-dos.md` that display commands, use 'txt' as the language specification as per user njzjz's preference.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#3875
File: doc/model/train-fitting-dos.md:107-107
Timestamp: 2024-06-13T16:32:13.786Z
Learning: For code blocks in `doc/model/train-fitting-dos.md` that display commands, use 'txt' as the language specification as per user njzjz's preference.
source/tests/infer/test_models.py (3)
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
PR: deepmodeling/deepmd-kit#4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
🧬 Code Graph Analysis (1)
source/tests/infer/test_models.py (4)
deepmd/pt/infer/deep_eval.py (1)
eval_fitting_last_layer
(682-735)deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
eval_fitting_last_layer
(89-91)deepmd/pt/model/model/dp_model.py (1)
eval_fitting_last_layer
(74-76)deepmd/infer/deep_eval.py (2)
eval_fitting_last_layer
(218-258)eval_fitting_last_layer
(504-569)
🪛 Ruff (0.11.9)
source/tests/infer/test_models.py
171-171: Loop control variable ii
not used within loop body
Rename unused ii
to _ii
(B007)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (python)
- GitHub Check: Test C++ (false)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (2, 3.12)
🔇 Additional comments (5)
source/tests/infer/case.py (1)
128-133
: LGTM: Consistent implementation of fit_ll attribute.The new
fit_ll
attribute follows the same pattern as other optional attributes in theResult
class, with proper numpy array conversion, dtype specification, and reshaping.deepmd/pt/infer/deep_eval.py (1)
682-735
: LGTM: Consistent implementation following established patterns.The new
eval_fitting_last_layer
method correctly follows the same pattern aseval_descriptor
with proper hook management, evaluation execution, and result conversion to numpy array. The comprehensive docstring and parameter handling are well-implemented.deepmd/pt/model/task/fitting.py (3)
331-331
: LGTM: Proper initialization of middle output flag.The
eval_return_middle_output
attribute is correctly initialized toFalse
to maintain backward compatibility.
454-455
: LGTM: Clean method for controlling middle output.The
set_return_middle_output
method provides a clean interface for toggling the middle output functionality.
605-658
: LGTM: Well-implemented middle output support.The modified
_forward_common
method correctly handles bothmixed_types=True
andmixed_types=False
cases:
- For mixed types, it uses
call_until_last
on the single network- For non-mixed types, it properly aggregates middle outputs from all networks with appropriate masking
- The return format as a dictionary maintains backward compatibility while enabling the new functionality
The implementation is clean and follows established patterns in the codebase.
Summary by CodeRabbit