Skip to content

Commit 4ce3ba6

Browse files
Jammy2211claude
authored andcommitted
fix(latent): global masking in compute_latent_samples to prevent KeyError
Latent finite-masking was computed per batch on the JAX path, so a latent that went NaN for one sample in a batch had its column dropped for that batch only. Samples then carried inconsistent kwargs key sets and Samples.summary() raised KeyError. Accumulate all batches and mask once globally (col-then-row); return None when nothing finite remains. Depends on autoconf.test_mode hook. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 682d673 commit 4ce3ba6

1 file changed

Lines changed: 65 additions & 29 deletions

File tree

autofit/non_linear/analysis/analysis.py

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,23 @@ def _safe_compute(xx):
210210
def batched_compute_latent(x):
211211
return np.array([_safe_compute(xx) for xx in x])
212212

213+
from autoconf.test_mode import inject_latent_nans
214+
213215
parameter_array = np.array(samples.parameter_lists)
214-
latent_samples = []
215216

216-
# process in batches
217+
# Compute every batch first and accumulate the raw, UN-masked latent
218+
# values into one (n_samples, n_latents) array. Masking is then done
219+
# ONCE, globally, after the loop (see below).
220+
#
221+
# Doing the finite mask per batch (the previous behaviour) was a bug:
222+
# a latent that went NaN for a single sample in one batch had its whole
223+
# column dropped *for that batch only*, while other batches kept it.
224+
# The resulting `Sample` objects then carried inconsistent kwargs key
225+
# sets, and `Samples.summary()` raised `KeyError` building its model
226+
# from the first sample's keys. Masking globally guarantees every
227+
# retained sample shares one identical key set.
228+
all_values = []
229+
all_samples = []
217230
for i in range(0, len(parameter_array), batch_size):
218231

219232
batch = parameter_array[i:i + batch_size]
@@ -225,36 +238,59 @@ def batched_compute_latent(x):
225238
if self._use_jax:
226239
import jax.numpy as jnp
227240
latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents)
228-
mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0)
229-
latent_values_batch = latent_values_batch[:, mask]
230-
else:
231-
# Drop samples whose latent computation failed (e.g. FitException from
232-
# model assertions surfaced as a NaN row in _safe_compute). This leaves
233-
# the per-latent column mask to continue handling degenerate latent
234-
# dimensions that produce NaN for all remaining samples.
235-
row_mask = np.all(np.isfinite(latent_values_batch), axis=1)
236-
latent_values_batch = latent_values_batch[row_mask]
237-
batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep]
238-
239-
if len(latent_values_batch):
240-
col_mask = np.all(np.isfinite(latent_values_batch), axis=0)
241-
latent_values_batch = latent_values_batch[:, col_mask]
242-
243-
for sample, values in zip(batch_samples, latent_values_batch):
244-
245-
kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)}
246-
247-
latent_samples.append(
248-
Sample(
249-
log_likelihood=sample.log_likelihood,
250-
log_prior=sample.log_prior,
251-
weight=sample.weight,
252-
kwargs=kwargs,
253-
)
254-
)
241+
242+
# Unify to NumPy so the global masking below is a single code path
243+
# for both backends (latent values are scalars, host transfer is
244+
# cheap and was already forced by the downstream `float(v)`).
245+
latent_values_batch = np.asarray(latent_values_batch)
246+
247+
# Test-only NaN injection (no-op unless PYAUTO_LATENT_NAN_INJECT set).
248+
latent_values_batch = inject_latent_nans(latent_values_batch, start_index=i)
249+
250+
all_values.append(latent_values_batch)
251+
all_samples.extend(batch_samples)
252+
253+
if all_values:
254+
all_values = np.concatenate(all_values, axis=0)
255+
else:
256+
all_values = np.empty((0, len(self.LATENT_KEYS)))
257+
258+
# Global masking, in two stages:
259+
# 1. Drop a latent column only if it is non-finite for EVERY sample
260+
# (a genuinely degenerate latent, e.g. a µJy latent with no magzero).
261+
col_mask = np.any(np.isfinite(all_values), axis=0)
262+
kept_keys = [k for k, keep in zip(self.LATENT_KEYS, col_mask) if keep]
263+
kept_values = all_values[:, col_mask]
264+
265+
# 2. Drop individual samples that still carry a NaN in a surviving
266+
# latent (e.g. a FitException NaN row, or a latent that went NaN
267+
# for just that sample). Every survivor now has all `kept_keys`.
268+
if kept_values.size:
269+
row_mask = np.all(np.isfinite(kept_values), axis=1)
270+
else:
271+
row_mask = np.zeros(len(all_samples), dtype=bool)
272+
kept_values = kept_values[row_mask]
273+
kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep]
255274

256275
print(f"Time to compute latent variables: {time.time() - start_latent} seconds for {len(samples)} samples.")
257276

277+
if not kept_keys or len(kept_samples) == 0:
278+
logger.warning(
279+
"compute_latent_samples: no finite latent samples remained "
280+
"after masking; skipping latent output."
281+
)
282+
return None
283+
284+
latent_samples = [
285+
Sample(
286+
log_likelihood=sample.log_likelihood,
287+
log_prior=sample.log_prior,
288+
weight=sample.weight,
289+
kwargs={k: float(v) for k, v in zip(kept_keys, values)},
290+
)
291+
for sample, values in zip(kept_samples, kept_values)
292+
]
293+
258294
return type(samples)(
259295
sample_list=latent_samples,
260296
model=simple_model_for_kwargs(latent_samples[0].kwargs),

0 commit comments

Comments
 (0)