Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7600127
fix(framework): Validate run_id
danieljanes Apr 21, 2026
b52b2a8
Pass missing run_id
danieljanes Apr 21, 2026
1ea3270
Update framework/py/flwr/supercore/interceptors/appio_token_intercept…
danieljanes Apr 21, 2026
ac207b5
Fix
danieljanes Apr 21, 2026
8b72851
Merge remote-tracking branch 'refs/remotes/origin/validate-run-id' in…
danieljanes Apr 21, 2026
b030fdb
Fix
danieljanes Apr 21, 2026
d39ce8d
fix serverappio servicer test
panh99 Apr 21, 2026
655d5c1
Merge branch 'main' into validate-run-id
danieljanes Apr 21, 2026
b982797
Merge branch 'main' into validate-run-id
danieljanes Apr 23, 2026
26eb792
Merge branch 'main' into validate-run-id
danieljanes Apr 27, 2026
fa298d6
Merge branch 'main' into validate-run-id
danieljanes May 3, 2026
43209a1
Merge branch 'main' into validate-run-id
danieljanes May 4, 2026
b5af48a
ci(framework): Add framework/AGENTS.md
danieljanes May 4, 2026
88cce57
Merge branch 'main' into validate-run-id
danieljanes May 4, 2026
a30e013
Merge branch 'framework-agents-md' into validate-run-id
danieljanes May 4, 2026
f822794
Merge remote-tracking branch 'refs/remotes/origin/validate-run-id' in…
danieljanes May 4, 2026
83a67e2
Update test
danieljanes May 4, 2026
5675d6c
Merge branch 'main' into validate-run-id
danieljanes May 4, 2026
5e0c589
Update framework/py/flwr/supercore/interceptors/appio_token_intercept…
danieljanes May 4, 2026
8412f77
Merge branch 'main' into validate-run-id
danieljanes May 7, 2026
956655b
Merge branch 'main' into validate-run-id
danieljanes May 7, 2026
a0531d0
Simplfy tests
danieljanes May 7, 2026
89a20cb
Handle omitted run_id
danieljanes May 7, 2026
8ee0adc
Merge branch 'main' into validate-run-id
danieljanes May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -332,24 +332,10 @@ def setUp(self) -> None:
superexec_auth_secret=_SUPEREXEC_SECRET,
)

