diff --git a/torchx/cli/test/cmd_run_test.py b/torchx/cli/test/cmd_run_test.py index b9219dc29..38b3d4b43 100644 --- a/torchx/cli/test/cmd_run_test.py +++ b/torchx/cli/test/cmd_run_test.py @@ -22,6 +22,7 @@ from torchx.cli.argparse_util import ArgOnceAction, torchxconfig from torchx.cli.cmd_run import _parse_component_name_and_args, CmdBuiltins, CmdRun +from torchx.runner.config import ENV_TORCHXCONFIG from torchx.schedulers.local_scheduler import SignalException from torchx.specs import AppDryRunInfo @@ -40,11 +41,19 @@ def cwd(path: str) -> Generator[None, None, None]: class CmdRunTest(unittest.TestCase): def setUp(self) -> None: self.tmpdir = Path(tempfile.mkdtemp()) + + # create empty .torchxconfig so that user .torchxconfig is not picked up + empty_config = self.tmpdir / ".torchxconfig" + empty_config.touch() + self.mock_env = patch.dict(os.environ, {ENV_TORCHXCONFIG: str(empty_config)}) + self.mock_env.start() + self.parser = argparse.ArgumentParser() self.cmd_run = CmdRun() self.cmd_run.add_arguments(self.parser) def tearDown(self) -> None: + self.mock_env.stop() shutil.rmtree(self.tmpdir, ignore_errors=True) ArgOnceAction.called_args = set() torchxconfig.called_args = set() diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index c43cfa0c9..40b0d1202 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -13,7 +13,7 @@ scheduler or pipeline adapter. """ import difflib -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Mapping, Optional from torchx.specs.api import ( ALL, @@ -48,12 +48,17 @@ ) from torchx.specs.builders import make_app_handle, materialize_appdef, parse_mounts -from torchx.specs.named_resources_aws import NAMED_RESOURCES as AWS_NAMED_RESOURCES -from torchx.specs.named_resources_generic import ( - NAMED_RESOURCES as GENERIC_NAMED_RESOURCES, -) from torchx.util.entrypoints import load_group +from torchx.util.modules import import_attr + +AWS_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr( + "torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={} +) +GENERIC_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr( + "torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={} +) + GiB: int = 1024 diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index ab1284a7b..dabf744eb 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -18,7 +18,6 @@ from types import ModuleType from typing import Any, Callable, Dict, Generator, List, Optional, Union -from torchx.specs import AppDef from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate from torchx.util import entrypoints from torchx.util.io import read_conf_file