Skip to content

Commit c758e7a

Browse files
samluryemeta-codesync[bot]
authored andcommitted
Fix race condition during ProcMesh stop (#1662)
Summary: Pull Request resolved: #1662 Await `ProcMesh._proc_mesh` before flushing logging manager so that we can be sure the logging client has been spawned and avoid throwing an assertion error. ghstack-source-id: 318971110 exported-using-ghexport Reviewed By: amirafzali Differential Revision: D85472096 fbshipit-source-id: 29e2d2a77bfb36ff0c05c85436db351a747a7d11
1 parent 1ee6502 commit c758e7a

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ async def task(
258258

259259
await pm._logging_manager.init(hy_proc_mesh, stream_log_to_client)
260260

261+
# If the user has passed the setup lambda, we need to call
262+
# it here before any of the other python actors are spawned so
263+
# that the environment variables are set up before cuda init.
261264
if setup_actor is not None:
262265
await setup_actor.setup.call()
263266

@@ -267,9 +270,9 @@ async def task(
267270
if setup is not None:
268271
from monarch._src.actor.proc_mesh import SetupActor # noqa
269272

270-
# If the user has passed the setup lambda, we need to call
271-
# it here before any of the other actors are spawned so that
272-
# the environment variables are set up before cuda init.
273+
# The SetupActor needs to be spawned outside of `task` for now,
274+
# since spawning a python actor requires a blocking call to
275+
# pickle the proc mesh, and we can't do that from the tokio runtime.
273276
setup_actor = pm._spawn_nonblocking_on(
274277
hy_proc_mesh, "setup", SetupActor, setup
275278
)
@@ -402,8 +405,9 @@ def stop(self) -> Future[None]:
402405
instance = context().actor_instance._as_rust()
403406

404407
async def _stop_nonblocking(instance: HyInstance) -> None:
408+
pm = await self._proc_mesh
405409
await PythonTask.spawn_blocking(lambda: self._logging_manager.flush())
406-
await (await self._proc_mesh).stop_nonblocking(instance)
410+
await pm.stop_nonblocking(instance)
407411
self._stopped = True
408412

409413
return Future(coro=_stop_nonblocking(instance))

python/tests/test_python_actors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ async def test_mesh_passed_to_mesh():
149149
proc = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
150150
f = proc.spawn("from", From)
151151
t = proc.spawn("to", To)
152+
# Make sure t is initialized before sending to f. Otherwise
153+
# f might call t.whoami before t.__init__.
154+
await t.whoami.call()
152155
all = [y for x in f.fetch.stream(t) for y in await x]
153156
assert len(all) == 4
154157
assert all[0] != all[1]
@@ -160,6 +163,9 @@ async def test_mesh_passed_to_mesh_on_different_proc_mesh():
160163
proc2 = fake_in_process_host().spawn_procs(per_host={"gpus": 2})
161164
f = proc.spawn("from", From)
162165
t = proc2.spawn("to", To)
166+
# Make sure t is initialized before sending to f. Otherwise
167+
# f might call t.whoami before t.__init__.
168+
await t.whoami.call()
163169
all = [y for x in f.fetch.stream(t) for y in await x]
164170
assert len(all) == 4
165171
assert all[0] != all[1]

0 commit comments

Comments
 (0)