Skip to content

Commit 4fa4cd9

Browse files
Jammy2211Jammy2211
authored andcommitted
this is a mess
1 parent e90fb4f commit 4fa4cd9

3 files changed

Lines changed: 52 additions & 10 deletions

File tree

autofit/messages/abstract.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,21 @@ def __init__(
5252

5353
self.id = next(self.ids) if id_ is None else id_
5454
self.log_norm = log_norm
55-
self._broadcast = np.broadcast(*parameters)
55+
56+
self._broadcast = None
57+
self._broadcast_jnp = None
58+
59+
if isinstance(parameters[0], (np.float64, float, int)):
60+
self._broadcast = np.broadcast(*parameters)
61+
else:
62+
import jax.numpy as jnp
63+
self._broadcast_jnp = jnp.broadcast_arrays(*parameters)
5664

5765
if self.shape:
58-
self.parameters = tuple(np.asanyarray(p) for p in parameters)
66+
if isinstance(parameters[0], (np.float64, float, int)):
67+
self.parameters = tuple(np.asanyarray(p) for p in parameters)
68+
else:
69+
self.parameters = tuple(jnp.asarray(p) for p in parameters)
5970
else:
6071
self.parameters = tuple(parameters)
6172

autofit/messages/interface.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ def broadcast(self):
2323

2424
@property
2525
def shape(self) -> Tuple[int, ...]:
26-
return self.broadcast.shape
26+
27+
if self.broadcast is not None:
28+
return self.broadcast.shape
29+
30+
return ()
2731

2832
@property
2933
def size(self) -> int:
@@ -47,6 +51,7 @@ def logpdf(self, x: Union[np.ndarray, float]) -> np.ndarray:
4751

4852
def _broadcast_natural_parameters(self, x):
4953
shape = np.shape(x)
54+
print(shape, self.shape)
5055
if shape == self.shape:
5156
return self.natural_parameters
5257
elif shape[1:] == self.shape:
@@ -78,8 +83,15 @@ def log_partition(self) -> np.ndarray:
7883

7984
@classmethod
8085
def natural_logpdf(cls, eta, t, log_base, log_partition):
81-
eta_t = np.multiply(eta, t).sum(0)
82-
return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf)
86+
87+
if isinstance(eta, (np.ndarray, np.float64)):
88+
eta_t = np.multiply(eta, t).sum(0)
89+
return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf)
90+
91+
import jax.numpy as jnp
92+
93+
eta_t = jnp.multiply(eta, t).sum(0)
94+
return jnp.nan_to_num(log_base + eta_t - log_partition, nan=-jnp.inf)
8395

8496
def numerical_logpdf_gradient(
8597
self, x: np.ndarray, eps: float = 1e-6

autofit/messages/normal.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def log_partition(self):
4040
This ensures normalization of the exponential-family distribution.
4141
"""
4242
eta1, eta2 = self.natural_parameters
43-
return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2
43+
44+
if isinstance(eta1, (np.ndarray, np.float64)):
45+
return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2
46+
47+
import jax.numpy as jnp
48+
49+
return -(eta1**2) / 4 / eta2 - jnp.log(-2 * eta2) / 2
4450

4551
log_base_measure = -0.5 * np.log(2 * np.pi)
4652
_support = ((-np.inf, np.inf),)
@@ -73,8 +79,9 @@ def __init__(
7379
id_
7480
An optional unique identifier used to track the message in larger probabilistic graphs or models.
7581
"""
76-
if (np.array(sigma) < 0).any():
77-
raise exc.MessageException("Sigma cannot be negative")
82+
if isinstance(sigma, (float, int, np.ndarray)):
83+
if (np.array(sigma) < 0).any():
84+
raise exc.MessageException("Sigma cannot be negative")
7885

7986
super().__init__(
8087
mean,
@@ -158,7 +165,13 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float,
158165
η₂ = -1 / (2σ²)
159166
"""
160167
precision = 1 / sigma**2
161-
return np.array([mu * precision, -precision / 2])
168+
169+
if isinstance(mu, (np.ndarray, np.float64)):
170+
return np.array([mu * precision, -precision / 2])
171+
172+
import jax.numpy as jnp
173+
174+
return jnp.array([mu * precision, -precision / 2])
162175

163176
@staticmethod
164177
def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]:
@@ -197,7 +210,13 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray:
197210
-------
198211
The sufficient statistics [x, x²].
199212
"""
200-
return np.array([x, x**2])
213+
214+
if isinstance(x, (np.ndarray, np.float64)):
215+
return np.array([x, x**2])
216+
217+
import jax.numpy as jnp
218+
219+
return jnp.array([x, x**2])
201220

202221
@classmethod
203222
def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray:

0 commit comments

Comments
 (0)