Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions src/tinker/lib/_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

When the server sets pjwt_auth_enabled, the SDK exchanges the caller's
credential for a short-lived JWT minted by the Tinker server. The JWT is
cached and refreshed in the background before it expires, so callers always
send a valid token without any per-request overhead.
cached and refreshed in the background before it expires. As a safety
net, get_token() also refreshes on demand if the cached token is near or
past expiry — so a delayed/failed background refresh cannot leave callers
sending a stale JWT (which the server rejects with 401 Invalid JWT, and
401 is not retried by the request layer).
"""

from __future__ import annotations
Expand All @@ -22,8 +25,9 @@

logger = logging.getLogger(__name__)

_REFRESH_BEFORE_EXPIRY_SECS = 300 # refresh 5 min before expiry
_RETRY_DELAY_SECS = 60
_REFRESH_BEFORE_EXPIRY_SECS = 300 # background loop refreshes 5 min before expiry
_REFRESH_ON_DEMAND_SECS = 60 # get_token() refreshes if <= this many seconds left
_RETRY_DELAY_SECS = 60 # backoff after a failed refresh


def _jwt_expiry(jwt: str) -> float:
Expand All @@ -36,23 +40,51 @@ def _jwt_expiry(jwt: str) -> float:
raise ValueError(f"Failed to parse JWT expiry: {e}") from e


def _seconds_until_expiry(jwt: str) -> float:
"""Seconds until the JWT expires; 0 if expiry can't be parsed."""
try:
return _jwt_expiry(jwt) - time.time()
except ValueError:
return 0.0


class JwtAuthProvider(AuthTokenProvider):
"""AuthTokenProvider that exchanges a credential for a short-lived JWT.

After init(), get_token() returns the current JWT. A background task
refreshes the JWT before it expires.
proactively refreshes the JWT before it expires. get_token() also
refreshes on demand if the cached token is near or past expiry, so a
stuck or delayed background refresh cannot leak a stale JWT into a
request.
"""

def __init__(
self,
aclient_fn: Callable[[], AbstractContextManager],
seed_token: str | None = None,
) -> None:
self._token = seed_token or ""
self._token: str = seed_token or ""
self._aclient_fn = aclient_fn
self._refresh_lock = asyncio.Lock()

async def get_token(self) -> str | None:
return self._token
# Fast path: cached token has comfortable runway.
if self._token and _seconds_until_expiry(self._token) > _REFRESH_ON_DEMAND_SECS:
return self._token

async with self._refresh_lock:
# Re-check after acquiring the lock — another caller may have
# just refreshed the token while we were waiting.
if self._token and _seconds_until_expiry(self._token) > _REFRESH_ON_DEMAND_SECS:
return self._token
try:
return await self._fetch()
except Exception as e:
# If the refresh fails, fall back to whatever we have.
# The background loop keeps trying; if the server is
# genuinely down the request will surface the error.
logger.warning("On-demand JWT refresh failed: %s", e)
return self._token or None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the cached token is expired, unparseable, or absent, this hides a retryable /auth/token failure and sends a request that will likely fail with non-retryable 401/no auth. Only fall back when the cached JWT is still usable; otherwise re-raise so the normal request retry path can retry the auth refresh. Could do something like this:

except Exception:
    seconds_left = _seconds_until_expiry(self._token) if self._token else 0.0
    if seconds_left > 0:
        logger.warning(
            "On-demand JWT refresh failed; using cached JWT with %.1fs remaining",
            seconds_left,
            exc_info=True,
        )
        return self._token

    logger.warning("On-demand JWT refresh failed with no usable cached JWT", exc_info=True)
    raise


