diff --git a/src/tinker/lib/_jwt_auth.py b/src/tinker/lib/_jwt_auth.py index 0bf8c3b..2252789 100644 --- a/src/tinker/lib/_jwt_auth.py +++ b/src/tinker/lib/_jwt_auth.py @@ -4,8 +4,11 @@ When the server sets pjwt_auth_enabled, the SDK exchanges the caller's credential for a short-lived JWT minted by the Tinker server. The JWT is -cached and refreshed in the background before it expires, so callers always -send a valid token without any per-request overhead. +cached and refreshed in the background before it expires. As a safety +net, get_token() also refreshes on demand if the cached token is near or +past expiry — so a delayed/failed background refresh cannot leave callers +sending a stale JWT (which the server rejects with 401 Invalid JWT, and +401 is not retried by the request layer). """ from __future__ import annotations @@ -22,8 +25,9 @@ logger = logging.getLogger(__name__) -_REFRESH_BEFORE_EXPIRY_SECS = 300 # refresh 5 min before expiry -_RETRY_DELAY_SECS = 60 +_REFRESH_BEFORE_EXPIRY_SECS = 300 # background loop refreshes 5 min before expiry +_REFRESH_ON_DEMAND_SECS = 60 # get_token() refreshes if <= this many seconds left +_RETRY_DELAY_SECS = 60 # backoff after a failed refresh def _jwt_expiry(jwt: str) -> float: @@ -36,11 +40,22 @@ def _jwt_expiry(jwt: str) -> float: raise ValueError(f"Failed to parse JWT expiry: {e}") from e +def _seconds_until_expiry(jwt: str) -> float: + """Seconds until the JWT expires; 0 if expiry can't be parsed.""" + try: + return _jwt_expiry(jwt) - time.time() + except ValueError: + return 0.0 + + class JwtAuthProvider(AuthTokenProvider): """AuthTokenProvider that exchanges a credential for a short-lived JWT. After init(), get_token() returns the current JWT. A background task - refreshes the JWT before it expires. + proactively refreshes the JWT before it expires. get_token() also + refreshes on demand if the cached token is near or past expiry, so a + stuck or delayed background refresh cannot leak a stale JWT into a + request. """ def __init__( @@ -48,11 +63,28 @@ def __init__( aclient_fn: Callable[[], AbstractContextManager], seed_token: str | None = None, ) -> None: - self._token = seed_token or "" + self._token: str = seed_token or "" self._aclient_fn = aclient_fn + self._refresh_lock = asyncio.Lock() async def get_token(self) -> str | None: - return self._token + # Fast path: cached token has comfortable runway. + if self._token and _seconds_until_expiry(self._token) > _REFRESH_ON_DEMAND_SECS: + return self._token + + async with self._refresh_lock: + # Re-check after acquiring the lock — another caller may have + # just refreshed the token while we were waiting. + if self._token and _seconds_until_expiry(self._token) > _REFRESH_ON_DEMAND_SECS: + return self._token + try: + return await self._fetch() + except Exception as e: + # If the refresh fails, fall back to whatever we have. + # The background loop keeps trying; if the server is + # genuinely down the request will surface the error. + logger.warning("On-demand JWT refresh failed: %s", e) + return self._token or None async def init(self) -> None: """Fetch a JWT (unless seeded) then start the background refresh loop. @@ -68,23 +100,34 @@ async def _fetch(self) -> str: """Exchange the current credential for a JWT via /api/v1/auth/token.""" with self._aclient_fn() as client: response = await client.service.auth_token() - self._token = response.jwt - return response.jwt + jwt: str = response.jwt + self._token = jwt + return jwt async def _refresh_loop(self, token: str) -> None: while True: try: delay = max( - _RETRY_DELAY_SECS, + 0.0, _jwt_expiry(token) - time.time() - _REFRESH_BEFORE_EXPIRY_SECS, ) except ValueError: logger.debug("Failed to parse JWT expiry, retrying in %ds", _RETRY_DELAY_SECS) - delay = _RETRY_DELAY_SECS + delay = float(_RETRY_DELAY_SECS) try: await asyncio.sleep(delay) - token = await self._fetch() + # Coordinate with on-demand refreshes in get_token() so we + # don't fire two concurrent /auth/token requests. + async with self._refresh_lock: + token = await self._fetch() except asyncio.CancelledError: return except Exception as e: logger.debug("JWT refresh failed, retrying in %ds: %s", _RETRY_DELAY_SECS, e) + # Explicit backoff: without the old max(60, ...) floor on + # `delay`, a stale token would otherwise compute delay=0 + # next iteration and tight-loop on persistent failures. + try: + await asyncio.sleep(_RETRY_DELAY_SECS) + except asyncio.CancelledError: + return diff --git a/src/tinker/lib/_jwt_auth_test.py b/src/tinker/lib/_jwt_auth_test.py index e26f155..0c28aa2 100644 --- a/src/tinker/lib/_jwt_auth_test.py +++ b/src/tinker/lib/_jwt_auth_test.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import base64 import json import time @@ -18,6 +19,7 @@ from tinker.lib._jwt_auth import ( JwtAuthProvider, _jwt_expiry, + _seconds_until_expiry, ) # --------------------------------------------------------------------------- @@ -154,3 +156,143 @@ async def test_fetch_returns_and_stores_token(): assert result == jwt assert await provider.get_token() == jwt + + +# --------------------------------------------------------------------------- +# _seconds_until_expiry +# --------------------------------------------------------------------------- + + +def test_seconds_until_expiry_returns_remaining_for_valid_jwt(): + exp = time.time() + 3600 + assert abs(_seconds_until_expiry(_make_jwt(exp)) - 3600) < 1 + + +def test_seconds_until_expiry_returns_zero_for_unparseable_jwt(): + assert _seconds_until_expiry("not.a.jwt") == 0.0 + + +def test_seconds_until_expiry_returns_negative_for_expired_jwt(): + exp = time.time() - 60 + assert _seconds_until_expiry(_make_jwt(exp)) < 0 + + +# --------------------------------------------------------------------------- +# JwtAuthProvider.get_token on-demand refresh +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_token_returns_cached_when_fresh(): + """Cached token with comfortable runway is returned without refetching.""" + fresh_jwt = _make_jwt(time.time() + 7200) + holder = _MockHolder("should-not-be-fetched") + provider = JwtAuthProvider(holder.aclient, seed_token=fresh_jwt) + + assert await provider.get_token() == fresh_jwt + holder._cm.__enter__.return_value.service.auth_token.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_token_refreshes_when_near_expiry(): + """If cached token has <= _REFRESH_ON_DEMAND_SECS left, fetch a new one.""" + near_expiry_jwt = _make_jwt(time.time() + 30) # 30s left, under threshold + refreshed_jwt = _make_jwt(time.time() + 7200) + holder = _MockHolder(refreshed_jwt) + provider = JwtAuthProvider(holder.aclient, seed_token=near_expiry_jwt) + + assert await provider.get_token() == refreshed_jwt + holder._cm.__enter__.return_value.service.auth_token.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_token_refreshes_when_already_expired(): + """An expired cached token must trigger refresh, not be served as-is.""" + expired_jwt = _make_jwt(time.time() - 30) + refreshed_jwt = _make_jwt(time.time() + 7200) + holder = _MockHolder(refreshed_jwt) + provider = JwtAuthProvider(holder.aclient, seed_token=expired_jwt) + + assert await provider.get_token() == refreshed_jwt + holder._cm.__enter__.return_value.service.auth_token.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_token_refreshes_when_cached_token_is_unparseable(): + """A garbled cached token (e.g. corrupt seed) is treated as expired.""" + refreshed_jwt = _make_jwt(time.time() + 7200) + holder = _MockHolder(refreshed_jwt) + provider = JwtAuthProvider(holder.aclient, seed_token="not.a.jwt") + + assert await provider.get_token() == refreshed_jwt + holder._cm.__enter__.return_value.service.auth_token.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_token_concurrent_refresh_only_fires_once(): + """Many concurrent get_token() calls share a single in-flight refresh.""" + near_expiry_jwt = _make_jwt(time.time() + 30) + refreshed_jwt = _make_jwt(time.time() + 7200) + + fetch_started = asyncio.Event() + fetch_release = asyncio.Event() + fetch_count = 0 + + async def slow_auth_token(): + nonlocal fetch_count + fetch_count += 1 + fetch_started.set() + await fetch_release.wait() + return _MockAuthResponse(refreshed_jwt) + + service = MagicMock() + service.auth_token = slow_auth_token + client = MagicMock() + client.service = service + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=client) + cm.__exit__ = MagicMock(return_value=None) + + provider = JwtAuthProvider(lambda: cm, seed_token=near_expiry_jwt) + + tasks = [asyncio.create_task(provider.get_token()) for _ in range(5)] + + # Wait for the first task to enter the fetch, then let the others queue + # up at the lock before releasing the in-flight fetch. + await fetch_started.wait() + await asyncio.sleep(0) + fetch_release.set() + results = await asyncio.gather(*tasks) + + assert fetch_count == 1 + assert all(r == refreshed_jwt for r in results) + + +@pytest.mark.asyncio +async def test_get_token_returns_stale_token_when_refresh_fails( + caplog: pytest.LogCaptureFixture, +): + """If on-demand refresh fails, return cached token + log a warning. + + Better than raising — the request will surface its own error if the + token really is rejected, and other in-flight requests sharing this + provider can still make progress on transient refresh failures. + """ + near_expiry_jwt = _make_jwt(time.time() + 30) + holder = _MockHolder("unused", fail=True) + provider = JwtAuthProvider(holder.aclient, seed_token=near_expiry_jwt) + + with caplog.at_level("WARNING"): + result = await provider.get_token() + + assert result == near_expiry_jwt + assert "On-demand JWT refresh failed" in caplog.text + + +@pytest.mark.asyncio +async def test_get_token_returns_none_when_no_token_and_refresh_fails(): + """No cached token + failed refresh => return None (no header sent).""" + holder = _MockHolder("unused", fail=True) + provider = JwtAuthProvider(holder.aclient, seed_token=None) + + assert await provider.get_token() is None