@@ -50,13 +50,12 @@ def _pix_sub_weights_from_query_points(self, query_points) -> PixSubWeights:
5050 # ------------------------------------------------------------------
5151 # Convert outputs to xp backend *only if needed*
5252 # ------------------------------------------------------------------
53- if xp is jnp :
54- weights = weights_jax
55- mappings = indices_jax
56- else :
57- # xp is numpy
53+ if xp is np :
5854 weights = np .asarray (weights_jax )
5955 mappings = np .asarray (indices_jax )
56+ else :
57+ weights = weights_jax
58+ mappings = indices_jax
6059
6160 # ------------------------------------------------------------------
6261 # Sizes: always k for kNN
@@ -90,11 +89,40 @@ def pix_sub_weights(self) -> PixSubWeights:
9089 @property
9190 def pix_sub_weights_split_points (self ) -> PixSubWeights :
9291 """
93- kNN mappings + kernel weights computed at split points (for split regularization schemes).
92+ kNN mappings + kernel weights computed at split points (for split regularization schemes),
93+ with split-point step sizes derived from kNN local spacing (no Delaunay / simplices).
9494 """
95- # Your Delaunay mesh exposes split points via self.delaunay.split_points.
96- # For KNN mesh, you should expose the same property. If not, route appropriately:
97- # split_points = self.mesh.split_points
98- split_points = self .delaunay .split_points # keep consistent with existing API
95+ from autoarray .structures .mesh .delaunay_2d import split_points_from
9996
97+ # TODO: wire these to your pixelization / regularization config rather than hard-code.
98+ k_neighbors = 10
99+ kernel = "wendland_c4"
100+ radius_scale = 1.5
101+ areas_factor = 0.5
102+
103+ xp = self ._xp # np or jnp
104+
105+ # Mesh points (N, 2)
106+ points = xp .asarray (self .source_plane_mesh_grid .array , dtype = xp .float64 )
107+
108+ # kNN distances of each point to its neighbors (include self, then drop it)
109+ _ , _ , dist_self = get_interpolation_weights (
110+ points = points ,
111+ query_points = points ,
112+ k_neighbors = int (k_neighbors ) + 1 ,
113+ kernel = kernel ,
114+ radius_scale = float (radius_scale ),
115+ )
116+
117+ # Local spacing scale: distance to k-th nearest OTHER point
118+ r_k = dist_self [:, 1 :][:, - 1 ] # (N,)
119+
120+ # Split cross step size (length): sqrt(area) ~ r_k
121+ split_step = xp .asarray (areas_factor , dtype = xp .float64 ) * r_k # (N,)
122+
123+ # Split points (xp-native)
124+ split_points = split_points_from (points = points , area_weights = split_step , xp = xp )
125+
126+ # Compute kNN mappings/weights at split points
100127 return self ._pix_sub_weights_from_query_points (query_points = split_points )
128+
0 commit comments