Skip to content

Commit

Permalink
Add mmseg2torchserve tool (open-mmlab#552)
Browse files Browse the repository at this point in the history
* Add docker/serve

* Add handler

* Add mmseg2torchserve

* Fix mmv minimum version

* Update docs with model serving section

* Update useful_tools.md

* pre-commit

* Update useful_tools.md

* Add 3dogs to resources

* Move mask to resources
  • Loading branch information
daavoo authored Jul 5, 2021
1 parent 2fd8e60 commit 61ca8c7
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 1 deletion.
47 changes: 47 additions & 0 deletions docker/serve/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
ARG PYTORCH="1.6.0"
ARG CUDA="10.1"
ARG CUDNN="7"
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel

ARG MMCV="1.3.1"
ARG MMSEG="0.13.0"

ENV PYTHONUNBUFFERED TRUE

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
ca-certificates \
g++ \
openjdk-11-jre-headless \
# MMDet Requirements
ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
&& rm -rf /var/lib/apt/lists/*

ENV PATH="/opt/conda/bin:$PATH"
RUN export FORCE_CUDA=1

# TORCHSEVER
RUN pip install torchserve torch-model-archiver

# MMLAB
RUN pip install mmcv-full==${MMCV} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
RUN pip install mmsegmentation==${MMSEG}

RUN useradd -m model-server \
&& mkdir -p /home/model-server/tmp

COPY entrypoint.sh /usr/local/bin/entrypoint.sh

RUN chmod +x /usr/local/bin/entrypoint.sh \
&& chown -R model-server /home/model-server

COPY config.properties /home/model-server/config.properties
RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store

EXPOSE 8080 8081 8082

USER model-server
WORKDIR /home/model-server
ENV TEMP=/home/model-server/tmp
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["serve"]
5 changes: 5 additions & 0 deletions docker/serve/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
model_store=/home/model-server/model-store
load_models=all
12 changes: 12 additions & 0 deletions docker/serve/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
set -e

if [[ "$1" = "serve" ]]; then
shift 1
torchserve --start --ts-config /home/model-server/config.properties
else
eval "$@"
fi

# prevent docker exit
tail -f /dev/null
61 changes: 61 additions & 0 deletions docs/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,64 @@ Examples:
```shell
python tools/analyze_logs.py log.json --keys loss --legend loss
```

## Model Serving

In order to serve an `MMSegmentation` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:

### 1. Convert model from MMSegmentation to TorchServe

```shell
python tools/mmseg2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
--output-folder ${MODEL_STORE} \
--model-name ${MODEL_NAME}
```

**Note**: ${MODEL_STORE} needs to be an absolute path to a folder.

### 2. Build `mmseg-serve` docker image

```shell
docker build -t mmseg-serve:latest docker/serve/
```

### 3. Run `mmseg-serve`

Check the official docs for [running TorchServe with docker](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment).

In order to run in GPU, you need to install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). You can omit the `--gpus` argument in order to run in CPU.

Example:

```shell
docker run --rm \
--cpus 8 \
--gpus device=0 \
-p8080:8080 -p8081:8081 -p8082:8082 \
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
mmseg-serve:latest
```

[Read the docs](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md) about the Inference (8080), Management (8081) and Metrics (8082) APis

### 4. Test deployment

```shell
curl -O https://raw.githubusercontent.com/open-mmlab/mmsegmentation/master/resources/3dogs.jpg
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T 3dogs.jpg -o 3dogs_mask.png
```

The response will be a ".png" mask.

You can visualize the output as follows:

```python
import matplotlib.pyplot as plt
import mmcv
plt.imshow(mmcv.imread("3dogs_mask.png", "grayscale"))
plt.show()
```

You should see something similar to:

![3dogs_mask](../resources/3dogs_mask.png)
Binary file added resources/3dogs.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resources/3dogs_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
110 changes: 110 additions & 0 deletions tools/mmseg2torchserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from argparse import ArgumentParser, Namespace
from pathlib import Path
from tempfile import TemporaryDirectory

import mmcv

try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None


def mmseg2torchserve(
config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False,
):
"""Converts mmsegmentation model (config + checkpoint) to TorchServe
`.mar`.
Args:
config_file:
In MMSegmentation config format.
The contents vary for each task repository.
checkpoint_file:
In MMSegmentation checkpoint format.
The contents vary for each task repository.
output_folder:
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name:
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
"""
mmcv.mkdir_or_exist(output_folder)

config = mmcv.Config.fromfile(config_file)

with TemporaryDirectory() as tmpdir:
config.dump(f'{tmpdir}/config.py')

args = Namespace(
**{
'model_file': f'{tmpdir}/config.py',
'serialized_file': checkpoint_file,
'handler': f'{Path(__file__).parent}/mmseg_handler.py',
'model_name': model_name or Path(checkpoint_file).stem,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)


def parse_args():
parser = ArgumentParser(
description='Convert mmseg models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()

return args


if __name__ == '__main__':
args = parse_args()

if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')

mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)
53 changes: 53 additions & 0 deletions tools/mmseg_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import base64
import io
import os

import cv2
import mmcv
import torch
from ts.torch_handler.base_handler import BaseHandler

from mmseg.apis import inference_segmentor, init_segmentor


class MMsegHandler(BaseHandler):

def initialize(self, context):
properties = context.system_properties
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(self.map_location + ':' +
str(properties.get('gpu_id')) if torch.cuda.
is_available() else self.map_location)
self.manifest = context.manifest

model_dir = properties.get('model_dir')
serialized_file = self.manifest['model']['serializedFile']
checkpoint = os.path.join(model_dir, serialized_file)
self.config_file = os.path.join(model_dir, 'config.py')

self.model = init_segmentor(self.config_file, checkpoint, self.device)
self.initialized = True

def preprocess(self, data):
images = []

for row in data:
image = row.get('data') or row.get('body')
if isinstance(image, str):
image = base64.b64decode(image)
image = mmcv.imfrombytes(image)
images.append(image)

return images

def inference(self, data, *args, **kwargs):
results = [inference_segmentor(self.model, img) for img in data]
return results

def postprocess(self, data):
output = []
for image_result in data:
buffer = io.BytesIO()
_, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
output.append(buffer.tobytes())
return output

0 comments on commit 61ca8c7

Please sign in to comment.