Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
95 changes: 95 additions & 0 deletions scripts/jax_grad/interferometer/lp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Tests that jax.value_and_grad can compute finite, non-NaN gradients of the log-likelihood
for an autogalaxy interferometer model with a parametric Sersic light profile. This tests
the core JAX differentiability that enables gradient-based inference on visibility data.
"""

import numpy as np
import jax
import jax.numpy as jnp
from os import path

import autofit as af
import autogalaxy as ag

dataset_path = path.join("dataset", "interferometer", "jax_test")

"""
__Dataset Auto-Simulation__

If the dataset does not already exist on your system, it will be created by running the corresponding
simulator script. This ensures that all example scripts can be run without manually simulating data first.
"""
if not path.exists(path.join(dataset_path, "data.fits")):
import subprocess
import sys

subprocess.run(
[
sys.executable,
"scripts/jax_likelihood_functions/interferometer/simulator.py",
],
check=True,
)

real_space_mask = ag.Mask2D.circular(
shape_native=(256, 256),
pixel_scales=0.1,
radius=3.0,
)

dataset = ag.Interferometer.from_fits(
data_path=path.join(dataset_path, "data.fits"),
noise_map_path=path.join(dataset_path, "noise_map.fits"),
uv_wavelengths_path=path.join(dataset_path, "uv_wavelengths.fits"),
real_space_mask=real_space_mask,
transformer_class=ag.TransformerDFT,
)

# Single galaxy with a Sersic bulge — no lens/source split, no mass profile.

bulge = af.Model(ag.lp.Sersic)

galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge)

model = af.Collection(galaxies=af.Collection(galaxy=galaxy))

print(model.info)

analysis = ag.AnalysisInterferometer(dataset=dataset)

from autofit.non_linear.fitness import Fitness

fitness = Fitness(
model=model,
analysis=analysis,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
)

param_vector = jnp.array(model.physical_values_from_prior_medians)

# Perturb ell_comps away from (0,0) to avoid degenerate gradients at the
# circular-profile singularity (arctan2 gradient is undefined at exactly (0,0)).
key = jax.random.PRNGKey(0)
perturbation = jax.random.uniform(
key, shape=param_vector.shape, minval=0.01, maxval=0.05
)
param_vector = param_vector + perturbation

value, grad = jax.value_and_grad(fitness.call)(param_vector)

print(f"Log likelihood = {float(value):.6f}")
print(f"Gradient shape = {grad.shape}")
print(f"Gradient = {np.array(grad)}")

assert np.isfinite(float(value)), "Log likelihood is not finite"
assert grad.shape == (
model.total_free_parameters,
), f"Gradient shape mismatch: {grad.shape}"
assert np.all(
np.isfinite(np.array(grad))
), f"Gradient contains non-finite values: {np.array(grad)}"
assert not np.all(np.array(grad) == 0.0), "Gradient is all zeros"

print("jax_grad/interferometer/lp.py JAX gradient checks passed.")
100 changes: 100 additions & 0 deletions scripts/jax_grad/interferometer/mge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Tests that jax.value_and_grad can compute finite, non-NaN gradients of the log-likelihood
for an autogalaxy interferometer model with a Multi-Gaussian Expansion (MGE) linear basis
light profile. This tests the core JAX differentiability that enables gradient-based
inference on visibility data.
"""

import numpy as np
import jax
import jax.numpy as jnp
from os import path

import autofit as af
import autogalaxy as ag

dataset_path = path.join("dataset", "interferometer", "jax_test")

"""
__Dataset Auto-Simulation__

If the dataset does not already exist on your system, it will be created by running the corresponding
simulator script. This ensures that all example scripts can be run without manually simulating data first.
"""
if not path.exists(path.join(dataset_path, "data.fits")):
import subprocess
import sys

subprocess.run(
[
sys.executable,
"scripts/jax_likelihood_functions/interferometer/simulator.py",
],
check=True,
)

real_space_mask = ag.Mask2D.circular(
shape_native=(256, 256),
pixel_scales=0.1,
radius=3.0,
)

dataset = ag.Interferometer.from_fits(
data_path=path.join(dataset_path, "data.fits"),
noise_map_path=path.join(dataset_path, "noise_map.fits"),
uv_wavelengths_path=path.join(dataset_path, "uv_wavelengths.fits"),
real_space_mask=real_space_mask,
transformer_class=ag.TransformerDFT,
)

# Single galaxy with an MGE linear basis light profile — no lens/source split, no mass profile.

mask_radius = 3.0

bulge = ag.model_util.mge_model_from(
mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=True
)

galaxy = af.Model(ag.Galaxy, redshift=0.5, bulge=bulge)

model = af.Collection(galaxies=af.Collection(galaxy=galaxy))

print(model.info)

analysis = ag.AnalysisInterferometer(dataset=dataset)

from autofit.non_linear.fitness import Fitness

fitness = Fitness(
model=model,
analysis=analysis,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
)

param_vector = jnp.array(model.physical_values_from_prior_medians)

# Perturb ell_comps away from (0,0) to avoid degenerate gradients at the
# circular-profile singularity (arctan2 gradient is undefined at exactly (0,0)).
key = jax.random.PRNGKey(0)
perturbation = jax.random.uniform(
key, shape=param_vector.shape, minval=0.01, maxval=0.05
)
param_vector = param_vector + perturbation

value, grad = jax.value_and_grad(fitness.call)(param_vector)

print(f"Log likelihood = {float(value):.6f}")
print(f"Gradient shape = {grad.shape}")
print(f"Gradient = {np.array(grad)}")

assert np.isfinite(float(value)), "Log likelihood is not finite"
assert grad.shape == (
model.total_free_parameters,
), f"Gradient shape mismatch: {grad.shape}"
assert np.all(
np.isfinite(np.array(grad))
), f"Gradient contains non-finite values: {np.array(grad)}"
assert not np.all(np.array(grad) == 0.0), "Gradient is all zeros"

print("jax_grad/interferometer/mge.py JAX gradient checks passed.")
2 changes: 2 additions & 0 deletions smoke_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ jax_likelihood_functions/imaging/delaunay.py
# jax_likelihood_functions/imaging/delaunay_mge.py # disabled: jax 0.7 removed jax.interpreters.xla.pytype_aval_mappings — see PyAutoPrompt/autobuild/smoke_workspace_fixes.md
jax_grad/imaging/lp.py
jax_grad/imaging/mge.py
jax_grad/interferometer/lp.py
jax_grad/interferometer/mge.py
jax_likelihood_functions/interferometer/lp.py
jax_likelihood_functions/interferometer/mge.py
jax_likelihood_functions/interferometer/mge_group.py
Expand Down
Loading