Skip to content

Commit 3d5234c

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Stop the SetupActor after running its endpoint (#1693)
Summary: The SetupActor runs a single function and then does nothing. There's no need to keep the actor loop running and listening for messages. Furthermore, if the setup function was holding onto resources somehow, this will let them get cleaned up. The motivation for this change is that, if the setup function happens to spawn ActorMeshes, there's no way to access them, so we should try to delete them. People should not rely on SetupActor staying around. If they want that, they can spawn their own Actor on the proc mesh that can hold onto process-wide state. Also, SetupActor was an async endpoint calling a synchronous function which made it easy to accidentally block on a PythonTask. Add support for sync and async callback functions. Reviewed By: mariusae Differential Revision: D85721318
1 parent f40f6db commit 3d5234c

File tree

4 files changed

+80
-10
lines changed

4 files changed

+80
-10
lines changed

python/monarch/_src/actor/proc_mesh.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import asyncio
1010
import importlib.metadata
11+
import inspect
1112
import json
1213
import logging
1314
import os
@@ -21,6 +22,7 @@
2122

2223
from typing import (
2324
Any,
25+
Awaitable,
2426
Callable,
2527
cast,
2628
Dict,
@@ -116,13 +118,29 @@ def __init__(self, env: Callable[[], None]) -> None:
116118
self._setup_method = env
117119

118120
@endpoint
119-
async def setup(self) -> None:
121+
def setup(self) -> None:
120122
"""
121123
Call the user defined setup method with the monarch context.
122124
"""
123125
self._setup_method()
124126

125127

128+
class AsyncSetupActor(Actor):
129+
"""
130+
A helper actor to set up the actor mesh with user defined setup method.
131+
"""
132+
133+
def __init__(self, env: Callable[[], Awaitable[None]]) -> None:
134+
self._setup_method = env
135+
136+
@endpoint
137+
async def setup(self) -> None:
138+
"""
139+
Call the user defined setup method with the monarch context.
140+
"""
141+
await self._setup_method()
142+
143+
126144
T = TypeVar("T")
127145
try:
128146
from __manifest__ import fbmake # noqa
@@ -382,7 +400,7 @@ async def task() -> HyProcMeshV0:
382400
async def task(
383401
pm: "ProcMeshV0",
384402
hy_proc_mesh_task: "Shared[HyProcMeshV0]",
385-
setup_actor: Optional[SetupActor],
403+
setup_actor: SetupActor | AsyncSetupActor | None,
386404
stream_log_to_client: bool,
387405
) -> HyProcMeshV0:
388406
hy_proc_mesh = await hy_proc_mesh_task
@@ -391,6 +409,9 @@ async def task(
391409

392410
if setup_actor is not None:
393411
await setup_actor.setup.call()
412+
# Cleanup resources held by the setup actor. We only need to run
413+
# the function, and not keep anything it created around.
414+
await cast(ActorMesh[SetupActor], setup_actor).stop()
394415

395416
return hy_proc_mesh
396417

@@ -399,8 +420,11 @@ async def task(
399420
# If the user has passed the setup lambda, we need to call
400421
# it here before any of the other actors are spawned so that
401422
# the environment variables are set up before cuda init.
423+
actor_type = (
424+
AsyncSetupActor if inspect.iscoroutinefunction(setup) else SetupActor
425+
)
402426
setup_actor = pm._spawn_nonblocking_on(
403-
hy_proc_mesh, "setup", SetupActor, setup
427+
hy_proc_mesh, "setup", actor_type, setup
404428
)
405429

406430
pm._proc_mesh = PythonTask.from_coroutine(

python/monarch/_src/actor/v1/host_mesh.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, Callable, Dict, Literal, Optional, Tuple
9+
from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Tuple
1010

1111
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
1212
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
@@ -158,7 +158,7 @@ async def task() -> HyHostMesh:
158158
def spawn_procs(
159159
self,
160160
per_host: Dict[str, int] | None = None,
161-
bootstrap: Callable[[], None] | None = None,
161+
bootstrap: Callable[[], None] | Callable[[], Awaitable[None]] | None = None,
162162
name: str | None = None,
163163
) -> "ProcMesh":
164164
if not per_host:
@@ -178,7 +178,7 @@ def _spawn_nonblocking(
178178
self,
179179
name: str,
180180
per_host: Extent,
181-
setup: Callable[[], None] | None,
181+
setup: Callable[[], None] | Callable[[], Awaitable[None]] | None,
182182
_attach_controller_controller: bool,
183183
) -> "ProcMesh":
184184
if set(per_host.labels) & set(self._labels):

python/monarch/_src/actor/v1/proc_mesh.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import asyncio
1010
import importlib
11+
import inspect
1112
import json
1213
import logging
1314
import os
@@ -17,6 +18,7 @@
1718

1819
from typing import (
1920
Any,
21+
Awaitable,
2022
Callable,
2123
cast,
2224
Dict,
@@ -238,7 +240,7 @@ def from_host_mesh(
238240
host_mesh: "HostMesh",
239241
hy_proc_mesh: "Shared[HyProcMesh]",
240242
region: Region,
241-
setup: Callable[[], None] | None = None,
243+
setup: Callable[[], None] | Callable[[], Awaitable[None]] | None = None,
242244
_attach_controller_controller: bool = True,
243245
) -> "ProcMesh":
244246
pm = ProcMesh(hy_proc_mesh, host_mesh, region, region, None)
@@ -251,7 +253,7 @@ def from_host_mesh(
251253
async def task(
252254
pm: "ProcMesh",
253255
hy_proc_mesh_task: "Shared[HyProcMesh]",
254-
setup_actor: Optional["SetupActor"],
256+
setup_actor: "SetupActor | AsyncSetupActor | None",
255257
stream_log_to_client: bool,
256258
) -> HyProcMesh:
257259
hy_proc_mesh = await hy_proc_mesh_task
@@ -268,13 +270,16 @@ async def task(
268270

269271
setup_actor = None
270272
if setup is not None:
271-
from monarch._src.actor.proc_mesh import SetupActor # noqa
273+
from monarch._src.actor.proc_mesh import AsyncSetupActor, SetupActor # noqa
272274

275+
actor_type = (
276+
AsyncSetupActor if inspect.iscoroutinefunction(setup) else SetupActor
277+
)
273278
# The SetupActor needs to be spawned outside of `task` for now,
274279
# since spawning a python actor requires a blocking call to
275280
# pickle the proc mesh, and we can't do that from the tokio runtime.
276281
setup_actor = pm._spawn_nonblocking_on(
277-
hy_proc_mesh, "setup", SetupActor, setup
282+
hy_proc_mesh, "setup", actor_type, setup
278283
)
279284

280285
pm._proc_mesh = PythonTask.from_coroutine(

python/tests/test_python_actors.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,3 +1686,44 @@ def test_login_job():
16861686
assert v == "hello!"
16871687

16881688
j.kill()
1689+
1690+
1691+
_global_foo = None
1692+
1693+
1694+
def setup_with_spawn() -> None:
1695+
global _global_foo
1696+
proc = this_proc()
1697+
# Doesn't matter which actor is spawned, just make sure it persists.
1698+
_global_foo = proc.spawn("foo", Counter, 0)
1699+
_global_foo.incr.call().get()
1700+
# Spawn one that dies to make sure it doesn't cause any issues.
1701+
bar = proc.spawn("bar", Counter, 0)
1702+
bar.incr.call().get()
1703+
1704+
1705+
async def async_setup_with_spawn() -> None:
1706+
global _global_foo
1707+
proc = this_proc()
1708+
# Doesn't matter which actor is spawned, just make sure it persists.
1709+
_global_foo = proc.spawn("foo", Counter, 0)
1710+
await _global_foo.incr.call()
1711+
# Spawn one that dies to make sure it doesn't cause any issues.
1712+
bar = proc.spawn("bar", Counter, 0)
1713+
await bar.incr.call()
1714+
1715+
1716+
def test_setup() -> None:
1717+
procs = this_host().spawn_procs(bootstrap=setup_with_spawn)
1718+
counter = procs.spawn("counter", Counter, 0)
1719+
counter.incr.call().get()
1720+
# Make sure no errors occur in the meantime
1721+
time.sleep(10)
1722+
1723+
1724+
def test_setup_async() -> None:
1725+
procs = this_host().spawn_procs(bootstrap=async_setup_with_spawn)
1726+
counter = procs.spawn("counter", Counter, 0)
1727+
counter.incr.call().get()
1728+
# Make sure no errors occur in the meantime
1729+
time.sleep(10)

0 commit comments

Comments
 (0)