Skip to content

Commit

Permalink
Add stop_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Feb 11, 2025
1 parent a35d74e commit c1c668c
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dev-dependencies = [
"isort>=5.13",
"jupyter>=1",
"mypy>=1.12",
"torch>=2.6",
"networkx>=3.1",
"pre-commit>=4",
"pylint>=3.3",
Expand Down
3 changes: 2 additions & 1 deletion tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ._src.leaky_integral import diffused_leaky_integrate, leaky_data_weight, leaky_integrate
from ._src.math_tools import (abs_square, create_diagonal_array, divide_where, inverse_softplus,
matrix_dot_product, matrix_vector_mul, normalize, outer_product,
softplus)
softplus, stop_gradient)
from ._src.partial import Partial
from ._src.rng import RngStream, create_streams, fork_streams, sample_streams
from ._src.shims import custom_jvp, custom_jvp_method, custom_vjp, custom_vjp_method, hessian, jit
Expand Down Expand Up @@ -113,6 +113,7 @@
'sample_streams',
'scale_cotangent',
'softplus',
'stop_gradient',
'tree_allclose',
'zero_from_primal',
]
19 changes: 18 additions & 1 deletion tjax/_src/math_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from types import ModuleType
from typing import Literal, TypeVar, overload

import array_api_extra as xpx
import numpy as np
from array_api_compat import get_namespace
from array_api_compat import get_namespace, is_jax_array, is_torch_array

from .annotations import (BooleanArray, ComplexArray, IntegralArray, JaxBooleanArray,
JaxComplexArray, JaxIntegralArray, JaxRealArray, NumpyBooleanArray,
Expand Down Expand Up @@ -173,3 +174,19 @@ def create_diagonal_array(m: T) -> T:
target_index = (*index, slice(None, None, n + 1))
retval = xpx.at(retval)[target_index].set(m[*index, :])
return xp.reshape(retval, (*m.shape, n))


U = TypeVar('U')


def stop_gradient(x: U, *, xp: ModuleType | None = None) -> U:
if xp is None:
xp = get_namespace(x)
if is_jax_array(xp):
from jax.lax import stop_gradient as sg # noqa: PLC0415
return sg(x)
if is_torch_array(xp):
from torch import Tensor # noqa: PLC0415
assert isinstance(x, Tensor)
return x.detach() # pyright: ignore
return x
193 changes: 193 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c1c668c

Please sign in to comment.