|
| 1 | +import numpy as np |
| 2 | + |
1 | 3 | from autoconf import cached_property |
2 | 4 |
|
3 | 5 | from autoarray.inversion.mesh.interpolator.delaunay import InterpolatorDelaunay |
@@ -241,3 +243,191 @@ def _mappings_sizes_weights_split(self): |
241 | 243 | # k=self.mesh.k_neighbors, |
242 | 244 | # radius_scale=self.mesh.radius_scale, |
243 | 245 | # ) |
| 246 | + |
| 247 | + |
| 248 | +def barycentric_weights_from_3_nearest( |
| 249 | + query_points, |
| 250 | + mesh_points, |
| 251 | + nearest_3_indices, |
| 252 | + xp, |
| 253 | +): |
| 254 | + """ |
| 255 | + Compute barycentric weights for each query point on the triangle formed by its |
| 256 | + 3 nearest mesh vertices. |
| 257 | +
|
| 258 | + Signed barycentric coordinates are computed, then clipped to be non-negative |
| 259 | + and renormalized so each row sums to 1. Queries inside the triangle return |
| 260 | + the exact Delaunay weights; queries outside return a clipped approximation |
| 261 | + (a convex combination of the 3 nearest, biased toward whichever vertices are |
| 262 | + on the same side of the triangle as the query). |
| 263 | +
|
| 264 | + Degenerate triangles (collinear vertices) get zero weights to avoid NaN. |
| 265 | +
|
| 266 | + Parameters |
| 267 | + ---------- |
| 268 | + query_points : (Q, 2) |
| 269 | + Query point (x, y) coordinates. |
| 270 | + mesh_points : (N, 2) |
| 271 | + Mesh vertex (x, y) coordinates. |
| 272 | + nearest_3_indices : (Q, 3) |
| 273 | + Indices into mesh_points of the 3 nearest vertices for each query. |
| 274 | + xp : module |
| 275 | + numpy or jax.numpy. |
| 276 | +
|
| 277 | + Returns |
| 278 | + ------- |
| 279 | + weights : (Q, 3) |
| 280 | + Barycentric weights, clipped non-negative and row-normalized. |
| 281 | + """ |
| 282 | + vertices = mesh_points[nearest_3_indices] # (Q, 3, 2) |
| 283 | + p0 = vertices[:, 0] |
| 284 | + p1 = vertices[:, 1] |
| 285 | + p2 = vertices[:, 2] |
| 286 | + q = query_points |
| 287 | + |
| 288 | + def signed_cross(a, b, c): |
| 289 | + return (b[..., 0] - a[..., 0]) * (c[..., 1] - a[..., 1]) - ( |
| 290 | + b[..., 1] - a[..., 1] |
| 291 | + ) * (c[..., 0] - a[..., 0]) |
| 292 | + |
| 293 | + total = signed_cross(p0, p1, p2) |
| 294 | + w0 = signed_cross(q, p1, p2) |
| 295 | + w1 = signed_cross(p0, q, p2) |
| 296 | + w2 = signed_cross(p0, p1, q) |
| 297 | + |
| 298 | + eps = xp.asarray(1e-12, dtype=total.dtype) |
| 299 | + safe_total = xp.where(xp.abs(total) > eps, total, 1.0) |
| 300 | + |
| 301 | + bary = xp.stack([w0, w1, w2], axis=1) / safe_total[:, None] |
| 302 | + |
| 303 | + clipped = xp.maximum(bary, 0.0) |
| 304 | + row_sum = xp.sum(clipped, axis=1, keepdims=True) |
| 305 | + safe_sum = xp.where(row_sum > eps, row_sum, 1.0) |
| 306 | + weights = clipped / safe_sum |
| 307 | + |
| 308 | + # Degenerate triangles fall back to nearest-neighbor (weight 1 on column 0, |
| 309 | + # which `get_interpolation_weights` orders as the closest mesh vertex). |
| 310 | + # Same fallback policy as `pix_indexes_for_sub_slim_index_delaunay_from` |
| 311 | + # for outside-simplex points. |
| 312 | + nearest_only = xp.asarray([1.0, 0.0, 0.0], dtype=weights.dtype) |
| 313 | + |
| 314 | + degenerate = xp.abs(total) <= eps |
| 315 | + weights = xp.where(degenerate[:, None], nearest_only[None, :], weights) |
| 316 | + |
| 317 | + return weights |
| 318 | + |
| 319 | + |
| 320 | +class InterpolatorKNNBarycentric(InterpolatorKNearestNeighbor): |
| 321 | + """ |
| 322 | + Interpolator that picks the 3 nearest mesh vertices in the source plane and |
| 323 | + computes locally-exact barycentric weights on the triangle they form. |
| 324 | +
|
| 325 | + Approximates :class:`InterpolatorDelaunay` without the scipy.spatial.Delaunay |
| 326 | + callback: when the 3 nearest are the containing Delaunay triangle's vertices, |
| 327 | + the weights are bit-identical to Delaunay; otherwise they are clipped-and- |
| 328 | + renormalized barycentric weights on whichever triangle the 3 nearest form. |
| 329 | +
|
| 330 | + The kNN connectivity knobs (``k_neighbors``, ``radius_scale``, |
| 331 | + ``split_neighbor_division``) on the parent :class:`KNearestNeighbor` mesh are |
| 332 | + inherited and still control the regularization-spacing computation via |
| 333 | + ``distance_to_self``. Interpolation always uses k=3, irrespective of |
| 334 | + ``mesh.k_neighbors``. |
| 335 | + """ |
| 336 | + |
| 337 | + @cached_property |
| 338 | + def _mappings_sizes_weights(self): |
| 339 | + |
| 340 | + try: |
| 341 | + query_points = self.data_grid.over_sampled.array |
| 342 | + except AttributeError: |
| 343 | + try: |
| 344 | + query_points = self.data_grid.array |
| 345 | + except AttributeError: |
| 346 | + query_points = self.data_grid |
| 347 | + |
| 348 | + mappings, _, _ = get_interpolation_weights( |
| 349 | + points=self.mesh_grid_xy, |
| 350 | + query_points=query_points, |
| 351 | + k_neighbors=3, |
| 352 | + radius_scale=1.0, |
| 353 | + ) |
| 354 | + |
| 355 | + weights = barycentric_weights_from_3_nearest( |
| 356 | + query_points=query_points, |
| 357 | + mesh_points=self.mesh_grid_xy, |
| 358 | + nearest_3_indices=mappings, |
| 359 | + xp=self._xp, |
| 360 | + ) |
| 361 | + |
| 362 | + # On the numpy path, materialize with `np.array(...)` so the regularization |
| 363 | + # code (which uses in-place assignment, e.g. `reg_split_np_from`) gets a |
| 364 | + # writable buffer rather than a read-only view of a jax.Array. On the jax |
| 365 | + # path, asarray is the right cast (no copy in a JIT trace). |
| 366 | + if self._xp is np: |
| 367 | + mappings = np.array(mappings) |
| 368 | + weights = np.array(weights) |
| 369 | + else: |
| 370 | + mappings = self._xp.asarray(mappings) |
| 371 | + weights = self._xp.asarray(weights) |
| 372 | + |
| 373 | + sizes = self._xp.full( |
| 374 | + (mappings.shape[0],), |
| 375 | + mappings.shape[1], |
| 376 | + ) |
| 377 | + |
| 378 | + return mappings, sizes, weights |
| 379 | + |
| 380 | + @cached_property |
| 381 | + def _mappings_sizes_weights_split(self): |
| 382 | + """ |
| 383 | + Same spacing scheme as :class:`InterpolatorKNearestNeighbor` but the |
| 384 | + split-point interpolator is :class:`InterpolatorKNNBarycentric` so the |
| 385 | + split-regularization weights are also barycentric rather than Wendland. |
| 386 | + """ |
| 387 | + from autoarray.inversion.regularization.regularization_util import ( |
| 388 | + split_points_from, |
| 389 | + ) |
| 390 | + |
| 391 | + neighbor_index = int(self.mesh.k_neighbors) // self.mesh.split_neighbor_division |
| 392 | + |
| 393 | + distance_to_self = self.distance_to_self |
| 394 | + others = distance_to_self[:, 1:] |
| 395 | + idx = int(neighbor_index) - 1 |
| 396 | + idx = max(0, min(idx, others.shape[1] - 1)) |
| 397 | + r_k = others[:, idx] |
| 398 | + |
| 399 | + split_step = self.mesh.areas_factor * r_k |
| 400 | + |
| 401 | + split_points = split_points_from( |
| 402 | + points=self.mesh_grid.array, |
| 403 | + area_weights=split_step, |
| 404 | + xp=self._xp, |
| 405 | + ) |
| 406 | + |
| 407 | + interpolator = InterpolatorKNNBarycentric( |
| 408 | + mesh=self.mesh, |
| 409 | + mesh_grid=self.mesh_grid, |
| 410 | + data_grid=split_points, |
| 411 | + xp=self._xp, |
| 412 | + ) |
| 413 | + |
| 414 | + mappings = interpolator.mappings |
| 415 | + weights = interpolator.weights |
| 416 | + |
| 417 | + # `reg_split_np_from` writes `splitted_mappings[i][j+1] = pixel_index` |
| 418 | + # for the "flag-zero" insertion of the central pixel, so the buffer |
| 419 | + # must have an extra column reserved past the k=3 mappings — matching |
| 420 | + # `InterpolatorDelaunay._mappings_sizes_weights_split`'s hstack-append. |
| 421 | + # `sizes` reports 3 (the actual mappings); `reg_split_np_from` grows it |
| 422 | + # to 4 in-place when it inserts. |
| 423 | + sizes = self._xp.full( |
| 424 | + (mappings.shape[0],), |
| 425 | + mappings.shape[1], |
| 426 | + ) |
| 427 | + |
| 428 | + pad_int = self._xp.full((mappings.shape[0], 1), -1, dtype=mappings.dtype) |
| 429 | + pad_float = self._xp.zeros((weights.shape[0], 1), dtype=weights.dtype) |
| 430 | + mappings = self._xp.hstack((mappings, pad_int)) |
| 431 | + weights = self._xp.hstack((weights, pad_float)) |
| 432 | + |
| 433 | + return mappings, sizes, weights |
0 commit comments