diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_auth_integration_test.py b/framework/py/flwr/server/superlink/serverappio/serverappio_auth_integration_test.py index 38fc4ac768b5..6b8b4c8c37e4 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_auth_integration_test.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_auth_integration_test.py @@ -17,10 +17,33 @@ import tempfile import unittest +from collections.abc import Callable import grpc +from google.protobuf.message import Message as GrpcMessage +from parameterized import parameterized from flwr.common.constant import SERVERAPPIO_API_DEFAULT_SERVER_ADDRESS +from flwr.proto.appio_pb2 import ( # pylint: disable=E0611 + PullAppMessagesRequest, + PullAppMessagesResponse, + PushAppMessagesRequest, + PushAppMessagesResponse, + SendTaskHeartbeatRequest, + SendTaskHeartbeatResponse, +) +from flwr.proto.log_pb2 import ( # pylint: disable=E0611 + PushLogsRequest, + PushLogsResponse, +) +from flwr.proto.message_pb2 import ( # pylint: disable=E0611 + ConfirmMessageReceivedRequest, + ConfirmMessageReceivedResponse, + PullObjectRequest, + PullObjectResponse, + PushObjectRequest, + PushObjectResponse, +) from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611 GetNodesRequest, GetNodesResponse, @@ -70,17 +93,20 @@ def setUp(self) -> None: # Seed one authenticated task token and reuse it for token-protected RPC # checks. - self._auth_run_id, auth_token = self._create_running_run() + _, auth_token = self._create_running_run() + _, self._simulation_token = self._create_running_run( + run_type=RunType.SIMULATION + ) # Create a single base channel and wrap it for authenticated calls. - base_channel = grpc.insecure_channel("localhost:9091") - self._get_nodes_no_auth = base_channel.unary_unary( + self._base_channel = grpc.insecure_channel("localhost:9091") + self._get_nodes_no_auth = self._base_channel.unary_unary( "/flwr.proto.ServerAppIo/GetNodes", request_serializer=GetNodesRequest.SerializeToString, response_deserializer=GetNodesResponse.FromString, ) auth_channel = grpc.intercept_channel( - base_channel, + self._base_channel, AppIoTokenClientInterceptor(token=auth_token), SuperExecAuthClientInterceptor( master_secret=_SUPEREXEC_SECRET, @@ -95,11 +121,14 @@ def setUp(self) -> None: def tearDown(self) -> None: """Stop the gRPC API server.""" + self._base_channel.close() self._server.stop(None) - def _create_running_run(self) -> tuple[int, str]: + def _create_running_run( + self, run_type: str = RunType.SERVER_APP + ) -> tuple[int, str]: run_id = self.state.create_run( - "", "", "", {}, NOOP_FEDERATION, None, "", RunType.SERVER_APP + "", "", "", {}, NOOP_FEDERATION, None, "", run_type ) run = self.state.get_run_info(run_ids=[run_id])[0] assert run.primary_task_id is not None @@ -131,3 +160,99 @@ def test_get_nodes_allows_with_valid_metadata_token(self) -> None: assert isinstance(response, GetNodesResponse) assert call.code() == grpc.StatusCode.OK + + @parameterized.expand( + [ + ( + "get_nodes", + "/flwr.proto.ServerAppIo/GetNodes", + GetNodesRequest(), + GetNodesResponse.FromString, + ), + ( + "push_messages", + "/flwr.proto.ServerAppIo/PushMessages", + PushAppMessagesRequest(), + PushAppMessagesResponse.FromString, + ), + ( + "pull_messages", + "/flwr.proto.ServerAppIo/PullMessages", + PullAppMessagesRequest(), + PullAppMessagesResponse.FromString, + ), + ( + "push_object", + "/flwr.proto.ServerAppIo/PushObject", + PushObjectRequest(), + PushObjectResponse.FromString, + ), + ( + "pull_object", + "/flwr.proto.ServerAppIo/PullObject", + PullObjectRequest(), + PullObjectResponse.FromString, + ), + ( + "confirm_message_received", + "/flwr.proto.ServerAppIo/ConfirmMessageReceived", + ConfirmMessageReceivedRequest(), + ConfirmMessageReceivedResponse.FromString, + ), + ] + ) # type: ignore + def test_serverapp_only_endpoint_denied_for_simulation_run( + self, + _case_name: str, + method: str, + request: GrpcMessage, + response_deserializer: Callable[[bytes], object], + ) -> None: + """ServerApp-only RPCs should deny simulation-run tokens.""" + rpc = self._base_channel.unary_unary( + method, + request_serializer=type(request).SerializeToString, + response_deserializer=response_deserializer, + ) + with self.assertRaises(grpc.RpcError) as err: + rpc.with_call( + request=request, + metadata=((TASK_TOKEN_HEADER, self._simulation_token),), + ) + assert err.exception.code() == grpc.StatusCode.PERMISSION_DENIED + + @parameterized.expand( + [ + ( + "send_task_heartbeat", + "/flwr.proto.ServerAppIo/SendTaskHeartbeat", + SendTaskHeartbeatRequest(), + SendTaskHeartbeatResponse.FromString, + ), + ( + "push_logs", + "/flwr.proto.ServerAppIo/PushLogs", + PushLogsRequest(logs=["hello"]), + PushLogsResponse.FromString, + ), + ] + ) # type: ignore + def test_shared_task_endpoint_allows_simulation_run( + self, + _case_name: str, + method: str, + request: GrpcMessage, + response_deserializer: Callable[[bytes], object], + ) -> None: + """Shared task RPCs should still allow simulation-run tokens.""" + rpc = self._base_channel.unary_unary( + method, + request_serializer=type(request).SerializeToString, + response_deserializer=response_deserializer, + ) + response, call = rpc.with_call( + request=request, + metadata=((TASK_TOKEN_HEADER, self._simulation_token),), + ) + assert response is not None + assert call.code() == grpc.StatusCode.OK diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer.py b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer.py index ad30d2386eb4..170d8f7ef195 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer.py @@ -62,6 +62,7 @@ ) from flwr.server.superlink.linkstate import LinkState, LinkStateFactory from flwr.server.utils.validator import validate_message +from flwr.supercore.constant import TaskType from flwr.supercore.inflatable.inflatable_object import ( UnexpectedObjectContentError, get_all_nested_objects, @@ -72,6 +73,10 @@ from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory from flwr.supercore.servicers import AppIoServicer +SERVERAPPIO_ENDPOINT_UNAVAILABLE_MESSAGE = ( + "Some ServerAppIo API endpoints are only available for Deployment Runtime runs." +) + class ServerAppIoServicer(AppIoServicer, serverappio_pb2_grpc.ServerAppIoServicer): """ServerAppIo API servicer.""" @@ -97,9 +102,7 @@ def GetNodes( # Init state state = self.state_factory.state() - # Get authenticated task and associated run ID - task = get_authenticated_task() - run_id = task.run_id + run_id = _get_authenticated_serverapp_run_id(context) all_ids: set[int] = state.get_nodes(run_id) nodes: list[Node] = [Node(node_id=node_id) for node_id in all_ids] @@ -115,9 +118,7 @@ def PushMessages( state = self.state_factory.state() store = self.objectstore_factory.store() - # Get authenticated task and associated run ID - task = get_authenticated_task() - run_id = task.run_id + run_id = _get_authenticated_serverapp_run_id(context) # Validate request and insert in State _raise_if( @@ -165,9 +166,7 @@ def PullMessages( # pylint: disable=R0914 state = self.state_factory.state() store = self.objectstore_factory.store() - # Get authenticated task and associated run ID - task = get_authenticated_task() - run_id = task.run_id + run_id = _get_authenticated_serverapp_run_id(context) # Read from state messages_res: list[Message] = state.get_message_res( @@ -311,9 +310,11 @@ def PushObject( """Push an object to the ObjectStore.""" log(DEBUG, "ServerAppIoServicer.PushObject") - # Init state and store + # Init store store = self.objectstore_factory.store() + _ = _get_authenticated_serverapp_run_id(context) + if request.node.node_id != SUPERLINK_NODE_ID: # Cancel insertion in ObjectStore context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.") @@ -337,9 +338,11 @@ def PullObject( """Pull an object from the ObjectStore.""" log(DEBUG, "ServerAppIoServicer.PullObject") - # Init state and store + # Init store store = self.objectstore_factory.store() + _ = _get_authenticated_serverapp_run_id(context) + if request.node.node_id != SUPERLINK_NODE_ID: # Cancel insertion in ObjectStore context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.") @@ -361,15 +364,28 @@ def ConfirmMessageReceived( """Confirm message received.""" log(DEBUG, "ServerAppIoServicer.ConfirmMessageReceived") - # Init state and store + # Init store store = self.objectstore_factory.store() + _ = _get_authenticated_serverapp_run_id(context) + # Delete the message object store.delete(request.message_object_id) return ConfirmMessageReceivedResponse() +def _get_authenticated_serverapp_run_id(context: grpc.ServicerContext) -> int: + """Return the authenticated run ID if it can use ServerAppIo endpoints.""" + task = get_authenticated_task() + if task.type != TaskType.SERVER_APP: + context.abort( + grpc.StatusCode.PERMISSION_DENIED, + SERVERAPPIO_ENDPOINT_UNAVAILABLE_MESSAGE, + ) + return task.run_id + + def _raise_if(validation_error: bool, request_name: str, detail: str) -> None: """Raise a `ValueError` with a detailed message if a validation error occurs.""" if validation_error: