Skip to content

Commit

Permalink
Merge pull request #1771 from wkentaro/enhance-vgg16
Browse files Browse the repository at this point in the history
Enhance vgg16 object recognition node
  • Loading branch information
k-okada authored Jul 7, 2016
2 parents 6f72a36 + c819cd2 commit cc71ce3
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions jsk_perception/node_scripts/vgg16_object_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import cv_bridge
from jsk_topic_tools import ConnectionBasedTransport
from jsk_topic_tools.log_utils import logerr_throttle
from jsk_recognition_utils.chainermodels import VGG16BatchNormalization
from jsk_recognition_msgs.msg import ClassificationResult
import message_filters
Expand Down Expand Up @@ -46,8 +47,12 @@ def __init__(self):

def subscribe(self):
if rospy.get_param('~use_mask', False):
sub = message_filters.Subscriber('~input', Image)
sub_mask = message_filters.Subscriber('~input/mask', Image)
# larger buff_size is necessary for taking time callback
# http://stackoverflow.com/questions/26415699/ros-subscriber-not-up-to-date/29160379#29160379 # NOQA
sub = message_filters.Subscriber(
'~input', Image, queue_size=1, buff_size=2**24)
sub_mask = message_filters.Subscriber(
'~input/mask', Image, queue_size=1, buff_size=2**24)
self.subs = [sub, sub_mask]
queue_size = rospy.get_param('~queue_size', 10)
if rospy.get_param('~approximate_sync', False):
Expand All @@ -59,7 +64,9 @@ def subscribe(self):
self.subs, queue_size=queue_size)
sync.registerCallback(self._recognize)
else:
sub = rospy.Subscriber('~input', Image, self._recognize, callback_args=None)
sub = rospy.Subscriber(
'~input', Image, self._recognize, callback_args=None,
queue_size=1, buff_size=2**24)
self.subs = [sub]

def unsubscribe(self):
Expand All @@ -71,6 +78,13 @@ def _recognize(self, imgmsg, mask_msg=None):
bgr = bridge.imgmsg_to_cv2(imgmsg, desired_encoding='bgr8')
if mask_msg is not None:
mask = bridge.imgmsg_to_cv2(mask_msg)
if mask.shape != bgr.shape[:2]:
logerr_throttle(10,
'Size of input image and mask is different')
return
elif mask.size == 0:
logerr_throttle(10, 'Size of input mask is 0')
return
bgr[mask == 0] = self.mean_bgr
bgr = skimage.transform.resize(bgr, (224, 224), preserve_range=True)
input_msg = bridge.cv2_to_imgmsg(bgr.astype(np.uint8), encoding='bgr8')
Expand Down

0 comments on commit cc71ce3

Please sign in to comment.