diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 5076870d7d..1f363318e2 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -120,7 +120,7 @@ jobs: if: matrix.subset != 'dragon' run: smart build --device cpu -v - + - name: Install ML Runtimes (with dragon) if: matrix.subset == 'dragon' env: diff --git a/doc/changelog.md b/doc/changelog.md index a25599f003..5269584554 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Add `TensorFlowWorker` and `ONNXWorker` - Fix symlink operation - RequestBatch rewrite - Fix regression on hostlist param to DragonRunRequest @@ -51,6 +52,7 @@ To be released at some point in the future Description +- Allow specifying Model and Ensemble parameters with - Implement workaround for Tensorflow that allows RedisAI to build with GCC-14 - Add instructions for installing SmartSim on PML's Scylla @@ -102,6 +104,9 @@ Description Detailed Notes +- The serializer would fail if a parameter for a Model or Ensemble + was specified as a numpy dtype. The constructors for these + methods now validate that the input is number-like and convert - On Frontier, the MIOPEN cache may need to be set prior to using RedisAI in the ``smart validate``. The instructions for Frontier have been updated accordingly. diff --git a/doc/installation_instructions/basic.rst b/doc/installation_instructions/basic.rst index a5db285ca8..98e136fe0b 100644 --- a/doc/installation_instructions/basic.rst +++ b/doc/installation_instructions/basic.rst @@ -319,6 +319,20 @@ in combination to customize the Dragon installation. For example: smart build --device cpu --dragon-repo userfork/dragon --dragon-version 0.91 +``smart build`` supports installing a specific version of dragon. It exposes the +parameters ``--dragon-repo`` and ``--dragon-version``, which can be used alone or +in combination to customize the Dragon installation. For example: + +.. code-block:: bash + + # using the --dragon-repo and --dragon-version flags to customize the Dragon installation + smart build --device cpu --dragon-repo userfork/dragon # install Dragon from a specific repo + smart build --device cpu --dragon-version 0.10 # install a specific Dragon release + + # combining both flags + smart build --device cpu --dragon-repo userfork/dragon --dragon-version 0.91 + + .. note:: Dragon is only supported on Linux systems. For further information, you can read :ref:`the dedicated documentation page `. diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py index 36f427937c..b53cb0b9cc 100644 --- a/ex/high_throughput_inference/mli_driver.py +++ b/ex/high_throughput_inference/mli_driver.py @@ -1,31 +1,66 @@ -import os +import argparse import base64 -import cloudpickle +import os +import shutil import sys -from smartsim import Experiment -from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker -from smartsim.status import TERMINAL_STATUSES -from smartsim.settings import DragonRunSettings import time import typing as t +import cloudpickle + +from smartsim import Experiment +from smartsim.settings import DragonRunSettings +from smartsim.status import TERMINAL_STATUSES + +parser = argparse.ArgumentParser("Mock application") +parser.add_argument("--log_max_batchsize", default=8, type=int) +parser.add_argument("--num_nodes_app", default=1, type=int) +parser.add_argument("--toolkit", default="torch", choices=["torch","tensorflow","onnx"], type=str) +args = parser.parse_args() + DEVICE = "gpu" -NUM_RANKS = 4 +NUM_RANKS_PER_NODE = 1 +NUM_NODES_APP = args.num_nodes_app NUM_WORKERS = 1 +BATCH_SIZE = 2 +BATCH_TIMEOUT = 0.0 filedir = os.path.dirname(__file__) worker_manager_script_name = os.path.join(filedir, "standalone_worker_manager.py") -app_script_name = os.path.join(filedir, "mock_app.py") -model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") +if args.toolkit == "torch": + # keeping old name for backward compatibility + app_script_name = os.path.join(filedir, "mock_app.py") + model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") +else: + app_script_name = os.path.join(filedir, f"mock_app_{args.toolkit}.py") transport: t.Literal["hsta", "tcp"] = "hsta" os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport -exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}") +exp_path = os.path.join( + filedir, + "benchmark", + args.toolkit, + f"throughput_n{NUM_NODES_APP}_rpn{NUM_RANKS_PER_NODE}_timeout{BATCH_TIMEOUT}", + f"samples{2**args.log_max_batchsize}", +) +try: + shutil.rmtree(exp_path) + time.sleep(2) +except: + pass os.makedirs(exp_path, exist_ok=True) -exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path) +exp = Experiment("MLI_benchmark", launcher="dragon", exp_path=exp_path) -torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") +if args.toolkit == "torch": + from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker + worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") +elif args.toolkit == "tensorflow": + from smartsim._core.mli.infrastructure.worker.tensorflow_worker import TensorFlowWorker + worker_str = base64.b64encode(cloudpickle.dumps(TensorFlowWorker)).decode("ascii") +elif args.toolkit == "onnx": + from smartsim._core.mli.infrastructure.worker.onnx_worker import ONNXWorker + worker_str = base64.b64encode(cloudpickle.dumps(ONNXWorker)).decode("ascii") worker_manager_rs: DragonRunSettings = exp.create_run_settings( sys.executable, @@ -34,35 +69,44 @@ "--device", DEVICE, "--worker_class", - torch_worker_str, + worker_str, "--batch_size", - str(NUM_RANKS//NUM_WORKERS), + str(BATCH_SIZE), "--batch_timeout", - str(0.00), + str(BATCH_TIMEOUT), "--num_workers", - str(NUM_WORKERS) + str(NUM_WORKERS), ], ) aff = [] worker_manager_rs.set_cpu_affinity(aff) - +worker_manager_rs.set_gpu_affinity([0, 1, 2, 3]) +worker_manager_rs.set_hostlist(["pinoak0037"]) worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs) worker_manager.attach_generator_files(to_copy=[worker_manager_script_name]) app_rs: DragonRunSettings = exp.create_run_settings( sys.executable, - exe_args=[app_script_name, "--device", DEVICE, "--log_max_batchsize", str(6)], + exe_args=[ + app_script_name, + "--device", + DEVICE, + "--log_max_batchsize", + str(args.log_max_batchsize), + ], ) -app_rs.set_tasks_per_node(NUM_RANKS) - +app_rs.set_tasks_per_node(NUM_RANKS_PER_NODE) +app_rs.set_nodes(NUM_NODES_APP) app = exp.create_model("app", run_settings=app_rs) -app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) +if args.toolkit == "torch": + app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) exp.generate(worker_manager, app, overwrite=True) -exp.start(worker_manager, app, block=False) +exp.start(worker_manager, block=False) +exp.start(app, block=False) while True: if exp.get_status(app)[0] in TERMINAL_STATUSES: diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py index c3b3eaaf4c..2564918450 100644 --- a/ex/high_throughput_inference/mock_app.py +++ b/ex/high_throughput_inference/mock_app.py @@ -24,23 +24,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# isort: off -import dragon -from dragon import fli -from dragon.channels import Channel -import dragon.channels -from dragon.data.ddict.ddict import DDict -from dragon.globalservices.api_setup import connect_to_infrastructure -from dragon.utils import b64decode, b64encode - -# isort: on import argparse import io - +from mpi4py import MPI import torch from smartsim.log import get_logger +from smartsim._core.mli.client.protoclient import ProtoClient torch.set_num_interop_threads(16) torch.set_num_threads(1) @@ -48,31 +39,26 @@ logger = get_logger("App") logger.info("Started app") -from collections import OrderedDict - -from smartsim.log import get_logger, log_to_file -from smartsim._core.mli.client.protoclient import ProtoClient logger = get_logger("App") -CHECK_RESULTS_AND_MAKE_ALL_SLOWER = False - - class ResNetWrapper: """Wrapper around a pre-rained ResNet model.""" + def __init__(self, name: str, model: str): """Initialize the instance. :param name: The name to use for the model :param model: The path to the pre-trained PyTorch model""" - self._model = torch.jit.load(model) + self._model = None # torch.jit.load(model) self._name = name - buffer = io.BytesIO() - scripted = torch.jit.trace(self._model, self.get_batch()) - torch.jit.save(scripted, buffer) + + with open(model, "rb") as model_file: + buffer = io.BytesIO(model_file.read()) self._serialized_model = buffer.getvalue() + # pylint: disable-next=no-self-use def get_batch(self, batch_size: int = 32): """Create a random batch of data with the correct dimensions to invoke a ResNet model. @@ -96,6 +82,11 @@ def name(self) -> str: return self._name +def log(msg: str, rank_: int) -> None: + if rank_ == 0: + logger.info(msg) + + if __name__ == "__main__": parser = argparse.ArgumentParser("Mock application") @@ -105,38 +96,24 @@ def name(self) -> str: resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt") - client = ProtoClient(timing_on=True) - client.set_model(resnet.name, resnet.model) + comm_world = MPI.COMM_WORLD + rank = comm_world.Get_rank() + client = ProtoClient(timing_on=True, rank=rank) + + if rank == 0: + client.set_model(resnet.name, resnet.model) - if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: - # TODO: adapt to non-Nvidia devices - torch_device = args.device.replace("gpu", "cuda") - pt_model = torch.jit.load(io.BytesIO(initial_bytes=(resnet.model))).to( - torch_device - ) + comm_world.Barrier() TOTAL_ITERATIONS = 100 - for log2_bsize in range(args.log_max_batchsize + 1): + for log2_bsize in range(args.log_max_batchsize, args.log_max_batchsize + 1): b_size: int = 2**log2_bsize - logger.info(f"Batch size: {b_size}") - for iteration_number in range(TOTAL_ITERATIONS + int(b_size == 1)): - logger.info(f"Iteration: {iteration_number}") - sample_batch = resnet.get_batch(b_size) + log(f"Batch size: {b_size}", rank) + for iteration_number in range(TOTAL_ITERATIONS): + sample_batch = resnet.get_batch(b_size).numpy() remote_result = client.run_model(resnet.name, sample_batch) + comm_world.Barrier() logger.info(client.perf_timer.get_last("total_time")) - if CHECK_RESULTS_AND_MAKE_ALL_SLOWER: - local_res = pt_model(sample_batch.to(torch_device)) - err_norm = torch.linalg.vector_norm( - torch.flatten(remote_result).to(torch_device) - - torch.flatten(local_res), - ord=1, - ).cpu() - res_norm = torch.linalg.vector_norm(remote_result, ord=1).item() - local_res_norm = torch.linalg.vector_norm(local_res, ord=1).item() - logger.info( - f"Avg norm of error {err_norm.item()/b_size} compared to result norm of {res_norm/b_size}:{local_res_norm/b_size}" - ) - torch.cuda.synchronize() - - client.perf_timer.print_timings(to_file=True) + + client.perf_timer.print_timings(to_file=True, to_stdout=rank == 0) diff --git a/ex/high_throughput_inference/mock_app_onnx.py b/ex/high_throughput_inference/mock_app_onnx.py new file mode 100644 index 0000000000..2ae5e0dbdb --- /dev/null +++ b/ex/high_throughput_inference/mock_app_onnx.py @@ -0,0 +1,130 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse + +from mpi4py import MPI +import numpy +from numpy.polynomial import Polynomial + +import onnx +from sklearn.preprocessing import PolynomialFeatures +from sklearn.linear_model import LinearRegression + +from skl2onnx import to_onnx + +from smartsim.log import get_logger +from smartsim._core.mli.client.protoclient import ProtoClient + +logger = get_logger("App") + + +class LinRegWrapper: + def __init__( + self, + name: str, + model: onnx.onnx_ml_pb2.ModelProto, + ): + self._get_onnx_model(model) + self._name = name + self._poly = PolynomialFeatures + + def _get_onnx_model(self, model: onnx.onnx_ml_pb2.ModelProto): + self._serialized_model = model.SerializeToString() + + # pylint: disable-next=no-self-use + def get_batch(self, batch_size: int = 32): + """Create a random batch of data with the correct dimensions to + invoke a ResNet model. + + :param batch_size: The desired number of samples to produce + :returns: A PyTorch tensor""" + x = numpy.random.randn(batch_size, 1).astype(numpy.float32) + return poly.fit_transform(x.reshape(-1, 1)) + + @property + def model(self): + """The content of a model file. + + :returns: The model bytes""" + return self._serialized_model + + @property + def name(self): + """The name applied to the model. + + :returns: The name""" + return self._name + + +def log(msg: str, rank_: int) -> None: + if rank_ == 0: + logger.info(msg) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu", type=str) + parser.add_argument("--log_max_batchsize", default=8, type=int) + args = parser.parse_args() + + X = numpy.linspace(0, 10, 10).astype(numpy.float32) + poly = PolynomialFeatures(degree=2, include_bias=False) + p = Polynomial([1.4, -10, 4]) + poly_features = poly.fit_transform(X.reshape(-1, 1)) + poly_reg_model = LinearRegression() + poly_reg_model.fit(poly_features, p(X)) + + onnx_model = to_onnx(poly_reg_model, poly_features, target_opset=13) + + linreg = LinRegWrapper("LinReg", onnx_model) + + comm_world = MPI.COMM_WORLD + rank = comm_world.Get_rank() + client = ProtoClient(timing_on=True, rank=rank) + + if rank == 0: + client.set_model(linreg.name, linreg.model) + + MPI.COMM_WORLD.Barrier() + + TOTAL_ITERATIONS = 100 + + for log2_bsize in range(args.log_max_batchsize, args.log_max_batchsize + 1): + b_size: int = 2**log2_bsize + log(f"Batch size: {b_size}", rank) + for iteration_number in range(TOTAL_ITERATIONS): + sample_batch = linreg.get_batch(b_size) + remote_result = client.run_model(linreg.name, sample_batch) + log( + f"Completed iteration: {iteration_number} " + f"in {client.perf_timer.get_last('total_time')} seconds", + rank, + ) + + client.perf_timer.print_timings(to_file=True, to_stdout=rank == 0) diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py index 8978bcea23..d2a3d4a05a 100644 --- a/ex/high_throughput_inference/mock_app_redis.py +++ b/ex/high_throughput_inference/mock_app_redis.py @@ -26,26 +26,28 @@ import argparse import io -import numpy import time + +import numpy import torch from mpi4py import MPI -from smartsim.log import get_logger -from smartsim._core.utils.timings import PerfTimer from smartredis import Client +from smartsim._core.utils.timings import PerfTimer +from smartsim.log import get_logger + logger = get_logger("App") -class ResNetWrapper(): + +class ResNetWrapper: def __init__(self, name: str, model: str): - self._model = torch.jit.load(model) + self._model = None self._name = name - buffer = io.BytesIO() - scripted = torch.jit.trace(self._model, self.get_batch()) - torch.jit.save(scripted, buffer) + with open(model, "rb") as model_file: + buffer = io.BytesIO(model_file.read()) self._serialized_model = buffer.getvalue() - def get_batch(self, batch_size: int=32): + def get_batch(self, batch_size: int = 32): return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) @property @@ -56,6 +58,12 @@ def model(self): def name(self): return self._name + +def log(msg: str, rank: int) -> None: + if rank == 0: + logger.info(msg) + + if __name__ == "__main__": comm = MPI.COMM_WORLD @@ -63,28 +71,48 @@ def name(self): parser = argparse.ArgumentParser("Mock application") parser.add_argument("--device", default="cpu") + parser.add_argument("--log_max_batchsize", default=8, type=int) args = parser.parse_args() - resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device}.pt") client = Client(cluster=False, address=None) - client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper()) - perf_timer: PerfTimer = PerfTimer(debug=False, timing_on=timing_on, prefix=f"redis{rank}_") + if rank == 0: + client.set_model( + resnet.name, resnet.model, backend="TORCH", device=args.device.upper() + ) + + comm.Barrier() + + perf_timer: PerfTimer = PerfTimer( + debug=False, timing_on=True, prefix=f"redis{rank}_" + ) total_iterations = 100 - timings=[] - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: - logger.info(f"Batch size: {batch_size}") - for iteration_number in range(total_iterations + int(batch_size==1)): + timings = [] + for log2_bsize in range(args.log_max_batchsize, args.log_max_batchsize + 1): + batch_size: int = 2**log2_bsize + log(f"Batch size: {batch_size}", rank) + for iteration_number in range(total_iterations): perf_timer.start_timings("batch_size", batch_size) - logger.info(f"Iteration: {iteration_number}") input_name = f"batch_{rank}" output_name = f"result_{rank}" - client.put_tensor(name=input_name, data=resnet.get_batch(batch_size).numpy()) - client.run_model(name=resnet.name, inputs=[input_name], outputs=[output_name]) + client.put_tensor( + name=input_name, data=resnet.get_batch(batch_size).numpy() + ) + perf_timer.measure_time("send_request") + client.run_model( + name=resnet.name, inputs=[input_name], outputs=[output_name] + ) + perf_timer.measure_time("run_model") result = client.get_tensor(name=output_name) + perf_timer.measure_time("receive_response") perf_timer.end_timings() + comm.Barrier() + log( + f"Completed iteration: {iteration_number} in {perf_timer.get_last('total_time')} seconds", + rank, + ) - - perf_timer.print_timings(True) + perf_timer.print_timings(True, to_stdout=rank == 0) diff --git a/ex/high_throughput_inference/mock_app_tensorflow.py b/ex/high_throughput_inference/mock_app_tensorflow.py new file mode 100644 index 0000000000..704d51ee48 --- /dev/null +++ b/ex/high_throughput_inference/mock_app_tensorflow.py @@ -0,0 +1,120 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse + +from mpi4py import MPI +import numpy +import tensorflow as tf +from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2_as_graph, +) + +from smartsim.log import get_logger +from smartsim._core.mli.client.protoclient import ProtoClient + +logger = get_logger("App") + + +class ResNetWrapper: + def __init__( + self, + name: str, + model: tf.keras.Model, + ): + self._get_tf_model(model) + self._name = name + + def _get_tf_model(self, model: tf.keras.Model): + real_model = tf.function(model).get_concrete_function( + tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) + ) + _, graph_def = convert_variables_to_constants_v2_as_graph(real_model) + self._serialized_model = graph_def.SerializeToString() + + # pylint: disable-next=no-self-use + def get_batch(self, batch_size: int = 32): + """Create a random batch of data with the correct dimensions to + invoke a ResNet model. + + :param batch_size: The desired number of samples to produce + :returns: A PyTorch tensor""" + return numpy.random.randn(batch_size, 224, 224, 3).astype(numpy.float32) + + @property + def model(self): + """The content of a model file. + + :returns: The model bytes""" + return self._serialized_model + + @property + def name(self): + """The name applied to the model. + + :returns: The name""" + return self._name + + +def log(msg: str, rank_: int) -> None: + if rank_ == 0: + logger.info(msg) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu", type=str) + parser.add_argument("--log_max_batchsize", default=8, type=int) + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", tf.keras.applications.ResNet50()) + + comm_world = MPI.COMM_WORLD + rank = comm_world.Get_rank() + client = ProtoClient(timing_on=True, rank=rank) + + if rank == 0: + client.set_model(resnet.name, resnet.model) + + comm_world.Barrier() + + TOTAL_ITERATIONS = 100 + + for log2_bsize in range(args.log_max_batchsize, args.log_max_batchsize + 1): + b_size: int = 2**log2_bsize + log(f"Batch size: {b_size}", rank) + for iteration_number in range(TOTAL_ITERATIONS): + sample_batch = resnet.get_batch(b_size) + remote_result = client.run_model(resnet.name, sample_batch) + log( + f"Completed iteration: {iteration_number} in " + f"{client.perf_timer.get_last('total_time')} seconds", + rank, + ) + + client.perf_timer.print_timings(to_file=True, to_stdout=rank == 0) diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py index ff57725d40..b3b2424723 100644 --- a/ex/high_throughput_inference/redis_driver.py +++ b/ex/high_throughput_inference/redis_driver.py @@ -24,29 +24,57 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import argparse import os import sys +import time + from smartsim import Experiment from smartsim.status import TERMINAL_STATUSES -import time DEVICE = "gpu" +NUM_TASKS_PER_NODE = 16 + filedir = os.path.dirname(__file__) app_script_name = os.path.join(filedir, "mock_app_redis.py") model_name = os.path.join(filedir, f"resnet50.{DEVICE}.pt") +parser = argparse.ArgumentParser("Mock application") +parser.add_argument("--num_nodes_app", default=1, type=int) +parser.add_argument("--log_max_batchsize", default=8, type=int) +args = parser.parse_args() + +NUM_NODES = args.num_nodes_app + +exp_path = os.path.join( + filedir, + "benchmark", + f"redis_ai_multi_n{NUM_NODES}_rpn{NUM_TASKS_PER_NODE}", + f"samples{2**args.log_max_batchsize}", +) +try: + shutil.rmtree(exp_path) + time.sleep(2) +except: + pass -exp_path = os.path.join(filedir, "redis_ai_multi") os.makedirs(exp_path, exist_ok=True) exp = Experiment("redis_ai_multi", launcher="slurm", exp_path=exp_path) -db = exp.create_database(interface="hsn0") +db = exp.create_database(interface="hsn0", hosts=["pinoak0036"]) app_rs = exp.create_run_settings( - sys.executable, exe_args = [app_script_name, "--device", DEVICE] - ) -app_rs.set_nodes(1) -app_rs.set_tasks(4) + sys.executable, + exe_args=[ + app_script_name, + "--device", + DEVICE, + "--log_max_batchsize", + str(args.log_max_batchsize), + ], +) +app_rs.set_nodes(NUM_NODES) +app_rs.set_tasks(NUM_NODES * NUM_TASKS_PER_NODE) app = exp.create_model("app", run_settings=app_rs) app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) @@ -63,4 +91,4 @@ break time.sleep(5) -print("Exiting.") \ No newline at end of file +print("Exiting.") diff --git a/ex/high_throughput_inference/standalone_worker_manager.py b/ex/high_throughput_inference/standalone_worker_manager.py index b4527bc5d2..afaf996780 100644 --- a/ex/high_throughput_inference/standalone_worker_manager.py +++ b/ex/high_throughput_inference/standalone_worker_manager.py @@ -25,9 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import dragon - -# pylint disable=import-error +# pylint: disable=import-error import dragon.infrastructure.policy as dragon_policy import dragon.infrastructure.process_desc as dragon_process_desc import dragon.native.process as dragon_process @@ -35,10 +33,8 @@ from dragon.channels import Channel from dragon.data.ddict.ddict import DDict from dragon.globalservices.api_setup import connect_to_infrastructure -from dragon.managed_memory import MemoryPool -from dragon.utils import b64decode, b64encode -# pylint enable=import-error +# pylint: enable=import-error # isort: off # isort: on @@ -49,7 +45,6 @@ import os import socket import time -import typing as t import cloudpickle @@ -124,13 +119,15 @@ def service_as_dragon_proc( "--batch_size", type=int, default=1, - help="How many requests the workers will try to aggregate before processing them", + help="How many requests the workers will try " + "to aggregate before processing them", ) parser.add_argument( "--batch_timeout", type=float, default=0.001, - help="How much time (in seconds) should be waited before processing an incomplete aggregated request", + help="How much time (in seconds) should be waited " + "before processing an incomplete aggregated request", ) args = parser.parse_args() @@ -143,7 +140,9 @@ def service_as_dragon_proc( to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + backbone._allow_reserved_writes = True backbone.worker_queue = to_worker_fli_comm_ch.descriptor + backbone._allow_reserved_writes = False os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor @@ -163,6 +162,7 @@ def service_as_dragon_proc( batch_size=args.batch_size, config_loader=config_loader, worker_type=arg_worker_type, + mem_pool_size=128 * 1024**3, ) wms = [] @@ -210,9 +210,11 @@ def service_as_dragon_proc( # TODO: use ProcessGroup and restart=True? all_procs = [dispatcher_proc, *worker_manager_procs] - print(f"Dispatcher proc: {dispatcher_proc}") for proc in all_procs: proc.start() while all(proc.is_alive for proc in all_procs): time.sleep(1) + + for proc in all_procs: + logger.info(f"{proc} is alive: {proc.is_alive}") diff --git a/smartsim/_core/launcher/dragon/dragon_backend.py b/smartsim/_core/launcher/dragon/dragon_backend.py index 82863d73b5..0f00bf4c41 100644 --- a/smartsim/_core/launcher/dragon/dragon_backend.py +++ b/smartsim/_core/launcher/dragon/dragon_backend.py @@ -532,10 +532,10 @@ def _stop_steps(self) -> None: and proc_group.status == DragonStatus.RUNNING ): try: - proc_group.kill() + proc_group.stop() except dragon_process_group.DragonProcessGroupError: try: - proc_group.stop() + proc_group.kill() except dragon_process_group.DragonProcessGroupError: logger.error("Process group already stopped") redir_group = self._group_infos[step_id].redir_workers diff --git a/smartsim/_core/mli/client/protoclient.py b/smartsim/_core/mli/client/protoclient.py index 46598a8171..1f01198183 100644 --- a/smartsim/_core/mli/client/protoclient.py +++ b/smartsim/_core/mli/client/protoclient.py @@ -30,23 +30,15 @@ import dragon.channels from dragon.globalservices.api_setup import connect_to_infrastructure -try: - from mpi4py import MPI # type: ignore[import-not-found] -except Exception: - MPI = None - print("Unable to import `mpi4py` package") # isort: on # pylint: enable=unused-import,import-error -import numbers import os -import time import typing as t from collections import OrderedDict import numpy -import torch from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel @@ -82,7 +74,7 @@ class ProtoClient: triggering QueueFull exceptions.""" _EVENT_SOURCE = "proto-client" - """A user-friendly name for this class instance to identify + """A user-friendly name for this class instance to identify the client as the publisher of an event.""" @staticmethod @@ -131,7 +123,9 @@ def _attach_to_worker_queue(self) -> DragonFLIChannel: ) raise SmartSimError("Unable to locate worker queue using backbone") from ex - return DragonFLIChannel.from_descriptor(descriptor) + fli_channel = DragonFLIChannel.from_descriptor(descriptor) + + return fli_channel def _create_broadcaster(self) -> EventBroadcaster: """Create an EventBroadcaster that broadcasts events to @@ -148,6 +142,7 @@ def __init__( self, timing_on: bool, backbone_timeout: float = _DEFAULT_BACKBONE_TIMEOUT, + rank: int = 0, ) -> None: """Initialize the client instance. @@ -158,13 +153,7 @@ def __init__( :raises SmartSimError: If unable to attach to a backbone featurestore :raises ValueError: If an invalid backbone timeout is specified """ - if MPI is not None: - # TODO: determine a way to make MPI work in the test environment - # - consider catching the import exception and defaulting rank to 0 - comm = MPI.COMM_WORLD - rank: int = comm.Get_rank() - else: - rank = 0 + self._rank = rank if backbone_timeout <= 0: raise ValueError( @@ -179,18 +168,16 @@ def __init__( self._backbone.wait_timeout = self.backbone_timeout self._to_worker_fli = self._attach_to_worker_queue() - self._from_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) + self._from_worker_ch = DragonCommChannel( + create_local(self._DEFAULT_WORK_QUEUE_SIZE) + ) self._to_worker_ch = create_local(self._DEFAULT_WORK_QUEUE_SIZE) self._publisher = self._create_broadcaster() self.perf_timer: PerfTimer = PerfTimer( - debug=False, timing_on=timing_on, prefix=f"a{rank}_" + debug=False, timing_on=timing_on, prefix=f"a{self._rank}_" ) - self._start: t.Optional[float] = None - self._interm: t.Optional[float] = None - self._timings: _TimingDict = OrderedDict() - self._timing_on = timing_on @property def backbone_timeout(self) -> float: @@ -200,83 +187,18 @@ def backbone_timeout(self) -> float: :returns: A float indicating the number of seconds to allow""" return self._backbone_timeout - def _add_label_to_timings(self, label: str) -> None: - """Adds a new label into the timing dictionary to prepare for - receiving timing events. - - :param label: The label to create storage for - """ - if label not in self._timings: - self._timings[label] = [] - - @staticmethod - def _format_number(number: t.Union[numbers.Number, float]) -> str: - """Utility function for formatting numbers consistently for logs. - - :param number: The number to convert to a formatted string - :returns: The formatted string containing the number - """ - return f"{number:0.4e}" - - def start_timings(self, batch_size: numbers.Number) -> None: - """Configure the client to begin storing timing information. - - :param batch_size: The size of batches to generate as inputs - to the model - """ - if self._timing_on: - self._add_label_to_timings("batch_size") - self._timings["batch_size"].append(self._format_number(batch_size)) - self._start = time.perf_counter() - self._interm = time.perf_counter() - - def end_timings(self) -> None: - """Configure the client to stop storing timing information.""" - if self._timing_on and self._start is not None: - self._add_label_to_timings("total_time") - self._timings["total_time"].append( - self._format_number(time.perf_counter() - self._start) - ) - - def measure_time(self, label: str) -> None: - """Measures elapsed time since the last recorded signal. - - :param label: The label to measure time for - """ - if self._timing_on and self._interm is not None: - self._add_label_to_timings(label) - self._timings[label].append( - self._format_number(time.perf_counter() - self._interm) - ) - self._interm = time.perf_counter() - - def print_timings(self, to_file: bool = False) -> None: - """Print timing information to standard output. If `to_file` - is `True`, also write results to a file. - - :param to_file: If `True`, also saves timing information - to the files `timings.npy` and `timings.txt` - """ - print(" ".join(self._timings.keys())) - - value_array = numpy.array(self._timings.values(), dtype=float) - value_array = numpy.transpose(value_array) - for i in range(value_array.shape[0]): - print(" ".join(self._format_number(value) for value in value_array[i])) - if to_file: - numpy.save("timings.npy", value_array) - numpy.savetxt("timings.txt", value_array) - - def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any: + def run_model( + self, model: t.Union[bytes, str], batch: numpy.ndarray[t.Any, t.Any] + ) -> t.Any: """Execute a batch of inference requests with the supplied ML model. - :param model: The raw bytes or path to a pytorch model + :param model: The raw bytes or path to a model :param batch: The tensor batch to perform inference on :returns: The inference results :raises ValueError: if the worker queue is not configured properly in the environment variables """ - tensors = [batch.numpy()] + tensors = [batch] self.perf_timer.start_timings("batch_size", batch.shape[0]) built_tensor_desc = MessageHandler.build_tensor_descriptor( "c", "float32", list(batch.shape) @@ -304,14 +226,9 @@ def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any: raise ValueError("No worker queue available.") # pylint: disable-next=protected-access - with self._to_worker_fli._channel.sendh( # type: ignore - timeout=None, - stream_channel=self._to_worker_ch.channel, - ) as to_sendh: - to_sendh.send_bytes(request_bytes) - self.perf_timer.measure_time("send_request") - for tensor in tensors: - to_sendh.send_bytes(tensor.tobytes()) # TODO NOT FAST ENOUGH!!! + self._to_worker_fli.send_multiple( + [request_bytes, *[tensor.tobytes() for tensor in tensors]], timeout=None + ) logger.info(f"Message size: {len(request_bytes)} bytes") self.perf_timer.measure_time("send_tensors") @@ -324,15 +241,15 @@ def run_model(self, model: t.Union[bytes, str], batch: torch.Tensor) -> t.Any: # recv depending on the len(response.result.descriptors)? data_blob: bytes = from_recvh.recv_bytes(timeout=None) self.perf_timer.measure_time("receive_tensor") - result = torch.from_numpy( - numpy.frombuffer( - data_blob, - dtype=str(response.result.descriptors[0].dataType), - ) + result = numpy.frombuffer( + data_blob, + dtype=str(response.result.descriptors[0].dataType), ) + self.perf_timer.measure_time("deserialize_tensor") self.perf_timer.end_timings() + return result def set_model(self, key: str, model: bytes) -> None: diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py index 104333ce7f..afc2e65d79 100644 --- a/smartsim/_core/mli/comm/channel/channel.py +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -52,19 +52,32 @@ def __init__( """A user-friendly identifier for channel-related logging""" @abstractmethod - def send(self, value: bytes, timeout: float = 0.001) -> None: + def send( + self, + value: bytes, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, + ) -> None: """Send a message through the underlying communication channel. :param value: The value to send :param timeout: Maximum time to wait (in seconds) for messages to send + :param handle_timeout: Maximum time to wait (in seconds) to obtain + new send handle :raises SmartSimError: If sending message fails """ @abstractmethod - def recv(self, timeout: float = 0.001) -> t.List[bytes]: + def recv( + self, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, + ) -> t.List[bytes]: """Receives message(s) through the underlying communication channel. :param timeout: Maximum time to wait (in seconds) for messages to arrive + :param handle_timeout: Maximum time to wait (in seconds) to obtain new + receive handle :returns: The received message """ diff --git a/smartsim/_core/mli/comm/channel/dragon_channel.py b/smartsim/_core/mli/comm/channel/dragon_channel.py index 110f19258a..3444c05454 100644 --- a/smartsim/_core/mli/comm/channel/dragon_channel.py +++ b/smartsim/_core/mli/comm/channel/dragon_channel.py @@ -57,29 +57,41 @@ def channel(self) -> "dch.Channel": """ return self._channel - def send(self, value: bytes, timeout: float = 0.001) -> None: + def send( + self, + value: bytes, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, + ) -> None: """Send a message through the underlying communication channel. :param value: The value to send :param timeout: Maximum time to wait (in seconds) for messages to send + :param handle_timeout: Maximum time to wait (in seconds) to obtain + new send handle :raises SmartSimError: If sending message fails """ try: - with self._channel.sendh(timeout=timeout) as sendh: - sendh.send_bytes(value, blocking=False) + with self._channel.sendh(timeout=handle_timeout) as sendh: + sendh.send_bytes(value, timeout=timeout) logger.debug(f"DragonCommChannel {self.descriptor} sent message") except Exception as e: raise SmartSimError( f"Error sending via DragonCommChannel {self.descriptor}" ) from e - def recv(self, timeout: float = 0.001) -> t.List[bytes]: + def recv( + self, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, + ) -> t.List[bytes]: """Receives message(s) through the underlying communication channel. :param timeout: Maximum time to wait (in seconds) for messages to arrive + :param handle_timeout: Maximum time to wait (in seconds) to obtain new :returns: The received message(s) """ - with self._channel.recvh(timeout=timeout) as recvh: + with self._channel.recvh(timeout=handle_timeout) as recvh: messages: t.List[bytes] = [] try: diff --git a/smartsim/_core/mli/comm/channel/dragon_fli.py b/smartsim/_core/mli/comm/channel/dragon_fli.py index 01849247cd..9438efc25c 100644 --- a/smartsim/_core/mli/comm/channel/dragon_fli.py +++ b/smartsim/_core/mli/comm/channel/dragon_fli.py @@ -68,18 +68,27 @@ def __init__( self._buffer_size: int = buffer_size """Maximum number of messages that can be buffered before sending""" - def send(self, value: bytes, timeout: float = 0.001) -> None: + def send( + self, + value: bytes, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, + ) -> None: """Send a message through the underlying communication channel. :param value: The value to send :param timeout: Maximum time to wait (in seconds) for messages to send + :param handle_timeout: Maximum time to wait (in seconds) to obtain new + send handle :raises SmartSimError: If sending message fails """ try: if self._channel is None: self._channel = drg_util.create_local(self._buffer_size) - with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + with self._fli.sendh( + timeout=handle_timeout, stream_channel=self._channel + ) as sendh: sendh.send_bytes(value, timeout=timeout) logger.debug(f"DragonFLIChannel {self.descriptor} sent message") except Exception as e: @@ -91,21 +100,26 @@ def send(self, value: bytes, timeout: float = 0.001) -> None: def send_multiple( self, values: t.Sequence[bytes], - timeout: float = 0.001, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, ) -> None: """Send a message through the underlying communication channel. :param values: The values to send :param timeout: Maximum time to wait (in seconds) for messages to send + :param handle_timeout: Maximum time to wait (in seconds) to obtain new send + handle :raises SmartSimError: If sending message fails """ try: if self._channel is None: self._channel = drg_util.create_local(self._buffer_size) - with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + with self._fli.sendh( + timeout=handle_timeout, stream_channel=self._channel + ) as sendh: for value in values: - sendh.send_bytes(value) + sendh.send_bytes(value, timeout=timeout) logger.debug(f"DragonFLIChannel {self.descriptor} sent message") except Exception as e: self._channel = None @@ -113,16 +127,22 @@ def send_multiple( f"Error sending via DragonFLIChannel {self.descriptor} {e}" ) from e - def recv(self, timeout: float = 0.001) -> t.List[bytes]: + def recv( + self, + timeout: t.Optional[float] = 0.001, + handle_timeout: t.Optional[float] = 0.001, + ) -> t.List[bytes]: """Receives message(s) through the underlying communication channel. :param timeout: Maximum time to wait (in seconds) for messages to arrive + :param handle_timeout: Maximum time to wait (in seconds) to obtain new + receive handle :returns: The received message(s) :raises SmartSimError: If receiving message(s) fails """ messages = [] eot = False - with self._fli.recvh(timeout=timeout) as recvh: + with self._fli.recvh(timeout=handle_timeout) as recvh: while not eot: try: message, _ = recvh.recv_bytes(timeout=timeout) diff --git a/smartsim/_core/mli/infrastructure/control/worker_manager.py b/smartsim/_core/mli/infrastructure/control/worker_manager.py index 1c93276074..8e953d87ec 100644 --- a/smartsim/_core/mli/infrastructure/control/worker_manager.py +++ b/smartsim/_core/mli/infrastructure/control/worker_manager.py @@ -330,9 +330,6 @@ def _on_iteration(self) -> None: self._perf_timer.end_timings() - if self._perf_timer.max_length == 801: - self._perf_timer.print_timings(True) - def _can_shutdown(self) -> bool: """Determine if the service can be shutdown. diff --git a/smartsim/_core/mli/infrastructure/worker/onnx_worker.py b/smartsim/_core/mli/infrastructure/worker/onnx_worker.py new file mode 100644 index 0000000000..4299863315 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/onnx_worker.py @@ -0,0 +1,291 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os + +import numpy as np +from onnx import load_model_from_string +from onnxruntime import InferenceSession # type: ignore + +# isort: off +# isort: on + +# pylint: disable=import-error +from dragon.managed_memory import MemoryAlloc, MemoryPool + +from .....error import SmartSimError +from .....log import get_logger +from .worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TensorMeta, + TransformInputResult, + TransformOutputResult, +) + +# pylint: enable=import-error + + +logger = get_logger(__name__) + + +class ONNXWorker(MachineLearningWorkerBase): + """A worker that executes an ONNX model.""" + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + """Given a loaded MachineLearningModel, ensure it is loaded into + device memory. + + :param request: The request that triggered the pipeline + :param fetch_result: Raw outputs from fetching model + :param device: The device on which the model must be placed + :returns: LoadModelResult wrapping the model loaded for the request + :raises ValueError: If model reference object is not found + :raises RuntimeError: If loading and evaluating the model failed + """ + if fetch_result.model_bytes: + model_bytes = fetch_result.model_bytes + elif batch.raw_model and batch.raw_model.data: + model_bytes = batch.raw_model.data + else: + raise ValueError("Unable to load model without reference object") + + providers = [] + provider_options = [] + + if "gpu" in device.lower(): + device_split = device.split(":") + if len(device_split) > 1: + provider_options.append({"device_id": device_split[-1]}) + else: + provider_options.append({}) + if "ROCR_VISIBLE_DEVICES" in os.environ: + providers = ["ROCMExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + + # Fallback + providers.append("CPUExecutionProvider") + provider_options.append({}) + + try: + onnx_deserialized = load_model_from_string(model_bytes) + output_tensors = [n.name for n in onnx_deserialized.graph.output] + input_layers = [n.name for n in onnx_deserialized.graph.input] + session = InferenceSession( + model_bytes, providers=providers, provider_options=provider_options + ) + except Exception as e: + raise RuntimeError( + "Failed to load and evaluate the model: " + f"Model key {batch.model_id.key}, Device {device}" + ) from e + result = LoadModelResult( + session, + input_layers, + output_tensors, + ) + return result + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: FetchInputResult, + mem_pool: MemoryPool, + ) -> TransformInputResult: + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. + + :param batch: The batch that triggered the pipeline + :param fetch_result: Raw outputs from fetching inputs from feature store or + request + :param mem_pool: The memory pool used to access batched input tensors + :returns: The transformed inputs wrapped in a TransformInputResult + :raises ValueError: If tensors cannot be reconstructed + :raises IndexError: If index out of range + """ + results: list[bytes] = [] + total_samples = 0 + slices: list[slice] = [] + + all_dims: list[list[int]] = [] + all_dtypes: list[str] = [] + if fetch_results.meta is None: + raise ValueError("Cannot reconstruct tensor without meta information") + # Traverse inputs to get total number of samples and compute slices + # Assumption: first dimension is samples, all tensors in the same input + # have same number of samples + # thus we only look at the first tensor for each input + for res_idx, res_meta_list in enumerate(fetch_results.meta): + if res_meta_list is None or any( + item_meta is None for item_meta in res_meta_list + ): + raise ValueError("Cannot reconstruct tensor without meta information") + first_tensor_desc: TensorMeta = res_meta_list[0] # type: ignore + num_samples = first_tensor_desc.dimensions[0] + slices.append(slice(total_samples, total_samples + num_samples)) + total_samples = total_samples + num_samples + + if res_idx == len(fetch_results.meta) - 1: + # For each tensor in the last input, get remaining dimensions + # Assumptions: all inputs have the same number of tensors and + # last N-1 dimensions match across inputs for corresponding tensors + # thus: resulting array will be of size (num_samples, all_other_dims) + for item_meta in res_meta_list: + tensor_desc: TensorMeta = item_meta # type: ignore + tensor_dims = tensor_desc.dimensions + all_dims.append([total_samples, *tensor_dims[1:]]) + all_dtypes.append(tensor_desc.datatype) + + for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): + itemsize = np.empty((1), dtype=dtype).itemsize + alloc_size = int(np.prod(dims) * itemsize) + mem_alloc = mem_pool.alloc(alloc_size) + mem_view = mem_alloc.get_memview() + try: + joined = b"".join( + [ + fetch_result[result_tensor_idx] + for fetch_result in fetch_results.inputs + ] + ) + mem_view[:alloc_size] = joined + except IndexError as e: + raise IndexError( + "Error accessing elements in fetch_result.inputs " + f"with index {result_tensor_idx}" + ) from e + + results.append(mem_alloc.serialize()) + + return TransformInputResult(results, slices, all_dims, all_dtypes) + + # pylint: disable-next=unused-argument + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + """Execute an ML model on inputs transformed for use by the model. + + :param batch: The batch of requests that triggered the pipeline + :param load_result: The result of loading the model onto device memory + :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed + :returns: The result of inference wrapped in an ExecuteResult + :raises SmartSimError: If model is not loaded + :raises IndexError: If memory slicing is out of range + :raises ValueError: If tensor creation fails or is unable to evaluate the model + """ + if not load_result.model: + raise SmartSimError("Model must be loaded to execute") + + tensors = [] + mem_allocs = [] + for transformed, dims, dtype in zip( + transform_result.transformed, transform_result.dims, transform_result.dtypes + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + try: + tensors.append( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + except IndexError as e: + raise IndexError("Error during memory slicing") from e + except Exception as e: + raise ValueError("Error during tensor creation") from e + + sess = load_result.model + if load_result.inputs is None: + raise ValueError("Model was stored without inputs") + try: + results = sess.run( + load_result.outputs, + input_feed=dict(zip(load_result.inputs, tensors)), + ) + except Exception as e: + raise ValueError( + f"Error while evaluating the model: Model {batch.model_id.key}" + ) from e + + transform_result.transformed = [] + + execute_result = ExecuteResult(results, transform_result.slices) + for mem_alloc in mem_allocs: + mem_alloc.free() + return execute_result + + @staticmethod + def transform_output( + batch: RequestBatch, + execute_result: ExecuteResult, + ) -> list[TransformOutputResult]: + """Given inference results, perform transformations required to + transmit results to the requestor. + + :param batch: The batch of requests that triggered the pipeline + :param execute_result: The result of inference wrapped in an ExecuteResult + :returns: A list of transformed outputs + :raises IndexError: If indexing is out of range + :raises ValueError: If transforming output fails + """ + transformed_list: list[TransformOutputResult] = [] + cpu_predictions = execute_result.predictions + + for result_slice in execute_result.slices: + transformed = [] + for cpu_item in cpu_predictions: + try: + transformed.append(cpu_item[result_slice].tobytes()) + + # todo: need the shape from latest schemas added here. + transformed_list.append( + TransformOutputResult(transformed, None, "c", "float32") + ) # fixme + except IndexError as e: + raise IndexError( + f"Error accessing elements: result_slice {result_slice}" + ) from e + except Exception as e: + raise ValueError("Error transforming output") from e + + execute_result.predictions = [] + + return transformed_list diff --git a/smartsim/_core/mli/infrastructure/worker/tensorflow_worker.py b/smartsim/_core/mli/infrastructure/worker/tensorflow_worker.py new file mode 100644 index 0000000000..bd1f8b7cee --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/tensorflow_worker.py @@ -0,0 +1,322 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import logging +import os + +import numpy as np +import tensorflow as tf + +# pylint: disable-next=no-name-in-module +from tensorflow.python.framework.ops import disable_eager_execution + +# isort: off +# isort: on + +# pylint: disable=import-error +from dragon.managed_memory import MemoryAlloc, MemoryPool + +from .....error import SmartSimError +from .....log import get_logger +from .worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TensorMeta, + TransformInputResult, + TransformOutputResult, +) + +# pylint: enable=import-error + + +tf.get_logger().setLevel(logging.ERROR) +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +logger = get_logger(__name__) + +disable_eager_execution() + + +class TensorFlowWorker(MachineLearningWorkerBase): + """A worker that executes a TensorFlow model.""" + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + """Given a loaded MachineLearningModel, ensure it is loaded into + device memory. + + :param request: The request that triggered the pipeline + :param fetch_result: Raw outputs from fetching model + :param device: The device on which the model must be placed + :returns: LoadModelResult wrapping the model loaded for the request + :raises ValueError: If model reference object is not found + :raises RuntimeError: If loading and evaluating the model failed + """ + if fetch_result.model_bytes: + model_bytes = fetch_result.model_bytes + elif batch.raw_model and batch.raw_model.data: + model_bytes = batch.raw_model.data + else: + raise ValueError("Unable to load model without reference object") + + device_to_tf = {"cpu": "/CPU", "gpu": "/GPU"} + for old, new in device_to_tf.items(): + device = device.replace(old, new) + + try: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(model_bytes) + + # pylint: disable-next=not-context-manager + with tf.Graph().as_default() as graph, tf.device(device): + tf.import_graph_def(graph_def, name="") + ops = graph.get_operations() + except Exception as e: + raise RuntimeError( + "Failed to load and evaluate the model: " + f"Model key {batch.model_id.key}, Device {device}" + ) from e + + input_layers = set() + for operation in ops: + if operation.type == "Placeholder": + logger.debug( + f"Input op name: {operation.name}, " + f"output shape: {operation.outputs[0].get_shape()}" + ) + input_layers.add(f"{operation.name}:0") + + # Code initially taken from + # apple.github.io/coremltools/docs-guides/source/tensorflow-1-workflow.html + output_tensors = set() + input_tensors = set() + for operation in ops: + for x in operation.inputs: + input_tensors.add(x.name) + for operation in ops: + if len(operation.outputs) > 0: + x = operation.outputs[0] + potential_names = [x.name] + name_split = x.name.split(":") + potential_names.append( + ":".join((name_split[0] + "/resource", name_split[-1])) + ) + if all(name not in input_tensors for name in potential_names): + logger.debug( + f"Output tensor name: {x.name}, " + f"tensor shape: {x.get_shape()}, " + f"parent op type: {operation.type}" + ) + output_tensors.add(x.name) + + with tf.device(device): + result = LoadModelResult( + tf.compat.v1.Session(graph=graph), + list(input_layers), + list(output_tensors), + ) + return result + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: FetchInputResult, + mem_pool: MemoryPool, + ) -> TransformInputResult: + """Given a collection of data, perform a transformation on the data and put + the raw tensor data on a MemoryPool allocation. + + :param batch: The request batch that triggered the pipeline + :param fetch_result: Raw outputs from fetching inputs + :param mem_pool: The memory pool used to access batched input tensors + :returns: The transformed inputs wrapped in a TransformInputResult + :raises ValueError: If tensors cannot be reconstructed + :raises IndexError: If index out of range + """ + results: list[bytes] = [] + total_samples = 0 + slices: list[slice] = [] + + all_dims: list[list[int]] = [] + all_dtypes: list[str] = [] + if fetch_results.meta is None: + raise ValueError("Cannot reconstruct tensor without meta information") + # Traverse inputs to get total number of samples and compute slices + # Assumption: first dimension is samples, all tensors in the same input + # have same number of samples + # thus we only look at the first tensor for each input + for res_idx, res_meta_list in enumerate(fetch_results.meta): + if res_meta_list is None or any( + item_meta is None for item_meta in res_meta_list + ): + raise ValueError("Cannot reconstruct tensor without meta information") + first_tensor_desc: TensorMeta = res_meta_list[0] # type: ignore + num_samples = first_tensor_desc.dimensions[0] + slices.append(slice(total_samples, total_samples + num_samples)) + total_samples = total_samples + num_samples + + if res_idx == len(fetch_results.meta) - 1: + # For each tensor in the last input, get remaining dimensions + # Assumptions: all inputs have the same number of tensors and + # last N-1 dimensions match across inputs for corresponding tensors + # thus: resulting array will be of size (num_samples, all_other_dims) + for item_meta in res_meta_list: + tensor_desc: TensorMeta = item_meta # type: ignore + tensor_dims = tensor_desc.dimensions + all_dims.append([total_samples, *tensor_dims[1:]]) + all_dtypes.append(tensor_desc.datatype) + + for result_tensor_idx, (dims, dtype) in enumerate(zip(all_dims, all_dtypes)): + itemsize = np.empty((1), dtype=dtype).itemsize + alloc_size = int(np.prod(dims) * itemsize) + mem_alloc = mem_pool.alloc(alloc_size) + mem_view = mem_alloc.get_memview() + try: + joined = b"".join( + [ + fetch_result[result_tensor_idx] + for fetch_result in fetch_results.inputs + ] + ) + mem_view[:alloc_size] = joined + except IndexError as e: + raise IndexError( + "Error accessing elements in fetch_result.inputs " + f"with index {result_tensor_idx}" + ) from e + + results.append(mem_alloc.serialize()) + + return TransformInputResult(results, slices, all_dims, all_dtypes) + + # pylint: disable-next=unused-argument + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + """Execute an ML model on inputs transformed for use by the model. + + :param batch: The batch of requests that triggered the pipeline + :param load_result: The result of loading the model onto device memory + :param transform_result: The result of transforming inputs for model consumption + :param device: The device on which the model will be executed + :returns: The result of inference wrapped in an ExecuteResult + :raises SmartSimError: If model is not loaded + :raises IndexError: If memory slicing is out of range + :raises ValueError: If tensor creation fails or is unable to evaluate the model + """ + if not load_result.model: + raise SmartSimError("Model must be loaded to execute") + device_to_tf = {"cpu": "/CPU", "gpu": "/GPU"} + for old, new in device_to_tf.items(): + device = device.replace(old, new) + + tensors = [] + mem_allocs = [] + for transformed, dims, dtype in zip( + transform_result.transformed, transform_result.dims, transform_result.dtypes + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + try: + tensors.append( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + except IndexError as e: + raise IndexError("Error during memory slicing") from e + except Exception as e: + raise ValueError("Error during tensor creation") from e + + sess = load_result.model + if load_result.inputs is None: + raise ValueError("Model was stored without inputs") + try: + with tf.device(device): + results = sess.run( + load_result.outputs, + feed_dict=dict(zip(load_result.inputs, tensors)), + ) + except Exception as e: + raise ValueError( + f"Error while evaluating the model: Model {batch.model_id.key}" + ) from e + + transform_result.transformed = [] + + execute_result = ExecuteResult(results, transform_result.slices) + for mem_alloc in mem_allocs: + mem_alloc.free() + return execute_result + + @staticmethod + def transform_output( + batch: RequestBatch, + execute_result: ExecuteResult, + ) -> list[TransformOutputResult]: + """Given inference results, perform transformations required to + transmit results to the requestor. + + :param batch: The batch of requests that triggered the pipeline + :param execute_result: The result of inference wrapped in an ExecuteResult + :returns: A list of transformed outputs + :raises IndexError: If indexing is out of range + :raises ValueError: If transforming output fails + """ + transformed_list: list[TransformOutputResult] = [] + cpu_predictions = execute_result.predictions + + for result_slice in execute_result.slices: + transformed = [] + for cpu_item in cpu_predictions: + try: + transformed.append(cpu_item[result_slice].tobytes()) + + # todo: need the shape from latest schemas added here. + transformed_list.append( + TransformOutputResult(transformed, None, "c", "float32") + ) # fixme + except IndexError as e: + raise IndexError( + f"Error accessing elements: result_slice {result_slice}" + ) from e + except Exception as e: + raise ValueError("Error transforming output") from e + + execute_result.predictions = [] + + return transformed_list diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py index f8d6e7c2de..4642369ef5 100644 --- a/smartsim/_core/mli/infrastructure/worker/torch_worker.py +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -105,13 +105,14 @@ def transform_input( the raw tensor data on a MemoryPool allocation. :param batch: The batch that triggered the pipeline - :param fetch_result: Raw outputs from fetching inputs out of a feature store + :param fetch_result: Raw outputs from fetching inputs from feature store or + request :param mem_pool: The memory pool used to access batched input tensors :returns: The transformed inputs wrapped in a TransformInputResult :raises ValueError: If tensors cannot be reconstructed :raises IndexError: If index out of range """ - results: list[torch.Tensor] = [] + results: list[bytes] = [] total_samples = 0 slices: list[slice] = [] diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index b122a1d9ba..96aaffd85c 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -208,13 +208,22 @@ class TensorMeta: class LoadModelResult: """A wrapper around a loaded model.""" - def __init__(self, model: t.Any) -> None: + def __init__( + self, + model: t.Any, + inputs: t.Optional[t.List[str]] = None, + outputs: t.Optional[t.List[str]] = None, + ) -> None: """Initialize the LoadModelResult. :param model: The loaded model """ self.model = model """The loaded model (e.g. a TensorFlow, PyTorch, ONNX, etc. model)""" + self.inputs = inputs + """List of input layer names, only used in TensorFlow""" + self.outputs = outputs + """List of output tensor names, only used in TensorFlow""" class TransformInputResult: diff --git a/tests/dragon/__init__.py b/tests/dragon/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dragon/channel.py b/tests/dragon/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/dragon/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/dragon/conftest.py b/tests/dragon/conftest.py new file mode 100644 index 0000000000..d542700175 --- /dev/null +++ b/tests/dragon/conftest.py @@ -0,0 +1,129 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +import os +import pathlib +import socket +import subprocess +import sys +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict +import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc +import dragon.native.process as dragon_process + +from dragon.fli import FLInterface + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_storage() -> dragon_ddict.DDict: + """Fixture to instantiate a dragon distributed dictionary.""" + return dragon_util.create_ddict(1, 2, 32 * 1024**2) + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonFLIChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + channel_ = create_local() + fli_ = FLInterface(main_ch=channel_, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + return comm_channel + + +@pytest.fixture(scope="module") +def the_backbone( + the_storage: t.Any, the_worker_channel: DragonFLIChannel +) -> BackboneFeatureStore: + """Fixture to create a distributed dragon dictionary and wrap it + in a BackboneFeatureStore. + + :param the_storage: The dragon storage engine to use + :param the_worker_channel: Pre-configured worker channel + """ + + backbone = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_channel.descriptor + + return backbone + + +@pytest.fixture(scope="module") +def backbone_descriptor(the_backbone: BackboneFeatureStore) -> str: + # create a shared backbone featurestore + return the_backbone.descriptor + + +def function_as_dragon_proc( + entrypoint_fn: t.Callable[[t.Any], None], + args: t.List[t.Any], + cpu_affinity: t.List[int], + gpu_affinity: t.List[int], +) -> dragon_process.Process: + """Execute a function as an independent dragon process. + + :param entrypoint_fn: The function to execute + :param args: The arguments for the entrypoint function + :param cpu_affinity: The cpu affinity for the process + :param gpu_affinity: The gpu affinity for the process + :returns: The dragon process handle + """ + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) + local_policy = dragon_policy.Policy( + placement=dragon_policy.Policy.Placement.HOST_NAME, + host_name=socket.gethostname(), + cpu_affinity=cpu_affinity, + gpu_affinity=gpu_affinity, + ) + return dragon_process.Process( + target=entrypoint_fn, + args=args, + cwd=os.getcwd(), + policy=local_policy, + options=options, + stderr=dragon_process.Popen.STDOUT, + stdout=dragon_process.Popen.STDOUT, + ) diff --git a/tests/dragon/feature_store.py b/tests/dragon/feature_store.py new file mode 100644 index 0000000000..d06b0b334e --- /dev/null +++ b/tests/dragon/feature_store.py @@ -0,0 +1,152 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import FeatureStore +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class MemoryFeatureStore(FeatureStore): + """A feature store with values persisted only in local memory""" + + def __init__( + self, storage: t.Optional[t.Dict[str, t.Union[str, bytes]]] = None + ) -> None: + """Initialize the MemoryFeatureStore instance""" + super().__init__("in-memory-fs") + if storage is None: + storage = {"_": "abc"} + self._storage = storage + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism + + :param key: The unique key that identifies the resource + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + return self._storage[key] + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + self._storage[key] = value + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise""" + return key in self._storage + + +class FileSystemFeatureStore(FeatureStore): + """Alternative feature store implementation for testing. Stores all + data on the file system""" + + def __init__(self, storage_dir: t.Union[pathlib.Path, str]) -> None: + """Initialize the FileSystemFeatureStore instance + + :param storage_dir: (optional) root directory to store all data relative to""" + if isinstance(storage_dir, str): + storage_dir = pathlib.Path(storage_dir) + self._storage_dir = storage_dir + super().__init__(storage_dir.as_posix()) + + def _get(self, key: str) -> t.Union[str, bytes]: + """Retrieve a value from the underlying storage mechanism + + :param key: The unique key that identifies the resource + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + path = self._key_path(key) + if not path.exists(): + raise sse.SmartSimError(f"{path} not found in feature store") + return path.read_bytes() + + def _set(self, key: str, value: t.Union[str, bytes]) -> None: + """Store a value into the underlying storage mechanism + + :param key: The unique key that identifies the resource + :param value: The value to store + :returns: the value identified by the key + :raises KeyError: if the key has not been used to store a value""" + path = self._key_path(key, create=True) + if isinstance(value, str): + value = value.encode("utf-8") + path.write_bytes(value) + + def _contains(self, key: str) -> bool: + """Determine if the storage mechanism contains a given key + + :param key: The unique key that identifies the resource + :returns: True if the key is defined, False otherwise""" + path = self._key_path(key) + return path.exists() + + def _key_path(self, key: str, create: bool = False) -> pathlib.Path: + """Given a key, return a path that is optionally combined with a base + directory used by the FileSystemFeatureStore. + + :param key: Unique key of an item to retrieve from the feature store""" + value = pathlib.Path(key) + + if self._storage_dir is not None: + value = self._storage_dir / key + + if create: + value.parent.mkdir(parents=True, exist_ok=True) + + return value + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemFeatureStore": + """A factory method that creates an instance from a descriptor string + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemFeatureStore""" + try: + path = pathlib.Path(descriptor) + path.mkdir(parents=True, exist_ok=True) + if not path.is_dir(): + raise ValueError("FileSystemFeatureStore requires a directory path") + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + return FileSystemFeatureStore(path) + except: + logger.error(f"Error while creating FileSystemFeatureStore: {descriptor}") + raise diff --git a/tests/dragon/test_core_machine_learning_worker.py b/tests/dragon/test_core_machine_learning_worker.py new file mode 100644 index 0000000000..e9c356b4e0 --- /dev/null +++ b/tests/dragon/test_core_machine_learning_worker.py @@ -0,0 +1,377 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import time + +import pytest + +dragon = pytest.importorskip("dragon") + +import torch + +import smartsim.error as sse +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey, TensorKey +from smartsim._core.mli.infrastructure.worker.worker import ( + InferenceRequest, + MachineLearningWorkerCore, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.utils import installed_redisai_backends + +from .feature_store import FileSystemFeatureStore, MemoryFeatureStore + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +# retrieved from pytest fixtures +is_dragon = ( + pytest.test_launcher == "dragon" if hasattr(pytest, "test_launcher") else False +) +torch_available = "torch" in installed_redisai_backends() + + +@pytest.fixture +def persist_torch_model(test_dir: str) -> pathlib.Path: + ts_start = time.time_ns() + print("Starting model file creation...") + test_path = pathlib.Path(test_dir) + model_path = test_path / "basic.pt" + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + ts_end = time.time_ns() + + ts_elapsed = (ts_end - ts_start) / 1000000000 + print(f"Model file creation took {ts_elapsed} seconds") + return model_path + + +@pytest.fixture +def persist_torch_tensor(test_dir: str) -> pathlib.Path: + ts_start = time.time_ns() + print("Starting model file creation...") + test_path = pathlib.Path(test_dir) + file_path = test_path / "tensor.pt" + + tensor = torch.randn((100, 100, 2)) + torch.save(tensor, file_path) + ts_end = time.time_ns() + + ts_elapsed = (ts_end - ts_start) / 1000000000 + print(f"Tensor file creation took {ts_elapsed} seconds") + return file_path + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_disk(persist_torch_model: pathlib.Path, test_dir: str) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + key = str(persist_torch_model) + feature_store = FileSystemFeatureStore(test_dir) + fsd = feature_store.descriptor + feature_store[str(persist_torch_model)] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +def test_fetch_model_disk_missing() -> None: + """Verify that the ML worker fails to retrieves a model + when given an invalid (file system) key""" + worker = MachineLearningWorkerCore + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + key = "/path/that/doesnt/exist" + + model_key = ModelKey(key=key, descriptor=fsd) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_model(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_feature_store(persist_torch_model: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + + # create a key to retrieve from the feature store + key = "test-model" + + # put model bytes into the feature store + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + feature_store[key] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +def test_fetch_model_feature_store_missing() -> None: + """Verify that the ML worker fails to retrieves a model + when given an invalid (feature store) key""" + worker = MachineLearningWorkerCore + + key = "some-key" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + # todo: consider that raising this exception shows impl. replace... + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_model(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_model_memory(persist_torch_model: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a model + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + + key = "test-model" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + feature_store[key] = persist_torch_model.read_bytes() + + model_key = ModelKey(key=key, descriptor=feature_store.descriptor) + request = InferenceRequest(model_key=model_key) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_model(batch, {fsd: feature_store}) + assert fetch_result.model_bytes + assert fetch_result.model_bytes == persist_torch_model.read_bytes() + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_disk(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (file system) key""" + tensor_name = str(persist_torch_tensor) + + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + worker = MachineLearningWorkerCore + + feature_store[tensor_name] = persist_torch_tensor.read_bytes() + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None + + +def test_fetch_input_disk_missing() -> None: + """Verify that the ML worker fails to retrieves a tensor/input + when given an invalid (file system) key""" + worker = MachineLearningWorkerCore + + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + key = "/path/that/doesnt/exist" + + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_inputs(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key[0] in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (feature store) key""" + worker = MachineLearningWorkerCore + + tensor_name = "test-tensor" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + request = InferenceRequest(input_keys=[TensorKey(key=tensor_name, descriptor=fsd)]) + + # put model bytes into the feature store + feature_store[tensor_name] = persist_torch_tensor.read_bytes() + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs + assert ( + list(fetch_result[0].inputs)[0][:10] == persist_torch_tensor.read_bytes()[:10] + ) + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_multi_input_feature_store(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves multiple tensor/input + when given a valid collection of (feature store) keys""" + worker = MachineLearningWorkerCore + + tensor_name = "test-tensor" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + # put model bytes into the feature store + body1 = persist_torch_tensor.read_bytes() + feature_store[tensor_name + "1"] = body1 + + body2 = b"abcdefghijklmnopqrstuvwxyz" + feature_store[tensor_name + "2"] = body2 + + body3 = b"mnopqrstuvwxyzabcdefghijkl" + feature_store[tensor_name + "3"] = body3 + + request = InferenceRequest( + input_keys=[ + TensorKey(key=tensor_name + "1", descriptor=fsd), + TensorKey(key=tensor_name + "2", descriptor=fsd), + TensorKey(key=tensor_name + "3", descriptor=fsd), + ] + ) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + + raw_bytes = list(fetch_result[0].inputs) + assert raw_bytes + assert raw_bytes[0][:10] == persist_torch_tensor.read_bytes()[:10] + assert raw_bytes[1][:10] == body2[:10] + assert raw_bytes[2][:10] == body3[:10] + + +def test_fetch_input_feature_store_missing() -> None: + """Verify that the ML worker fails to retrieves a tensor/input + when given an invalid (feature store) key""" + worker = MachineLearningWorkerCore + + key = "bad-key" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + with pytest.raises(sse.SmartSimError) as ex: + worker.fetch_inputs(batch, {fsd: feature_store}) + + # ensure the error message includes key-identifying information + assert key in ex.value.args[0] + + +@pytest.mark.skipif(not torch_available, reason="Torch backend is not installed") +def test_fetch_input_memory(persist_torch_tensor: pathlib.Path) -> None: + """Verify that the ML worker successfully retrieves a tensor/input + when given a valid (file system) key""" + worker = MachineLearningWorkerCore + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + key = "test-model" + feature_store[key] = persist_torch_tensor.read_bytes() + request = InferenceRequest(input_keys=[TensorKey(key=key, descriptor=fsd)]) + + model_key = ModelKey(key="test-model", descriptor=fsd) + batch = RequestBatch([request], None, model_key) + + fetch_result = worker.fetch_inputs(batch, {fsd: feature_store}) + assert fetch_result[0].inputs is not None + + +def test_place_outputs() -> None: + """Verify outputs are shared using the feature store""" + worker = MachineLearningWorkerCore + + key_name = "test-model" + feature_store = MemoryFeatureStore() + fsd = feature_store.descriptor + + # create a key to retrieve from the feature store + keys = [ + TensorKey(key=key_name + "1", descriptor=fsd), + TensorKey(key=key_name + "2", descriptor=fsd), + TensorKey(key=key_name + "3", descriptor=fsd), + ] + data = [b"abcdef", b"ghijkl", b"mnopqr"] + + for fsk, v in zip(keys, data): + feature_store[fsk.key] = v + + request = InferenceRequest(output_keys=keys) + transform_result = TransformOutputResult(data, [1], "c", "float32") + + worker.place_output(request, transform_result, {fsd: feature_store}) + + for i in range(3): + assert feature_store[keys[i].key] == data[i] + + +@pytest.mark.parametrize( + "key, descriptor", + [ + pytest.param("", "desc", id="invalid key"), + pytest.param("key", "", id="invalid descriptor"), + ], +) +def test_invalid_tensorkey(key, descriptor) -> None: + with pytest.raises(ValueError): + fsk = TensorKey(key, descriptor) diff --git a/tests/dragon/test_device_manager.py b/tests/dragon/test_device_manager.py new file mode 100644 index 0000000000..6b22c8bd66 --- /dev/null +++ b/tests/dragon/test_device_manager.py @@ -0,0 +1,186 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.device_manager import ( + DeviceManager, + WorkerDevice, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ( + FeatureStore, + ModelKey, + TensorKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MockWorker(MachineLearningWorkerBase): + @staticmethod + def fetch_model( + batch: RequestBatch, feature_stores: t.Dict[str, FeatureStore] + ) -> FetchModelResult: + if batch.has_raw_model: + return FetchModelResult(batch.raw_model) + return FetchModelResult(b"fetched_model") + + @staticmethod + def load_model( + batch: RequestBatch, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + return LoadModelResult(fetch_result.model_bytes) + + @staticmethod + def transform_input( + batch: RequestBatch, + fetch_results: list[FetchInputResult], + mem_pool: "MemoryPool", + ) -> TransformInputResult: + return TransformInputResult(b"result", [slice(0, 1)], [[1, 2]], ["float32"]) + + @staticmethod + def execute( + batch: RequestBatch, + load_result: LoadModelResult, + transform_result: TransformInputResult, + device: str, + ) -> ExecuteResult: + return ExecuteResult(b"result", [slice(0, 1)]) + + @staticmethod + def transform_output( + batch: RequestBatch, execute_result: ExecuteResult + ) -> t.List[TransformOutputResult]: + return [TransformOutputResult(b"result", None, "c", "float32")] + + +def test_worker_device(): + worker_device = WorkerDevice("gpu:0") + assert worker_device.name == "gpu:0" + + model_key = "my_model_key" + model = b"the model" + + worker_device.add_model(model_key, model) + + assert model_key in worker_device + assert worker_device.get_model(model_key) == model + worker_device.remove_model(model_key) + + assert model_key not in worker_device + + +def test_device_manager_model_in_request(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"raw model", + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key).model == b"raw model" + + assert model_key.key not in worker_device + + +def test_device_manager_model_key(): + + worker_device = WorkerDevice("gpu:0") + device_manager = DeviceManager(worker_device) + + worker = MockWorker() + + tensor_key = TensorKey(key="key", descriptor="desc") + output_key = TensorKey(key="key", descriptor="desc") + model_key = ModelKey(key="model key", descriptor="desc") + + request = InferenceRequest( + model_key=model_key, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=None, + batch_size=0, + ) + + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_key, + ) + + with device_manager.get_device( + worker=worker, batch=request_batch, feature_stores={} + ) as returned_device: + + assert returned_device == worker_device + assert worker_device.get_model(model_key.key).model == b"fetched_model" + + assert model_key.key in worker_device diff --git a/tests/dragon/test_dragon_backend.py b/tests/dragon/test_dragon_backend.py new file mode 100644 index 0000000000..2b2ef50f99 --- /dev/null +++ b/tests/dragon/test_dragon_backend.py @@ -0,0 +1,307 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import time +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + + +from smartsim._core.launcher.dragon.dragonBackend import DragonBackend +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_backend() -> DragonBackend: + return DragonBackend(pid=9999) + + +def test_dragonbackend_start_listener(the_backend: DragonBackend): + """Verify the background process listening to consumer registration events + is up and processing messages as expected.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + with pytest.raises(KeyError) as ex: + # we expect the value of the consumer to be empty until + # the listener start-up completes. + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + + assert "not found" in ex.value.args[0] + + drg_process = the_backend.start_event_listener(cpu_affinity=[], gpu_affinity=[]) + + # # confirm there is a process still running + logger.info(f"Dragon process started: {drg_process}") + assert drg_process is not None, "Backend was unable to start event listener" + assert drg_process.puid != 0, "Process unique ID is empty" + assert drg_process.returncode is None, "Listener terminated early" + + # wait for the event listener to come up + try: + config = backbone.wait_for( + [BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], timeout=30 + ) + # verify result was in the returned configuration map + assert config[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + except Exception: + raise KeyError( + f"Unable to locate {BackboneFeatureStore.MLI_REGISTRAR_CONSUMER}" + "in the backbone" + ) + + # wait_for ensures the normal retrieval will now work, error-free + descriptor = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + assert descriptor is not None + + # register a new listener channel + comm_channel = DragonCommChannel.from_descriptor(descriptor) + mock_descriptor = str(uuid.uuid4()) + event = OnCreateConsumer("test_dragonbackend_start_listener", mock_descriptor, []) + + event_bytes = bytes(event) + comm_channel.send(event_bytes) + + subscriber_list = [] + + # Give the channel time to write the message and the listener time to handle it + for i in range(20): + time.sleep(1) + # Retrieve the subscriber list from the backbone and verify it is updated + if subscriber_list := backbone.notification_channels: + logger.debug(f"The subscriber list was populated after {i} iterations") + break + + assert mock_descriptor in subscriber_list + + # now send a shutdown message to terminate the listener + return_code = drg_process.returncode + + # clean up if the OnShutdownRequested wasn't properly handled + if return_code is None and drg_process.is_alive: + drg_process.kill() + drg_process.join() + + +def test_dragonbackend_backend_consumer(the_backend: DragonBackend): + """Verify the listener background process updates the appropriate + value in the backbone.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + assert backbone._allow_reserved_writes + + # create listener with `as_service=False` to perform a single loop iteration + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + logger.debug(f"backbone loaded? {listener._backbone}") + logger.debug(f"listener created? {listener}") + + try: + # call the service execute method directly to trigger + # the entire service lifecycle + listener.execute() + + consumer_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + logger.debug(f"MLI_REGISTRAR_CONSUMER: {consumer_desc}") + + assert consumer_desc + except Exception as ex: + logger.info("") + finally: + listener._on_shutdown() + + +def test_dragonbackend_event_handled(the_backend: DragonBackend): + """Verify the event listener process updates the appropriate + value in the backbone when an event is received and again on shutdown. + """ + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + # create the listener to be tested + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=False) + + assert listener._backbone, "The listener is not attached to a backbone" + + try: + # set up the listener but don't let the service event loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can simulate registrations + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + num_events = 5 + events = [] + for i in range(num_events): + # register some mock consumers using the backend channel + event = OnCreateConsumer( + "test_dragonbackend_event_handled", + f"mock-consumer-descriptor-{uuid.uuid4()}", + [], + ) + event_bytes = bytes(event) + comm_channel.send(event_bytes) + events.append(event) + + # run few iterations of the event loop in case it takes a few cycles to write + for _ in range(20): + listener._on_iteration() + # Grab the value that should be getting updated + notify_consumers = set(backbone.notification_channels) + if len(notify_consumers) == len(events): + logger.info(f"Retrieved all consumers after {i} listen cycles") + break + + # ... and confirm that all the mock consumer descriptors are registered + assert set([e.descriptor for e in events]) == set(notify_consumers) + logger.info(f"Number of registered consumers: {len(notify_consumers)}") + + except Exception as ex: + logger.exception(f"test_dragonbackend_event_handled - exception occurred: {ex}") + assert False + finally: + # shutdown should unregister a registration listener + listener._on_shutdown() + + for i in range(10): + if BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in backbone: + logger.debug(f"The listener was removed after {i} iterations") + channel_desc = None + break + + # we should see that there is no listener registered + assert not channel_desc, "Listener shutdown failed to clean up the backbone" + + +def test_dragonbackend_shutdown_event(the_backend: DragonBackend): + """Verify the background process shuts down when it receives a + shutdown request.""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener(backbone, 1.0, 1.0, as_service=True) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + + # grab the channel descriptor so we can publish to it + channel_desc = backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] + comm_channel = DragonCommChannel.from_descriptor(channel_desc) + + assert listener._consumer.listening, "Listener isn't ready to listen" + + # send a shutdown request... + event = OnShutdownRequested("test_dragonbackend_shutdown_event") + event_bytes = bytes(event) + comm_channel.send(event_bytes, 0.1) + + # execute should encounter the shutdown and exit + listener.execute() + + # ...and confirm the listener is now cancelled + assert not listener._consumer.listening + + +@pytest.mark.parametrize("health_check_frequency", [10, 20]) +def test_dragonbackend_shutdown_on_health_check( + the_backend: DragonBackend, + health_check_frequency: float, +): + """Verify that the event listener automatically shuts down when + a new listener is registered in its place. + + :param health_check_frequency: The expected frequency of service health check + invocations""" + + # We need to let the backend create the backbone to continue + backbone = the_backend._create_backbone() + backbone.pop(BackboneFeatureStore.MLI_NOTIFY_CONSUMERS) + backbone.pop(BackboneFeatureStore.MLI_REGISTRAR_CONSUMER) + + listener = ConsumerRegistrationListener( + backbone, + 1.0, + 1.0, + as_service=True, # allow service to run long enough to health check + health_check_frequency=health_check_frequency, + ) + + # set up the listener but don't let the listener loop start + listener._create_eventing() # listener.execute() + assert listener._consumer.listening, "Listener wasn't ready to listen" + + # Replace the consumer descriptor in the backbone to trigger + # an automatic shutdown + backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = str(uuid.uuid4()) + + # set the last health check manually to verify the duration + start_at = time.time() + listener._last_health_check = time.time() + + # run execute to let the service trigger health checks + listener.execute() + elapsed = time.time() - start_at + + # confirm the frequency of the health check was honored + assert elapsed >= health_check_frequency + + # ...and confirm the listener is now cancelled + assert ( + not listener._consumer.listening + ), "Listener was not automatically shutdown by the health check" diff --git a/tests/dragon/test_dragon_ddict_utils.py b/tests/dragon/test_dragon_ddict_utils.py new file mode 100644 index 0000000000..c8bf687ef1 --- /dev/null +++ b/tests/dragon/test_dragon_ddict_utils.py @@ -0,0 +1,117 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +# isort: off +import dragon.data.ddict.ddict as dragon_ddict + +# isort: on + +from smartsim._core.mli.infrastructure.storage import dragon_util +from smartsim.log import get_logger + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +logger = get_logger(__name__) + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(1, 1, 3 * 1024**2, id="3MB, Bare minimum allocation"), + pytest.param(2, 2, 128 * 1024**2, id="128 MB allocation, 2 nodes, 2 mgr"), + pytest.param(2, 1, 512 * 1024**2, id="512 MB allocation, 2 nodes, 1 mgr"), + ], +) +def test_dragon_storage_util_create_ddict( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + ddict = dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + assert ddict is not None + + +@pytest.mark.parametrize( + "num_nodes, num_managers, mem_per_node", + [ + pytest.param(-1, 1, 3 * 1024**2, id="Negative Node Count"), + pytest.param(0, 1, 3 * 1024**2, id="Invalid Node Count"), + pytest.param(1, -1, 3 * 1024**2, id="Negative Mgr Count"), + pytest.param(1, 0, 3 * 1024**2, id="Invalid Mgr Count"), + pytest.param(1, 1, -3 * 1024**2, id="Negative Mem Per Node"), + pytest.param(1, 1, (3 * 1024**2) - 1, id="Invalid Mem Per Node"), + pytest.param(1, 1, 0 * 1024**2, id="No Mem Per Node"), + ], +) +def test_dragon_storage_util_create_ddict_validators( + num_nodes: int, + num_managers: int, + mem_per_node: int, +): + """Verify that a dragon dictionary is successfully created. + + :param num_nodes: Number of ddict nodes to attempt to create + :param num_managers: Number of managers per node to request + :param num_managers: Memory to allocate per node + """ + with pytest.raises(ValueError): + dragon_util.create_ddict(num_nodes, num_managers, mem_per_node) + + +def test_dragon_storage_util_get_ddict_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a descriptor is created. + + :param the_storage: A pre-allocated ddict + """ + value = dragon_util.ddict_to_descriptor(the_storage) + + assert isinstance(value, str) + assert len(value) > 0 + + +def test_dragon_storage_util_get_ddict_from_descriptor(the_storage: dragon_ddict.DDict): + """Verify that a ddict is created from a descriptor. + + :param the_storage: A pre-allocated ddict + """ + descriptor = dragon_util.ddict_to_descriptor(the_storage) + + value = dragon_util.descriptor_to_ddict(descriptor) + + assert value is not None + assert isinstance(value, dragon_ddict.DDict) + assert dragon_util.ddict_to_descriptor(value) == descriptor diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon/test_environment_loader.py new file mode 100644 index 0000000000..07b2a45c1c --- /dev/null +++ b/tests/dragon/test_environment_loader.py @@ -0,0 +1,147 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +import dragon.data.ddict.ddict as dragon_ddict +import dragon.utils as du +from dragon.fli import FLInterface + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + DragonFeatureStore, +) +from smartsim.error.errors import SmartSimError + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "content", + [ + pytest.param(b"a"), + pytest.param(b"new byte string"), + ], +) +def test_environment_loader_attach_fli(content: bytes, monkeypatch: pytest.MonkeyPatch): + """A descriptor can be stored, loaded, and reattached.""" + chan = create_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + config_queue = config.get_queue() + + _ = config_queue.send(content) + + old_recv = queue.recvh() + result, _ = old_recv.recv_bytes() + assert result == content + + +def test_environment_loader_serialize_fli(monkeypatch: pytest.MonkeyPatch): + """The serialized descriptors of a loaded and unloaded + queue are the same.""" + chan = create_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv( + EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, + du.B64.bytes_to_str(queue.serialize()), + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + config_queue = config.get_queue() + assert config_queue._fli.serialize() == queue.serialize() + + +def test_environment_loader_flifails(monkeypatch: pytest.MonkeyPatch): + """An incorrect serialized descriptor will fails to attach.""" + + monkeypatch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "randomstring") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=None, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + with pytest.raises(SmartSimError): + config.get_queue() + + +def test_environment_loader_backbone_load_dfs( + monkeypatch: pytest.MonkeyPatch, the_storage: dragon_ddict.DDict +): + """Verify the dragon feature store is loaded correctly by the + EnvironmentConfigLoader to demonstrate featurestore_factory correctness.""" + feature_store = DragonFeatureStore(the_storage) + monkeypatch.setenv( + EnvironmentConfigLoader.BACKBONE_ENV_VAR, feature_store.descriptor + ) + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=None, + queue_factory=None, + ) + + print(f"calling config.get_backbone: `{feature_store.descriptor}`") + + backbone = config.get_backbone() + assert backbone is not None + + +def test_environment_variables_not_set(monkeypatch: pytest.MonkeyPatch): + """EnvironmentConfigLoader getters return None when environment + variables are not set.""" + with monkeypatch.context() as patch: + patch.setenv(EnvironmentConfigLoader.BACKBONE_ENV_VAR, "") + patch.setenv(EnvironmentConfigLoader.REQUEST_QUEUE_ENV_VAR, "") + + config = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonCommChannel.from_descriptor, + ) + assert config.get_backbone() is None + assert config.get_queue() is None diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py new file mode 100644 index 0000000000..aacd47b556 --- /dev/null +++ b/tests/dragon/test_error_handling.py @@ -0,0 +1,511 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +import multiprocessing as mp + +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.fli import FLInterface +from dragon.mpbridge.queues import DragonQueue + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import ( + WorkerManager, + exception_handler, +) +from smartsim._core.mli.infrastructure.environment_loader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ( + FeatureStore, + ModelKey, + TensorKey, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + RequestBatch, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.mli.mli_schemas.response.response_capnp import ResponseBuilder + +from .utils.channel import FileSystemCommChannel +from .utils.worker import IntegratedTorchWorker + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.fixture(scope="module") +def app_feature_store(the_storage) -> FeatureStore: + # create a standalone feature store to mimic a user application putting + # data into an application-owned resource (app should not access backbone) + app_fs = DragonFeatureStore(the_storage) + return app_fs + + +@pytest.fixture +def setup_worker_manager_model_bytes( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + + inf_request = InferenceRequest( + model_key=None, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + + model_id = ModelKey(key="key", descriptor=app_feature_store.descriptor) + + request_batch = RequestBatch( + [inf_request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_worker_manager_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + dispatcher_task_queue: mp.Queue[RequestBatch] = mp.Queue(maxsize=0) + + worker_manager = WorkerManager( + config_loader=config_loader, + worker_type=integrated_worker_type, + dispatcher_queue=dispatcher_task_queue, + as_service=False, + cooldown=3, + ) + + tensor_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + output_key = TensorKey(key="key", descriptor=app_feature_store.descriptor) + model_id = ModelKey(key="model key", descriptor=app_feature_store.descriptor) + + request = InferenceRequest( + model_key=model_id, + callback=None, + raw_inputs=None, + input_keys=[tensor_key], + input_meta=None, + output_keys=[output_key], + raw_model=b"model", + batch_size=0, + ) + request_batch = RequestBatch( + [request], + TransformInputResult(b"transformed", [slice(0, 1)], [[1, 2]], ["float32"]), + model_id=model_id, + ) + + dispatcher_task_queue.put(request_batch) + return worker_manager, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_bytes( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") + request = MessageHandler.build_request( + test_dir, model, [tensor_key], [output_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type + + +@pytest.fixture +def setup_request_dispatcher_model_key( + test_dir: str, + monkeypatch: pytest.MonkeyPatch, + backbone_descriptor: str, + app_feature_store: FeatureStore, + the_worker_channel: DragonFLIChannel, +): + integrated_worker_type = IntegratedTorchWorker + + monkeypatch.setenv( + BackboneFeatureStore.MLI_WORKER_QUEUE, the_worker_channel.descriptor + ) + # Put backbone descriptor into env var for the `EnvironmentConfigLoader` + monkeypatch.setenv(BackboneFeatureStore.MLI_BACKBONE, backbone_descriptor) + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=0, + batch_size=0, + config_loader=config_loader, + worker_type=integrated_worker_type, + ) + request_dispatcher._on_start() + + tensor_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + output_key = MessageHandler.build_tensor_key("key", app_feature_store.descriptor) + model_key = MessageHandler.build_model_key( + key="model key", descriptor=app_feature_store.descriptor + ) + request = MessageHandler.build_request( + test_dir, model_key, [tensor_key], [output_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + request_dispatcher._incoming_channel.send(ser_request) + + return request_dispatcher, integrated_worker_type + + +def mock_pipeline_stage( + monkeypatch: pytest.MonkeyPatch, + integrated_worker: MachineLearningWorkerBase, + stage: str, +) -> t.Callable[[t.Any], ResponseBuilder]: + def mock_stage(*args: t.Any, **kwargs: t.Any) -> None: + raise ValueError(f"Simulated error in {stage}") + + monkeypatch.setattr(integrated_worker, stage, mock_stage) + mock_reply_fn = MagicMock() + mock_response = MagicMock() + mock_response.schema.node.displayName = "Response" + mock_reply_fn.return_value = mock_response + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", + mock_reply_fn, + ) + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + + def mock_exception_handler( + exc: Exception, reply_channel: CommChannelBase, failure_message: str + ) -> None: + exception_handler(exc, mock_reply_channel, failure_message) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.worker_manager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.request_dispatcher.exception_handler", + mock_exception_handler, + ) + + return mock_reply_fn + + +@pytest.mark.parametrize( + "setup_worker_manager", + [ + pytest.param("setup_worker_manager_model_bytes"), + pytest.param("setup_worker_manager_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_model", + "Error loading model on device or getting device.", + id="fetch model", + ), + pytest.param( + "load_model", + "Error loading model on device or getting device.", + id="load model", + ), + pytest.param("execute", "Error while executing.", id="execute"), + pytest.param( + "transform_output", + "Error while transforming the output.", + id="transform output", + ), + pytest.param( + "place_output", "Error while placing the output.", id="place output" + ), + ], +) +def test_wm_pipeline_stage_errors_handled( + request: pytest.FixtureRequest, + setup_worker_manager: str, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +) -> None: + """Ensures that the worker manager does not crash after a failure in various pipeline stages""" + worker_manager, integrated_worker_type = request.getfixturevalue( + setup_worker_manager + ) + integrated_worker = worker_manager._worker + + worker_manager._on_start() + device = worker_manager._device_manager._device + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_model"]: + monkeypatch.setattr( + integrated_worker, + "fetch_model", + MagicMock(return_value=FetchModelResult(b"result_bytes")), + ) + if stage not in ["fetch_model", "load_model"]: + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + device, + "get_model", + MagicMock(return_value=b"result_bytes"), + ) + if stage not in [ + "fetch_model", + "execute", + ]: + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes", [slice(0, 1)])), + ) + if stage not in [ + "fetch_model", + "execute", + "transform_output", + ]: + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock( + return_value=[TransformOutputResult(b"result", [], "c", "float32")] + ), + ) + + worker_manager._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +@pytest.mark.parametrize( + "setup_request_dispatcher", + [ + pytest.param("setup_request_dispatcher_model_bytes"), + pytest.param("setup_request_dispatcher_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_inputs", + "Error fetching input.", + id="fetch input", + ), + pytest.param( + "transform_input", + "Error transforming input.", + id="transform input", + ), + ], +) +def test_dispatcher_pipeline_stage_errors_handled( + request: pytest.FixtureRequest, + setup_request_dispatcher: str, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +) -> None: + """Ensures that the request dispatcher does not crash after a failure in various pipeline stages""" + request_dispatcher, integrated_worker_type = request.getfixturevalue( + setup_request_dispatcher + ) + integrated_worker = request_dispatcher._worker + + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_inputs"]: + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=[FetchInputResult(result=[b"result"], meta=None)]), + ) + + request_dispatcher._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + + mock_reply_fn = MagicMock() + + mock_response = MagicMock() + mock_response.schema.node.displayName = "Response" + mock_reply_fn.return_value = mock_response + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.error_handling.build_failure_reply", + mock_reply_fn, + ) + + test_exception = ValueError("Test ValueError") + exception_handler( + test_exception, mock_reply_channel, "Failure while fetching the model." + ) + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") + + +def test_dragon_feature_store_invalid_storage(): + """Verify that attempting to create a DragonFeatureStore without storage fails.""" + storage = None + + with pytest.raises(ValueError) as ex: + DragonFeatureStore(storage) + + assert "storage" in ex.value.args[0].lower() + assert "required" in ex.value.args[0].lower() diff --git a/tests/dragon/test_event_consumer.py b/tests/dragon/test_event_consumer.py new file mode 100644 index 0000000000..8a241bab19 --- /dev/null +++ b/tests/dragon/test_event_consumer.py @@ -0,0 +1,386 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import time +import typing as t +from unittest import mock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnShutdownRequested, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.control.listener import ( + ConsumerRegistrationListener, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_eventconsumer_eventpublisher_integration( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. This + test closely tracks the test in tests/test_featurestore_base.py also named + test_eventconsumer_eventpublisher_integration but requires dragon entities. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + back_channel = DragonCommChannel(create_local()) + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + the_backbone, + ) + back_consumer = EventConsumer( + back_channel, + the_backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", + the_backbone.descriptor, + key, + ) + mock_client_app.send(event, timeout=0.1) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize( + " timeout, batch_timeout, exp_err_msg", + [(-1, 1, " timeout"), (1, -1, "batch_timeout")], +) +def test_eventconsumer_invalid_timeout( + timeout: float, + batch_timeout: float, + exp_err_msg: str, + test_dir: str, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that the event consumer raises an exception + when provided an invalid request timeout. + + :param timeout: The request timeout for the event consumer recv call + :param batch_timeout: The batch timeout for the event consumer recv call + :param exp_err_msg: A unique value from the error message that should be raised + :param the_storage: The dragon storage engine to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # the consumer should report an error for the invalid timeout value + with pytest.raises(ValueError) as ex: + wmgr_consumer.recv(timeout=timeout, batch_timeout=batch_timeout) + + assert exp_err_msg in ex.value.args[0] + + +def test_eventconsumer_no_event_handler_registered( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer discards messages when + on a channel if no handler is registered. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + + # create a consumer to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone, event_handler=None) + + # create a broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [wmgr_channel.descriptor] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + # run the handler and let it discard messages + for _ in range(15): + wmgr_consumer.listen_once(0.2, 2.0) + + assert wmgr_consumer.listening + + +def test_eventconsumer_no_event_handler_registered_shutdown( + the_backbone: t.Any, test_dir: str +) -> None: + """Verify that a consumer without an event handler + registered still honors shutdown requests. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + wmgr_channel = DragonCommChannel(create_local()) + capp_channel = DragonCommChannel(create_local()) + + # create a consumers to receive messages + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + # create a broadcaster to publish messages + mock_worker_mgr = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # manually register the consumers since we don't have a backend running + the_backbone.notification_channels = [ + wmgr_channel.descriptor, + capp_channel.descriptor, + ] + + # simulate the app updating a model a few times + for key in ["key-1", "key-2", "key-1"]: + event = OnWriteFeatureStore( + "test_eventconsumer_no_event_handler_registered_shutdown", + the_backbone.descriptor, + key, + ) + mock_worker_mgr.send(event, timeout=0.1) + + event = OnShutdownRequested( + "test_eventconsumer_no_event_handler_registered_shutdown" + ) + mock_worker_mgr.send(event, timeout=0.1) + + # wmgr will stop listening to messages when it is told to stop listening + wmgr_consumer.listen(timeout=0.1, batch_timeout=2.0) + + for _ in range(15): + wmgr_consumer.listen_once(timeout=0.1, batch_timeout=2.0) + + # confirm the messages were processed, discarded, and the shutdown was received + assert wmgr_consumer.listening == False + + +def test_eventconsumer_registration( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that a consumer is correctly registered in + the backbone after sending a registration request. Then, + Confirm the consumer is unregistered after sending the + un-register request. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # NOTE: service.execute(as_service=False) will complete the service life- + # cycle and remove the registrar from the backbone, so mock _on_shutdown + disabled_shutdown = mock.MagicMock() + patch.setattr(registrar, "_on_shutdown", disabled_shutdown) + + # initialze registrar resources + registrar.execute() + + # create a consumer that will be registered + wmgr_channel = DragonCommChannel(create_local()) + wmgr_consumer = EventConsumer(wmgr_channel, the_backbone) + + registered_channels = the_backbone.notification_channels + + # trigger the consumer-to-registrar handshake + wmgr_consumer.register() + + current_registrations: t.List[str] = [] + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is registered + assert wmgr_channel.descriptor in current_registrations + + # copy old list so we can compare against it. + registered_channels = list(current_registrations) + + # trigger the consumer removal + wmgr_consumer.unregister() + + # have the registrar run a few times to pick up the msg + for i in range(15): + registrar.execute() + current_registrations = the_backbone.notification_channels + if len(current_registrations) != len(registered_channels): + logger.debug(f"The event was processed on iteration {i}") + break + + # confirm the consumer is no longer registered + assert wmgr_channel.descriptor not in current_registrations + + +def test_registrar_teardown( + the_backbone: t.Any, test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that the consumer registrar removes itself from + the backbone when it shuts down. + + :param the_backbone: The BackboneFeatureStore to use + :param test_dir: Automatically generated unique working + directories for individual test outputs + """ + + with monkeypatch.context() as patch: + registrar = ConsumerRegistrationListener( + the_backbone, 1.0, 2.0, as_service=False + ) + + # directly initialze registrar resources to avoid service life-cycle + registrar._create_eventing() + + # confirm the registrar is published to the backbone + cfg = the_backbone.wait_for([BackboneFeatureStore.MLI_REGISTRAR_CONSUMER], 10) + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in cfg + + # execute the entire service lifecycle 1x + registrar.execute() + + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + + for i in range(15): + time.sleep(0.1) + consumer_found = BackboneFeatureStore.MLI_REGISTRAR_CONSUMER in the_backbone + if not consumer_found: + logger.debug(f"Registrar removed from the backbone on iteration {i}") + break + + assert BackboneFeatureStore.MLI_REGISTRAR_CONSUMER not in the_backbone diff --git a/tests/dragon/test_featurestore.py b/tests/dragon/test_featurestore.py new file mode 100644 index 0000000000..019dcde7a0 --- /dev/null +++ b/tests/dragon/test_featurestore.py @@ -0,0 +1,327 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import multiprocessing as mp +import random +import time +import typing as t +import unittest.mock as mock +import uuid + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + time as bbtime, +) +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# isort: off +from dragon import fli +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +def test_backbone_wait_for_no_keys( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeeds + immediately and does not cause a wait to occur if the supplied key + list is empty. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([]) + assert len(values) == 0 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for a value succeed + immediately and do not cause a wait to occur if the data exists. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([BackboneFeatureStore.MLI_WORKER_QUEUE], 0.1) + + # confirm that wait_for with one key returns one value + assert len(values) == 1 + + # confirm that the descriptor is non-null w/some non-trivial value + assert len(values[BackboneFeatureStore.MLI_WORKER_QUEUE]) > 5 + + # confirm that no wait occurred + bbtime.sleep.assert_not_called() + + +def test_backbone_wait_for_prepopulated_dupe( + the_backbone: BackboneFeatureStore, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that asking the backbone to wait for keys that are duplicated + results in a single value being returned for each key. + + :param the_backbone: the storage engine to use, prepopulated with + """ + # set a very low timeout to confirm that it does not wait + + key1, key2 = "key-1", "key-2" + value1, value2 = "i-am-value-1", "i-am-value-2" + the_backbone[key1] = value1 + the_backbone[key2] = value2 + + with monkeypatch.context() as ctx: + # all keys should be found and the timeout should never be checked. + ctx.setattr(bbtime, "sleep", mock.MagicMock()) + + values = the_backbone.wait_for([key1, key2, key1]) # key1 is duplicated + + # confirm that wait_for with one key returns one value + assert len(values) == 2 + assert key1 in values + assert key2 in values + + assert values[key1] == value1 + assert values[key2] == value2 + + +def set_value_after_delay( + descriptor: str, key: str, value: str, delay: float = 5 +) -> None: + """Helper method to persist a random value into the backbone + + :param descriptor: the backbone feature store descriptor to attach to + :param key: the key to write to + :param value: a value to write to the key + :param delay: amount of delay to apply before writing the key + """ + time.sleep(delay) + + backbone = BackboneFeatureStore.from_descriptor(descriptor) + backbone[key] = value + logger.debug(f"set_value_after_delay wrote `{value} to backbone[`{key}`]") + + +@pytest.mark.parametrize( + "delay", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 2, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 4, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 8, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_partial_prepopulated( + the_backbone: BackboneFeatureStore, delay: float +) -> None: + """Verify that when data is not all in the backbone, the `wait_for` operation + continues to poll until it finds everything it needs. + + :param the_backbone: the storage engine to use, prepopulated with + :param delay: the number of seconds the second process will wait before + setting the target value in the backbone featurestore + """ + # set a very low timeout to confirm that it does not wait + wait_timeout = 10 + + key, value = str(uuid.uuid4()), str(random.random() * 10) + + logger.debug(f"Starting process to write {key} after {delay}s") + p = mp.Process( + target=set_value_after_delay, args=(the_backbone.descriptor, key, value, delay) + ) + p.start() + + p2 = mp.Process( + target=the_backbone.wait_for, + args=([BackboneFeatureStore.MLI_WORKER_QUEUE, key],), + kwargs={"timeout": wait_timeout}, + ) + p2.start() + + p.join() + p2.join() + + # both values should be written at this time + ret_vals = the_backbone.wait_for( + [key, BackboneFeatureStore.MLI_WORKER_QUEUE, key], 0.1 + ) + # confirm that wait_for with two keys returns two values + assert len(ret_vals) == 2, "values should contain values for both awaited keys" + + # confirm the pre-populated value has the correct output + assert ( + ret_vals[BackboneFeatureStore.MLI_WORKER_QUEUE] == "12345" + ) # mock descriptor value from fixture + + # confirm the population process completed and the awaited value is correct + assert ret_vals[key] == value, "verify order of values " + + +@pytest.mark.parametrize( + "num_keys", + [ + pytest.param( + 0, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 1, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 3, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 7, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + pytest.param( + 11, + marks=pytest.mark.skip( + "Must use entrypoint instead of mp.Process to run on build agent" + ), + ), + ], +) +def test_backbone_wait_for_multikey( + the_backbone: BackboneFeatureStore, + num_keys: int, + test_dir: str, +) -> None: + """Verify that asking the backbone to wait for multiple keys results + in that number of values being returned. + + :param the_backbone: the storage engine to use, prepopulated with + :param num_keys: the number of extra keys to set & request in the backbone + """ + # maximum delay allowed for setter processes + max_delay = 5 + + extra_keys = [str(uuid.uuid4()) for _ in range(num_keys)] + extra_values = [str(uuid.uuid4()) for _ in range(num_keys)] + extras = dict(zip(extra_keys, extra_values)) + delays = [random.random() * max_delay for _ in range(num_keys)] + processes = [] + + for key, value, delay in zip(extra_keys, extra_values, delays): + assert delay < max_delay, "write delay exceeds test timeout" + logger.debug(f"Delaying {key} write by {delay} seconds") + p = mp.Process( + target=set_value_after_delay, + args=(the_backbone.descriptor, key, value, delay), + ) + p.start() + processes.append(p) + + p2 = mp.Process( + target=the_backbone.wait_for, + args=(extra_keys,), + kwargs={"timeout": max_delay * 2}, + ) + p2.start() + for p in processes: + p.join(timeout=max_delay * 2) + p2.join( + timeout=max_delay * 2 + ) # give it 10 seconds longer than p2 timeout for backoff + + # use without a wait to verify all values are written + num_keys = len(extra_keys) + actual_values = the_backbone.wait_for(extra_keys, timeout=0.01) + assert len(extra_keys) == num_keys + + # confirm that wait_for returns all the expected values + assert len(actual_values) == num_keys + + # confirm that the returned values match (e.g. are returned in the right order) + for k in extras: + assert extras[k] == actual_values[k] diff --git a/tests/dragon/test_featurestore_base.py b/tests/dragon/test_featurestore_base.py new file mode 100644 index 0000000000..6daceb9061 --- /dev/null +++ b/tests/dragon/test_featurestore_base.py @@ -0,0 +1,844 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pathlib +import time +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import ( + OnCreateConsumer, + OnWriteFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.feature_store import ReservedKeys +from smartsim.error import SmartSimError + +from .channel import FileSystemCommChannel +from .feature_store import MemoryFeatureStore + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +def boom(*args, **kwargs) -> None: + """Helper function that blows up when used to mock up + some other function.""" + raise Exception(f"you shall not pass! {args}, {kwargs}") + + +def test_event_uid() -> None: + """Verify that all events include a unique identifier.""" + uids: t.Set[str] = set() + num_iters = 1000 + + # generate a bunch of events and keep track all the IDs + for i in range(num_iters): + event_a = OnCreateConsumer("test_event_uid", str(i), filters=[]) + event_b = OnWriteFeatureStore("test_event_uid", "test_event_uid", str(i)) + + uids.add(event_a.uid) + uids.add(event_b.uid) + + # verify each event created a unique ID + assert len(uids) == 2 * num_iters + + +def test_mli_reserved_keys_conversion() -> None: + """Verify that conversion from a string to an enum member + works as expected.""" + + for reserved_key in ReservedKeys: + # iterate through all keys and verify `from_string` works + assert ReservedKeys.contains(reserved_key.value) + + # show that the value (actual key) not the enum member name + # will not be incorrectly identified as reserved + assert not ReservedKeys.contains(str(reserved_key).split(".")[1]) + + +def test_mli_reserved_keys_writes() -> None: + """Verify that attempts to write to reserved keys are blocked from a + standard DragonFeatureStore but enabled with the BackboneFeatureStore.""" + + mock_storage = {} + dfs = DragonFeatureStore(mock_storage) + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + other = MemoryFeatureStore(mock_storage) + + expected_value = "value" + + for reserved_key in ReservedKeys: + # we expect every reserved key to fail using DragonFeatureStore... + with pytest.raises(SmartSimError) as ex: + dfs[reserved_key] = expected_value + + assert "reserved key" in ex.value.args[0] + + # ... and expect other feature stores to respect reserved keys + with pytest.raises(SmartSimError) as ex: + other[reserved_key] = expected_value + + assert "reserved key" in ex.value.args[0] + + # ...and those same keys to succeed on the backbone + backbone[reserved_key] = expected_value + actual_value = backbone[reserved_key] + assert actual_value == expected_value + + +def test_mli_consumers_read_by_key() -> None: + """Verify that the value returned from the mli consumers method is written + to the correct key and reads are allowed via standard dragon feature store.""" + + mock_storage = {} + dfs = DragonFeatureStore(mock_storage) + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + other = MemoryFeatureStore(mock_storage) + + expected_value = "value" + + # write using backbone that has permission to write reserved keys + backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value + + # confirm read-only access to reserved keys from any FeatureStore + for fs in [dfs, backbone, other]: + assert fs[ReservedKeys.MLI_NOTIFY_CONSUMERS] == expected_value + + +def test_mli_consumers_read_by_backbone() -> None: + """Verify that the backbone reads the correct location + when using the backbone feature store API instead of mapping API.""" + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + expected_value = "value" + + backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] = expected_value + + # confirm reading via convenience method returns expected value + assert backbone.notification_channels[0] == expected_value + + +def test_mli_consumers_write_by_backbone() -> None: + """Verify that the backbone writes the correct location + when using the backbone feature store API instead of mapping API.""" + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + expected_value = ["value"] + + backbone.notification_channels = expected_value + + # confirm write using convenience method targets expected key + assert backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] == ",".join(expected_value) + + +def test_eventpublisher_broadcast_no_factory(test_dir: str) -> None: + """Verify that a broadcast operation without any registered subscribers + succeeds without raising Exceptions. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + consumer_descriptor = storage_path / "test-consumer" + + # NOTE: we're not putting any consumers into the backbone here! + backbone = BackboneFeatureStore(mock_storage) + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) + + publisher = EventBroadcaster(backbone) + num_receivers = 0 + + # publishing this event without any known consumers registered should succeed + # but report that it didn't have anybody to send the event to + consumer_descriptor = storage_path / f"test-consumer" + event = OnCreateConsumer( + "test_eventpublisher_broadcast_no_factory", consumer_descriptor, filters=[] + ) + + num_receivers += publisher.send(event) + + # confirm no changes to the backbone occur when fetching the empty consumer key + key_in_features_store = ReservedKeys.MLI_NOTIFY_CONSUMERS in backbone + assert not key_in_features_store + + # confirm that the broadcast reports no events published + assert num_receivers == 0 + # confirm that the broadcast buffered the event for a later send + assert publisher.num_buffered == 1 + + +def test_eventpublisher_broadcast_to_empty_consumer_list(test_dir: str) -> None: + """Verify that a broadcast operation without any registered subscribers + succeeds without raising Exceptions. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + # prep our backbone with a consumer list + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = [] + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_to_empty_consumer_list", + consumer_descriptor, + filters=[], + ) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + num_receivers = publisher.send(event) + + registered_consumers = backbone[ReservedKeys.MLI_NOTIFY_CONSUMERS] + + # confirm that no consumers exist in backbone to send to + assert not registered_consumers + # confirm that the broadcast reports no events published + assert num_receivers == 0 + # confirm that the broadcast buffered the event for a later send + assert publisher.num_buffered == 1 + + +def test_eventpublisher_broadcast_without_channel_factory(test_dir: str) -> None: + """Verify that a broadcast operation reports an error if no channel + factory was supplied for constructing the consumer channels. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + # prep our backbone with a consumer list + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = [consumer_descriptor] + + event = OnCreateConsumer( + "test_eventpublisher_broadcast_without_channel_factory", + consumer_descriptor, + filters=[], + ) + publisher = EventBroadcaster( + backbone, + # channel_factory=FileSystemCommChannel.from_descriptor # <--- not supplied + ) + + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "factory" in ex.value.args[0] + + +def test_eventpublisher_broadcast_empties_buffer(test_dir: str) -> None: + """Verify that a successful broadcast clears messages from the event + buffer when a new message is sent and consumers are registered. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = (consumer_descriptor,) + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + # mock building up some buffered events + num_buffered_events = 14 + for i in range(num_buffered_events): + event = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(i)}", + [], + ) + publisher._event_buffer.append(bytes(event)) + + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_empties_buffer", + storage_path / f"test-consumer-{str(num_buffered_events + 1)}", + [], + ) + + num_receivers = publisher.send(event0) + # 1 receiver x 15 total events == 15 events + assert num_receivers == num_buffered_events + 1 + + +@pytest.mark.parametrize( + "num_consumers, num_buffered, expected_num_sent", + [ + pytest.param(0, 7, 0, id="0 x (7+1) - no consumers, multi-buffer"), + pytest.param(1, 7, 8, id="1 x (7+1) - single consumer, multi-buffer"), + pytest.param(2, 7, 16, id="2 x (7+1) - multi-consumer, multi-buffer"), + pytest.param(4, 4, 20, id="4 x (4+1) - multi-consumer, multi-buffer (odd #)"), + pytest.param(9, 0, 9, id="13 x (0+1) - multi-consumer, empty buffer"), + ], +) +def test_eventpublisher_broadcast_returns_total_sent( + test_dir: str, num_consumers: int, num_buffered: int, expected_num_sent: int +) -> None: + """Verify that a successful broadcast returns the total number of events + sent, including buffered messages. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param num_consumers: the number of consumers to mock setting up prior to send + :param num_buffered: the number of pre-buffered events to mock up + :param expected_num_sent: the expected result from calling send + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumers = [] + for i in range(num_consumers): + consumers.append(storage_path / f"test-consumer-{i}") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + backbone.notification_channels = consumers + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + # mock building up some buffered events + for i in range(num_buffered): + event = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{str(i)}", + [], + ) + publisher._event_buffer.append(bytes(event)) + + assert publisher.num_buffered == num_buffered + + # this event will trigger clearing anything already in buffer + event0 = OnCreateConsumer( + "test_eventpublisher_broadcast_returns_total_sent", + storage_path / f"test-consumer-{num_buffered}", + [], + ) + + # num_receivers should contain a number that computes w/all consumers and all events + num_receivers = publisher.send(event0) + + assert num_receivers == expected_num_sent + + +def test_eventpublisher_prune_unused_consumer(test_dir: str) -> None: + """Verify that any unused consumers are pruned each time a new event is sent. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + mock_storage = {} + + # note: file-system descriptors are just paths + consumer_descriptor = storage_path / "test-consumer" + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", + consumer_descriptor, + filters=[], + ) + + # the only registered cnosumer is in the event, expect no pruning + backbone.notification_channels = (consumer_descriptor,) + + publisher.send(event) + assert str(consumer_descriptor) in publisher._channel_cache + assert len(publisher._channel_cache) == 1 + + # add a new descriptor for another event... + consumer_descriptor2 = storage_path / "test-consumer-2" + # ... and remove the old descriptor from the backbone when it's looked up + backbone.notification_channels = (consumer_descriptor2,) + + event = OnCreateConsumer( + "test_eventpublisher_prune_unused_consumer", consumer_descriptor2, filters=[] + ) + + publisher.send(event) + + assert str(consumer_descriptor2) in publisher._channel_cache + assert str(consumer_descriptor) not in publisher._channel_cache + assert len(publisher._channel_cache) == 1 + + # test multi-consumer pruning by caching some extra channels + prune0, prune1, prune2 = "abc", "def", "ghi" + publisher._channel_cache[prune0] = "doesnt-matter-if-it-is-pruned" + publisher._channel_cache[prune1] = "doesnt-matter-if-it-is-pruned" + publisher._channel_cache[prune2] = "doesnt-matter-if-it-is-pruned" + + # add in one of our old channels so we prune the above items, send to these + backbone.notification_channels = (consumer_descriptor, consumer_descriptor2) + + publisher.send(event) + + assert str(consumer_descriptor2) in publisher._channel_cache + + # NOTE: we should NOT prune something that isn't used by this message but + # does appear in `backbone.notification_channels` + assert str(consumer_descriptor) in publisher._channel_cache + + # confirm all of our items that were not in the notification channels are gone + for pruned in [prune0, prune1, prune2]: + assert pruned not in publisher._channel_cache + + # confirm we have only the two expected items in the channel cache + assert len(publisher._channel_cache) == 2 + + +def test_eventpublisher_serialize_failure( + test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that errors during message serialization are raised to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_serialize_failure", target_descriptor, filters=[] + ) + + # patch the __bytes__ implementation to cause pickling to fail during send + def bad_bytes(self) -> bytes: + return b"abc" + + # this patch causes an attribute error when event pickling is attempted + patch.setattr(event, "__bytes__", bad_bytes) + + backbone.notification_channels = (target_descriptor,) + + # send a message into the channel + with pytest.raises(AttributeError) as ex: + publisher.send(event) + + assert "serialize" in ex.value.args[0] + + +def test_eventpublisher_factory_failure( + test_dir: str, monkeypatch: pytest.MonkeyPatch +) -> None: + """Verify that errors during channel construction are raised to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + def boom(descriptor: str) -> None: + raise Exception(f"you shall not pass! {descriptor}") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster(backbone, channel_factory=boom) + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_factory_failure", target_descriptor, filters=[] + ) + + backbone.notification_channels = (target_descriptor,) + + # send a message into the channel + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "construct" in ex.value.args[0] + + +def test_eventpublisher_failure(test_dir: str, monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that unexpected errors during message send are caught and wrapped in a + SmartSimError so they are not propagated directly to the caller. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param monkeypatch: pytest fixture for modifying behavior of existing code + with mock implementations + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + publisher = EventBroadcaster( + backbone, channel_factory=FileSystemCommChannel.from_descriptor + ) + + def boom(self) -> None: + raise Exception("That was unexpected...") + + with monkeypatch.context() as patch: + event = OnCreateConsumer( + "test_eventpublisher_failure", target_descriptor, filters=[] + ) + + # patch the _broadcast implementation to cause send to fail after + # after the event has been pickled + patch.setattr(publisher, "_broadcast", boom) + + backbone.notification_channels = (target_descriptor,) + + # Here, we see the exception raised by broadcast that isn't expected + # is not allowed directly out, and instead is wrapped in SmartSimError + with pytest.raises(SmartSimError) as ex: + publisher.send(event) + + assert "unexpected" in ex.value.args[0] + + +def test_eventconsumer_receive(test_dir: str) -> None: + """Verify that a consumer retrieves a message from the given channel. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + event = OnCreateConsumer( + "test_eventconsumer_receive", target_descriptor, filters=[] + ) + + # simulate a sent event by writing directly to the input comm channel + comm_channel.send(bytes(event)) + + consumer = EventConsumer(comm_channel, backbone) + + all_received: t.List[OnCreateConsumer] = consumer.recv() + assert len(all_received) == 1 + + # verify we received the same event that was raised + assert all_received[0].category == event.category + assert all_received[0].descriptor == event.descriptor + + +@pytest.mark.parametrize("num_sent", [0, 1, 2, 4, 8, 16]) +def test_eventconsumer_receive_multi(test_dir: str, num_sent: int) -> None: + """Verify that a consumer retrieves multiple message from the given channel. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + :param num_sent: parameterized value used to vary the number of events + that are enqueued and validations are checked at multiple queue sizes + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + + # simulate multiple sent events by writing directly to the input comm channel + for _ in range(num_sent): + event = OnCreateConsumer( + "test_eventconsumer_receive_multi", target_descriptor, filters=[] + ) + comm_channel.send(bytes(event)) + + consumer = EventConsumer(comm_channel, backbone) + + all_received: t.List[OnCreateConsumer] = consumer.recv() + assert len(all_received) == num_sent + + +def test_eventconsumer_receive_empty(test_dir: str) -> None: + """Verify that a consumer receiving an empty message ignores the + message and continues processing. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + + # note: file-system descriptors are just paths + target_descriptor = str(storage_path / "test-consumer") + + backbone = BackboneFeatureStore(mock_storage) + comm_channel = FileSystemCommChannel.from_descriptor(target_descriptor) + + # simulate a sent event by writing directly to the input comm channel + comm_channel.send(bytes(b"")) + + consumer = EventConsumer(comm_channel, backbone) + + messages = consumer.recv() + + # the messages array should be empty + assert not messages + + +def test_eventconsumer_eventpublisher_integration(test_dir: str) -> None: + """Verify that the publisher and consumer integrate as expected when + multiple publishers and consumers are sending simultaneously. + + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + mock_fs_descriptor = str(storage_path / f"mock-feature-store") + + wmgr_channel = FileSystemCommChannel(storage_path / "test-wmgr") + capp_channel = FileSystemCommChannel(storage_path / "test-capp") + back_channel = FileSystemCommChannel(storage_path / "test-backend") + + wmgr_consumer_descriptor = wmgr_channel.descriptor + capp_consumer_descriptor = capp_channel.descriptor + back_consumer_descriptor = back_channel.descriptor + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + wmgr_channel, + backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + capp_consumer = EventConsumer( + capp_channel, + backbone, + ) + back_consumer = EventConsumer( + back_channel, + backbone, + filters=[OnCreateConsumer.CONSUMER_CREATED], + ) + + # create some broadcasters to publish messages + mock_worker_mgr = EventBroadcaster( + backbone, + channel_factory=FileSystemCommChannel.from_descriptor, + ) + mock_client_app = EventBroadcaster( + backbone, + channel_factory=FileSystemCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + backbone.notification_channels = [ + wmgr_consumer_descriptor, + capp_consumer_descriptor, + back_consumer_descriptor, + ] + + # simulate worker manager sending a notification to backend that it's alive + event_1 = OnCreateConsumer( + "test_eventconsumer_eventpublisher_integration", + wmgr_consumer_descriptor, + filters=[], + ) + mock_worker_mgr.send(event_1) + + # simulate the app updating a model a few times + event_2 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + event_3 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-2" + ) + event_4 = OnWriteFeatureStore( + "test_eventconsumer_eventpublisher_integration", mock_fs_descriptor, "key-1" + ) + + mock_client_app.send(event_2) + mock_client_app.send(event_3) + mock_client_app.send(event_4) + + # worker manager should only get updates about feature update + wmgr_messages = wmgr_consumer.recv() + assert len(wmgr_messages) == 3 + + # the backend should only receive messages about consumer creation + back_messages = back_consumer.recv() + assert len(back_messages) == 1 + + # hypothetical app has no filters and will get all events + app_messages = capp_consumer.recv() + assert len(app_messages) == 4 + + +@pytest.mark.parametrize("invalid_timeout", [-100.0, -1.0, 0.0]) +def test_eventconsumer_batch_timeout( + invalid_timeout: float, + test_dir: str, +) -> None: + """Verify that a consumer allows only positive, non-zero values for timeout + if it is supplied. + + :param invalid_timeout: any invalid timeout that should fail validation + :param test_dir: pytest fixture automatically generating unique working + directories for individual test outputs + """ + storage_path = pathlib.Path(test_dir) / "features" + storage_path.mkdir(parents=True, exist_ok=True) + + mock_storage = {} + backbone = BackboneFeatureStore(mock_storage) + + channel = FileSystemCommChannel(storage_path / "test-wmgr") + + with pytest.raises(ValueError) as ex: + # try to create a consumer w/a max recv size of 0 + consumer = EventConsumer( + channel, + backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + consumer.recv(batch_timeout=invalid_timeout) + + assert "positive" in ex.value.args[0] + + +@pytest.mark.parametrize( + "wait_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(1, 1 + 1 + 1, id="1s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + pytest.param(9, 3 + 2 + 4 + 8, id="9s wait, 6 cycle steps"), + # aggregate an entire cycle into 16 + pytest.param(19.5, 16 + 3 + 2 + 4, id="20s wait, repeat cycle"), + ], +) +def test_backbone_wait_timeout(wait_timeout: float, exp_wait_max: float) -> None: + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param wait_timeout: Maximum amount of time (in seconds) to allow the backbone + to wait for the requested value to exist + :param exp_wait_max: Maximum amount of time (in seconds) to set as the upper + bound to allow the delays with backoff to occur + :param storage_for_dragon_fs: the dragon storage engine to use + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + start_time = time.time() + + storage = {} + backbone = BackboneFeatureStore(storage) + + with pytest.raises(SmartSimError) as ex: + backbone.wait_for(["does-not-exist"], wait_timeout) + + assert "timeout" in str(ex.value.args[0]).lower() + + end_time = time.time() + elapsed = end_time - start_time + + # confirm that we met our timeout + assert elapsed > wait_timeout, f"below configured timeout {wait_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" diff --git a/tests/dragon/test_featurestore_integration.py b/tests/dragon/test_featurestore_integration.py new file mode 100644 index 0000000000..23fdc55ab6 --- /dev/null +++ b/tests/dragon/test_featurestore_integration.py @@ -0,0 +1,213 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_util import ( + DEFAULT_CHANNEL_BUFFER_SIZE, + create_local, +) +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.consumer import EventConsumer +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) + +# isort: off +from dragon.channels import Channel + +# isort: on + +if t.TYPE_CHECKING: + import conftest + + +# The tests in this file must run in a dragon environment +pytestmark = pytest.mark.dragon + + +@pytest.fixture(scope="module") +def the_worker_channel() -> DragonCommChannel: + """Fixture to create a valid descriptor for a worker channel + that can be attached to.""" + wmgr_channel_ = create_local() + wmgr_channel = DragonCommChannel(wmgr_channel_) + return wmgr_channel + + +@pytest.mark.parametrize( + "num_events, batch_timeout, max_batches_expected", + [ + pytest.param(1, 1.0, 2, id="under 1s timeout"), + pytest.param(20, 1.0, 3, id="test 1s timeout 20x"), + pytest.param(30, 0.2, 5, id="test 0.2s timeout 30x"), + pytest.param(60, 0.4, 4, id="small batches"), + pytest.param(100, 0.1, 10, id="many small batches"), + ], +) +def test_eventconsumer_max_dequeue( + num_events: int, + batch_timeout: float, + max_batches_expected: int, + the_worker_channel: DragonCommChannel, + the_backbone: BackboneFeatureStore, +) -> None: + """Verify that a consumer does not sit and collect messages indefinitely + by checking that a consumer returns after a maximum timeout is exceeded. + + :param num_events: Total number of events to raise in the test + :param batch_timeout: Maximum wait time (in seconds) for a message to be sent + :param max_batches_expected: Maximum number of receives that should occur + :param the_storage: Dragon storage engine to use + """ + + # create some consumers to receive messages + wmgr_consumer = EventConsumer( + the_worker_channel, + the_backbone, + filters=[OnWriteFeatureStore.FEATURE_STORE_WRITTEN], + ) + + # create a broadcaster to publish messages + mock_client_app = EventBroadcaster( + the_backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + the_backbone.notification_channels = [the_worker_channel.descriptor] + + # simulate the app updating a model a lot of times + for key in (f"key-{i}" for i in range(num_events)): + event = OnWriteFeatureStore( + "test_eventconsumer_max_dequeue", the_backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) + + num_dequeued = 0 + num_batches = 0 + + while wmgr_messages := wmgr_consumer.recv( + timeout=0.1, + batch_timeout=batch_timeout, + ): + # worker manager should not get more than `max_num_msgs` events + num_dequeued += len(wmgr_messages) + num_batches += 1 + + # make sure we made all the expected dequeue calls and got everything + assert num_dequeued == num_events + assert num_batches > 0 + assert num_batches < max_batches_expected, "too many recv calls were made" + + +@pytest.mark.parametrize( + "buffer_size", + [ + pytest.param( + -1, + id="replace negative, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 0, + id="replace zero, default to 500", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1, + id="non-zero buffer size: 1", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + # pytest.param(500, id="maximum size edge case: 500"), + pytest.param( + 550, + id="larger than default: 550", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 800, + id="much larger then default: 800", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + pytest.param( + 1000, + id="very large buffer: 1000, unreliable in dragon-v0.10", + marks=pytest.mark.skip("create_local issue w/MPI must be mitigated"), + ), + ], +) +def test_channel_buffer_size( + buffer_size: int, + the_storage: t.Any, +) -> None: + """Verify that a channel used by an EventBroadcaster can buffer messages + until a configured maximum value is exceeded. + + :param buffer_size: Maximum number of messages allowed in a channel buffer + :param the_storage: The dragon storage engine to use + """ + + mock_storage = the_storage + backbone = BackboneFeatureStore(mock_storage, allow_reserved_writes=True) + + wmgr_channel_ = create_local(buffer_size) # <--- vary buffer size + wmgr_channel = DragonCommChannel(wmgr_channel_) + wmgr_consumer_descriptor = wmgr_channel.descriptor + + # create a broadcaster to publish messages. create no consumers to + # push the number of sent messages past the allotted buffer size + mock_client_app = EventBroadcaster( + backbone, + channel_factory=DragonCommChannel.from_descriptor, + ) + + # register all of the consumers even though the OnCreateConsumer really should + # trigger its registration. event processing is tested elsewhere. + backbone.notification_channels = [wmgr_consumer_descriptor] + + if buffer_size < 1: + # NOTE: we set this after creating the channel above to ensure + # the default parameter value was used during instantiation + buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE + + # simulate the app updating a model a lot of times + for key in (f"key-{i}" for i in range(buffer_size)): + event = OnWriteFeatureStore( + "test_channel_buffer_size", backbone.descriptor, key + ) + mock_client_app.send(event, timeout=0.01) + + # adding 1 more over the configured buffer size should report the error + with pytest.raises(Exception) as ex: + mock_client_app.send(event, timeout=0.01) diff --git a/tests/dragon/test_inference_reply.py b/tests/dragon/test_inference_reply.py new file mode 100644 index 0000000000..bdc7be14bc --- /dev/null +++ b/tests/dragon/test_inference_reply.py @@ -0,0 +1,76 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +handler = MessageHandler() + + +@pytest.fixture +def inference_reply() -> InferenceReply: + return InferenceReply() + + +@pytest.fixture +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") + + +@pytest.mark.parametrize( + "outputs, expected", + [ + ([b"output bytes"], True), + (None, False), + ([], False), + ], +) +def test_has_outputs(monkeypatch, inference_reply, outputs, expected): + """Test the has_outputs property with different values for outputs.""" + monkeypatch.setattr(inference_reply, "outputs", outputs) + assert inference_reply.has_outputs == expected + + +@pytest.mark.parametrize( + "output_keys, expected", + [ + ([fs_key], True), + (None, False), + ([], False), + ], +) +def test_has_output_keys(monkeypatch, inference_reply, output_keys, expected): + """Test the has_output_keys property with different values for output_keys.""" + monkeypatch.setattr(inference_reply, "output_keys", output_keys) + assert inference_reply.has_output_keys == expected diff --git a/tests/dragon/test_inference_request.py b/tests/dragon/test_inference_request.py new file mode 100644 index 0000000000..f5c8b9bdc7 --- /dev/null +++ b/tests/dragon/test_inference_request.py @@ -0,0 +1,118 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.storage.feature_store import TensorKey +from smartsim._core.mli.infrastructure.worker.worker import InferenceRequest +from smartsim._core.mli.message_handler import MessageHandler + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +handler = MessageHandler() + + +@pytest.fixture +def inference_request() -> InferenceRequest: + return InferenceRequest() + + +@pytest.fixture +def fs_key() -> TensorKey: + return TensorKey("key", "descriptor") + + +@pytest.mark.parametrize( + "raw_model, expected", + [ + (handler.build_model(b"bytes", "Model Name", "V1"), True), + (None, False), + ], +) +def test_has_raw_model(monkeypatch, inference_request, raw_model, expected): + """Test the has_raw_model property with different values for raw_model.""" + monkeypatch.setattr(inference_request, "raw_model", raw_model) + assert inference_request.has_raw_model == expected + + +@pytest.mark.parametrize( + "model_key, expected", + [ + (fs_key, True), + (None, False), + ], +) +def test_has_model_key(monkeypatch, inference_request, model_key, expected): + """Test the has_model_key property with different values for model_key.""" + monkeypatch.setattr(inference_request, "model_key", model_key) + assert inference_request.has_model_key == expected + + +@pytest.mark.parametrize( + "raw_inputs, expected", + [([b"raw input bytes"], True), (None, False), ([], False)], +) +def test_has_raw_inputs(monkeypatch, inference_request, raw_inputs, expected): + """Test the has_raw_inputs property with different values for raw_inputs.""" + monkeypatch.setattr(inference_request, "raw_inputs", raw_inputs) + assert inference_request.has_raw_inputs == expected + + +@pytest.mark.parametrize( + "input_keys, expected", + [([fs_key], True), (None, False), ([], False)], +) +def test_has_input_keys(monkeypatch, inference_request, input_keys, expected): + """Test the has_input_keys property with different values for input_keys.""" + monkeypatch.setattr(inference_request, "input_keys", input_keys) + assert inference_request.has_input_keys == expected + + +@pytest.mark.parametrize( + "output_keys, expected", + [([fs_key], True), (None, False), ([], False)], +) +def test_has_output_keys(monkeypatch, inference_request, output_keys, expected): + """Test the has_output_keys property with different values for output_keys.""" + monkeypatch.setattr(inference_request, "output_keys", output_keys) + assert inference_request.has_output_keys == expected + + +@pytest.mark.parametrize( + "input_meta, expected", + [ + ([handler.build_tensor_descriptor("c", "float32", [1, 2, 3])], True), + (None, False), + ([], False), + ], +) +def test_has_input_meta(monkeypatch, inference_request, input_meta, expected): + """Test the has_input_meta property with different values for input_meta.""" + monkeypatch.setattr(inference_request, "input_meta", input_meta) + assert inference_request.has_input_meta == expected diff --git a/tests/dragon/test_onnx_worker.py b/tests/dragon/test_onnx_worker.py new file mode 100644 index 0000000000..c9cfeccd26 --- /dev/null +++ b/tests/dragon/test_onnx_worker.py @@ -0,0 +1,215 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import numpy as np +import numpy.typing as npt +import pytest + +onnx = pytest.importorskip("onnx") +from onnx import load_model_from_string + +pytest.importorskip("onnxruntime") +from onnxruntime import InferenceSession + +dragon = pytest.importorskip("dragon") + +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryAlloc, MemoryPool +from skl2onnx import to_onnx +from sklearn.linear_model import LinearRegression +from sklearn.preprocessing import PolynomialFeatures + +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey +from smartsim._core.mli.infrastructure.worker.onnx_worker import ONNXWorker +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + RequestBatch, + TransformInputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +def get_X() -> npt.ArrayLike: + return np.linspace(0, 10, 10).astype(np.float32) + + +def get_poly_features() -> npt.ArrayLike: + + poly = PolynomialFeatures(degree=2, include_bias=False) + return poly.fit_transform(get_X().reshape(-1, 1)) + + +def get_Y() -> npt.ArrayLike: + p = np.polynomial.Polynomial([1.4, -10, 4]) + return p(get_X()) + + +def create_onnx_model(): + + poly_features = get_poly_features() + poly_reg_model = LinearRegression() + + poly_reg_model.fit(poly_features, get_Y()) + onnx_model = to_onnx(poly_reg_model, poly_features, target_opset=13) + + onnx_serialized = onnx_model.SerializeToString() + output_names = [n.name for n in onnx_model.graph.output] + input_names = [n.name for n in onnx_model.graph.input] + + return onnx_serialized, input_names, output_names + + +def get_request() -> InferenceRequest: + + tensors = [get_poly_features()] + serialized_tensors_descriptors = [ + MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape)) + for tensor in tensors + ] + + return InferenceRequest( + model_key=ModelKey(key="model", descriptor="xyz"), + callback=None, + raw_inputs=tensors, + input_keys=None, + input_meta=serialized_tensors_descriptors, + output_keys=None, + raw_model=create_onnx_model()[0], + batch_size=0, + ) + + +def get_request_batch_from_request( + request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None +) -> RequestBatch: + + return RequestBatch([request], inputs, request.model_key) + + +sample_request: InferenceRequest = get_request() +sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request) +worker = ONNXWorker() + + +def test_load_model(mlutils) -> None: + fetch_model_result = FetchModelResult(sample_request.raw_model) + load_model_result = worker.load_model( + sample_request_batch, fetch_model_result, mlutils.get_test_device().lower() + ) + + results = load_model_result.model.run( + load_model_result.outputs, + input_feed=dict(zip(load_model_result.inputs, [get_poly_features()])), + ) + + assert results[0].shape == (10, 1) + + +def test_transform_input(mlutils) -> None: + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_input_result = worker.transform_input( + sample_request_batch, [fetch_input_result], mem_pool + ) + + batch = get_poly_features() + assert transform_input_result.slices[0] == slice(0, batch.shape[0]) + + tensor_index = 0 + assert tuple(transform_input_result.dims[tensor_index]) == batch.shape + assert transform_input_result.dtypes[tensor_index] == str(batch.dtype) + mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index]) + itemsize = batch.itemsize + tensor = np.frombuffer( + mem_alloc.get_memview()[ + 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize + ], + dtype=transform_input_result.dtypes[tensor_index], + ).reshape(transform_input_result.dims[tensor_index]) + + np.testing.assert_allclose(tensor, sample_request.raw_inputs[tensor_index]) + + mem_pool.destroy() + + +def test_execute(mlutils) -> None: + + onnx_serialized, inputs, outputs = create_onnx_model() + + providers = ["CPUExecutionProvider"] + session = InferenceSession(onnx_serialized, providers=providers) + load_model_result = LoadModelResult(session, inputs=inputs, outputs=outputs) + + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_result = worker.transform_input( + request_batch, [fetch_input_result], mem_pool + ) + + execute_result = worker.execute( + request_batch, + load_model_result, + transform_result, + mlutils.get_test_device().lower(), + ) + + assert all(result.shape == (10, 1) for result in execute_result.predictions) + + mem_pool.destroy() + + +def test_transform_output(mlutils): + tensors = [np.zeros((10, 1))] + execute_result = ExecuteResult(tensors, [slice(0, 10)]) + + transformed_output = worker.transform_output(sample_request_batch, execute_result) + + assert transformed_output[0].outputs == [item.tobytes() for item in tensors] + assert transformed_output[0].shape == None + assert transformed_output[0].order == "c" + assert transformed_output[0].dtype == "float32" diff --git a/tests/dragon/test_protoclient.py b/tests/dragon/test_protoclient.py new file mode 100644 index 0000000000..008fe313df --- /dev/null +++ b/tests/dragon/test_protoclient.py @@ -0,0 +1,313 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import pickle +import time +import typing as t +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.comm.broadcaster import EventBroadcaster +from smartsim._core.mli.infrastructure.comm.event import OnWriteFeatureStore +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +# isort: off +from dragon import fli +from dragon.data.ddict.ddict import DDict + +# from ..ex..high_throughput_inference.mock_app import ProtoClient +from smartsim._core.mli.client.protoclient import ProtoClient + + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon +WORK_QUEUE_KEY = BackboneFeatureStore.MLI_WORKER_QUEUE +logger = get_logger(__name__) + + +@pytest.fixture(scope="module") +def the_worker_queue(the_backbone: BackboneFeatureStore) -> DragonFLIChannel: + """Fixture that creates a dragon FLI channel as a stand-in for the + worker queue created by the worker. + + :param the_backbone: The backbone feature store to update + with the worker queue descriptor. + :returns: The attached `DragonFLIChannel` + """ + + # create the FLI + to_worker_channel = create_local() + fli_ = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + comm_channel = DragonFLIChannel(fli_) + + # store the descriptor in the backbone + the_backbone.worker_queue = comm_channel.descriptor + + try: + comm_channel.send(b"foo") + except Exception as ex: + logger.exception(f"Test send from worker channel failed", exc_info=True) + + return comm_channel + + +@pytest.mark.parametrize( + "backbone_timeout, exp_wait_max", + [ + # aggregate the 1+1+1 into 3 on remaining parameters + pytest.param(0.5, 1 + 1 + 1, id="0.5s wait, 3 cycle steps"), + pytest.param(2, 3 + 2, id="2s wait, 4 cycle steps"), + pytest.param(4, 3 + 2 + 4, id="4s wait, 5 cycle steps"), + ], +) +def test_protoclient_timeout( + backbone_timeout: float, + exp_wait_max: float, + the_backbone: BackboneFeatureStore, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempts to attach to the worker queue from the protoclient + timeout in an appropriate amount of time. Note: due to the backoff, we verify + the elapsed time is less than the 15s of a cycle of waits. + + :param backbone_timeout: a timeout for use when configuring a proto client + :param exp_wait_max: a ceiling for the expected time spent waiting for + the timeout + :param the_backbone: a pre-initialized backbone featurestore for setting up + the environment variable required by the client + """ + + # NOTE: exp_wait_time maps to the cycled backoff of [0.1, 0.2, 0.4, 0.8] + # with leeway added (by allowing 1s each for the 0.1 and 0.5 steps) + + with monkeypatch.context() as ctx, pytest.raises(SmartSimError) as ex: + start_time = time.time() + # remove the worker queue value from the backbone if it exists + # to ensure the timeout occurs + the_backbone.pop(BackboneFeatureStore.MLI_WORKER_QUEUE) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + + ProtoClient(timing_on=False, backbone_timeout=backbone_timeout) + elapsed = time.time() - start_time + logger.info(f"ProtoClient timeout occurred in {elapsed} seconds") + + # confirm that we met our timeout + assert ( + elapsed >= backbone_timeout + ), f"below configured timeout {backbone_timeout}" + + # confirm that the total wait time is aligned with the sleep cycle + assert elapsed < exp_wait_max, f"above expected max wait {exp_wait_max}" + + +def test_protoclient_initialization_no_backbone( + monkeypatch: pytest.MonkeyPatch, the_worker_queue: DragonFLIChannel +): + """Verify that attempting to start the client without required environment variables + results in an exception. + + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + + NOTE: os.environ[BackboneFeatureStore.MLI_BACKBONE] is not set""" + + with monkeypatch.context() as patch, pytest.raises(SmartSimError) as ex: + patch.setenv(BackboneFeatureStore.MLI_BACKBONE, "") + + ProtoClient(timing_on=False) + + # confirm the missing value error has been raised + assert {"backbone", "configuration"}.issubset(set(ex.value.args[0].split(" "))) + + +def test_protoclient_initialization( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that attempting to start the client with required env vars results + in a fully initialized client. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone""" + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + fs_descriptor = the_backbone.descriptor + wq_descriptor = the_worker_queue.descriptor + + # confirm the backbone was attached correctly + assert client._backbone is not None + assert client._backbone.descriptor == fs_descriptor + + # we expect the backbone to add its descriptor to the local env + assert os.environ[BackboneFeatureStore.MLI_BACKBONE] == fs_descriptor + + # confirm the worker queue is created and attached correctly + assert client._to_worker_fli is not None + assert client._to_worker_fli.descriptor == wq_descriptor + + # we expect the worker queue descriptor to be placed into the backbone + # we do NOT expect _from_worker_ch to be placed anywhere. it's a specific callback + assert the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] == wq_descriptor + + # confirm the worker channels are created + assert client._from_worker_ch is not None + assert client._to_worker_ch is not None + + # wrap the channels just to easily verify they produces a descriptor + assert DragonCommChannel(client._from_worker_ch.channel).descriptor + assert DragonCommChannel(client._to_worker_ch).descriptor + + # confirm a publisher is created + assert client._publisher is not None + + +def test_protoclient_write_model( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that writing a model using the client causes the model data to be + written to a feature store. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: Passing the worker queue fixture to ensure + the worker queue environment is correctly configured. + from the backbone + """ + + with monkeypatch.context() as ctx: + # we won't actually send here + client = ProtoClient(timing_on=False) + + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + model_key = "my-model" + model_bytes = b"12345" + + client.set_model(model_key, model_bytes) + + # confirm the client modified the underlying feature store + assert client._backbone[model_key] == model_bytes + + +@pytest.mark.parametrize( + "num_listeners, num_model_updates", + [(1, 1), (1, 4), (2, 4), (16, 4), (64, 8)], +) +def test_protoclient_write_model_notification_sent( + the_backbone: BackboneFeatureStore, + the_worker_queue: DragonFLIChannel, + monkeypatch: pytest.MonkeyPatch, + num_listeners: int, + num_model_updates: int, +): + """Verify that writing a model sends a key-written event. + + :param the_backbone: a pre-initialized backbone featurestore + :param the_worker_queue: an FLI channel the client will retrieve + from the backbone + :param num_listeners: vary the number of registered listeners + to verify that the event is broadcast to everyone + :param num_listeners: vary the number of listeners to register + to verify the broadcast counts messages sent correctly + """ + + # we won't actually send here, but it won't try without registered listeners + listeners = [f"mock-ch-desc-{i}" for i in range(num_listeners)] + + the_backbone[BackboneFeatureStore.MLI_BACKBONE] = the_backbone.descriptor + the_backbone[BackboneFeatureStore.MLI_WORKER_QUEUE] = the_worker_queue.descriptor + the_backbone[BackboneFeatureStore.MLI_NOTIFY_CONSUMERS] = ",".join(listeners) + the_backbone[BackboneFeatureStore.MLI_REGISTRAR_CONSUMER] = None + + with monkeypatch.context() as ctx: + ctx.setenv(BackboneFeatureStore.MLI_BACKBONE, the_backbone.descriptor) + # NOTE: rely on `the_worker_queue` fixture to put MLI_WORKER_QUEUE in backbone + + client = ProtoClient(timing_on=False) + + publisher = t.cast(EventBroadcaster, client._publisher) + + # mock attaching to a channel given the mock-ch-desc in backbone + mock_send = MagicMock(return_value=None) + mock_comm_channel = MagicMock(**{"send": mock_send}, spec=DragonCommChannel) + mock_get_comm_channel = MagicMock(return_value=mock_comm_channel) + ctx.setattr(publisher, "_get_comm_channel", mock_get_comm_channel) + + model_key = "my-model" + model_bytes = b"12345" + + for i in range(num_model_updates): + client.set_model(model_key, model_bytes) + + # confirm that a listener channel was attached + # once for each registered listener in backbone + assert mock_get_comm_channel.call_count == num_listeners * num_model_updates + + # confirm the client raised the key-written event + assert ( + mock_send.call_count == num_listeners * num_model_updates + ), f"Expected {num_listeners} sends with {num_listeners} registrations" + + # with at least 1 consumer registered, we can verify the message is sent + for call_args in mock_send.call_args_list: + send_args = call_args.args + event_bytes, timeout = send_args[0], send_args[1] + + assert event_bytes, "Expected event bytes to be supplied to send" + assert ( + timeout == 0.001 + ), "Expected default timeout on call to `publisher.send`, " + + # confirm the correct event was raised + event = t.cast( + OnWriteFeatureStore, + pickle.loads(event_bytes), + ) + assert event.descriptor == the_backbone.descriptor + assert event.key == model_key diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py new file mode 100644 index 0000000000..48493b3c4d --- /dev/null +++ b/tests/dragon/test_reply_building.py @@ -0,0 +1,64 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.worker_manager import build_failure_reply + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("timeout", "Worker timed out", id="timeout"), + pytest.param("fail", "Failed while executing", id="fail"), + ], +) +def test_build_failure_reply(status: "Status", message: str): + "Ensures failure replies can be built successfully" + response = build_failure_reply(status, message) + display_name = response.schema.node.displayName # type: ignore + class_name = display_name.split(":")[-1] + assert class_name == "Response" + assert response.status == status + assert response.message == message + + +def test_build_failure_reply_fails(): + "Ensures ValueError is raised if a Status Enum is not used" + with pytest.raises(ValueError) as ex: + response = build_failure_reply("not a status enum", "message") + + assert "Error assigning status to response" in ex.value.args[0] diff --git a/tests/dragon/test_request_dispatcher.py b/tests/dragon/test_request_dispatcher.py new file mode 100644 index 0000000000..70d73e243f --- /dev/null +++ b/tests/dragon/test_request_dispatcher.py @@ -0,0 +1,233 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import gc +import os +import subprocess as sp +import time +import typing as t +from queue import Empty + +import numpy as np +import pytest + +from . import conftest +from .utils import msg_pump + +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp + +import torch + +# isort: on + +from dragon import fli +from dragon.data.ddict.ddict import DDict +from dragon.managed_memory import MemoryAlloc + +from smartsim._core.mli.comm.channel.dragon_channel import DragonCommChannel +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.request_dispatcher import ( + RequestBatch, + RequestDispatcher, +) +from smartsim._core.mli.infrastructure.control.worker_manager import ( + EnvironmentConfigLoader, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim.log import get_logger + +logger = get_logger(__name__) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +@pytest.mark.parametrize("num_iterations", [4]) +def test_request_dispatcher( + num_iterations: int, + the_storage: DDict, + test_dir: str, +) -> None: + """Test the request dispatcher batching and queueing system + + This also includes setting a queue to disposable, checking that it is no + longer referenced by the dispatcher. + """ + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_comm_ch = DragonFLIChannel(to_worker_fli) + + backbone_fs = BackboneFeatureStore(the_storage, allow_reserved_writes=True) + + # NOTE: env vars should be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = to_worker_fli_comm_ch.descriptor + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone_fs.descriptor + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=DragonCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + request_dispatcher = RequestDispatcher( + batch_timeout=1000, + batch_size=2, + config_loader=config_loader, + worker_type=TorchWorker, + mem_pool_size=2 * 1024**2, + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warning( + "FLI input queue not loaded correctly from config_loader: " + f"{config_loader._queue_descriptor}" + ) + + request_dispatcher._on_start() + + # put some messages into the work queue for the dispatcher to pickup + channels = [] + processes = [] + for i in range(num_iterations): + batch: t.Optional[RequestBatch] = None + mem_allocs = [] + tensors = [] + + # NOTE: creating callbacks in test to avoid a local channel being torn + # down when mock_messages terms but before the final response message is sent + + callback_channel = DragonCommChannel.from_local() + channels.append(callback_channel) + + process = conftest.function_as_dragon_proc( + msg_pump.mock_messages, + [ + worker_queue.descriptor, + backbone_fs.descriptor, + i, + callback_channel.descriptor, + ], + [], + [], + ) + processes.append(process) + process.start() + assert process.returncode is None, "The message pump failed to start" + + # give dragon some time to populate the message queues + for i in range(15): + try: + request_dispatcher._on_iteration() + batch = request_dispatcher.task_queue.get(timeout=1.0) + break + except Empty: + time.sleep(2) + logger.warning(f"Task queue is empty on iteration {i}") + continue + except Exception as exc: + logger.error(f"Task queue exception on iteration {i}") + raise exc + + assert batch is not None + assert batch.has_valid_requests + + model_key = batch.model_id.key + + try: + transform_result = batch.inputs + for transformed, dims, dtype in zip( + transform_result.transformed, + transform_result.dims, + transform_result.dtypes, + ): + mem_alloc = MemoryAlloc.attach(transformed) + mem_allocs.append(mem_alloc) + itemsize = np.empty((1), dtype=dtype).itemsize + tensors.append( + torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[0 : np.prod(dims) * itemsize], + dtype=dtype, + ).reshape(dims) + ) + ) + + assert len(batch.requests) == 2 + assert batch.model_id.key == model_key + assert model_key in request_dispatcher._queues + assert model_key in request_dispatcher._active_queues + assert len(request_dispatcher._queues[model_key]) == 1 + assert request_dispatcher._queues[model_key][0].empty() + assert request_dispatcher._queues[model_key][0].model_id.key == model_key + assert len(tensors) == 1 + assert tensors[0].shape == torch.Size([2, 2]) + + for tensor in tensors: + for sample_idx in range(tensor.shape[0]): + tensor_in = tensor[sample_idx] + tensor_out = (sample_idx + 1) * torch.ones( + (2,), dtype=torch.float32 + ) + assert torch.equal(tensor_in, tensor_out) + + except Exception as exc: + raise exc + finally: + for mem_alloc in mem_allocs: + mem_alloc.free() + + request_dispatcher._active_queues[model_key].make_disposable() + assert request_dispatcher._active_queues[model_key].can_be_removed + + request_dispatcher._on_iteration() + + assert model_key not in request_dispatcher._active_queues + assert model_key not in request_dispatcher._queues + + # Try to remove the dispatcher and free the memory + del request_dispatcher + gc.collect() diff --git a/tests/dragon/test_tensorflow_worker.py b/tests/dragon/test_tensorflow_worker.py new file mode 100644 index 0000000000..67014f1694 --- /dev/null +++ b/tests/dragon/test_tensorflow_worker.py @@ -0,0 +1,222 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import numpy as np +import pytest + +tf = pytest.importorskip("tensorflow") +from tensorflow import keras +from tensorflow.python.framework.convert_to_constants import ( + convert_var_to_const_function_in_v1, +) + +dragon = pytest.importorskip("dragon") +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryAlloc, MemoryPool + +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey +from smartsim._core.mli.infrastructure.worker.tensorflow_worker import TensorFlowWorker +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + RequestBatch, + TransformInputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +def get_batch() -> np.typing.ArrayLike: + return np.random.randn(20, 28, 28).astype(np.float32) + + +def create_tf_model(): + model = keras.Sequential( + layers=[ + keras.layers.InputLayer(input_shape=(28, 28), name="input"), + keras.layers.Flatten(input_shape=(28, 28), name="flatten"), + keras.layers.Dense(128, activation="relu", name="dense"), + keras.layers.Dense(10, activation="softmax", name="output"), + ], + name="FCN", + ) + + # Compile model with optimizer + model.compile( + optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + ) + + real_model = tf.function(model).get_concrete_function( + tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype) + ) + with tf.compat.v1.Session() as sess: + ffunc = convert_var_to_const_function_in_v1(real_model) + graph_def_orig = ffunc.graph.as_graph_def() + + graph_def_str = graph_def_orig.SerializeToString() + + names = lambda l: [x.name for x in l] + + return graph_def_str, names(ffunc.inputs), names(ffunc.outputs) + + +tensorflow_device = {"cpu": "/CPU", "gpu": "/GPU"} + + +def get_request() -> InferenceRequest: + + tensors = [get_batch()] + serialized_tensors_descriptors = [ + MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape)) + for tensor in tensors + ] + + return InferenceRequest( + model_key=ModelKey(key="model", descriptor="xyz"), + callback=None, + raw_inputs=tensors, + input_keys=None, + input_meta=serialized_tensors_descriptors, + output_keys=None, + raw_model=create_tf_model()[0], + batch_size=0, + ) + + +def get_request_batch_from_request( + request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None +) -> RequestBatch: + + return RequestBatch([request], inputs, request.model_key) + + +sample_request: InferenceRequest = get_request() +sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request) +worker = TensorFlowWorker() + + +def test_load_model(mlutils) -> None: + fetch_model_result = FetchModelResult(sample_request.raw_model) + load_model_result = worker.load_model( + sample_request_batch, fetch_model_result, mlutils.get_test_device().lower() + ) + + with tf.device(tensorflow_device[mlutils.get_test_device().lower()]): + results = load_model_result.model.run( + load_model_result.outputs, + feed_dict=dict(zip(load_model_result.inputs, [get_batch()])), + ) + + assert results[0].shape == (20, 10) + + +def test_transform_input(mlutils) -> None: + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_input_result = worker.transform_input( + sample_request_batch, [fetch_input_result], mem_pool + ) + + batch = get_batch() + assert transform_input_result.slices[0] == slice(0, batch.shape[0]) + + tensor_index = 0 + assert tuple(transform_input_result.dims[tensor_index]) == batch.shape + assert transform_input_result.dtypes[tensor_index] == str(batch.dtype) + mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index]) + itemsize = batch.itemsize + tensor = np.frombuffer( + mem_alloc.get_memview()[ + 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize + ], + dtype=transform_input_result.dtypes[tensor_index], + ).reshape(transform_input_result.dims[tensor_index]) + + np.testing.assert_allclose(tensor, sample_request.raw_inputs[tensor_index]) + + mem_pool.destroy() + + +def test_execute(mlutils) -> None: + + graph_def_str, inputs, outputs = create_tf_model() + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(graph_def_str) + + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name="") + load_model_result = LoadModelResult( + tf.compat.v1.Session(graph=graph), inputs=inputs, outputs=outputs + ) + + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_result = worker.transform_input( + request_batch, [fetch_input_result], mem_pool + ) + + execute_result = worker.execute( + request_batch, + load_model_result, + transform_result, + mlutils.get_test_device().lower(), + ) + + assert all(result.shape == (20, 10) for result in execute_result.predictions) + + mem_pool.destroy() + + +def test_transform_output(mlutils): + tensors = [np.zeros((20, 10))] + execute_result = ExecuteResult(tensors, [slice(0, 20)]) + + transformed_output = worker.transform_output(sample_request_batch, execute_result) + + assert transformed_output[0].outputs == [item.tobytes() for item in tensors] + assert transformed_output[0].shape == None + assert transformed_output[0].order == "c" + assert transformed_output[0].dtype == "float32" diff --git a/tests/dragon/test_torch_worker.py b/tests/dragon/test_torch_worker.py new file mode 100644 index 0000000000..2a9e7d01bd --- /dev/null +++ b/tests/dragon/test_torch_worker.py @@ -0,0 +1,221 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import numpy as np +import pytest +import torch + +dragon = pytest.importorskip("dragon") +import dragon.globalservices.pool as dragon_gs_pool +from dragon.managed_memory import MemoryAlloc, MemoryPool +from torch import nn +from torch.nn import functional as F + +from smartsim._core.mli.infrastructure.storage.feature_store import ModelKey +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + RequestBatch, + TransformInputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +# simple MNIST in PyTorch +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x, y): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +torch_device = {"cpu": "cpu", "gpu": "cuda"} + + +def get_batch() -> torch.Tensor: + return torch.rand(20, 1, 28, 28) + + +def create_torch_model(): + n = Net() + example_forward_input = get_batch() + module = torch.jit.trace(n, [example_forward_input, example_forward_input]) + model_buffer = io.BytesIO() + torch.jit.save(module, model_buffer) + return model_buffer.getvalue() + + +def get_request() -> InferenceRequest: + + tensors = [get_batch() for _ in range(2)] + tensor_numpy = [tensor.numpy() for tensor in tensors] + serialized_tensors_descriptors = [ + MessageHandler.build_tensor_descriptor("c", "float32", list(tensor.shape)) + for tensor in tensors + ] + + return InferenceRequest( + model_key=ModelKey(key="model", descriptor="xyz"), + callback=None, + raw_inputs=tensor_numpy, + input_keys=None, + input_meta=serialized_tensors_descriptors, + output_keys=None, + raw_model=create_torch_model(), + batch_size=0, + ) + + +def get_request_batch_from_request( + request: InferenceRequest, inputs: t.Optional[TransformInputResult] = None +) -> RequestBatch: + + return RequestBatch([request], inputs, request.model_key) + + +sample_request: InferenceRequest = get_request() +sample_request_batch: RequestBatch = get_request_batch_from_request(sample_request) +worker = TorchWorker() + + +def test_load_model(mlutils) -> None: + fetch_model_result = FetchModelResult(sample_request.raw_model) + load_model_result = worker.load_model( + sample_request_batch, fetch_model_result, mlutils.get_test_device().lower() + ) + + assert load_model_result.model( + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + get_batch().to(torch_device[mlutils.get_test_device().lower()]), + ).shape == torch.Size((20, 10)) + + +def test_transform_input(mlutils) -> None: + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_input_result = worker.transform_input( + sample_request_batch, [fetch_input_result], mem_pool + ) + + batch = get_batch().numpy() + assert transform_input_result.slices[0] == slice(0, batch.shape[0]) + + for tensor_index in range(2): + assert torch.Size(transform_input_result.dims[tensor_index]) == batch.shape + assert transform_input_result.dtypes[tensor_index] == str(batch.dtype) + mem_alloc = MemoryAlloc.attach(transform_input_result.transformed[tensor_index]) + itemsize = batch.itemsize + tensor = torch.from_numpy( + np.frombuffer( + mem_alloc.get_memview()[ + 0 : np.prod(transform_input_result.dims[tensor_index]) * itemsize + ], + dtype=transform_input_result.dtypes[tensor_index], + ).reshape(transform_input_result.dims[tensor_index]) + ) + + assert torch.equal( + tensor, torch.from_numpy(sample_request.raw_inputs[tensor_index]) + ) + + mem_pool.destroy() + + +def test_execute(mlutils) -> None: + load_model_result = LoadModelResult( + Net().to(torch_device[mlutils.get_test_device().lower()]) + ) + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + request_batch = get_request_batch_from_request(sample_request, fetch_input_result) + + mem_pool = MemoryPool.attach(dragon_gs_pool.create(1024**2).sdesc) + + transform_result = worker.transform_input( + request_batch, [fetch_input_result], mem_pool + ) + + execute_result = worker.execute( + request_batch, + load_model_result, + transform_result, + mlutils.get_test_device().lower(), + ) + + assert all( + result.shape == torch.Size((20, 10)) for result in execute_result.predictions + ) + + mem_pool.destroy() + + +def test_transform_output(mlutils): + tensors = [torch.rand((20, 10)) for _ in range(2)] + execute_result = ExecuteResult(tensors, [slice(0, 20)]) + + transformed_output = worker.transform_output(sample_request_batch, execute_result) + + assert transformed_output[0].outputs == [item.numpy().tobytes() for item in tensors] + assert transformed_output[0].shape == None + assert transformed_output[0].order == "c" + assert transformed_output[0].dtype == "float32" diff --git a/tests/dragon/test_worker_manager.py b/tests/dragon/test_worker_manager.py new file mode 100644 index 0000000000..4047a731fc --- /dev/null +++ b/tests/dragon/test_worker_manager.py @@ -0,0 +1,314 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import time + +import pytest + +torch = pytest.importorskip("torch") +dragon = pytest.importorskip("dragon") + +import multiprocessing as mp + +try: + mp.set_start_method("dragon") +except Exception: + pass + +import os + +import torch.nn as nn +from dragon import fli + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.comm.channel.dragon_util import create_local +from smartsim._core.mli.infrastructure.control.worker_manager import ( + EnvironmentConfigLoader, + WorkerManager, +) +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_feature_store import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.storage.dragon_util import create_ddict +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +from .utils.channel import FileSystemCommChannel + +logger = get_logger(__name__) +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +class MiniModel(nn.Module): + """A torch model that can be executed by the default torch worker""" + + def __init__(self): + """Initialize the model.""" + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + """Execute a forward pass.""" + return self._net(input) + + @property + def bytes(self) -> bytes: + """Retrieve the serialized model + + :returns: The byte stream of the model file + """ + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + """Generate a single batch of data with the correct + shape for inference. + + :returns: The batch as a torch tensor + """ + return torch.randn((100, 2), dtype=torch.float32) + + +def create_model(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :param model_path: The path to the torch model file + """ + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + mini_model = MiniModel() + torch.save(mini_model, model_path) + + return model_path + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing.""" + mini_model = MiniModel() + return mini_model.bytes + + +def mock_messages( + feature_store_root_dir: pathlib.Path, + comm_channel_root_dir: pathlib.Path, + kill_queue: mp.Queue, +) -> None: + """Mock event producer for triggering the inference pipeline. + + :param feature_store_root_dir: Path to a directory where a + FileSystemFeatureStore can read & write results + :param comm_channel_root_dir: Path to a directory where a + FileSystemCommChannel can read & write messages + :param kill_queue: Queue used by unit test to stop mock_message process + """ + feature_store_root_dir.mkdir(parents=True, exist_ok=True) + comm_channel_root_dir.mkdir(parents=True, exist_ok=True) + + iteration_number = 0 + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + backbone = config_loader.get_backbone() + + worker_queue = config_loader.get_queue() + if worker_queue is None: + queue_desc = config_loader._queue_descriptor + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {queue_desc}" + ) + + model_key = "mini-model" + model_bytes = load_model() + backbone[model_key] = model_bytes + + while True: + if not kill_queue.empty(): + return + iteration_number += 1 + time.sleep(1) + + channel_key = comm_channel_root_dir / f"{iteration_number}/channel.txt" + callback_channel = FileSystemCommChannel(pathlib.Path(channel_key)) + + batch = MiniModel.get_batch() + shape = batch.shape + batch_bytes = batch.numpy().tobytes() + + logger.debug(f"Model content: {backbone[model_key][:20]}") + + input_descriptor = MessageHandler.build_tensor_descriptor( + "f", "float32", list(shape) + ) + + # The first request is always the metadata... + request = MessageHandler.build_request( + reply_channel=callback_channel.descriptor, + model=MessageHandler.build_model(model_bytes, "mini-model", "1.0"), + inputs=[input_descriptor], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + request_bytes = MessageHandler.serialize_request(request) + fli: DragonFLIChannel = worker_queue + + with fli._fli.sendh(timeout=None, stream_channel=fli._channel) as sendh: + sendh.send_bytes(request_bytes) + sendh.send_bytes(batch_bytes) + + logger.info("published message") + + if iteration_number > 5: + return + + +def mock_mli_infrastructure_mgr() -> None: + """Create resources normally instanatiated by the infrastructure + management portion of the DragonBackend. + """ + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + + integrated_worker = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker, + as_service=True, + cooldown=10, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + worker_manager.execute() + + +@pytest.fixture +def prepare_environment(test_dir: str) -> pathlib.Path: + """Cleanup prior outputs to run demo repeatedly. + + :param test_dir: the directory to prepare + :returns: The path to the log file + """ + path = pathlib.Path(f"{test_dir}/workermanager.log") + logging.basicConfig(filename=path.absolute(), level=logging.DEBUG) + return path + + +def test_worker_manager(prepare_environment: pathlib.Path) -> None: + """Test the worker manager. + + :param prepare_environment: Pass this fixture to configure + global resources before the worker manager executes + """ + + test_path = prepare_environment + fs_path = test_path / "feature_store" + comm_path = test_path / "comm_store" + + mgr_per_node = 1 + num_nodes = 2 + mem_per_node = 128 * 1024**2 + + storage = create_ddict(num_nodes, mgr_per_node, mem_per_node) + backbone = BackboneFeatureStore(storage, allow_reserved_writes=True) + + to_worker_channel = create_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + + to_worker_fli_comm_channel = DragonFLIChannel(to_worker_fli) + + # NOTE: env vars must be set prior to instantiating EnvironmentConfigLoader + # or test environment may be unable to send messages w/queue + os.environ[BackboneFeatureStore.MLI_WORKER_QUEUE] = ( + to_worker_fli_comm_channel.descriptor + ) + os.environ[BackboneFeatureStore.MLI_BACKBONE] = backbone.descriptor + + config_loader = EnvironmentConfigLoader( + featurestore_factory=DragonFeatureStore.from_descriptor, + callback_factory=FileSystemCommChannel.from_descriptor, + queue_factory=DragonFLIChannel.from_descriptor, + ) + integrated_worker_type = TorchWorker + + worker_manager = WorkerManager( + config_loader, + integrated_worker_type, + as_service=True, + cooldown=5, + device="cpu", + dispatcher_queue=mp.Queue(maxsize=0), + ) + + worker_queue = config_loader.get_queue() + if worker_queue is None: + logger.warn( + f"FLI input queue not loaded correctly from config_loader: {config_loader._queue_descriptor}" + ) + backbone.worker_queue = to_worker_fli_comm_channel.descriptor + + # create a mock client application to populate the request queue + kill_queue = mp.Queue() + msg_pump = mp.Process( + target=mock_messages, + args=(fs_path, comm_path, kill_queue), + ) + msg_pump.start() + + # create a process to execute commands + process = mp.Process(target=mock_mli_infrastructure_mgr) + + # let it send some messages before starting the worker manager + msg_pump.join(timeout=5) + process.start() + msg_pump.join(timeout=5) + kill_queue.put_nowait("kill!") + process.join(timeout=5) + msg_pump.kill() + process.kill() diff --git a/tests/dragon/utils/__init__.py b/tests/dragon/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py new file mode 100644 index 0000000000..4c46359c2d --- /dev/null +++ b/tests/dragon/utils/channel.py @@ -0,0 +1,125 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pathlib +import threading +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.error.errors import SmartSimError +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: pathlib.Path) -> None: + """Initialize the FileSystemCommChannel instance. + + :param key: a path to the root directory of the feature store + """ + self._lock = threading.RLock() + + super().__init__(key.as_posix()) + self._file_path = key + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes, timeout: float = 0) -> None: + """Send a message throuh the underlying communication channel. + + :param value: The value to send + :param timeout: maximum time to wait (in seconds) for messages to send + """ + with self._lock: + # write as text so we can add newlines as delimiters + with open(self._file_path, "a") as fp: + encoded_value = base64.b64encode(value).decode("utf-8") + fp.write(f"{encoded_value}\n") + logger.debug(f"FileSystemCommChannel {self._file_path} sent message") + + def recv(self, timeout: float = 0) -> t.List[bytes]: + """Receives message(s) through the underlying communication channel. + + :param timeout: maximum time to wait (in seconds) for messages to arrive + :returns: the received message + :raises SmartSimError: if the descriptor points to a missing file + """ + with self._lock: + messages: t.List[bytes] = [] + if not self._file_path.exists(): + raise SmartSimError("Empty channel") + + # read as text so we can split on newlines + with open(self._file_path, "r") as fp: + lines = fp.readlines() + + if lines: + line = lines.pop(0) + event_bytes = base64.b64decode(line.encode("utf-8")) + messages.append(event_bytes) + + self.clear() + + # remove the first message only, write remainder back... + if len(lines) > 0: + with open(self._file_path, "w") as fp: + fp.writelines(lines) + + logger.debug( + f"FileSystemCommChannel {self._file_path} received message" + ) + + return messages + + def clear(self) -> None: + """Create an empty file for events.""" + if self._file_path.exists(): + self._file_path.unlink() + self._file_path.touch() + + @classmethod + def from_descriptor( + cls, + descriptor: str, + ) -> "FileSystemCommChannel": + """A factory method that creates an instance from a descriptor string. + + :param descriptor: The descriptor that uniquely identifies the resource + :returns: An attached FileSystemCommChannel + """ + try: + path = pathlib.Path(descriptor) + return FileSystemCommChannel(path) + except: + logger.warning(f"failed to create fs comm channel: {descriptor}") + raise diff --git a/tests/dragon/utils/msg_pump.py b/tests/dragon/utils/msg_pump.py new file mode 100644 index 0000000000..f3beaa8134 --- /dev/null +++ b/tests/dragon/utils/msg_pump.py @@ -0,0 +1,227 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import pathlib +import sys +import time +import typing as t + +import pytest + +pytest.importorskip("torch") +pytest.importorskip("dragon") + + +# isort: off +import dragon +import multiprocessing as mp +import torch +import torch.nn as nn + +# isort: on + +from smartsim._core.mli.comm.channel.dragon_fli import DragonFLIChannel +from smartsim._core.mli.infrastructure.storage.backbone_feature_store import ( + BackboneFeatureStore, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__, log_level=logging.DEBUG) + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + +try: + mp.set_start_method("dragon") +except Exception: + pass + + +class MiniModel(nn.Module): + def __init__(self): + super().__init__() + + self._name = "mini-model" + self._net = torch.nn.Linear(2, 1) + + def forward(self, input): + return self._net(input) + + @property + def bytes(self) -> bytes: + """Returns the model serialized to a byte stream""" + buffer = io.BytesIO() + scripted = torch.jit.trace(self._net, self.get_batch()) + torch.jit.save(scripted, buffer) + return buffer.getvalue() + + @classmethod + def get_batch(cls) -> "torch.Tensor": + return torch.randn((100, 2), dtype=torch.float32) + + +def load_model() -> bytes: + """Create a simple torch model in memory for testing""" + mini_model = MiniModel() + return mini_model.bytes + + +def persist_model_file(model_path: pathlib.Path) -> pathlib.Path: + """Create a simple torch model and persist to disk for + testing purposes. + + :returns: Path to the model file + """ + # test_path = pathlib.Path(work_dir) + if not model_path.parent.exists(): + model_path.parent.mkdir(parents=True, exist_ok=True) + + model_path.unlink(missing_ok=True) + + model = torch.nn.Linear(2, 1) + torch.save(model, model_path) + + return model_path + + +def _mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> None: + """Mock event producer for triggering the inference pipeline.""" + model_key = "mini-model" + # mock_message sends 2 messages, so we offset by 2 * (# of iterations in caller) + offset = 2 * parent_iteration + + feature_store = BackboneFeatureStore.from_descriptor(fs_descriptor) + request_dispatcher_queue = DragonFLIChannel.from_descriptor(dispatch_fli_descriptor) + + feature_store[model_key] = load_model() + + for iteration_number in range(2): + logged_iteration = offset + iteration_number + logger.debug(f"Sending mock message {logged_iteration}") + + output_key = f"output-{iteration_number}" + + tensor = ( + (iteration_number + 1) * torch.ones((1, 2), dtype=torch.float32) + ).numpy() + fsd = feature_store.descriptor + + tensor_desc = MessageHandler.build_tensor_descriptor( + "c", "float32", list(tensor.shape) + ) + + message_tensor_output_key = MessageHandler.build_tensor_key(output_key, fsd) + message_model_key = MessageHandler.build_model_key(model_key, fsd) + + request = MessageHandler.build_request( + reply_channel=callback_descriptor, + model=message_model_key, + inputs=[tensor_desc], + outputs=[message_tensor_output_key], + output_descriptors=[], + custom_attributes=None, + ) + + logger.info(f"Sending request {iteration_number} to request_dispatcher_queue") + request_bytes = MessageHandler.serialize_request(request) + + logger.info("Sending msg_envelope") + + # cuid = request_dispatcher_queue._channel.cuid + # logger.info(f"\tInternal cuid: {cuid}") + + # send the header & body together so they arrive together + try: + request_dispatcher_queue.send_multiple( + [request_bytes, tensor.tobytes()], timeout=None, handle_timeout=None + ) + logger.info(f"\tenvelope 0: {request_bytes[:5]}...") + logger.info(f"\tenvelope 1: {tensor.tobytes()[:5]}...") + except Exception as ex: + logger.exception("Unable to send request envelope") + + logger.info("All messages sent") + + # keep the process alive for an extra 15 seconds to let the processor + # have access to the channels before they're destroyed + for _ in range(15): + time.sleep(1) + + +def mock_messages( + dispatch_fli_descriptor: str, + fs_descriptor: str, + parent_iteration: int, + callback_descriptor: str, +) -> int: + """Mock event producer for triggering the inference pipeline. Used + when starting using multiprocessing.""" + logger.info(f"{dispatch_fli_descriptor=}") + logger.info(f"{fs_descriptor=}") + logger.info(f"{parent_iteration=}") + logger.info(f"{callback_descriptor=}") + + try: + return _mock_messages( + dispatch_fli_descriptor, + fs_descriptor, + parent_iteration, + callback_descriptor, + ) + except Exception as ex: + logger.exception() + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + args = argparse.ArgumentParser() + + args.add_argument("--dispatch-fli-descriptor", type=str) + args.add_argument("--fs-descriptor", type=str) + args.add_argument("--parent-iteration", type=int) + args.add_argument("--callback-descriptor", type=str) + + args = args.parse_args() + + return_code = mock_messages( + args.dispatch_fli_descriptor, + args.fs_descriptor, + args.parent_iteration, + args.callback_descriptor, + ) + sys.exit(return_code) diff --git a/tests/dragon/utils/worker.py b/tests/dragon/utils/worker.py new file mode 100644 index 0000000000..0582cae566 --- /dev/null +++ b/tests/dragon/utils/worker.py @@ -0,0 +1,104 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult, device: str + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + device: str, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + result_device: str, + ) -> mliw.TransformOutputResult: + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) diff --git a/tests/test_dragon_runsettings.py b/tests/test_dragon_runsettings.py new file mode 100644 index 0000000000..8c7600c74c --- /dev/null +++ b/tests/test_dragon_runsettings.py @@ -0,0 +1,217 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim.settings import DragonRunSettings + +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_a + + +def test_dragon_runsettings_nodes(): + """Verify that node count is set correctly""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + exp_value = 3 + rs.set_nodes(exp_value) + assert rs.run_args["nodes"] == exp_value + + exp_value = 9 + rs.set_nodes(exp_value) + assert rs.run_args["nodes"] == exp_value + + +def test_dragon_runsettings_tasks_per_node(): + """Verify that tasks per node is set correctly""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + exp_value = 3 + rs.set_tasks_per_node(exp_value) + assert rs.run_args["tasks-per-node"] == exp_value + + exp_value = 7 + rs.set_tasks_per_node(exp_value) + assert rs.run_args["tasks-per-node"] == exp_value + + +def test_dragon_runsettings_cpu_affinity(): + """Verify that the CPU affinity is set correctly""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + exp_value = [0, 1, 2, 3] + rs.set_cpu_affinity([0, 1, 2, 3]) + assert rs.run_args["cpu-affinity"] == ",".join(str(val) for val in exp_value) + + # ensure the value is not changed when we extend the list + exp_value.extend([4, 5, 6]) + assert rs.run_args["cpu-affinity"] != ",".join(str(val) for val in exp_value) + + rs.set_cpu_affinity(exp_value) + assert rs.run_args["cpu-affinity"] == ",".join(str(val) for val in exp_value) + + # ensure the value is not changed when we extend the list + rs.run_args["cpu-affinity"] = "7,8,9" + assert rs.run_args["cpu-affinity"] != ",".join(str(val) for val in exp_value) + + +def test_dragon_runsettings_gpu_affinity(): + """Verify that the GPU affinity is set correctly""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + exp_value = [0, 1, 2, 3] + rs.set_gpu_affinity([0, 1, 2, 3]) + assert rs.run_args["gpu-affinity"] == ",".join(str(val) for val in exp_value) + + # ensure the value is not changed when we extend the list + exp_value.extend([4, 5, 6]) + assert rs.run_args["gpu-affinity"] != ",".join(str(val) for val in exp_value) + + rs.set_gpu_affinity(exp_value) + assert rs.run_args["gpu-affinity"] == ",".join(str(val) for val in exp_value) + + # ensure the value is not changed when we extend the list + rs.run_args["gpu-affinity"] = "7,8,9" + assert rs.run_args["gpu-affinity"] != ",".join(str(val) for val in exp_value) + + +def test_dragon_runsettings_hostlist_null(): + """Verify that passing a null hostlist is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + with pytest.raises(ValueError) as ex: + rs.set_hostlist(None) + + assert "empty hostlist" in ex.value.args[0] + + +def test_dragon_runsettings_hostlist_empty(): + """Verify that passing an empty hostlist is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + with pytest.raises(ValueError) as ex: + rs.set_hostlist([]) + + assert "empty hostlist" in ex.value.args[0] + + +@pytest.mark.parametrize("hostlist_csv", [" ", " , , , ", ",", ",,,"]) +def test_dragon_runsettings_hostlist_whitespace_handling(hostlist_csv: str): + """Verify that passing a hostlist with emptystring host names is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + # empty string as hostname in list + with pytest.raises(ValueError) as ex: + rs.set_hostlist(hostlist_csv) + + assert "invalid names" in ex.value.args[0] + + +@pytest.mark.parametrize( + "hostlist_csv", [[" "], [" ", "", " ", " "], ["", " "], ["", "", "", ""]] +) +def test_dragon_runsettings_hostlist_whitespace_handling_list(hostlist_csv: str): + """Verify that passing a hostlist with emptystring host names contained in a list + is treated as a failure""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + # empty string as hostname in list + with pytest.raises(ValueError) as ex: + rs.set_hostlist(hostlist_csv) + + assert "invalid names" in ex.value.args[0] + + +def test_dragon_runsettings_hostlist_as_csv(): + """Verify that a hostlist is stored properly when passing in a CSV string""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + hostnames = ["host0", "host1", "host2", "host3", "host4"] + + # set the host list with ideal comma separated values + input0 = ",".join(hostnames) + + # set the host list with a string of comma separated values + # including extra whitespace + input1 = ", ".join(hostnames) + + for hosts_input in [input0, input1]: + rs.set_hostlist(hosts_input) + + stored_list = rs.run_args.get("host-list", None) + assert stored_list + + # confirm that all values from the original list are retrieved + split_stored_list = stored_list.split(",") + assert set(hostnames) == set(split_stored_list) + + +def test_dragon_runsettings_hostlist_as_csv(): + """Verify that a hostlist is stored properly when passing in a CSV string""" + rs = DragonRunSettings(exe="sleep", exe_args=["1"]) + + # baseline check that no host list exists + stored_list = rs.run_args.get("host-list", None) + assert stored_list is None + + hostnames = ["host0", "host1", "host2", "host3", "host4"] + + # set the host list with ideal comma separated values + input0 = ",".join(hostnames) + + # set the host list with a string of comma separated values + # including extra whitespace + input1 = ", ".join(hostnames) + + for hosts_input in [input0, input1]: + rs.set_hostlist(hosts_input) + + stored_list = rs.run_args.get("host-list", None) + assert stored_list + + # confirm that all values from the original list are retrieved + split_stored_list = stored_list.split(",") + assert set(hostnames) == set(split_stored_list)