Skip to content

Commit e99bfaa

Browse files
Jammy2211Jammy2211
authored andcommitted
fix numba import
1 parent 7bd9e4d commit e99bfaa

5 files changed

Lines changed: 10 additions & 5 deletions

File tree

autoarray/config/general.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ inversion:
1313
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
1414
numba:
1515
use_numba: true
16-
cache: false
16+
cache: true
1717
nopython: true
1818
parallel: false
1919
pixelization:

autoarray/inversion/inversion/abstract.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from jax.scipy.linalg import block_diag
2+
33
import numpy as np
44
from typing import Dict, List, Optional, Type, Union, TYPE_CHECKING
55

@@ -331,6 +331,12 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
331331
If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion
332332
are regularized so high their value is forced to zero.
333333
"""
334+
if self.xp.__name__.startswith("jax"):
335+
from jax.scipy.linalg import block_diag
336+
return block_diag(
337+
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
338+
)
339+
from scipy.linalg import block_diag
334340
return block_diag(
335341
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
336342
)

autoarray/numba_util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
cache = True
1515
parallel = False
1616

17-
1817
def jit(nopython=nopython, cache=cache, parallel=parallel):
1918

2019
def wrapper(func):

test_autoarray/config/general.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ inversion:
1515
no_regularization_add_to_curvature_diag_value : 1.0e-8 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
1616
positive_only_uses_p_initial: false # If True, the positive-only solver of an inversion's uses an initial guess of the reconstructed data's values as which values should be positive, speeding up the solver.
1717
numba:
18-
cache: true
1918
nopython: true
19+
cache: true
2020
parallel: false
2121
use_numba: true
2222
pixelization:

test_autoarray/inversion/inversion/test_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test__inversion_imaging__via_mapper(
6363
rectangular_mapper_7x7_3x3,
6464
delaunay_mapper_9_3x3,
6565
):
66+
6667
inversion = aa.Inversion(
6768
dataset=masked_imaging_7x7_no_blur,
6869
linear_obj_list=[rectangular_mapper_7x7_3x3],
@@ -114,7 +115,6 @@ def test__inversion_imaging__via_mapper(
114115
assert inversion.log_det_curvature_reg_matrix_term == pytest.approx(10.6674, 1.0e-4)
115116
assert inversion.mapped_reconstructed_image == pytest.approx(np.ones(9), 1.0e-4)
116117

117-
fff
118118

119119

120120
def test__inversion_imaging__via_regularizations(

0 commit comments

Comments
 (0)