Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions framework/py/flwr/supercore/interceptors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AppIoTokenServerInterceptor,
create_clientappio_token_auth_server_interceptor,
create_serverappio_token_auth_server_interceptor,
get_authenticated_task_id,
)
from .runtime_version_interceptor import (
RuntimeVersionClientInterceptor,
Expand Down Expand Up @@ -49,4 +50,5 @@
"create_serverappio_runtime_version_server_interceptor",
"create_serverappio_superexec_auth_server_interceptor",
"create_serverappio_token_auth_server_interceptor",
"get_authenticated_task_id",
]
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from contextvars import ContextVar
from typing import Any, NoReturn, Protocol, cast

import grpc
Expand All @@ -28,12 +29,15 @@
SERVERAPPIO_METHOD_AUTH_POLICY,
MethodTokenPolicy,
)
from flwr.supercore.utils import get_metadata_str
from flwr.supercore.utils import find_metadata_keys, get_metadata_str

APP_TOKEN_HEADER = "flwr-app-token"
AUTHENTICATION_FAILED_MESSAGE = "Authentication failed."


_current_task_id: ContextVar[int | None] = ContextVar("current_task_id", default=None)


class _TokenState(Protocol):
"""State methods required by token auth."""

Expand All @@ -43,6 +47,9 @@ def get_run_id_by_token(self, token: str) -> int | None:
def verify_token(self, run_id: int, token: str) -> bool:
"""Return whether token is valid for run_id."""

def get_task_id_by_token(self, token: str) -> int | None:
"""Return the task ID associated with the task token, if valid."""


def _abort_auth_denied(context: grpc.ServicerContext) -> NoReturn:
context.abort(grpc.StatusCode.UNAUTHENTICATED, AUTHENTICATION_FAILED_MESSAGE)
Expand Down Expand Up @@ -70,9 +77,12 @@ def intercept_unary_unary(
request: GrpcMessage,
) -> grpc.Call:
"""Add/replace the App token metadata on outbound unary requests."""
metadata = list(client_call_details.metadata or [])
metadata = [(key, value) for key, value in metadata if key != APP_TOKEN_HEADER]
metadata.append((APP_TOKEN_HEADER, self._token))
metadata = tuple(client_call_details.metadata or ())
if find_metadata_keys(metadata, (APP_TOKEN_HEADER,)):
raise RuntimeError(
f"{APP_TOKEN_HEADER} already present in outbound metadata."
)
metadata += ((APP_TOKEN_HEADER, self._token),)
Comment thread
panh99 marked this conversation as resolved.
details = client_call_details._replace(metadata=metadata)
return continuation(details, request)

Expand Down Expand Up @@ -127,12 +137,22 @@ def _authenticated_handler(
_abort_auth_denied(context)

state = self._state_provider()

# Legacy: Validate both token->run lookup and run->token mapping.
run_id = state.get_run_id_by_token(token)
# Validate both token->run lookup and run->token mapping.
if run_id is None or not state.verify_token(run_id, token):
_abort_auth_denied(context)
if run_id is not None and state.verify_token(run_id, token):
return unary_handler(request, context)

return unary_handler(request, context)
# Validate task token and set task context for downstream handlers
task_id = state.get_task_id_by_token(token)
if task_id is not None:
ctx_token = _current_task_id.set(task_id)
try:
return unary_handler(request, context)
finally:
_current_task_id.reset(ctx_token)

_abort_auth_denied(context)

return grpc.unary_unary_rpc_method_handler(
_authenticated_handler,
Expand All @@ -141,6 +161,21 @@ def _authenticated_handler(
)


def get_authenticated_task_id() -> int:
"""Return the task ID authenticated for the current RPC.

The task ID is available only while handling an RPC authenticated with an AppIo task
token.
"""
ret = _current_task_id.get()
if ret is None:
raise RuntimeError(
"No authenticated task ID in the current RPC context. "
"This function must be called from a task-token-authenticated RPC handler."
)
return ret


def create_serverappio_token_auth_server_interceptor(
state_provider: Callable[[], _TokenState],
) -> AppIoTokenServerInterceptor:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
AppIoTokenServerInterceptor,
create_clientappio_token_auth_server_interceptor,
create_serverappio_token_auth_server_interceptor,
get_authenticated_task_id,
)

_ClientCallDetails = namedtuple(
Expand Down Expand Up @@ -74,6 +75,11 @@ def verify_token(self, run_id: int, token: str) -> bool:
"""Return whether the token is bound to the given run id."""
return self._token_to_run_id.get(token) == run_id

def get_task_id_by_token(self, token: str) -> int | None: # pylint: disable=R1711
"""Return the task ID for a task token, if present."""
_ = token
return None # make mypy happy


def _make_unary_handler() -> grpc.RpcMethodHandler:
def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str:
Expand All @@ -94,13 +100,13 @@ def _handler(
class TestAppIoTokenClientInterceptor(TestCase):
"""Unit tests for AppIoTokenClientInterceptor."""

def test_attach_and_replace_app_token_header(self) -> None:
"""The interceptor should enforce a single App token header."""
def test_attach_app_token_header(self) -> None:
"""The interceptor should attach App token metadata."""
interceptor = AppIoTokenClientInterceptor(token="new-token")
details = _ClientCallDetails(
method="/flwr.proto.ServerAppIo/GetNodes",
timeout=None,
metadata=(("x-test", "value"), (APP_TOKEN_HEADER, "old-token")),
metadata=(("x-test", "value"),),
credentials=None,
wait_for_ready=None,
compression=None,
Expand Down Expand Up @@ -128,6 +134,25 @@ def continuation(
[(APP_TOKEN_HEADER, "new-token")],
)

def test_raise_if_app_token_header_already_present(self) -> None:
"""The interceptor should reject duplicate App token metadata."""
interceptor = AppIoTokenClientInterceptor(token="new-token")
details = _ClientCallDetails(
method="/flwr.proto.ServerAppIo/GetNodes",
timeout=None,
metadata=(("x-test", "value"), (APP_TOKEN_HEADER, "old-token")),
credentials=None,
wait_for_ready=None,
compression=None,
)

with self.assertRaises(RuntimeError):
interceptor.intercept_unary_unary(
continuation=Mock(),
client_call_details=details,
request=GetNodesRequest(run_id=1),
)


class TestAppIoTokenServerInterceptor(TestCase):
"""Unit tests for AppIoTokenServerInterceptor."""
Expand Down Expand Up @@ -233,6 +258,36 @@ def test_valid_token_passes_for_protected_method(self) -> None:
# Run-id mismatch deny coverage belongs to the
# follow-up PR that enforces run binding.

def test_valid_task_token_passes_and_sets_task_id(self) -> None:
"""Protected methods should pass with a valid task token."""
state = Mock()
state.get_run_id_by_token.return_value = None
state.get_task_id_by_token.return_value = 123
interceptor = create_serverappio_token_auth_server_interceptor(lambda: state)
method = self._find_serverappio_method(requires_token=True)
if method is None:
self.skipTest("No token-required ServerAppIo method found in policy table.")
captured_task_id = None

def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str:
nonlocal captured_task_id
captured_task_id = get_authenticated_task_id()
return "ok"

intercepted = interceptor.intercept_service(
lambda _: grpc.unary_unary_rpc_method_handler(_handler),
_HandlerCallDetails(
method,
invocation_metadata=((APP_TOKEN_HEADER, "task-token"),),
),
)

response = intercepted.unary_unary(GetNodesRequest(run_id=7), Mock())
self.assertEqual(response, "ok")
self.assertEqual(captured_task_id, 123)
state.get_run_id_by_token.assert_called_once_with("task-token")
state.get_task_id_by_token.assert_called_once_with("task-token")

def test_metadata_token_used_even_when_request_has_token(self) -> None:
"""Metadata token should be authoritative when both sources exist."""
interceptor = self._new_interceptor(token_to_run_id={"metadata-token": 5})
Expand Down
Loading