From c12b51439b44b9b7fb4a7f0eaebac3f9c1149876 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Wed, 22 Oct 2025 01:14:45 +0500 Subject: [PATCH 01/10] feat: enhance types in `numpyro.distributions.constraints` module --- numpyro/_typing.py | 18 +- numpyro/distributions/constraints.py | 339 +++++++++++++++------------ numpyro/distributions/transforms.py | 14 +- pyproject.toml | 1 + 4 files changed, 217 insertions(+), 155 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 10cc1c9e8..4f7375c8a 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,7 +4,7 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Optional, Protocol, Union, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, Union, runtime_checkable import weakref try: @@ -12,6 +12,7 @@ except ImportError: from typing_extensions import ParamSpec, TypeAlias + import numpy as np import jax @@ -36,19 +37,26 @@ """A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" +NumLikeT = TypeVar("NumLikeT", bound=NumLike) + + @runtime_checkable -class ConstraintT(Protocol): +class ConstraintT(Protocol[NumLikeT]): """A protocol for typing constraints.""" @property def is_discrete(self) -> bool: ... @property def event_dim(self) -> int: ... + @is_discrete.setter + def is_discrete(self, value: bool): ... + @event_dim.setter + def event_dim(self, value: int): ... - def __call__(self, x: ArrayLike) -> ArrayLike: ... + def __call__(self, x: NumLikeT) -> ArrayLike: ... def __repr__(self) -> str: ... - def check(self, value: ArrayLike) -> ArrayLike: ... - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: ... + def check(self, value: NumLikeT) -> ArrayLike: ... + def feasible_like(self, prototype: NumLikeT) -> NumLikeT: ... @runtime_checkable diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 417b5f260..454570aeb 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -64,7 +64,7 @@ ] import math -from typing import Optional +from typing import Generic, Optional, TypeVar import numpy as np @@ -75,8 +75,10 @@ from numpyro._typing import ConstraintT, NonScalarArray, NumLike +NumLikeT = TypeVar("NumLikeT", bound=NumLike) -class Constraint(object): + +class Constraint(Generic[NumLikeT]): """ Abstract base class for constraints. @@ -84,32 +86,48 @@ class Constraint(object): e.g. within which a variable can be optimized. """ - is_discrete = False - event_dim = 0 + _is_discrete = False + _event_dim = 0 def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten) - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLikeT) -> ArrayLike: raise NotImplementedError def __repr__(self) -> str: return self.__class__.__name__[1:] + "()" - def check(self, value: ArrayLike) -> ArrayLike: + def check(self, value: NumLikeT) -> ArrayLike: """ Returns a byte tensor of `sample_shape + batch_shape` indicating whether each event in value satisfies this constraint. """ return self(value) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLikeT) -> NumLikeT: """ Get a feasible value which has the same shape as dtype as `prototype`. """ raise NotImplementedError + @property + def is_discrete(self) -> bool: + return self._is_discrete + + @property + def event_dim(self) -> int: + return self._event_dim + + @is_discrete.setter # type: ignore[attr-defined] + def is_discrete(self, value: bool): + self._is_discrete = value + + @event_dim.setter # type: ignore[attr-defined] + def event_dim(self, value: int): + self._event_dim = value + @classmethod def tree_unflatten(cls, aux_data, params): params_keys, aux_data = aux_data @@ -122,12 +140,12 @@ def tree_unflatten(cls, aux_data, params): return self -class ParameterFreeConstraint(Constraint): +class ParameterFreeConstraint(Constraint[NumLikeT]): def tree_flatten(self): return (), ((), dict()) -class _SingletonConstraint(ParameterFreeConstraint): +class _SingletonConstraint(ParameterFreeConstraint[NumLikeT]): """ A constraint type which has only one canonical instance, like constraints.real, and unlike constraints.interval. @@ -140,20 +158,20 @@ def __new__(cls): return cls._instance -class _Boolean(_SingletonConstraint): - is_discrete = True +class _Boolean(_SingletonConstraint[NumLike]): + _is_discrete = True - def __call__(self, x: ArrayLike) -> ArrayLike: - return (x == 0) | (x == 1) + def __call__(self, x: NumLike) -> ArrayLike: + return jnp.logical_or(jnp.equal(x, 0), jnp.equal(x, 1)) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _CorrCholesky(_SingletonConstraint): - event_dim = 2 +class _CorrCholesky(_SingletonConstraint[NonScalarArray]): + _event_dim = 2 - def __call__(self, x: NonScalarArray) -> NonScalarArray: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tril = jnp.tril(x) lower_triangular = jnp.all( @@ -171,10 +189,10 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _CorrMatrix(_SingletonConstraint): - event_dim = 2 +class _CorrMatrix(_SingletonConstraint[NonScalarArray]): + _event_dim = 2 - def __call__(self, x: NonScalarArray) -> NonScalarArray: + def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) @@ -184,7 +202,7 @@ def __call__(self, x: NonScalarArray) -> NonScalarArray: unit_variance = jnp.all( jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1 ) - return symmetric & positive & unit_variance + return jnp.logical_and(jnp.logical_and(symmetric, positive), unit_variance) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -192,7 +210,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _Dependent(Constraint): +class _Dependent(Constraint[NumLikeT]): """ Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. @@ -226,7 +244,7 @@ def event_dim(self) -> int: def __call__( self, - x: Optional[ArrayLike] = None, + x: Optional[NumLikeT] = None, *, is_discrete: bool = NotImplemented, event_dim: int = NotImplemented, @@ -242,21 +260,22 @@ def __call__( event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Dependent): + return False return ( - type(self) is type(other) - and self._is_discrete == other._is_discrete + self._is_discrete == other._is_discrete and self._event_dim == other._event_dim ) def tree_flatten(self): return (), ( (), - dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim), + dict(is_discrete=self._is_discrete, event_dim=self._event_dim), ) -class dependent_property(property, _Dependent): +class dependent_property(property, _Dependent[NumLikeT]): # XXX: this should not need to be pytree-able since it simply wraps a method # and thus is automatically present once the method's object is created def __init__( @@ -266,7 +285,7 @@ def __init__( self._is_discrete = is_discrete self._event_dim = event_dim - def __call__(self, x: ArrayLike) -> ArrayLike: + def __call__(self, x: NumLikeT) -> ArrayLike: if not callable(x): return super().__call__(x) @@ -283,12 +302,12 @@ def is_dependent(constraint): return isinstance(constraint, _Dependent) -class _GreaterThan(Constraint): +class _GreaterThan(Constraint[NumLike]): def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound def __call__(self, x: NumLike) -> ArrayLike: - return x > self.lower_bound + return jnp.greater(x, self.lower_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -301,40 +320,42 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore[return-value] class _GreaterThanEq(_GreaterThan): def __call__(self, x: NumLike) -> ArrayLike: - return x >= self.lower_bound + return jnp.greater_equal(x, self.lower_bound) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThanEq): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore[return-value] -class _Positive(_SingletonConstraint, _GreaterThan): +class _Positive(_SingletonConstraint[NumLike], _GreaterThan): def __init__(self) -> None: super().__init__(0.0) -class _Nonnegative(_SingletonConstraint, _GreaterThanEq): +class _Nonnegative(_SingletonConstraint[NumLike], _GreaterThanEq): def __init__(self) -> None: super().__init__(0.0) -class _IndependentConstraint(Constraint): +class _IndependentConstraint(Constraint[NumLikeT]): """ Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many dims in :meth:`check`, so that an event is valid only if all its independent entries are valid. """ - def __init__(self, base_constraint, reinterpreted_batch_ndims): + def __init__( + self, base_constraint: ConstraintT[NumLikeT], reinterpreted_batch_ndims: int + ): assert isinstance(base_constraint, Constraint) assert isinstance(reinterpreted_batch_ndims, int) assert reinterpreted_batch_ndims >= 0 @@ -345,17 +366,11 @@ def __init__(self, base_constraint, reinterpreted_batch_ndims): base_constraint = base_constraint.base_constraint self.base_constraint = base_constraint self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + self._is_discrete = base_constraint.is_discrete + self._event_dim = base_constraint.event_dim + reinterpreted_batch_ndims super().__init__() - @property - def is_discrete(self) -> bool: - return self.base_constraint.is_discrete - - @property - def event_dim(self) -> int: - return self.base_constraint.event_dim + self.reinterpreted_batch_ndims - - def __call__(self, value: ArrayLike) -> ArrayLike: + def __call__(self, value: NumLikeT) -> ArrayLike: result = self.base_constraint(value) if self.reinterpreted_batch_ndims == 0: return result @@ -364,7 +379,9 @@ def __call__(self, value: ArrayLike) -> ArrayLike: raise ValueError( f"Expected value.dim() >= {expected} but got {jax.numpy.ndim(value)}" ) - result = result.reshape( + # jax>=0.7.2 introduced `TypedNdArray` to represent constants in jaxpr, and they + # have no reshape method. + result = result.reshape( # type: ignore[union-attr] jax.numpy.shape(result)[ : jax.numpy.ndim(result) - self.reinterpreted_batch_ndims ] @@ -380,7 +397,7 @@ def __repr__(self) -> str: self.reinterpreted_batch_ndims, ) - def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + def feasible_like(self, prototype: NumLikeT) -> NumLikeT: return self.base_constraint.feasible_like(prototype) def tree_flatten(self): @@ -389,7 +406,7 @@ def tree_flatten(self): {"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims}, ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IndependentConstraint): return False @@ -398,22 +415,26 @@ def __eq__(self, other: ConstraintT) -> bool: ) -class _RealVector(_IndependentConstraint, _SingletonConstraint): +class _RealVector( + _IndependentConstraint[NumLike], _SingletonConstraint[NonScalarArray] +): def __init__(self) -> None: super().__init__(_Real(), 1) -class _RealMatrix(_IndependentConstraint, _SingletonConstraint): +class _RealMatrix( + _IndependentConstraint[NumLike], _SingletonConstraint[NonScalarArray] +): def __init__(self) -> None: super().__init__(_Real(), 2) -class _LessThan(Constraint): +class _LessThan(Constraint[NumLike]): def __init__(self, upper_bound: NumLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return x < self.upper_bound + return jnp.less(x, self.upper_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -426,31 +447,37 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThan): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore[return-value] class _LessThanEq(_LessThan): def __call__(self, x: NumLike) -> ArrayLike: - return x <= self.upper_bound + return jnp.less_equal(x, self.upper_bound) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThanEq): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore[return-value] -class _IntegerInterval(Constraint): - is_discrete = True +class _IntegerInterval(Constraint[NumLike]): + _is_discrete = True def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return (x >= self.lower_bound) & (x <= self.upper_bound) & (x % 1 == 0) + return jnp.logical_and( + jnp.logical_and( + jnp.greater_equal(x, self.lower_bound), + jnp.less_equal(x, self.upper_bound), + ), + jnp.equal(jnp.mod(x, 1), 0), + ) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -468,23 +495,25 @@ def tree_flatten(self): dict(), ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IntegerInterval): return False + return jnp.logical_and( + jnp.array_equal(self.lower_bound, other.lower_bound), + jnp.array_equal(self.upper_bound, other.upper_bound), + ) # type: ignore[return-value] - return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( - self.upper_bound, other.upper_bound - ) - -class _IntegerGreaterThan(Constraint): - is_discrete = True +class _IntegerGreaterThan(Constraint[NumLike]): + _is_discrete = True def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound def __call__(self, x: NumLike) -> ArrayLike: - return (x % 1 == 0) & (x >= self.lower_bound) + return jnp.logical_and( + jnp.equal(jnp.mod(x, 1), 0), jnp.greater_equal(x, self.lower_bound) + ) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -497,29 +526,31 @@ def feasible_like(self, prototype: NumLike) -> NumLike: def tree_flatten(self): return (self.lower_bound,), (("lower_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _IntegerGreaterThan): return False - return jnp.array_equal(self.lower_bound, other.lower_bound) + return jnp.array_equal(self.lower_bound, other.lower_bound) # type: ignore[return-value] -class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan): +class _IntegerPositive(_SingletonConstraint[NumLike], _IntegerGreaterThan): def __init__(self) -> None: super().__init__(1) -class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan): +class _IntegerNonnegative(_SingletonConstraint[NumLike], _IntegerGreaterThan): def __init__(self) -> None: super().__init__(0) -class _Interval(Constraint): +class _Interval(Constraint[NumLike]): def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return (x >= self.lower_bound) & (x <= self.upper_bound) + return jnp.logical_and( + jnp.greater_equal(x, self.lower_bound), jnp.less_equal(x, self.upper_bound) + ) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -533,12 +564,12 @@ def feasible_like(self, prototype: NumLike) -> NumLike: (self.lower_bound + self.upper_bound) / 2, jax.numpy.shape(prototype) ) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _Interval): return False return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal( self.upper_bound, other.upper_bound - ) + ) # type: ignore[return-value] def tree_flatten(self): return (self.lower_bound, self.upper_bound), ( @@ -547,19 +578,22 @@ def tree_flatten(self): ) -class _Circular(_SingletonConstraint, _Interval): +class _Circular(_SingletonConstraint[NumLike], _Interval): def __init__(self) -> None: super().__init__(-math.pi, math.pi) -class _UnitInterval(_SingletonConstraint, _Interval): +class _UnitInterval(_SingletonConstraint[NumLike], _Interval): def __init__(self) -> None: super().__init__(0.0, 1.0) class _OpenInterval(_Interval): def __call__(self, x: NumLike) -> ArrayLike: - return (x > self.lower_bound) & (x < self.upper_bound) + return jnp.logical_and( + jnp.greater(x, self.lower_bound), + jnp.less(x, self.upper_bound), + ) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -569,8 +603,8 @@ def __repr__(self) -> str: return fmt_string -class _LowerCholesky(_SingletonConstraint): - event_dim = 2 +class _LowerCholesky(_SingletonConstraint[NonScalarArray]): + _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -579,7 +613,7 @@ def __call__(self, x: NonScalarArray) -> ArrayLike: jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1 ) positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) - return lower_triangular & positive_diagonal + return jnp.logical_and(lower_triangular, positive_diagonal) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -587,15 +621,18 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _Multinomial(Constraint): - is_discrete = True - event_dim = 1 +class _Multinomial(Constraint[NonScalarArray]): + _is_discrete = True + _event_dim = 1 def __init__(self, upper_bound: ArrayLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NonScalarArray) -> ArrayLike: - return (x >= 0).all(axis=-1) & (x.sum(axis=-1) == self.upper_bound) + return jnp.logical_and( + (x >= 0).all(axis=-1), + jnp.equal(x.sum(axis=-1), self.upper_bound), + ) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: pad_width = ((0, 0),) * jax.numpy.ndim(self.upper_bound) + ( @@ -607,31 +644,31 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: def tree_flatten(self): return (self.upper_bound,), (("upper_bound",), dict()) - def __eq__(self, other: ConstraintT) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, _Multinomial): return False - return jnp.array_equal(self.upper_bound, other.upper_bound) + return jnp.array_equal(self.upper_bound, other.upper_bound) # type: ignore[return-value] -class _L1Ball(_SingletonConstraint): +class _L1Ball(_SingletonConstraint[NumLike]): """ Constrain to the L1 ball of any dimension. """ - event_dim = 1 + _event_dim = 1 reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NumLike) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - eps = jnp.finfo(x.dtype).eps + eps = jnp.finfo(x.dtype if isinstance(x, jnp.ndarray) else type(x)).eps return jnp.abs(x).sum(axis=-1) < 1 + self.reltol * eps def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _OrderedVector(_SingletonConstraint): - event_dim = 1 +class _OrderedVector(_SingletonConstraint[NonScalarArray]): + _event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: return (x[..., 1:] > x[..., :-1]).all(axis=-1) @@ -642,8 +679,8 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _PositiveDefinite(_SingletonConstraint): - event_dim = 2 +class _PositiveDefinite(_SingletonConstraint[NonScalarArray]): + _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -659,21 +696,21 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _PositiveDefiniteCirculantVector(_SingletonConstraint): - event_dim = 1 +class _PositiveDefiniteCirculantVector(_SingletonConstraint[NonScalarArray]): + _event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy tol = 10 * jnp.finfo(x.dtype).eps rfft = jnp.fft.rfft(x) - return (jnp.abs(rfft.imag) < tol) & (rfft.real > -tol) + return jnp.logical_and(jnp.abs(rfft.imag) < tol, jnp.greater(rfft.real, -tol)) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jnp.zeros_like(prototype).at[..., 0].set(1.0) -class _PositiveSemiDefinite(_SingletonConstraint): - event_dim = 2 +class _PositiveSemiDefinite(_SingletonConstraint[NonScalarArray]): + _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -681,7 +718,7 @@ def __call__(self, x: NonScalarArray) -> ArrayLike: symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is nonnegative nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 - return symmetric & nonnegative + return jnp.logical_and(symmetric, nonnegative) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -689,16 +726,18 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _PositiveOrderedVector(_SingletonConstraint): +class _PositiveOrderedVector(_SingletonConstraint[NonScalarArray]): """ Constrains to a positive real-valued tensor where the elements are monotonically increasing along the `event_shape` dimension. """ - event_dim = 1 + _event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: - return ordered_vector.check(x) & independent(positive, 1).check(x) + return jnp.logical_and( + ordered_vector.check(x), independent[NumLike](positive, 1).check(x) + ) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -706,16 +745,22 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: ) -class _Complex(_SingletonConstraint): +class _Complex(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval - return (x == x) & (x != float("inf")) & (x != float("-inf")) + return jnp.logical_and( + jnp.equal(x, x), + jnp.logical_and( + jnp.not_equal(x, float("inf")), + jnp.not_equal(x, float("-inf")), + ), + ) def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _Real(_SingletonConstraint): +class _Real(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval return (x == x) & (x != float("inf")) & (x != float("-inf")) @@ -724,8 +769,8 @@ def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) -class _Simplex(_SingletonConstraint): - event_dim = 1 +class _Simplex(_SingletonConstraint[NonScalarArray]): + _event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: x_sum = x.sum(axis=-1) @@ -735,7 +780,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, 1 / prototype.shape[-1]) -class _SoftplusPositive(_SingletonConstraint, _GreaterThan): +class _SoftplusPositive(_SingletonConstraint[NumLike], _GreaterThan): def __init__(self) -> None: super().__init__(lower_bound=0.0) @@ -754,12 +799,12 @@ class _ScaledUnitLowerCholesky(_LowerCholesky): pass -class _Sphere(_SingletonConstraint): +class _Sphere(_SingletonConstraint[NonScalarArray]): """ Constrain to the Euclidean sphere of any dimension. """ - event_dim = 1 + _event_dim = 1 reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NonScalarArray) -> ArrayLike: @@ -773,9 +818,9 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) -class _ZeroSum(Constraint): +class _ZeroSum(Constraint[NonScalarArray]): def __init__(self, event_dim: int = 1) -> None: - self.event_dim = event_dim + self._event_dim = event_dim super().__init__() def __call__(self, x: NonScalarArray) -> ArrayLike: @@ -786,8 +831,10 @@ def __call__(self, x: NonScalarArray) -> ArrayLike: zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol) return zerosum_true - def __eq__(self, other: ConstraintT) -> bool: - return type(self) is type(other) and self.event_dim == other.event_dim + def __eq__(self, other: object) -> bool: + if not isinstance(other, _ZeroSum): + return False + return self.event_dim == other.event_dim def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.zeros_like(prototype) @@ -799,11 +846,12 @@ def tree_flatten(self): # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 -boolean: ConstraintT = _Boolean() -circular: ConstraintT = _Circular() -complex: ConstraintT = _Complex() -corr_cholesky: ConstraintT = _CorrCholesky() -corr_matrix: ConstraintT = _CorrMatrix() +# fmt: off +boolean: ConstraintT[NumLike] = _Boolean() +circular: ConstraintT[NumLike] = _Circular() +complex: ConstraintT[NumLike] = _Complex() +corr_cholesky: ConstraintT[NonScalarArray] = _CorrCholesky() +corr_matrix: ConstraintT[NonScalarArray] = _CorrMatrix() dependent: ConstraintT = _Dependent() greater_than = _GreaterThan greater_than_eq = _GreaterThanEq @@ -813,26 +861,27 @@ def tree_flatten(self): integer_interval = _IntegerInterval integer_greater_than = _IntegerGreaterThan interval = _Interval -l1_ball: ConstraintT = _L1Ball() -lower_cholesky: ConstraintT = _LowerCholesky() -scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() +l1_ball: ConstraintT[NumLike] = _L1Ball() +lower_cholesky: ConstraintT[NonScalarArray] = _LowerCholesky() +scaled_unit_lower_cholesky: ConstraintT[NonScalarArray] = _ScaledUnitLowerCholesky() multinomial = _Multinomial -nonnegative: ConstraintT = _Nonnegative() -nonnegative_integer: ConstraintT = _IntegerNonnegative() -ordered_vector: ConstraintT = _OrderedVector() -positive: ConstraintT = _Positive() -positive_definite: ConstraintT = _PositiveDefinite() -positive_definite_circulant_vector: ConstraintT = _PositiveDefiniteCirculantVector() -positive_semidefinite: ConstraintT = _PositiveSemiDefinite() -positive_integer: ConstraintT = _IntegerPositive() -positive_ordered_vector: ConstraintT = _PositiveOrderedVector() -real: ConstraintT = _Real() -real_vector: ConstraintT = _RealVector() -real_matrix: ConstraintT = _RealMatrix() -simplex: ConstraintT = _Simplex() -softplus_lower_cholesky: ConstraintT = _SoftplusLowerCholesky() -softplus_positive: ConstraintT = _SoftplusPositive() -sphere: ConstraintT = _Sphere() -unit_interval: ConstraintT = _UnitInterval() +nonnegative: ConstraintT[NumLike] = _Nonnegative() +nonnegative_integer: ConstraintT[NumLike] = _IntegerNonnegative() +ordered_vector: ConstraintT[NonScalarArray] = _OrderedVector() +positive: ConstraintT[NumLike] = _Positive() +positive_definite: ConstraintT[NonScalarArray] = _PositiveDefinite() +positive_definite_circulant_vector: ConstraintT[NonScalarArray] = _PositiveDefiniteCirculantVector() +positive_semidefinite: ConstraintT[NonScalarArray] = _PositiveSemiDefinite() +positive_integer: ConstraintT[NumLike] = _IntegerPositive() +positive_ordered_vector: ConstraintT[NonScalarArray] = _PositiveOrderedVector() +real: ConstraintT[NumLike] = _Real() +real_vector: ConstraintT[NumLike] = _RealVector() +real_matrix: ConstraintT[NumLike] = _RealMatrix() +simplex: ConstraintT[NonScalarArray] = _Simplex() +softplus_lower_cholesky: ConstraintT[NonScalarArray] = _SoftplusLowerCholesky() +softplus_positive: ConstraintT[NumLike] = _SoftplusPositive() +sphere: ConstraintT[NonScalarArray] = _Sphere() +unit_interval: ConstraintT[NumLike] = _UnitInterval() open_interval = _OpenInterval zero_sum = _ZeroSum +# fmt: on diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4512225c8..11cf179e0 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -3,7 +3,7 @@ import math -from typing import Generic, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Generic, Optional, Sequence, Tuple, Union, cast import warnings import weakref @@ -18,7 +18,14 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT +from numpyro._typing import ( + ConstraintT, + NonScalarArray, + NumLike, + NumLikeT, + PyTree, + TransformT, +) from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -65,9 +72,6 @@ def _clipped_expit(x: NumLike) -> NumLike: return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) -NumLikeT = TypeVar("NumLikeT", bound=NumLike) - - class Transform(Generic[NumLikeT]): _inv: Optional[Union[TransformT, weakref.ref]] = None diff --git a/pyproject.toml b/pyproject.toml index 47925bf7a..f3b504427 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ module = [ "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*", + "numpyro.distributions.constraints", "numpyro.distributions.transforms", ] ignore_errors = false From cb01e0991224aaef0e8969b26991fa95154b3ae6 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Wed, 22 Oct 2025 01:58:59 +0500 Subject: [PATCH 02/10] refactor: simplify `ConstraintT` protocol by removing generic type parameter --- numpyro/_typing.py | 8 ++-- numpyro/distributions/constraints.py | 63 +++++++++++++--------------- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 4f7375c8a..184063dd2 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -41,7 +41,7 @@ @runtime_checkable -class ConstraintT(Protocol[NumLikeT]): +class ConstraintT(Protocol): """A protocol for typing constraints.""" @property @@ -53,10 +53,10 @@ def is_discrete(self, value: bool): ... @event_dim.setter def event_dim(self, value: int): ... - def __call__(self, x: NumLikeT) -> ArrayLike: ... + def __call__(self, x: NumLike) -> ArrayLike: ... def __repr__(self) -> str: ... - def check(self, value: NumLikeT) -> ArrayLike: ... - def feasible_like(self, prototype: NumLikeT) -> NumLikeT: ... + def check(self, value: NumLike) -> ArrayLike: ... + def feasible_like(self, prototype: NumLike) -> NumLike: ... @runtime_checkable diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 454570aeb..ef99b9750 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -64,7 +64,7 @@ ] import math -from typing import Generic, Optional, TypeVar +from typing import Generic, Optional import numpy as np @@ -73,9 +73,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike - -NumLikeT = TypeVar("NumLikeT", bound=NumLike) +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, NumLikeT class Constraint(Generic[NumLikeT]): @@ -353,9 +351,7 @@ class _IndependentConstraint(Constraint[NumLikeT]): independent entries are valid. """ - def __init__( - self, base_constraint: ConstraintT[NumLikeT], reinterpreted_batch_ndims: int - ): + def __init__(self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int): assert isinstance(base_constraint, Constraint) assert isinstance(reinterpreted_batch_ndims, int) assert reinterpreted_batch_ndims >= 0 @@ -846,12 +842,12 @@ def tree_flatten(self): # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 -# fmt: off -boolean: ConstraintT[NumLike] = _Boolean() -circular: ConstraintT[NumLike] = _Circular() -complex: ConstraintT[NumLike] = _Complex() -corr_cholesky: ConstraintT[NonScalarArray] = _CorrCholesky() -corr_matrix: ConstraintT[NonScalarArray] = _CorrMatrix() + +boolean: ConstraintT = _Boolean() +circular: ConstraintT = _Circular() +complex: ConstraintT = _Complex() +corr_cholesky: ConstraintT = _CorrCholesky() +corr_matrix: ConstraintT = _CorrMatrix() dependent: ConstraintT = _Dependent() greater_than = _GreaterThan greater_than_eq = _GreaterThanEq @@ -861,27 +857,26 @@ def tree_flatten(self): integer_interval = _IntegerInterval integer_greater_than = _IntegerGreaterThan interval = _Interval -l1_ball: ConstraintT[NumLike] = _L1Ball() -lower_cholesky: ConstraintT[NonScalarArray] = _LowerCholesky() -scaled_unit_lower_cholesky: ConstraintT[NonScalarArray] = _ScaledUnitLowerCholesky() +l1_ball: ConstraintT = _L1Ball() +lower_cholesky: ConstraintT = _LowerCholesky() +scaled_unit_lower_cholesky: ConstraintT = _ScaledUnitLowerCholesky() multinomial = _Multinomial -nonnegative: ConstraintT[NumLike] = _Nonnegative() -nonnegative_integer: ConstraintT[NumLike] = _IntegerNonnegative() -ordered_vector: ConstraintT[NonScalarArray] = _OrderedVector() -positive: ConstraintT[NumLike] = _Positive() -positive_definite: ConstraintT[NonScalarArray] = _PositiveDefinite() -positive_definite_circulant_vector: ConstraintT[NonScalarArray] = _PositiveDefiniteCirculantVector() -positive_semidefinite: ConstraintT[NonScalarArray] = _PositiveSemiDefinite() -positive_integer: ConstraintT[NumLike] = _IntegerPositive() -positive_ordered_vector: ConstraintT[NonScalarArray] = _PositiveOrderedVector() -real: ConstraintT[NumLike] = _Real() -real_vector: ConstraintT[NumLike] = _RealVector() -real_matrix: ConstraintT[NumLike] = _RealMatrix() -simplex: ConstraintT[NonScalarArray] = _Simplex() -softplus_lower_cholesky: ConstraintT[NonScalarArray] = _SoftplusLowerCholesky() -softplus_positive: ConstraintT[NumLike] = _SoftplusPositive() -sphere: ConstraintT[NonScalarArray] = _Sphere() -unit_interval: ConstraintT[NumLike] = _UnitInterval() +nonnegative: ConstraintT = _Nonnegative() +nonnegative_integer: ConstraintT = _IntegerNonnegative() +ordered_vector: ConstraintT = _OrderedVector() +positive: ConstraintT = _Positive() +positive_definite: ConstraintT = _PositiveDefinite() +positive_definite_circulant_vector: ConstraintT = _PositiveDefiniteCirculantVector() +positive_semidefinite: ConstraintT = _PositiveSemiDefinite() +positive_integer: ConstraintT = _IntegerPositive() +positive_ordered_vector: ConstraintT = _PositiveOrderedVector() +real: ConstraintT = _Real() +real_vector: ConstraintT = _RealVector() +real_matrix: ConstraintT = _RealMatrix() +simplex: ConstraintT = _Simplex() +softplus_lower_cholesky: ConstraintT = _SoftplusLowerCholesky() +softplus_positive: ConstraintT = _SoftplusPositive() +sphere: ConstraintT = _Sphere() +unit_interval: ConstraintT = _UnitInterval() open_interval = _OpenInterval zero_sum = _ZeroSum -# fmt: on From 78d8e93af378a19399d590c8ac52ed93cd5dfc2e Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 24 Oct 2025 01:38:21 +0500 Subject: [PATCH 03/10] refactor: remove unused `NumLikeT` type variable and streamline imports in typing modules --- numpyro/_typing.py | 3 --- numpyro/distributions/constraints.py | 6 ++++-- numpyro/distributions/transforms.py | 14 +++++--------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 184063dd2..a205a9ca8 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -37,9 +37,6 @@ """A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" -NumLikeT = TypeVar("NumLikeT", bound=NumLike) - - @runtime_checkable class ConstraintT(Protocol): """A protocol for typing constraints.""" diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index ef99b9750..0febac81f 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -64,7 +64,7 @@ ] import math -from typing import Generic, Optional +from typing import Generic, Optional, TypeVar import numpy as np @@ -73,7 +73,9 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike, NumLikeT +from numpyro._typing import ConstraintT, NonScalarArray, NumLike + +NumLikeT = TypeVar("NumLikeT", bound=NumLike) class Constraint(Generic[NumLikeT]): diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 11cf179e0..4512225c8 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -3,7 +3,7 @@ import math -from typing import Generic, Optional, Sequence, Tuple, Union, cast +from typing import Generic, Optional, Sequence, Tuple, TypeVar, Union, cast import warnings import weakref @@ -18,14 +18,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ( - ConstraintT, - NonScalarArray, - NumLike, - NumLikeT, - PyTree, - TransformT, -) +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -72,6 +65,9 @@ def _clipped_expit(x: NumLike) -> NumLike: return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) +NumLikeT = TypeVar("NumLikeT", bound=NumLike) + + class Transform(Generic[NumLikeT]): _inv: Optional[Union[TransformT, weakref.ref]] = None From cd2bc8fe6411fe94bb596bf263e346068f9fb1cc Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 24 Oct 2025 01:55:23 +0500 Subject: [PATCH 04/10] refactor: remove setter methods for `is_discrete` and `event_dim` in `Constraint` class --- numpyro/_typing.py | 6 +----- numpyro/distributions/constraints.py | 10 +--------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index a205a9ca8..861496ba8 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,7 +4,7 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Optional, Protocol, TypeVar, Union, runtime_checkable +from typing import Any, Optional, Protocol, Union, runtime_checkable import weakref try: @@ -45,10 +45,6 @@ class ConstraintT(Protocol): def is_discrete(self) -> bool: ... @property def event_dim(self) -> int: ... - @is_discrete.setter - def is_discrete(self, value: bool): ... - @event_dim.setter - def event_dim(self, value: int): ... def __call__(self, x: NumLike) -> ArrayLike: ... def __repr__(self) -> str: ... diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 0febac81f..abb788735 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -120,14 +120,6 @@ def is_discrete(self) -> bool: def event_dim(self) -> int: return self._event_dim - @is_discrete.setter # type: ignore[attr-defined] - def is_discrete(self, value: bool): - self._is_discrete = value - - @event_dim.setter # type: ignore[attr-defined] - def event_dim(self, value: int): - self._event_dim = value - @classmethod def tree_unflatten(cls, aux_data, params): params_keys, aux_data = aux_data @@ -271,7 +263,7 @@ def __eq__(self, other: object) -> bool: def tree_flatten(self): return (), ( (), - dict(is_discrete=self._is_discrete, event_dim=self._event_dim), + dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim), ) From a6e52e156c3166216569b917ec5548923e385506 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 24 Oct 2025 01:59:06 +0500 Subject: [PATCH 05/10] refactor: replace `reshape` method with `jnp.reshape` in `_IndependentConstraint` class --- numpyro/distributions/constraints.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index abb788735..7fa6f2c0b 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -369,13 +369,12 @@ def __call__(self, value: NumLikeT) -> ArrayLike: raise ValueError( f"Expected value.dim() >= {expected} but got {jax.numpy.ndim(value)}" ) - # jax>=0.7.2 introduced `TypedNdArray` to represent constants in jaxpr, and they - # have no reshape method. - result = result.reshape( # type: ignore[union-attr] + result = jnp.reshape( + result, jax.numpy.shape(result)[ : jax.numpy.ndim(result) - self.reinterpreted_batch_ndims ] - + (-1,) + + (-1,), ) result = result.all(-1) return result From 5c331a17f7f335869ab5ad8f06ffcfaa9ed41fc1 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 24 Oct 2025 02:08:40 +0500 Subject: [PATCH 06/10] refactor: replace `jnp` with dynamic use of `jax.numpy` or `numpy` --- numpyro/distributions/constraints.py | 48 +++++++++++++++++----------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7fa6f2c0b..13de30c24 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -154,7 +154,8 @@ class _Boolean(_SingletonConstraint[NumLike]): _is_discrete = True def __call__(self, x: NumLike) -> ArrayLike: - return jnp.logical_or(jnp.equal(x, 0), jnp.equal(x, 1)) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.logical_or(xp.equal(x, 0), xp.equal(x, 1)) def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) @@ -299,7 +300,8 @@ def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound def __call__(self, x: NumLike) -> ArrayLike: - return jnp.greater(x, self.lower_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.greater(x, self.lower_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -320,7 +322,8 @@ def __eq__(self, other: object) -> bool: class _GreaterThanEq(_GreaterThan): def __call__(self, x: NumLike) -> ArrayLike: - return jnp.greater_equal(x, self.lower_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.greater_equal(x, self.lower_bound) def __eq__(self, other: object) -> bool: if not isinstance(other, _GreaterThanEq): @@ -423,7 +426,8 @@ def __init__(self, upper_bound: NumLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return jnp.less(x, self.upper_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.less(x, self.upper_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -444,7 +448,8 @@ def __eq__(self, other: object) -> bool: class _LessThanEq(_LessThan): def __call__(self, x: NumLike) -> ArrayLike: - return jnp.less_equal(x, self.upper_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.less_equal(x, self.upper_bound) def __eq__(self, other: object) -> bool: if not isinstance(other, _LessThanEq): @@ -500,8 +505,9 @@ def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound def __call__(self, x: NumLike) -> ArrayLike: - return jnp.logical_and( - jnp.equal(jnp.mod(x, 1), 0), jnp.greater_equal(x, self.lower_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.logical_and( + xp.equal(xp.mod(x, 1), 0), xp.greater_equal(x, self.lower_bound) ) def __repr__(self) -> str: @@ -537,8 +543,9 @@ def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return jnp.logical_and( - jnp.greater_equal(x, self.lower_bound), jnp.less_equal(x, self.upper_bound) + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.logical_and( + xp.greater_equal(x, self.lower_bound), xp.less_equal(x, self.upper_bound) ) def __repr__(self) -> str: @@ -579,9 +586,10 @@ def __init__(self) -> None: class _OpenInterval(_Interval): def __call__(self, x: NumLike) -> ArrayLike: - return jnp.logical_and( - jnp.greater(x, self.lower_bound), - jnp.less(x, self.upper_bound), + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.logical_and( + xp.greater(x, self.lower_bound), + xp.less(x, self.upper_bound), ) def __repr__(self) -> str: @@ -618,9 +626,10 @@ def __init__(self, upper_bound: ArrayLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NonScalarArray) -> ArrayLike: - return jnp.logical_and( + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.logical_and( (x >= 0).all(axis=-1), - jnp.equal(x.sum(axis=-1), self.upper_bound), + xp.equal(x.sum(axis=-1), self.upper_bound), ) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -737,11 +746,12 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _Complex(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval - return jnp.logical_and( - jnp.equal(x, x), - jnp.logical_and( - jnp.not_equal(x, float("inf")), - jnp.not_equal(x, float("-inf")), + xp = jax.numpy if isinstance(x, jax.Array) else np + return xp.logical_and( + xp.equal(x, x), + xp.logical_and( + xp.not_equal(x, float("inf")), + xp.not_equal(x, float("-inf")), ), ) From 642ce7e1766d798d6162ca4e3f019c5ed257bccc Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 24 Oct 2025 23:09:01 +0500 Subject: [PATCH 07/10] Revert "refactor: remove unused `NumLikeT` type variable and streamline imports in typing modules" This reverts commit 78d8e93af378a19399d590c8ac52ed93cd5dfc2e. --- numpyro/_typing.py | 5 ++++- numpyro/distributions/constraints.py | 6 ++---- numpyro/distributions/transforms.py | 14 +++++++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/numpyro/_typing.py b/numpyro/_typing.py index 861496ba8..684f05a3d 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -4,7 +4,7 @@ from collections import OrderedDict from collections.abc import Callable -from typing import Any, Optional, Protocol, Union, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, Union, runtime_checkable import weakref try: @@ -37,6 +37,9 @@ """A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" +NumLikeT = TypeVar("NumLikeT", bound=NumLike) + + @runtime_checkable class ConstraintT(Protocol): """A protocol for typing constraints.""" diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 13de30c24..894ce9364 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -64,7 +64,7 @@ ] import math -from typing import Generic, Optional, TypeVar +from typing import Generic, Optional import numpy as np @@ -73,9 +73,7 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike - -NumLikeT = TypeVar("NumLikeT", bound=NumLike) +from numpyro._typing import ConstraintT, NonScalarArray, NumLike, NumLikeT class Constraint(Generic[NumLikeT]): diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4512225c8..11cf179e0 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -3,7 +3,7 @@ import math -from typing import Generic, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Generic, Optional, Sequence, Tuple, Union, cast import warnings import weakref @@ -18,7 +18,14 @@ from jax.tree_util import register_pytree_node from jax.typing import ArrayLike -from numpyro._typing import ConstraintT, NonScalarArray, NumLike, PyTree, TransformT +from numpyro._typing import ( + ConstraintT, + NonScalarArray, + NumLike, + NumLikeT, + PyTree, + TransformT, +) from numpyro.distributions import constraints from numpyro.distributions.util import ( add_diag, @@ -65,9 +72,6 @@ def _clipped_expit(x: NumLike) -> NumLike: return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) -NumLikeT = TypeVar("NumLikeT", bound=NumLike) - - class Transform(Generic[NumLikeT]): _inv: Optional[Union[TransformT, weakref.ref]] = None From def4e0be7fe62a0750d1a4a27575da4fcad4a547 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sat, 25 Oct 2025 00:48:11 +0500 Subject: [PATCH 08/10] refactor: replace `jnp` with `xp` and replace logical operations with bitwise operators --- numpyro/distributions/constraints.py | 127 ++++++++++++--------------- 1 file changed, 56 insertions(+), 71 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 894ce9364..18a0734a2 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -153,7 +153,7 @@ class _Boolean(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: xp = jax.numpy if isinstance(x, jax.Array) else np - return xp.logical_or(xp.equal(x, 0), xp.equal(x, 1)) + return xp.equal(x, 0) | xp.equal(x, 1) def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) @@ -163,15 +163,13 @@ class _CorrCholesky(_SingletonConstraint[NonScalarArray]): _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tril = jnp.tril(x) - lower_triangular = jnp.all( - jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1 - ) - positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) - x_norm = jnp.linalg.norm(x, axis=-1) - tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 - unit_norm_row = jnp.all(jnp.abs(x_norm - 1) <= tol, axis=-1) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tril = xp.tril(x) + lower_triangular = xp.all(xp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1) + positive_diagonal = xp.all(xp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) + x_norm = xp.linalg.norm(x, axis=-1) + tol = xp.finfo(x.dtype).eps * x.shape[-1] * 10 + unit_norm_row = xp.all(xp.abs(x_norm - 1) <= tol, axis=-1) return lower_triangular & positive_diagonal & unit_norm_row def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -184,16 +182,16 @@ class _CorrMatrix(_SingletonConstraint[NonScalarArray]): _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + symmetric = xp.all(xp.isclose(x, xp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is positive - positive = jnp.linalg.eigvalsh(x)[..., 0] > 0 + positive = xp.linalg.eigvalsh(x)[..., 0] > 0 # check for diagonal equal to 1 - unit_variance = jnp.all( - jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1 + unit_variance = xp.all( + xp.abs(xp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1 ) - return jnp.logical_and(jnp.logical_and(symmetric, positive), unit_variance) + return symmetric & positive & unit_variance def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -355,8 +353,8 @@ def __init__(self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int) reinterpreted_batch_ndims + base_constraint.reinterpreted_batch_ndims ) base_constraint = base_constraint.base_constraint - self.base_constraint = base_constraint - self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + self.base_constraint: Constraint = base_constraint + self.reinterpreted_batch_ndims: int = reinterpreted_batch_ndims self._is_discrete = base_constraint.is_discrete self._event_dim = base_constraint.event_dim + reinterpreted_batch_ndims super().__init__() @@ -463,12 +461,11 @@ def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.upper_bound = upper_bound def __call__(self, x: NumLike) -> ArrayLike: - return jnp.logical_and( - jnp.logical_and( - jnp.greater_equal(x, self.lower_bound), - jnp.less_equal(x, self.upper_bound), - ), - jnp.equal(jnp.mod(x, 1), 0), + xp = jax.numpy if isinstance(x, jax.Array) else np + return ( + xp.greater_equal(x, self.lower_bound) + & xp.less_equal(x, self.upper_bound) + & xp.equal(xp.mod(x, 1), 0) ) def __repr__(self) -> str: @@ -504,9 +501,7 @@ def __init__(self, lower_bound: NumLike) -> None: def __call__(self, x: NumLike) -> ArrayLike: xp = jax.numpy if isinstance(x, jax.Array) else np - return xp.logical_and( - xp.equal(xp.mod(x, 1), 0), xp.greater_equal(x, self.lower_bound) - ) + return (xp.mod(x, 1) == 0) & xp.greater_equal(x, self.lower_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -542,8 +537,8 @@ def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: def __call__(self, x: NumLike) -> ArrayLike: xp = jax.numpy if isinstance(x, jax.Array) else np - return xp.logical_and( - xp.greater_equal(x, self.lower_bound), xp.less_equal(x, self.upper_bound) + return xp.greater_equal(x, self.lower_bound) & xp.less_equal( + x, self.upper_bound ) def __repr__(self) -> str: @@ -585,10 +580,7 @@ def __init__(self) -> None: class _OpenInterval(_Interval): def __call__(self, x: NumLike) -> ArrayLike: xp = jax.numpy if isinstance(x, jax.Array) else np - return xp.logical_and( - xp.greater(x, self.lower_bound), - xp.less(x, self.upper_bound), - ) + return xp.greater(x, self.lower_bound) & xp.less(x, self.upper_bound) def __repr__(self) -> str: fmt_string = self.__class__.__name__[1:] @@ -602,13 +594,11 @@ class _LowerCholesky(_SingletonConstraint[NonScalarArray]): _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tril = jnp.tril(x) - lower_triangular = jnp.all( - jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1 - ) - positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) - return jnp.logical_and(lower_triangular, positive_diagonal) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tril = xp.tril(x) + lower_triangular = xp.all(xp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1) + positive_diagonal = xp.all(xp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1) + return lower_triangular & positive_diagonal def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -625,10 +615,7 @@ def __init__(self, upper_bound: ArrayLike) -> None: def __call__(self, x: NonScalarArray) -> ArrayLike: xp = jax.numpy if isinstance(x, jax.Array) else np - return xp.logical_and( - (x >= 0).all(axis=-1), - xp.equal(x.sum(axis=-1), self.upper_bound), - ) + return (x >= 0).all(axis=-1) & xp.equal(x.sum(axis=-1), self.upper_bound) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: pad_width = ((0, 0),) * jax.numpy.ndim(self.upper_bound) + ( @@ -655,9 +642,9 @@ class _L1Ball(_SingletonConstraint[NumLike]): reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NumLike) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - eps = jnp.finfo(x.dtype if isinstance(x, jnp.ndarray) else type(x)).eps - return jnp.abs(x).sum(axis=-1) < 1 + self.reltol * eps + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + eps = xp.finfo(x.dtype if isinstance(x, xp.ndarray) else type(x)).eps + return xp.abs(x).sum(axis=-1) < 1 + self.reltol * eps def feasible_like(self, prototype: NumLike) -> NumLike: return jax.numpy.zeros_like(prototype) @@ -679,11 +666,11 @@ class _PositiveDefinite(_SingletonConstraint[NonScalarArray]): _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + symmetric = xp.all(xp.isclose(x, xp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is positive - positive = jnp.linalg.eigh(x)[0][..., 0] > 0 + positive = xp.linalg.eigh(x)[0][..., 0] > 0 return symmetric & positive def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -696,10 +683,10 @@ class _PositiveDefiniteCirculantVector(_SingletonConstraint[NonScalarArray]): _event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tol = 10 * jnp.finfo(x.dtype).eps - rfft = jnp.fft.rfft(x) - return jnp.logical_and(jnp.abs(rfft.imag) < tol, jnp.greater(rfft.real, -tol)) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tol = 10 * xp.finfo(x.dtype).eps + rfft = xp.fft.rfft(x) + return (xp.abs(rfft.imag) < tol) & (rfft.real > -tol) def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jnp.zeros_like(prototype).at[..., 0].set(1.0) @@ -709,12 +696,12 @@ class _PositiveSemiDefinite(_SingletonConstraint[NonScalarArray]): _event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) + symmetric = xp.all(xp.isclose(x, xp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is nonnegative - nonnegative = jnp.linalg.eigh(x)[0][..., 0] >= 0 - return jnp.logical_and(symmetric, nonnegative) + nonnegative = xp.linalg.eigh(x)[0][..., 0] >= 0 + return symmetric & nonnegative def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: return jax.numpy.broadcast_to( @@ -745,12 +732,10 @@ class _Complex(_SingletonConstraint[NumLike]): def __call__(self, x: NumLike) -> ArrayLike: # XXX: consider to relax this condition to [-inf, inf] interval xp = jax.numpy if isinstance(x, jax.Array) else np - return xp.logical_and( - xp.equal(x, x), - xp.logical_and( - xp.not_equal(x, float("inf")), - xp.not_equal(x, float("-inf")), - ), + return ( + xp.equal(x, x) + & xp.not_equal(x, float("inf")) + & xp.not_equal(x, float("-inf")) ) def feasible_like(self, prototype: NumLike) -> NumLike: @@ -805,10 +790,10 @@ class _Sphere(_SingletonConstraint[NonScalarArray]): reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - eps = jnp.finfo(x.dtype).eps - norm = jnp.linalg.norm(x, axis=-1) - error = jnp.abs(norm - 1) + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + eps = xp.finfo(x.dtype).eps + norm = xp.linalg.norm(x, axis=-1) + error = xp.abs(norm - 1) return error < self.reltol * eps * x.shape[-1] ** 0.5 def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: @@ -821,11 +806,11 @@ def __init__(self, event_dim: int = 1) -> None: super().__init__() def __call__(self, x: NonScalarArray) -> ArrayLike: - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 + xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tol = xp.finfo(x.dtype).eps * x.shape[-1] * 10 zerosum_true = True for dim in range(-self.event_dim, 0): - zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol) + zerosum_true = zerosum_true & xp.allclose(x.sum(dim), 0, atol=tol) return zerosum_true def __eq__(self, other: object) -> bool: From b4e2b816a757f8425c76bae113193fbd1fdb98bc Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sat, 25 Oct 2025 00:54:09 +0500 Subject: [PATCH 09/10] refactor: update type parameters in `_RealVector` and `_RealMatrix` classes to use `NonScalarArray` --- numpyro/distributions/constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 18a0734a2..0f299dd3d 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -404,14 +404,14 @@ def __eq__(self, other: object) -> bool: class _RealVector( - _IndependentConstraint[NumLike], _SingletonConstraint[NonScalarArray] + _IndependentConstraint[NonScalarArray], _SingletonConstraint[NonScalarArray] ): def __init__(self) -> None: super().__init__(_Real(), 1) class _RealMatrix( - _IndependentConstraint[NumLike], _SingletonConstraint[NonScalarArray] + _IndependentConstraint[NonScalarArray], _SingletonConstraint[NonScalarArray] ): def __init__(self) -> None: super().__init__(_Real(), 2) From d609a7f231ae9ef382945697cceca721b015435d Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sat, 25 Oct 2025 01:06:26 +0500 Subject: [PATCH 10/10] refactor: replace private attributes with public properties for `is_discrete` and `event_dim` in `Constraint` class --- numpyro/distributions/constraints.py | 58 ++++++++++++---------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 0f299dd3d..f7c1e846a 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -84,8 +84,8 @@ class Constraint(Generic[NumLikeT]): e.g. within which a variable can be optimized. """ - _is_discrete = False - _event_dim = 0 + is_discrete: bool = False + event_dim: int = 0 def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -110,14 +110,6 @@ def feasible_like(self, prototype: NumLikeT) -> NumLikeT: """ raise NotImplementedError - @property - def is_discrete(self) -> bool: - return self._is_discrete - - @property - def event_dim(self) -> int: - return self._event_dim - @classmethod def tree_unflatten(cls, aux_data, params): params_keys, aux_data = aux_data @@ -149,7 +141,7 @@ def __new__(cls): class _Boolean(_SingletonConstraint[NumLike]): - _is_discrete = True + is_discrete = True def __call__(self, x: NumLike) -> ArrayLike: xp = jax.numpy if isinstance(x, jax.Array) else np @@ -160,7 +152,7 @@ def feasible_like(self, prototype: NumLike) -> NumLike: class _CorrCholesky(_SingletonConstraint[NonScalarArray]): - _event_dim = 2 + event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -179,7 +171,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _CorrMatrix(_SingletonConstraint[NonScalarArray]): - _event_dim = 2 + event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -226,7 +218,7 @@ def is_discrete(self): return self._is_discrete @property - def event_dim(self) -> int: + def event_dim(self) -> int: # type: ignore[override] if self._event_dim is NotImplemented: raise NotImplementedError(".event_dim cannot be determined statically") return self._event_dim @@ -260,7 +252,7 @@ def __eq__(self, other: object) -> bool: def tree_flatten(self): return (), ( (), - dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim), + dict(_is_discrete=self._is_discrete, _event_dim=self.event_dim), ) @@ -272,7 +264,7 @@ def __init__( ): super().__init__(fn) self._is_discrete = is_discrete - self._event_dim = event_dim + self.event_dim = event_dim def __call__(self, x: NumLikeT) -> ArrayLike: if not callable(x): @@ -283,7 +275,7 @@ def __call__(self, x: NumLikeT) -> ArrayLike: # def support(self): # ... return dependent_property( - x, is_discrete=self._is_discrete, event_dim=self._event_dim + x, is_discrete=self._is_discrete, event_dim=self.event_dim ) @@ -355,8 +347,8 @@ def __init__(self, base_constraint: ConstraintT, reinterpreted_batch_ndims: int) base_constraint = base_constraint.base_constraint self.base_constraint: Constraint = base_constraint self.reinterpreted_batch_ndims: int = reinterpreted_batch_ndims - self._is_discrete = base_constraint.is_discrete - self._event_dim = base_constraint.event_dim + reinterpreted_batch_ndims + self.is_discrete = base_constraint.is_discrete + self.event_dim = base_constraint.event_dim + reinterpreted_batch_ndims super().__init__() def __call__(self, value: NumLikeT) -> ArrayLike: @@ -454,7 +446,7 @@ def __eq__(self, other: object) -> bool: class _IntegerInterval(Constraint[NumLike]): - _is_discrete = True + is_discrete = True def __init__(self, lower_bound: NumLike, upper_bound: NumLike) -> None: self.lower_bound = lower_bound @@ -494,7 +486,7 @@ def __eq__(self, other: object) -> bool: class _IntegerGreaterThan(Constraint[NumLike]): - _is_discrete = True + is_discrete = True def __init__(self, lower_bound: NumLike) -> None: self.lower_bound = lower_bound @@ -591,7 +583,7 @@ def __repr__(self) -> str: class _LowerCholesky(_SingletonConstraint[NonScalarArray]): - _event_dim = 2 + event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -607,8 +599,8 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _Multinomial(Constraint[NonScalarArray]): - _is_discrete = True - _event_dim = 1 + is_discrete = True + event_dim = 1 def __init__(self, upper_bound: ArrayLike) -> None: self.upper_bound = upper_bound @@ -638,7 +630,7 @@ class _L1Ball(_SingletonConstraint[NumLike]): Constrain to the L1 ball of any dimension. """ - _event_dim = 1 + event_dim = 1 reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NumLike) -> ArrayLike: @@ -651,7 +643,7 @@ def feasible_like(self, prototype: NumLike) -> NumLike: class _OrderedVector(_SingletonConstraint[NonScalarArray]): - _event_dim = 1 + event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: return (x[..., 1:] > x[..., :-1]).all(axis=-1) @@ -663,7 +655,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _PositiveDefinite(_SingletonConstraint[NonScalarArray]): - _event_dim = 2 + event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -680,7 +672,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _PositiveDefiniteCirculantVector(_SingletonConstraint[NonScalarArray]): - _event_dim = 1 + event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -693,7 +685,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _PositiveSemiDefinite(_SingletonConstraint[NonScalarArray]): - _event_dim = 2 + event_dim = 2 def __call__(self, x: NonScalarArray) -> ArrayLike: xp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy @@ -715,7 +707,7 @@ class _PositiveOrderedVector(_SingletonConstraint[NonScalarArray]): increasing along the `event_shape` dimension. """ - _event_dim = 1 + event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: return jnp.logical_and( @@ -752,7 +744,7 @@ def feasible_like(self, prototype: NumLike) -> NumLike: class _Simplex(_SingletonConstraint[NonScalarArray]): - _event_dim = 1 + event_dim = 1 def __call__(self, x: NonScalarArray) -> ArrayLike: x_sum = x.sum(axis=-1) @@ -786,7 +778,7 @@ class _Sphere(_SingletonConstraint[NonScalarArray]): Constrain to the Euclidean sphere of any dimension. """ - _event_dim = 1 + event_dim = 1 reltol = 10.0 # Relative to finfo.eps. def __call__(self, x: NonScalarArray) -> ArrayLike: @@ -802,7 +794,7 @@ def feasible_like(self, prototype: NonScalarArray) -> NonScalarArray: class _ZeroSum(Constraint[NonScalarArray]): def __init__(self, event_dim: int = 1) -> None: - self._event_dim = event_dim + self.event_dim = event_dim super().__init__() def __call__(self, x: NonScalarArray) -> ArrayLike: