diff --git a/google/auth/credentials.py b/google/auth/credentials.py index 82c73c3bf..d53c52021 100644 --- a/google/auth/credentials.py +++ b/google/auth/credentials.py @@ -69,6 +69,7 @@ def __init__(self): self._use_non_blocking_refresh = False self._refresh_worker = RefreshThreadManager() + self._custom_headers = {} @property def expired(self): @@ -185,6 +186,7 @@ def apply(self, headers, token=None): self._apply(headers, token) if self.quota_project_id: headers["x-goog-user-project"] = self.quota_project_id + headers.update(self._custom_headers) def _blocking_refresh(self, request): if not self.valid: @@ -233,6 +235,38 @@ def before_request(self, request, method, url, headers): def with_non_blocking_refresh(self): self._use_non_blocking_refresh = True + def with_headers(self, headers): + """Returns a copy of these credentials with additional custom headers. + + Args: + headers (Mapping[str, str]): The custom headers to add. + + Returns: + google.auth.credentials.Credentials: A new credentials instance. + + Raises: + ValueError: If a protected header is included in the input headers. + """ + import copy + + PROTECTED_HEADERS = { + "authorization", + "x-goog-user-project", + "x-goog-api-client", + "x-allowed-locations", + } + + for key in headers: + if key.lower() in PROTECTED_HEADERS: + raise ValueError( + f"Header '{key}' is protected and cannot be set with with_headers. " + "These headers are managed by the library." + ) + + new_creds = copy.deepcopy(self) + new_creds._custom_headers.update(headers) + return new_creds + class CredentialsWithQuotaProject(Credentials): """Abstract base for credentials supporting ``with_quota_project`` factory""" diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 1fb880096..a95db9ba3 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -71,6 +71,48 @@ def test_with_non_blocking_refresh(): assert c._use_non_blocking_refresh +def test_with_headers(): + credentials = CredentialsImpl() + request = mock.Mock() + + # 1. Add a new custom header + creds_with_header = credentials.with_headers({"X-Custom-Header": "value1"}) + headers = {} + creds_with_header.before_request(request, "http://example.com", "GET", headers) + assert headers["X-Custom-Header"] == "value1" + assert "authorization" in headers # Ensure base apply logic ran + assert creds_with_header is not credentials + assert not hasattr(credentials, "_custom_headers") or not credentials._custom_headers + + # 2. Update an existing custom header + creds_updated = creds_with_header.with_headers({"X-Custom-Header": "value2"}) + headers = {} + creds_updated.before_request(request, "http://example.com", "GET", headers) + assert headers["X-Custom-Header"] == "value2" + + # 3. Chaining with_headers calls + creds_chained = credentials.with_headers({"X-Header-1": "v1"}).with_headers( + {"X-Header-2": "v2"} + ) + headers = {} + creds_chained.before_request(request, "http://example.com", "GET", headers) + assert headers["X-Header-1"] == "v1" + assert headers["X-Header-2"] == "v2" + + # 4. Ensure protected headers cannot be set + with pytest.raises(ValueError): + credentials.with_headers({"Authorization": "Bearer token"}) + with pytest.raises(ValueError): + credentials.with_headers({"X-Goog-User-Project": "test"}) + with pytest.raises(ValueError): + credentials.with_headers({"authorization": "Bearer token"}) # Case-insensitive + + # 5. Check original credentials are not modified + headers = {} + credentials.before_request(request, "http://example.com", "GET", headers) + assert "X-Custom-Header" not in headers + + def test_expired_and_valid(): credentials = CredentialsImpl() credentials.token = "token"