Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions autoarray/inversion/mesh/interpolator/knn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from autoconf import cached_property

from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay
Expand Down Expand Up @@ -241,3 +243,191 @@ def _mappings_sizes_weights_split(self):
# k=self.mesh.k_neighbors,
# radius_scale=self.mesh.radius_scale,
# )


def barycentric_weights_from_3_nearest(
query_points,
mesh_points,
nearest_3_indices,
xp,
):
"""
Compute barycentric weights for each query point on the triangle formed by its
3 nearest mesh vertices.

Signed barycentric coordinates are computed, then clipped to be non-negative
and renormalized so each row sums to 1. Queries inside the triangle return
the exact Delaunay weights; queries outside return a clipped approximation
(a convex combination of the 3 nearest, biased toward whichever vertices are
on the same side of the triangle as the query).

Degenerate triangles (collinear vertices) get zero weights to avoid NaN.

Parameters
----------
query_points : (Q, 2)
Query point (x, y) coordinates.
mesh_points : (N, 2)
Mesh vertex (x, y) coordinates.
nearest_3_indices : (Q, 3)
Indices into mesh_points of the 3 nearest vertices for each query.
xp : module
numpy or jax.numpy.

Returns
-------
weights : (Q, 3)
Barycentric weights, clipped non-negative and row-normalized.
"""
vertices = mesh_points[nearest_3_indices] # (Q, 3, 2)
p0 = vertices[:, 0]
p1 = vertices[:, 1]
p2 = vertices[:, 2]
q = query_points

def signed_cross(a, b, c):
return (b[..., 0] - a[..., 0]) * (c[..., 1] - a[..., 1]) - (
b[..., 1] - a[..., 1]
) * (c[..., 0] - a[..., 0])

total = signed_cross(p0, p1, p2)
w0 = signed_cross(q, p1, p2)
w1 = signed_cross(p0, q, p2)
w2 = signed_cross(p0, p1, q)

eps = xp.asarray(1e-12, dtype=total.dtype)
safe_total = xp.where(xp.abs(total) > eps, total, 1.0)

bary = xp.stack([w0, w1, w2], axis=1) / safe_total[:, None]

clipped = xp.maximum(bary, 0.0)
row_sum = xp.sum(clipped, axis=1, keepdims=True)
safe_sum = xp.where(row_sum > eps, row_sum, 1.0)
weights = clipped / safe_sum

# Degenerate triangles fall back to nearest-neighbor (weight 1 on column 0,
# which `get_interpolation_weights` orders as the closest mesh vertex).
# Same fallback policy as `pix_indexes_for_sub_slim_index_delaunay_from`
# for outside-simplex points.
nearest_only = xp.asarray([1.0, 0.0, 0.0], dtype=weights.dtype)

degenerate = xp.abs(total) <= eps
weights = xp.where(degenerate[:, None], nearest_only[None, :], weights)

return weights


class InterpolatorKNNBarycentric(InterpolatorKNearestNeighbor):
"""
Interpolator that picks the 3 nearest mesh vertices in the source plane and
computes locally-exact barycentric weights on the triangle they form.

Approximates :class:`InterpolatorDelaunay` without the scipy.spatial.Delaunay
callback: when the 3 nearest are the containing Delaunay triangle's vertices,
the weights are bit-identical to Delaunay; otherwise they are clipped-and-
renormalized barycentric weights on whichever triangle the 3 nearest form.

The kNN connectivity knobs (``k_neighbors``, ``radius_scale``,
``split_neighbor_division``) on the parent :class:`KNearestNeighbor` mesh are
inherited and still control the regularization-spacing computation via
``distance_to_self``. Interpolation always uses k=3, irrespective of
``mesh.k_neighbors``.
"""

@cached_property
def _mappings_sizes_weights(self):

try:
query_points = self.data_grid.over_sampled.array
except AttributeError:
try:
query_points = self.data_grid.array
except AttributeError:
query_points = self.data_grid

mappings, _, _ = get_interpolation_weights(
points=self.mesh_grid_xy,
query_points=query_points,
k_neighbors=3,
radius_scale=1.0,
)

