Skip to content

Add KNNBarycentric mesh: JAX-native Delaunay-class interpolator#318

Merged
Jammy2211 merged 2 commits into
mainfrom
feature/knn-barycentric
May 16, 2026
Merged

Add KNNBarycentric mesh: JAX-native Delaunay-class interpolator#318
Jammy2211 merged 2 commits into
mainfrom
feature/knn-barycentric

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Adds a JAX-native interpolator that approximates InterpolatorDelaunay
without the scipy.spatial.Delaunay callback. For each query point it
picks the 3 nearest mesh vertices via the existing brute-force kNN
search and computes locally-exact barycentric weights on the triangle
they form. When the 3 nearest happen to be the containing Delaunay
triangle's vertices, weights are bit-identical to Delaunay; otherwise
weights are clipped non-negative and renormalized. Degenerate
(collinear) triangles fall back to nearest-neighbor weight [1, 0, 0],
matching pix_indexes_for_sub_slim_index_delaunay_from's outside-simplex
policy.

This is the library half of the wildcard speedup investigated in issue
#317 — the scipy callback costs ~16.87 ms/element (24% of the production
batched likelihood at batch=20 on A100). Scientific validation against
real lens models at rtol=1e-3 on log_evidence happens in the
follow-up workspace PRs (autolens_workspace_developer regression
script + autolens_workspace_test smoke).

API Changes

Purely additive:

  • New aa.mesh.KNNBarycentric(pixels=...) mesh class, inherits from
    KNearestNeighbor (so all the existing kNN regularization-spacing
    knobs work) but selects the new interpolator.
  • New InterpolatorKNNBarycentric (not exported at aa. top level,
    same as InterpolatorKNearestNeighbor).
  • New barycentric_weights_from_3_nearest() helper in
    autoarray.inversion.mesh.interpolator.knn.

No removals, no signature changes, no behaviour changes to existing
mesh/interpolator classes. See full details below.

Test Plan

  • pytest test_autoarray/inversion/pixelization/interpolator/test_knn_barycentric.py — 7/7 pass
  • pytest test_autoarray/inversion/pixelization/interpolator/ test_autoarray/inversion/pixelization/mesh/ — 24/24 pass, no regression
  • Full test_autoarray/ suite — run by ship_library subagent
  • Scientific validation (rtol=1e-3 vs Delaunay log_evidence at HST fiducial 26288.321397232066) — follow-up workspace PR
Full API Changes (for automation & release notes)

Added

  • autoarray.inversion.mesh.mesh.knn.KNNBarycentric(pixels, zeroed_pixels=0, k_neighbors=10, radius_scale=1.5, areas_factor=0.5, split_neighbor_division=2) — mesh class. Inherits all knobs from KNearestNeighbor; only interpolator_cls differs. Available as aa.mesh.KNNBarycentric.
  • autoarray.inversion.mesh.interpolator.knn.InterpolatorKNNBarycentric — interpolator class. Subclass of InterpolatorKNearestNeighbor; overrides _mappings_sizes_weights and _mappings_sizes_weights_split to call kNN with k=3 and use barycentric instead of Wendland weights.
  • autoarray.inversion.mesh.interpolator.knn.barycentric_weights_from_3_nearest(query_points, mesh_points, nearest_3_indices, xp) — helper that computes signed barycentric weights on the triangle formed by the 3 nearest mesh vertices, clips non-negative, renormalizes, and falls back to nearest-neighbor on degenerate triangles.

Removed / Renamed / Signature Changed

  • None.

Behaviour Changed

  • None for existing classes. KNearestNeighbor and InterpolatorKNearestNeighbor are untouched.

Migration

  • No migration required. Users opt in by swapping aa.mesh.Delaunay(pixels=...) (or aa.mesh.KNearestNeighbor(pixels=...)) for aa.mesh.KNNBarycentric(pixels=...).

🤖 Generated with Claude Code

A JAX-native approximation to InterpolatorDelaunay that avoids the
scipy.spatial.Delaunay callback. Picks the 3 nearest mesh vertices in
source plane and computes locally-exact barycentric weights on the
triangle they form. When the 3 nearest are the containing Delaunay
triangle's vertices the weights are bit-identical to Delaunay;
otherwise weights are clipped non-negative and renormalized.
Degenerate (collinear) triangles fall back to nearest-neighbor.

Additive only: new `KNNBarycentric(KNearestNeighbor)` mesh class
selectable via `aa.mesh.KNNBarycentric(pixels=...)`, no existing
behaviour changed.
Bugfixes surfaced while integrating InterpolatorKNNBarycentric into the
autolens FitImaging eager numpy path:

- `_mappings_sizes_weights` was returning np.asarray() of jax outputs,
  which gives a read-only view. ConstantSplit's `reg_split_np_from` uses
  in-place assignment and raised "assignment destination is read-only".
  On the numpy path, materialize with np.array() to get a writable buffer.
- `_mappings_sizes_weights_split` returned a k=3 mappings array. The
  ConstantSplit code writes `splitted_mappings[i][j+1]` for the central
  pixel insertion, requiring an extra reserved column. Match
  InterpolatorDelaunay's hstack-append pattern.
- Add `import numpy as np` at module scope (was implicitly available
  through `jax.numpy as jnp` aliasing for Wendland-only paths).
- New regression-guard test confirms numpy-path mappings/weights are
  writable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 7c728f7 into main May 16, 2026
6 checks passed
@Jammy2211 Jammy2211 deleted the feature/knn-barycentric branch May 16, 2026 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant