Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions mtplx/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class SamplerConfig:
top_p: float = 0.95
top_k: int = 20


@dataclass(frozen=True)
class SparseDistribution:
token_ids: np.ndarray
Expand All @@ -37,7 +36,12 @@ def __post_init__(self):
raise ValueError("SparseDistribution probabilities must be non-negative")
total = probs.sum()
if not np.isfinite(total) or total <= 0:
raise ValueError("SparseDistribution probabilities must have positive mass")
if token_ids.size > 0:
token_ids = token_ids[:1]
else:
token_ids = np.array([0], dtype=np.int64)
probs = np.array([1.0], dtype=np.float64)
total = 1.0
object.__setattr__(self, "token_ids", token_ids)
object.__setattr__(self, "probs", probs / total)

Expand Down Expand Up @@ -158,27 +162,27 @@ def residual_distribution(target_p: Distribution, draft_q: Distribution) -> Dist
)
keep = residual > 0
total = residual[keep].sum()
if total <= 0:
if total <= 0 or not np.isfinite(total):
return target_p
return SparseDistribution(token_ids[keep], residual[keep] / total, _vocab_size(target_p))

dense_target = _as_dense(target_p)
dense_draft = _as_dense(draft_q)
residual = np.maximum(dense_target - dense_draft, 0.0)
total = residual.sum()
if total <= 0:
if total <= 0 or not np.isfinite(total):
residual = dense_target.copy()
total = residual.sum()
if total <= 0:
if total <= 0 or not np.isfinite(total):
raise ValueError("Cannot build residual distribution from empty target")
return residual / total

residual = np.maximum(np.asarray(target_p) - np.asarray(draft_q), 0.0)
total = residual.sum()
if total <= 0:
if total <= 0 or not np.isfinite(total):
residual = np.asarray(target_p, dtype=np.float64).copy()
total = residual.sum()
if total <= 0:
if total <= 0 or not np.isfinite(total):
raise ValueError("Cannot build residual distribution from empty target")
return residual / total

Expand Down
41 changes: 41 additions & 0 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,44 @@ def test_sparse_distribution_sampling_returns_original_token_ids():
vocab_size=12,
)
assert sample_from_distribution(dist, np.random.default_rng(0)) == 11


def test_sparse_distribution_fallback_on_zero_or_nan_mass():
# Test fallback in constructor when probs has zero sum
dist_zero = SparseDistribution(
token_ids=np.array([1, 2]),
probs=np.array([0.0, 0.0]),
vocab_size=5,
)
assert dist_zero.probs.sum() == 1.0
assert dist_zero.token_ids.tolist() == [1]

# Test fallback in constructor when probs has nan value
dist_nan = SparseDistribution(
token_ids=np.array([3, 4]),
probs=np.array([np.nan, 1.0]),
vocab_size=5,
)
assert dist_nan.probs.sum() == 1.0
assert dist_nan.token_ids.tolist() == [3]


def test_residual_distribution_nan_robustness():
# Test residual_distribution robustness against NaN values
target = SparseDistribution(
token_ids=np.array([1, 2]),
probs=np.array([0.2, 0.8]),
vocab_size=5,
)
draft = SparseDistribution(
token_ids=np.array([1, 2]),
probs=np.array([0.8, 0.2]),
vocab_size=5,
)
# Mock probability to return NaN for one token
from unittest.mock import patch
original_probability = SparseDistribution.probability
with patch.object(SparseDistribution, "probability", lambda self, token_id: np.nan if token_id == 1 else original_probability(self, token_id)):
res = residual_distribution(target, draft)
# Should gracefully return target because NaN is checked/handled
assert isinstance(res, SparseDistribution)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.