diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py index 70083d224fa6..cc4afb3e85ff 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py @@ -332,24 +332,10 @@ def setUp(self) -> None: superexec_auth_secret=_SUPEREXEC_SECRET, ) - # Provide a valid metadata token on the default test channel so existing - # servicer behavior tests continue to exercise business logic paths. - self._auth_run_id = self.state.create_run( - "", "", "", {}, NOOP_FEDERATION, None, "", RunType.SERVER_APP - ) - auth_token = self.state.create_token(self._auth_run_id) - assert auth_token is not None - self._auth_token = auth_token - self._appio_auth_interceptor = AppIoTokenClientInterceptor(auth_token) - _ = self.state.update_run_status( - self._auth_run_id, RunStatus(Status.STARTING, "", "") - ) - _ = self.state.update_run_status( - self._auth_run_id, RunStatus(Status.RUNNING, "", "") - ) + self._appio_token_client_interceptor = AppIoTokenClientInterceptor(token="") self._channel = grpc.intercept_channel( grpc.insecure_channel("localhost:9091"), - self._appio_auth_interceptor, + self._appio_token_client_interceptor, SuperExecAuthClientInterceptor( master_secret=_SUPEREXEC_SECRET, protected_methods=SERVERAPPIO_SUPEREXEC_METHODS, @@ -429,10 +415,19 @@ def _create_dummy_run(self, running: bool = True, *, fab_hash: str = "") -> int: "", RunType.SERVER_APP, ) + + # Set token in the client interceptor to pass authentication + assert (token := self.state.create_token(run_id)) is not None + self._set_token(token) + + # Transition run status if required if running: self._transition_run_status(run_id, 2) return run_id + def _set_token(self, token: str) -> None: + self._appio_token_client_interceptor._token = token # pylint: disable=W0212 + def test_successful_get_node_if_running(self) -> None: """Test `GetNode` success.""" # Prepare @@ -542,13 +537,20 @@ def test_create_task_rejects_missing_required_fields( assert err.exception.code() == grpc.StatusCode.FAILED_PRECONDITION assert err.exception.details() == error_msg - def test_create_task_fast_fails_missing_run(self) -> None: - """Test `CreateTask` propagates an unknown run ID as an RPC failure.""" + def test_create_task_rejects_missing_run(self) -> None: + """Test `CreateTask` rejects unknown run IDs.""" + # Seed a stale token binding so auth passes and the servicer still + # exercises the missing-run branch. + run_id = 42 + token = self.state.create_token(run_id) + assert token is not None + self._set_token(token) + with self.assertRaises(grpc.RpcError) as err: self._create_task.with_call( request=CreateTaskRequest( type=TaskType.MODEL, - run_id=42, + run_id=run_id, model_ref="model://test", ) ) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index 58a21de8f978..891e70f9dd8c 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -34,6 +34,7 @@ APP_TOKEN_HEADER = "flwr-app-token" AUTHENTICATION_FAILED_MESSAGE = "Authentication failed." +RUN_BINDING_FAILED_MESSAGE = "Token is not valid for requested run." _current_task: ContextVar[Task | None] = ContextVar("current_task", default=None) @@ -57,6 +58,11 @@ def _abort_auth_denied(context: grpc.ServicerContext) -> NoReturn: raise RuntimeError("Should not reach this point") +def _abort_run_denied(context: grpc.ServicerContext) -> NoReturn: + context.abort(grpc.StatusCode.PERMISSION_DENIED, RUN_BINDING_FAILED_MESSAGE) + raise RuntimeError("Should not reach this point") + + def _unauthenticated_terminator() -> grpc.RpcMethodHandler: def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: context.abort(grpc.StatusCode.UNAUTHENTICATED, AUTHENTICATION_FAILED_MESSAGE) @@ -65,6 +71,14 @@ def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMess return grpc.unary_unary_rpc_method_handler(_terminate) +def _get_request_run_id(request: GrpcMessage) -> int | None: + run_id = cast(int | None, getattr(request, "run_id", None)) + # Proto3 scalar fields without presence expose an omitted uint64 as 0. + if run_id == 0: + return None + return run_id + + class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore """Attach App token metadata to outbound unary RPCs.""" @@ -139,14 +153,20 @@ def _authenticated_handler( state = self._state_provider() - # Legacy: Validate both token->run lookup and run->token mapping. + # Legacy: validate both token->run lookup and run->token mapping. run_id = state.get_run_id_by_token(token) if run_id is not None and state.verify_token(run_id, token): + request_run_id: int | None = _get_request_run_id(request) + if request_run_id is not None and request_run_id != run_id: + _abort_run_denied(context) return unary_handler(request, context) - # Validate task token and set task context for downstream handlers + # Validate task token and set task context for downstream handlers. task = state.get_task_by_token(token) if task is not None: + request_run_id = _get_request_run_id(request) + if request_run_id is not None and request_run_id != task.run_id: + _abort_run_denied(context) ctx_token = _current_task.set(task) try: return unary_handler(request, context) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py index 954604e08c52..2a2ce5f6c303 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py @@ -47,6 +47,9 @@ create_serverappio_token_auth_server_interceptor, get_authenticated_task, ) +from flwr.supercore.interceptors.appio_token_interceptor import ( + RUN_BINDING_FAILED_MESSAGE, +) _ClientCallDetails = namedtuple( "_ClientCallDetails", @@ -256,14 +259,36 @@ def test_valid_token_passes_for_protected_method(self) -> None: # cross-run use is expected. response = cast(str, intercepted.unary_unary(GetNodesRequest(run_id=7), Mock())) self.assertEqual(response, "ok") - # Run-id mismatch deny coverage belongs to the - # follow-up PR that enforces run binding. + + def test_run_id_mismatch_denied_for_protected_method(self) -> None: + """Protected methods should deny token use against a different run.""" + interceptor = self._new_interceptor(token_to_run_id={"valid": 7}) + method = self._find_serverappio_method(requires_token=True) + if method is None: + self.skipTest("No token-required ServerAppIo method found in policy table.") + context = Mock() + context.abort.side_effect = grpc.RpcError() + + intercepted = interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + method, + invocation_metadata=((APP_TOKEN_HEADER, "valid"),), + ), + ) + + with self.assertRaises(grpc.RpcError): + intercepted.unary_unary(GetNodesRequest(run_id=8), context) + context.abort.assert_called_once_with( + grpc.StatusCode.PERMISSION_DENIED, + RUN_BINDING_FAILED_MESSAGE, + ) def test_valid_task_token_passes_and_sets_task_id(self) -> None: """Protected methods should pass with a valid task token.""" state = Mock() state.get_run_id_by_token.return_value = None - state.get_task_by_token.return_value = Mock(task_id=123) + state.get_task_by_token.return_value = Mock(task_id=123, run_id=7) interceptor = create_serverappio_token_auth_server_interceptor(lambda: state) method = self._find_serverappio_method(requires_token=True) if method is None: @@ -290,6 +315,56 @@ def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: state.get_run_id_by_token.assert_called_once_with("task-token") state.get_task_by_token.assert_called_once_with("task-token") + def test_task_token_passes_when_run_id_omitted(self) -> None: + """Task-token protected methods should accept omitted request run IDs.""" + state = Mock() + state.get_run_id_by_token.return_value = None + state.get_task_by_token.return_value = Mock(task_id=123, run_id=7) + interceptor = create_serverappio_token_auth_server_interceptor(lambda: state) + captured_task = None + + def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: + nonlocal captured_task + captured_task = get_authenticated_task() + return "ok" + + intercepted = interceptor.intercept_service( + lambda _: grpc.unary_unary_rpc_method_handler(_handler), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/PushAppOutputs", + invocation_metadata=((APP_TOKEN_HEADER, "task-token"),), + ), + ) + + response = intercepted.unary_unary(PushAppOutputsRequest(), Mock()) + self.assertEqual(response, "ok") + self.assertIsNotNone(captured_task) + self.assertEqual(cast(Mock, captured_task).task_id, 123) + + def test_task_token_run_id_mismatch_denied(self) -> None: + """Task-token protected methods should deny a different request run.""" + state = Mock() + state.get_run_id_by_token.return_value = None + state.get_task_by_token.return_value = Mock(task_id=123, run_id=7) + interceptor = create_serverappio_token_auth_server_interceptor(lambda: state) + context = Mock() + context.abort.side_effect = grpc.RpcError() + + intercepted = interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/PushAppOutputs", + invocation_metadata=((APP_TOKEN_HEADER, "task-token"),), + ), + ) + + with self.assertRaises(grpc.RpcError): + intercepted.unary_unary(PushAppOutputsRequest(run_id=8), context) + context.abort.assert_called_once_with( + grpc.StatusCode.PERMISSION_DENIED, + RUN_BINDING_FAILED_MESSAGE, + ) + def test_metadata_token_used_even_when_request_has_token(self) -> None: """Metadata token should be authoritative when both sources exist.""" interceptor = self._new_interceptor(token_to_run_id={"metadata-token": 5}) @@ -485,7 +560,7 @@ def test_clientappio_factory_uses_client_policy(self) -> None: response = cast( str, intercepted.unary_unary( - PushObjectRequest(object_id="obj", object_content=b"x"), + PushObjectRequest(run_id=1, object_id="obj", object_content=b"x"), Mock(), ), ) diff --git a/framework/py/flwr/supernode/runtime/run_clientapp.py b/framework/py/flwr/supernode/runtime/run_clientapp.py index 07fe7750cacb..a3282dab3899 100644 --- a/framework/py/flwr/supernode/runtime/run_clientapp.py +++ b/framework/py/flwr/supernode/runtime/run_clientapp.py @@ -234,7 +234,7 @@ def pull_appinputs(stub: ClientAppIoStub) -> tuple[Message, Context, Run, Fab]: # Pull and inflate the message pull_msg_res: PullAppMessagesResponse = stub.PullMessage( - PullAppMessagesRequest() + PullAppMessagesRequest(run_id=run.run_id) ) run_id = context.run_id node = Node(node_id=context.node_id) @@ -280,7 +280,9 @@ def push_appoutputs( # pylint: disable=R0913, R0917 # This is temporary. The message should not contain its content push_msg_res = stub.PushMessage( PushAppMessagesRequest( - messages_list=[proto_message], message_object_trees=[object_tree] + messages_list=[proto_message], + run_id=context.run_id, + message_object_trees=[object_tree], ) ) del proto_message @@ -304,7 +306,10 @@ def push_appoutputs( # pylint: disable=R0913, R0917 # Push Context res: PushAppOutputsResponse = stub.PushAppOutputs( PushAppOutputsRequest( - context=proto_context, sub_status=sub_status, details=details + run_id=context.run_id, + context=proto_context, + sub_status=sub_status, + details=details, ) ) return res diff --git a/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py b/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py index d32b6b7718f4..7268a4a55a8a 100644 --- a/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py +++ b/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py @@ -56,7 +56,8 @@ def setUp(self) -> None: state_factory = NodeStateFactory(objectstore_factory=objectstore_factory) state = state_factory.state() - token = state.create_token(99) + self.valid_run_id = 99 + token = state.create_token(self.valid_run_id) assert token is not None self.valid_token = token @@ -100,7 +101,9 @@ def tearDown(self) -> None: def test_pull_object_denied_without_metadata_token(self) -> None: """Protected RPC should deny requests missing metadata token.""" with self.assertRaises(grpc.RpcError) as err: - self._pull_object.with_call(request=PullObjectRequest(object_id="obj-1")) + self._pull_object.with_call( + request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-1") + ) assert err.exception.code() == grpc.StatusCode.UNAUTHENTICATED assert err.exception.details() == AUTHENTICATION_FAILED_MESSAGE @@ -108,7 +111,7 @@ def test_pull_object_denied_with_invalid_metadata_token(self) -> None: """Protected RPC should deny requests with invalid metadata token.""" with self.assertRaises(grpc.RpcError) as err: self._pull_object.with_call( - request=PullObjectRequest(object_id="obj-2"), + request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-2"), metadata=((APP_TOKEN_HEADER, "invalid-token"),), ) assert err.exception.code() == grpc.StatusCode.UNAUTHENTICATED @@ -117,7 +120,7 @@ def test_pull_object_denied_with_invalid_metadata_token(self) -> None: def test_pull_object_allows_with_valid_metadata_token(self) -> None: """Protected RPC should allow requests with valid metadata token.""" response, call = self._pull_object.with_call( - request=PullObjectRequest(object_id="obj-3"), + request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-3"), metadata=((APP_TOKEN_HEADER, self.valid_token),), )