Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions autolens/analysis/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def image_plane_multiple_image_positions(
grid = self.analysis.dataset.mask.derive_grid.all_false

solver = PointSolver.for_grid(
grid=grid, pixel_scale_precision=0.001, xp=self.analysis._xp
grid=grid, pixel_scale_precision=0.001
)

source_plane_centre = self.source_plane_centre_from(
Expand All @@ -114,6 +114,7 @@ def image_plane_multiple_image_positions(
multiple_images = solver.solve(
tracer=self.max_log_likelihood_tracer,
source_plane_coordinate=source_plane_centre.in_list[0],
xp=self.analysis._xp,
plane_redshift=plane_redshift,
)

Expand Down Expand Up @@ -166,7 +167,7 @@ def image_plane_multiple_image_positions_for_single_image_from(
centre = self.source_plane_centre_from(plane_redshift=plane_redshift).in_list[0]

solver = PointSolver.for_grid(
grid=grid, pixel_scale_precision=0.001, xp=self.analysis._xp
grid=grid, pixel_scale_precision=0.001
)

for i in range(1, increments):
Expand All @@ -175,6 +176,7 @@ def image_plane_multiple_image_positions_for_single_image_from(
multiple_images = solver.solve(
tracer=self.max_log_likelihood_tracer,
source_plane_coordinate=(centre[0] * factor, centre[1] * factor),
xp=self.analysis._xp,
plane_redshift=plane_redshift,
)

Expand Down
1 change: 1 addition & 0 deletions autolens/point/fit/positions/image/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def model_data(self) -> aa.Grid2DIrregular:
return self.solver.solve(
tracer=self.tracer,
source_plane_coordinate=self.source_plane_coordinate,
xp=self._xp,
plane_redshift=self.plane_redshift,
remove_infinities=False,
)
3 changes: 3 additions & 0 deletions autolens/point/mock/mock_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

import numpy as np


class MockPointSolver:
def __init__(self, model_positions):
Expand All @@ -9,6 +11,7 @@ def solve(
self,
tracer,
source_plane_coordinate,
xp=np,
plane_redshift: Optional[float] = None,
remove_infinities: bool = True,
):
Expand Down
21 changes: 13 additions & 8 deletions autolens/point/solver/point_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def solve(
self,
tracer: Tracer,
source_plane_coordinate: Tuple[float, float],
xp=np,
plane_redshift: Optional[float] = None,
remove_infinities: bool = True,
) -> aa.Grid2DIrregular:
Expand All @@ -57,11 +58,14 @@ def solve(

Parameters
----------
tracer
The tracer that traces the image plane coordinates to the source plane.
source_plane_coordinate
The plane coordinate to trace to the image plane, which by default in the source-plane coordinate
but could be a coordinate in another plane is `plane_redshift` is input.
tracer
The tracer that traces the image plane coordinates to the source plane
xp
The array module (``numpy`` or ``jax.numpy``) the solve runs in. ``AnalysisPoint``
passes ``jax.numpy`` when ``use_jax=True`` is set on the analysis.
plane_redshift
The redshift of the plane coordinate, which for multi-plane systems may not be the source-plane.

Expand All @@ -73,23 +77,24 @@ def solve(
kept_triangles = super().solve_triangles(
tracer=tracer,
shape=Point(*source_plane_coordinate),
xp=xp,
plane_redshift=plane_redshift,
)

filtered_means = self._filter_low_magnification(
tracer=tracer, points=kept_triangles.means
tracer=tracer, points=kept_triangles.means, xp=xp
)

solution = aa.Grid2DIrregular(
[pair for pair in filtered_means], xp=self._xp
[pair for pair in filtered_means], xp=xp
).array

is_nan = self._xp.isnan(solution).any(axis=1)
sentinel = self._xp.full_like(solution[0], fill_value=self._xp.inf)
solution = self._xp.where(is_nan[:, None], sentinel, solution)
is_nan = xp.isnan(solution).any(axis=1)
sentinel = xp.full_like(solution[0], fill_value=xp.inf)
solution = xp.where(is_nan[:, None], sentinel, solution)

if remove_infinities:

solution = solution[~self._xp.isinf(solution).any(axis=1)]
solution = solution[~xp.isinf(solution).any(axis=1)]

return aa.Grid2DIrregular(solution)
Loading
Loading