Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ jobs:

steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
# - name: Setup GStreamer
# run: sudo apt-get install -y --no-install-recommends ca-certificates build-essential \
# python3-gst-1.0 gstreamer1.0-python3-plugin-loader gstreamer1.0-plugins-base gstreamer1.0-tools gstreamer1.0-libav \
# gstreamer1.0-plugins-base-apps gstreamer1.0-plugins-good gstreamer1.0-plugins-bad gstreamer1.0-plugins-ugly \
# gstreamer1.0-x gstreamer1.0-alsa gstreamer1.0-gl gstreamer1.0-gtk3 gstreamer1.0-qt5 gstreamer1.0-pulseaudio
- name: Set up Python 3.12
uses: actions/setup-python@v2
with:
python-version: 3.9
python-version: 3.12
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -31,10 +36,10 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=121 --statistics
- name: Run verify program
run: |
PYTHONPATH=. python stream/verify.py
PYTHONPATH=. python monaistream/verify.py
# - name: Test with unittest
# run: |
# ./runtests.sh --coverage
# ./runtests.sh --coverage -u
# - name: Upload coverage
# uses: codecov/codecov-action@v1
# with:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Experimental revival of this repo, please comment but don't expect anything to b

2. Run `xhost +local:docker` to grant X permissions to Docker containers

3. Run: `docker run -ti --rm -e DISPLAY --gpus device=1 -v $PWD:/opt/monaistream monaistream`
3. Run: `docker run -ti --rm -e DISPLAY --network host --gpus device=1 -v $PWD:/opt/monaistream monaistream`


## Numpy Transform Test
Expand Down
1 change: 1 addition & 0 deletions monaistream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .threadsafe import *
from .verify import *
from .callable_evaluator import *
100 changes: 49 additions & 51 deletions monaistream/simple_inference.py → monaistream/callable_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,70 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Callable, Sequence
from monai.inferers import Inferer
from monai.transforms import apply_transform, Transform
from monai.engines import SupervisedEvaluator, default_metric_cmp_fn, default_prepare_batch
from monai.utils import ForwardMode,CommonKeys
from monai.data import Dataset
from monai.handlers import MeanSquaredError, from_engine
import logging
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import torch
from torch.nn import Module

from monai.utils import IgniteInfo, min_version, optional_import
from monai.data import Dataset, DataLoader
from monai.engines import SupervisedEvaluator, default_metric_cmp_fn, default_prepare_batch
from monai.inferers import Inferer
from monai.transforms import Transform
from monai.utils import CommonKeys, ForwardMode, min_version, optional_import
from monai.utils.enums import IgniteInfo

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
from ignite.engine import Engine, Events, EventEnum
from ignite.metrics import Metric
else:
version = IgniteInfo.OPT_IMPORT_VERSION
Engine, _ = optional_import("ignite.engine", version, min_version, "Engine", as_type="decorator")
Metric, _ = optional_import("ignite.metrics", version, min_version, "Metric", as_type="decorator")
Events, _ = optional_import("ignite.engine", version, min_version, "Events", as_type="decorator")
EventEnum, _ = optional_import("ignite.engine", version, min_version, "EventEnum", as_type="decorator")


class SimpleInferenceEngine:
"""
A simple engine-like class is for running inference on a per-input basis, such as with per-frame data in a
video stream. It relies on a supplied Inferer instance and a network.
"""

def __init__(
self, inferer: Inferer, network: Module, preprocess: Callable | None = None, postprocess: Callable | None = None
):
self.inferer = inferer
self.network = network
self.preprocess = preprocess
self.postprocess = postprocess

def __call__(self, inputs: torch.Tensor, *args: Any, **kwargs: Any) -> Any:
if self.preprocess:
inputs = apply_transform(self.preprocess, inputs)

outputs = self.inferer(inputs, self.network, *args, **kwargs)

if self.postprocess:
outputs = apply_transform(self.postprocess, outputs)

return outputs
__all__ = ["SingleItemDataset", "RingBufferDataset", "CallableEvaluator"]


class SingleItemDataset(Dataset):
"""
This simple dataset only ever has one item and acts as its own iterable. This is used with InferenceEngine to
This simple dataset only ever has one item and acts as its own iterable. This is used with InferenceEngine to
represent a changeable single item epoch.
"""

def __init__(self, transform: Sequence[Callable] | Callable | None = None) -> None:
super().__init__([None], transform)

def set_item(self, item):
self.data[0] = item

def __iter__(self):
yield self.data[0]
item = self[0]

# TODO: use standard way of adding batch dimensions, or do something specific here
# for how groups of frames would be passed?
if isinstance(item, torch.Tensor):
yield item[None]
elif isinstance(item, Sequence):
yield tuple(v[None] for v in item)
else:
yield {k: v[None] for k, v in item.items()}


class RingBufferDataset(SingleItemDataset):
def __init__(self, num_values=1, transform: Sequence[Callable] | Callable | None = None) -> None:
super().__init__(transform)
self.num_values = num_values

def set_item(self, item):
if self.data[0] is None:
self.data[0] = (item,) * self.num_values
else:
self.data[0] = self.data[0][1:] + (item,)


class InferenceEngine(SupervisedEvaluator):
class CallableEvaluator(SupervisedEvaluator):
"""
A simple inference engine type for applying inference to one input at a time as a callable. This is meant to be used
for inference on per-frame video stream data where the state of the engine and other setup should be done initially
Expand All @@ -84,6 +84,7 @@ def __init__(
self,
device: torch.device,
network: torch.nn.Module,
data_loader: Iterable | DataLoader | None = None,
preprocessing: Transform | None = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
Expand All @@ -102,12 +103,13 @@ def __init__(
amp_kwargs: dict | None = None,
compile: bool = False,
compile_kwargs: dict | None = None,
use_interrupt: bool = False,
) -> None:
super().__init__(
device=device,
val_data_loader=SingleItemDataset(preprocessing),
val_data_loader=data_loader if data_loader is not None else SingleItemDataset(preprocessing),
epoch_length=1,
network=network,
network=network, # TODO: auto-convert to given device?
inferer=inferer,
non_blocking=non_blocking,
prepare_batch=prepare_batch,
Expand All @@ -128,24 +130,20 @@ def __init__(
compile_kwargs=compile_kwargs,
)

self.logger.setLevel(logging.ERROR) # probably don't want output for every frame
self.use_interrupt = use_interrupt

if use_interrupt:
self.add_event_handler(Events.ITERATION_COMPLETED, self.interrupt)

def __call__(self, item: Any, include_metrics: bool = False) -> Any:
self.data_loader.set_item(item)

self.run()

out = self.state.output[0][CommonKeys.PRED]

if include_metrics:
return out, dict(engine.state.metrics)
return out, dict(self.state.metrics)
else:
return out


if __name__ == "__main__":
net = torch.nn.Identity()
engine = InferenceEngine(
network=net,
device="cpu",
key_val_metric={"mse": MeanSquaredError(output_transform=from_engine([CommonKeys.IMAGE, CommonKeys.PRED]))},
)
print(engine(torch.rand(1, 5, 5)))
print(engine(torch.rand(1, 6, 6), True))
5 changes: 3 additions & 2 deletions monaistream/gstreamer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@
os.environ["GST_PLUGIN_PATH"] = plugin_path # set the plugin path so that this directory is searched

gi.require_version("Gst", "1.0")
gi.require_version("GObject", "2.0")
gi.require_version("GstBase", "1.0")
gi.require_version("GstVideo", "1.0")
from gi.repository import Gst
from gi.repository import Gst, GLib, GObject, GstBase, GstVideo

Gst.init([])
# use GST_DEBUG instead https://gstreamer.freedesktop.org/documentation/gstreamer/running.html
# Gst.debug_set_active(True)
# Gst.debug_set_default_threshold(5)

from monaistream.gstreamer.utils import *
from monaistream.gstreamer.numpy_transforms import *
# from monaistream.gstreamer.numpy_transforms import *

# TODO: import more things here

Expand Down
87 changes: 87 additions & 0 deletions monaistream/gstreamer/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from typing import Callable
from monaistream.gstreamer import Gst, GLib


def default_message_handler(bus: Gst.Bus, message: Gst.Message, loop: GLib.MainLoop):
if message.type == Gst.MessageType.EOS:
loop.quit()
elif message.type == Gst.MessageType.ERROR:
err, debug = message.parse_error()
print(err, debug, file=sys.stderr)
loop.quit()
elif message.type == Gst.MessageType.WARNING:
err, debug = message.parse_warning()
print(err, debug, file=sys.stderr)

return True


def default_loop_runner(
pipeline: Gst.Pipeline,
loop: GLib.MainLoop | None = None,
message_handler: Callable | None = default_message_handler,
):
if loop is None:
loop = GLib.MainLoop()

if message_handler is not None:
bus = pipeline.get_bus()
bus.add_signal_watch()
bus.connect("message", message_handler, loop)

pipeline.set_state(Gst.State.PLAYING)

try:
loop.run()
except KeyboardInterrupt:
raise
finally:
pipeline.set_state(Gst.State.NULL)
if loop and loop.is_running():
loop.quit()


def launch(args: list[str] | str, message_handler: Callable = default_message_handler):
"""
Defines a launching program for GStreamer like gst-launch but with MONAIStream plugin classes loaded. The `args`
argument is a list of strings containing a single full pipeline command line, a list of strings containing a
transform and its arguments (with no ! separator), or a list of words as if on the command line (with ! separator).

Example::

python -m monaistream.gstreamer.launch videotestsrc num-buffers=1 ! jpegenc ! filesink location=img.jpg

"""

if isinstance(args, str):
command = args
else:
args = list(map(str, args))
if not args:
raise ValueError("No arguments provided, a list of elements or a pipeline string is required.")

if len(args) == 1:
command = args[0]
elif "!" in args:
command = " ".join(args)
else:
command = " ! ".join(args)

pipeline = Gst.parse_launch(command)
default_loop_runner(pipeline, None, message_handler)


if __name__ == "__main__":
launch(sys.argv[1:])
Loading