Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7600127
fix(framework): Validate run_id
danieljanes Apr 21, 2026
b52b2a8
Pass missing run_id
danieljanes Apr 21, 2026
1ea3270
Update framework/py/flwr/supercore/interceptors/appio_token_intercept…
danieljanes Apr 21, 2026
ac207b5
Fix
danieljanes Apr 21, 2026
8b72851
Merge remote-tracking branch 'refs/remotes/origin/validate-run-id' in…
danieljanes Apr 21, 2026
b030fdb
Fix
danieljanes Apr 21, 2026
d39ce8d
fix serverappio servicer test
panh99 Apr 21, 2026
655d5c1
Merge branch 'main' into validate-run-id
danieljanes Apr 21, 2026
f2266b1
feat(framework): Prevent simulation runtime from accessing grid endpo…
danieljanes Apr 22, 2026
b982797
Merge branch 'main' into validate-run-id
danieljanes Apr 23, 2026
26eb792
Merge branch 'main' into validate-run-id
danieljanes Apr 27, 2026
fa298d6
Merge branch 'main' into validate-run-id
danieljanes May 3, 2026
43209a1
Merge branch 'main' into validate-run-id
danieljanes May 4, 2026
b5af48a
ci(framework): Add framework/AGENTS.md
danieljanes May 4, 2026
88cce57
Merge branch 'main' into validate-run-id
danieljanes May 4, 2026
a30e013
Merge branch 'framework-agents-md' into validate-run-id
danieljanes May 4, 2026
f822794
Merge remote-tracking branch 'refs/remotes/origin/validate-run-id' in…
danieljanes May 4, 2026
83a67e2
Update test
danieljanes May 4, 2026
5675d6c
Merge branch 'main' into validate-run-id
danieljanes May 4, 2026
5e0c589
Update framework/py/flwr/supercore/interceptors/appio_token_intercept…
danieljanes May 4, 2026
734b5e6
Merge branch 'validate-run-id' into validate-run-type
danieljanes May 4, 2026
1a9cb30
fix(framework): Prefer task tokens in AppIo auth and enforce run bind…
danieljanes May 9, 2026
955bdaf
Merge branch 'main' into validate-run-type
danieljanes May 10, 2026
7fa18cf
Merge branch 'main' into validate-run-type
danieljanes May 14, 2026
c0d4c57
Refactor
danieljanes May 14, 2026
498ef2a
Fix test
danieljanes May 14, 2026
339e231
Merge branch 'main' into validate-run-type
danieljanes May 14, 2026
dd08d53
Simplify
danieljanes May 14, 2026
159713d
Merge branch 'main' into validate-run-type
danieljanes May 14, 2026
e73664c
Update serverappio_servicer.py
panh99 May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,33 @@

import tempfile
import unittest
from collections.abc import Callable

import grpc
from google.protobuf.message import Message as GrpcMessage
from parameterized import parameterized

from flwr.common.constant import SERVERAPPIO_API_DEFAULT_SERVER_ADDRESS
from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
PullAppMessagesRequest,
PullAppMessagesResponse,
PushAppMessagesRequest,
PushAppMessagesResponse,
SendTaskHeartbeatRequest,
SendTaskHeartbeatResponse,
)
from flwr.proto.log_pb2 import ( # pylint: disable=E0611
PushLogsRequest,
PushLogsResponse,
)
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
ConfirmMessageReceivedRequest,
ConfirmMessageReceivedResponse,
PullObjectRequest,
PullObjectResponse,
PushObjectRequest,
PushObjectResponse,
)
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
GetNodesRequest,
GetNodesResponse,
Expand Down Expand Up @@ -70,17 +93,20 @@ def setUp(self) -> None:

# Seed one authenticated task token and reuse it for token-protected RPC
# checks.
self._auth_run_id, auth_token = self._create_running_run()
_, auth_token = self._create_running_run()
_, self._simulation_token = self._create_running_run(
run_type=RunType.SIMULATION
)

# Create a single base channel and wrap it for authenticated calls.
base_channel = grpc.insecure_channel("localhost:9091")
self._get_nodes_no_auth = base_channel.unary_unary(
self._base_channel = grpc.insecure_channel("localhost:9091")
self._get_nodes_no_auth = self._base_channel.unary_unary(
"/flwr.proto.ServerAppIo/GetNodes",
request_serializer=GetNodesRequest.SerializeToString,
response_deserializer=GetNodesResponse.FromString,
)
auth_channel = grpc.intercept_channel(
base_channel,
self._base_channel,
AppIoTokenClientInterceptor(token=auth_token),
SuperExecAuthClientInterceptor(
master_secret=_SUPEREXEC_SECRET,
Expand All @@ -95,11 +121,14 @@ def setUp(self) -> None:

def tearDown(self) -> None:
"""Stop the gRPC API server."""
self._base_channel.close()
self._server.stop(None)

def _create_running_run(self) -> tuple[int, str]:
def _create_running_run(
self, run_type: str = RunType.SERVER_APP
) -> tuple[int, str]:
run_id = self.state.create_run(
"", "", "", {}, NOOP_FEDERATION, None, "", RunType.SERVER_APP
"", "", "", {}, NOOP_FEDERATION, None, "", run_type
)
run = self.state.get_run_info(run_ids=[run_id])[0]
assert run.primary_task_id is not None
Expand Down Expand Up @@ -131,3 +160,99 @@ def test_get_nodes_allows_with_valid_metadata_token(self) -> None:

assert isinstance(response, GetNodesResponse)
assert call.code() == grpc.StatusCode.OK

@parameterized.expand(
[
(
"get_nodes",
"/flwr.proto.ServerAppIo/GetNodes",
GetNodesRequest(),
GetNodesResponse.FromString,
),
(
"push_messages",
"/flwr.proto.ServerAppIo/PushMessages",
PushAppMessagesRequest(),
PushAppMessagesResponse.FromString,
),
(
"pull_messages",
"/flwr.proto.ServerAppIo/PullMessages",
PullAppMessagesRequest(),
PullAppMessagesResponse.FromString,
),
(
"push_object",
"/flwr.proto.ServerAppIo/PushObject",
PushObjectRequest(),
PushObjectResponse.FromString,
),
(
"pull_object",
"/flwr.proto.ServerAppIo/PullObject",
PullObjectRequest(),
PullObjectResponse.FromString,
),
(
"confirm_message_received",
"/flwr.proto.ServerAppIo/ConfirmMessageReceived",
ConfirmMessageReceivedRequest(),
ConfirmMessageReceivedResponse.FromString,
),
]
) # type: ignore
def test_serverapp_only_endpoint_denied_for_simulation_run(
self,
_case_name: str,
method: str,
request: GrpcMessage,
response_deserializer: Callable[[bytes], object],
) -> None:
"""ServerApp-only RPCs should deny simulation-run tokens."""
rpc = self._base_channel.unary_unary(
method,
request_serializer=type(request).SerializeToString,
response_deserializer=response_deserializer,
)
with self.assertRaises(grpc.RpcError) as err:
rpc.with_call(
request=request,
metadata=((TASK_TOKEN_HEADER, self._simulation_token),),
)
assert err.exception.code() == grpc.StatusCode.PERMISSION_DENIED

@parameterized.expand(
[
(
"send_task_heartbeat",
"/flwr.proto.ServerAppIo/SendTaskHeartbeat",
SendTaskHeartbeatRequest(),
SendTaskHeartbeatResponse.FromString,
),
(
"push_logs",
"/flwr.proto.ServerAppIo/PushLogs",
PushLogsRequest(logs=["hello"]),
PushLogsResponse.FromString,
),
]
) # type: ignore
def test_shared_task_endpoint_allows_simulation_run(
self,
_case_name: str,
method: str,
request: GrpcMessage,
response_deserializer: Callable[[bytes], object],
) -> None:
"""Shared task RPCs should still allow simulation-run tokens."""
rpc = self._base_channel.unary_unary(
method,
request_serializer=type(request).SerializeToString,
response_deserializer=response_deserializer,
)
response, call = rpc.with_call(
request=request,
metadata=((TASK_TOKEN_HEADER, self._simulation_token),),
)
assert response is not None
assert call.code() == grpc.StatusCode.OK
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory
from flwr.server.utils.validator import validate_message
from flwr.supercore.constant import TaskType
from flwr.supercore.inflatable.inflatable_object import (
UnexpectedObjectContentError,
get_all_nested_objects,
Expand All @@ -72,6 +73,10 @@
from flwr.supercore.object_store import NoObjectInStoreError, ObjectStoreFactory
from flwr.supercore.servicers import AppIoServicer

SERVERAPPIO_ENDPOINT_UNAVAILABLE_MESSAGE = (
"Some ServerAppIo API endpoints are only available for Deployment Runtime runs."
)


class ServerAppIoServicer(AppIoServicer, serverappio_pb2_grpc.ServerAppIoServicer):
"""ServerAppIo API servicer."""
Expand All @@ -97,9 +102,7 @@ def GetNodes(
# Init state
state = self.state_factory.state()

# Get authenticated task and associated run ID
task = get_authenticated_task()
run_id = task.run_id
run_id = _get_authenticated_serverapp_run_id(context)

all_ids: set[int] = state.get_nodes(run_id)
nodes: list[Node] = [Node(node_id=node_id) for node_id in all_ids]
Expand All @@ -115,9 +118,7 @@ def PushMessages(
state = self.state_factory.state()
store = self.objectstore_factory.store()

# Get authenticated task and associated run ID
task = get_authenticated_task()
run_id = task.run_id
run_id = _get_authenticated_serverapp_run_id(context)

# Validate request and insert in State
_raise_if(
Expand Down Expand Up @@ -165,9 +166,7 @@ def PullMessages( # pylint: disable=R0914
state = self.state_factory.state()
store = self.objectstore_factory.store()

# Get authenticated task and associated run ID
task = get_authenticated_task()
run_id = task.run_id
run_id = _get_authenticated_serverapp_run_id(context)

# Read from state
messages_res: list[Message] = state.get_message_res(
Expand Down Expand Up @@ -311,9 +310,11 @@ def PushObject(
"""Push an object to the ObjectStore."""
log(DEBUG, "ServerAppIoServicer.PushObject")

# Init state and store
# Init store
store = self.objectstore_factory.store()

_ = _get_authenticated_serverapp_run_id(context)

if request.node.node_id != SUPERLINK_NODE_ID:
# Cancel insertion in ObjectStore
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
Expand All @@ -337,9 +338,11 @@ def PullObject(
"""Pull an object from the ObjectStore."""
log(DEBUG, "ServerAppIoServicer.PullObject")

# Init state and store
# Init store
store = self.objectstore_factory.store()

_ = _get_authenticated_serverapp_run_id(context)

if request.node.node_id != SUPERLINK_NODE_ID:
# Cancel insertion in ObjectStore
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Unexpected node ID.")
Expand All @@ -361,15 +364,28 @@ def ConfirmMessageReceived(
"""Confirm message received."""
log(DEBUG, "ServerAppIoServicer.ConfirmMessageReceived")

# Init state and store
# Init store
store = self.objectstore_factory.store()

_ = _get_authenticated_serverapp_run_id(context)

# Delete the message object
store.delete(request.message_object_id)

return ConfirmMessageReceivedResponse()


def _get_authenticated_serverapp_run_id(context: grpc.ServicerContext) -> int:
"""Return the authenticated run ID if it can use ServerAppIo endpoints."""
task = get_authenticated_task()
if task.type != TaskType.SERVER_APP:
context.abort(
grpc.StatusCode.PERMISSION_DENIED,
SERVERAPPIO_ENDPOINT_UNAVAILABLE_MESSAGE,
)
return task.run_id


def _raise_if(validation_error: bool, request_name: str, detail: str) -> None:
"""Raise a `ValueError` with a detailed message if a validation error occurs."""
if validation_error:
Expand Down
Loading