Skip to content

Commit 6b06f7f

Browse files
author
GreenHatHG
committed
refactor(dvc_webdav): ensure BearerAuthClient is initialized only once
1 parent 2a40b8d commit 6b06f7f

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

dvc_webdav/__init__.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,9 @@ def ask_password(host, user):
2323

2424
@wrap_with(threading.Lock())
2525
@memoize
26-
def get_bearer_auth_client(bearer_token_command: str, token: Optional[str] = None, save_token_cb=None):
27-
logger.debug(
28-
"Bearer token command provided, using BearerAuthClient, command: %s",
29-
bearer_token_command,
30-
)
31-
return BearerAuthClient(
32-
bearer_token_command, token=token, save_token_cb=save_token_cb
33-
)
26+
def get_bearer_auth_client(bearer_token_command: str):
27+
logger.debug("Bearer token command provided, using BearerAuthClient, command: %s", bearer_token_command, )
28+
return BearerAuthClient(bearer_token_command)
3429

3530

3631
class WebDAVFileSystem(FileSystem): # pylint:disable=abstract-method
@@ -54,9 +49,11 @@ def __init__(self, **config):
5449
}
5550
)
5651
if bearer_token_command := config.get("bearer_token_command"):
57-
self.fs_args["http_client"] = get_bearer_auth_client(
58-
bearer_token_command, token=config.get('token'), save_token_cb=self._save_token
59-
)
52+
client = get_bearer_auth_client(bearer_token_command)
53+
client.save_token_cb = self._save_token
54+
if token := config.get("token"):
55+
client.update_token(token)
56+
self.fs_args["http_client"] = client
6057

6158
def unstrip_protocol(self, path: str) -> str:
6259
return self.fs_args["base_url"] + "/" + path

dvc_webdav/bearer_auth_client.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,16 @@ def __init__(
9292
self,
9393
bearer_token_command: str,
9494
save_token_cb: Optional[TokenSaver] = None,
95-
token: Optional[str] = None,
9695
**kwargs,
9796
):
9897
super().__init__(**kwargs)
9998
if not isinstance(bearer_token_command, str) or not bearer_token_command.strip():
10099
raise ValueError("[BearerAuthClient] bearer_token_command must be a non-empty string")
101100
self.bearer_token_command = bearer_token_command
102101
self.save_token_cb = save_token_cb
103-
self._token: Optional[str] = token
102+
self._token: Optional[str] = None
104103
self._lock = threading.Lock()
105104

106-
if not self._token:
107-
auth_header = self.headers.get("Authorization")
108-
if auth_header and auth_header.startswith("Bearer "):
109-
self._token = auth_header.split(" ", 1)[1]
110-
111-
if self._token:
112-
logger.debug("[BearerAuthClient] Initial token found, setting Authorization header.")
113-
self.headers["Authorization"] = f"Bearer {self._token}"
114-
115105
def _refresh_token_locked(self) -> None:
116106
"""Execute token command and update state."""
117107
_log_with_thread(logging.DEBUG, "[BearerAuthClient] Refreshing token via command...")
@@ -142,6 +132,16 @@ def _ensure_token(self) -> None:
142132
if not self._token:
143133
self._refresh_token_locked()
144134

135+
def update_token(self, token: Optional[str]) -> None:
136+
"""Update the token with a new one"""
137+
if not token:
138+
return
139+
140+
with self._lock:
141+
if self._token != token:
142+
self._token = token
143+
self.headers["Authorization"] = f"Bearer {token}"
144+
145145
def request(self, *args, **kwargs) -> httpx.Response:
146146
"""Wraps httpx.request with auto-refresh logic for 401 Unauthorized."""
147147
self._ensure_token()

0 commit comments

Comments
 (0)