Skip to content

Commit 38e87b1

Browse files
committed
Error on mixed async/sync
Pull Request resolved: #1626 ghstack-source-id: 317790411 @exported-using-ghexport Differential Revision: [D85090856](https://our.internmc.facebook.com/intern/diff/D85090856/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D85090856/)!
1 parent 073b963 commit 38e87b1

File tree

4 files changed

+44
-26
lines changed

4 files changed

+44
-26
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,9 @@ def __init__(
11141114
self._inner: "ActorMeshProtocol" = inner
11151115
self._shape = shape
11161116
self._proc_mesh = proc_mesh
1117+
1118+
async_endpoints = []
1119+
sync_endpoints = []
11171120
for attr_name in dir(self._class):
11181121
attr_value = getattr(self._class, attr_name, None)
11191122
if isinstance(attr_value, EndpointProperty):
@@ -1133,6 +1136,17 @@ def __init__(
11331136
attr_value._explicit_response_port,
11341137
),
11351138
)
1139+
if inspect.iscoroutinefunction(attr_value._method):
1140+
async_endpoints.append(attr_name)
1141+
else:
1142+
sync_endpoints.append(attr_name)
1143+
1144+
if sync_endpoints and async_endpoints:
1145+
raise ValueError(
1146+
f"{self._class} mixes both async and sync endpoints."
1147+
"Synchronous endpoints cannot be mixed with async endpoints because they can cause the asyncio loop to deadlock if they wait."
1148+
f"sync: {sync_endpoints} async: {async_endpoints}"
1149+
)
11361150

11371151
def __getattr__(self, attr: str) -> NotAnEndpoint:
11381152
if attr in dir(self._class):

python/tests/test_actor_error.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -421,13 +421,23 @@ class ActorFailureError(BaseException):
421421
pass
422422

423423

424-
class ErrorActor(Actor):
424+
class SyncErrorActor(Actor):
425425
@endpoint
426426
def fail_with_supervision_error(self) -> None:
427427
raise ActorFailureError("Simulated actor failure for supervision testing")
428428

429429
@endpoint
430-
async def fail_with_supervision_error_async(self) -> None:
430+
def check(self) -> str:
431+
return "this is a healthy check"
432+
433+
@endpoint
434+
def check_with_exception(self) -> None:
435+
raise RuntimeError("failed the check with app error")
436+
437+
438+
class ErrorActor(Actor):
439+
@endpoint
440+
async def fail_with_supervision_error(self) -> None:
431441
raise ActorFailureError("Simulated actor failure for supervision testing")
432442

433443
@endpoint
@@ -568,10 +578,10 @@ async def test_actor_mesh_supervision_handling_chained_error() -> None:
568578
ids=["local_proc_mesh", "proc_mesh"],
569579
)
570580
@pytest.mark.parametrize(
571-
"method_name",
572-
["fail_with_supervision_error", "fail_with_supervision_error_async"],
581+
"error_actor_cls",
582+
[ErrorActor, SyncErrorActor],
573583
)
574-
async def test_base_exception_handling(mesh, method_name) -> None:
584+
async def test_base_exception_handling(mesh, error_actor_cls) -> None:
575585
"""Test that BaseException subclasses trigger supervision errors.
576586
577587
This test verifies that both synchronous and asynchronous methods
@@ -580,10 +590,10 @@ async def test_base_exception_handling(mesh, method_name) -> None:
580590
581591
"""
582592
proc = mesh({"gpus": 1})
583-
error_actor = proc.spawn("error", ErrorActor)
593+
error_actor = proc.spawn("error", error_actor_cls)
584594

585595
# Get the method to call based on the parameter
586-
method = getattr(error_actor, method_name)
596+
method = error_actor.fail_with_supervision_error
587597

588598
# The call should raise a SupervisionError
589599
with pytest.raises(
@@ -794,7 +804,7 @@ async def subworker_fail(self) -> None:
794804
)
795805

796806
@endpoint
797-
def get_failures(self) -> list[str]:
807+
async def get_failures(self) -> list[str]:
798808
# MeshFailure is not picklable, so we convert it to a string.
799809
return [str(f) for f in self.failures]
800810

python/tests/test_debugger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1151,7 +1151,7 @@ def debug_func(self, func, class_closure) -> int:
11511151
return func(class_closure)
11521152

11531153
@endpoint
1154-
async def name(self) -> str:
1154+
def name(self) -> str:
11551155
return context().actor_instance.actor_id.actor_name
11561156

11571157

python/tests/test_python_actors.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,14 @@ async def incr(self):
8787
async def value(self) -> int:
8888
return self.v
8989

90+
91+
class SyncCounter(Actor):
92+
def __init__(self, c):
93+
self.c = c
94+
9095
@endpoint
9196
def value_sync_endpoint(self) -> int:
92-
return self.v
97+
return self.c.value.choose().get()
9398

9499

95100
class Indirect(Actor):
@@ -112,7 +117,8 @@ async def test_choose():
112117

113118
assert result == result2
114119

115-
result3 = await v.value_sync_endpoint.choose()
120+
v2 = proc.spawn("sync_counter", SyncCounter, v)
121+
result3 = v2.value_sync_endpoint.choose().get()
116122
assert_type(result, int)
117123
assert result2 == result3
118124

@@ -335,15 +341,15 @@ def __init__(self):
335341
self.local.value = 0
336342

337343
@endpoint
338-
def increment(self):
344+
async def increment(self):
339345
self.local.value += 1
340346

341347
@endpoint
342348
async def increment_async(self):
343349
self.local.value += 1
344350

345351
@endpoint
346-
def get_value(self):
352+
async def get_value(self):
347353
return self.local.value
348354

349355
@endpoint
@@ -408,18 +414,6 @@ async def no_more(self) -> None:
408414
self.should_exit = True
409415

410416

411-
@pytest.mark.timeout(60)
412-
async def test_async_concurrency():
413-
"""Test that async endpoints will be processed concurrently."""
414-
pm = this_host().spawn_procs(per_host={})
415-
am = pm.spawn("async", AsyncActor)
416-
fut = am.sleep.call()
417-
# This call should go through and exit the sleep loop, as long as we are
418-
# actually concurrently processing messages.
419-
await am.no_more.call()
420-
await fut
421-
422-
423417
async def awaitit(f):
424418
return await f
425419

@@ -1543,7 +1537,7 @@ def __init__(self, root="None"):
15431537
self._root = root
15441538

15451539
@endpoint
1546-
def return_root(self):
1540+
async def return_root(self):
15471541
return self._root
15481542

15491543
@endpoint

0 commit comments

Comments
 (0)