@@ -210,10 +210,23 @@ def _safe_compute(xx):
210210 def batched_compute_latent (x ):
211211 return np .array ([_safe_compute (xx ) for xx in x ])
212212
213+ from autoconf .test_mode import inject_latent_nans
214+
213215 parameter_array = np .array (samples .parameter_lists )
214- latent_samples = []
215216
216- # process in batches
217+ # Compute every batch first and accumulate the raw, UN-masked latent
218+ # values into one (n_samples, n_latents) array. Masking is then done
219+ # ONCE, globally, after the loop (see below).
220+ #
221+ # Doing the finite mask per batch (the previous behaviour) was a bug:
222+ # a latent that went NaN for a single sample in one batch had its whole
223+ # column dropped *for that batch only*, while other batches kept it.
224+ # The resulting `Sample` objects then carried inconsistent kwargs key
225+ # sets, and `Samples.summary()` raised `KeyError` building its model
226+ # from the first sample's keys. Masking globally guarantees every
227+ # retained sample shares one identical key set.
228+ all_values = []
229+ all_samples = []
217230 for i in range (0 , len (parameter_array ), batch_size ):
218231
219232 batch = parameter_array [i :i + batch_size ]
@@ -225,36 +238,59 @@ def batched_compute_latent(x):
225238 if self ._use_jax :
226239 import jax .numpy as jnp
227240 latent_values_batch = jnp .stack (latent_values_batch , axis = - 1 ) # (batch, n_latents)
228- mask = jnp .all (jnp .isfinite (latent_values_batch ), axis = 0 )
229- latent_values_batch = latent_values_batch [:, mask ]
230- else :
231- # Drop samples whose latent computation failed (e.g. FitException from
232- # model assertions surfaced as a NaN row in _safe_compute). This leaves
233- # the per-latent column mask to continue handling degenerate latent
234- # dimensions that produce NaN for all remaining samples.
235- row_mask = np .all (np .isfinite (latent_values_batch ), axis = 1 )
236- latent_values_batch = latent_values_batch [row_mask ]
237- batch_samples = [s for s , keep in zip (batch_samples , row_mask ) if keep ]
238-
239- if len (latent_values_batch ):
240- col_mask = np .all (np .isfinite (latent_values_batch ), axis = 0 )
241- latent_values_batch = latent_values_batch [:, col_mask ]
242-
243- for sample , values in zip (batch_samples , latent_values_batch ):
244-
245- kwargs = {k : float (v ) for k , v in zip (self .LATENT_KEYS , values )}
246-
247- latent_samples .append (
248- Sample (
249- log_likelihood = sample .log_likelihood ,
250- log_prior = sample .log_prior ,
251- weight = sample .weight ,
252- kwargs = kwargs ,
253- )
254- )
241+
242+ # Unify to NumPy so the global masking below is a single code path
243+ # for both backends (latent values are scalars, host transfer is
244+ # cheap and was already forced by the downstream `float(v)`).
245+ latent_values_batch = np .asarray (latent_values_batch )
246+
247+ # Test-only NaN injection (no-op unless PYAUTO_LATENT_NAN_INJECT set).
248+ latent_values_batch = inject_latent_nans (latent_values_batch , start_index = i )
249+
250+ all_values .append (latent_values_batch )
251+ all_samples .extend (batch_samples )
252+
253+ if all_values :
254+ all_values = np .concatenate (all_values , axis = 0 )
255+ else :
256+ all_values = np .empty ((0 , len (self .LATENT_KEYS )))
257+
258+ # Global masking, in two stages:
259+ # 1. Drop a latent column only if it is non-finite for EVERY sample
260+ # (a genuinely degenerate latent, e.g. a µJy latent with no magzero).
261+ col_mask = np .any (np .isfinite (all_values ), axis = 0 )
262+ kept_keys = [k for k , keep in zip (self .LATENT_KEYS , col_mask ) if keep ]
263+ kept_values = all_values [:, col_mask ]
264+
265+ # 2. Drop individual samples that still carry a NaN in a surviving
266+ # latent (e.g. a FitException NaN row, or a latent that went NaN
267+ # for just that sample). Every survivor now has all `kept_keys`.
268+ if kept_values .size :
269+ row_mask = np .all (np .isfinite (kept_values ), axis = 1 )
270+ else :
271+ row_mask = np .zeros (len (all_samples ), dtype = bool )
272+ kept_values = kept_values [row_mask ]
273+ kept_samples = [s for s , keep in zip (all_samples , row_mask ) if keep ]
255274
256275 print (f"Time to compute latent variables: { time .time () - start_latent } seconds for { len (samples )} samples." )
257276
277+ if not kept_keys or len (kept_samples ) == 0 :
278+ logger .warning (
279+ "compute_latent_samples: no finite latent samples remained "
280+ "after masking; skipping latent output."
281+ )
282+ return None
283+
284+ latent_samples = [
285+ Sample (
286+ log_likelihood = sample .log_likelihood ,
287+ log_prior = sample .log_prior ,
288+ weight = sample .weight ,
289+ kwargs = {k : float (v ) for k , v in zip (kept_keys , values )},
290+ )
291+ for sample , values in zip (kept_samples , kept_values )
292+ ]
293+
258294 return type (samples )(
259295 sample_list = latent_samples ,
260296 model = simple_model_for_kwargs (latent_samples [0 ].kwargs ),
0 commit comments