Skip to content

Commit 8654adf

Browse files
authored
Merge pull request #282 from PyAutoLabs/feature/nnls-target-kappa-fix
fix(jax): bump default NNLS target_kappa to 1e-2 for finite backward gradients
2 parents 19d7957 + f7ff5da commit 8654adf

2 files changed

Lines changed: 21 additions & 2 deletions

File tree

autoarray/config/general.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ inversion:
77
no_regularization_add_to_curvature_diag_value : 1.0e-3 # The default value added to the curvature matrix's diagonal when regularization is not applied to a linear object, which prevents inversion's failing due to the matrix being singular.
88
use_border_relocator: false # If True, by default a pixelization's border is used to relocate all pixels outside its border to the border.
99
nnls_jacobi_preconditioning: true # If True (default), the curvature matrix passed to jaxnnls.solve_nnls_primal is Jacobi-preconditioned (D Q D y = D q, x = D y). Fixes NaN backward-pass gradients on ill-conditioned Q and roughly halves forward solve time. Set False to restore the raw unpreconditioned solve.
10+
nnls_target_kappa: 1.0e-2 # Central-path relaxation parameter passed to jaxnnls.solve_nnls_primal. Larger values smooth the relaxed-KKT backward pass and prevent NaN gradients on ill-conditioned Q; smaller values tighten the primal solve. 1.0e-2 is the smallest value empirically verified to produce finite gradients across MGE pipelines. jaxnnls's own default (1e-3) is too aggressive for the backward pass.
1011
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
1112
numba:
1213
use_numba: true

autoarray/inversion/inversion/inversion_util.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,19 @@ def reconstruction_positive_only_from(
287287
# explicitly disables preconditioning in the shadowing config.
288288
use_jacobi = True
289289

290+
try:
291+
target_kappa = conf.instance["general"]["inversion"][
292+
"nnls_target_kappa"
293+
]
294+
except KeyError:
295+
# jaxnnls's hardcoded default (1e-3) produces NaN in the relaxed-KKT
296+
# backward pass on ill-conditioned curvature matrices, even after
297+
# Jacobi preconditioning. 1e-2 is the smallest value empirically
298+
# verified to produce finite gradients across MGE pipelines (see
299+
# autolens_workspace_developer/jax_profiling/imaging/mge_gradients.py
300+
# _diagnose_kappa).
301+
target_kappa = 1.0e-2
302+
290303
if use_jacobi:
291304
# Ill-conditioned Q makes jaxnnls's relaxed-KKT backward pass
292305
# produce NaN gradients. Rescale Q so its diagonal is unit:
@@ -297,9 +310,14 @@ def reconstruction_positive_only_from(
297310
D = 1.0 / d
298311
Q_pc = (curvature_reg_matrix * D[:, None]) * D[None, :]
299312
q_pc = data_vector * D
300-
return jaxnnls.solve_nnls_primal(Q_pc, q_pc) * D
313+
return (
314+
jaxnnls.solve_nnls_primal(Q_pc, q_pc, target_kappa=target_kappa)
315+
* D
316+
)
301317

302-
return jaxnnls.solve_nnls_primal(curvature_reg_matrix, data_vector)
318+
return jaxnnls.solve_nnls_primal(
319+
curvature_reg_matrix, data_vector, target_kappa=target_kappa
320+
)
303321

304322
try:
305323

0 commit comments

Comments
 (0)