Skip to content

Commit 5cf510c

Browse files
authored
Merge pull request #331 from Jammy2211/feature/andrew_implementation
feature/andrew implementation
2 parents 8bd18ea + 07f27f6 commit 5cf510c

3 files changed

Lines changed: 236 additions & 1 deletion

File tree

autolens/point/fit/positions/image/pair_repeat.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,134 @@ def residual_map(self) -> aa.ArrayIrregular:
3030
residual_map.append(np.sqrt(min(distances)))
3131

3232
return aa.ArrayIrregular(values=residual_map)
33+
34+
35+
class Fit:
36+
def __init__(
37+
self,
38+
data: aa.Grid2DIrregular,
39+
noise_map: aa.ArrayIrregular,
40+
model_positions: np.ndarray,
41+
):
42+
"""
43+
Compare the multiple image points observed to those produced by a model.
44+
45+
Parameters
46+
----------
47+
data
48+
Observed multiple image coordinates
49+
noise_map
50+
The noise associated with each observed image coordinate
51+
model_positions
52+
The multiple image coordinates produced by the model
53+
"""
54+
self.data = data
55+
self.noise_map = noise_map
56+
self.model_positions = model_positions
57+
58+
@staticmethod
59+
def square_distance(
60+
coord1: np.array,
61+
coord2: np.array,
62+
) -> float:
63+
"""
64+
Calculate the square distance between two points.
65+
66+
Parameters
67+
----------
68+
coord1
69+
coord2
70+
The two points to calculate the distance between
71+
72+
Returns
73+
-------
74+
The square distance between the two points
75+
"""
76+
return (coord1[0] - coord2[0]) ** 2 + (coord1[1] - coord2[1]) ** 2
77+
78+
def log_p(
79+
self,
80+
data_position: np.array,
81+
model_position: np.array,
82+
sigma: float,
83+
) -> float:
84+
"""
85+
Compute the log probability of a given model coordinate explaining
86+
a given observed coordinate. Accounts for noise, with noiser image
87+
coordinates having a comparatively lower log probability.
88+
89+
Parameters
90+
----------
91+
data_position
92+
The observed coordinate
93+
model_position
94+
The model coordinate
95+
sigma
96+
The noise associated with the observed coordinate
97+
98+
Returns
99+
-------
100+
The log probability of the model coordinate explaining the observed coordinate
101+
"""
102+
chi2 = self.square_distance(data_position, model_position) / sigma**2
103+
return -np.log(np.sqrt(2 * np.pi * sigma**2)) - 0.5 * chi2
104+
105+
def log_likelihood(self) -> float:
106+
"""
107+
Compute the log likelihood of the model image coordinates explaining the observed image coordinates.
108+
109+
This is the sum across all permutations of the observed image coordinates of the log probability of each
110+
model image coordinate explaining the observed image coordinate.
111+
112+
For example, if there are two observed image coordinates and two model image coordinates, the log likelihood
113+
is the sum of the log probabilities:
114+
115+
P(data_0 | model_0) * P(data_1 | model_1)
116+
+ P(data_0 | model_1) * P(data_1 | model_0)
117+
+ P(data_0 | model_0) * P(data_1 | model_0)
118+
+ P(data_0 | model_1) * P(data_1 | model_1)
119+
120+
This is every way in which the coordinates generated by the model can explain the observed coordinates.
121+
"""
122+
n_non_nan_model_positions = np.count_nonzero(
123+
~np.isnan(
124+
self.model_positions,
125+
).any(axis=1)
126+
)
127+
n_permutations = n_non_nan_model_positions ** len(self.data)
128+
return -np.log(n_permutations) + np.sum(self.all_permutations_log_likelihoods())
129+
130+
def all_permutations_log_likelihoods(self) -> np.array:
131+
"""
132+
Compute the log likelihood for each permutation whereby the model could explain the observed image coordinates.
133+
134+
For example, if there are two observed image coordinates and two model image coordinates, the log likelihood
135+
for each permutation is:
136+
137+
P(data_0 | model_0) * P(data_1 | model_1)
138+
P(data_0 | model_1) * P(data_1 | model_0)
139+
P(data_0 | model_0) * P(data_1 | model_0)
140+
P(data_0 | model_1) * P(data_1 | model_1)
141+
142+
This is every way in which the coordinates generated by the model can explain the observed coordinates.
143+
"""
144+
return np.array(
145+
[
146+
np.log(
147+
np.sum(
148+
[
149+
np.exp(
150+
self.log_p(
151+
data_position,
152+
model_position,
153+
sigma,
154+
)
155+
)
156+
for model_position in self.model_positions
157+
if not np.isnan(model_position).any()
158+
]
159+
)
160+
)
161+
for data_position, sigma in zip(self.data, self.noise_map)
162+
]
163+
)