# Provide a valid metadata token on the default test channel so existing
# servicer behavior tests continue to exercise business logic paths.
self._auth_run_id = self.state.create_run(
"", "", "", {}, NOOP_FEDERATION, None, "", RunType.SERVER_APP
)
auth_token = self.state.create_token(self._auth_run_id)
assert auth_token is not None
self._auth_token = auth_token
self._appio_auth_interceptor = AppIoTokenClientInterceptor(auth_token)
_ = self.state.update_run_status(
self._auth_run_id, RunStatus(Status.STARTING, "", "")
)
_ = self.state.update_run_status(
self._auth_run_id, RunStatus(Status.RUNNING, "", "")
)
self._appio_token_client_interceptor = AppIoTokenClientInterceptor(token="")
self._channel = grpc.intercept_channel(
grpc.insecure_channel("localhost:9091"),
self._appio_auth_interceptor,
self._appio_token_client_interceptor,
SuperExecAuthClientInterceptor(
master_secret=_SUPEREXEC_SECRET,
protected_methods=SERVERAPPIO_SUPEREXEC_METHODS,
Expand Down Expand Up @@ -429,10 +415,19 @@ def _create_dummy_run(self, running: bool = True, *, fab_hash: str = "") -> int:
"",
RunType.SERVER_APP,
)

# Set token in the client interceptor to pass authentication
assert (token := self.state.create_token(run_id)) is not None
self._set_token(token)

# Transition run status if required
if running:
self._transition_run_status(run_id, 2)
return run_id

def _set_token(self, token: str) -> None:
self._appio_token_client_interceptor._token = token # pylint: disable=W0212

def test_successful_get_node_if_running(self) -> None:
"""Test `GetNode` success."""
# Prepare
Expand Down Expand Up @@ -542,13 +537,20 @@ def test_create_task_rejects_missing_required_fields(
assert err.exception.code() == grpc.StatusCode.FAILED_PRECONDITION
assert err.exception.details() == error_msg

def test_create_task_fast_fails_missing_run(self) -> None:
"""Test `CreateTask` propagates an unknown run ID as an RPC failure."""
def test_create_task_rejects_missing_run(self) -> None:
"""Test `CreateTask` rejects unknown run IDs."""
# Seed a stale token binding so auth passes and the servicer still
# exercises the missing-run branch.
run_id = 42
token = self.state.create_token(run_id)
assert token is not None
self._set_token(token)

with self.assertRaises(grpc.RpcError) as err:
self._create_task.with_call(
request=CreateTaskRequest(
type=TaskType.MODEL,
run_id=42,
run_id=run_id,
model_ref="model://test",
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

APP_TOKEN_HEADER = "flwr-app-token"
AUTHENTICATION_FAILED_MESSAGE = "Authentication failed."
RUN_BINDING_FAILED_MESSAGE = "Token is not valid for requested run."


_current_task: ContextVar[Task | None] = ContextVar("current_task", default=None)
Expand All @@ -57,6 +58,11 @@ def _abort_auth_denied(context: grpc.ServicerContext) -> NoReturn:
raise RuntimeError("Should not reach this point")


def _abort_run_denied(context: grpc.ServicerContext) -> NoReturn:
context.abort(grpc.StatusCode.PERMISSION_DENIED, RUN_BINDING_FAILED_MESSAGE)
raise RuntimeError("Should not reach this point")


def _unauthenticated_terminator() -> grpc.RpcMethodHandler:
def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage:
context.abort(grpc.StatusCode.UNAUTHENTICATED, AUTHENTICATION_FAILED_MESSAGE)
Expand All @@ -65,6 +71,14 @@ def _terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMess
return grpc.unary_unary_rpc_method_handler(_terminate)


def _get_request_run_id(request: GrpcMessage) -> int | None:
run_id = cast(int | None, getattr(request, "run_id", None))
# Proto3 scalar fields without presence expose an omitted uint64 as 0.
if run_id == 0:
return None
return run_id


class AppIoTokenClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
"""Attach App token metadata to outbound unary RPCs."""

Expand Down Expand Up @@ -139,14 +153,20 @@ def _authenticated_handler(

state = self._state_provider()

# Legacy: Validate both token->run lookup and run->token mapping.
# Legacy: validate both token->run lookup and run->token mapping.
run_id = state.get_run_id_by_token(token)
if run_id is not None and state.verify_token(run_id, token):
request_run_id: int | None = _get_request_run_id(request)
if request_run_id is not None and request_run_id != run_id:
_abort_run_denied(context)
return unary_handler(request, context)

# Validate task token and set task context for downstream handlers
# Validate task token and set task context for downstream handlers.
task = state.get_task_by_token(token)
if task is not None:
request_run_id = _get_request_run_id(request)
if request_run_id is not None and request_run_id != task.run_id:
_abort_run_denied(context)
Comment thread
danieljanes marked this conversation as resolved.
ctx_token = _current_task.set(task)
try:
return unary_handler(request, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
create_serverappio_token_auth_server_interceptor,
get_authenticated_task,
)
from flwr.supercore.interceptors.appio_token_interceptor import (
RUN_BINDING_FAILED_MESSAGE,
)

_ClientCallDetails = namedtuple(
"_ClientCallDetails",
Expand Down Expand Up @@ -256,14 +259,36 @@ def test_valid_token_passes_for_protected_method(self) -> None:
# cross-run use is expected.
response = cast(str, intercepted.unary_unary(GetNodesRequest(run_id=7), Mock()))
self.assertEqual(response, "ok")
# Run-id mismatch deny coverage belongs to the
# follow-up PR that enforces run binding.

def test_run_id_mismatch_denied_for_protected_method(self) -> None:
"""Protected methods should deny token use against a different run."""
interceptor = self._new_interceptor(token_to_run_id={"valid": 7})
method = self._find_serverappio_method(requires_token=True)
if method is None:
self.skipTest("No token-required ServerAppIo method found in policy table.")
context = Mock()
context.abort.side_effect = grpc.RpcError()

intercepted = interceptor.intercept_service(
lambda _: _make_unary_handler(),
_HandlerCallDetails(
method,
invocation_metadata=((APP_TOKEN_HEADER, "valid"),),
),
)

with self.assertRaises(grpc.RpcError):
intercepted.unary_unary(GetNodesRequest(run_id=8), context)
context.abort.assert_called_once_with(
grpc.StatusCode.PERMISSION_DENIED,
RUN_BINDING_FAILED_MESSAGE,
)

def test_valid_task_token_passes_and_sets_task_id(self) -> None:
"""Protected methods should pass with a valid task token."""
state = Mock()
state.get_run_id_by_token.return_value = None
state.get_task_by_token.return_value = Mock(task_id=123)
state.get_task_by_token.return_value = Mock(task_id=123, run_id=7)
interceptor = create_serverappio_token_auth_server_interceptor(lambda: state)
method = self._find_serverappio_method(requires_token=True)
if method is None:
Expand All @@ -290,6 +315,56 @@ def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str:
state.get_run_id_by_token.assert_called_once_with("task-token")
state.get_task_by_token.assert_called_once_with("task-token")

def test_task_token_passes_when_run_id_omitted(self) -> None:
"""Task-token protected methods should accept omitted request run IDs."""
state = Mock()
state.get_run_id_by_token.return_value = None
state.get_task_by_token.return_value = Mock(task_id=123, run_id=7)
interceptor = create_serverappio_token_auth_server_interceptor(lambda: state)
captured_task = None

def _handler(_request: GrpcMessage, _context: grpc.ServicerContext) -> str:
nonlocal captured_task
captured_task = get_authenticated_task()
return "ok"

intercepted = interceptor.intercept_service(
lambda _: grpc.unary_unary_rpc_method_handler(_handler),
_HandlerCallDetails(
"/flwr.proto.ServerAppIo/PushAppOutputs",
invocation_metadata=((APP_TOKEN_HEADER, "task-token"),),
),
)

response = intercepted.unary_unary(PushAppOutputsRequest(), Mock())
self.assertEqual(response, "ok")
self.assertIsNotNone(captured_task)
self.assertEqual(cast(Mock, captured_task).task_id, 123)

def test_task_token_run_id_mismatch_denied(self) -> None:
"""Task-token protected methods should deny a different request run."""
state = Mock()
state.get_run_id_by_token.return_value = None
state.get_task_by_token.return_value = Mock(task_id=123, run_id=7)
interceptor = create_serverappio_token_auth_server_interceptor(lambda: state)
context = Mock()
context.abort.side_effect = grpc.RpcError()

intercepted = interceptor.intercept_service(
lambda _: _make_unary_handler(),
_HandlerCallDetails(
"/flwr.proto.ServerAppIo/PushAppOutputs",
invocation_metadata=((APP_TOKEN_HEADER, "task-token"),),
),
)

with self.assertRaises(grpc.RpcError):
intercepted.unary_unary(PushAppOutputsRequest(run_id=8), context)
context.abort.assert_called_once_with(
grpc.StatusCode.PERMISSION_DENIED,
RUN_BINDING_FAILED_MESSAGE,
)

def test_metadata_token_used_even_when_request_has_token(self) -> None:
"""Metadata token should be authoritative when both sources exist."""
interceptor = self._new_interceptor(token_to_run_id={"metadata-token": 5})
Expand Down Expand Up @@ -485,7 +560,7 @@ def test_clientappio_factory_uses_client_policy(self) -> None:
response = cast(
str,
intercepted.unary_unary(
PushObjectRequest(object_id="obj", object_content=b"x"),
PushObjectRequest(run_id=1, object_id="obj", object_content=b"x"),
Mock(),
),
)
Expand Down
11 changes: 8 additions & 3 deletions framework/py/flwr/supernode/runtime/run_clientapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def pull_appinputs(stub: ClientAppIoStub) -> tuple[Message, Context, Run, Fab]:

# Pull and inflate the message
pull_msg_res: PullAppMessagesResponse = stub.PullMessage(
PullAppMessagesRequest()
PullAppMessagesRequest(run_id=run.run_id)
)
run_id = context.run_id
node = Node(node_id=context.node_id)
Expand Down Expand Up @@ -280,7 +280,9 @@ def push_appoutputs( # pylint: disable=R0913, R0917
# This is temporary. The message should not contain its content
push_msg_res = stub.PushMessage(
PushAppMessagesRequest(
messages_list=[proto_message], message_object_trees=[object_tree]
messages_list=[proto_message],
run_id=context.run_id,
message_object_trees=[object_tree],
)
)
del proto_message
Expand All @@ -304,7 +306,10 @@ def push_appoutputs( # pylint: disable=R0913, R0917
# Push Context
res: PushAppOutputsResponse = stub.PushAppOutputs(
PushAppOutputsRequest(
context=proto_context, sub_status=sub_status, details=details
run_id=context.run_id,
context=proto_context,
sub_status=sub_status,
details=details,
)
)
return res
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def setUp(self) -> None:
state_factory = NodeStateFactory(objectstore_factory=objectstore_factory)

state = state_factory.state()
token = state.create_token(99)
self.valid_run_id = 99
token = state.create_token(self.valid_run_id)
assert token is not None
self.valid_token = token

Expand Down Expand Up @@ -100,15 +101,17 @@ def tearDown(self) -> None:
def test_pull_object_denied_without_metadata_token(self) -> None:
"""Protected RPC should deny requests missing metadata token."""
with self.assertRaises(grpc.RpcError) as err:
self._pull_object.with_call(request=PullObjectRequest(object_id="obj-1"))
self._pull_object.with_call(
request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-1")
)
assert err.exception.code() == grpc.StatusCode.UNAUTHENTICATED
assert err.exception.details() == AUTHENTICATION_FAILED_MESSAGE

def test_pull_object_denied_with_invalid_metadata_token(self) -> None:
"""Protected RPC should deny requests with invalid metadata token."""
with self.assertRaises(grpc.RpcError) as err:
self._pull_object.with_call(
request=PullObjectRequest(object_id="obj-2"),
request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-2"),
metadata=((APP_TOKEN_HEADER, "invalid-token"),),
)
assert err.exception.code() == grpc.StatusCode.UNAUTHENTICATED
Expand All @@ -117,7 +120,7 @@ def test_pull_object_denied_with_invalid_metadata_token(self) -> None:
def test_pull_object_allows_with_valid_metadata_token(self) -> None:
"""Protected RPC should allow requests with valid metadata token."""
response, call = self._pull_object.with_call(
request=PullObjectRequest(object_id="obj-3"),
request=PullObjectRequest(run_id=self.valid_run_id, object_id="obj-3"),
metadata=((APP_TOKEN_HEADER, self.valid_token),),
)

Expand Down
Loading