Skip to content

Commit f59ae1d

Browse files
Jammy2211Jammy2211
authored andcommitted
Add KNNBarycentric mesh: pure-JAX barycentric weights on top-3 nearest
A JAX-native approximation to InterpolatorDelaunay that avoids the scipy.spatial.Delaunay callback. Picks the 3 nearest mesh vertices in source plane and computes locally-exact barycentric weights on the triangle they form. When the 3 nearest are the containing Delaunay triangle's vertices the weights are bit-identical to Delaunay; otherwise weights are clipped non-negative and renormalized. Degenerate (collinear) triangles fall back to nearest-neighbor. Additive only: new `KNNBarycentric(KNearestNeighbor)` mesh class selectable via `aa.mesh.KNNBarycentric(pixels=...)`, no existing behaviour changed.
1 parent aa418a5 commit f59ae1d

4 files changed

Lines changed: 324 additions & 0 deletions

File tree

autoarray/inversion/mesh/interpolator/knn.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,172 @@ def _mappings_sizes_weights_split(self):
241241
# k=self.mesh.k_neighbors,
242242
# radius_scale=self.mesh.radius_scale,
243243
# )
244+
245+
246+
def barycentric_weights_from_3_nearest(
247+
query_points,
248+
mesh_points,
249+
nearest_3_indices,
250+
xp,
251+
):
252+
"""
253+
Compute barycentric weights for each query point on the triangle formed by its
254+
3 nearest mesh vertices.
255+
256+
Signed barycentric coordinates are computed, then clipped to be non-negative
257+
and renormalized so each row sums to 1. Queries inside the triangle return
258+
the exact Delaunay weights; queries outside return a clipped approximation
259+
(a convex combination of the 3 nearest, biased toward whichever vertices are
260+
on the same side of the triangle as the query).
261+
262+
Degenerate triangles (collinear vertices) get zero weights to avoid NaN.
263+
264+
Parameters
265+
----------
266+
query_points : (Q, 2)
267+
Query point (x, y) coordinates.
268+
mesh_points : (N, 2)
269+
Mesh vertex (x, y) coordinates.
270+
nearest_3_indices : (Q, 3)
271+
Indices into mesh_points of the 3 nearest vertices for each query.
272+
xp : module
273+
numpy or jax.numpy.
274+
275+
Returns
276+
-------
277+
weights : (Q, 3)
278+
Barycentric weights, clipped non-negative and row-normalized.
279+
"""
280+
vertices = mesh_points[nearest_3_indices] # (Q, 3, 2)
281+
p0 = vertices[:, 0]
282+
p1 = vertices[:, 1]
283+
p2 = vertices[:, 2]
284+
q = query_points
285+
286+
def signed_cross(a, b, c):
287+
return (b[..., 0] - a[..., 0]) * (c[..., 1] - a[..., 1]) - (
288+
b[..., 1] - a[..., 1]
289+
) * (c[..., 0] - a[..., 0])
290+
291+
total = signed_cross(p0, p1, p2)
292+
w0 = signed_cross(q, p1, p2)
293+
w1 = signed_cross(p0, q, p2)
294+
w2 = signed_cross(p0, p1, q)
295+
296+
eps = xp.asarray(1e-12, dtype=total.dtype)
297+
safe_total = xp.where(xp.abs(total) > eps, total, 1.0)
298+
299+
bary = xp.stack([w0, w1, w2], axis=1) / safe_total[:, None]
300+
301+
clipped = xp.maximum(bary, 0.0)
302+
row_sum = xp.sum(clipped, axis=1, keepdims=True)
303+
safe_sum = xp.where(row_sum > eps, row_sum, 1.0)
304+
weights = clipped / safe_sum
305+
306+
# Degenerate triangles fall back to nearest-neighbor (weight 1 on column 0,
307+
# which `get_interpolation_weights` orders as the closest mesh vertex).
308+
# Same fallback policy as `pix_indexes_for_sub_slim_index_delaunay_from`
309+
# for outside-simplex points.
310+
nearest_only = xp.asarray([1.0, 0.0, 0.0], dtype=weights.dtype)
311+
312+
degenerate = xp.abs(total) <= eps
313+
weights = xp.where(degenerate[:, None], nearest_only[None, :], weights)
314+
315+
return weights
316+
317+
318+
class InterpolatorKNNBarycentric(InterpolatorKNearestNeighbor):
319+
"""
320+
Interpolator that picks the 3 nearest mesh vertices in the source plane and
321+
computes locally-exact barycentric weights on the triangle they form.
322+
323+
Approximates :class:`InterpolatorDelaunay` without the scipy.spatial.Delaunay
324+
callback: when the 3 nearest are the containing Delaunay triangle's vertices,
325+
the weights are bit-identical to Delaunay; otherwise they are clipped-and-
326+
renormalized barycentric weights on whichever triangle the 3 nearest form.
327+
328+
The kNN connectivity knobs (``k_neighbors``, ``radius_scale``,
329+
``split_neighbor_division``) on the parent :class:`KNearestNeighbor` mesh are
330+
inherited and still control the regularization-spacing computation via
331+
``distance_to_self``. Interpolation always uses k=3, irrespective of
332+
``mesh.k_neighbors``.
333+
"""
334+
335+
@cached_property
336+
def _mappings_sizes_weights(self):
337+
338+
try:
339+
query_points = self.data_grid.over_sampled.array
340+
except AttributeError:
341+
try:
342+
query_points = self.data_grid.array
343+
except AttributeError:
344+
query_points = self.data_grid
345+
346+
mappings, _, _ = get_interpolation_weights(
347+
points=self.mesh_grid_xy,
348+
query_points=query_points,
349+
k_neighbors=3,
350+
radius_scale=1.0,
351+
)
352+
353+
weights = barycentric_weights_from_3_nearest(
354+
query_points=query_points,
355+
mesh_points=self.mesh_grid_xy,
356+
nearest_3_indices=mappings,
357+
xp=self._xp,
358+
)
359+
360+
mappings = self._xp.asarray(mappings)
361+
weights = self._xp.asarray(weights)
362+
363+
sizes = self._xp.full(
364+
(mappings.shape[0],),
365+
mappings.shape[1],
366+
)
367+
368+
return mappings, sizes, weights
369+
370+
@cached_property
371+
def _mappings_sizes_weights_split(self):
372+
"""
373+
Same spacing scheme as :class:`InterpolatorKNearestNeighbor` but the
374+
split-point interpolator is :class:`InterpolatorKNNBarycentric` so the
375+
split-regularization weights are also barycentric rather than Wendland.
376+
"""
377+
from autoarray.inversion.regularization.regularization_util import (
378+
split_points_from,
379+
)
380+
381+
neighbor_index = int(self.mesh.k_neighbors) // self.mesh.split_neighbor_division
382+
383+
distance_to_self = self.distance_to_self
384+
others = distance_to_self[:, 1:]
385+
idx = int(neighbor_index) - 1
386+
idx = max(0, min(idx, others.shape[1] - 1))
387+
r_k = others[:, idx]
388+
389+
split_step = self.mesh.areas_factor * r_k
390+
391+
split_points = split_points_from(
392+
points=self.mesh_grid.array,
393+
area_weights=split_step,
394+
xp=self._xp,
395+
)
396+
397+
interpolator = InterpolatorKNNBarycentric(
398+
mesh=self.mesh,
399+
mesh_grid=self.mesh_grid,
400+
data_grid=split_points,
401+
xp=self._xp,
402+
)
403+
404+
mappings = interpolator.mappings
405+
weights = interpolator.weights
406+
407+
sizes = self._xp.full(
408+
(mappings.shape[0],),
409+
mappings.shape[1],
410+
)
411+
412+
return mappings, sizes, weights

