diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index aa773ea54..b4d2881bf 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -13,7 +13,7 @@ from torchx.schedulers.api import Scheduler from torchx.util.entrypoints import load_group -DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = { +BUILTIN_SCHEDULER_MODULES: Mapping[str, str] = { "local_docker": "torchx.schedulers.docker_scheduler", "local_cwd": "torchx.schedulers.local_scheduler", "slurm": "torchx.schedulers.slurm_scheduler", @@ -39,8 +39,17 @@ def run(*args: object, **kwargs: object) -> Scheduler: return run +def default_schedulers() -> dict[str, SchedulerFactory]: + """Build default schedulers (built-in + extras from torchx.schedulers.extra).""" + return { + **{s: _defer_load_scheduler(p) for s, p in BUILTIN_SCHEDULER_MODULES.items()}, + **load_group("torchx.schedulers.extra", default={}), + } + + def get_scheduler_factories( - group: str = "torchx.schedulers", skip_defaults: bool = False + group: str = "torchx.schedulers", + skip_defaults: bool = False, ) -> dict[str, SchedulerFactory]: """ get_scheduler_factories returns all the available schedulers names under `group` and the @@ -48,15 +57,7 @@ def get_scheduler_factories( The first scheduler in the dictionary is used as the default scheduler. """ - - if skip_defaults: - default_schedulers = {} - else: - default_schedulers: dict[str, SchedulerFactory] = {} - for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): - default_schedulers[scheduler] = _defer_load_scheduler(path) - - return load_group(group, default=default_schedulers) + return load_group(group, default={} if skip_defaults else default_schedulers()) def get_default_scheduler_name() -> str: diff --git a/torchx/schedulers/test/registry_test.py b/torchx/schedulers/test/registry_test.py index e133aafcf..21b27c071 100644 --- a/torchx/schedulers/test/registry_test.py +++ b/torchx/schedulers/test/registry_test.py @@ -43,3 +43,75 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None: for scheduler in schedulers.values(): self.assertEqual("test_session", scheduler.session_name) + + @patch("torchx.schedulers.load_group") + def test_torchx_schedulers_overrides_all(self, mock_load_group: MagicMock) -> None: + """torchx.schedulers completely overrides defaults and ignores extras""" + mock_custom: MagicMock = MagicMock() + mock_extra: MagicMock = MagicMock() + + mock_load_group.side_effect = lambda group, default: ( + {"custom": mock_custom} + if group == "torchx.schedulers" + else {"extra": mock_extra} if group == "torchx.schedulers.extra" else {} + ) + + factories = get_scheduler_factories() + + self.assertEqual(factories, {"custom": mock_custom}) + self.assertNotIn("local_docker", factories) + self.assertNotIn("extra", factories) + + @patch("torchx.schedulers.load_group") + def test_no_custom_returns_defaults_and_extras( + self, mock_load_group: MagicMock + ) -> None: + """no custom schedulers returns built-in + extras""" + mock_extra: MagicMock = MagicMock() + + mock_load_group.side_effect = lambda group, default: ( + {"extra": mock_extra} if group == "torchx.schedulers.extra" else default + ) + + factories = get_scheduler_factories() + + self.assertIn("local_docker", factories) + self.assertIn("slurm", factories) + self.assertIn("extra", factories) + + @patch("torchx.schedulers.load_group") + def test_no_custom_no_extras_returns_builtins( + self, mock_load_group: MagicMock + ) -> None: + """no custom, no extras returns only built-in schedulers""" + mock_load_group.side_effect = lambda group, default: default + + factories = get_scheduler_factories() + + self.assertIn("local_docker", factories) + self.assertIn("slurm", factories) + + @patch("torchx.schedulers.load_group") + def test_skip_defaults_returns_empty(self, mock_load_group: MagicMock) -> None: + """skip_defaults=True with no custom schedulers returns empty""" + mock_load_group.side_effect = lambda group, default: default + + factories = get_scheduler_factories(skip_defaults=True) + + self.assertEqual(factories, {}) + + @patch("torchx.schedulers.load_group") + def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None: + """first custom scheduler becomes the default""" + mock_aws: MagicMock = MagicMock() + mock_custom: MagicMock = MagicMock() + + mock_load_group.side_effect = lambda group, default: ( + {"aws_batch": mock_aws, "custom_1": mock_custom} + if group == "torchx.schedulers" + else {} + ) + + default_name = get_default_scheduler_name() + + self.assertIn(default_name, ["aws_batch", "custom_1"])