Skip to content

Commit 446e5bc

Browse files
Jammy2211Jammy2211
authored andcommitted
zeroing now handled internally with positive values
1 parent 2119d6b commit 446e5bc

4 files changed

Lines changed: 97 additions & 52 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,60 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
379379
# Zero rows and columns in the matrix we want to ignore
380380
return self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]
381381

382+
@cached_property
383+
def zeroed_ids_to_keep(self):
384+
"""
385+
Return the **positive global indices** of linear objects to keep in the inversion,
386+
given **mesh-local positive pixel indices** to zero.
387+
388+
Assumes the full linear system parameter vector is laid out as:
389+
390+
[ non-pixel linear objects ][ pixel block ]
391+
392+
where:
393+
- the pixel block occupies the final `mesh.pixels` entries of the full vector
394+
- `mesh.zeroed_pixels` contains **positive** indices in the pixel block's own
395+
indexing (0 is top-left for rectangular meshes, increasing row-major)
396+
- all non-pixel linear objects are always kept
397+
398+
This implementation is backend-friendly (NumPy / JAX via `self._xp`) and returns
399+
indices suitable for advanced indexing of `data_vector` and curvature matrices.
400+
401+
Returns
402+
-------
403+
array-like
404+
1D array of **positive** indices to keep, sorted ascending.
405+
"""
406+
xp = self._xp
407+
408+
mapper_list = self.cls_list_from(cls=Mapper)
409+
mesh = mapper_list[0].mesh
410+
411+
n_total = int(self.total_params)
412+
n_pixels = int(mesh.pixels)
413+
414+
# Pixel block starts at this global index
415+
pixel_start = n_total - n_pixels
416+
417+
# Mesh-local positive pixel indices to zero (e.g. [0, 1, 2, ...] for edges)
418+
zeros_local = xp.asarray(mesh.zeroed_pixels)
419+
420+
# Convert to global positive indices
421+
zeros_global = pixel_start + zeros_local
422+
423+
# Keep mask over full parameter vector
424+
keep = xp.ones((n_total,), dtype=bool)
425+
426+
if hasattr(keep, "at"):
427+
keep = keep.at[zeros_global].set(False)
428+
keep_ids = xp.nonzero(keep, size=n_total)[0]
429+
keep_ids = keep_ids[: keep.sum()]
430+
else:
431+
keep[zeros_global] = False
432+
keep_ids = xp.nonzero(keep)[0]
433+
434+
return keep_ids
435+
382436
@cached_property
383437
def reconstruction(self) -> np.ndarray:
384438
"""
@@ -400,16 +454,11 @@ def reconstruction(self) -> np.ndarray:
400454

401455
if self.settings.use_edge_zeroed_pixels and self.has(cls=Mapper):
402456

403-
mapper_list = self.cls_list_from(cls=Mapper)
404-
405-
# ids of values which are not zeroed and therefore kept in soluiton, which is computed in preloads.
406-
ids_to_keep = mapper_list[0].mesh.zeroed_pixels_to_keep
407-
408457
# Use advanced indexing to select rows/columns
409-
data_vector = self.data_vector[ids_to_keep]
410-
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][
411-
:, ids_to_keep
412-
]
458+
data_vector = self.data_vector[self.zeroed_ids_to_keep]
459+
curvature_reg_matrix = self.curvature_reg_matrix[
460+
self.zeroed_ids_to_keep
461+
][:, self.zeroed_ids_to_keep]
413462

414463
# Perform reconstruction via fnnls
415464
reconstruction_partial = (
@@ -426,11 +475,11 @@ def reconstruction(self) -> np.ndarray:
426475

427476
# Scatter the partial solution back to the full shape
428477
if self._xp.__name__.startswith("jax"):
429-
reconstruction = reconstruction.at[ids_to_keep].set(
478+
reconstruction = reconstruction.at[self.zeroed_ids_to_keep].set(
430479
reconstruction_partial
431480
)
432481
else:
433-
reconstruction[ids_to_keep] = reconstruction_partial
482+
reconstruction[self.zeroed_ids_to_keep] = reconstruction_partial
434483

435484
return reconstruction
436485

autoarray/inversion/mesh/mesh/abstract.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -52,39 +52,6 @@ def relocated_grid_from(
5252
xp=xp,
5353
)
5454

55-
@property
56-
def zeroed_pixels_to_keep(self):
57-
"""
58-
Return the positive indices of pixels that should be kept (solved for),
59-
accounting for zeroed pixels specified using Python-style negative indexing.
60-
61-
This property assumes that `self.zeroed_pixels` contains **negative indices**
62-
referring to entries counted from the right-hand side of the parameter array
63-
(e.g. -1 is the last entry, -2 the second-to-last, etc.).
64-
65-
These negative indices are converted to their corresponding positive indices
66-
before constructing a boolean mask over the full set of mapper indices.
67-
68-
Returns
69-
-------
70-
np.ndarray
71-
A 1D array of positive indices corresponding to pixels that are *not*
72-
zeroed and should therefore be included in the solve.
73-
"""
74-
# Negative indices from the right (e.g. [-1, -2, ...])
75-
ids_zeros_neg = np.array(self.zeroed_pixels, dtype=int)
76-
77-
# Total number of values being solved for
78-
n_values = self.pixels
79-
80-
# Convert negative indices to positive
81-
ids_zeros_pos = n_values + ids_zeros_neg
82-
83-
values_to_solve = np.ones(n_values, dtype=bool)
84-
values_to_solve[ids_zeros_pos] = False
85-
86-
return np.where(values_to_solve)[0]
87-
8855
def relocated_mesh_grid_from(
8956
self,
9057
border_relocator: Optional[BorderRelocator],

autoarray/inversion/mesh/mesh/delaunay.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,30 @@ def __init__(
3838
pixels = int(pixels) + zeroed_pixels
3939

4040
super().__init__()
41+
self.pixels = pixels
4142
self.areas_factor = areas_factor
42-
self.zeroed_pixels = zeroed_pixels
43+
self._zeroed_pixels = zeroed_pixels
4344

45+
@property
4446
def zeroed_pixels(self):
45-
if not self.zeroed_pixels:
46-
return []
47+
"""
48+
Return the **positive** mesh-local pixel indices to zero for a Delaunay mesh.
49+
50+
For Delaunay meshes, `self.zeroed_pixels` is interpreted as a *count* of pixels
51+
to be zeroed at the end of the pixel block. For example:
52+
self.pixels = 780, self.zeroed_pixels = 30
53+
returns indices 750..779.
54+
55+
Returns
56+
-------
57+
np.ndarray
58+
1D array of positive pixel indices to zero.
59+
"""
60+
if self._zeroed_pixels <= 0:
61+
return np.array([], dtype=int)
4762

48-
return -np.arange(1, self.zeroed_pixels + 1).tolist()
63+
start = self.pixels - self._zeroed_pixels
64+
return np.arange(start, self.pixels, dtype=int)
4965

5066
@property
5167
def skip_areas(self):

autoarray/inversion/mesh/mesh/rectangular_adapt_density.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,29 @@ def __init__(self, shape: Tuple[int, int] = (3, 3)):
102102

103103
@property
104104
def zeroed_pixels(self):
105+
"""
106+
Return the **positive** 1D pixel indices of the edge pixels in a rectangular mesh.
107+
108+
Indices are in row-major (C-order) flattened form for the rectangular pixel grid:
109+
- 0 corresponds to the top-left pixel (row=0, col=0)
110+
- indices increase across rows
111+
112+
These indices are defined purely within the rectangular mesh's pixel indexing
113+
scheme (size = rows * cols) and are intended to be shifted / mapped to the full
114+
inversion indexing inside the inversion logic.
105115
116+
Returns
117+
-------
118+
np.ndarray
119+
A 1D array of positive indices corresponding to edge pixels.
120+
"""
106121
from autoarray.inversion.mesh.mesh_geometry.rectangular import (
107122
rectangular_edge_pixel_list_from,
108123
)
109124

110-
edge_pixe_list = rectangular_edge_pixel_list_from(
111-
shape_native=self.shape,
112-
)
125+
edge_pixel_list = rectangular_edge_pixel_list_from(shape_native=self.shape)
113126

114-
return -(np.array(edge_pixe_list) + 1)
127+
return np.array(edge_pixel_list, dtype=int)
115128

116129
@property
117130
def interpolator_cls(self):

0 commit comments

Comments
 (0)