@@ -40,7 +40,13 @@ def log_partition(self):
4040 This ensures normalization of the exponential-family distribution.
4141 """
4242 eta1 , eta2 = self .natural_parameters
43- return - (eta1 ** 2 ) / 4 / eta2 - np .log (- 2 * eta2 ) / 2
43+
44+ if isinstance (eta1 , (np .ndarray , np .float64 )):
45+ return - (eta1 ** 2 ) / 4 / eta2 - np .log (- 2 * eta2 ) / 2
46+
47+ import jax .numpy as jnp
48+
49+ return - (eta1 ** 2 ) / 4 / eta2 - jnp .log (- 2 * eta2 ) / 2
4450
4551 log_base_measure = - 0.5 * np .log (2 * np .pi )
4652 _support = ((- np .inf , np .inf ),)
@@ -73,8 +79,9 @@ def __init__(
7379 id_
7480 An optional unique identifier used to track the message in larger probabilistic graphs or models.
7581 """
76- if (np .array (sigma ) < 0 ).any ():
77- raise exc .MessageException ("Sigma cannot be negative" )
82+ if isinstance (sigma , (float , int , np .ndarray )):
83+ if (np .array (sigma ) < 0 ).any ():
84+ raise exc .MessageException ("Sigma cannot be negative" )
7885
7986 super ().__init__ (
8087 mean ,
@@ -158,7 +165,13 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float,
158165 η₂ = -1 / (2σ²)
159166 """
160167 precision = 1 / sigma ** 2
161- return np .array ([mu * precision , - precision / 2 ])
168+
169+ if isinstance (mu , (np .ndarray , np .float64 )):
170+ return np .array ([mu * precision , - precision / 2 ])
171+
172+ import jax .numpy as jnp
173+
174+ return jnp .array ([mu * precision , - precision / 2 ])
162175
163176 @staticmethod
164177 def invert_natural_parameters (natural_parameters : np .ndarray ) -> Tuple [float , float ]:
@@ -197,7 +210,13 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray:
197210 -------
198211 The sufficient statistics [x, x²].
199212 """
200- return np .array ([x , x ** 2 ])
213+
214+ if isinstance (x , (np .ndarray , np .float64 )):
215+ return np .array ([x , x ** 2 ])
216+
217+ import jax .numpy as jnp
218+
219+ return jnp .array ([x , x ** 2 ])
201220
202221 @classmethod
203222 def invert_sufficient_statistics (cls , suff_stats : Tuple [float , float ]) -> np .ndarray :
0 commit comments