-
Notifications
You must be signed in to change notification settings - Fork 56
Refresh JWT on-demand to avoid 401 Invalid JWT on long-running requests #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
derek-tml
merged 1 commit into
thinking-machines-lab:main
from
yusudz:yury/fix/jwt-on-demand-refresh
May 14, 2026
+197
−12
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,8 +4,11 @@ | |
|
|
||
| When the server sets pjwt_auth_enabled, the SDK exchanges the caller's | ||
| credential for a short-lived JWT minted by the Tinker server. The JWT is | ||
| cached and refreshed in the background before it expires, so callers always | ||
| send a valid token without any per-request overhead. | ||
| cached and refreshed in the background before it expires. As a safety | ||
| net, get_token() also refreshes on demand if the cached token is near or | ||
| past expiry — so a delayed/failed background refresh cannot leave callers | ||
| sending a stale JWT (which the server rejects with 401 Invalid JWT, and | ||
| 401 is not retried by the request layer). | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
@@ -22,8 +25,9 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _REFRESH_BEFORE_EXPIRY_SECS = 300 # refresh 5 min before expiry | ||
| _RETRY_DELAY_SECS = 60 | ||
| _REFRESH_BEFORE_EXPIRY_SECS = 300 # background loop refreshes 5 min before expiry | ||
| _REFRESH_ON_DEMAND_SECS = 60 # get_token() refreshes if <= this many seconds left | ||
| _RETRY_DELAY_SECS = 60 # backoff after a failed refresh | ||
|
|
||
|
|
||
| def _jwt_expiry(jwt: str) -> float: | ||
|
|
@@ -36,23 +40,51 @@ def _jwt_expiry(jwt: str) -> float: | |
| raise ValueError(f"Failed to parse JWT expiry: {e}") from e | ||
|
|
||
|
|
||
| def _seconds_until_expiry(jwt: str) -> float: | ||
| """Seconds until the JWT expires; 0 if expiry can't be parsed.""" | ||
| try: | ||
| return _jwt_expiry(jwt) - time.time() | ||
| except ValueError: | ||
| return 0.0 | ||
|
|
||
|
|
||
| class JwtAuthProvider(AuthTokenProvider): | ||
| """AuthTokenProvider that exchanges a credential for a short-lived JWT. | ||
|
|
||
| After init(), get_token() returns the current JWT. A background task | ||
| refreshes the JWT before it expires. | ||
| proactively refreshes the JWT before it expires. get_token() also | ||
| refreshes on demand if the cached token is near or past expiry, so a | ||
| stuck or delayed background refresh cannot leak a stale JWT into a | ||
| request. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| aclient_fn: Callable[[], AbstractContextManager], | ||
| seed_token: str | None = None, | ||
| ) -> None: | ||
| self._token = seed_token or "" | ||
| self._token: str = seed_token or "" | ||
| self._aclient_fn = aclient_fn | ||
| self._refresh_lock = asyncio.Lock() | ||
|
|
||
| async def get_token(self) -> str | None: | ||
| return self._token | ||
| # Fast path: cached token has comfortable runway. | ||
| if self._token and _seconds_until_expiry(self._token) > _REFRESH_ON_DEMAND_SECS: | ||
| return self._token | ||
|
|
||
| async with self._refresh_lock: | ||
| # Re-check after acquiring the lock — another caller may have | ||
| # just refreshed the token while we were waiting. | ||
| if self._token and _seconds_until_expiry(self._token) > _REFRESH_ON_DEMAND_SECS: | ||
| return self._token | ||
| try: | ||
| return await self._fetch() | ||
| except Exception as e: | ||
| # If the refresh fails, fall back to whatever we have. | ||
| # The background loop keeps trying; if the server is | ||
| # genuinely down the request will surface the error. | ||
| logger.warning("On-demand JWT refresh failed: %s", e) | ||
| return self._token or None | ||
|
|
||
| async def init(self) -> None: | ||
| """Fetch a JWT (unless seeded) then start the background refresh loop. | ||
|
|
@@ -68,23 +100,34 @@ async def _fetch(self) -> str: | |
| """Exchange the current credential for a JWT via /api/v1/auth/token.""" | ||
| with self._aclient_fn() as client: | ||
| response = await client.service.auth_token() | ||
| self._token = response.jwt | ||
| return response.jwt | ||
| jwt: str = response.jwt | ||
| self._token = jwt | ||
| return jwt | ||
|
|
||
| async def _refresh_loop(self, token: str) -> None: | ||
| while True: | ||
| try: | ||
| delay = max( | ||
| _RETRY_DELAY_SECS, | ||
| 0.0, | ||
| _jwt_expiry(token) - time.time() - _REFRESH_BEFORE_EXPIRY_SECS, | ||
| ) | ||
| except ValueError: | ||
| logger.debug("Failed to parse JWT expiry, retrying in %ds", _RETRY_DELAY_SECS) | ||
| delay = _RETRY_DELAY_SECS | ||
| delay = float(_RETRY_DELAY_SECS) | ||
| try: | ||
| await asyncio.sleep(delay) | ||
| token = await self._fetch() | ||
| # Coordinate with on-demand refreshes in get_token() so we | ||
| # don't fire two concurrent /auth/token requests. | ||
| async with self._refresh_lock: | ||
| token = await self._fetch() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: is slightly better. would prefer this |
||
| except asyncio.CancelledError: | ||
| return | ||
| except Exception as e: | ||
| logger.debug("JWT refresh failed, retrying in %ds: %s", _RETRY_DELAY_SECS, e) | ||
| # Explicit backoff: without the old max(60, ...) floor on | ||
| # `delay`, a stale token would otherwise compute delay=0 | ||
| # next iteration and tight-loop on persistent failures. | ||
| try: | ||
| await asyncio.sleep(_RETRY_DELAY_SECS) | ||
| except asyncio.CancelledError: | ||
| return | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the cached token is expired, unparseable, or absent, this hides a retryable /auth/token failure and sends a request that will likely fail with non-retryable 401/no auth. Only fall back when the cached JWT is still usable; otherwise re-raise so the normal request retry path can retry the auth refresh. Could do something like this: