Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from typing import Optional, Dict

from autofit import exc
from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.samples.summary import SamplesSummary
Expand Down Expand Up @@ -173,8 +174,17 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] =
batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model))
logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.")
else:
n_latents = len(self.LATENT_KEYS)
nan_row = np.full(n_latents, np.nan)

def _safe_compute(xx):
try:
return compute_latent_for_model(xx)
except exc.FitException:
return nan_row

def batched_compute_latent(x):
return np.array([compute_latent_for_model(xx) for xx in x])
return np.array([_safe_compute(xx) for xx in x])

parameter_array = np.array(samples.parameter_lists)
latent_samples = []
Expand All @@ -183,6 +193,7 @@ def batched_compute_latent(x):
for i in range(0, len(parameter_array), batch_size):

batch = parameter_array[i:i + batch_size]
batch_samples = samples.sample_list[i:i + batch_size]

# batched JAX call on this chunk
latent_values_batch = batched_compute_latent(batch)
Expand All @@ -193,10 +204,19 @@ def batched_compute_latent(x):
mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, mask]
else:
mask = np.all(np.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, mask]

for sample, values in zip(samples.sample_list[i:i + batch_size], latent_values_batch):
# Drop samples whose latent computation failed (e.g. FitException from
# model assertions surfaced as a NaN row in _safe_compute). This leaves
# the per-latent column mask to continue handling degenerate latent
# dimensions that produce NaN for all remaining samples.
row_mask = np.all(np.isfinite(latent_values_batch), axis=1)
latent_values_batch = latent_values_batch[row_mask]
batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep]

if len(latent_values_batch):
col_mask = np.all(np.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, col_mask]

for sample, values in zip(batch_samples, latent_values_batch):

kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)}

Expand Down
39 changes: 39 additions & 0 deletions test_autofit/analysis/test_latent_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,45 @@ def test_compute_latent_samples(latent_samples):
assert latent_samples.model.instance_from_vector([1.0]).fwhm == 1.0


class AssertionAnalysis(af.Analysis):

LATENT_KEYS = ["fwhm"]

def log_likelihood_function(self, instance):
return 1.0

def compute_latent_variables(self, parameters, model):
if parameters[0] < 0:
raise af.exc.FitException("assertion violated")
instance = model.instance_from_vector(vector=parameters)
return (instance.fwhm,)


def test_compute_latent_samples_skips_fit_exception_samples():
analysis = AssertionAnalysis()
latent_samples = analysis.compute_latent_samples(
SamplesPDF(
model=af.Model(af.ex.Gaussian),
sample_list=[
af.Sample(
log_likelihood=1.0,
log_prior=0.0,
weight=1.0,
kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0},
),
af.Sample(
log_likelihood=-1.0,
log_prior=0.0,
weight=0.0,
kwargs={"centre": -1.0, "normalization": 2.0, "sigma": 3.0},
),
],
),
)
assert len(latent_samples.sample_list) == 1
assert latent_samples.sample_list[0].kwargs == {"fwhm": 7.0644601350928475}


def test_info(latent_samples):
info = result_info_from(latent_samples)
assert (
Expand Down
Loading