weights = barycentric_weights_from_3_nearest(
query_points=query_points,
mesh_points=self.mesh_grid_xy,
nearest_3_indices=mappings,
xp=self._xp,
)

# On the numpy path, materialize with `np.array(...)` so the regularization
# code (which uses in-place assignment, e.g. `reg_split_np_from`) gets a
# writable buffer rather than a read-only view of a jax.Array. On the jax
# path, asarray is the right cast (no copy in a JIT trace).
if self._xp is np:
mappings = np.array(mappings)
weights = np.array(weights)
else:
mappings = self._xp.asarray(mappings)
weights = self._xp.asarray(weights)

sizes = self._xp.full(
(mappings.shape[0],),
mappings.shape[1],
)

return mappings, sizes, weights

@cached_property
def _mappings_sizes_weights_split(self):
"""
Same spacing scheme as :class:`InterpolatorKNearestNeighbor` but the
split-point interpolator is :class:`InterpolatorKNNBarycentric` so the
split-regularization weights are also barycentric rather than Wendland.
"""
from autoarray.inversion.regularization.regularization_util import (
split_points_from,
)

neighbor_index = int(self.mesh.k_neighbors) // self.mesh.split_neighbor_division

distance_to_self = self.distance_to_self
others = distance_to_self[:, 1:]
idx = int(neighbor_index) - 1
idx = max(0, min(idx, others.shape[1] - 1))
r_k = others[:, idx]

split_step = self.mesh.areas_factor * r_k

split_points = split_points_from(
points=self.mesh_grid.array,
area_weights=split_step,
xp=self._xp,
)

interpolator = InterpolatorKNNBarycentric(
mesh=self.mesh,
mesh_grid=self.mesh_grid,
data_grid=split_points,
xp=self._xp,
)

mappings = interpolator.mappings
weights = interpolator.weights

# `reg_split_np_from` writes `splitted_mappings[i][j+1] = pixel_index`
# for the "flag-zero" insertion of the central pixel, so the buffer
# must have an extra column reserved past the k=3 mappings — matching
# `InterpolatorDelaunay._mappings_sizes_weights_split`'s hstack-append.
# `sizes` reports 3 (the actual mappings); `reg_split_np_from` grows it
# to 4 in-place when it inserts.
sizes = self._xp.full(
(mappings.shape[0],),
mappings.shape[1],
)

pad_int = self._xp.full((mappings.shape[0], 1), -1, dtype=mappings.dtype)
pad_float = self._xp.zeros((weights.shape[0], 1), dtype=weights.dtype)
mappings = self._xp.hstack((mappings, pad_int))
weights = self._xp.hstack((weights, pad_float))

return mappings, sizes, weights
1 change: 1 addition & 0 deletions autoarray/inversion/mesh/mesh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .rectangular_uniform import RectangularUniform
from .delaunay import Delaunay
from .knn import KNearestNeighbor
from .knn import KNNBarycentric
25 changes: 25 additions & 0 deletions autoarray/inversion/mesh/mesh/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,28 @@ def interpolator_cls(self):
)

return InterpolatorKNearestNeighbor


class KNNBarycentric(KNearestNeighbor):
"""
A mesh that inherits k-nearest-neighbour connectivity from
:class:`KNearestNeighbor` but uses :class:`InterpolatorKNNBarycentric` to
compute interpolation weights as locally-exact barycentric coordinates on
the triangle formed by the 3 nearest source-plane mesh vertices, rather
than a Wendland kernel.

This is a JAX-native approximation to Delaunay barycentric interpolation
that avoids the scipy.spatial.Delaunay callback. The kNN connectivity knobs
(``k_neighbors``, ``radius_scale``, ``split_neighbor_division``) are
inherited and still control the regularization-spacing computation, but the
*interpolation* weights always use k=3 + barycentric and ignore them.
"""

@property
def interpolator_cls(self):

from autoarray.inversion.mesh.interpolator.knn import (
InterpolatorKNNBarycentric,
)

return InterpolatorKNNBarycentric
Loading
Loading