Skip to content

chore(array-api): implement scatter_sum #4654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: devel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Utilities for the array API."""

import array_api_compat
import array_api_extra as xpx
import numpy as np
from packaging.version import (
Version,
)
Expand Down Expand Up @@ -48,7 +50,11 @@ def xp_swapaxes(a, axis1, axis2):

def xp_take_along_axis(arr, indices, axis):
xp = array_api_compat.array_namespace(arr)
if Version(xp.__array_api_version__) >= Version("2024.12"):
if (
Version(xp.__array_api_version__) >= Version("2024.12")
or array_api_compat.is_numpy_array(arr)
or array_api_compat.is_jax_array(arr)
):
# see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39
return xp.take_along_axis(arr, indices, axis=axis)
arr = xp_swapaxes(arr, axis, -1)
Expand All @@ -73,3 +79,22 @@ def xp_take_along_axis(arr, indices, axis):
out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1)


def xp_ravel(input: np.ndarray) -> np.ndarray:
"""Flattens the input tensor."""
xp = array_api_compat.array_namespace(input)
return xp.reshape(input, [-1])


def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
xp = array_api_compat.array_namespace(input)
idx = xp.arange(input.size, dtype=xp.int64)
idx = xp.reshape(idx, input.shape)
new_idx = xp_take_along_axis(idx, index, axis=dim)
new_idx = xp_ravel(new_idx)
shape = input.shape
input = xp_ravel(input)
input = xpx.at(input, new_idx).add(xp_ravel(src))
return xp.reshape(input, shape)
43 changes: 15 additions & 28 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
xp_scatter_sum,
)
from deepmd.dpmodel.common import (
GLOBAL_ENER_FLOAT_PRECISION,
)
Expand Down Expand Up @@ -98,20 +101,12 @@ def communicate_extended_output(
mapping = xp.reshape(mapping, (mldims + [1] * len(derv_r_ext_dims)))
mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims)
force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype)
# jax only
if array_api_compat.is_jax_array(force):
from deepmd.jax.common import (
scatter_sum,
)

force = scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
force = xp_scatter_sum(
force,
1,
mapping,
model_ret[kk_derv_r],
)
new_ret[kk_derv_r] = force
else:
# name holders
Expand All @@ -127,20 +122,12 @@ def communicate_extended_output(
vldims + derv_c_ext_dims,
dtype=vv.dtype,
)
# jax only
if array_api_compat.is_jax_array(virial):
from deepmd.jax.common import (
scatter_sum,
)

virial = scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
virial = xp_scatter_sum(
virial,
1,
mapping,
model_ret[kk_derv_c],
)
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
else:
Expand Down
10 changes: 0 additions & 10 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,3 @@ def __dlpack__(self, *args, **kwargs):

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)


def scatter_sum(input, dim, index: jnp.ndarray, src: jnp.ndarray) -> jnp.ndarray:
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
shape = input.shape
input = input.ravel()
input = input.at[new_idx].add(src.ravel())
return input.reshape(shape)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
'ml_dtypes',
'mendeleev',
'array-api-compat',
'array-api-extra>=0.5.0',
]
requires-python = ">=3.9"
keywords = ["deepmd"]
Expand Down
Loading