diff --git a/promptlens/utils/retry.py b/promptlens/utils/retry.py index 7834517..06e7fea 100644 --- a/promptlens/utils/retry.py +++ b/promptlens/utils/retry.py @@ -33,6 +33,18 @@ async def retry_with_exponential_backoff( Raises: Exception: The last exception if all retries fail """ + if max_attempts < 1: + raise ValueError("max_attempts must be >= 1") + + if initial_delay < 0: + raise ValueError("initial_delay must be >= 0") + + if backoff_factor <= 0: + raise ValueError("backoff_factor must be > 0") + + if max_delay <= 0: + raise ValueError("max_delay must be > 0") + delay = initial_delay last_exception = None diff --git a/tests/test_retry_hardening.py b/tests/test_retry_hardening.py new file mode 100644 index 0000000..2d20729 --- /dev/null +++ b/tests/test_retry_hardening.py @@ -0,0 +1,39 @@ +import pytest + +from promptlens.utils.retry import retry_with_exponential_backoff + + +@pytest.mark.asyncio +async def test_retry_rejects_non_positive_attempts() -> None: + async def always_fails() -> str: + raise RuntimeError("boom") + + with pytest.raises(ValueError, match="max_attempts"): + await retry_with_exponential_backoff(always_fails, max_attempts=0) + + +@pytest.mark.asyncio +async def test_retry_rejects_negative_initial_delay() -> None: + async def always_fails() -> str: + raise RuntimeError("boom") + + with pytest.raises(ValueError, match="initial_delay"): + await retry_with_exponential_backoff(always_fails, initial_delay=-0.1) + + +@pytest.mark.asyncio +async def test_retry_rejects_invalid_backoff_factor() -> None: + async def always_fails() -> str: + raise RuntimeError("boom") + + with pytest.raises(ValueError, match="backoff_factor"): + await retry_with_exponential_backoff(always_fails, backoff_factor=0) + + +@pytest.mark.asyncio +async def test_retry_rejects_non_positive_max_delay() -> None: + async def always_fails() -> str: + raise RuntimeError("boom") + + with pytest.raises(ValueError, match="max_delay"): + await retry_with_exponential_backoff(always_fails, max_delay=0)