diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 80323ff..7a6639d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,10 +13,15 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + # - name: Setup GStreamer + # run: sudo apt-get install -y --no-install-recommends ca-certificates build-essential \ + # python3-gst-1.0 gstreamer1.0-python3-plugin-loader gstreamer1.0-plugins-base gstreamer1.0-tools gstreamer1.0-libav \ + # gstreamer1.0-plugins-base-apps gstreamer1.0-plugins-good gstreamer1.0-plugins-bad gstreamer1.0-plugins-ugly \ + # gstreamer1.0-x gstreamer1.0-alsa gstreamer1.0-gl gstreamer1.0-gtk3 gstreamer1.0-qt5 gstreamer1.0-pulseaudio + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.12 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -31,10 +36,10 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=121 --statistics - name: Run verify program run: | - PYTHONPATH=. python stream/verify.py + PYTHONPATH=. python monaistream/verify.py # - name: Test with unittest # run: | - # ./runtests.sh --coverage + # ./runtests.sh --coverage -u # - name: Upload coverage # uses: codecov/codecov-action@v1 # with: diff --git a/README.md b/README.md index 302ae27..f127169 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Experimental revival of this repo, please comment but don't expect anything to b 2. Run `xhost +local:docker` to grant X permissions to Docker containers -3. Run: `docker run -ti --rm -e DISPLAY --gpus device=1 -v $PWD:/opt/monaistream monaistream` +3. Run: `docker run -ti --rm -e DISPLAY --network host --gpus device=1 -v $PWD:/opt/monaistream monaistream` ## Numpy Transform Test diff --git a/monaistream/__init__.py b/monaistream/__init__.py index cc92023..ccfccac 100644 --- a/monaistream/__init__.py +++ b/monaistream/__init__.py @@ -14,3 +14,4 @@ from .threadsafe import * from .verify import * +from .callable_evaluator import * \ No newline at end of file diff --git a/monaistream/simple_inference.py b/monaistream/callable_evaluator.py similarity index 66% rename from monaistream/simple_inference.py rename to monaistream/callable_evaluator.py index 1c32f13..a072ac1 100644 --- a/monaistream/simple_inference.py +++ b/monaistream/callable_evaluator.py @@ -9,59 +9,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Sequence -from monai.inferers import Inferer -from monai.transforms import apply_transform, Transform -from monai.engines import SupervisedEvaluator, default_metric_cmp_fn, default_prepare_batch -from monai.utils import ForwardMode,CommonKeys -from monai.data import Dataset -from monai.handlers import MeanSquaredError, from_engine +import logging +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence + import torch -from torch.nn import Module -from monai.utils import IgniteInfo, min_version, optional_import +from monai.data import Dataset, DataLoader +from monai.engines import SupervisedEvaluator, default_metric_cmp_fn, default_prepare_batch +from monai.inferers import Inferer +from monai.transforms import Transform +from monai.utils import CommonKeys, ForwardMode, min_version, optional_import +from monai.utils.enums import IgniteInfo if TYPE_CHECKING: - from ignite.engine import Engine, EventEnum + from ignite.engine import Engine, Events, EventEnum from ignite.metrics import Metric else: version = IgniteInfo.OPT_IMPORT_VERSION Engine, _ = optional_import("ignite.engine", version, min_version, "Engine", as_type="decorator") Metric, _ = optional_import("ignite.metrics", version, min_version, "Metric", as_type="decorator") + Events, _ = optional_import("ignite.engine", version, min_version, "Events", as_type="decorator") EventEnum, _ = optional_import("ignite.engine", version, min_version, "EventEnum", as_type="decorator") -class SimpleInferenceEngine: - """ - A simple engine-like class is for running inference on a per-input basis, such as with per-frame data in a - video stream. It relies on a supplied Inferer instance and a network. - """ - - def __init__( - self, inferer: Inferer, network: Module, preprocess: Callable | None = None, postprocess: Callable | None = None - ): - self.inferer = inferer - self.network = network - self.preprocess = preprocess - self.postprocess = postprocess - - def __call__(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any: - if self.preprocess: - inputs = apply_transform(self.preprocess, inputs) - - outputs = self.inferer(inputs, self.network, *args, **kwargs) - - if self.postprocess: - outputs = apply_transform(self.postprocess, outputs) - - return outputs +__all__ = ["SingleItemDataset", "RingBufferDataset", "CallableEvaluator"] class SingleItemDataset(Dataset): """ - This simple dataset only ever has one item and acts as its own iterable. This is used with InferenceEngine to + This simple dataset only ever has one item and acts as its own iterable. This is used with InferenceEngine to represent a changeable single item epoch. """ + def __init__(self, transform: Sequence[Callable] | Callable | None = None) -> None: super().__init__([None], transform) @@ -69,10 +48,31 @@ def set_item(self, item): self.data[0] = item def __iter__(self): - yield self.data[0] + item = self[0] + + # TODO: use standard way of adding batch dimensions, or do something specific here + # for how groups of frames would be passed? + if isinstance(item, torch.Tensor): + yield item[None] + elif isinstance(item, Sequence): + yield tuple(v[None] for v in item) + else: + yield {k: v[None] for k, v in item.items()} + + +class RingBufferDataset(SingleItemDataset): + def __init__(self, num_values=1, transform: Sequence[Callable] | Callable | None = None) -> None: + super().__init__(transform) + self.num_values = num_values + + def set_item(self, item): + if self.data[0] is None: + self.data[0] = (item,) * self.num_values + else: + self.data[0] = self.data[0][1:] + (item,) -class InferenceEngine(SupervisedEvaluator): +class CallableEvaluator(SupervisedEvaluator): """ A simple inference engine type for applying inference to one input at a time as a callable. This is meant to be used for inference on per-frame video stream data where the state of the engine and other setup should be done initially @@ -84,6 +84,7 @@ def __init__( self, device: torch.device, network: torch.nn.Module, + data_loader: Iterable | DataLoader | None = None, preprocessing: Transform | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, @@ -102,12 +103,13 @@ def __init__( amp_kwargs: dict | None = None, compile: bool = False, compile_kwargs: dict | None = None, + use_interrupt: bool = False, ) -> None: super().__init__( device=device, - val_data_loader=SingleItemDataset(preprocessing), + val_data_loader=data_loader if data_loader is not None else SingleItemDataset(preprocessing), epoch_length=1, - network=network, + network=network, # TODO: auto-convert to given device? inferer=inferer, non_blocking=non_blocking, prepare_batch=prepare_batch, @@ -128,24 +130,20 @@ def __init__( compile_kwargs=compile_kwargs, ) + self.logger.setLevel(logging.ERROR) # probably don't want output for every frame + self.use_interrupt = use_interrupt + + if use_interrupt: + self.add_event_handler(Events.ITERATION_COMPLETED, self.interrupt) + def __call__(self, item: Any, include_metrics: bool = False) -> Any: self.data_loader.set_item(item) + self.run() out = self.state.output[0][CommonKeys.PRED] if include_metrics: - return out, dict(engine.state.metrics) + return out, dict(self.state.metrics) else: return out - - -if __name__ == "__main__": - net = torch.nn.Identity() - engine = InferenceEngine( - network=net, - device="cpu", - key_val_metric={"mse": MeanSquaredError(output_transform=from_engine([CommonKeys.IMAGE, CommonKeys.PRED]))}, - ) - print(engine(torch.rand(1, 5, 5))) - print(engine(torch.rand(1, 6, 6), True)) diff --git a/monaistream/gstreamer/__init__.py b/monaistream/gstreamer/__init__.py index aceb2d4..c643911 100644 --- a/monaistream/gstreamer/__init__.py +++ b/monaistream/gstreamer/__init__.py @@ -32,9 +32,10 @@ os.environ["GST_PLUGIN_PATH"] = plugin_path # set the plugin path so that this directory is searched gi.require_version("Gst", "1.0") + gi.require_version("GObject", "2.0") gi.require_version("GstBase", "1.0") gi.require_version("GstVideo", "1.0") - from gi.repository import Gst + from gi.repository import Gst, GLib, GObject, GstBase, GstVideo Gst.init([]) # use GST_DEBUG instead https://gstreamer.freedesktop.org/documentation/gstreamer/running.html @@ -42,7 +43,7 @@ # Gst.debug_set_default_threshold(5) from monaistream.gstreamer.utils import * - from monaistream.gstreamer.numpy_transforms import * + # from monaistream.gstreamer.numpy_transforms import * # TODO: import more things here diff --git a/monaistream/gstreamer/launch.py b/monaistream/gstreamer/launch.py new file mode 100644 index 0000000..82b6c4e --- /dev/null +++ b/monaistream/gstreamer/launch.py @@ -0,0 +1,87 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable +from monaistream.gstreamer import Gst, GLib + + +def default_message_handler(bus: Gst.Bus, message: Gst.Message, loop: GLib.MainLoop): + if message.type == Gst.MessageType.EOS: + loop.quit() + elif message.type == Gst.MessageType.ERROR: + err, debug = message.parse_error() + print(err, debug, file=sys.stderr) + loop.quit() + elif message.type == Gst.MessageType.WARNING: + err, debug = message.parse_warning() + print(err, debug, file=sys.stderr) + + return True + + +def default_loop_runner( + pipeline: Gst.Pipeline, + loop: GLib.MainLoop | None = None, + message_handler: Callable | None = default_message_handler, +): + if loop is None: + loop = GLib.MainLoop() + + if message_handler is not None: + bus = pipeline.get_bus() + bus.add_signal_watch() + bus.connect("message", message_handler, loop) + + pipeline.set_state(Gst.State.PLAYING) + + try: + loop.run() + except KeyboardInterrupt: + raise + finally: + pipeline.set_state(Gst.State.NULL) + if loop and loop.is_running(): + loop.quit() + + +def launch(args: list[str] | str, message_handler: Callable = default_message_handler): + """ + Defines a launching program for GStreamer like gst-launch but with MONAIStream plugin classes loaded. The `args` + argument is a list of strings containing a single full pipeline command line, a list of strings containing a + transform and its arguments (with no ! separator), or a list of words as if on the command line (with ! separator). + + Example:: + + python -m monaistream.gstreamer.launch videotestsrc num-buffers=1 ! jpegenc ! filesink location=img.jpg + + """ + + if isinstance(args, str): + command = args + else: + args = list(map(str, args)) + if not args: + raise ValueError("No arguments provided, a list of elements or a pipeline string is required.") + + if len(args) == 1: + command = args[0] + elif "!" in args: + command = " ".join(args) + else: + command = " ! ".join(args) + + pipeline = Gst.parse_launch(command) + default_loop_runner(pipeline, None, message_handler) + + +if __name__ == "__main__": + launch(sys.argv[1:]) diff --git a/monaistream/gstreamer/utils.py b/monaistream/gstreamer/utils.py index 7b92b01..421115c 100644 --- a/monaistream/gstreamer/utils.py +++ b/monaistream/gstreamer/utils.py @@ -10,21 +10,33 @@ # limitations under the License. from contextlib import contextmanager + import numpy as np +import torch + +from monaistream.gstreamer import Gst, GstVideo -from gi.repository import Gst, GstVideo -__all__ = ["BYTE_FORMATS", "get_dtype_from_bits", "map_buffer_to_numpy"] +__all__ = [ + "BYTE_FORMATS", + "DEFAULT_CAPS_STR", + "get_dtype_from_bits", + "map_buffer_to_numpy", + "map_buffer_to_tensor", + "get_buffer_tensor", +] BYTE_FORMATS = "{RGBx,BGRx,xRGB,xBGR,RGBA,BGRA,ARGB,ABGR,RGB,BGR,GRAY8,GRAY16_BE,GRAY16_LE}" +DEFAULT_CAPS_STR = f"video/x-raw,format={BYTE_FORMATS}" + def get_video_pad_template( - name, direction=Gst.PadDirection.SRC, presence=Gst.PadPresence.ALWAYS, caps_str=f"video/x-raw,format={BYTE_FORMATS}" + name, direction=Gst.PadDirection.SRC, presence=Gst.PadPresence.ALWAYS, caps_str=DEFAULT_CAPS_STR ): """ - Create a pad from the given template components. + Create a pad from the given template components. """ return Gst.PadTemplate.new(name, direction, presence, Gst.Caps.from_string(caps_str)) @@ -47,22 +59,22 @@ def get_components(cformat): """ Get the number of components for each pixel format, including padded components such as in RGBx. """ - if cformat in ("RGB","BGR"): + if cformat in ("RGB", "BGR"): return 3 - if cformat in ("RGBx","BGRx","xRGB","xBGR","RGBA","BGRA","ARGB","ABGR"): + if cformat in ("RGBx", "BGRx", "xRGB", "xBGR", "RGBA", "BGRA", "ARGB", "ABGR"): return 4 - if cformat in ("GRAY8","GRAY16_BE","GRAY16_LE"): + if cformat in ("GRAY8", "GRAY16_BE", "GRAY16_LE"): return 1 - raise ValueError(f"Format `{cformat}` does not have a known number of components.") - + raise ValueError(f"Format `{cformat}` does not have a known number of components.") + @contextmanager def map_buffer_to_numpy(buffer, flags, caps, dtype=None): """ Map the given buffer with the given flags and the capabilities from its associated pad. The dtype is inferred if not given which may be inaccurate for certain formats. The context object is a Numpy array for the buffer which is - unmapped when the context exits. + unmapped when the context exits. """ cstruct = caps.get_structure(0) height = cstruct.get_value("height") @@ -75,7 +87,7 @@ def map_buffer_to_numpy(buffer, flags, caps, dtype=None): if dtype is None: dtype = get_dtype_from_bits(ifstruct.bits) - dtype=np.dtype(dtype) + dtype = np.dtype(dtype) is_mapped, map_info = buffer.map(flags) if not is_mapped: raise ValueError(f"Buffer {buffer} failed to map with flags `{flags}`.") @@ -92,8 +104,21 @@ def map_buffer_to_numpy(buffer, flags, caps, dtype=None): # TODO: byte order for gray formats bufarray = np.ndarray(shape, dtype=dtype, buffer=map_info.data) - + try: yield bufarray finally: buffer.unmap(map_info) + + +@contextmanager +def map_buffer_to_tensor(buffer, flags, caps, dtype=None): + with map_buffer_to_numpy(buffer, flags, caps, dtype) as npbuf: + yield torch.as_tensor(npbuf) + + +def get_buffer_tensor(buffer, caps, flags=Gst.MapFlags.WRITE, dtype=None, device="cpu"): + with map_buffer_to_tensor(buffer, flags, caps, dtype) as tbuf: + out = torch.zeros_like(tbuf, device=device) + out[:] = tbuf + return out diff --git a/pyproject.toml b/pyproject.toml index 8066c35..2202622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "MONAIStream" +name = "monaistream" readme = "README.md" requires-python = ">=3.9" license = {text = "Apache2.0"} diff --git a/requirements-dev.txt b/requirements-dev.txt index 004e447..730702a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ -r requirements.txt +pytorch-ignite pep8-naming pycodestyle pyflakes @@ -9,3 +10,4 @@ isort coverage pytype>=2020.6.1; platform_system != "Windows" mypy>=0.790 +parameterized diff --git a/runtests.sh b/runtests.sh index 0b7fd61..0f24b64 100755 --- a/runtests.sh +++ b/runtests.sh @@ -128,7 +128,7 @@ function clean_py() { find ${TO_CLEAN} -maxdepth 1 -type f -name ".coverage.*" -delete find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".eggs" -exec rm -r "{}" + - find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "monaistream.egg-info" -exec rm -r "{}" + + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "MONAIStream.egg-info" -exec rm -r "{}" + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "build" -exec rm -r "{}" + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name "dist" -exec rm -r "{}" + find ${TO_CLEAN} -depth -maxdepth 1 -type d -name ".mypy_cache" -exec rm -r "{}" + @@ -416,6 +416,6 @@ fi # report on coverage if [ $doCoverage = true ]; then echo "${separator}${blue}coverage${noColor}" - ${cmdPrefix}${PY_EXE} -m coverage combine --append .coverage/ - ${cmdPrefix}${PY_EXE} -m coverage report + # ${cmdPrefix}${PY_EXE} -m coverage combine --append .coverage/ + ${cmdPrefix}${PY_EXE} -m coverage report -i fi diff --git a/tests/test_bundles/blur/configs/stream.json b/tests/test_bundles/blur/configs/stream.json new file mode 100644 index 0000000..48dfd5a --- /dev/null +++ b/tests/test_bundles/blur/configs/stream.json @@ -0,0 +1,61 @@ +{ + "imports": [ + "$import numpy" + ], + "bundle_root": ".", + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "network_def": { + "_target_": "GaussianFilter", + "spatial_dims": 2, + "sigma": 10.0 + }, + "network": "$@network_def.to(@device)", + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "EnsureChannelFirst", + "channel_dim": 2 + }, + { + "_target_": "EnsureType", + "dtype": "$torch.float32" + }, + { + "_target_": "ScaleIntensity" + } + ] + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "ScaleIntensityRanged", + "keys": "pred", + "a_min": 0, + "a_max": 1.0, + "b_min": 0, + "b_max": 255.0, + "clip": true, + "dtype": "$numpy.uint8" + }, + { + "_target_": "AsChannelLastd", + "keys": "pred" + } + ] + }, + "streamer": { + "_target_": "monaistream.CallableEvaluator", + "device": "@device", + "network": "@network", + "preprocessing": "@preprocessing", + "postprocessing": "@postprocessing", + "use_interrupt": false + }, + "initialize": [], + "run": [ + "@streamer" + ], + "finalize": [] +} diff --git a/tests/test_numpy_transforms.py b/tests/test_numpy_transforms.py deleted file mode 100644 index d7a4c14..0000000 --- a/tests/test_numpy_transforms.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -from tempfile import TemporaryDirectory -from subprocess import check_call, CalledProcessError -import unittest -from tests.utils import SkipIfNoModule - - -@SkipIfNoModule("gi") -class TestNumpyInplaceTransform(unittest.TestCase): - def test_import(self): - """ - Test importation of the transform. - """ - from monaistream.gstreamer import NumpyInplaceTransform - - def test_pipeline(self): - """ - Test the transform can be loaded with `parse_launchv`. - """ - from gi.repository import Gst - - pipeline = Gst.parse_launchv(["videotestsrc", "numpyinplacetransform"]) - self.assertIsNotNone(pipeline) - - def test_gst_launch(self): - """ - Test launching a separate pipeline subprocess with gst-launch-1.0 correctly imports the transform. - """ - pipeline = "videotestsrc num-buffers=1 ! numpyinplacetransform ! jpegenc ! filesink location=img.jpg" - - with TemporaryDirectory() as td: - try: - check_call(["gst-launch-1.0"] + list(pipeline.split()), cwd=td) - except CalledProcessError as cpe: - print("Output gst-launch-1.0:\n", repr(cpe.output), file=sys.stderr) - raise - - self.assertTrue(os.path.isfile(os.path.join(td, "img.jpg"))) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_stream_runner.py b/tests/test_stream_runner.py new file mode 100644 index 0000000..a890600 --- /dev/null +++ b/tests/test_stream_runner.py @@ -0,0 +1,236 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import logging +import unittest +from typing import Any, Callable +from tempfile import TemporaryDirectory + +import torch +from monai.handlers import MeanSquaredError, from_engine +from monai.bundle import ConfigWorkflow +from monai.utils import CommonKeys, first +from parameterized import parameterized + +from monaistream.gstreamer import Gst, GstBase, GObject +from monaistream.gstreamer.utils import get_video_pad_template, map_buffer_to_tensor, get_buffer_tensor +from monaistream import SingleItemDataset, RingBufferDataset, CallableEvaluator +from monaistream.gstreamer.launch import default_loop_runner +from tests.utils import SkipIfNoModule + + +DEVICES = ["cpu"] +if torch.cuda.is_available: + DEVICES.append("cuda:0") + + +class TensorCallbackTransform(GstBase.BaseTransform): + __gstmetadata__ = ("Tensor Callback Transform", "Transform", "Description", "Author") # TODO: correct info + + __gsttemplates__ = ( + get_video_pad_template("src", Gst.PadDirection.SRC), + get_video_pad_template("sink", Gst.PadDirection.SINK), + ) + + def __init__(self, trans_fn: Callable | None = None): + super().__init__() + self.trans_fn = trans_fn + self.device = "cpu" + + def do_transform(self, inbuf: Gst.Buffer, outbuf: Gst.Buffer) -> Gst.FlowReturn: + intensor = get_buffer_tensor(inbuf, self.srcpad.get_current_caps(), device=self.device) + with map_buffer_to_tensor(outbuf, Gst.MapFlags.WRITE, self.sinkpad.get_current_caps()) as outtensor: + outtensor[:] = self.trans_fn(intensor) + + return Gst.FlowReturn.OK + + +class TestSingleItemDataset(unittest.TestCase): + def setUp(self): + self.rand_input = torch.rand(1, 3, 3) + + def test_single_input(self): + ds = SingleItemDataset() + ds.set_item(self.rand_input) + out = first(ds) + + self.assertEqual(out.shape, (1,) + tuple(self.rand_input.shape)) + + def test_list_input(self): + ds = SingleItemDataset() + ds.set_item([self.rand_input] * 2) + out = first(ds) + + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 2) + self.assertEqual(out[0].shape, (1,) + tuple(self.rand_input.shape)) + self.assertEqual(out[1].shape, (1,) + tuple(self.rand_input.shape)) + + +class TestRingBufferDataset(unittest.TestCase): + def setUp(self): + self.rand_input = torch.rand(1, 3, 3) + + def test_single_input(self): + ds = RingBufferDataset(5) + ds.set_item(self.rand_input) + + out = first(ds) + + self.assertIsInstance(out, tuple) + self.assertEqual(len(out), 5) + + for i in out: + self.assertEqual(i.shape, (1,) + tuple(self.rand_input.shape)) + + +@SkipIfNoModule("ignite") +class TestStreamRunner(unittest.TestCase): + def setUp(self): + self.rand_input = torch.rand(1, 3, 5) + self.bundle_dir = os.path.dirname(__file__) + "/test_bundles/blur" + # fileConfig(os.path.join(self.bundle_dir, "configs","logging.conf")) + + @parameterized.expand(DEVICES) + def test_single_input(self, device): + net = torch.nn.Identity() + engine = CallableEvaluator(network=net, device=device, use_interrupt=False) + + result = engine(self.rand_input) + + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.shape, self.rand_input.shape) + self.assertEqual(result.device, torch.device(device)) + + @parameterized.expand(DEVICES) + def test_two_inputs(self, device): + net = torch.nn.Identity() + engine = CallableEvaluator(network=net, device=device, use_interrupt=False) + + result1 = engine(self.rand_input.to(device)) + result2 = engine(self.rand_input.to(device)) + + self.assertIsInstance(result1, torch.Tensor) + self.assertIsInstance(result2, torch.Tensor) + + self.assertEqual(result1.shape, self.rand_input.shape) + self.assertEqual(result2.shape, self.rand_input.shape) + + self.assertEqual(engine.state.iteration, 1) + + # @parameterized.expand(DEVICES) + # def test_ring_buffer(self, device): + # from monai.engines.utils import default_prepare_batch, PrepareBatch + + # class TuplePrepareBatch(PrepareBatch): + # def __call__( + # self, + # batchdata: dict[str, torch.Tensor], + # device: str | torch.device | None = None, + # non_blocking: bool = False, + # **kwargs: Any, + # ) -> Any: + # assert isinstance(batchdata, tuple) + # batches=[default_prepare_batch(b, device, non_blocking, **kwargs) for b in batchdata] + # if isinstance(batches[0], tuple) and len(batches[0])==2: + # inputs,outputs=zip(*batches) + # else: + # inputs=tuple(batches) + # outputs=(None,)*len(inputs) + + # return tuple(inputs),tuple(outputs) + + # class FakeMultiInputNet(torch.nn.Module): + # def forward(self,x): + # assert isinstance(x, tuple) + # assert isinstance(x[0], torch.Tensor), type(x[0]) + # return torch.as_tensor([i.mean() for i in x]) + + # with self.subTest("Identity Net"): + # engine = CallableEvaluator( + # data_loader=RingBufferDataset(5), + # prepare_batch=TuplePrepareBatch(), + # network=torch.nn.Identity(), + # device=device, + # use_interrupt=False, + # ) + + # result = engine(self.rand_input) + + # self.assertIsInstance(result, list) + # self.assertEqual(len(result), 5) + + # for r,_ in result: + # self.assertIsInstance(r, torch.Tensor) + # self.assertEqual(r.shape, self.rand_input.shape) + # self.assertEqual(r.device, torch.device(device)) + + # with self.subTest("MultiInput Net"): + # engine = CallableEvaluator( + # data_loader=RingBufferDataset(5), + # prepare_batch=TuplePrepareBatch(), + # network=FakeMultiInputNet(), + # device=device, + # use_interrupt=False, + # ) + + # result = engine(self.rand_input) + + # self.assertIsInstance(result, torch.Tensor) + # self.assertEqual(result.shape,(1,5)) + + @parameterized.expand(DEVICES) + def test_metric(self, device): + net = torch.nn.Identity() + metric = MeanSquaredError(output_transform=from_engine([CommonKeys.IMAGE, CommonKeys.PRED])) + engine = CallableEvaluator(network=net, device=device, key_val_metric={"mse": metric}, use_interrupt=False) + + result, mets = engine(self.rand_input.to(device), include_metrics=True) + + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.shape, self.rand_input.shape) + self.assertIsInstance(mets, dict) + self.assertIn("mse", mets) + self.assertEqual(mets["mse"], 0) + + @parameterized.expand(DEVICES) + def test_bundle_stream(self, device): + bw = ConfigWorkflow( + self.bundle_dir + "/configs/stream.json", self.bundle_dir + "/configs/metadata.json", workflow_type="infer" + ) + bw.device = device + + bw.initialize() + cb = bw.run() + self.assertEqual(len(cb), 1) + self.assertIsInstance(cb[0], CallableEvaluator) + + with TemporaryDirectory() as td: + RunnerType = GObject.type_register(TensorCallbackTransform) + Gst.Element.register(None, "tensorcallbacktransform", Gst.Rank.NONE, RunnerType) + img = os.path.join(td, "img.jpg") + + pipeline = Gst.parse_launch( + f"videotestsrc num-buffers=1 ! tensorcallbacktransform name=t ! jpegenc ! filesink location={img}" + ) + + tcbt = pipeline.get_by_name("t") + tcbt.device = device + tcbt.trans_fn = cb[0] + + default_loop_runner(pipeline, None) + self.assertTrue(os.path.isfile(img)) + + +if __name__ == "__main__": + unittest.main()