|
| 1 | +import jax.numpy as jnp |
| 2 | +from jax import jit |
1 | 3 | import logging |
2 | 4 | import math |
3 | 5 |
|
|
6 | 8 | import autoarray as aa |
7 | 9 |
|
8 | 10 | from autoarray.structures.triangles.shape import Shape |
9 | | -from autofit.jax_wrapper import jit, use_jax, numpy as np, register_pytree_node_class |
10 | | - |
11 | | -try: |
12 | | - if use_jax: |
13 | | - from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import ( |
14 | | - CoordinateArrayTriangles, |
15 | | - ) |
16 | | - else: |
17 | | - from autoarray.structures.triangles.coordinate_array.coordinate_array import ( |
18 | | - CoordinateArrayTriangles, |
19 | | - ) |
20 | | - |
21 | | -except ImportError: |
22 | | - from autoarray.structures.triangles.coordinate_array.coordinate_array import ( |
23 | | - CoordinateArrayTriangles, |
24 | | - ) |
| 11 | +from autofit.jax_wrapper import register_pytree_node_class |
25 | 12 |
|
| 13 | +from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import ( |
| 14 | + CoordinateArrayTriangles, |
| 15 | +) |
26 | 16 | from autoarray.structures.triangles.abstract import AbstractTriangles |
27 | 17 |
|
28 | 18 | from autogalaxy import OperateDeflections |
@@ -278,13 +268,13 @@ def _filter_low_magnification( |
278 | 268 | ------- |
279 | 269 | The points with an absolute magnification above the threshold. |
280 | 270 | """ |
281 | | - points = np.array(points) |
| 271 | + points = jnp.array(points) |
282 | 272 | magnifications = tracer.magnification_2d_via_hessian_from( |
283 | 273 | grid=aa.Grid2DIrregular(points), |
284 | 274 | buffer=self.scale, |
285 | 275 | ) |
286 | | - mask = np.abs(magnifications.array) > self.magnification_threshold |
287 | | - return np.where(mask[:, None], points, np.nan) |
| 276 | + mask = jnp.abs(magnifications.array) > self.magnification_threshold |
| 277 | + return jnp.where(mask[:, None], points, jnp.nan) |
288 | 278 |
|
289 | 279 | def _source_triangles( |
290 | 280 | self, |
|
0 commit comments