diff --git a/framework/proto/flwr/proto/run.proto b/framework/proto/flwr/proto/run.proto index 064d4787cc69..9c0c9806be4b 100644 --- a/framework/proto/flwr/proto/run.proto +++ b/framework/proto/flwr/proto/run.proto @@ -38,6 +38,7 @@ message Run { uint64 bytes_recv = 14; double clientapp_runtime = 15; string run_type = 16; + optional uint64 primary_task_id = 17; } message RunStatus { diff --git a/framework/py/flwr/common/serde.py b/framework/py/flwr/common/serde.py index 5a2cade30b56..1f65077bbf4c 100644 --- a/framework/py/flwr/common/serde.py +++ b/framework/py/flwr/common/serde.py @@ -639,6 +639,8 @@ def run_to_proto(run: typing.Run) -> ProtoRun: clientapp_runtime=run.clientapp_runtime, run_type=run.run_type, ) + if run.primary_task_id is not None: + proto.primary_task_id = run.primary_task_id return proto @@ -657,6 +659,9 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run: status=run_status_from_proto(run_proto.status), flwr_aid=run_proto.flwr_aid, federation=run_proto.federation, + primary_task_id=( + run_proto.primary_task_id if run_proto.HasField("primary_task_id") else None + ), bytes_sent=run_proto.bytes_sent, bytes_recv=run_proto.bytes_recv, clientapp_runtime=run_proto.clientapp_runtime, diff --git a/framework/py/flwr/common/serde_test.py b/framework/py/flwr/common/serde_test.py index 44766f23ca33..3eed5324bbbf 100644 --- a/framework/py/flwr/common/serde_test.py +++ b/framework/py/flwr/common/serde_test.py @@ -486,6 +486,7 @@ def test_run_serialization_deserialization() -> None: status=typing.RunStatus(status="running", sub_status="", details="OK"), flwr_aid="user123", federation="mock-fed", + primary_task_id=42, bytes_sent=2048, bytes_recv=1024, clientapp_runtime=3.14, diff --git a/framework/py/flwr/common/typing.py b/framework/py/flwr/common/typing.py index f9efaf313948..e8f43820d4e5 100644 --- a/framework/py/flwr/common/typing.py +++ b/framework/py/flwr/common/typing.py @@ -237,6 +237,7 @@ class Run: # pylint: disable=too-many-instance-attributes status: RunStatus flwr_aid: str federation: str + primary_task_id: int | None bytes_sent: int bytes_recv: int clientapp_runtime: float @@ -258,6 +259,7 @@ def create_empty(cls, run_id: int) -> "Run": status=RunStatus(status="", sub_status="", details=""), flwr_aid="", federation="", + primary_task_id=None, bytes_sent=0, bytes_recv=0, clientapp_runtime=0.0, diff --git a/framework/py/flwr/proto/run_pb2.py b/framework/py/flwr/proto/run_pb2.py index fd944a83bb6e..25154b108187 100644 --- a/framework/py/flwr/proto/run_pb2.py +++ b/framework/py/flwr/proto/run_pb2.py @@ -27,7 +27,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xc9\x03\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x10\n\x08\x66lwr_aid\x18\x0b \x01(\t\x12\x12\n\nfederation\x18\x0c \x01(\t\x12\x12\n\nbytes_sent\x18\r \x01(\x04\x12\x12\n\nbytes_recv\x18\x0e \x01(\x04\x12\x19\n\x11\x63lientapp_runtime\x18\x0f \x01(\x01\x12\x10\n\x08run_type\x18\x10 \x01(\t\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"T\n\x1cGetFederationOptionsResponse\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x18.flwr.proto.ConfigRecordb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1b\x66lwr/proto/recorddict.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xfb\x03\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x12\x10\n\x08\x66\x61\x62_hash\x18\x05 \x01(\t\x12\x12\n\npending_at\x18\x06 \x01(\t\x12\x13\n\x0bstarting_at\x18\x07 \x01(\t\x12\x12\n\nrunning_at\x18\x08 \x01(\t\x12\x13\n\x0b\x66inished_at\x18\t \x01(\t\x12%\n\x06status\x18\n \x01(\x0b\x32\x15.flwr.proto.RunStatus\x12\x10\n\x08\x66lwr_aid\x18\x0b \x01(\t\x12\x12\n\nfederation\x18\x0c \x01(\t\x12\x12\n\nbytes_sent\x18\r \x01(\x04\x12\x12\n\nbytes_recv\x18\x0e \x01(\x04\x12\x19\n\x11\x63lientapp_runtime\x18\x0f \x01(\x01\x12\x10\n\x08run_type\x18\x10 \x01(\t\x12\x1c\n\x0fprimary_task_id\x18\x11 \x01(\x04H\x00\x88\x01\x01\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x42\x12\n\x10_primary_task_id\"@\n\tRunStatus\x12\x0e\n\x06status\x18\x01 \x01(\t\x12\x12\n\nsub_status\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"?\n\rGetRunRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Run\"S\n\x16UpdateRunStatusRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12)\n\nrun_status\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RunStatus\"\x19\n\x17UpdateRunStatusResponse\"-\n\x1bGetFederationOptionsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"T\n\x1cGetFederationOptionsResponse\x12\x34\n\x12\x66\x65\x64\x65ration_options\x18\x01 \x01(\x0b\x32\x18.flwr.proto.ConfigRecordb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -37,21 +37,21 @@ _globals['_RUN_OVERRIDECONFIGENTRY']._loaded_options = None _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001' _globals['_RUN']._serialized_start=117 - _globals['_RUN']._serialized_end=574 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=501 - _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=574 - _globals['_RUNSTATUS']._serialized_start=576 - _globals['_RUNSTATUS']._serialized_end=640 - _globals['_GETRUNREQUEST']._serialized_start=642 - _globals['_GETRUNREQUEST']._serialized_end=705 - _globals['_GETRUNRESPONSE']._serialized_start=707 - _globals['_GETRUNRESPONSE']._serialized_end=753 - _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=755 - _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=838 - _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=840 - _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=865 - _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=867 - _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=912 - _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=914 - _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=998 + _globals['_RUN']._serialized_end=624 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=531 + _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=604 + _globals['_RUNSTATUS']._serialized_start=626 + _globals['_RUNSTATUS']._serialized_end=690 + _globals['_GETRUNREQUEST']._serialized_start=692 + _globals['_GETRUNREQUEST']._serialized_end=755 + _globals['_GETRUNRESPONSE']._serialized_start=757 + _globals['_GETRUNRESPONSE']._serialized_end=803 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_start=805 + _globals['_UPDATERUNSTATUSREQUEST']._serialized_end=888 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_start=890 + _globals['_UPDATERUNSTATUSRESPONSE']._serialized_end=915 + _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_start=917 + _globals['_GETFEDERATIONOPTIONSREQUEST']._serialized_end=962 + _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_start=964 + _globals['_GETFEDERATIONOPTIONSRESPONSE']._serialized_end=1048 # @@protoc_insertion_point(module_scope) diff --git a/framework/py/flwr/proto/run_pb2.pyi b/framework/py/flwr/proto/run_pb2.pyi index c4d9110db8ef..78469f0a1341 100644 --- a/framework/py/flwr/proto/run_pb2.pyi +++ b/framework/py/flwr/proto/run_pb2.pyi @@ -67,6 +67,7 @@ class Run(google.protobuf.message.Message): BYTES_RECV_FIELD_NUMBER: builtins.int CLIENTAPP_RUNTIME_FIELD_NUMBER: builtins.int RUN_TYPE_FIELD_NUMBER: builtins.int + PRIMARY_TASK_ID_FIELD_NUMBER: builtins.int run_id: builtins.int fab_id: builtins.str fab_version: builtins.str @@ -81,6 +82,7 @@ class Run(google.protobuf.message.Message): bytes_recv: builtins.int clientapp_runtime: builtins.float run_type: builtins.str + primary_task_id: builtins.int @property def override_config(self) -> google.protobuf.internal.containers.MessageMap[builtins.str, flwr.proto.transport_pb2.Scalar]: ... @property @@ -104,9 +106,11 @@ class Run(google.protobuf.message.Message): bytes_recv: builtins.int = ..., clientapp_runtime: builtins.float = ..., run_type: builtins.str = ..., + primary_task_id: builtins.int | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["status", b"status"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["bytes_recv", b"bytes_recv", "bytes_sent", b"bytes_sent", "clientapp_runtime", b"clientapp_runtime", "fab_hash", b"fab_hash", "fab_id", b"fab_id", "fab_version", b"fab_version", "federation", b"federation", "finished_at", b"finished_at", "flwr_aid", b"flwr_aid", "override_config", b"override_config", "pending_at", b"pending_at", "run_id", b"run_id", "run_type", b"run_type", "running_at", b"running_at", "starting_at", b"starting_at", "status", b"status"]) -> None: ... + def HasField(self, field_name: typing.Literal["_primary_task_id", b"_primary_task_id", "primary_task_id", b"primary_task_id", "status", b"status"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_primary_task_id", b"_primary_task_id", "bytes_recv", b"bytes_recv", "bytes_sent", b"bytes_sent", "clientapp_runtime", b"clientapp_runtime", "fab_hash", b"fab_hash", "fab_id", b"fab_id", "fab_version", b"fab_version", "federation", b"federation", "finished_at", b"finished_at", "flwr_aid", b"flwr_aid", "override_config", b"override_config", "pending_at", b"pending_at", "primary_task_id", b"primary_task_id", "run_id", b"run_id", "run_type", b"run_type", "running_at", b"running_at", "starting_at", b"starting_at", "status", b"status"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_primary_task_id", b"_primary_task_id"]) -> typing.Literal["primary_task_id"] | None: ... global___Run = Run diff --git a/framework/py/flwr/server/grid/inmemory_grid_test.py b/framework/py/flwr/server/grid/inmemory_grid_test.py index d36b4588730f..0ad12a1a990a 100644 --- a/framework/py/flwr/server/grid/inmemory_grid_test.py +++ b/framework/py/flwr/server/grid/inmemory_grid_test.py @@ -104,6 +104,7 @@ def setUp(self) -> None: status=RunStatus(status=Status.PENDING, sub_status="", details=""), flwr_aid="user123", federation="mock-fed", + primary_task_id=None, bytes_sent=0, bytes_recv=0, clientapp_runtime=0.0, diff --git a/framework/py/flwr/server/superlink/fleet/vce/vce_api_test.py b/framework/py/flwr/server/superlink/fleet/vce/vce_api_test.py index 779785a367ab..c41c9ec1057e 100644 --- a/framework/py/flwr/server/superlink/fleet/vce/vce_api_test.py +++ b/framework/py/flwr/server/superlink/fleet/vce/vce_api_test.py @@ -136,6 +136,7 @@ def register_messages_into_state( ), flwr_aid="user123", federation="mock-fed", + primary_task_id=None, bytes_sent=0, bytes_recv=0, clientapp_runtime=0.0, diff --git a/framework/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/framework/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index eb956ad6d304..517a05a4aa75 100644 --- a/framework/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/framework/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -571,6 +571,7 @@ def create_run( ), flwr_aid=flwr_aid if flwr_aid else "", federation=federation, + primary_task_id=None, bytes_sent=0, bytes_recv=0, clientapp_runtime=0.0, @@ -755,12 +756,12 @@ def acknowledge_node_heartbeat( return True return False - def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None: + def _on_tokens_expired(self, expired_records: list[tuple[int, int]]) -> None: """Transition runs with expired tokens to failed status. Parameters ---------- - expired_records : list[tuple[int, float]] + expired_records : list[tuple[int, int]] List of tuples containing (run_id, active_until timestamp) for expired tokens. """ diff --git a/framework/py/flwr/server/superlink/linkstate/sql_linkstate.py b/framework/py/flwr/server/superlink/linkstate/sql_linkstate.py index 621fd74ae746..b806c8d0e7f4 100644 --- a/framework/py/flwr/server/superlink/linkstate/sql_linkstate.py +++ b/framework/py/flwr/server/superlink/linkstate/sql_linkstate.py @@ -815,14 +815,15 @@ def create_run( # pylint: disable=R0913, R0914, R0917 query = """ INSERT INTO run (run_id, fab_id, fab_version, fab_hash, override_config, federation, - federation_config, run_type, pending_at, starting_at, running_at, - finished_at, usage_reported_at, sub_status, details, flwr_aid, - bytes_sent, bytes_recv, clientapp_runtime) + primary_task_id, federation_config, run_type, pending_at, + starting_at, running_at, finished_at, usage_reported_at, + sub_status, details, flwr_aid, bytes_sent, bytes_recv, + clientapp_runtime) VALUES (:run_id, :fab_id, :fab_version, :fab_hash, :override_config, - :federation, :federation_config, :run_type, :pending_at, - :starting_at, :running_at, :finished_at, :usage_reported_at, - :sub_status, :details, :flwr_aid, :bytes_sent, :bytes_recv, - :clientapp_runtime) + :federation, :primary_task_id, :federation_config, :run_type, + :pending_at, :starting_at, :running_at, :finished_at, + :usage_reported_at, :sub_status, :details, :flwr_aid, + :bytes_sent, :bytes_recv, :clientapp_runtime) """ override_config_json = json.dumps(override_config) params = { @@ -832,6 +833,7 @@ def create_run( # pylint: disable=R0913, R0914, R0917 "fab_hash": fab_hash or "", "override_config": override_config_json, "federation": federation, + "primary_task_id": None, "federation_config": fed_config_json, "run_type": run_type, "pending_at": now().isoformat(), @@ -954,6 +956,11 @@ def get_run_info( # pylint: disable=too-many-arguments, too-many-branches ), flwr_aid=row["flwr_aid"], federation=row["federation"], + primary_task_id=( + int64_to_uint64(row["primary_task_id"]) + if row["primary_task_id"] is not None + else None + ), bytes_sent=row["bytes_sent"], bytes_recv=row["bytes_recv"], clientapp_runtime=row["clientapp_runtime"], diff --git a/framework/py/flwr/server/superlink/serverappio/serverappio_grpc.py b/framework/py/flwr/server/superlink/serverappio/serverappio_grpc.py index f7994528661b..156a8e919adb 100644 --- a/framework/py/flwr/server/superlink/serverappio/serverappio_grpc.py +++ b/framework/py/flwr/server/superlink/serverappio/serverappio_grpc.py @@ -27,6 +27,7 @@ ) from flwr.server.superlink.linkstate import LinkStateFactory from flwr.supercore.interceptors import ( + create_serverappio_runtime_version_server_interceptor, create_serverappio_superexec_auth_server_interceptor, create_serverappio_token_auth_server_interceptor, ) @@ -69,6 +70,7 @@ def run_serverappio_api_grpc( # pylint: disable=R0913,R0917 master_secret=superexec_auth_secret, ) ) + interceptors.append(create_serverappio_runtime_version_server_interceptor()) serverappio_add_servicer_to_server_fn = add_ServerAppIoServicer_to_server serverappio_grpc_server = generic_create_grpc_server( servicer_and_add_fn=( diff --git a/framework/py/flwr/simulation/simulationio_connection.py b/framework/py/flwr/simulation/simulationio_connection.py index f052c7cb93b9..df9d96397843 100644 --- a/framework/py/flwr/simulation/simulationio_connection.py +++ b/framework/py/flwr/simulation/simulationio_connection.py @@ -25,7 +25,10 @@ from flwr.common.logger import log from flwr.common.retry_invoker import make_simple_grpc_retry_invoker, wrap_stub from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611 -from flwr.supercore.interceptors import AppIoTokenClientInterceptor +from flwr.supercore.interceptors import ( + AppIoTokenClientInterceptor, + RuntimeVersionClientInterceptor, +) class SimulationIoConnection: @@ -85,7 +88,10 @@ def _connect(self) -> None: server_address=self._addr, insecure=self._insecure, root_certificates=self._cert, - interceptors=[AppIoTokenClientInterceptor(token=self._token)], + interceptors=[ + RuntimeVersionClientInterceptor(component_name="flwr-simulation"), + AppIoTokenClientInterceptor(token=self._token), + ], ) self._channel.subscribe(on_channel_state_change) self._grpc_stub = ServerAppIoStub(self._channel) diff --git a/framework/py/flwr/simulation/simulationio_connection_test.py b/framework/py/flwr/simulation/simulationio_connection_test.py index f5b6a235a92b..24f9cbbdb3ee 100644 --- a/framework/py/flwr/simulation/simulationio_connection_test.py +++ b/framework/py/flwr/simulation/simulationio_connection_test.py @@ -18,7 +18,10 @@ import unittest from unittest.mock import Mock, patch -from flwr.supercore.interceptors import AppIoTokenClientInterceptor +from flwr.supercore.interceptors import ( + AppIoTokenClientInterceptor, + RuntimeVersionClientInterceptor, +) from .simulationio_connection import SimulationIoConnection @@ -29,13 +32,13 @@ class TestSimulationIoConnection(unittest.TestCase): @patch("flwr.simulation.simulationio_connection.wrap_stub") @patch("flwr.simulation.simulationio_connection.ServerAppIoStub") @patch("flwr.simulation.simulationio_connection.create_channel") - def test_connect_adds_client_interceptor( + def test_connect_adds_client_interceptors( self, mock_create_channel: Mock, _mock_serverappio_stub: Mock, _mock_wrap_stub: Mock, ) -> None: - """`_connect` should pass the token interceptor to create_channel.""" + """`_connect` should pass version and token interceptors to create_channel.""" mock_create_channel.return_value = Mock() conn = SimulationIoConnection(token="test-token") @@ -45,8 +48,9 @@ def test_connect_adds_client_interceptor( interceptors = kwargs["interceptors"] self.assertIsNotNone(interceptors) assert interceptors is not None - self.assertEqual(len(interceptors), 1) - self.assertIsInstance(interceptors[0], AppIoTokenClientInterceptor) + self.assertEqual(len(interceptors), 2) + self.assertIsInstance(interceptors[0], RuntimeVersionClientInterceptor) + self.assertIsInstance(interceptors[1], AppIoTokenClientInterceptor) def test_init_requires_token(self) -> None: """`SimulationIoConnection` should require token values.""" diff --git a/framework/py/flwr/supercore/constant.py b/framework/py/flwr/supercore/constant.py index 0510f1c0924a..eeea029f3915 100644 --- a/framework/py/flwr/supercore/constant.py +++ b/framework/py/flwr/supercore/constant.py @@ -132,6 +132,11 @@ MIN_TIMESTAMP_DIFF_SECONDS = -SYSTEM_TIME_TOLERANCE MAX_TIMESTAMP_DIFF_SECONDS = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE +# Constants for Flower runtime version metadata +FLWR_PACKAGE_NAME_METADATA_KEY = "flwr-package-name" +FLWR_PACKAGE_VERSION_METADATA_KEY = "flwr-package-version" +FLWR_COMPONENT_NAME_METADATA_KEY = "flwr-component-name" +VERSION_INCOMPATIBILITY_MESSAGE_METADATA_KEY = "flwr-version-incompatibility-message" # System message type SYSTEM_MESSAGE_TYPE = "system" diff --git a/framework/py/flwr/supercore/corestate/corestate.py b/framework/py/flwr/supercore/corestate/corestate.py index 903db7113a99..baf5f5b1db94 100644 --- a/framework/py/flwr/supercore/corestate/corestate.py +++ b/framework/py/flwr/supercore/corestate/corestate.py @@ -114,6 +114,92 @@ def get_tasks( # pylint: disable=too-many-arguments filters. """ + @abstractmethod + def claim_task(self, task_id: int) -> str | None: + """Atomically claim a pending task. + + Claiming a task creates a task token, initializes heartbeat state, and + moves the task from pending to starting. + + Parameters + ---------- + task_id : int + The ID of the task to claim. + + Returns + ------- + Optional[str] + The generated task token if the claim succeeds, otherwise `None`. + """ + + @abstractmethod + def activate_task(self, task_id: int) -> bool: + """Move a task from starting to running. + + Parameters + ---------- + task_id : int + The ID of the task to activate. + + Returns + ------- + bool + True if the task existed and transitioned from starting to running, + otherwise False. + """ + + @abstractmethod + def finish_task(self, task_id: int, sub_status: str, details: str) -> bool: + """Move an unfinished task to finished. + + Parameters + ---------- + task_id : int + The ID of the task to finish. + sub_status : str + Terminal task sub-status, such as completed, failed, or stopped. + Only RUNNING status can be transitioned to FINISHED:COMPLETED + details : str + Additional terminal status details. + + Returns + ------- + bool + True if the task existed and was not already finished, otherwise + False. + """ + + @abstractmethod + def acknowledge_task_heartbeat(self, task_id: int) -> bool: + """Extend heartbeat state for the claimed task. + + Parameters + ---------- + task_id : int + The ID of the task whose heartbeat should be acknowledged. + + Returns + ------- + bool + True if the task heartbeat was acknowledged successfully, otherwise + False. + """ + + @abstractmethod + def get_task_id_by_token(self, token: str) -> int | None: + """Return the task ID associated with the task token, if valid. + + Parameters + ---------- + token : str + The task token to look up. + + Returns + ------- + Optional[int] + The task ID if the token is valid, otherwise None. + """ + @abstractmethod def create_token(self, run_id: int) -> str | None: """Create a token for the given run ID. diff --git a/framework/py/flwr/supercore/corestate/corestate_test.py b/framework/py/flwr/supercore/corestate/corestate_test.py index 55e4a474ce15..bf3792a17e8f 100644 --- a/framework/py/flwr/supercore/corestate/corestate_test.py +++ b/framework/py/flwr/supercore/corestate/corestate_test.py @@ -20,14 +20,21 @@ from typing import Any, cast from unittest.mock import patch +from parameterized import parameterized + from flwr.common import now -from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL, Status +from flwr.common.constant import ( + HEARTBEAT_DEFAULT_INTERVAL, + HEARTBEAT_PATIENCE, + Status, + SubStatus, +) from flwr.proto.task_pb2 import TaskStatus # pylint: disable=E0611 from . import CoreState -class StateTest(unittest.TestCase): +class StateTest(unittest.TestCase): # pylint: disable=R0904 """Test all CoreState implementations.""" # This is to True in each child class @@ -119,6 +126,178 @@ def test_get_task_returns_copy(self) -> None: reloaded = reloaded_tasks[0] self.assertEqual(reloaded.fab_hash, "fab-hash") + def test_claim_task_transitions_pending_to_starting(self) -> None: + """Claiming a task should create a token and move it to starting.""" + state = self.state_factory() + task_id = state.create_task(task_type="flwr-model", run_id=42) + assert task_id is not None + + # Claim should persist token ownership and move the task to STARTING. + token = state.claim_task(task_id) + + self.assertIsNotNone(token) + assert token is not None + self.assertEqual(state.get_task_id_by_token(token), task_id) + tasks = state.get_tasks(task_ids=[task_id]) + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].status.status, Status.STARTING) + self.assertTrue(tasks[0].starting_at) + self.assertEqual(tasks[0].running_at, "") + self.assertEqual(tasks[0].finished_at, "") + + def test_claim_task_rejects_missing_claimed_and_non_pending(self) -> None: + """Only existing pending unclaimed tasks should be claimable.""" + state = self.state_factory() + + # Missing tasks cannot be claimed. + self.assertIsNone(state.claim_task(61016)) + + claimed_task_id = state.create_task(task_type="flwr-model", run_id=42) + finished_task_id = state.create_task(task_type="flwr-model", run_id=42) + assert claimed_task_id is not None and finished_task_id is not None + + # Claiming is single-owner and cannot be repeated. + self.assertIsNotNone(state.claim_task(claimed_task_id)) + self.assertIsNone(state.claim_task(claimed_task_id)) + + # Finished tasks are not claimable. + self.assertTrue(state.finish_task(finished_task_id, SubStatus.FAILED, "done")) + self.assertIsNone(state.claim_task(finished_task_id)) + + def test_activate_task_transitions_starting_to_running(self) -> None: + """Only starting tasks should transition to running.""" + state = self.state_factory() + task_id = state.create_task(task_type="flwr-model", run_id=42) + assert task_id is not None + + # Task does not exist, so it cannot be activated. + self.assertFalse(state.activate_task(61016)) + # Task exists but is pending, so it must be claimed before activation. + self.assertFalse(state.activate_task(task_id)) + # Claiming the task returns a token. + self.assertIsNotNone(state.claim_task(task_id)) + # The task is in starting status, so it can be activated. + self.assertTrue(state.activate_task(task_id)) + # The task is already in running status, so it cannot be activated again. + self.assertFalse(state.activate_task(task_id)) + + tasks = state.get_tasks(task_ids=[task_id]) + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].status.status, Status.RUNNING) + self.assertTrue(tasks[0].running_at) + self.assertEqual(tasks[0].finished_at, "") + + @parameterized.expand( # type: ignore + [ + (SubStatus.FAILED, False), + (SubStatus.STOPPED, False), + (SubStatus.COMPLETED, True), + ] + ) + def test_finish_task_transitions_unfinished_task_to_finished( + self, sub_status: str, requires_running: bool + ) -> None: + """Finishing a task should store the terminal status details.""" + state = self.state_factory() + task_id = state.create_task(task_type="flwr-model", run_id=42) + assert task_id is not None + + # Task does not exist. + self.assertFalse(state.finish_task(61016, SubStatus.FAILED, "missing")) + + if requires_running: + # FINISHED:COMPLETED is only valid once the task is RUNNING. + self.assertFalse(state.finish_task(task_id, sub_status, "boom")) + self.assertIsNotNone(state.claim_task(task_id)) + self.assertTrue(state.activate_task(task_id)) + + # Valid unfinished task transition should succeed. + self.assertTrue(state.finish_task(task_id, sub_status, "boom")) + # Task is already finished, so it cannot be finished again. + self.assertFalse(state.finish_task(task_id, SubStatus.FAILED, "again")) + # Finished tasks cannot be claimed. + self.assertIsNone(state.claim_task(task_id)) + + tasks = state.get_tasks(task_ids=[task_id]) + self.assertEqual(len(tasks), 1) + task = tasks[0] + self.assertEqual( + task.status, + TaskStatus( + status=Status.FINISHED, + sub_status=sub_status, + details="boom", + ), + ) + self.assertTrue(task.finished_at) + + def test_task_heartbeat_extends_token_expiration(self) -> None: + """Task heartbeat should keep a claimed task token valid.""" + state = self.state_factory() + fixed_now = now() + + with patch("datetime.datetime") as mock_dt: + mock_dt.now.return_value = fixed_now + task_id = state.create_task(task_type="flwr-model", run_id=42) + assert task_id is not None + token = state.claim_task(task_id) + assert token is not None + + # Heartbeat extends only existing claimed task leases. + self.assertFalse(state.acknowledge_task_heartbeat(61016)) + self.assertTrue(state.acknowledge_task_heartbeat(task_id)) + + # The heartbeat extension should keep the token valid past its + # initial claim deadline. + mock_dt.now.return_value = fixed_now + timedelta( + seconds=HEARTBEAT_DEFAULT_INTERVAL + 1 + ) + self.assertEqual(state.get_task_id_by_token(token), task_id) + + # Once the extended deadline passes, the token no longer resolves. + mock_dt.now.return_value = fixed_now + timedelta( + seconds=HEARTBEAT_PATIENCE * HEARTBEAT_DEFAULT_INTERVAL + 1 + ) + self.assertIsNone(state.get_task_id_by_token(token)) + self.assertFalse(state.acknowledge_task_heartbeat(task_id)) + + def test_expired_task_token_transitions_task_to_finished_failed(self) -> None: + """Expired task claims should transition tasks to FINISHED:FAILED.""" + state = self.state_factory() + fixed_now = now() + + with patch("datetime.datetime") as mock_dt: + mock_dt.now.return_value = fixed_now + task_id = state.create_task(task_type="flwr-model", run_id=42) + assert task_id is not None + + token = state.claim_task(task_id) + assert token is not None + + mock_dt.now.return_value = fixed_now + timedelta( + seconds=HEARTBEAT_DEFAULT_INTERVAL + 1 + ) + self.assertIsNone(state.get_task_id_by_token(token)) + self.assertFalse(state.acknowledge_task_heartbeat(task_id)) + + tasks = state.get_tasks(task_ids=[task_id]) + self.assertEqual(len(tasks), 1) + self.assertEqual( + tasks[0].status, + TaskStatus( + status=Status.FINISHED, + sub_status=SubStatus.FAILED, + details="No heartbeat received from the task", + ), + ) + self.assertTrue(tasks[0].finished_at) + + def test_get_task_id_by_token_returns_none_for_unknown_token(self) -> None: + """Unknown task tokens should not resolve to a task.""" + state = self.state_factory() + + self.assertIsNone(state.get_task_id_by_token("missing-token")) + def test_create_verify_and_delete_token(self) -> None: """Test creating, verifying, and deleting tokens.""" # Prepare diff --git a/framework/py/flwr/supercore/corestate/in_memory_corestate.py b/framework/py/flwr/supercore/corestate/in_memory_corestate.py index fb276ee38e39..a8c0f5774ae7 100644 --- a/framework/py/flwr/supercore/corestate/in_memory_corestate.py +++ b/framework/py/flwr/supercore/corestate/in_memory_corestate.py @@ -29,6 +29,7 @@ HEARTBEAT_PATIENCE, TASK_ID_NUM_BYTES, Status, + SubStatus, ) from flwr.common.typing import Fab from flwr.proto.task_pb2 import Task, TaskStatus # pylint: disable=E0611 @@ -43,7 +44,7 @@ class TokenRecord: """Record containing token and heartbeat information.""" token: str - active_until: float + active_until: int class InMemoryCoreState(CoreState): # pylint: disable=too-many-instance-attributes @@ -60,6 +61,10 @@ def __init__(self, object_store: ObjectStore) -> None: self.nonce_store: dict[tuple[str, str], float] = {} self.lock_nonce_store = Lock() self.task_store: dict[int, Task] = {} + # Store task ID to token mapping + self.task_token_store: dict[int, TokenRecord] = {} + # Store token to task ID mapping + self.task_token_to_task_id: dict[str, int] = {} self.lock_task_store = Lock() @property @@ -113,6 +118,7 @@ def create_task( # pylint: disable=too-many-arguments,too-many-positional-argum task_id=task_id, type=task_type, run_id=run_id, + status=TaskStatus(status=Status.PENDING, sub_status="", details=""), pending_at=now().isoformat(), fab_hash=fab_hash, model_ref=model_ref, @@ -156,8 +162,7 @@ def get_tasks( # pylint: disable=too-many-arguments matched_task_ids &= { task_id for task_id in matched_task_ids - if determine_task_status(self.task_store[task_id]).status - in status_set + if self.task_store[task_id].status.status in status_set } tasks = [self.task_store[task_id] for task_id in matched_task_ids] @@ -176,10 +181,126 @@ def get_tasks( # pylint: disable=too-many-arguments for task in tasks: task_copy = Task() task_copy.CopyFrom(task) - task_copy.status.CopyFrom(determine_task_status(task)) result.append(task_copy) return result + def claim_task(self, task_id: int) -> str | None: + """Atomically claim a pending task.""" + token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) + with self.lock_task_store: + task = self.task_store.get(task_id) + if task is None or task_id in self.task_token_store: + return None + if task.status.status != Status.PENDING: + return None + + # Claiming moves the task into STARTING and records the heartbeat state. + claimed_at = now() + task.starting_at = claimed_at.isoformat() + task.status.CopyFrom( + TaskStatus(status=Status.STARTING, sub_status="", details="") + ) + self.task_token_store[task_id] = TokenRecord( + token=token, + active_until=int(claimed_at.timestamp()) + HEARTBEAT_DEFAULT_INTERVAL, + ) + self.task_token_to_task_id[token] = task_id + return token + + def activate_task(self, task_id: int) -> bool: + """Move a task from starting to running.""" + with self.lock_task_store: + # Expire non-responsive tasks before transitioning task status. + self._cleanup_expired_task_tokens_locked() + + # Transition task from STARTING -> RUNNING. + task = self.task_store.get(task_id) + if task is None or task.status.status != Status.STARTING: + return False + + task.running_at = now().isoformat() + task.status.CopyFrom( + TaskStatus(status=Status.RUNNING, sub_status="", details="") + ) + return True + + def finish_task(self, task_id: int, sub_status: str, details: str) -> bool: + """Move an unfinished task to finished.""" + with self.lock_task_store: + # Expire non-responsive tasks before transitioning task status. + self._cleanup_expired_task_tokens_locked() + + # Transition task to FINISHED + task = self.task_store.get(task_id) + if task is None or task.status.status == Status.FINISHED: + return False + + if sub_status == SubStatus.COMPLETED: + # Only allow transition to COMPLETED if currently RUNNING + if task.status.status != Status.RUNNING: + return False + + task.finished_at = now().isoformat() + task.status.CopyFrom( + TaskStatus( + status=Status.FINISHED, sub_status=sub_status, details=details + ) + ) + + # Revoke any existing task token now that the task is finished. + if (record := self.task_token_store.pop(task_id, None)) is not None: + self.task_token_to_task_id.pop(record.token, None) + return True + + def acknowledge_task_heartbeat(self, task_id: int) -> bool: + """Extend heartbeat state for the claimed task.""" + with self.lock_task_store: + # Heartbeats are accepted only for starting and running tasks + self._cleanup_expired_task_tokens_locked() + task = self.task_store.get(task_id) + record = self.task_token_store.get(task_id) + if task is None or record is None or task.status.status == Status.FINISHED: + return False + + now_int = int(now().timestamp()) + record.active_until = ( + now_int + HEARTBEAT_PATIENCE * HEARTBEAT_DEFAULT_INTERVAL + ) + return True + + def get_task_id_by_token(self, token: str) -> int | None: + """Return the task ID associated with the task token, if valid.""" + with self.lock_task_store: + # Resolve tokens after cleanup so callers never receive expired claims. + self._cleanup_expired_task_tokens_locked() + return self.task_token_to_task_id.get(token) + + def _cleanup_expired_task_tokens_locked(self) -> None: + """Remove expired task tokens. + + Callers must acquire `lock_task_store` before calling this method. + Expired tasks are marked as finished with a failed status, and their + tokens are removed. + """ + expired_at = now() + current = int(expired_at.timestamp()) + for task_id, record in list(self.task_token_store.items()): + if record.active_until < current: + # The task is considered expired. Mark it as finished with a failed + # status if it's not already finished, and remove the token. + task = self.task_store.get(task_id) + if task and task.status.status != Status.FINISHED: + task.finished_at = expired_at.isoformat() + task.status.CopyFrom( + TaskStatus( + status=Status.FINISHED, + sub_status=SubStatus.FAILED, + details="No heartbeat received from the task", + ) + ) + del self.task_token_store[task_id] + self.task_token_to_task_id.pop(record.token, None) + def create_token(self, run_id: int) -> str | None: """Create a token for the given run ID.""" token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token @@ -187,8 +308,9 @@ def create_token(self, run_id: int) -> str | None: if run_id in self.token_store: return None # Token already created for this run ID + active_until = int(now().timestamp()) + HEARTBEAT_DEFAULT_INTERVAL self.token_store[run_id] = TokenRecord( - token=token, active_until=now().timestamp() + HEARTBEAT_DEFAULT_INTERVAL + token=token, active_until=active_until ) self.token_to_run_id[token] = run_id return token @@ -226,7 +348,7 @@ def acknowledge_app_heartbeat(self, token: str) -> bool: # Get the run_id and update heartbeat info run_id = self.token_to_run_id[token] record = self.token_store[run_id] - current = now().timestamp() + current = int(now().timestamp()) record.active_until = ( current + HEARTBEAT_PATIENCE * HEARTBEAT_DEFAULT_INTERVAL ) @@ -239,8 +361,8 @@ def _cleanup_expired_tokens(self) -> None: Subclasses can override `_on_tokens_expired` to add custom cleanup logic. """ with self.lock_token_store: - current = now().timestamp() - expired_records: list[tuple[int, float]] = [] + current = int(now().timestamp()) + expired_records: list[tuple[int, int]] = [] for run_id, record in list(self.token_store.items()): if record.active_until < current: expired_records.append((run_id, record.active_until)) @@ -252,14 +374,14 @@ def _cleanup_expired_tokens(self) -> None: if expired_records: self._on_tokens_expired(expired_records) - def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None: + def _on_tokens_expired(self, expired_records: list[tuple[int, int]]) -> None: """Handle cleanup of expired tokens. Override in subclasses to add custom cleanup logic. Parameters ---------- - expired_records : list[tuple[int, float]] + expired_records : list[tuple[int, int]] List of tuples containing (run_id, active_until timestamp) for expired tokens. """ @@ -282,16 +404,3 @@ def _cleanup_expired_nonces(self) -> None: for key, expires_at in list(self.nonce_store.items()): if expires_at < current: del self.nonce_store[key] - - -def determine_task_status(task: Task) -> TaskStatus: - """Determine the status of a task based on timestamp fields.""" - if task.pending_at: - if task.finished_at: - return TaskStatus(status=Status.FINISHED, sub_status="", details="") - if task.starting_at: - if task.running_at: - return TaskStatus(status=Status.RUNNING, sub_status="", details="") - return TaskStatus(status=Status.STARTING, sub_status="", details="") - return TaskStatus(status=Status.PENDING, sub_status="", details="") - raise ValueError(f"The task {task.task_id} does not have a valid status.") diff --git a/framework/py/flwr/supercore/corestate/sql_corestate.py b/framework/py/flwr/supercore/corestate/sql_corestate.py index cc0921637f46..e9fc9193ae3f 100644 --- a/framework/py/flwr/supercore/corestate/sql_corestate.py +++ b/framework/py/flwr/supercore/corestate/sql_corestate.py @@ -31,6 +31,7 @@ HEARTBEAT_PATIENCE, TASK_ID_NUM_BYTES, Status, + SubStatus, ) from flwr.common.typing import Fab from flwr.proto.task_pb2 import Task, TaskStatus # pylint: disable=E0611 @@ -42,6 +43,15 @@ from .corestate import CoreState from .utils import generate_rand_int_from_bytes +# Define SQL conditions for task statuses to ensure consistency across queries +STATUS_CONDITIONS = { + Status.PENDING: "(starting_at IS NULL AND finished_at IS NULL)", + Status.STARTING: "(starting_at IS NOT NULL AND running_at IS NULL " + "AND finished_at IS NULL)", + Status.RUNNING: "(running_at IS NOT NULL AND finished_at IS NULL)", + Status.FINISHED: "(finished_at IS NOT NULL)", +} + class SqlCoreState(CoreState, SqlMixin): """SQLAlchemy-based CoreState implementation.""" @@ -113,10 +123,12 @@ def create_task( # pylint: disable=too-many-arguments,too-many-positional-argum insert_query = """ INSERT INTO task (task_id, type, run_id, fab_hash, model_ref, connector_ref, token, - pending_at, starting_at, running_at, finished_at) + active_until, pending_at, starting_at, running_at, finished_at, + sub_status, details) VALUES (:task_id, :type, :run_id, :fab_hash, :model_ref, :connector_ref, :token, - :pending_at, :starting_at, :running_at, :finished_at); + :active_until, :pending_at, :starting_at, :running_at, :finished_at, + :sub_status, :details); """ params = { @@ -127,10 +139,13 @@ def create_task( # pylint: disable=too-many-arguments,too-many-positional-argum "model_ref": model_ref, "connector_ref": connector_ref, "token": None, + "active_until": None, "pending_at": now().isoformat(), "starting_at": None, "running_at": None, "finished_at": None, + "sub_status": "", + "details": "", } with self.session(): @@ -195,7 +210,8 @@ def get_tasks( # pylint: disable=too-many-arguments,too-many-locals,too-many-br query = """ SELECT task_id, type, run_id, fab_hash, model_ref, connector_ref, - pending_at, starting_at, running_at, finished_at + pending_at, starting_at, running_at, finished_at, + sub_status, details FROM task """ if conditions: @@ -230,6 +246,158 @@ def get_metadata(self) -> MetaData: """Return SQLAlchemy MetaData needed for CoreState tables.""" return create_corestate_metadata() + def claim_task(self, task_id: int) -> str | None: + """Atomically claim a pending task.""" + token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) + claimed_at = now() + active_until = int(claimed_at.timestamp()) + HEARTBEAT_DEFAULT_INTERVAL + sint64_task_id = uint64_to_int64(task_id) + try: + # The conditional UPDATE is the atomic claim: exactly one caller can + # move a pending, unclaimed task to STARTING and attach a token. + rows = self.query( + f""" + UPDATE task + SET token = :token, + active_until = :active_until, + starting_at = :starting_at + WHERE task_id = :task_id AND token IS NULL + AND {STATUS_CONDITIONS[Status.PENDING]} + RETURNING task_id + """, + { + "task_id": sint64_task_id, + "token": token, + "active_until": active_until, + "starting_at": claimed_at.isoformat(), + }, + ) + if not rows: + return None + + return token + except IntegrityError: + # Rare failure: generated token already exists (duplicate) + return None + + def activate_task(self, task_id: int) -> bool: + """Move a task from starting to running.""" + # Expire non-responsive tasks before transitioning task status. + + with self.session(): + self._cleanup_expired_task_tokens() + + # Activation is a strict STARTING -> RUNNING transition. + rows = self.query( + f""" + UPDATE task + SET running_at = :running_at + WHERE task_id = :task_id AND {STATUS_CONDITIONS[Status.STARTING]} + RETURNING task_id + """, + {"task_id": uint64_to_int64(task_id), "running_at": now().isoformat()}, + ) + return len(rows) > 0 + + def finish_task(self, task_id: int, sub_status: str, details: str) -> bool: + """Move an unfinished task to finished.""" + sint64_task_id = uint64_to_int64(task_id) + with self.session(): + self._cleanup_expired_task_tokens() + # FINISHED:COMPLETED is only valid from RUNNING. + completion_constraint = "" + if sub_status == SubStatus.COMPLETED: + completion_constraint = "AND running_at IS NOT NULL" + + rows = self.query( + f""" + UPDATE task + SET finished_at = :finished_at, + sub_status = :sub_status, + details = :details, + active_until = NULL, + token = NULL + WHERE task_id = :task_id + AND finished_at IS NULL {completion_constraint} + RETURNING task_id + """, + { + "task_id": sint64_task_id, + "finished_at": now().isoformat(), + "sub_status": sub_status, + "details": details, + }, + ) + if not rows: + return False + + return True + + def acknowledge_task_heartbeat(self, task_id: int) -> bool: + """Extend heartbeat state for the claimed task.""" + # Heartbeats are accepted only for active, unexpired task claims. + with self.session(): + current = int(now().timestamp()) + self._cleanup_expired_task_tokens() + rows = self.query( + """ + UPDATE task + SET active_until = :active_until + WHERE task_id = :task_id + AND active_until >= :current + AND finished_at IS NULL + RETURNING task_id + """, + { + "task_id": uint64_to_int64(task_id), + "current": current, + "active_until": ( + current + HEARTBEAT_PATIENCE * HEARTBEAT_DEFAULT_INTERVAL + ), + }, + ) + return len(rows) > 0 + + def get_task_id_by_token(self, token: str) -> int | None: + """Return the task ID associated with the task token, if valid.""" + rows = self.query( + """ + SELECT task_id FROM task + WHERE token = :token AND active_until >= :current AND finished_at IS NULL + """, + {"token": token, "current": int(now().timestamp())}, + ) + if not rows: + return None + return int64_to_uint64(rows[0]["task_id"]) + + def _cleanup_expired_task_tokens(self) -> None: + """Remove expired task heartbeat records. + + Expired tasks are marked as finished with a failed status, and their tokens are + removed. + """ + expired_at = now() + current = int(expired_at.timestamp()) + # Expired task claims are terminal failures and lose their token. + self.query( + """ + UPDATE task + SET token = NULL, + active_until = NULL, + finished_at = :finished_at, + sub_status = :sub_status, + details = :details + WHERE token IS NOT NULL AND active_until < :current + """, + { + "current": current, + "finished_at": expired_at.isoformat(), + "sub_status": SubStatus.FAILED, + "details": "No heartbeat received from the task", + }, + ) + def create_token(self, run_id: int) -> str | None: """Create a token for the given run ID.""" token = secrets.token_hex(FLWR_APP_TOKEN_LENGTH) # Generate a random token @@ -360,7 +528,11 @@ def determine_task_status(row: dict[str, Any]) -> TaskStatus: """Determine the status of the task based on timestamp fields.""" if row["pending_at"]: if row["finished_at"]: - return TaskStatus(status=Status.FINISHED, sub_status="", details="") + return TaskStatus( + status=Status.FINISHED, + sub_status=row["sub_status"], + details=row["details"], + ) if row["starting_at"]: if row["running_at"]: return TaskStatus(status=Status.RUNNING, sub_status="", details="") diff --git a/framework/py/flwr/supercore/interceptors/__init__.py b/framework/py/flwr/supercore/interceptors/__init__.py index c58b146284b4..b29dcf39e9d5 100644 --- a/framework/py/flwr/supercore/interceptors/__init__.py +++ b/framework/py/flwr/supercore/interceptors/__init__.py @@ -23,6 +23,11 @@ create_clientappio_token_auth_server_interceptor, create_serverappio_token_auth_server_interceptor, ) +from .runtime_version_interceptor import ( + RuntimeVersionClientInterceptor, + RuntimeVersionServerInterceptor, + create_serverappio_runtime_version_server_interceptor, +) from .superexec_auth_interceptor import ( SuperExecAuthClientInterceptor, SuperExecAuthServerInterceptor, @@ -35,10 +40,13 @@ "AUTHENTICATION_FAILED_MESSAGE", "AppIoTokenClientInterceptor", "AppIoTokenServerInterceptor", + "RuntimeVersionClientInterceptor", + "RuntimeVersionServerInterceptor", "SuperExecAuthClientInterceptor", "SuperExecAuthServerInterceptor", "create_clientappio_superexec_auth_server_interceptor", "create_clientappio_token_auth_server_interceptor", + "create_serverappio_runtime_version_server_interceptor", "create_serverappio_superexec_auth_server_interceptor", "create_serverappio_token_auth_server_interceptor", ] diff --git a/framework/py/flwr/supercore/interceptors/runtime_version_interceptor.py b/framework/py/flwr/supercore/interceptors/runtime_version_interceptor.py new file mode 100644 index 000000000000..ea873668eb17 --- /dev/null +++ b/framework/py/flwr/supercore/interceptors/runtime_version_interceptor.py @@ -0,0 +1,123 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Runtime version metadata interceptors.""" + + +from __future__ import annotations + +from collections.abc import Callable +from logging import WARN +from typing import Any + +import grpc +from google.protobuf.message import Message as GrpcMessage + +from flwr.common.logger import log +from flwr.supercore.constant import VERSION_INCOMPATIBILITY_MESSAGE_METADATA_KEY +from flwr.supercore.runtime_version_compatibility import RuntimeVersionMetadata +from flwr.supercore.utils import get_metadata_str + + +class RuntimeVersionClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore + """Attach Flower runtime version metadata to outbound unary RPCs.""" + + def __init__(self, component_name: str) -> None: + self._metadata = RuntimeVersionMetadata.from_local_component(component_name) + self._compatibility_warning_logged = False + + def intercept_unary_unary( + self, + continuation: Callable[[Any, Any], Any], + client_call_details: grpc.ClientCallDetails, + request: GrpcMessage, + ) -> grpc.Call: + """Add the runtime version metadata headers.""" + details = client_call_details._replace( + metadata=self._metadata.append_to_grpc_metadata( + client_call_details.metadata + ) + ) + call: grpc.Call = continuation(details, request) + + # Log the incompatibility message from the response metadata + if not self._compatibility_warning_logged: + incompat_message = get_metadata_str( + call.trailing_metadata(), VERSION_INCOMPATIBILITY_MESSAGE_METADATA_KEY + ) + if incompat_message: + self._compatibility_warning_logged = True + log(WARN, incompat_message) + + return call + + +class RuntimeVersionServerInterceptor(grpc.ServerInterceptor): # type: ignore + """Observe Flower runtime version metadata on inbound unary RPCs.""" + + def __init__( + self, + *, + connection_name: str, + local_metadata: RuntimeVersionMetadata, + ) -> None: + self._connection_name = connection_name + self._local_metadata = local_metadata + + def intercept_service( + self, + continuation: Callable[[Any], Any], + handler_call_details: grpc.HandlerCallDetails, + ) -> grpc.RpcMethodHandler: + """Parse peer runtime metadata, then continue normal RPC handling.""" + method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) + if method_handler is None or method_handler.unary_unary is None: + return method_handler + + # Parse and validate peer metadata + peer_metadata, incompat_details = RuntimeVersionMetadata.from_grpc_metadata( + handler_call_details.invocation_metadata + ) + + # Check compatibility and return any rejection message + if incompat_details is None: + incompat_details = self._local_metadata.check_compatibility(peer_metadata) + + # Attach the incompatibility message to the trailing metadata if present + def wrapped(request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: + if incompat_details: + incompat_message = ( + "Runtime version compatibility check failed for " + f"{self._connection_name}. {incompat_details}" + ) + context.set_trailing_metadata( + ((VERSION_INCOMPATIBILITY_MESSAGE_METADATA_KEY, incompat_message),) + ) + return method_handler.unary_unary(request, context) # type: ignore + + return grpc.unary_unary_rpc_method_handler( + wrapped, + request_deserializer=method_handler.request_deserializer, + response_serializer=method_handler.response_serializer, + ) + + +def create_serverappio_runtime_version_server_interceptor( + connection_name: str = "Caller <-> SuperLink ServerAppIo API", +) -> RuntimeVersionServerInterceptor: + """Create the default runtime version interceptor for ServerAppIo.""" + return RuntimeVersionServerInterceptor( + connection_name=connection_name, + local_metadata=RuntimeVersionMetadata.from_local_component("SuperLink"), + ) diff --git a/framework/py/flwr/supercore/interceptors/runtime_version_interceptor_test.py b/framework/py/flwr/supercore/interceptors/runtime_version_interceptor_test.py new file mode 100644 index 000000000000..40a6fc8131ab --- /dev/null +++ b/framework/py/flwr/supercore/interceptors/runtime_version_interceptor_test.py @@ -0,0 +1,207 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for runtime version metadata interceptors.""" + + +from collections import namedtuple +from unittest import TestCase +from unittest.mock import Mock + +import grpc +from google.protobuf.message import Message as GrpcMessage + +from flwr.proto.serverappio_pb2 import GetNodesRequest # pylint: disable=E0611 +from flwr.supercore.constant import ( + FLWR_COMPONENT_NAME_METADATA_KEY, + FLWR_PACKAGE_NAME_METADATA_KEY, + FLWR_PACKAGE_VERSION_METADATA_KEY, +) +from flwr.supercore.interceptors import ( + RuntimeVersionClientInterceptor, + RuntimeVersionServerInterceptor, +) +from flwr.supercore.runtime_version_compatibility import RuntimeVersionMetadata + +_ClientCallDetails = namedtuple( + "_ClientCallDetails", + ["method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"], +) + + +class _HandlerCallDetails: + def __init__( + self, + method: str, + invocation_metadata: tuple[tuple[str, str | bytes], ...], + ) -> None: + self.method = method + self.invocation_metadata = invocation_metadata + + +def _make_unary_handler() -> grpc.RpcMethodHandler: + def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str: + return "ok" + + return grpc.unary_unary_rpc_method_handler(_handler) + + +class TestRuntimeVersionClientInterceptor(TestCase): + """Unit tests for RuntimeVersionClientInterceptor.""" + + def _make_call(self) -> Mock: + call = Mock(spec=grpc.Call) + call.trailing_metadata.return_value = () + return call + + def test_attach_runtime_version_headers(self) -> None: + """The interceptor should add the shared version metadata keys.""" + interceptor = RuntimeVersionClientInterceptor(component_name="simulation") + details = _ClientCallDetails( + method="/flwr.proto.ServerAppIo/GetNodes", + timeout=None, + metadata=(("x-test", "value"),), + credentials=None, + wait_for_ready=None, + compression=None, + ) + captured: dict[str, list[tuple[str, str | bytes]]] = {} + call = self._make_call() + + def continuation( + client_call_details: grpc.ClientCallDetails, + _request: GrpcMessage, + ) -> Mock: + captured["metadata"] = list(client_call_details.metadata or []) + return call + + response = interceptor.intercept_unary_unary( + continuation=continuation, + client_call_details=details, + request=GetNodesRequest(run_id=1), + ) + + self.assertIs(response, call) + metadata = dict(captured["metadata"]) + self.assertEqual(metadata["x-test"], "value") + self.assertIn(FLWR_PACKAGE_NAME_METADATA_KEY, metadata) + self.assertIn(FLWR_PACKAGE_VERSION_METADATA_KEY, metadata) + self.assertEqual(metadata[FLWR_COMPONENT_NAME_METADATA_KEY], "simulation") + + def test_attach_runtime_version_headers_rejects_preexisting_runtime_keys( + self, + ) -> None: + """Fail fast when runtime-version keys are already present outbound.""" + interceptor = RuntimeVersionClientInterceptor(component_name="simulation") + details = _ClientCallDetails( + method="/flwr.proto.ServerAppIo/GetNodes", + timeout=None, + metadata=((FLWR_PACKAGE_NAME_METADATA_KEY, "old"), ("x-test", "value")), + credentials=None, + wait_for_ready=None, + compression=None, + ) + with self.assertRaisesRegex( + RuntimeError, + "gRPC metadata already contains runtime version keys: flwr-package-name", + ): + interceptor.intercept_unary_unary( + continuation=lambda _details, _request: self._make_call(), + client_call_details=details, + request=GetNodesRequest(run_id=1), + ) + + +class TestRuntimeVersionServerInterceptor(TestCase): + """Unit tests for RuntimeVersionServerInterceptor.""" + + def setUp(self) -> None: + """Create a baseline interceptor for each test.""" + self.interceptor = RuntimeVersionServerInterceptor( + connection_name="flwr-simulation <-> SuperLink ServerAppIo API", + local_metadata=RuntimeVersionMetadata.from_local_component( + "superlink", + package_name_value="flwr", + package_version_value="1.29.0", + ), + ) + + def test_missing_metadata_is_tolerated(self) -> None: + """Missing runtime metadata should pass during rollout.""" + intercepted = self.interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails("/flwr.proto.ServerAppIo/GetNodes", ()), + ) + + context = Mock() + response = intercepted.unary_unary(GetNodesRequest(run_id=1), context) + self.assertEqual(response, "ok") + context.set_trailing_metadata.assert_not_called() + + def test_unparseable_peer_version_is_warned(self) -> None: + """Explicit unparseable peer versions should be warned.""" + intercepted = self.interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/GetNodes", + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "main"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "simulation"), + ), + ), + ) + + context = Mock() + response = intercepted.unary_unary(GetNodesRequest(run_id=1), context) + self.assertEqual(response, "ok") + context.set_trailing_metadata.assert_called_once() + + def test_incompatible_metadata_is_warned(self) -> None: + """Different major.minor versions should still be warned.""" + intercepted = self.interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/GetNodes", + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.30.1"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "simulation"), + ), + ), + ) + + context = Mock() + response = intercepted.unary_unary(GetNodesRequest(run_id=1), context) + self.assertEqual(response, "ok") + context.set_trailing_metadata.assert_called_once() + + def test_compatible_metadata_is_accepted(self) -> None: + """Same major.minor versions should pass.""" + intercepted = self.interceptor.intercept_service( + lambda _: _make_unary_handler(), + _HandlerCallDetails( + "/flwr.proto.ServerAppIo/GetNodes", + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.29.7"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "simulation"), + ), + ), + ) + + context = Mock() + response = intercepted.unary_unary(GetNodesRequest(run_id=1), context) + self.assertEqual(response, "ok") + context.set_trailing_metadata.assert_not_called() diff --git a/framework/py/flwr/supercore/runtime_version_compatibility.py b/framework/py/flwr/supercore/runtime_version_compatibility.py new file mode 100644 index 000000000000..3615a0911263 --- /dev/null +++ b/framework/py/flwr/supercore/runtime_version_compatibility.py @@ -0,0 +1,196 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers for Flower runtime version metadata and compatibility checks.""" + + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass + +from packaging.version import InvalidVersion, Version + +from flwr.supercore.constant import ( + FLWR_COMPONENT_NAME_METADATA_KEY, + FLWR_PACKAGE_NAME_METADATA_KEY, + FLWR_PACKAGE_VERSION_METADATA_KEY, +) +from flwr.supercore.utils import ( + MetadataLookupError, + find_metadata_keys, + get_metadata_str_checked, +) +from flwr.supercore.version import package_name as flwr_package_name +from flwr.supercore.version import package_version as flwr_package_version + +_SUPPORTED_FLOWER_PACKAGE_NAMES = frozenset({"flwr", "flwr-nightly"}) +_RUNTIME_METADATA_KEYS = ( + FLWR_PACKAGE_NAME_METADATA_KEY, + FLWR_PACKAGE_VERSION_METADATA_KEY, + FLWR_COMPONENT_NAME_METADATA_KEY, +) + + +@dataclass(frozen=True) +class RuntimeVersionMetadata: + """Flower runtime version metadata attached to a caller.""" + + package_name: str + package_version: str + component_name: str + + @classmethod + def from_local_component( + cls, + component_name: str, + *, + package_name_value: str = flwr_package_name, + package_version_value: str = flwr_package_version, + ) -> RuntimeVersionMetadata: + """Build metadata for the local Flower runtime component.""" + component_name = component_name.strip() + if component_name == "": + raise ValueError("`component_name` must be a non-empty string") + + # Check version validity when the package name is recognized + if package_name_value != "unknown": + try: + Version(package_version_value) + except InvalidVersion: + raise ValueError( + "`package_version_value` is not a valid version: " + f"{package_version_value!r}" + ) from None + + return cls( + package_name=package_name_value, + package_version=package_version_value, + component_name=component_name, + ) + + @classmethod + def from_grpc_metadata( + cls, + grpc_metadata: Sequence[tuple[str, str | bytes]] | None, + ) -> tuple[RuntimeVersionMetadata | None, str | None]: + """Parse runtime version metadata from a gRPC metadata sequence.""" + # TEMPORARY: allow continuation when all runtime metadata keys are missing + # to avoid hard-failing older clients without metadata + if _metadata_is_missing(grpc_metadata): + return None, None + + try: + ret = RuntimeVersionMetadata( + package_name=get_metadata_str_checked( + grpc_metadata, FLWR_PACKAGE_NAME_METADATA_KEY + ), + package_version=get_metadata_str_checked( + grpc_metadata, FLWR_PACKAGE_VERSION_METADATA_KEY + ), + component_name=get_metadata_str_checked( + grpc_metadata, FLWR_COMPONENT_NAME_METADATA_KEY + ), + ) + return ret, None + + except MetadataLookupError as e: + return None, f"Invalid Flower runtime metadata: {str(e)}" + + def append_to_grpc_metadata( + self, + grpc_metadata: Sequence[tuple[str, str | bytes]] | None, + ) -> tuple[tuple[str, str | bytes], ...]: + """Return gRPC metadata with runtime version values added.""" + metadata = tuple(grpc_metadata or ()) + existing_runtime_keys = find_metadata_keys(metadata, _RUNTIME_METADATA_KEYS) + if existing_runtime_keys: + raise RuntimeError( + "gRPC metadata already contains runtime version keys: " + f"{', '.join(sorted(existing_runtime_keys))}" + ) + runtime_metadata = ( + (FLWR_PACKAGE_NAME_METADATA_KEY, self.package_name), + (FLWR_PACKAGE_VERSION_METADATA_KEY, self.package_version), + (FLWR_COMPONENT_NAME_METADATA_KEY, self.component_name), + ) + return metadata + runtime_metadata + + def check_compatibility(self, peer: RuntimeVersionMetadata | None) -> str | None: + """Return a rejection message, or ``None`` if the peer is accepted. + + Rejects the peer if any of the following are true: + - The peer's Flower package name is not recognized. + - The peer's Flower version cannot be parsed as a valid version. + - The peer's major or minor version differs from the local version. + + Accepts the peer (returns ``None``) if any of the following are true: + - The peer metadata is missing (temporary allowance for older clients). + - The local package name is not recognized. + - The peer's major and minor version match the local version. + """ + # TEMPORARY: allow continuation when peer metadata is missing to avoid + # hard-failing older clients without metadata + if peer is None: + return None + + # Reject suspicious peer package name + peer_package_name = peer.package_name.strip() + if peer_package_name not in _SUPPORTED_FLOWER_PACKAGE_NAMES: + return f"Peer Flower package name is not recognized: {peer_package_name!r}." + + # Allow continuation when the local package name is not recognized + if self.package_name.strip() not in _SUPPORTED_FLOWER_PACKAGE_NAMES: + return None + + # Parse versions + local_version = Version(self.package_version) + try: + peer_version = Version(peer.package_version) + except InvalidVersion: + return ( + f"Peer Flower version metadata cannot be parsed: " + f"{peer.package_version!r}." + ) + + # Check major.minor compatibility + if ( + local_version.major != peer_version.major + or local_version.minor != peer_version.minor + ): + return ( + f"{self.component_name} version {self.package_version} only accepts " + "peers from the same major.minor release, but received " + f"{peer.component_name} version {peer.package_version}." + ) + + # Versions are compatible + return None + + +def _metadata_is_missing( + metadata: Sequence[tuple[str, str | bytes]] | None, +) -> bool: + """Return `True` if all runtime metadata keys are missing from the gRPC metadata. + + This is a TEMPORARY helper to allow older clients without runtime metadata to + continue working rather than being rejected for missing metadata. It is safe to + remove this once the minimum supported Flower version is new enough that all clients + are expected to include runtime metadata. + """ + if metadata is None: + return True + + metadata_keys = {key for key, _ in metadata} + return all(key not in metadata_keys for key in _RUNTIME_METADATA_KEYS) diff --git a/framework/py/flwr/supercore/runtime_version_compatibility_test.py b/framework/py/flwr/supercore/runtime_version_compatibility_test.py new file mode 100644 index 000000000000..ff5b706d2fa3 --- /dev/null +++ b/framework/py/flwr/supercore/runtime_version_compatibility_test.py @@ -0,0 +1,227 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Flower runtime version metadata helpers.""" + + +import pytest + +from flwr.supercore.constant import ( + FLWR_COMPONENT_NAME_METADATA_KEY, + FLWR_PACKAGE_NAME_METADATA_KEY, + FLWR_PACKAGE_VERSION_METADATA_KEY, +) + +from .runtime_version_compatibility import RuntimeVersionMetadata + + +def test_runtime_version_metadata_appends_new_metadata() -> None: + """Runtime metadata should append the shared key names.""" + metadata = RuntimeVersionMetadata.from_local_component( + "supernode", + package_name_value="flwr", + package_version_value="1.29.0", + ) + + assert metadata.append_to_grpc_metadata(None) == ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.29.0"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "supernode"), + ) + + +def test_runtime_version_metadata_appends_to_grpc_metadata() -> None: + """Runtime metadata should preserve unrelated metadata when appending.""" + metadata = RuntimeVersionMetadata.from_local_component( + "simulation", + package_name_value="flwr", + package_version_value="1.29.0", + ) + + grpc_metadata = metadata.append_to_grpc_metadata((("x-test", "value"),)) + + assert grpc_metadata == ( + ("x-test", "value"), + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.29.0"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "simulation"), + ) + + +def test_runtime_version_metadata_append_rejects_preexisting_runtime_keys() -> None: + """Appending should fail fast when runtime-version keys already exist.""" + metadata = RuntimeVersionMetadata.from_local_component( + "simulation", + package_name_value="flwr", + package_version_value="1.29.0", + ) + + with pytest.raises( + RuntimeError, + match="gRPC metadata already contains runtime version keys: " + "flwr-package-name", + ): + metadata.append_to_grpc_metadata( + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "old"), + ("x-test", "value"), + ) + ) + + +def test_build_runtime_version_metadata_rejects_empty_component_name() -> None: + """Component names must not be empty.""" + with pytest.raises(ValueError, match="component_name"): + RuntimeVersionMetadata.from_local_component("") + + +def test_runtime_version_metadata_from_grpc_returns_missing_for_absent_keys() -> None: + """Absent Flower metadata should be treated as the rollout missing case.""" + metadata, error = RuntimeVersionMetadata.from_grpc_metadata( + (("other-header", "value"),) + ) + + assert metadata is None + assert error is None + + +def test_runtime_version_metadata_from_grpc_accepts_metadata_item_iterables() -> None: + """GRPC metadata-style iterables should be supported directly.""" + metadata, error = RuntimeVersionMetadata.from_grpc_metadata( + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.29.0"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "cli"), + ) + ) + + assert error is None + assert metadata == RuntimeVersionMetadata( + package_name="flwr", + package_version="1.29.0", + component_name="cli", + ) + + +@pytest.mark.parametrize( + ("grpc_metadata", "expected_error"), + [ + ( + ((FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"),), + "Invalid Flower runtime metadata: " + "Metadata key 'flwr-package-version' is missing.", + ), + ( + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, b"flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, b"1.29.0"), + (FLWR_COMPONENT_NAME_METADATA_KEY, b"cli"), + ), + "Invalid Flower runtime metadata: " + "Metadata key 'flwr-package-name' has a value of the wrong type.", + ), + ( + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, b"\xff\xfe"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "cli"), + ), + "Invalid Flower runtime metadata: " + "Metadata key 'flwr-package-version' has a value of the wrong type.", + ), + ( + ( + (FLWR_PACKAGE_NAME_METADATA_KEY, "flwr"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.29.0"), + (FLWR_PACKAGE_VERSION_METADATA_KEY, "1.29.1"), + (FLWR_COMPONENT_NAME_METADATA_KEY, "cli"), + ), + "Invalid Flower runtime metadata: " + "Metadata key 'flwr-package-version' has duplicate values.", + ), + ], +) +def test_runtime_version_metadata_from_grpc_rejects_invalid_metadata( + grpc_metadata: tuple[tuple[str, str | bytes], ...], + expected_error: str, +) -> None: + """Malformed runtime metadata should be rejected explicitly.""" + metadata, error = RuntimeVersionMetadata.from_grpc_metadata(grpc_metadata) + + assert metadata is None + assert error == expected_error + + +@pytest.mark.parametrize( + ("local", "peer"), + [ + ( + RuntimeVersionMetadata("flwr", "1.29.0", "superlink"), + RuntimeVersionMetadata("flwr", "1.29.7", "supernode"), + ), + ( + RuntimeVersionMetadata("flwr", "1.30.0.dev20260425", "superlink"), + RuntimeVersionMetadata("flwr", "1.30.0rc1", "supernode"), + ), + ( + RuntimeVersionMetadata("flwr", "1.30.0", "superlink"), + RuntimeVersionMetadata("flwr-nightly", "1.30.1.dev20260425", "supernode"), + ), + ( + RuntimeVersionMetadata("flwr", "1.29.0", "superlink"), + None, + ), + ], +) +def test_runtime_version_metadata_allows_expected_cases( + local: RuntimeVersionMetadata, + peer: RuntimeVersionMetadata | None, +) -> None: + """Compatible peers and absent metadata should continue.""" + assert local.check_compatibility(peer) is None + + +@pytest.mark.parametrize( + ("local", "peer", "expected_rejection"), + [ + ( + RuntimeVersionMetadata("flwr", "1.29.2", "SuperLink"), + RuntimeVersionMetadata("flwr", "1.30.0", "SuperNode"), + "SuperLink version 1.29.2 only accepts peers from the same " + "major.minor release, but received SuperNode version 1.30.0.", + ), + ( + RuntimeVersionMetadata("unknown", "unknown", "SuperLink"), + RuntimeVersionMetadata("flwr", "1.29.0", "flwr-simulation"), + None, + ), + ( + RuntimeVersionMetadata("flwr", "1.29.0", "SuperLink"), + RuntimeVersionMetadata("flwr", "main", "SuperNode"), + "Peer Flower version metadata cannot be parsed: 'main'.", + ), + ( + RuntimeVersionMetadata("flwr", "1.29.0", "SuperLink"), + RuntimeVersionMetadata("forked-flower", "1.29.1", "SuperNode"), + "Peer Flower package name is not recognized: 'forked-flower'.", + ), + ], +) +def test_runtime_version_metadata_rejects_expected_cases( + local: RuntimeVersionMetadata, + peer: RuntimeVersionMetadata, + expected_rejection: str | None, +) -> None: + """Explicitly invalid or incompatible peers should be rejected.""" + assert local.check_compatibility(peer) == expected_rejection diff --git a/framework/py/flwr/supercore/state/alembic/versions/rev_2026_04_27_add_primary_task_to_run_and_more.py b/framework/py/flwr/supercore/state/alembic/versions/rev_2026_04_27_add_primary_task_to_run_and_more.py new file mode 100644 index 000000000000..715d6bd13345 --- /dev/null +++ b/framework/py/flwr/supercore/state/alembic/versions/rev_2026_04_27_add_primary_task_to_run_and_more.py @@ -0,0 +1,74 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Add primary task metadata to run and task tables. + +Revision ID: 8253e456d570 +Revises: dee9b802b5c9 +Create Date: 2026-04-27 13:00:44.155029 +""" +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# pylint: disable=no-member + +# revision identifiers, used by Alembic. +revision: str = "8253e456d570" +down_revision: str | Sequence[str] | None = "dee9b802b5c9" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("run", schema=None) as batch_op: + batch_op.add_column( + sa.Column("primary_task_id", sa.BigInteger(), nullable=True) + ) + + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "sub_status", + sa.String(), + server_default=sa.text("''"), + nullable=False, + ) + ) + batch_op.add_column( + sa.Column( + "details", + sa.String(), + server_default=sa.text("''"), + nullable=False, + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.drop_column("details") + batch_op.drop_column("sub_status") + + with op.batch_alter_table("run", schema=None) as batch_op: + batch_op.drop_column("primary_task_id") + + # ### end Alembic commands ### diff --git a/framework/py/flwr/supercore/state/alembic/versions/rev_2026_04_28_add_active_until_to_task_table.py b/framework/py/flwr/supercore/state/alembic/versions/rev_2026_04_28_add_active_until_to_task_table.py new file mode 100644 index 000000000000..0cd0880a543f --- /dev/null +++ b/framework/py/flwr/supercore/state/alembic/versions/rev_2026_04_28_add_active_until_to_task_table.py @@ -0,0 +1,50 @@ +# Copyright 2026 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Add active_until to task table. + +Revision ID: c7eb009ce75c +Revises: 8253e456d570 +Create Date: 2026-04-28 19:24:40.811386 +""" +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# pylint: disable=no-member + +# revision identifiers, used by Alembic. +revision: str = "c7eb009ce75c" +down_revision: str | Sequence[str] | None = "8253e456d570" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.add_column(sa.Column("active_until", sa.BigInteger(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.drop_column("active_until") + + # ### end Alembic commands ### diff --git a/framework/py/flwr/supercore/state/schema/README.md b/framework/py/flwr/supercore/state/schema/README.md index 2c2803f81843..c320f556ce7f 100644 --- a/framework/py/flwr/supercore/state/schema/README.md +++ b/framework/py/flwr/supercore/state/schema/README.md @@ -104,6 +104,7 @@ erDiagram VARCHAR flwr_aid "nullable" VARCHAR override_config "nullable" VARCHAR pending_at "nullable" + BIGINT primary_task_id "nullable" BIGINT run_id UK "nullable" VARCHAR run_type VARCHAR running_at "nullable" @@ -118,7 +119,9 @@ erDiagram } task { + BIGINT active_until "nullable" VARCHAR connector_ref "nullable" + VARCHAR details VARCHAR fab_hash "nullable" VARCHAR finished_at "nullable" VARCHAR model_ref "nullable" @@ -126,6 +129,7 @@ erDiagram BIGINT run_id VARCHAR running_at "nullable" VARCHAR starting_at "nullable" + VARCHAR sub_status BIGINT task_id UK VARCHAR token "nullable" VARCHAR type diff --git a/framework/py/flwr/supercore/state/schema/corestate_tables.py b/framework/py/flwr/supercore/state/schema/corestate_tables.py index 10485ee2b911..136bb959f542 100644 --- a/framework/py/flwr/supercore/state/schema/corestate_tables.py +++ b/framework/py/flwr/supercore/state/schema/corestate_tables.py @@ -24,6 +24,7 @@ MetaData, String, Table, + text, ) @@ -75,10 +76,13 @@ def create_corestate_metadata() -> MetaData: Column("model_ref", String, nullable=True), Column("connector_ref", String, nullable=True), Column("token", String, nullable=True), + Column("active_until", BigInteger, nullable=True), Column("pending_at", String, nullable=False), Column("starting_at", String, nullable=True), Column("running_at", String, nullable=True), Column("finished_at", String, nullable=True), + Column("sub_status", String, nullable=False, server_default=text("''")), + Column("details", String, nullable=False, server_default=text("''")), ) return metadata diff --git a/framework/py/flwr/supercore/state/schema/linkstate_tables.py b/framework/py/flwr/supercore/state/schema/linkstate_tables.py index b2e86bfde54b..7b8b10a87799 100644 --- a/framework/py/flwr/supercore/state/schema/linkstate_tables.py +++ b/framework/py/flwr/supercore/state/schema/linkstate_tables.py @@ -82,6 +82,7 @@ def create_linkstate_metadata() -> MetaData: Column("sub_status", String), Column("details", String), Column("federation", String), + Column("primary_task_id", BigInteger, nullable=True), Column("federation_config", String), Column("run_type", String, nullable=False, server_default=RunType.SERVER_APP), Column("flwr_aid", String), diff --git a/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin.py b/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin.py index 4aea328f89c2..7c3ef745b11c 100644 --- a/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin.py +++ b/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin.py @@ -20,6 +20,7 @@ from collections.abc import Callable, Sequence from flwr.common.exit import ExitCode, flwr_exit +from flwr.supercore.superexec.app_supervisor import launch_with_lifeline from .exec_plugin import ExecPlugin @@ -44,6 +45,7 @@ def select_run_id(self, candidate_run_ids: Sequence[int]) -> int | None: def launch_app(self, token: str, run_id: int) -> None: """Launch the application associated with a given run ID and token.""" + use_lifeline_supervisor = os.name == "posix" cmds = [self.command] if self.insecure: cmds += ["--insecure"] @@ -51,12 +53,16 @@ def launch_app(self, token: str, run_id: int) -> None: cmds += ["--root-certificates", self.root_certificates_path] cmds += [self.appio_api_address_arg, self.appio_api_address] cmds += ["--token", token] - cmds += ["--parent-pid", str(os.getpid())] + if not use_lifeline_supervisor: + cmds += ["--parent-pid", str(os.getpid())] if self.runtime_dependency_install: cmds += ["--allow-runtime-dependency-installation"] # Perform any cleanup before launching the app if self.cleanup_before_launch is not None: self.cleanup_before_launch() - # Launch the app process and wait for it to finish - subprocess.run(cmds, check=False) + if use_lifeline_supervisor: + launch_with_lifeline(cmds, wait=True, popen_kwargs={}) + else: + # Launch the app directly on non-POSIX and wait for it to finish. + subprocess.run(cmds, check=False) flwr_exit(ExitCode.SUCCESS, "App process finished, exiting SuperExec.") diff --git a/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin_test.py b/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin_test.py index 46215b59b676..1f4c67521062 100644 --- a/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin_test.py +++ b/framework/py/flwr/supercore/superexec/plugin/base_ephemeral_exec_plugin_test.py @@ -33,12 +33,15 @@ class _EphemeralExecPlugin(BaseEphemeralExecPlugin): appio_api_address_arg = "--serverappio-api-address" -def _get_ephemeral_plugin() -> _EphemeralExecPlugin: +def _get_ephemeral_plugin( + runtime_dependency_install: bool = False, +) -> _EphemeralExecPlugin: return _EphemeralExecPlugin( appio_api_address="127.0.0.1:9091", get_run=_get_run, insecure=True, root_certificates_path=None, + runtime_dependency_install=runtime_dependency_install, ) @@ -55,24 +58,25 @@ def test_select_run_id_returns_first_candidate() -> None: def test_launch_app_runs_expected_command_and_exits() -> None: - """Launch should invoke the app with token and parent PID, then exit.""" + """POSIX launch should invoke the app through the supervisor, then exit.""" plugin = _get_ephemeral_plugin() with ( patch( - "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.os.getpid", - return_value=1234, + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.os.name", + "posix", ), patch( - "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.subprocess.run" - ) as run, + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin." + "launch_with_lifeline" + ) as launch, patch( "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.flwr_exit" ) as flwr_exit, ): plugin.launch_app(token="token-123", run_id=5) - run.assert_called_once_with( + launch.assert_called_once_with( [ "flwr-serverapp", "--insecure", @@ -80,10 +84,9 @@ def test_launch_app_runs_expected_command_and_exits() -> None: "127.0.0.1:9091", "--token", "token-123", - "--parent-pid", - "1234", ], - check=False, + wait=True, + popen_kwargs={}, ) flwr_exit.assert_called_once_with( ExitCode.SUCCESS, @@ -92,7 +95,7 @@ def test_launch_app_runs_expected_command_and_exits() -> None: def test_launch_app_calls_cleanup_before_launch() -> None: - """Launch should invoke cleanup_before_launch before running the subprocess.""" + """Launch should invoke cleanup before supervision, then exit.""" # Prepare call_log: list[str] = [] plugin = _get_ephemeral_plugin() @@ -101,12 +104,85 @@ def test_launch_app_calls_cleanup_before_launch() -> None: # Execute with ( patch( - "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.subprocess.run", - side_effect=lambda *_, **__: call_log.append("subprocess"), + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.os.name", + "posix", + ), + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin." + "launch_with_lifeline", + side_effect=lambda *_, **__: call_log.append("launch"), + ), + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.flwr_exit", + side_effect=lambda *_, **__: call_log.append("exit"), ), - patch("flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.flwr_exit"), ): plugin.launch_app(token="token-abc", run_id=1) # Assert - assert call_log == ["cleanup", "subprocess"] + assert call_log == ["cleanup", "launch", "exit"] + + +def test_launch_app_forwards_runtime_dependency_install_flag() -> None: + """POSIX launch should preserve optional runtime install flags.""" + plugin = _get_ephemeral_plugin(runtime_dependency_install=True) + + with ( + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.os.name", + "posix", + ), + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin." + "launch_with_lifeline" + ) as launch, + patch("flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.flwr_exit"), + ): + plugin.launch_app(token="token-123", run_id=5) + + assert launch.call_args.args[0] == [ + "flwr-serverapp", + "--insecure", + "--serverappio-api-address", + "127.0.0.1:9091", + "--token", + "token-123", + "--allow-runtime-dependency-installation", + ] + assert "--parent-pid" not in launch.call_args.args[0] + + +def test_launch_app_non_posix_fallback_passes_parent_pid() -> None: + """Non-POSIX launch should keep the existing parent PID behavior.""" + plugin = _get_ephemeral_plugin() + + with ( + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.os.name", + "nt", + ), + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.os.getpid", + return_value=1234, + ), + patch( + "flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin." + "subprocess.run" + ) as run, + patch("flwr.supercore.superexec.plugin.base_ephemeral_exec_plugin.flwr_exit"), + ): + plugin.launch_app(token="token-123", run_id=5) + + run.assert_called_once_with( + [ + "flwr-serverapp", + "--insecure", + "--serverappio-api-address", + "127.0.0.1:9091", + "--token", + "token-123", + "--parent-pid", + "1234", + ], + check=False, + ) diff --git a/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin.py b/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin.py index 795c245e0553..7bb4e24afa0d 100644 --- a/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin.py +++ b/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin.py @@ -20,6 +20,8 @@ from collections.abc import Sequence from typing import Any +from flwr.supercore.superexec.app_supervisor import launch_with_lifeline + from .exec_plugin import ExecPlugin @@ -41,6 +43,7 @@ def select_run_id(self, candidate_run_ids: Sequence[int]) -> int | None: def launch_app(self, token: str, run_id: int) -> None: """Launch the application associated with a given run ID and token.""" + use_lifeline_supervisor = os.name == "posix" cmds = [self.command] if self.insecure: cmds.append("--insecure") @@ -48,10 +51,18 @@ def launch_app(self, token: str, run_id: int) -> None: cmds += ["--root-certificates", self.root_certificates_path] cmds += [self.appio_api_address_arg, self.appio_api_address] cmds += ["--token", token] - cmds += ["--parent-pid", str(os.getpid())] + if not use_lifeline_supervisor: + cmds += ["--parent-pid", str(os.getpid())] if self.runtime_dependency_install: cmds += ["--allow-runtime-dependency-installation"] - # Launch the client app without waiting for it to complete. + if use_lifeline_supervisor: + launch_with_lifeline( + cmds, + wait=False, + popen_kwargs=self.get_popen_kwargs(), + ) + return + # Launch the app directly on non-POSIX without waiting for it to complete. # Since we don't need to manage the process, we intentionally avoid using # a `with` statement. Suppress the pylint warning for it in this case. # pylint: disable-next=consider-using-with diff --git a/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin_test.py b/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin_test.py index d73ebd7c6e12..360ee537351d 100644 --- a/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin_test.py +++ b/framework/py/flwr/supercore/superexec/plugin/base_exec_plugin_test.py @@ -39,11 +39,16 @@ def test_clientapp_launch_inherits_default_stdio() -> None: get_run=_get_run, ) - with patch("subprocess.Popen") as popen: + with ( + patch("flwr.supercore.superexec.plugin.base_exec_plugin.os.name", "posix"), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.launch_with_lifeline" + ) as launch, + ): plugin.launch_app(token="token", run_id=7) - assert "stdout" not in popen.call_args.kwargs - assert "stderr" not in popen.call_args.kwargs + assert "stdout" not in launch.call_args.kwargs["popen_kwargs"] + assert "stderr" not in launch.call_args.kwargs["popen_kwargs"] def test_serverapp_launch_isolates_stdio() -> None: @@ -55,11 +60,16 @@ def test_serverapp_launch_isolates_stdio() -> None: get_run=_get_run, ) - with patch("subprocess.Popen") as popen: + with ( + patch("flwr.supercore.superexec.plugin.base_exec_plugin.os.name", "posix"), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.launch_with_lifeline" + ) as launch, + ): plugin.launch_app(token="token", run_id=5) - assert popen.call_args.kwargs["stdout"] is subprocess.DEVNULL - assert popen.call_args.kwargs["stderr"] is subprocess.DEVNULL + assert launch.call_args.kwargs["popen_kwargs"]["stdout"] is subprocess.DEVNULL + assert launch.call_args.kwargs["popen_kwargs"]["stderr"] is subprocess.DEVNULL class DummyExecPlugin(BaseExecPlugin): @@ -81,26 +91,26 @@ def test_launch_app_forwards_runtime_dependency_install_flag() -> None: with ( patch( - "flwr.supercore.superexec.plugin.base_exec_plugin.os.getpid", - return_value=1234, + "flwr.supercore.superexec.plugin.base_exec_plugin.os.name", + "posix", ), patch( - "flwr.supercore.superexec.plugin.base_exec_plugin.subprocess.Popen" - ) as popen, + "flwr.supercore.superexec.plugin.base_exec_plugin.launch_with_lifeline" + ) as launch, ): plugin.launch_app(token="token-123", run_id=7) - assert popen.call_args.args[0] == [ + assert launch.call_args.args[0] == [ "dummy-app", "--insecure", "--appio-api-address", "127.0.0.1:9091", "--token", "token-123", - "--parent-pid", - "1234", "--allow-runtime-dependency-installation", ] + assert "--parent-pid" not in launch.call_args.args[0] + assert launch.call_args.kwargs["wait"] is False def test_launch_app_skips_optional_runtime_flags_by_default() -> None: @@ -112,12 +122,15 @@ def test_launch_app_skips_optional_runtime_flags_by_default() -> None: get_run=Mock(), ) - with patch( - "flwr.supercore.superexec.plugin.base_exec_plugin.subprocess.Popen" - ) as popen: + with ( + patch("flwr.supercore.superexec.plugin.base_exec_plugin.os.name", "posix"), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.launch_with_lifeline" + ) as launch, + ): plugin.launch_app(token="token-123", run_id=7) - assert "--allow-runtime-dependency-installation" not in popen.call_args.args[0] + assert "--allow-runtime-dependency-installation" not in launch.call_args.args[0] def test_clientapp_launch_forwards_root_certificate() -> None: @@ -129,10 +142,15 @@ def test_clientapp_launch_forwards_root_certificate() -> None: get_run=_get_run, ) - with patch("subprocess.Popen") as mock_popen: + with ( + patch("flwr.supercore.superexec.plugin.base_exec_plugin.os.name", "posix"), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.launch_with_lifeline" + ) as launch, + ): plugin.launch_app(token="token", run_id=7) - assert mock_popen.call_args.args[0][:3] == [ + assert launch.call_args.args[0][:3] == [ "flwr-clientapp", "--root-certificates", "/tmp/root.pem", @@ -148,8 +166,46 @@ def test_clientapp_launch_omits_tls_flags_when_using_system_certificates() -> No get_run=_get_run, ) - with patch("subprocess.Popen") as mock_popen: + with ( + patch("flwr.supercore.superexec.plugin.base_exec_plugin.os.name", "posix"), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.launch_with_lifeline" + ) as launch, + ): plugin.launch_app(token="token", run_id=7) - assert "--insecure" not in mock_popen.call_args.args[0] - assert "--root-certificates" not in mock_popen.call_args.args[0] + assert "--insecure" not in launch.call_args.args[0] + assert "--root-certificates" not in launch.call_args.args[0] + + +def test_launch_app_non_posix_fallback_passes_parent_pid() -> None: + """Non-POSIX launch should keep the existing parent PID behavior.""" + plugin = DummyExecPlugin( + appio_api_address="127.0.0.1:9091", + insecure=True, + root_certificates_path=None, + get_run=Mock(), + ) + + with ( + patch("flwr.supercore.superexec.plugin.base_exec_plugin.os.name", "nt"), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.os.getpid", + return_value=1234, + ), + patch( + "flwr.supercore.superexec.plugin.base_exec_plugin.subprocess.Popen" + ) as popen, + ): + plugin.launch_app(token="token-123", run_id=7) + + assert popen.call_args.args[0] == [ + "dummy-app", + "--insecure", + "--appio-api-address", + "127.0.0.1:9091", + "--token", + "token-123", + "--parent-pid", + "1234", + ] diff --git a/framework/py/flwr/supercore/utils.py b/framework/py/flwr/supercore/utils.py index 15dfea40f4c1..846321b57aef 100644 --- a/framework/py/flwr/supercore/utils.py +++ b/framework/py/flwr/supercore/utils.py @@ -20,10 +20,10 @@ import os import re import sys -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from logging import WARN from pathlib import Path -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar import requests @@ -38,6 +38,28 @@ PR_SET_DUMPABLE = 4 # from /usr/include/linux/prctl.h +MetadataLookupErrorType = Literal["missing", "duplicate", "wrong_type", "empty"] + + +class MetadataLookupError(Exception): + """Error type for metadata lookup failures.""" + + def __init__(self, key: str, error_type: MetadataLookupErrorType) -> None: + self.key = key + self.error_type = error_type + if error_type == "missing": + message = f"Metadata key '{key}' is missing." + elif error_type == "duplicate": + message = f"Metadata key '{key}' has duplicate values." + elif error_type == "wrong_type": + message = f"Metadata key '{key}' has a value of the wrong type." + elif error_type == "empty": + message = f"Metadata key '{key}' has an empty value." + else: + message = f"Metadata key '{key}' has an unknown error: {error_type}." + super().__init__(message) + + def mask_string(value: str, head: int = 4, tail: int = 4) -> str: """Mask a string by preserving only the head and tail characters. @@ -343,41 +365,98 @@ def is_valid_name(name: str) -> tuple[bool, str]: return True, "" -def _get_metadata_typed( +def find_metadata_keys( + metadata: Sequence[tuple[str, str | bytes]] | None, + keys: Iterable[str], +) -> set[str]: + """Return the subset of `keys` present in the gRPC metadata sequence.""" + if metadata is None: + return set() + + key_set = set(keys) + return {metadata_key for metadata_key, _ in metadata if metadata_key in key_set} + + +def _get_metadata_typed_checked( metadata: Sequence[tuple[str, str | bytes]] | None, key: str, value_type: type[T], -) -> T | None: - """Return exactly one non-empty string or bytes metadata value for `key`.""" - if metadata is None: - return None +) -> T: + """Return exactly one non-empty string or bytes metadata value for `key`. + + Raises + ------ + MetadataLookupError + If the metadata value for `key` is missing, duplicated, of the wrong type, + or empty. + """ values: list[Any] = [ - value for metadata_key, value in metadata if metadata_key == key + value for metadata_key, value in metadata or [] if metadata_key == key ] - if len(values) != 1: - return None + if not values: + raise MetadataLookupError(key, "missing") + if len(values) > 1: + raise MetadataLookupError(key, "duplicate") value = values[0] if not isinstance(value, value_type): - return None + raise MetadataLookupError(key, "wrong_type") if value in ("", b""): - return None + raise MetadataLookupError(key, "empty") return value +def get_metadata_str_checked( + metadata: Sequence[tuple[str, str | bytes]] | None, + key: str, +) -> str: + """Return exactly one non-empty string metadata value for `key`. + + Raises + ------ + MetadataLookupError + If the metadata value for `key` is missing, duplicated, of the wrong type, + or empty. + """ + return _get_metadata_typed_checked(metadata, key, str) + + +def get_metadata_bytes_checked( + metadata: Sequence[tuple[str, str | bytes]] | None, + key: str, +) -> bytes: + """Return exactly one non-empty bytes metadata value for `key`. + + Raises + ------ + MetadataLookupError + If the metadata value for `key` is missing, duplicated, of the wrong type, + or empty. + """ + return _get_metadata_typed_checked(metadata, key, bytes) + + def get_metadata_str( metadata: Sequence[tuple[str, str | bytes]] | None, key: str, ) -> str | None: - """Return exactly one non-empty string metadata value for `key`.""" - return _get_metadata_typed(metadata, key, str) + """Return exactly one non-empty string metadata value for `key`, or None if not + found or invalid.""" + try: + return get_metadata_str_checked(metadata, key) + except MetadataLookupError: + return None def get_metadata_bytes( metadata: Sequence[tuple[str, str | bytes]] | None, key: str, ) -> bytes | None: - """Return exactly one non-empty bytes metadata value for `key`.""" - return _get_metadata_typed(metadata, key, bytes) + """Return exactly one non-empty bytes metadata value for `key`, or None if not found + or invalid.""" + try: + return get_metadata_bytes_checked(metadata, key) + except MetadataLookupError: + return None def disable_process_dumping(strict: bool) -> None: diff --git a/framework/py/flwr/supercore/utils_test.py b/framework/py/flwr/supercore/utils_test.py index 1cd5f4591257..1b9e679b151a 100644 --- a/framework/py/flwr/supercore/utils_test.py +++ b/framework/py/flwr/supercore/utils_test.py @@ -25,8 +25,11 @@ from flwr.proto.federation_config_pb2 import SimulationConfig # pylint: disable=E0611 from .utils import ( + MetadataLookupError, + find_metadata_keys, get_metadata_bytes, get_metadata_str, + get_metadata_str_checked, humanize_bytes, humanize_duration, int64_to_uint64, @@ -39,6 +42,18 @@ ) +def test_find_metadata_keys() -> None: + """Return the subset of requested keys present in metadata.""" + assert find_metadata_keys( + [ + ("x-token", "value"), + ("x-trace-id", "abc"), + ("x-token", "other"), + ], + ("x-token", "missing"), + ) == {"x-token"} + + def test_mask_string() -> None: """Test the `mask_string` function.""" assert mask_string("abcdefghi") == "abcd...fghi" @@ -72,6 +87,33 @@ def test_get_metadata_str( assert get_metadata_str(metadata, key) == expected +@pytest.mark.parametrize( + ("metadata", "key", "expected_value", "expected_error"), + [ + ([("x-token", "value")], "x-token", "value", None), + ([("x-token", "")], "x-token", None, "empty"), + ([("x-token", "value"), ("x-token", "other")], "x-token", None, "duplicate"), + ([("x-token", b"value")], "x-token", None, "wrong_type"), + ([("other", "value")], "x-token", None, "missing"), + ], +) +def test_get_metadata_str_checked( + metadata: list[tuple[str, str | bytes]], + key: str, + expected_value: str | None, + expected_error: str | None, +) -> None: + """Preserve metadata validation outcomes for callers that need them.""" + value, error_type = None, None + try: + value = get_metadata_str_checked(metadata, key) + except MetadataLookupError as e: + error_type = e.error_type + + assert value == expected_value + assert error_type == expected_error + + @pytest.mark.parametrize( ("metadata", "key", "expected"), [ diff --git a/framework/py/flwr/superlink/federation/noop_federation_manager_test.py b/framework/py/flwr/superlink/federation/noop_federation_manager_test.py index 49797132d2e0..7e6b3e8056db 100644 --- a/framework/py/flwr/superlink/federation/noop_federation_manager_test.py +++ b/framework/py/flwr/superlink/federation/noop_federation_manager_test.py @@ -77,6 +77,7 @@ def test_get_details_with_valid_federation() -> None: status=RunStatus(status="running", sub_status="", details=""), flwr_aid=NOOP_FLWR_AID, federation=NOOP_FEDERATION, + primary_task_id=None, bytes_sent=1024, bytes_recv=512, clientapp_runtime=1.1, @@ -94,6 +95,7 @@ def test_get_details_with_valid_federation() -> None: status=RunStatus(status="finished", sub_status="", details=""), flwr_aid=NOOP_FLWR_AID, federation=NOOP_FEDERATION, + primary_task_id=None, bytes_sent=2048, bytes_recv=1024, clientapp_runtime=1.2, diff --git a/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py b/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py index de68dd9b839a..a546e2510a1b 100644 --- a/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py +++ b/framework/py/flwr/supernode/nodestate/in_memory_nodestate.py @@ -194,7 +194,7 @@ def get_run_ids_with_pending_messages(self) -> Sequence[int]: ret -= set(self.token_store.keys()) return list(ret) - def _on_tokens_expired(self, expired_records: list[tuple[int, float]]) -> None: + def _on_tokens_expired(self, expired_records: list[tuple[int, int]]) -> None: """Insert error replies for messages associated with expired tokens.""" with self.lock_msg_store: # Find all retrieved messages associated with expired run IDs