diff --git a/autolens/analysis/latent.py b/autolens/analysis/latent.py index 2058d4293..0ef44bbd7 100644 --- a/autolens/analysis/latent.py +++ b/autolens/analysis/latent.py @@ -76,12 +76,18 @@ def total_lensed_source_flux_mujy(fit, magzero, xp=np): def total_source_flux_mujy(fit, magzero, xp=np): """ - Source-plane intrinsic flux of the source galaxy, via - ``fit.tracer.galaxies[-1].image_2d_from(grid=fit.dataset.grids.lp)``. + Source-plane intrinsic flux of the source galaxy, in microjanskies. + + Reads from ``fit.tracer_linear_light_profiles_to_light_profiles`` rather + than ``fit.tracer`` so that linear light profiles (whose ``intensity`` + is solved by the inversion at fit time) contribute the correct image. + For non-linear fits this property is a no-op pass-through (returns + ``fit.tracer``), so the numpy-only and JAX paths both work uniformly. """ _require_magzero(magzero, "total_source_flux_mujy") try: - source_image = fit.tracer.galaxies[-1].image_2d_from( + tracer = fit.tracer_linear_light_profiles_to_light_profiles + source_image = tracer.galaxies[-1].image_2d_from( grid=fit.dataset.grids.lp, xp=xp ) except (AttributeError, IndexError): diff --git a/test_autolens/analysis/test_latent.py b/test_autolens/analysis/test_latent.py index 1e97375fb..848f79353 100644 --- a/test_autolens/analysis/test_latent.py +++ b/test_autolens/analysis/test_latent.py @@ -82,8 +82,13 @@ def test_total_source_flux_mujy_against_known_image(): array=np.array([2.0, 3.0, 5.0]) ) ) + # Both `tracer` and `tracer_linear_light_profiles_to_light_profiles` + # point at the same galaxies for non-linear fits — the conversion + # property is a no-op pass-through. + galaxies_namespace = SimpleNamespace(galaxies=[object(), source]) fit = SimpleNamespace( - tracer=SimpleNamespace(galaxies=[object(), source]), + tracer=galaxies_namespace, + tracer_linear_light_profiles_to_light_profiles=galaxies_namespace, dataset=SimpleNamespace(grids=SimpleNamespace(lp=object())), ) value = total_source_flux_mujy(fit=fit, magzero=25.0) @@ -93,6 +98,38 @@ def test_total_source_flux_mujy_against_known_image(): assert value == pytest.approx(expected_muJy) +def test_total_source_flux_mujy_uses_converted_tracer_for_linear_profiles(): + """When the source has a linear light profile, ``fit.tracer.galaxies[-1]`` + is un-solved (``image_2d_from`` returns zeros). The library must read from + ``fit.tracer_linear_light_profiles_to_light_profiles`` where intensities + are filled in from the inversion.""" + + unsolved_source = SimpleNamespace( + image_2d_from=lambda grid, xp=np: SimpleNamespace(array=np.zeros(4)) + ) + solved_source = SimpleNamespace( + image_2d_from=lambda grid, xp=np: SimpleNamespace( + array=np.array([1.0, 2.0, 3.0, 4.0]) + ) + ) + fit = SimpleNamespace( + tracer=SimpleNamespace(galaxies=[object(), unsolved_source]), + tracer_linear_light_profiles_to_light_profiles=SimpleNamespace( + galaxies=[object(), solved_source] + ), + dataset=SimpleNamespace(grids=SimpleNamespace(lp=object())), + ) + + value = total_source_flux_mujy(fit=fit, magzero=25.0) + + # Expected from the solved source (sum = 10): + expected_ab_mag = -2.5 * np.log10(10.0) + 25.0 + expected_muJy = 10 ** ((23.9 - expected_ab_mag) / 2.5) + assert value == pytest.approx(expected_muJy) + # Confirm we did NOT read from the unsolved tracer (which would give 0). + assert value != 0.0 + + def test_total_source_flux_mujy_missing_magzero_raises(): with pytest.raises(ValueError, match="magzero"): total_source_flux_mujy(fit=MagicMock(), magzero=None) @@ -108,8 +145,12 @@ def image_2d_from(self, grid, xp=np): return SimpleNamespace(array=np.array([2.0])) source = _FakeSourceGalaxy() + # Same fakes for tracer and the converted-tracer property (no-op + # pass-through for non-linear profile fixtures). + galaxies_namespace = SimpleNamespace(galaxies=[object(), source]) fit = SimpleNamespace( - tracer=SimpleNamespace(galaxies=[object(), source]), + tracer=galaxies_namespace, + tracer_linear_light_profiles_to_light_profiles=galaxies_namespace, galaxy_image_dict={source: SimpleNamespace(array=np.array([10.0]))}, dataset=SimpleNamespace(grids=SimpleNamespace(lp=object())), )