Skip to content

Commit 0b424c4

Browse files
Jammy2211Jammy2211
authored andcommitted
adaptive_pixel_signals_from JAX-d
1 parent 93157b8 commit 0b424c4

6 files changed

Lines changed: 94 additions & 44 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,6 @@ def __init__(
6868
Settings controlling how an inversion is fitted for example which linear algebra formalism is used.
6969
"""
7070

71-
try:
72-
import numba
73-
except ModuleNotFoundError:
74-
raise exc.InversionException(
75-
"Inversion functionality (linear light profiles, pixelized reconstructions) is "
76-
"disabled if numba is not installed.\n\n"
77-
"This is because the run-times without numba are too slow.\n\n"
78-
"Please install numba, which is described at the following web page:\n\n"
79-
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
80-
)
81-
8271
self.dataset = dataset
8372

8473
self.linear_obj_list = linear_obj_list
@@ -160,17 +149,10 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]:
160149
-------
161150
A list of the index range of the parameters of each linear object in the inversion of the input cls type.
162151
"""
163-
index_list = []
164-
165-
pixel_count = 0
166-
167-
for linear_obj in self.linear_obj_list:
168-
if isinstance(linear_obj, cls):
169-
index_list.append([pixel_count, pixel_count + linear_obj.params])
170-
171-
pixel_count += linear_obj.params
172-
173-
return index_list
152+
return inversion_util.param_range_list_from(
153+
cls=cls,
154+
linear_obj_list=self.linear_obj_list
155+
)
174156

175157
def cls_list_from(self, cls: Type, cls_filtered: Optional[Type] = None) -> List:
176158
"""

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from typing import Dict, List, Optional, Union, Type
2+
from typing import Dict, List, Union, Type
33

44
from autoconf import cached_property
55

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ def __init__(
4949
the simultaneous linear equations are combined and solved simultaneously.
5050
"""
5151

52+
try:
53+
import numba
54+
except ModuleNotFoundError:
55+
raise exc.InversionException(
56+
"Inversion functionality (linear light profiles, pixelized reconstructions) is "
57+
"disabled if numba is not installed.\n\n"
58+
"This is because the run-times without numba are too slow.\n\n"
59+
"Please install numba, which is described at the following web page:\n\n"
60+
"https://pyautolens.readthedocs.io/en/latest/installation/overview.html"
61+
)
62+
5263
super().__init__(
5364
dataset=dataset,
5465
linear_obj_list=linear_obj_list,

autoarray/inversion/inversion/inversion_util.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import jax.lax as lax
33
import numpy as np
44

5-
from typing import List, Optional
5+
from typing import List, Optional, Type
66

77
from autoconf import conf
88

@@ -346,3 +346,48 @@ def preconditioner_matrix_via_mapping_matrix_from(
346346
return (
347347
preconditioner_noise_normalization * curvature_matrix
348348
) + regularization_matrix
349+
350+
351+
def param_range_list_from(cls: Type, linear_obj_list) -> List[List[int]]:
352+
"""
353+
Each linear object in the `Inversion` has N parameters, and these parameters correspond to a certain range
354+
of indexing values in the matrices used to perform the inversion.
355+
356+
This function returns the `param_range_list` of an input type of linear object, which gives the indexing range
357+
of each linear object of the input type.
358+
359+
For example, if an `Inversion` has:
360+
361+
- A `LinearFuncList` linear object with 3 `params`.
362+
- A `Mapper` with 100 `params`.
363+
- A `Mapper` with 200 `params`.
364+
365+
The corresponding matrices of this inversion (e.g. the `curvature_matrix`) have `shape=(303, 303)` where:
366+
367+
- The `LinearFuncList` values are in the entries `[0:3]`.
368+
- The first `Mapper` values are in the entries `[3:103]`.
369+
- The second `Mapper` values are in the entries `[103:303]
370+
371+
For this example, `param_range_list_from(cls=AbstractMapper)` therefore returns the
372+
list `[[3, 103], [103, 303]]`.
373+
374+
Parameters
375+
----------
376+
cls
377+
The type of class that the list of their parameter range index values are returned for.
378+
379+
Returns
380+
-------
381+
A list of the index range of the parameters of each linear object in the inversion of the input cls type.
382+
"""
383+
index_list = []
384+
385+
pixel_count = 0
386+
387+
for linear_obj in linear_obj_list:
388+
if isinstance(linear_obj, cls):
389+
index_list.append([pixel_count, pixel_count + linear_obj.params])
390+
391+
pixel_count += linear_obj.params
392+
393+
return index_list

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,6 @@ def remove_bad_entries_voronoi_nn(
498498
return pix_weights_for_sub_slim_index, pix_indexes_for_sub_slim_index
499499

500500

501-
@numba_util.jit()
502501
def adaptive_pixel_signals_from(
503502
pixels: int,
504503
pixel_weights: np.ndarray,
@@ -536,30 +535,43 @@ def adaptive_pixel_signals_from(
536535
The image of the galaxy which is used to compute the weigghted pixel signals.
537536
"""
538537

539-
pixel_signals = np.zeros((pixels,))
540-
pixel_sizes = np.zeros((pixels,))
538+
M_sub, B = pix_indexes_for_sub_slim_index.shape
541539

542-
for sub_slim_index in range(len(pix_indexes_for_sub_slim_index)):
543-
vertices_indexes = pix_indexes_for_sub_slim_index[sub_slim_index]
540+
# 1) Flatten the per‐mapping tables:
541+
flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,)
542+
flat_weights = pixel_weights.reshape(-1) # (M_sub*B,)
544543

545-
mask_1d_index = slim_index_for_sub_slim_index[sub_slim_index]
544+
# 2) Build a matching “parent‐slim” index for each flattened entry:
545+
I_sub = jnp.repeat(jnp.arange(M_sub), B) # (M_sub*B,)
546546

547-
pix_size_tem = pix_size_for_sub_slim_index[sub_slim_index]
547+
# 3) Mask out any k >= pix_size_for_sub_slim_index[i]
548+
valid = (I_sub < 0) # dummy to get shape
549+
# better:
550+
valid = (jnp.arange(B)[None, :] < pix_size_for_sub_slim_index[:, None]).reshape(-1)
548551

549-
if pix_size_tem > 1:
550-
pixel_signals[vertices_indexes[:pix_size_tem]] += (
551-
adapt_data[mask_1d_index] * pixel_weights[sub_slim_index]
552-
)
553-
pixel_sizes[vertices_indexes] += 1
554-
else:
555-
pixel_signals[vertices_indexes[0]] += adapt_data[mask_1d_index]
556-
pixel_sizes[vertices_indexes[0]] += 1
552+
flat_weights = jnp.where(valid, flat_weights, 0.0)
553+
flat_pixidx = jnp.where(valid, flat_pixidx, pixels) # send invalid indices to an out-of-bounds slot
554+
555+
# 4) Look up data & multiply by mapping weights:
556+
flat_data_vals = adapt_data[slim_index_for_sub_slim_index][I_sub] # (M_sub*B,)
557+
flat_contrib = flat_data_vals * flat_weights # (M_sub*B,)
558+
559+
# 5) Scatter‐add into signal sums and counts:
560+
pixel_signals = jnp.zeros((pixels+1,)).at[flat_pixidx].add(flat_contrib)
561+
pixel_counts = jnp.zeros((pixels+1,)).at[flat_pixidx].add(valid.astype(float))
562+
563+
# 6) Drop the extra “out-of-bounds” slot:
564+
pixel_signals = pixel_signals[:pixels]
565+
pixel_counts = pixel_counts[:pixels]
557566

558-
pixel_sizes[pixel_sizes == 0] = 1
559-
pixel_signals /= pixel_sizes
560-
pixel_signals /= np.max(pixel_signals)
567+
# 7) Normalize
568+
pixel_counts = jnp.where(pixel_counts > 0, pixel_counts, 1.0)
569+
pixel_signals = pixel_signals / pixel_counts
570+
max_sig = jnp.max(pixel_signals)
571+
pixel_signals = jnp.where(max_sig > 0, pixel_signals / max_sig, pixel_signals)
561572

562-
return pixel_signals**signal_scale
573+
# 8) Exponentiate
574+
return pixel_signals ** signal_scale
563575

564576

565577
def mapping_matrix_from(

test_autoarray/inversion/pixelization/mappers/test_rectangular.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test__pixel_signals_from__matches_util(grid_2d_sub_1_7x7, image_7x7):
6868
pix_size_for_sub_slim_index=mapper.pix_sizes_for_sub_slim_index,
6969
pixel_weights=mapper.pix_weights_for_sub_slim_index,
7070
slim_index_for_sub_slim_index=grid_2d_sub_1_7x7.over_sampler.slim_for_sub_slim,
71-
adapt_data=np.array(image_7x7),
71+
adapt_data=image_7x7,
7272
)
7373

7474
assert (pixel_signals == pixel_signals_util).all()

0 commit comments

Comments
 (0)