autoarray/inversion/mesh/mesh/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .rectangular_uniform import RectangularUniform
77
from .delaunay import Delaunay
88
from .knn import KNearestNeighbor
9+
from .knn import KNNBarycentric

autoarray/inversion/mesh/mesh/knn.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,28 @@ def interpolator_cls(self):
7777
)
7878

7979
return InterpolatorKNearestNeighbor
80+
81+
82+
class KNNBarycentric(KNearestNeighbor):
83+
"""
84+
A mesh that inherits k-nearest-neighbour connectivity from
85+
:class:`KNearestNeighbor` but uses :class:`InterpolatorKNNBarycentric` to
86+
compute interpolation weights as locally-exact barycentric coordinates on
87+
the triangle formed by the 3 nearest source-plane mesh vertices, rather
88+
than a Wendland kernel.
89+
90+
This is a JAX-native approximation to Delaunay barycentric interpolation
91+
that avoids the scipy.spatial.Delaunay callback. The kNN connectivity knobs
92+
(``k_neighbors``, ``radius_scale``, ``split_neighbor_division``) are
93+
inherited and still control the regularization-spacing computation, but the
94+
*interpolation* weights always use k=3 + barycentric and ignore them.
95+
"""
96+
97+
@property
98+
def interpolator_cls(self):
99+
100+
from autoarray.inversion.mesh.interpolator.knn import (
101+
InterpolatorKNNBarycentric,
102+
)
103+
104+
return InterpolatorKNNBarycentric
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import numpy as np
2+
import pytest
3+
4+
from autoarray.inversion.mesh.interpolator.knn import (
5+
barycentric_weights_from_3_nearest,
6+
)
7+
8+
9+
def test__weights__interior_query_matches_delaunay_exactly():
10+
mesh = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
11+
query = np.array([[0.25, 0.25]])
12+
indices = np.array([[0, 1, 2]])
13+
14+
weights = barycentric_weights_from_3_nearest(
15+
query_points=query,
16+
mesh_points=mesh,
17+
nearest_3_indices=indices,
18+
xp=np,
19+
)
20+
21+
assert weights[0] == pytest.approx([0.5, 0.25, 0.25], rel=1e-12)
22+
assert weights[0].sum() == pytest.approx(1.0, rel=1e-12)
23+
24+
25+
def test__weights__query_on_edge_splits_50_50():
26+
mesh = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
27+
query = np.array([[0.5, 0.5]])
28+
indices = np.array([[0, 1, 2]])
29+
30+
weights = barycentric_weights_from_3_nearest(
31+
query_points=query,
32+
mesh_points=mesh,
33+
nearest_3_indices=indices,
34+
xp=np,
35+
)
36+
37+
assert weights[0] == pytest.approx([0.0, 0.5, 0.5], abs=1e-12)
38+
39+
40+
def test__weights__outside_query_clipped_and_renormalized():
41+
mesh = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
42+
query = np.array([[-0.5, -0.5]])
43+
indices = np.array([[0, 1, 2]])
44+
45+
weights = barycentric_weights_from_3_nearest(
46+
query_points=query,
47+
mesh_points=mesh,
48+
nearest_3_indices=indices,
49+
xp=np,
50+
)
51+
52+
assert (weights[0] >= 0.0).all()
53+
assert weights[0].sum() == pytest.approx(1.0, rel=1e-12)
54+
# p0 is opposite both edges the query is "outside" of → keeps all weight
55+
assert weights[0, 0] == pytest.approx(1.0, rel=1e-12)
56+
assert weights[0, 1] == pytest.approx(0.0, abs=1e-12)
57+
assert weights[0, 2] == pytest.approx(0.0, abs=1e-12)
58+
59+
60+
def test__weights__degenerate_triangle_falls_back_to_nearest_neighbor():
61+
mesh = np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]])
62+
query = np.array([[1.0, 0.0]])
63+
# column 0 = closest by construction (set by get_interpolation_weights)
64+
indices = np.array([[0, 1, 2]])
65+
66+
weights = barycentric_weights_from_3_nearest(
67+
query_points=query,
68+
mesh_points=mesh,
69+
nearest_3_indices=indices,
70+
xp=np,
71+
)
72+
73+
assert np.isfinite(weights).all()
74+
# Degenerate → fall back to nearest-neighbor: [1, 0, 0]
75+
assert weights[0] == pytest.approx([1.0, 0.0, 0.0], abs=1e-12)
76+
77+
78+
def test__weights__matches_pixel_weights_delaunay_from__for_inside_triangle_grid():
79+
from autoarray.inversion.mesh.interpolator.delaunay import (
80+
pixel_weights_delaunay_from,
81+
)
82+
83+
mesh = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
84+
rng = np.random.default_rng(seed=42)
85+
# 200 random barycentric-coordinate samples → guaranteed inside-triangle queries
86+
u = rng.uniform(0, 1, size=200)
87+
v = rng.uniform(0, 1, size=200)
88+
inside = u + v <= 1.0
89+
u, v = u[inside], v[inside]
90+
queries = np.stack([u, v], axis=1) # (M, 2) — inside the unit-right triangle
91+
92+
indices = np.tile(np.array([[0, 1, 2]]), (queries.shape[0], 1))
93+
94+
bary = barycentric_weights_from_3_nearest(
95+
query_points=queries,
96+
mesh_points=mesh,
97+
nearest_3_indices=indices,
98+
xp=np,
99+
)
100+
101+
delaunay = pixel_weights_delaunay_from(
102+
data_grid=queries,
103+
mesh_grid=mesh,
104+
pix_indexes_for_sub_slim_index=indices,
105+
xp=np,
106+
)
107+
108+
assert bary == pytest.approx(delaunay, rel=1e-12, abs=1e-12)
109+
110+
111+
def test__mesh_class__interpolator_cls_is_knn_barycentric():
112+
import autoarray as aa
113+
from autoarray.inversion.mesh.interpolator.knn import (
114+
InterpolatorKNNBarycentric,
115+
)
116+
117+
mesh = aa.mesh.KNNBarycentric(pixels=10)
118+
assert mesh.interpolator_cls is InterpolatorKNNBarycentric
119+
120+
121+
def test__mesh_class__inherits_knn_knobs():
122+
import autoarray as aa
123+
124+
mesh = aa.mesh.KNNBarycentric(pixels=10)
125+
# Knobs inherited from KNearestNeighbor still control regularization spacing
126+
assert mesh.k_neighbors == 10
127+
assert mesh.radius_scale == 1.5
128+
assert mesh.areas_factor == 0.5
129+
assert mesh.split_neighbor_division == 2

0 commit comments

Comments
 (0)