Skip to content

Commit 79e897f

Browse files
committed
all point solver issues fixed
1 parent 61be7da commit 79e897f

3 files changed

Lines changed: 23 additions & 7 deletions

File tree

autolens/point/solver/point_solver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import jax.numpy as jnp
12
import logging
23
from typing import Tuple, Optional
34

@@ -54,4 +55,8 @@ def solve(
5455
tracer=tracer, points=kept_triangles.means
5556
)
5657

57-
return aa.Grid2DIrregular([pair for pair in filtered_means])
58+
arr = aa.Grid2DIrregular([pair for pair in filtered_means])
59+
60+
mask = ~jnp.isnan(arr.array).any(axis=1)
61+
return aa.Grid2DIrregular(arr.array[mask])
62+

test_autolens/plot/test_get_visuals.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def test__2d__via_tracer(tracer_x2_plane_7x7, grid_2d_7x7):
4545
visuals_2d_via.tangential_critical_curves[0]
4646
== tracer_x2_plane_7x7.tangential_critical_curve_list_from(grid=grid_2d_7x7)[0]
4747
).all()
48-
assert visuals_2d_via.radial_critical_curves == None
48+
assert (
49+
visuals_2d_via.radial_critical_curves[0]
50+
== tracer_x2_plane_7x7.radial_critical_curve_list_from(grid=grid_2d_7x7)[0]
51+
).all()
4952
assert visuals_2d_via.vectors == 2
5053

5154
include_2d = aplt.Include2D(
@@ -134,7 +137,12 @@ def test__via_fit_imaging_from(fit_imaging_x2_plane_7x7, grid_2d_7x7):
134137
grid=grid_2d_7x7
135138
)[0]
136139
).all()
137-
assert visuals_2d_via.radial_critical_curves == None
140+
assert (
141+
visuals_2d_via.radial_critical_curves[0]
142+
== fit_imaging_x2_plane_7x7.tracer.radial_critical_curve_list_from(
143+
grid=grid_2d_7x7
144+
)[0]
145+
).all()
138146
assert visuals_2d_via.vectors == 2
139147

140148
include_2d = aplt.Include2D(

test_autolens/point/fit/positions/image/test_abstract.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,12 @@ def test__multi_plane_position_solving():
7676
redshift_0=0.5, redshift_1=1.0, redshift_final=2.0
7777
)
7878

79-
assert fit_0.model_data[0, 0] == pytest.approx(
80-
scaling_factor * fit_1.model_data[1, 0], 1.0e-1
79+
print(fit_0.model_data)
80+
print(fit_1.model_data.array)
81+
82+
assert fit_0.model_data[0, :] == pytest.approx(
83+
scaling_factor * fit_1.model_data.array[0, :], 1.0e-1
8184
)
82-
assert fit_0.model_data[1, 1] == pytest.approx(
83-
scaling_factor * fit_1.model_data[0, 1], 1.0e-1
85+
assert fit_0.model_data[0, :] == pytest.approx(
86+
scaling_factor * fit_1.model_data.array[0, :], 1.0e-1
8487
)

0 commit comments

Comments
 (0)