-
Notifications
You must be signed in to change notification settings - Fork 556
chore(array-api): implement scatter_sum #4654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Conversation
Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a scatter_sum operation using the array-api-extra package to support generalized array APIs. Key changes include:
- Adding the "array-api-extra>=0.5.0" dependency in pyproject.toml.
- Introducing new utility functions xp_ravel and xp_scatter_sum in deepmd/dpmodel/array_api.py with revised handling for various array types.
- Updating deepmd/dpmodel/model/transform_output.py to use xp_scatter_sum instead of a JAX-specific version.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
pyproject.toml | Adds dependency on array-api-extra to support the new scatter_sum functionality. |
deepmd/dpmodel/array_api.py | Adds xp_ravel and xp_scatter_sum; refines xp_take_along_axis for broader support. |
deepmd/dpmodel/model/transform_output.py | Replaces JAX-specific scatter_sum with the new, generalized xp_scatter_sum. |
deepmd/jax/common.py | Removes the now redundant scatter_sum function. |
Comments suppressed due to low confidence (2)
deepmd/dpmodel/array_api.py:90
- [nitpick] Consider renaming the parameter 'input' to avoid shadowing the built-in function and improve clarity, e.g., use 'inp' or 'input_array'.
def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
deepmd/dpmodel/array_api.py:90
- Ensure that xp_scatter_sum is covered by unit tests for various array types (e.g., numpy, JAX) to verify its correct behavior across supported array APIs.
def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
📝 WalkthroughWalkthroughThis PR extends the array manipulation functionality by updating the Changes
Sequence Diagram(s)sequenceDiagram
participant CO as communicate_extended_output
participant XSS as xp_scatter_sum
participant XR as xp_ravel
participant XTA as xp_take_along_axis
CO->>XSS: Invoke scatter sum for force/virial tensors
XSS->>XR: Flatten input tensor via xp_ravel
XSS->>XTA: Adjust values based on indices via xp_take_along_axis
XSS-->>CO: Return computed tensor
Possibly related PRs
Suggested labels
Suggested reviewers
Tip ⚡🧪 Multi-step agentic review comment chat (experimental)
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (4)
💤 Files with no reviewable changes (1)
⏰ Context from checks skipped due to timeout of 90000ms (23)
🔇 Additional comments (8)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
It seems that |
Address #4649 (comment). This PR depends on array_api_extra.at.
Summary by CodeRabbit