Skip to content

Commit b9e0920

Browse files
Jammy2211Jammy2211
authored andcommitted
knn without batch fully working with factory fixes
1 parent f5de839 commit b9e0920

4 files changed

Lines changed: 78 additions & 15 deletions

File tree

autoarray/inversion/pixelization/mappers/factory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,22 @@ def mapper_from(
7070
preloads=preloads,
7171
xp=xp,
7272
)
73-
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay):
74-
return MapperDelaunay(
73+
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunayKNN):
74+
return MapperKNNInterpolator(
7575
mapper_grids=mapper_grids,
7676
border_relocator=border_relocator,
7777
regularization=regularization,
7878
settings=settings,
7979
preloads=preloads,
8080
xp=xp,
8181
)
82-
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunayKNN):
83-
return MapperKNNInterpolator(
82+
83+
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DDelaunay):
84+
return MapperDelaunay(
8485
mapper_grids=mapper_grids,
8586
border_relocator=border_relocator,
8687
regularization=regularization,
8788
settings=settings,
8889
preloads=preloads,
8990
xp=xp,
90-
)
91+
)

autoarray/inversion/pixelization/mappers/knn.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,12 @@ def _pix_sub_weights_from_query_points(self, query_points) -> PixSubWeights:
5050
# ------------------------------------------------------------------
5151
# Convert outputs to xp backend *only if needed*
5252
# ------------------------------------------------------------------
53-
if xp is jnp:
54-
weights = weights_jax
55-
mappings = indices_jax
56-
else:
57-
# xp is numpy
53+
if xp is np:
5854
weights = np.asarray(weights_jax)
5955
mappings = np.asarray(indices_jax)
56+
else:
57+
weights = weights_jax
58+
mappings = indices_jax
6059

6160
# ------------------------------------------------------------------
6261
# Sizes: always k for kNN
@@ -90,11 +89,40 @@ def pix_sub_weights(self) -> PixSubWeights:
9089
@property
9190
def pix_sub_weights_split_points(self) -> PixSubWeights:
9291
"""
93-
kNN mappings + kernel weights computed at split points (for split regularization schemes).
92+
kNN mappings + kernel weights computed at split points (for split regularization schemes),
93+
with split-point step sizes derived from kNN local spacing (no Delaunay / simplices).
9494
"""
95-
# Your Delaunay mesh exposes split points via self.delaunay.split_points.
96-
# For KNN mesh, you should expose the same property. If not, route appropriately:
97-
# split_points = self.mesh.split_points
98-
split_points = self.delaunay.split_points # keep consistent with existing API
95+
from autoarray.structures.mesh.delaunay_2d import split_points_from
9996

97+
# TODO: wire these to your pixelization / regularization config rather than hard-code.
98+
k_neighbors = 10
99+
kernel = "wendland_c4"
100+
radius_scale = 1.5
101+
areas_factor = 0.5
102+
103+
xp = self._xp # np or jnp
104+
105+
# Mesh points (N, 2)
106+
points = xp.asarray(self.source_plane_mesh_grid.array, dtype=xp.float64)
107+
108+
# kNN distances of each point to its neighbors (include self, then drop it)
109+
_, _, dist_self = get_interpolation_weights(
110+
points=points,
111+
query_points=points,
112+
k_neighbors=int(k_neighbors) + 1,
113+
kernel=kernel,
114+
radius_scale=float(radius_scale),
115+
)
116+
117+
# Local spacing scale: distance to k-th nearest OTHER point
118+
r_k = dist_self[:, 1:][:, -1] # (N,)
119+
120+
# Split cross step size (length): sqrt(area) ~ r_k
121+
split_step = xp.asarray(areas_factor, dtype=xp.float64) * r_k # (N,)
122+
123+
# Split points (xp-native)
124+
split_points = split_points_from(points=points, area_weights=split_step, xp=xp)
125+
126+
# Compute kNN mappings/weights at split points
100127
return self._pix_sub_weights_from_query_points(query_points=split_points)
128+

autoarray/inversion/pixelization/mesh/knn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
Uses Wendland compactly supported kernels with normalized weights (partition of unity).
44
More robust and faster than MLS, better accuracy than simple IDW.
55
"""
6+
import numpy as np
67
import jax
78
import jax.numpy as jnp
89
from functools import partial
910

11+
from autoarray.structures.mesh.knn_delaunay_2d import Mesh2DDelaunayKNN
12+
1013

1114
def get_interpolation_weights(points, query_points, k_neighbors=10, kernel='wendland_c4',
1215
radius_scale=1.5):
@@ -268,3 +271,33 @@ def __init__(self, k_neighbors=10, kernel='wendland_c4',
268271

269272
super().__init__()
270273

274+
def mesh_grid_from(
275+
self,
276+
source_plane_data_grid=None,
277+
source_plane_mesh_grid=None,
278+
preloads=None,
279+
xp=np,
280+
):
281+
"""
282+
Return the Delaunay ``source_plane_mesh_grid`` as a ``Mesh2DDelaunay`` object, which provides additional
283+
functionality for performing operations that exploit the geometry of a Delaunay mesh.
284+
285+
Parameters
286+
----------
287+
source_plane_data_grid
288+
A 2D grid of (y,x) coordinates associated with the unmasked 2D data after it has been transformed to the
289+
``source`` reference frame.
290+
source_plane_mesh_grid
291+
The centres of every Delaunay pixel in the ``source`` frame, which are initially derived by computing a sparse
292+
set of (y,x) coordinates computed from the unmasked data in the image-plane and applying a transformation
293+
to this.
294+
settings
295+
Settings controlling the pixelization for example if a border is used to relocate its exterior coordinates.
296+
"""
297+
298+
return Mesh2DDelaunayKNN(
299+
values=source_plane_mesh_grid,
300+
source_plane_data_grid_over_sampled=source_plane_data_grid,
301+
preloads=preloads,
302+
_xp=xp,
303+
)

autoarray/preloads.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
use_voronoi_areas: bool = True,
2828
areas_factor: float = 0.5,
2929
skip_areas: bool = False,
30+
splitted_only : bool = False
3031
):
3132
"""
3233
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance

0 commit comments

Comments
 (0)