diff --git a/framework/py/flwr/supercore/interceptors/__init__.py b/framework/py/flwr/supercore/interceptors/__init__.py index b29dcf39e9d5..8838fe9e49d2 100644 --- a/framework/py/flwr/supercore/interceptors/__init__.py +++ b/framework/py/flwr/supercore/interceptors/__init__.py @@ -22,6 +22,7 @@ AppIoTokenServerInterceptor, create_clientappio_token_auth_server_interceptor, create_serverappio_token_auth_server_interceptor, + get_authenticated_task_id, ) from .runtime_version_interceptor import ( RuntimeVersionClientInterceptor, @@ -49,4 +50,5 @@ "create_serverappio_runtime_version_server_interceptor", "create_serverappio_superexec_auth_server_interceptor", "create_serverappio_token_auth_server_interceptor", + "get_authenticated_task_id", ] diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index 94a95bc3b8f7..86a50d24ab08 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -18,6 +18,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping +from contextvars import ContextVar from typing import Any, NoReturn, Protocol, cast import grpc @@ -28,12 +29,15 @@ SERVERAPPIO_METHOD_AUTH_POLICY, MethodTokenPolicy, ) -from flwr.supercore.utils import get_metadata_str +from flwr.supercore.utils import find_metadata_keys, get_metadata_str APP_TOKEN_HEADER = "flwr-app-token" AUTHENTICATION_FAILED_MESSAGE = "Authentication failed." +_current_task_id: ContextVar[int | None] = ContextVar("current_task_id", default=None) + + class _TokenState(Protocol): """State methods required by token auth.""" @@ -43,6 +47,9 @@ def get_run_id_by_token(self, token: str) -> int | None: def verify_token(self, run_id: int, token: str) -> bool: """Return whether token is valid for run_id.""" + def get_task_id_by_token(self, token: str) -> int | None: + """Return the task ID associated with the task token, if valid.""" + def _abort_auth_denied(context: grpc.ServicerContext) -> NoReturn: context.abort(grpc.StatusCode.UNAUTHENTICATED, AUTHENTICATION_FAILED_MESSAGE) @@ -70,9 +77,12 @@ def intercept_unary_unary( request: GrpcMessage, ) -> grpc.Call: """Add/replace the App token metadata on outbound unary requests.""" - metadata = list(client_call_details.metadata or []) - metadata = [(key, value) for key, value in metadata if key != APP_TOKEN_HEADER] - metadata.append((APP_TOKEN_HEADER, self._token)) + metadata = tuple(client_call_details.metadata or ()) + if find_metadata_keys(metadata, (APP_TOKEN_HEADER,)): + raise RuntimeError( + f"{APP_TOKEN_HEADER} already present in outbound metadata." + ) + metadata += ((APP_TOKEN_HEADER, self._token),) details = client_call_details._replace(metadata=metadata) return continuation(details, request) @@ -127,12 +137,22 @@ def _authenticated_handler( _abort_auth_denied(context) state = self._state_provider() + + # Legacy: Validate both token->run lookup and run->token mapping. run_id = state.get_run_id_by_token(token) - # Validate both token->run lookup and run->token mapping. - if run_id is None or not state.verify_token(run_id, token): - _abort_auth_denied(context) + if run_id is not None and state.verify_token(run_id, token): + return unary_handler(request, context) - return unary_handler(request, context) + # Validate task token and set task context for downstream handlers + task_id = state.get_task_id_by_token(token) + if task_id is not None: + ctx_token = _current_task_id.set(task_id) + try: + return unary_handler(request, context) + finally: + _current_task_id.reset(ctx_token) + + _abort_auth_denied(context) return grpc.unary_unary_rpc_method_handler( _authenticated_handler, @@ -141,6 +161,21 @@ def _authenticated_handler( ) +def get_authenticated_task_id() -> int: + """Return the task ID authenticated for the current RPC. + + The task ID is available only while handling an RPC authenticated with an AppIo task + token. + """ + ret = _current_task_id.get() + if ret is None: + raise RuntimeError( + "No authenticated task ID in the current RPC context. " + "This function must be called from a task-token-authenticated RPC handler." + ) + return ret + + def create_serverappio_token_auth_server_interceptor( state_provider: Callable[[], _TokenState], ) -> AppIoTokenServerInterceptor: diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py index 14bda6d5c8f1..ed10aa78cc1f 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py @@ -44,6 +44,7 @@ AppIoTokenServerInterceptor, create_clientappio_token_auth_server_interceptor, create_serverappio_token_auth_server_interceptor, + get_authenticated_task_id, ) _ClientCallDetails = namedtuple( @@ -74,6 +75,11 @@ def verify_token(self, run_id: int, token: str) -> bool: """Return whether the token is bound to the given run id.""" return self._token_to_run_id.get(token) == run_id + def get_task_id_by_token(self, token: str) -> int | None: # pylint: disable=R1711 + """Return the task ID for a task token, if present.""" + _ = token + return None # make mypy happy + def _make_unary_handler() -> grpc.RpcMethodHandler: def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: @@ -94,13 +100,13 @@ def _handler( class TestAppIoTokenClientInterceptor(TestCase): """Unit tests for AppIoTokenClientInterceptor.""" - def test_attach_and_replace_app_token_header(self) -> None: - """The interceptor should enforce a single App token header.""" + def test_attach_app_token_header(self) -> None: + """The interceptor should attach App token metadata.""" interceptor = AppIoTokenClientInterceptor(token="new-token") details = _ClientCallDetails( method="/flwr.proto.ServerAppIo/GetNodes", timeout=None, - metadata=(("x-test", "value"), (APP_TOKEN_HEADER, "old-token")), + metadata=(("x-test", "value"),), credentials=None, wait_for_ready=None, compression=None, @@ -128,6 +134,25 @@ def continuation( [(APP_TOKEN_HEADER, "new-token")], ) + def test_raise_if_app_token_header_already_present(self) -> None: + """The interceptor should reject duplicate App token metadata.""" + interceptor = AppIoTokenClientInterceptor(token="new-token") + details = _ClientCallDetails( + method="/flwr.proto.ServerAppIo/GetNodes", + timeout=None, + metadata=(("x-test", "value"), (APP_TOKEN_HEADER, "old-token")), + credentials=None, + wait_for_ready=None, + compression=None, + ) + + with self.assertRaises(RuntimeError): + interceptor.intercept_unary_unary( + continuation=Mock(), + client_call_details=details, + request=GetNodesRequest(run_id=1), + ) + class TestAppIoTokenServerInterceptor(TestCase): """Unit tests for AppIoTokenServerInterceptor.""" @@ -233,6 +258,36 @@ def test_valid_token_passes_for_protected_method(self) -> None: # Run-id mismatch deny coverage belongs to the # follow-up PR that enforces run binding. + def test_valid_task_token_passes_and_sets_task_id(self) -> None: + """Protected methods should pass with a valid task token.""" + state = Mock() + state.get_run_id_by_token.return_value = None + state.get_task_id_by_token.return_value = 123 + interceptor = create_serverappio_token_auth_server_interceptor(lambda: state) + method = self._find_serverappio_method(requires_token=True) + if method is None: + self.skipTest("No token-required ServerAppIo method found in policy table.") + captured_task_id = None + + def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: + nonlocal captured_task_id + captured_task_id = get_authenticated_task_id() + return "ok" + + intercepted = interceptor.intercept_service( + lambda _: grpc.unary_unary_rpc_method_handler(_handler), + _HandlerCallDetails( + method, + invocation_metadata=((APP_TOKEN_HEADER, "task-token"),), + ), + ) + + response = intercepted.unary_unary(GetNodesRequest(run_id=7), Mock()) + self.assertEqual(response, "ok") + self.assertEqual(captured_task_id, 123) + state.get_run_id_by_token.assert_called_once_with("task-token") + state.get_task_id_by_token.assert_called_once_with("task-token") + def test_metadata_token_used_even_when_request_has_token(self) -> None: """Metadata token should be authoritative when both sources exist.""" interceptor = self._new_interceptor(token_to_run_id={"metadata-token": 5})