diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index d3fac9d38..9b382a2ad 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -8,6 +8,7 @@ import asyncio import importlib.metadata +import inspect import json import logging import os @@ -21,6 +22,7 @@ from typing import ( Any, + Awaitable, Callable, cast, Dict, @@ -116,13 +118,29 @@ def __init__(self, env: Callable[[], None]) -> None: self._setup_method = env @endpoint - async def setup(self) -> None: + def setup(self) -> None: """ Call the user defined setup method with the monarch context. """ self._setup_method() +class AsyncSetupActor(Actor): + """ + A helper actor to set up the actor mesh with user defined setup method. + """ + + def __init__(self, env: Callable[[], Awaitable[None]]) -> None: + self._setup_method = env + + @endpoint + async def setup(self) -> None: + """ + Call the user defined setup method with the monarch context. + """ + await self._setup_method() + + T = TypeVar("T") try: from __manifest__ import fbmake # noqa @@ -382,7 +400,7 @@ async def task() -> HyProcMeshV0: async def task( pm: "ProcMeshV0", hy_proc_mesh_task: "Shared[HyProcMeshV0]", - setup_actor: Optional[SetupActor], + setup_actor: SetupActor | AsyncSetupActor | None, stream_log_to_client: bool, ) -> HyProcMeshV0: hy_proc_mesh = await hy_proc_mesh_task @@ -391,6 +409,9 @@ async def task( if setup_actor is not None: await setup_actor.setup.call() + # Cleanup resources held by the setup actor. We only need to run + # the function, and not keep anything it created around. + await cast(ActorMesh[SetupActor], setup_actor).stop() return hy_proc_mesh @@ -399,8 +420,11 @@ async def task( # If the user has passed the setup lambda, we need to call # it here before any of the other actors are spawned so that # the environment variables are set up before cuda init. + actor_type = ( + AsyncSetupActor if inspect.iscoroutinefunction(setup) else SetupActor + ) setup_actor = pm._spawn_nonblocking_on( - hy_proc_mesh, "setup", SetupActor, setup + hy_proc_mesh, "setup", actor_type, setup ) pm._proc_mesh = PythonTask.from_coroutine( diff --git a/python/monarch/_src/actor/v1/host_mesh.py b/python/monarch/_src/actor/v1/host_mesh.py index 94ec122a1..dec78736f 100644 --- a/python/monarch/_src/actor/v1/host_mesh.py +++ b/python/monarch/_src/actor/v1/host_mesh.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, Callable, Dict, Literal, Optional, Tuple +from typing import Any, Awaitable, Callable, Dict, Literal, Optional, Tuple from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared @@ -158,7 +158,7 @@ async def task() -> HyHostMesh: def spawn_procs( self, per_host: Dict[str, int] | None = None, - bootstrap: Callable[[], None] | None = None, + bootstrap: Callable[[], None] | Callable[[], Awaitable[None]] | None = None, name: str | None = None, ) -> "ProcMesh": if not per_host: @@ -178,7 +178,7 @@ def _spawn_nonblocking( self, name: str, per_host: Extent, - setup: Callable[[], None] | None, + setup: Callable[[], None] | Callable[[], Awaitable[None]] | None, _attach_controller_controller: bool, ) -> "ProcMesh": if set(per_host.labels) & set(self._labels): diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index 3d75c04c0..00de257d3 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -8,6 +8,7 @@ import asyncio import importlib +import inspect import json import logging import os @@ -17,6 +18,7 @@ from typing import ( Any, + Awaitable, Callable, cast, Dict, @@ -238,7 +240,7 @@ def from_host_mesh( host_mesh: "HostMesh", hy_proc_mesh: "Shared[HyProcMesh]", region: Region, - setup: Callable[[], None] | None = None, + setup: Callable[[], None] | Callable[[], Awaitable[None]] | None = None, _attach_controller_controller: bool = True, ) -> "ProcMesh": pm = ProcMesh(hy_proc_mesh, host_mesh, region, region, None) @@ -251,7 +253,7 @@ def from_host_mesh( async def task( pm: "ProcMesh", hy_proc_mesh_task: "Shared[HyProcMesh]", - setup_actor: Optional["SetupActor"], + setup_actor: "SetupActor | AsyncSetupActor | None", stream_log_to_client: bool, ) -> HyProcMesh: hy_proc_mesh = await hy_proc_mesh_task @@ -268,13 +270,16 @@ async def task( setup_actor = None if setup is not None: - from monarch._src.actor.proc_mesh import SetupActor # noqa + from monarch._src.actor.proc_mesh import AsyncSetupActor, SetupActor # noqa + actor_type = ( + AsyncSetupActor if inspect.iscoroutinefunction(setup) else SetupActor + ) # The SetupActor needs to be spawned outside of `task` for now, # since spawning a python actor requires a blocking call to # pickle the proc mesh, and we can't do that from the tokio runtime. setup_actor = pm._spawn_nonblocking_on( - hy_proc_mesh, "setup", SetupActor, setup + hy_proc_mesh, "setup", actor_type, setup ) pm._proc_mesh = PythonTask.from_coroutine( diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 6da3e7230..067e8f861 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -1686,3 +1686,44 @@ def test_login_job(): assert v == "hello!" j.kill() + + +_global_foo = None + + +def setup_with_spawn() -> None: + global _global_foo + proc = this_proc() + # Doesn't matter which actor is spawned, just make sure it persists. + _global_foo = proc.spawn("foo", Counter, 0) + _global_foo.incr.call().get() + # Spawn one that dies to make sure it doesn't cause any issues. + bar = proc.spawn("bar", Counter, 0) + bar.incr.call().get() + + +async def async_setup_with_spawn() -> None: + global _global_foo + proc = this_proc() + # Doesn't matter which actor is spawned, just make sure it persists. + _global_foo = proc.spawn("foo", Counter, 0) + await _global_foo.incr.call() + # Spawn one that dies to make sure it doesn't cause any issues. + bar = proc.spawn("bar", Counter, 0) + await bar.incr.call() + + +def test_setup() -> None: + procs = this_host().spawn_procs(bootstrap=setup_with_spawn) + counter = procs.spawn("counter", Counter, 0) + counter.incr.call().get() + # Make sure no errors occur in the meantime + time.sleep(10) + + +def test_setup_async() -> None: + procs = this_host().spawn_procs(bootstrap=async_setup_with_spawn) + counter = procs.spawn("counter", Counter, 0) + counter.incr.call().get() + # Make sure no errors occur in the meantime + time.sleep(10)