Skip to content

Commit 7c728f7

Browse files
authored
Merge pull request #318 from PyAutoLabs/feature/knn-barycentric
Add KNNBarycentric mesh: JAX-native Delaunay-class interpolator
2 parents ae375ae + 58b314c commit 7c728f7

4 files changed

Lines changed: 386 additions & 0 deletions

File tree

autoarray/inversion/mesh/interpolator/knn.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from autoconf import cached_property
24

35
from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay
@@ -241,3 +243,191 @@ def _mappings_sizes_weights_split(self):
241243
# k=self.mesh.k_neighbors,
242244
# radius_scale=self.mesh.radius_scale,
243245
# )
246+
247+
248+
def barycentric_weights_from_3_nearest(
249+
query_points,
250+
mesh_points,
251+
nearest_3_indices,
252+
xp,
253+
):
254+
"""
255+
Compute barycentric weights for each query point on the triangle formed by its
256+
3 nearest mesh vertices.
257+
258+
Signed barycentric coordinates are computed, then clipped to be non-negative
259+
and renormalized so each row sums to 1. Queries inside the triangle return
260+
the exact Delaunay weights; queries outside return a clipped approximation
261+
(a convex combination of the 3 nearest, biased toward whichever vertices are
262+
on the same side of the triangle as the query).
263+
264+
Degenerate triangles (collinear vertices) get zero weights to avoid NaN.
265+
266+
Parameters
267+
----------
268+
query_points : (Q, 2)
269+
Query point (x, y) coordinates.
270+
mesh_points : (N, 2)
271+
Mesh vertex (x, y) coordinates.
272+
nearest_3_indices : (Q, 3)
273+
Indices into mesh_points of the 3 nearest vertices for each query.
274+
xp : module
275+
numpy or jax.numpy.
276+
277+
Returns
278+
-------
279+
weights : (Q, 3)
280+
Barycentric weights, clipped non-negative and row-normalized.
281+
"""
282+
vertices = mesh_points[nearest_3_indices] # (Q, 3, 2)
283+
p0 = vertices[:, 0]
284+
p1 = vertices[:, 1]
285+
p2 = vertices[:, 2]
286+
q = query_points
287+
288+
def signed_cross(a, b, c):
289+
return (b[..., 0] - a[..., 0]) * (c[..., 1] - a[..., 1]) - (
290+
b[..., 1] - a[..., 1]
291+
) * (c[..., 0] - a[..., 0])
292+
293+
total = signed_cross(p0, p1, p2)
294+
w0 = signed_cross(q, p1, p2)
295+
w1 = signed_cross(p0, q, p2)
296+
w2 = signed_cross(p0, p1, q)
297+
298+
eps = xp.asarray(1e-12, dtype=total.dtype)
299+
safe_total = xp.where(xp.abs(total) > eps, total, 1.0)
300+
301+
bary = xp.stack([w0, w1, w2], axis=1) / safe_total[:, None]
302+
303+
clipped = xp.maximum(bary, 0.0)
304+
row_sum = xp.sum(clipped, axis=1, keepdims=True)
305+
safe_sum = xp.where(row_sum > eps, row_sum, 1.0)
306+
weights = clipped / safe_sum
307+
308+
# Degenerate triangles fall back to nearest-neighbor (weight 1 on column 0,
309+
# which `get_interpolation_weights` orders as the closest mesh vertex).
310+
# Same fallback policy as `pix_indexes_for_sub_slim_index_delaunay_from`
311+
# for outside-simplex points.
312+
nearest_only = xp.asarray([1.0, 0.0, 0.0], dtype=weights.dtype)
313+
314+
degenerate = xp.abs(total) <= eps
315+
weights = xp.where(degenerate[:, None], nearest_only[None, :], weights)
316+
317+
return weights
318+
319+
320+
class InterpolatorKNNBarycentric(InterpolatorKNearestNeighbor):
321+
"""
322+
Interpolator that picks the 3 nearest mesh vertices in the source plane and
323+
computes locally-exact barycentric weights on the triangle they form.
324+
325+
Approximates :class:`InterpolatorDelaunay` without the scipy.spatial.Delaunay
326+
callback: when the 3 nearest are the containing Delaunay triangle's vertices,
327+
the weights are bit-identical to Delaunay; otherwise they are clipped-and-
328+
renormalized barycentric weights on whichever triangle the 3 nearest form.
329+
330+
The kNN connectivity knobs (``k_neighbors``, ``radius_scale``,
331+
``split_neighbor_division``) on the parent :class:`KNearestNeighbor` mesh are
332+
inherited and still control the regularization-spacing computation via
333+
``distance_to_self``. Interpolation always uses k=3, irrespective of
334+
``mesh.k_neighbors``.
335+
"""
336+
337+
@cached_property
338+
def _mappings_sizes_weights(self):
339+
340+
try:
341+
query_points = self.data_grid.over_sampled.array
342+
except AttributeError:
343+
try:
344+
query_points = self.data_grid.array
345+
except AttributeError:
346+
query_points = self.data_grid
347+
348+
mappings, _, _ = get_interpolation_weights(
349+
points=self.mesh_grid_xy,
350+
query_points=query_points,
351+
k_neighbors=3,
352+
radius_scale=1.0,
353+
)
354+
355+
weights = barycentric_weights_from_3_nearest(
356+
query_points=query_points,
357+
mesh_points=self.mesh_grid_xy,
358+
nearest_3_indices=mappings,
359+
xp=self._xp,
360+
)
361+
362+
# On the numpy path, materialize with `np.array(...)` so the regularization
363+
# code (which uses in-place assignment, e.g. `reg_split_np_from`) gets a
364+
# writable buffer rather than a read-only view of a jax.Array. On the jax
365+
# path, asarray is the right cast (no copy in a JIT trace).
366+
if self._xp is np:
367+
mappings = np.array(mappings)
368+
weights = np.array(weights)
369+
else:
370+
mappings = self._xp.asarray(mappings)
371+
weights = self._xp.asarray(weights)
372+
373+
sizes = self._xp.full(
374+
(mappings.shape[0],),
375+
mappings.shape[1],
376+
)
377+
378+
return mappings, sizes, weights
379+
380+
@cached_property
381+
def _mappings_sizes_weights_split(self):
382+
"""
383+
Same spacing scheme as :class:`InterpolatorKNearestNeighbor` but the
384+
split-point interpolator is :class:`InterpolatorKNNBarycentric` so the
385+
split-regularization weights are also barycentric rather than Wendland.
386+
"""
387+
from autoarray.inversion.regularization.regularization_util import (
388+
split_points_from,
389+
)
390+
391+
neighbor_index = int(self.mesh.k_neighbors) // self.mesh.split_neighbor_division
392+
393+
distance_to_self = self.distance_to_self
394+
others = distance_to_self[:, 1:]
395+
idx = int(neighbor_index) - 1
396+
idx = max(0, min(idx, others.shape[1] - 1))
397+
r_k = others[:, idx]
398+
399+
split_step = self.mesh.areas_factor * r_k
400+
401+
split_points = split_points_from(
402+
points=self.mesh_grid.array,
403+
area_weights=split_step,
404+
xp=self._xp,
405+
)
406+
407+
interpolator = InterpolatorKNNBarycentric(
408+
mesh=self.mesh,
409+
mesh_grid=self.mesh_grid,
410+
data_grid=split_points,
411+
xp=self._xp,
412+
)
413+
414+
mappings = interpolator.mappings
415+
weights = interpolator.weights
416+
417+
# `reg_split_np_from` writes `splitted_mappings[i][j+1] = pixel_index`
418+
# for the "flag-zero" insertion of the central pixel, so the buffer
419+
# must have an extra column reserved past the k=3 mappings — matching
420+
# `InterpolatorDelaunay._mappings_sizes_weights_split`'s hstack-append.
421+
# `sizes` reports 3 (the actual mappings); `reg_split_np_from` grows it
422+
# to 4 in-place when it inserts.
423+
sizes = self._xp.full(
424+
(mappings.shape[0],),
425+
mappings.shape[1],
426+
)
427+
428+
pad_int = self._xp.full((mappings.shape[0], 1), -1, dtype=mappings.dtype)
429+
pad_float = self._xp.zeros((weights.shape[0], 1), dtype=weights.dtype)
430+
mappings = self._xp.hstack((mappings, pad_int))
431+
weights = self._xp.hstack((weights, pad_float))
432+
433+
return mappings, sizes, weights

