diff --git a/mtplx/sampling.py b/mtplx/sampling.py index d71af75..9161f12 100644 --- a/mtplx/sampling.py +++ b/mtplx/sampling.py @@ -17,7 +17,6 @@ class SamplerConfig: top_p: float = 0.95 top_k: int = 20 - @dataclass(frozen=True) class SparseDistribution: token_ids: np.ndarray @@ -37,7 +36,12 @@ def __post_init__(self): raise ValueError("SparseDistribution probabilities must be non-negative") total = probs.sum() if not np.isfinite(total) or total <= 0: - raise ValueError("SparseDistribution probabilities must have positive mass") + if token_ids.size > 0: + token_ids = token_ids[:1] + else: + token_ids = np.array([0], dtype=np.int64) + probs = np.array([1.0], dtype=np.float64) + total = 1.0 object.__setattr__(self, "token_ids", token_ids) object.__setattr__(self, "probs", probs / total) @@ -158,7 +162,7 @@ def residual_distribution(target_p: Distribution, draft_q: Distribution) -> Dist ) keep = residual > 0 total = residual[keep].sum() - if total <= 0: + if total <= 0 or not np.isfinite(total): return target_p return SparseDistribution(token_ids[keep], residual[keep] / total, _vocab_size(target_p)) @@ -166,19 +170,19 @@ def residual_distribution(target_p: Distribution, draft_q: Distribution) -> Dist dense_draft = _as_dense(draft_q) residual = np.maximum(dense_target - dense_draft, 0.0) total = residual.sum() - if total <= 0: + if total <= 0 or not np.isfinite(total): residual = dense_target.copy() total = residual.sum() - if total <= 0: + if total <= 0 or not np.isfinite(total): raise ValueError("Cannot build residual distribution from empty target") return residual / total residual = np.maximum(np.asarray(target_p) - np.asarray(draft_q), 0.0) total = residual.sum() - if total <= 0: + if total <= 0 or not np.isfinite(total): residual = np.asarray(target_p, dtype=np.float64).copy() total = residual.sum() - if total <= 0: + if total <= 0 or not np.isfinite(total): raise ValueError("Cannot build residual distribution from empty target") return residual / total diff --git a/tests/test_sampling.py b/tests/test_sampling.py index f9bb72a..fb2ebe1 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -76,3 +76,44 @@ def test_sparse_distribution_sampling_returns_original_token_ids(): vocab_size=12, ) assert sample_from_distribution(dist, np.random.default_rng(0)) == 11 + + +def test_sparse_distribution_fallback_on_zero_or_nan_mass(): + # Test fallback in constructor when probs has zero sum + dist_zero = SparseDistribution( + token_ids=np.array([1, 2]), + probs=np.array([0.0, 0.0]), + vocab_size=5, + ) + assert dist_zero.probs.sum() == 1.0 + assert dist_zero.token_ids.tolist() == [1] + + # Test fallback in constructor when probs has nan value + dist_nan = SparseDistribution( + token_ids=np.array([3, 4]), + probs=np.array([np.nan, 1.0]), + vocab_size=5, + ) + assert dist_nan.probs.sum() == 1.0 + assert dist_nan.token_ids.tolist() == [3] + + +def test_residual_distribution_nan_robustness(): + # Test residual_distribution robustness against NaN values + target = SparseDistribution( + token_ids=np.array([1, 2]), + probs=np.array([0.2, 0.8]), + vocab_size=5, + ) + draft = SparseDistribution( + token_ids=np.array([1, 2]), + probs=np.array([0.8, 0.2]), + vocab_size=5, + ) + # Mock probability to return NaN for one token + from unittest.mock import patch + original_probability = SparseDistribution.probability + with patch.object(SparseDistribution, "probability", lambda self, token_id: np.nan if token_id == 1 else original_probability(self, token_id)): + res = residual_distribution(target, draft) + # Should gracefully return target because NaN is checked/handled + assert isinstance(res, SparseDistribution) diff --git a/uv.lock b/uv.lock index 085a3ca..19e368c 100644 --- a/uv.lock +++ b/uv.lock @@ -678,7 +678,7 @@ wheels = [ [[package]] name = "mtplx" -version = "0.3.6" +version = "0.3.7" source = { editable = "." } dependencies = [ { name = "fastapi" },