From 76001271a70b4c79b3a48b25f50b18a690e59fd9 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 20 Apr 2026 19:12:07 -0700 Subject: [PATCH 01/11] fix(framework): Validate run_id --- .../interceptors/appio_token_interceptor.py | 19 ++++++++++++ .../appio_token_interceptor_test.py | 31 +++++++++++++++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index 94a95bc3b8f7..dcc950386877 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -32,6 +32,7 @@ APP_TOKEN_HEADER = "flwr-app-token" AUTHENTICATION_FAILED_MESSAGE = "Authentication failed." +RUN_BINDING_FAILED_MESSAGE = "Token is not valid for requested run." class _TokenState(Protocol): @@ -49,6 +50,11 @@ def _abort_auth_denied(context: grpc.ServicerContext) -> NoReturn: raise RuntimeError("Should not reach this point") +def _abort_run_denied(context: grpc.ServicerContext) -> NoReturn: + context.abort(grpc.StatusCode.PERMISSION_DENIED, RUN_BINDING_FAILED_MESSAGE) + raise RuntimeError("Should not reach this point") + + def _unauthenticated_terminator() -> grpc.RpcMethodHandler: def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: context.abort(grpc.StatusCode.UNAUTHENTICATED, AUTHENTICATION_FAILED_MESSAGE) @@ -57,6 +63,13 @@ def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMess return grpc.unary_unary_rpc_method_handler(_terminate) +def _get_request_run_id(request: GrpcMessage) -> int | None: + descriptor = getattr(request, "DESCRIPTOR", None) + if descriptor is None or "run_id" not in descriptor.fields_by_name: + return None + return cast(int, getattr(request, "run_id")) + + class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore """Attach App token metadata to outbound unary RPCs.""" @@ -128,10 +141,16 @@ def _authenticated_handler( state = self._state_provider() run_id = state.get_run_id_by_token(token) + # Validate both token->run lookup and run->token mapping. if run_id is None or not state.verify_token(run_id, token): _abort_auth_denied(context) + # Validate request.run_id matches the run_id associated with the token + request_run_id: int | None = _get_request_run_id(request) + if request_run_id is not None and request_run_id != run_id: + _abort_run_denied(context) + return unary_handler(request, context) return grpc.unary_unary_rpc_method_handler( diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py index 5a5a6f278496..dc5b3b41f4fe 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py @@ -45,6 +45,9 @@ create_clientappio_token_auth_server_interceptor, create_serverappio_token_auth_server_interceptor, ) +from flwr.supercore.interceptors.appio_token_interceptor import ( + RUN_BINDING_FAILED_MESSAGE, +) _ClientCallDetails = namedtuple( "_ClientCallDetails", @@ -230,8 +233,30 @@ def test_valid_token_passes_for_protected_method(self) -> None: # cross-run use is expected. response = cast(str, intercepted.unary_unary(GetNodesRequest(run_id=7), Mock())) self.assertEqual(response, "ok") - # Run-id mismatch deny coverage belongs to the - # follow-up PR that enforces run binding. + + def test_run_id_mismatch_denied_for_protected_method(self) -> None: + """Protected methods should deny token use against a different run.""" + interceptor = self._new_interceptor(token_to_run_id={"valid": 7}) + method = self._find_serverappio_method(requires_token=True) + if method is None: + self.skipTest("No token-required ServerAppIo method found in policy table.") + context = Mock() + context.abort.side_effect = grpc.RpcError() + + intercepted = interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + method, + invocation_metadata=((APP_TOKEN_HEADER, "valid"),), + ), + ) + + with self.assertRaises(grpc.RpcError): + intercepted.unary_unary(GetNodesRequest(run_id=8), context) + context.abort.assert_called_once_with( + grpc.StatusCode.PERMISSION_DENIED, + RUN_BINDING_FAILED_MESSAGE, + ) def test_metadata_token_used_even_when_request_has_token(self) -> None: """Metadata token should be authoritative when both sources exist.""" @@ -424,7 +449,7 @@ def test_clientappio_factory_uses_client_policy(self) -> None: response = cast( str, intercepted.unary_unary( - PushObjectRequest(object_id="obj", object_content=b"x"), + PushObjectRequest(run_id=1, object_id="obj", object_content=b"x"), Mock(), ), ) From b52b2a8e91119c24c50a4004299b9ef17bbbf148 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 20 Apr 2026 19:23:39 -0700 Subject: [PATCH 02/11] Pass missing run_id --- framework/py/flwr/supernode/runtime/run_clientapp.py | 9 +++++++-- .../clientappio/clientappio_auth_integration_test.py | 11 +++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/framework/py/flwr/supernode/runtime/run_clientapp.py b/framework/py/flwr/supernode/runtime/run_clientapp.py index 61e8af73ee7b..214fdefd92cc 100644 --- a/framework/py/flwr/supernode/runtime/run_clientapp.py +++ b/framework/py/flwr/supernode/runtime/run_clientapp.py @@ -213,7 +213,7 @@ def pull_appinputs( # Pull and inflate the message pull_msg_res: PullAppMessagesResponse = stub.PullMessage( - PullAppMessagesRequest(token=token) + PullAppMessagesRequest(token=token, run_id=run.run_id) ) run_id = context.run_id node = Node(node_id=context.node_id) @@ -259,6 +259,7 @@ def push_appoutputs( PushAppMessagesRequest( token=token, messages_list=[proto_message], + run_id=context.run_id, message_object_trees=[object_tree], ) ) @@ -282,7 +283,11 @@ def push_appoutputs( # Push Context res: PushAppOutputsResponse = stub.PushAppOutputs( - PushAppOutputsRequest(token=token, context=proto_context) + PushAppOutputsRequest( + token=token, + run_id=context.run_id, + context=proto_context, + ) ) return res except grpc.RpcError as e: diff --git a/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py b/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py index 5b356acc361f..2b2306348b3b 100644 --- a/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py +++ b/framework/py/flwr/supernode/servicer/clientappio/clientappio_auth_integration_test.py @@ -56,7 +56,8 @@ def setUp(self) -> None: state_factory = NodeStateFactory(objectstore_factory=objectstore_factory) state = state_factory.state() - token = state.create_token(99) + self.valid_run_id = 99 + token = state.create_token(self.valid_run_id) assert token is not None self.valid_token = token @@ -100,7 +101,9 @@ def tearDown(self) -> None: def test_pull_object_denied_without_metadata_token(self) -> None: """Protected RPC should deny requests missing metadata token.""" with self.assertRaises(grpc.RpcError) as err: - self._pull_object.with_call(request=PullObjectRequest(object_id="obj-1")) + self._pull_object.with_call( + request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-1") + ) assert err.exception.code() == grpc.StatusCode.UNAUTHENTICATED assert err.exception.details() == AUTHENTICATION_FAILED_MESSAGE @@ -108,7 +111,7 @@ def test_pull_object_denied_with_invalid_metadata_token(self) -> None: """Protected RPC should deny requests with invalid metadata token.""" with self.assertRaises(grpc.RpcError) as err: self._pull_object.with_call( - request=PullObjectRequest(object_id="obj-2"), + request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-2"), metadata=((APP_TOKEN_HEADER, "invalid-token"),), ) assert err.exception.code() == grpc.StatusCode.UNAUTHENTICATED @@ -117,7 +120,7 @@ def test_pull_object_denied_with_invalid_metadata_token(self) -> None: def test_pull_object_allows_with_valid_metadata_token(self) -> None: """Protected RPC should allow requests with valid metadata token.""" response, call = self._pull_object.with_call( - request=PullObjectRequest(object_id="obj-3"), + request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-3"), metadata=((APP_TOKEN_HEADER, self.valid_token),), ) From 1ea32706a1b69e3a42e25549e2656faeb2d86473 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 20 Apr 2026 19:24:44 -0700 Subject: [PATCH 03/11] Update framework/py/flwr/supercore/interceptors/appio_token_interceptor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../py/flwr/supercore/interceptors/appio_token_interceptor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index dcc950386877..2fde515a5897 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -141,7 +141,6 @@ def _authenticated_handler( state = self._state_provider() run_id = state.get_run_id_by_token(token) - # Validate both token->run lookup and run->token mapping. if run_id is None or not state.verify_token(run_id, token): _abort_auth_denied(context) From ac207b51313a4867b4d887b0587ab34598d5b38b Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 20 Apr 2026 19:58:07 -0700 Subject: [PATCH 04/11] Fix --- .../py/flwr/supercore/interceptors/appio_token_interceptor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index dcc950386877..afa4db63eda6 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -67,7 +67,7 @@ def _get_request_run_id(request: GrpcMessage) -> int | None: descriptor = getattr(request, "DESCRIPTOR", None) if descriptor is None or "run_id" not in descriptor.fields_by_name: return None - return cast(int, getattr(request, "run_id")) + return cast(int, request.run_id) class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore @@ -141,7 +141,7 @@ def _authenticated_handler( state = self._state_provider() run_id = state.get_run_id_by_token(token) - + # Validate both token->run lookup and run->token mapping. if run_id is None or not state.verify_token(run_id, token): _abort_auth_denied(context) From b030fdb137f9b90029db2c403a322aa19d4937da Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 20 Apr 2026 20:24:38 -0700 Subject: [PATCH 05/11] Fix --- .../py/flwr/supercore/interceptors/appio_token_interceptor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index afa4db63eda6..95149d0fa709 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -67,7 +67,8 @@ def _get_request_run_id(request: GrpcMessage) -> int | None: descriptor = getattr(request, "DESCRIPTOR", None) if descriptor is None or "run_id" not in descriptor.fields_by_name: return None - return cast(int, request.run_id) + request_any = cast(Any, request) + return cast(int, request_any.run_id) class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore From d39ce8de7ad7a81bb500ecd1112df6cd893456d0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 21 Apr 2026 18:51:04 +0100 Subject: [PATCH 06/11] fix serverappio servicer test --- .../serverappio/serverappio_servicer_test.py | 69 +++++++++---------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py index 26534761b05c..8d34d9a97478 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py @@ -316,6 +316,7 @@ def setUp(self) -> None: "mock_owner", "fake_name", self.node_pk, 30 ) self.state.acknowledge_node_heartbeat(self.node_id, 1e3) + self._run_id_to_token: dict[int, str] = {} self.status_to_msg = _STATUS_TO_MSG @@ -327,23 +328,10 @@ def setUp(self) -> None: superexec_auth_secret=_SUPEREXEC_SECRET, ) - # Provide a valid metadata token on the default test channel so existing - # servicer behavior tests continue to exercise business logic paths. - self._auth_run_id = self.state.create_run( - "", "", "", {}, NOOP_FEDERATION, None, "", RunType.SERVER_APP - ) - auth_token = self.state.create_token(self._auth_run_id) - assert auth_token is not None - self._auth_token = auth_token - _ = self.state.update_run_status( - self._auth_run_id, RunStatus(Status.STARTING, "", "") - ) - _ = self.state.update_run_status( - self._auth_run_id, RunStatus(Status.RUNNING, "", "") - ) + self._appio_token_client_interceptor = AppIoTokenClientInterceptor(token="") self._channel = grpc.intercept_channel( grpc.insecure_channel("localhost:9091"), - AppIoTokenClientInterceptor(token=self._auth_token), + self._appio_token_client_interceptor, SuperExecAuthClientInterceptor( master_secret=_SUPEREXEC_SECRET, protected_methods=SERVERAPPIO_SUPEREXEC_METHODS, @@ -422,7 +410,9 @@ def _transition_run_status(self, run_id: int, num_transitions: int) -> None: if num_transitions > 2: _ = self.state.update_run_status(run_id, RunStatus(Status.FINISHED, "", "")) - def _create_dummy_run(self, running: bool = True, *, fab_hash: str = "") -> int: + def _create_dummy_run( + self, running: bool = True, *, fab_hash: str = "", create_token: bool = True + ) -> int: run_id = self.state.create_run( "", "", @@ -433,10 +423,24 @@ def _create_dummy_run(self, running: bool = True, *, fab_hash: str = "") -> int: "", RunType.SERVER_APP, ) + + # Set token in the client interceptor to pass authentication + if create_token: + assert (token := self.state.create_token(run_id)) is not None + self._set_token(token) + self._run_id_to_token[run_id] = token + + # Transition run status if required if running: self._transition_run_status(run_id, 2) return run_id + def _get_token(self, run_id: int) -> str: + return self._run_id_to_token[run_id] + + def _set_token(self, token: str) -> None: + self._appio_token_client_interceptor._token = token # pylint: disable=W0212 + def test_successful_get_node_if_running(self) -> None: """Test `GetNode` success.""" # Prepare @@ -694,6 +698,8 @@ def test_pull_message_from_expired_message_error(self) -> None: of an Error message created by the LinkState due to an expired TTL.""" # Prepare run_id = self._create_dummy_run() + token = self._get_token(run_id) + self.state.acknowledge_app_heartbeat(token) # Push Messages and reply message_ins = message_from_proto( @@ -701,6 +707,7 @@ def test_pull_message_from_expired_message_error(self) -> None: src_node_id=SUPERLINK_NODE_ID, dst_node_id=self.node_id, run_id=run_id ) ) + message_ins.metadata.ttl = 1 # Use short message TTL for testing msg_id = self.state.store_message_ins(message=message_ins) # Simulate situation where the message has expired in the LinkState @@ -708,21 +715,10 @@ def test_pull_message_from_expired_message_error(self) -> None: future_dt = now() + timedelta(seconds=message_ins.metadata.ttl + 0.1) with patch("datetime.datetime") as mock_dt: mock_dt.now.return_value = future_dt # over TTL limit - - token = self.state.create_token(run_id) - assert token is not None request = PullAppMessagesRequest(message_ids=[str(msg_id)], run_id=run_id) - pull_messages_plain = grpc.insecure_channel("localhost:9091").unary_unary( - "/flwr.proto.ServerAppIo/PullMessages", - request_serializer=PullAppMessagesRequest.SerializeToString, - response_deserializer=PullAppMessagesResponse.FromString, - ) # Execute - response, call = pull_messages_plain.with_call( - request=request, - metadata=((APP_TOKEN_HEADER, token),), - ) + response, call = self._pull_messages.with_call(request=request) # Assert assert isinstance(response, PullAppMessagesResponse) @@ -742,8 +738,7 @@ def test_push_serverapp_outputs_successful_if_running(self) -> None: """Test `PushServerAppOutputs` success.""" # Prepare run_id = self._create_dummy_run(running=False) - token = self.state.create_token(run_id) - assert token is not None + token = self._get_token(run_id) maker = RecordMaker() context = Context( @@ -797,8 +792,7 @@ def test_push_serverapp_outputs_not_successful_if_not_running( """Test `PushServerAppOutputs` not successful if RunStatus is not running.""" # Prepare run_id = self._create_dummy_run(running=False) - token = self.state.create_token(run_id) - assert token is not None + token = self._get_token(run_id) maker = RecordMaker() context = Context( @@ -870,7 +864,7 @@ def test_update_run_status_not_successful_if_finished(self) -> None: def test_send_app_heartbeat(self, success: bool) -> None: """Test sending an app heartbeat.""" # Prepare - token = "test-token" + token = self._get_token(self._create_dummy_run()) request = SendAppHeartbeatRequest(token=token) mock_ack_method = Mock(return_value=success) self.state.acknowledge_app_heartbeat = mock_ack_method # type: ignore @@ -1067,7 +1061,7 @@ def test_list_apps_to_launch(self) -> None: def test_request_token(self) -> None: """Test `RequestToken`.""" # Prepare - run_id = self._create_dummy_run(running=False) + run_id = self._create_dummy_run(running=False, create_token=False) # Execute request = RequestTokenRequest(run_id=run_id) @@ -1087,7 +1081,7 @@ def test_request_token(self) -> None: def test_request_token_fail_closed_for_finished_run(self) -> None: """Ensure `RequestToken` returns empty token for finished runs.""" # Prepare - run_id = self._create_dummy_run(running=False) + run_id = self._create_dummy_run(running=False, create_token=False) self._transition_run_status(run_id, 2) assert self.state.update_run_status( run_id, @@ -1124,7 +1118,9 @@ def test_run_status_transitions(self) -> None: fab_hash = self.state.store_fab( Fab(hashlib.sha256(fab_content).hexdigest(), fab_content, {}) ) - run_id = self._create_dummy_run(running=False, fab_hash=fab_hash) + run_id = self._create_dummy_run( + running=False, fab_hash=fab_hash, create_token=False + ) # Set serverapp context context = Context(run_id, SUPERLINK_NODE_ID, {}, RecordDict(), {}) @@ -1134,6 +1130,7 @@ def test_run_status_transitions(self) -> None: token_request = RequestTokenRequest(run_id=run_id) token_response, call = self._request_token.with_call(request=token_request) token = token_response.token + self._set_token(token) # Assert: Response is successful and run status is STARTING assert isinstance(token_response, RequestTokenResponse) From b5af48a6bfb84c7e06394ac654ccb8cdad5e7582 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 4 May 2026 15:00:46 +0200 Subject: [PATCH 07/11] ci(framework): Add framework/AGENTS.md --- framework/AGENTS.md | 199 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 framework/AGENTS.md diff --git a/framework/AGENTS.md b/framework/AGENTS.md new file mode 100644 index 000000000000..6f9401c5384b --- /dev/null +++ b/framework/AGENTS.md @@ -0,0 +1,199 @@ +# AGENTS.md + +These instructions apply to the `framework/` subtree. Also follow the repository +root `AGENTS.md`. + +## Project shape + +- `py/flwr/` is the Python package. It is a typed package (`py.typed`) and the + main source of truth for framework behavior. +- Unit tests live next to the code as `*_test.py` under `py/flwr/`. Follow nearby + test style; many tests use `pytest` parametrization, some older tests use + `unittest`. Prefer `pytest` for new tests. +- `proto/flwr/proto/` contains protobuf sources. Generated Python protobuf files + live under `py/flwr/proto/`. +- `py/flwr/supercore/state/schema/` contains SQLAlchemy Core table metadata. + Alembic revisions live in `py/flwr/supercore/state/alembic/versions/`. +- `docs/source/` contains Sphinx docs. `docs/build/` is generated output. +- `e2e/` contains standalone Flower apps and shell scripts used by CI. +- `docker/`, `swift/`, `kotlin/`, and `cc/` contain non-Python packaging/client + surfaces. Keep Python framework changes separate from those unless the behavior + actually crosses the boundary. +- Avoid editing or committing generated/cache outputs such as `.venv/`, + `.mypy_cache/`, `.pytest_cache/`, `.ruff_cache/`, `docs/build/`, and `dist/`. + +## Environment and commands + +Run framework commands from `framework/` unless a command is explicitly shown from +the repository root. The checked-in `.python-version` is `3.10.19`; CI also checks +newer Python versions, but project tooling targets Python 3.10 syntax. + +Preferred local agent form: + +```bash +cd framework +uv run --no-sync --python=3.10.19 +``` + +If the environment is missing or dependencies changed, synchronize intentionally: + +```bash +cd framework +uv sync --python=3.10.19 --all-extras --all-groups +``` + +CI installs with Poetry, so when reproducing CI setup exactly use: + +```bash +cd framework +python -m poetry install --all-extras +``` + +Useful checks: + +```bash +# Fast package-only quality gate, skips e2e/docs/copyright extras in dev/test.sh +uv run --no-sync --python=3.10.19 ./dev/test.sh false + +# Full framework quality gate used by pre-commit/CI +uv run --no-sync --python=3.10.19 ./dev/test.sh + +# Narrow tests +uv run --no-sync --python=3.10.19 python -m pytest py/flwr/path/to_test.py +uv run --no-sync --python=3.10.19 python -m pytest py/flwr -k "name" + +# Targeted type/lint checks +uv run --no-sync --python=3.10.19 python -m mypy py +uv run --no-sync --python=3.10.19 python -m ruff check py/flwr --no-respect-gitignore +uv run --no-sync --python=3.10.19 python -m pylint --ignore=py/flwr/proto py/flwr +``` + +`dev/test.sh` sets `RAY_ENABLE_UV_RUN_RUNTIME_ENV=0` for pytest because Ray's uv +runtime-env hook can stall under `uv run`. Use the same environment variable when +debugging Ray/simulation tests directly. + +Format broadly only when appropriate: + +```bash +uv run --no-sync --python=3.10.19 ./dev/format.sh +``` + +For narrow edits, prefer targeted `isort`, `black`, `docformatter`, and `ruff` +commands on the touched files. Keep `py/flwr/proto/` excluded from Python +formatters because it is generated. + +## Python conventions + +- Keep core `flwr` framework code ML-framework-agnostic. Do not add PyTorch, + TensorFlow, JAX, sklearn, or similar dependencies to the core package for a + narrow feature; e2e apps and examples carry framework-specific dependencies. +- Use explicit, typed interfaces compatible with strict mypy. Avoid untyped + helpers unless nearby code already establishes the pattern. +- Use NumPy-style docstrings for public classes/functions. `ruff` enforces + pydocstyle with the NumPy convention, and `docsig` checks signatures. +- Keep imports consistent with `isort` and `black` line length 88. The project + commonly uses `from __future__ import annotations` in newer modules. +- New source files should include the Flower Apache license header matching + nearby files. +- Prefer existing helpers for logging, exit handling, serialization, config + parsing, and CLI errors. Do not introduce parallel utility layers without a + concrete need. + +## Public API rules + +Flower's Python public API is defined by recursively following `__all__` from +`flwr/__init__.py`. See `docs/source/contributor-explanation-public-and-private-apis.rst`. + +- Adding a public symbol usually requires importing it in the relevant + `__init__.py`, adding it to `__all__`, adding/updating API docs under + `docs/source/ref-api/`, and adding tests for the public import path. +- Do not expose implementation modules accidentally. Prefer + `from .module import Name as Name` plus `__all__ = ["Name"]`, as nearby code + does. +- Treat existing public behavior as compatibility-sensitive. If changing or + removing public APIs, add deprecation handling/tests instead of hard breaks + unless the task explicitly calls for a breaking change. + +## Protobufs + +- Edit `.proto` files in `proto/flwr/proto/`, not generated files in + `py/flwr/proto/`. +- Regenerate Python protobuf outputs with: + +```bash +uv run --no-sync --python=3.10.19 ./dev/protoc.sh +``` + +- CI runs `./framework/dev/check-protos.sh` from the repository root. That script + reruns generation and fails if generated files differ from `HEAD`, so in a + local dirty worktree it can report expected uncommitted generated changes. + Use it as a clean-tree/CI parity check. +- Wire-format changes need serialization/deserialization tests, usually near + `py/flwr/common/serde_test.py` or the module-specific test. + +## Database schema and migrations + +Use the Alembic generator for schema diffs. Do not hand-write a new migration for +a normal SQLAlchemy metadata change. + +```bash +cd framework +uv run --no-sync --python=3.10.19 python -m dev.generate_migration "Describe schema change" +``` + +After generation: + +- Confirm the new revision's `down_revision` and branch target are correct. +- Confirm generated operations match the SQLAlchemy metadata change. +- Review SQLite compatibility, especially `batch_alter_table` blocks. +- Update `py/flwr/supercore/state/schema/README.md` when table metadata changes. + `dev/format.sh` regenerates this schema documentation through `paracelsus`. +- Run the migration check when schema work is involved: + +```bash +uv run --no-sync --python=3.10.19 ./dev/check-migrations.sh +``` + +## Tests and e2e + +- Put focused unit coverage next to the changed module. Prefer a narrow pytest + command first, then broader checks. +- For CLI changes, add tests around Typer command parsing and removed/deprecated + flags. Existing tests under `py/flwr/cli/`, `py/flwr/supernode/cli/`, and + `py/flwr/supercore/cli/` show the expected patterns. +- For SuperLink/SuperNode/state changes, consider targeted tests under + `py/flwr/server/superlink/`, `py/flwr/supernode/`, and `py/flwr/supercore/`. +- E2E scripts often mutate app-local `pyproject.toml`, generate certs, create + sqlite DB files, and start background processes. Do not run them in an e2e app + directory with unrelated local edits. +- Common e2e commands, from an app directory such as `framework/e2e/e2e-bare/`: + +```bash +python simulation.py +./../test_superlink.sh e2e-bare +./../test_superlink.sh e2e-bare rest +./../test_superlink.sh e2e-bare sqlite +``` + +## Docs, packaging, and locks + +- Build docs with: + +```bash +uv run --no-sync --python=3.10.19 ./dev/build-docs.sh +``` + +Docs builds require system `pandoc`. Do not commit `docs/build/` output. + +- Build and check release artifacts with: + +```bash +uv run --no-sync --python=3.10.19 ./dev/build.sh +uv run --no-sync --python=3.10.19 ./dev/test-wheel.sh +``` + +Do not commit `dist/` artifacts. + +- If dependency constraints change, update `pyproject.toml`, `uv.lock`, and + `poetry.lock` intentionally. CI checks `uv.lock` freshness, while framework CI + still installs with Poetry. From 83a67e26740b2c664627a09a3baf93c72c6ba12d Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 4 May 2026 18:43:03 +0200 Subject: [PATCH 08/11] Update test --- .../superlink/serverappio/serverappio_servicer_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py index 2ff8d01b16c5..a3542e58fd5e 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py @@ -566,11 +566,18 @@ def test_create_task_rejects_missing_required_fields( def test_create_task_rejects_missing_run(self) -> None: """Test `CreateTask` rejects unknown run IDs.""" + # Seed a stale token binding so auth passes and the servicer still + # exercises the missing-run branch. + run_id = 42 + token = self.state.create_token(run_id) + assert token is not None + self._set_token(token) + with self.assertRaises(grpc.RpcError) as err: self._create_task.with_call( request=CreateTaskRequest( type=TaskType.MODEL, - run_id=42, + run_id=run_id, model_ref="model://test", ) ) From 5e0c5899f3ee1fb627633ed35fa931eae1aae591 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Mon, 4 May 2026 18:55:47 +0200 Subject: [PATCH 09/11] Update framework/py/flwr/supercore/interceptors/appio_token_interceptor.py Co-authored-by: Heng Pan --- .../flwr/supercore/interceptors/appio_token_interceptor.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index 95149d0fa709..e01d7c3fa4ae 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -64,11 +64,7 @@ def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMess def _get_request_run_id(request: GrpcMessage) -> int | None: - descriptor = getattr(request, "DESCRIPTOR", None) - if descriptor is None or "run_id" not in descriptor.fields_by_name: - return None - request_any = cast(Any, request) - return cast(int, request_any.run_id) + return cast(int, getattr(request, "run_id", None)) class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore From a0531d0f70b57f7087ffd228fe2683bf0a0460f0 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 7 May 2026 21:17:13 +0200 Subject: [PATCH 10/11] Simplfy tests --- .../serverappio/serverappio_servicer_test.py | 91 ++++--------------- 1 file changed, 17 insertions(+), 74 deletions(-) diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py index 888dfa012cb2..cc4afb3e85ff 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_servicer_test.py @@ -33,11 +33,9 @@ SERVERAPPIO_API_DEFAULT_SERVER_ADDRESS, SUPERLINK_NODE_ID, Status, - SubStatus, ) from flwr.common.message import get_message_to_descendant_id_mapping from flwr.common.serde import context_to_proto, message_from_proto -from flwr.common.serde_test import RecordMaker from flwr.common.typing import Fab, RunStatus from flwr.proto.appio_pb2 import ( # pylint: disable=E0611 ClaimTaskRequest, @@ -323,7 +321,6 @@ def setUp(self) -> None: "mock_owner", "fake_name", self.node_pk, 30 ) self.state.acknowledge_node_heartbeat(self.node_id, 1e3) - self._run_id_to_token: dict[int, str] = {} self.status_to_msg = _STATUS_TO_MSG @@ -407,9 +404,7 @@ def _transition_run_status(self, run_id: int, num_transitions: int) -> None: if num_transitions > 2: _ = self.state.update_run_status(run_id, RunStatus(Status.FINISHED, "", "")) - def _create_dummy_run( - self, running: bool = True, *, fab_hash: str = "", create_token: bool = True - ) -> int: + def _create_dummy_run(self, running: bool = True, *, fab_hash: str = "") -> int: run_id = self.state.create_run( "", "", @@ -422,30 +417,17 @@ def _create_dummy_run( ) # Set token in the client interceptor to pass authentication - if create_token: - assert (token := self.state.create_token(run_id)) is not None - self._set_token(token) - self._run_id_to_token[run_id] = token + assert (token := self.state.create_token(run_id)) is not None + self._set_token(token) # Transition run status if required if running: self._transition_run_status(run_id, 2) return run_id - def _get_token(self, run_id: int) -> str: - return self._run_id_to_token[run_id] - def _set_token(self, token: str) -> None: self._appio_token_client_interceptor._token = token # pylint: disable=W0212 - def _create_claimed_task_token(self, run_id: int) -> tuple[int, str]: - task_id = self.state.create_task(task_type=TaskType.SERVER_APP, run_id=run_id) - assert task_id is not None - token = self.state.claim_task(task_id) - assert token - self._set_token(token) - return task_id, token - def test_successful_get_node_if_running(self) -> None: """Test `GetNode` success.""" # Prepare @@ -822,8 +804,6 @@ def test_pull_message_from_expired_message_error(self) -> None: of an Error message created by the LinkState due to an expired TTL.""" # Prepare run_id = self._create_dummy_run() - token = self._get_token(run_id) - self.state.acknowledge_app_heartbeat(token) # Push Messages and reply message_ins = message_from_proto( @@ -831,7 +811,7 @@ def test_pull_message_from_expired_message_error(self) -> None: src_node_id=SUPERLINK_NODE_ID, dst_node_id=self.node_id, run_id=run_id ) ) - message_ins.metadata.ttl = 1 # Use short message TTL for testing + message_ins.metadata.ttl = 1 # set short TTL for testing msg_id = self.state.store_message_ins(message=message_ins) # Simulate situation where the message has expired in the LinkState @@ -839,9 +819,9 @@ def test_pull_message_from_expired_message_error(self) -> None: future_dt = now() + timedelta(seconds=message_ins.metadata.ttl + 0.1) with patch("datetime.datetime") as mock_dt: mock_dt.now.return_value = future_dt # over TTL limit - request = PullAppMessagesRequest(message_ids=[str(msg_id)], run_id=run_id) # Execute + request = PullAppMessagesRequest(message_ids=[str(msg_id)], run_id=run_id) response, call = self._pull_messages.with_call(request=request) # Assert @@ -858,58 +838,21 @@ def test_pull_message_from_expired_message_error(self) -> None: # expected a single object id (that of the error message) assert list(object_ids_in_response) == [msg_res.object_id] - def test_push_serverapp_outputs_successful_if_running(self) -> None: - """Test `PushServerAppOutputs` success.""" - # Prepare - run_id = self._create_dummy_run(running=False, create_token=False) - task_id, _ = self._create_claimed_task_token(run_id) - assert self.state.activate_task(task_id) - - maker = RecordMaker() - context = Context( - run_id=run_id, - node_id=0, - node_config=maker.user_config(), - state=maker.recorddict(1, 1, 1), - run_config=maker.user_config(), - ) - - # Keep run status aligned with the claimed task lifecycle. - self._transition_run_status(run_id, 2) + def _assert_push_serverapp_outputs_not_allowed( + self, token: str, context: Context + ) -> None: + """Assert `PushServerAppOutputs` not allowed.""" + run_id = self.state.get_run_id_by_token(token) + assert run_id is not None, "Invalid token is provided." + run_status = self.state.get_run_status({run_id})[run_id] request = PushAppOutputsRequest( - run_id=run_id, - context=context_to_proto(context), - sub_status=SubStatus.COMPLETED, - details="", + token=token, run_id=run_id, context=context_to_proto(context) ) - # Execute - response, call = self._push_serverapp_outputs.with_call(request=request) - - # Assert - assert isinstance(response, PushAppOutputsResponse) - assert grpc.StatusCode.OK == call.code() - task = self.state.get_tasks(task_ids=[task_id])[0] - self.assertEqual(task.status.status, Status.FINISHED) - self.assertEqual(task.status.sub_status, SubStatus.COMPLETED) - - @parameterized.expand([(True,), (False,)]) # type: ignore - def test_send_task_heartbeat(self, success: bool) -> None: - """Test sending a task heartbeat.""" - # Prepare - run_id = self._create_dummy_run(create_token=False) - task_id, _ = self._create_claimed_task_token(run_id) - request = SendTaskHeartbeatRequest() - mock_ack_method = Mock(return_value=success) - self.state.acknowledge_task_heartbeat = mock_ack_method # type: ignore - - # Execute - response, _ = self._send_task_heartbeat.with_call(request=request) - - # Assert - self.assertIsInstance(response, SendTaskHeartbeatResponse) - self.assertEqual(response.success, success) - mock_ack_method.assert_called_once_with(task_id) + with self.assertRaises(grpc.RpcError) as e: + self._push_serverapp_outputs.with_call(request=request) + assert e.exception.code() == grpc.StatusCode.PERMISSION_DENIED + assert e.exception.details() == self.status_to_msg[run_status.status] def test_push_object_succesful(self) -> None: """Test `PushObject`.""" From 89a20cb90263a47c8c7213efff36161b3b6372c2 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Thu, 7 May 2026 22:19:06 +0200 Subject: [PATCH 11/11] Handle omitted run_id --- .../interceptors/appio_token_interceptor.py | 6 ++- .../appio_token_interceptor_test.py | 50 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py index 572a65230ce5..891e70f9dd8c 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor.py @@ -72,7 +72,11 @@ def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMess def _get_request_run_id(request: GrpcMessage) -> int | None: - return cast(int, getattr(request, "run_id", None)) + run_id = cast(int | None, getattr(request, "run_id", None)) + # Proto3 scalar fields without presence expose an omitted uint64 as 0. + if run_id == 0: + return None + return run_id class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore diff --git a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py index 19df0db7f78e..2a2ce5f6c303 100644 --- a/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py +++ b/framework/py/flwr/supercore/interceptors/appio_token_interceptor_test.py @@ -315,6 +315,56 @@ def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: state.get_run_id_by_token.assert_called_once_with("task-token") state.get_task_by_token.assert_called_once_with("task-token") + def test_task_token_passes_when_run_id_omitted(self) -> None: + """Task-token protected methods should accept omitted request run IDs.""" + state = Mock() + state.get_run_id_by_token.return_value = None + state.get_task_by_token.return_value = Mock(task_id=123, run_id=7) + interceptor = create_serverappio_token_auth_server_interceptor(lambda: state) + captured_task = None + + def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: + nonlocal captured_task + captured_task = get_authenticated_task() + return "ok" + + intercepted = interceptor.intercept_service( + lambda _: grpc.unary_unary_rpc_method_handler(_handler), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/PushAppOutputs", + invocation_metadata=((APP_TOKEN_HEADER, "task-token"),), + ), + ) + + response = intercepted.unary_unary(PushAppOutputsRequest(), Mock()) + self.assertEqual(response, "ok") + self.assertIsNotNone(captured_task) + self.assertEqual(cast(Mock, captured_task).task_id, 123) + + def test_task_token_run_id_mismatch_denied(self) -> None: + """Task-token protected methods should deny a different request run.""" + state = Mock() + state.get_run_id_by_token.return_value = None + state.get_task_by_token.return_value = Mock(task_id=123, run_id=7) + interceptor = create_serverappio_token_auth_server_interceptor(lambda: state) + context = Mock() + context.abort.side_effect = grpc.RpcError() + + intercepted = interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/PushAppOutputs", + invocation_metadata=((APP_TOKEN_HEADER, "task-token"),), + ), + ) + + with self.assertRaises(grpc.RpcError): + intercepted.unary_unary(PushAppOutputsRequest(run_id=8), context) + context.abort.assert_called_once_with( + grpc.StatusCode.PERMISSION_DENIED, + RUN_BINDING_FAILED_MESSAGE, + ) + def test_metadata_token_used_even_when_request_has_token(self) -> None: """Metadata token should be authoritative when both sources exist.""" interceptor = self._new_interceptor(token_to_run_id={"metadata-token": 5})