autoarray/inversion/mesh/mesh/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .rectangular_uniform import RectangularUniform
77
from .delaunay import Delaunay
88
from .knn import KNearestNeighbor
9+
from .knn import KNNBarycentric

autoarray/inversion/mesh/mesh/knn.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,28 @@ def interpolator_cls(self):
7777
)
7878

7979
return InterpolatorKNearestNeighbor
80+
81+
82+
class KNNBarycentric(KNearestNeighbor):
83+
"""
84+
A mesh that inherits k-nearest-neighbour connectivity from
85+
:class:`KNearestNeighbor` but uses :class:`InterpolatorKNNBarycentric` to
86+
compute interpolation weights as locally-exact barycentric coordinates on
87+
the triangle formed by the 3 nearest source-plane mesh vertices, rather
88+
than a Wendland kernel.
89+
90+
This is a JAX-native approximation to Delaunay barycentric interpolation
91+
that avoids the scipy.spatial.Delaunay callback. The kNN connectivity knobs
92+
(``k_neighbors``, ``radius_scale``, ``split_neighbor_division``) are
93+
inherited and still control the regularization-spacing computation, but the
94+
*interpolation* weights always use k=3 + barycentric and ignore them.
95+
"""
96+
97+
@property
98+
def interpolator_cls(self):
99+
100+
from autoarray.inversion.mesh.interpolator.knn import (
101+
InterpolatorKNNBarycentric,
102+
)
103+
104+
return InterpolatorKNNBarycentric

0 commit comments

Comments
 (0)