Skip to content

Commit e8bfcbd

Browse files
committed
Error on mixed async/sync
Pull Request resolved: #1626 ghstack-source-id: 317845131 @exported-using-ghexport Differential Revision: [D85090856](https://our.internmc.facebook.com/intern/diff/D85090856/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D85090856/)!
1 parent b0f9777 commit e8bfcbd

File tree

4 files changed

+42
-35
lines changed

4 files changed

+42
-35
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,8 @@ def __init__(
11221122
# can be collected even without any endpoints being awaited.
11231123
self._inner.start_supervision(context().actor_instance._as_rust())
11241124

1125+
async_endpoints = []
1126+
sync_endpoints = []
11251127
for attr_name in dir(self._class):
11261128
attr_value = getattr(self._class, attr_name, None)
11271129
if isinstance(attr_value, EndpointProperty):
@@ -1141,6 +1143,17 @@ def __init__(
11411143
attr_value._explicit_response_port,
11421144
),
11431145
)
1146+
if inspect.iscoroutinefunction(attr_value._method):
1147+
async_endpoints.append(attr_name)
1148+
else:
1149+
sync_endpoints.append(attr_name)
1150+
1151+
if sync_endpoints and async_endpoints:
1152+
raise ValueError(
1153+
f"{self._class} mixes both async and sync endpoints."
1154+
"Synchronous endpoints cannot be mixed with async endpoints because they can cause the asyncio loop to deadlock if they wait."
1155+
f"sync: {sync_endpoints} async: {async_endpoints}"
1156+
)
11441157

11451158
def __getattr__(self, attr: str) -> NotAnEndpoint:
11461159
if attr in dir(self._class):

python/tests/test_actor_error.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -423,13 +423,23 @@ class ActorFailureError(BaseException):
423423
pass
424424

425425

426-
class ErrorActor(Actor):
426+
class SyncErrorActor(Actor):
427427
@endpoint
428428
def fail_with_supervision_error(self) -> None:
429429
raise ActorFailureError("Simulated actor failure for supervision testing")
430430

431431
@endpoint
432-
async def fail_with_supervision_error_async(self) -> None:
432+
def check(self) -> str:
433+
return "this is a healthy check"
434+
435+
@endpoint
436+
def check_with_exception(self) -> None:
437+
raise RuntimeError("failed the check with app error")
438+
439+
440+
class ErrorActor(Actor):
441+
@endpoint
442+
async def fail_with_supervision_error(self) -> None:
433443
raise ActorFailureError("Simulated actor failure for supervision testing")
434444

435445
@endpoint
@@ -579,10 +589,10 @@ async def test_actor_mesh_supervision_handling_chained_error() -> None:
579589
ids=["local_proc_mesh", "proc_mesh"],
580590
)
581591
@pytest.mark.parametrize(
582-
"method_name",
583-
["fail_with_supervision_error", "fail_with_supervision_error_async"],
592+
"error_actor_cls",
593+
[ErrorActor, SyncErrorActor],
584594
)
585-
async def test_base_exception_handling(mesh, method_name) -> None:
595+
async def test_base_exception_handling(mesh, error_actor_cls) -> None:
586596
"""Test that BaseException subclasses trigger supervision errors.
587597
588598
This test verifies that both synchronous and asynchronous methods
@@ -591,10 +601,10 @@ async def test_base_exception_handling(mesh, method_name) -> None:
591601
592602
"""
593603
proc = mesh({"gpus": 1})
594-
error_actor = proc.spawn("error", ErrorActor)
604+
error_actor = proc.spawn("error", error_actor_cls)
595605

596606
# Get the method to call based on the parameter
597-
method = getattr(error_actor, method_name)
607+
method = error_actor.fail_with_supervision_error
598608

599609
# The call should raise a SupervisionError
600610
with pytest.raises(
@@ -807,16 +817,6 @@ async def subworker_fail(self, sleep: float | None = None) -> None:
807817
"Should never get here, SupervisionError should be raised by the above call"
808818
)
809819

810-
@endpoint
811-
async def subworker_broadcast_fail(self) -> None:
812-
self.mesh.check.broadcast()
813-
# When not awaiting the result of an endpoint which experiences a
814-
# failure, it should still propagate back to __supervise__.
815-
self.mesh.fail_with_supervision_error.broadcast()
816-
# Give time for the failure to occur before returning, so get_failures
817-
# will have a non-empty result.
818-
await asyncio.sleep(10)
819-
820820
@endpoint
821821
async def get_failures(self) -> list[str]:
822822
# MeshFailure is not picklable, so we convert it to a string.

python/tests/test_debugger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def debug_func(self, func, class_closure) -> int:
11511151
return func(class_closure)
11521152

11531153
@endpoint
1154-
async def name(self) -> str:
1154+
def name(self) -> str:
11551155
return context().actor_instance.actor_id.actor_name
11561156

11571157

python/tests/test_python_actors.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,14 @@ async def incr(self):
8787
async def value(self) -> int:
8888
return self.v
8989

90+
91+
class SyncCounter(Actor):
92+
def __init__(self, c):
93+
self.c = c
94+
9095
@endpoint
9196
def value_sync_endpoint(self) -> int:
92-
return self.v
97+
return self.c.value.choose().get()
9398

9499

95100
class Indirect(Actor):
@@ -112,7 +117,8 @@ async def test_choose():
112117

113118
assert result == result2
114119

115-
result3 = await v.value_sync_endpoint.choose()
120+
v2 = proc.spawn("sync_counter", SyncCounter, v)
121+
result3 = v2.value_sync_endpoint.choose().get()
116122
assert_type(result, int)
117123
assert result2 == result3
118124

@@ -335,15 +341,15 @@ def __init__(self):
335341
self.local.value = 0
336342

337343
@endpoint
338-
def increment(self):
344+
async def increment(self):
339345
self.local.value += 1
340346

341347
@endpoint
342348
async def increment_async(self):
343349
self.local.value += 1
344350

345351
@endpoint
346-
def get_value(self):
352+
async def get_value(self):
347353
return self.local.value
348354

349355
@endpoint
@@ -408,18 +414,6 @@ async def no_more(self) -> None:
408414
self.should_exit = True
409415

410416

411-
@pytest.mark.timeout(60)
412-
async def test_async_concurrency():
413-
"""Test that async endpoints will be processed concurrently."""
414-
pm = this_host().spawn_procs(per_host={})
415-
am = pm.spawn("async", AsyncActor)
416-
fut = am.sleep.call()
417-
# This call should go through and exit the sleep loop, as long as we are
418-
# actually concurrently processing messages.
419-
await am.no_more.call()
420-
await fut
421-
422-
423417
async def awaitit(f):
424418
return await f
425419

@@ -1543,7 +1537,7 @@ def __init__(self, root="None"):
15431537
self._root = root
15441538

15451539
@endpoint
1546-
def return_root(self):
1540+
async def return_root(self):
15471541
return self._root
15481542

15491543
@endpoint

0 commit comments

Comments
 (0)