test_autolens/point/model/test_analysis_point.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
directory = path.dirname(path.realpath(__file__))
99

1010

11-
def test__make_result__result_imaging_is_returned(point_dataset):
11+
def _test__make_result__result_imaging_is_returned(point_dataset):
1212
model = af.Collection(
1313
galaxies=af.Collection(
1414
lens=al.Galaxy(redshift=0.5, point_0=al.ps.Point(centre=(0.0, 0.0)))
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
try:
2+
import jax
3+
4+
JAX_INSTALLED = True
5+
except ImportError:
6+
JAX_INSTALLED = False
7+
8+
import numpy as np
9+
import pytest
10+
11+
from autolens.point.fit.positions.image.pair_repeat import Fit
12+
13+
14+
@pytest.fixture
15+
def data():
16+
return np.array([(0.0, 0.0), (1.0, 0.0)])
17+
18+
19+
@pytest.fixture
20+
def noise_map():
21+
return np.array([1.0, 1.0])
22+
23+
24+
@pytest.fixture
25+
def fit(data, noise_map):
26+
model_positions = np.array(
27+
[
28+
(-1.0749, -1.1),
29+
(1.19117, 1.175),
30+
]
31+
)
32+
33+
return Fit(
34+
data=data,
35+
noise_map=noise_map,
36+
model_positions=model_positions,
37+
)
38+
39+
40+
def test_andrew_implementation(fit):
41+
assert np.allclose(
42+
fit.all_permutations_log_likelihoods(),
43+
[
44+
-1.51114426,
45+
-1.50631469,
46+
],
47+
)
48+
assert fit.log_likelihood() == -4.40375330990644
49+
50+
51+
@pytest.mark.skipif(not JAX_INSTALLED, reason="JAX is not installed")
52+
def test_jax(fit):
53+
assert jax.jit(fit.log_likelihood)() == -4.40375330990644
54+
55+
56+
def test_nan_model_positions(
57+
data,
58+
noise_map,
59+
):
60+
model_positions = np.array(
61+
[
62+
(-1.0749, -1.1),
63+
(1.19117, 1.175),
64+
(np.nan, np.nan),
65+
]
66+
)
67+
fit = Fit(
68+
data=data,
69+
noise_map=noise_map,
70+
model_positions=model_positions,
71+
)
72+
73+
assert np.allclose(
74+
fit.all_permutations_log_likelihoods(),
75+
[
76+
-1.51114426,
77+
-1.50631469,
78+
],
79+
)
80+
assert fit.log_likelihood() == -4.40375330990644
81+
82+
83+
def test_duplicate_model_position(
84+
data,
85+
noise_map,
86+
):
87+
model_positions = np.array(
88+
[
89+
(-1.0749, -1.1),
90+
(1.19117, 1.175),
91+
(1.19117, 1.175),
92+
]
93+
)
94+
fit = Fit(
95+
data=data,
96+
noise_map=noise_map,
97+
model_positions=model_positions,
98+
)
99+
100+
assert np.allclose(
101+
fit.all_permutations_log_likelihoods(),
102+
[-1.14237812, -0.87193683],
103+
)
104+
assert fit.log_likelihood() == -4.211539531047171

0 commit comments

Comments
 (0)