async def init(self) -> None:
"""Fetch a JWT (unless seeded) then start the background refresh loop.
Expand All @@ -68,23 +100,34 @@ async def _fetch(self) -> str:
"""Exchange the current credential for a JWT via /api/v1/auth/token."""
with self._aclient_fn() as client:
response = await client.service.auth_token()
self._token = response.jwt
return response.jwt
jwt: str = response.jwt
self._token = jwt
return jwt

async def _refresh_loop(self, token: str) -> None:
while True:
try:
delay = max(
_RETRY_DELAY_SECS,
0.0,
_jwt_expiry(token) - time.time() - _REFRESH_BEFORE_EXPIRY_SECS,
)
except ValueError:
logger.debug("Failed to parse JWT expiry, retrying in %ds", _RETRY_DELAY_SECS)
delay = _RETRY_DELAY_SECS
delay = float(_RETRY_DELAY_SECS)
try:
await asyncio.sleep(delay)
token = await self._fetch()
# Coordinate with on-demand refreshes in get_token() so we
# don't fire two concurrent /auth/token requests.
async with self._refresh_lock:
token = await self._fetch()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
your usage above of

async with self._refresh_lock:
    if _seconds_until_expiry(self._token) > _REFRESH_BEFORE_EXPIRY_SECS:
        token = self._token
    else:
        token = await self._fetch()

is slightly better. would prefer this

except asyncio.CancelledError:
return
except Exception as e:
logger.debug("JWT refresh failed, retrying in %ds: %s", _RETRY_DELAY_SECS, e)
# Explicit backoff: without the old max(60, ...) floor on
# `delay`, a stale token would otherwise compute delay=0
# next iteration and tight-loop on persistent failures.
try:
await asyncio.sleep(_RETRY_DELAY_SECS)
except asyncio.CancelledError:
return
142 changes: 142 additions & 0 deletions src/tinker/lib/_jwt_auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import base64
import json
import time
Expand All @@ -18,6 +19,7 @@
from tinker.lib._jwt_auth import (
JwtAuthProvider,
_jwt_expiry,
_seconds_until_expiry,
)

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -154,3 +156,143 @@ async def test_fetch_returns_and_stores_token():

assert result == jwt
assert await provider.get_token() == jwt


# ---------------------------------------------------------------------------
# _seconds_until_expiry
# ---------------------------------------------------------------------------


def test_seconds_until_expiry_returns_remaining_for_valid_jwt():
exp = time.time() + 3600
assert abs(_seconds_until_expiry(_make_jwt(exp)) - 3600) < 1


def test_seconds_until_expiry_returns_zero_for_unparseable_jwt():
assert _seconds_until_expiry("not.a.jwt") == 0.0


def test_seconds_until_expiry_returns_negative_for_expired_jwt():
exp = time.time() - 60
assert _seconds_until_expiry(_make_jwt(exp)) < 0


# ---------------------------------------------------------------------------
# JwtAuthProvider.get_token on-demand refresh
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_get_token_returns_cached_when_fresh():
"""Cached token with comfortable runway is returned without refetching."""
fresh_jwt = _make_jwt(time.time() + 7200)
holder = _MockHolder("should-not-be-fetched")
provider = JwtAuthProvider(holder.aclient, seed_token=fresh_jwt)

assert await provider.get_token() == fresh_jwt
holder._cm.__enter__.return_value.service.auth_token.assert_not_called()


@pytest.mark.asyncio
async def test_get_token_refreshes_when_near_expiry():
"""If cached token has <= _REFRESH_ON_DEMAND_SECS left, fetch a new one."""
near_expiry_jwt = _make_jwt(time.time() + 30) # 30s left, under threshold
refreshed_jwt = _make_jwt(time.time() + 7200)
holder = _MockHolder(refreshed_jwt)
provider = JwtAuthProvider(holder.aclient, seed_token=near_expiry_jwt)

assert await provider.get_token() == refreshed_jwt
holder._cm.__enter__.return_value.service.auth_token.assert_called_once()


@pytest.mark.asyncio
async def test_get_token_refreshes_when_already_expired():
"""An expired cached token must trigger refresh, not be served as-is."""
expired_jwt = _make_jwt(time.time() - 30)
refreshed_jwt = _make_jwt(time.time() + 7200)
holder = _MockHolder(refreshed_jwt)
provider = JwtAuthProvider(holder.aclient, seed_token=expired_jwt)

assert await provider.get_token() == refreshed_jwt
holder._cm.__enter__.return_value.service.auth_token.assert_called_once()


@pytest.mark.asyncio
async def test_get_token_refreshes_when_cached_token_is_unparseable():
"""A garbled cached token (e.g. corrupt seed) is treated as expired."""
refreshed_jwt = _make_jwt(time.time() + 7200)
holder = _MockHolder(refreshed_jwt)
provider = JwtAuthProvider(holder.aclient, seed_token="not.a.jwt")

assert await provider.get_token() == refreshed_jwt
holder._cm.__enter__.return_value.service.auth_token.assert_called_once()


@pytest.mark.asyncio
async def test_get_token_concurrent_refresh_only_fires_once():
"""Many concurrent get_token() calls share a single in-flight refresh."""
near_expiry_jwt = _make_jwt(time.time() + 30)
refreshed_jwt = _make_jwt(time.time() + 7200)

fetch_started = asyncio.Event()
fetch_release = asyncio.Event()
fetch_count = 0

async def slow_auth_token():
nonlocal fetch_count
fetch_count += 1
fetch_started.set()
await fetch_release.wait()
return _MockAuthResponse(refreshed_jwt)

service = MagicMock()
service.auth_token = slow_auth_token
client = MagicMock()
client.service = service
cm = MagicMock()
cm.__enter__ = MagicMock(return_value=client)
cm.__exit__ = MagicMock(return_value=None)

provider = JwtAuthProvider(lambda: cm, seed_token=near_expiry_jwt)

tasks = [asyncio.create_task(provider.get_token()) for _ in range(5)]

# Wait for the first task to enter the fetch, then let the others queue
# up at the lock before releasing the in-flight fetch.
await fetch_started.wait()
await asyncio.sleep(0)
fetch_release.set()
results = await asyncio.gather(*tasks)

assert fetch_count == 1
assert all(r == refreshed_jwt for r in results)


@pytest.mark.asyncio
async def test_get_token_returns_stale_token_when_refresh_fails(
caplog: pytest.LogCaptureFixture,
):
"""If on-demand refresh fails, return cached token + log a warning.

Better than raising — the request will surface its own error if the
token really is rejected, and other in-flight requests sharing this
provider can still make progress on transient refresh failures.
"""
near_expiry_jwt = _make_jwt(time.time() + 30)
holder = _MockHolder("unused", fail=True)
provider = JwtAuthProvider(holder.aclient, seed_token=near_expiry_jwt)

with caplog.at_level("WARNING"):
result = await provider.get_token()

assert result == near_expiry_jwt
assert "On-demand JWT refresh failed" in caplog.text


@pytest.mark.asyncio
async def test_get_token_returns_none_when_no_token_and_refresh_fails():
"""No cached token + failed refresh => return None (no header sent)."""
holder = _MockHolder("unused", fail=True)
provider = JwtAuthProvider(holder.aclient, seed_token=None)

assert await provider.get_token() is None