Add KNNBarycentric mesh: JAX-native Delaunay-class interpolator#318
Merged
Conversation
A JAX-native approximation to InterpolatorDelaunay that avoids the scipy.spatial.Delaunay callback. Picks the 3 nearest mesh vertices in source plane and computes locally-exact barycentric weights on the triangle they form. When the 3 nearest are the containing Delaunay triangle's vertices the weights are bit-identical to Delaunay; otherwise weights are clipped non-negative and renormalized. Degenerate (collinear) triangles fall back to nearest-neighbor. Additive only: new `KNNBarycentric(KNearestNeighbor)` mesh class selectable via `aa.mesh.KNNBarycentric(pixels=...)`, no existing behaviour changed.
Bugfixes surfaced while integrating InterpolatorKNNBarycentric into the autolens FitImaging eager numpy path: - `_mappings_sizes_weights` was returning np.asarray() of jax outputs, which gives a read-only view. ConstantSplit's `reg_split_np_from` uses in-place assignment and raised "assignment destination is read-only". On the numpy path, materialize with np.array() to get a writable buffer. - `_mappings_sizes_weights_split` returned a k=3 mappings array. The ConstantSplit code writes `splitted_mappings[i][j+1]` for the central pixel insertion, requiring an extra reserved column. Match InterpolatorDelaunay's hstack-append pattern. - Add `import numpy as np` at module scope (was implicitly available through `jax.numpy as jnp` aliasing for Wendland-only paths). - New regression-guard test confirms numpy-path mappings/weights are writable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a JAX-native interpolator that approximates
InterpolatorDelaunaywithout the scipy.spatial.Delaunay callback. For each query point it
picks the 3 nearest mesh vertices via the existing brute-force kNN
search and computes locally-exact barycentric weights on the triangle
they form. When the 3 nearest happen to be the containing Delaunay
triangle's vertices, weights are bit-identical to Delaunay; otherwise
weights are clipped non-negative and renormalized. Degenerate
(collinear) triangles fall back to nearest-neighbor weight
[1, 0, 0],matching
pix_indexes_for_sub_slim_index_delaunay_from's outside-simplexpolicy.
This is the library half of the wildcard speedup investigated in issue
#317 — the scipy callback costs ~16.87 ms/element (24% of the production
batched likelihood at batch=20 on A100). Scientific validation against
real lens models at
rtol=1e-3onlog_evidencehappens in thefollow-up workspace PRs (
autolens_workspace_developerregressionscript +
autolens_workspace_testsmoke).API Changes
Purely additive:
aa.mesh.KNNBarycentric(pixels=...)mesh class, inherits fromKNearestNeighbor(so all the existing kNN regularization-spacingknobs work) but selects the new interpolator.
InterpolatorKNNBarycentric(not exported ataa.top level,same as
InterpolatorKNearestNeighbor).barycentric_weights_from_3_nearest()helper inautoarray.inversion.mesh.interpolator.knn.No removals, no signature changes, no behaviour changes to existing
mesh/interpolator classes. See full details below.
Test Plan
pytest test_autoarray/inversion/pixelization/interpolator/test_knn_barycentric.py— 7/7 passpytest test_autoarray/inversion/pixelization/interpolator/ test_autoarray/inversion/pixelization/mesh/— 24/24 pass, no regressiontest_autoarray/suite — run by ship_library subagentlog_evidenceat HST fiducial 26288.321397232066) — follow-up workspace PRFull API Changes (for automation & release notes)
Added
autoarray.inversion.mesh.mesh.knn.KNNBarycentric(pixels, zeroed_pixels=0, k_neighbors=10, radius_scale=1.5, areas_factor=0.5, split_neighbor_division=2)— mesh class. Inherits all knobs fromKNearestNeighbor; onlyinterpolator_clsdiffers. Available asaa.mesh.KNNBarycentric.autoarray.inversion.mesh.interpolator.knn.InterpolatorKNNBarycentric— interpolator class. Subclass ofInterpolatorKNearestNeighbor; overrides_mappings_sizes_weightsand_mappings_sizes_weights_splitto call kNN with k=3 and use barycentric instead of Wendland weights.autoarray.inversion.mesh.interpolator.knn.barycentric_weights_from_3_nearest(query_points, mesh_points, nearest_3_indices, xp)— helper that computes signed barycentric weights on the triangle formed by the 3 nearest mesh vertices, clips non-negative, renormalizes, and falls back to nearest-neighbor on degenerate triangles.Removed / Renamed / Signature Changed
Behaviour Changed
KNearestNeighborandInterpolatorKNearestNeighborare untouched.Migration
aa.mesh.Delaunay(pixels=...)(oraa.mesh.KNearestNeighbor(pixels=...)) foraa.mesh.KNNBarycentric(pixels=...).🤖 Generated with Claude Code