Skip to content

Commit 37ba243

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent 1825d2e commit 37ba243

9 files changed

Lines changed: 26 additions & 31 deletions

File tree

autolens/aggregator/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from autolens.aggregator.subplot import SubplotFitX1Plane as subplot_fit_x1_plane
2828
from autolens.aggregator.subplot import SubplotFit as subplot_fit
2929
from autolens.aggregator.subplot import SubplotFitLog10 as subplot_fit_log10
30-
from autolens.aggregator.subplot import FITSModelGalaxyImages as fits_model_galaxy_images
30+
from autolens.aggregator.subplot import (
31+
FITSModelGalaxyImages as fits_model_galaxy_images,
32+
)
3133
from autolens.aggregator.subplot import FITSTracer as fits_tracer
3234
from autolens.aggregator.subplot import FITSFit as fits_fit

autolens/analysis/positions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,3 @@ def log_likelihood_penalty_from(
199199
)
200200

201201
return penalty if max_separation > self.threshold else np.array(0.0)
202-

autolens/mock.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def deflections_between_planes_from(self, grid, xp=np, plane_i=0, plane_j=-1):
2323
return xp.zeros_like(grid.array)
2424

2525
def magnification_2d_via_hessian_from(
26-
self, grid, buffer: float = 0.01, deflections_func=None, xp=np,
26+
self,
27+
grid,
28+
buffer: float = 0.01,
29+
deflections_func=None,
30+
xp=np,
2731
) -> aa.ArrayIrregular:
2832
return aa.ArrayIrregular(values=xp.ones(grid.shape[0]))

autolens/point/fit/abstract.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ def magnifications_at_positions(self) -> aa.ArrayIrregular:
133133
"""
134134
return abs(
135135
self.tracer.magnification_2d_via_hessian_from(
136-
grid=self.positions,
137-
deflections_func=self.deflections_func,
138-
xp=self._xp
136+
grid=self.positions, deflections_func=self.deflections_func, xp=self._xp
139137
)
140138
)
141139

autolens/point/fit/positions/image/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,5 @@ def model_data(self) -> aa.Grid2DIrregular:
104104
tracer=self.tracer,
105105
source_plane_coordinate=self.source_plane_coordinate,
106106
plane_redshift=self.plane_redshift,
107-
remove_infinities=False
107+
remove_infinities=False,
108108
)

autolens/point/mock/mock_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ def solve(
1010
tracer,
1111
source_plane_coordinate,
1212
plane_redshift: Optional[float] = None,
13-
remove_infinities : bool = True
13+
remove_infinities: bool = True,
1414
):
1515
return self.model_positions

autolens/point/solver/point_solver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def solve(
6060
tracer=tracer, points=kept_triangles.means
6161
)
6262

63-
solution = aa.Grid2DIrregular([pair for pair in filtered_means], xp=self._xp).array
63+
solution = aa.Grid2DIrregular(
64+
[pair for pair in filtered_means], xp=self._xp
65+
).array
6466

6567
is_nan = self._xp.isnan(solution).any(axis=1)
6668
sentinel = self._xp.full_like(solution[0], fill_value=self._xp.inf)

autolens/point/solver/shape_solver.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
pixel_scale_precision: float,
2424
magnification_threshold=0.1,
2525
neighbor_degree: int = 1,
26-
xp = np
26+
xp=np,
2727
):
2828
"""
2929
Determine the image plane coordinates that are traced to be a source plane coordinate.
@@ -144,9 +144,13 @@ def for_limits_and_scale(
144144
"""
145145

146146
if xp.__name__.startswith("jax"):
147-
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles as triangle_cls
147+
from autoarray.structures.triangles.coordinate_array import (
148+
CoordinateArrayTriangles as triangle_cls,
149+
)
148150
else:
149-
from autoarray.structures.triangles.coordinate_array_np import CoordinateArrayTrianglesNp as triangle_cls
151+
from autoarray.structures.triangles.coordinate_array_np import (
152+
CoordinateArrayTrianglesNp as triangle_cls,
153+
)
150154

151155
initial_triangles = triangle_cls.for_limits_and_scale(
152156
y_min=y_min,
@@ -196,10 +200,7 @@ def _plane_grid(
196200
plane_index = tracer.plane_index_via_redshift_from(redshift=plane_redshift)
197201

198202
deflections = tracer.deflections_between_planes_from(
199-
grid=grid,
200-
plane_i=0,
201-
plane_j=plane_index,
202-
xp=self._xp
203+
grid=grid, plane_i=0, plane_j=plane_index, xp=self._xp
203204
)
204205
# noinspection PyTypeChecker
205206
return grid.grid_2d_via_deflection_grid_from(deflection_grid=deflections)
@@ -265,9 +266,7 @@ def _filter_low_magnification(
265266
"""
266267
points = self._xp.array(points)
267268
magnifications = tracer.magnification_2d_via_hessian_from(
268-
grid=aa.Grid2DIrregular(points).array,
269-
buffer=self.scale,
270-
xp=self._xp
269+
grid=aa.Grid2DIrregular(points).array, buffer=self.scale, xp=self._xp
271270
)
272271
mask = self._xp.abs(magnifications.array) > self.magnification_threshold
273272
return self._xp.where(mask[:, None], points, self._xp.nan)

test_autolens/point/triangles/test_solver.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,25 +71,16 @@ def test_real_example_jax(grid, tracer):
7171

7272
import jax.numpy as jnp
7373

74-
jax_solver = PointSolver.for_grid(
75-
grid=grid,
76-
pixel_scale_precision=0.001,
77-
xp=jnp
78-
)
74+
jax_solver = PointSolver.for_grid(grid=grid, pixel_scale_precision=0.001, xp=jnp)
7975

8076
result = jax_solver.solve(
81-
tracer=tracer,
82-
source_plane_coordinate=(0.07, 0.07),
83-
remove_infinities=True
77+
tracer=tracer, source_plane_coordinate=(0.07, 0.07), remove_infinities=True
8478
)
8579

8680
assert len(result) == 5
8781

88-
8982
result = jax_solver.solve(
90-
tracer=tracer,
91-
source_plane_coordinate=(0.07, 0.07),
92-
remove_infinities=False
83+
tracer=tracer, source_plane_coordinate=(0.07, 0.07), remove_infinities=False
9384
)
9485

9586
assert len(result) == 15

0 commit comments

Comments
 (0)