Skip to content

Commit 02be5e2

Browse files
Example files for blog post: Getting started with MAX (modular#1834)
1 parent 3f908b7 commit 02be5e2

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Create the model repository and download ResNet-50 model
2+
MODEL_REPOSITORY=model-repository
3+
MODEL_NAME=resnet50
4+
SAVED_MODEL_DIR=resnet50_saved_model
5+
6+
mkdir -p $MODEL_REPOSITORY/$MODEL_NAME/1
7+
cp -a $SAVED_MODEL_DIR $MODEL_REPOSITORY/$MODEL_NAME/1/
8+
9+
# Create Triton config
10+
cat >$MODEL_REPOSITORY/$MODEL_NAME/config.pbtxt <<EOL
11+
instance_group {
12+
kind: KIND_CPU
13+
}
14+
default_model_filename: "$SAVED_MODEL_DIR"
15+
backend: "max"
16+
EOL
17+
18+
# run the recently built max_serving_local container
19+
docker run -it --rm --network=host \
20+
-v $PWD/$MODEL_REPOSITORY/:/models \
21+
public.ecr.aws/modular/max-serving-de \
22+
tritonserver --model-repository=/models --model-control-mode=explicit \
23+
--load-model=$MODEL_NAME
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
2+
import numpy as np
3+
import tritonclient.http as httpclient
4+
from PIL import Image
5+
import tensorflow as tf
6+
import numpy as np
7+
import tensorflow as tf
8+
9+
### Triton client ###
10+
client = httpclient.InferenceServerClient(url="localhost:8000")
11+
12+
### Image pre-processing ###
13+
def image_preprocess(img):
14+
img = np.asarray(img.resize((224, 224)))
15+
img = np.stack([img])
16+
img = tf.keras.applications.resnet50.preprocess_input(img)
17+
return img
18+
19+
### Image to classify ###
20+
img= Image.open('max/examples/inference/resnet50-python-tensorflow/input/leatherback_turtle.jpg')
21+
img = image_preprocess(img)
22+
23+
### Inference request format ###
24+
inputs = httpclient.InferInput("input_1",
25+
img.shape,
26+
datatype="FP32")
27+
inputs.set_data_from_numpy(img, binary_data=True)
28+
29+
outputs = httpclient.InferRequestedOutput("predictions", binary_data=True, class_count=1000)
30+
31+
### Submit inference request ###
32+
results = client.infer(model_name="resnet50",
33+
inputs=[inputs],
34+
outputs=[outputs])
35+
inference_output = results.as_numpy('predictions')
36+
37+
### Process request ###
38+
idx = [int(out.decode().split(':')[1]) for out in inference_output]
39+
probs = [float(out.decode().split(':')[0]) for out in inference_output]
40+
41+
### Decoding predictions ###
42+
probs = np.array(probs)[np.argsort(idx)]
43+
print(tf.keras.applications.resnet.decode_predictions(np.expand_dims(probs, axis=0), top=5))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
import shutil
3+
import tensorflow as tf
4+
import numpy as np
5+
from tensorflow.keras.applications.resnet50 import ResNet50
6+
from PIL import Image
7+
8+
def load_save_resnet50_model(saved_model_dir = 'resnet50_saved_model'):
9+
model = ResNet50(weights='imagenet')
10+
shutil.rmtree(saved_model_dir, ignore_errors=True)
11+
model.save(saved_model_dir, include_optimizer=False, save_format='tf')
12+
saved_model_dir = 'resnet50_saved_model'
13+
load_save_resnet50_model(saved_model_dir)
14+
15+
#============================================#
16+
### MAX Engine Python API ###
17+
from max import engine
18+
sess = engine.InferenceSession()
19+
model = sess.load('resnet50_saved_model')
20+
#============================================#
21+
22+
def image_preprocess(img, reps=1):
23+
img = np.asarray(img.resize((224, 224)))
24+
img = np.stack([img]*reps)
25+
img = tf.keras.applications.resnet50.preprocess_input(img)
26+
return img
27+
28+
img= Image.open('max/examples/inference/resnet50-python-tensorflow/input/leatherback_turtle.jpg')
29+
img = image_preprocess(img)
30+
31+
### MAX Engine Python API ###
32+
#============================================#
33+
outputs = model.execute(input_1=img)
34+
#============================================#
35+
36+
probs = np.array(outputs['predictions'][0])
37+
print(tf.keras.applications.resnet.decode_predictions(np.expand_dims(probs, axis=0), top=5))

0 commit comments

Comments
 (0)