From 1b5435a0aa6ff38792f1839c515ecb165b7c4824 Mon Sep 17 00:00:00 2001 From: Melih Elibol Date: Wed, 2 Mar 2022 11:14:45 -0800 Subject: [PATCH 1/2] specify backend from init. --- nums/api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nums/api.py b/nums/api.py index 5d7bae84..e13ffd96 100644 --- a/nums/api.py +++ b/nums/api.py @@ -25,10 +25,14 @@ def init( address: Optional[str] = None, num_cpus: Optional[int] = None, cluster_shape: Optional[tuple] = None, + backend: Optional[str] = None, ): # pylint: disable = import-outside-toplevel import nums.core.settings as settings + if backend is not None: + assert backend in {"serial", "ray", "dask", "mpi"} + settings.backend_name = backend if cluster_shape is not None: settings.cluster_shape = cluster_shape settings.num_cpus = num_cpus From a3504e85b9e3a5af805d346b459608b53c9506f0 Mon Sep 17 00:00:00 2001 From: Melih Elibol Date: Wed, 2 Mar 2022 11:18:50 -0800 Subject: [PATCH 2/2] add test for init backend. --- tests/test_api.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_api.py b/tests/test_api.py index abb5cee1..8f459e98 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -24,6 +24,23 @@ # pylint: disable=import-outside-toplevel +def test_init(): + import nums + from nums.core import application_manager + from nums.core.backends.serial import SerialBackend + from nums.core.backends.ray import RayBackend + + nums.init(backend="serial") + instance = application_manager.instance() + assert isinstance(instance.km.backend, SerialBackend) + application_manager.destroy() + + nums.init(backend="ray") + instance = application_manager.instance() + assert isinstance(instance.km.backend, RayBackend) + application_manager.destroy() + + def test_rwd(): import nums from nums.core import application_manager