Skip to content

Commit

Permalink
Add tensorpack changes (armandmcqueen#64)
Browse files Browse the repository at this point in the history
* add tensorpack changes

* minor change

* name change

* name change
  • Loading branch information
YangFei1990 authored Jun 12, 2019
1 parent 3eb6901 commit 5fb96b1
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 29 deletions.
2 changes: 2 additions & 0 deletions MaskRCNN/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def __ne__(self, _):
_C.RPN.TRAIN_PER_LEVEL_NMS_TOPK = 2000
_C.RPN.TEST_PER_LEVEL_NMS_TOPK = 1000
_C.RPN.TOPK_PER_IMAGE = True
_C.RPN.UNQUANTIZED_ANCHOR = True # From tensorpack https://github.com/tensorpack/tensorpack/commit/141ab53cc37dce728802803747584fc0fb82863b
_C.RPN.SLOW_ACCURATE_MASK = True # If on, mask calculation will be slower but more accurate. From tensorpack https://github.com/tensorpack/tensorpack/commit/141ab53cc37dce728802803747584fc0fb82863b

# fastrcnn training ---------------------
_C.FRCNN.BATCH_PER_IM = 512
Expand Down
31 changes: 22 additions & 9 deletions MaskRCNN/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,28 @@ def get_all_anchors(stride=None, sizes=None, tile=True):
# Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
# are centered on stride / 2, have (approximate) sqrt areas of the specified
# sizes, and aspect ratios as given.
cell_anchors = generate_anchors(
stride,
scales=np.array(sizes, dtype=np.float) / stride,
ratios=np.array(cfg.RPN.ANCHOR_RATIOS, dtype=np.float))
if not cfg.RPN.UNQUANTIZED_ANCHOR:
cell_anchors = generate_anchors(
stride,
scales=np.array(sizes, dtype=np.float) / stride,
ratios=np.array(cfg.RPN.ANCHOR_RATIOS, dtype=np.float))
else:
anchors = []
ratios=np.array(cfg.RPN.ANCHOR_RATIOS, dtype=np.float)
for sz in sizes:
for ratio in ratios:
w = np.sqrt(sz * sz / ratio)
h = ratio * w
anchors.append([-w, -h, w, h])
cell_anchors = np.asarray(anchors) * 0.5
# anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)

if tile:
max_size = cfg.PREPROC.MAX_SIZE
field_size = int(np.ceil(max_size / stride))
shifts = np.arange(0, field_size) * stride
if not cfg.RPN.UNQUANTIZED_ANCHOR: shifts = np.arange(0, field_size) * stride
else: shifts = (np.arange(0, field_size) * stride).astype("float32")
shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten()
shift_y = shift_y.flatten()
Expand All @@ -167,15 +178,17 @@ def get_all_anchors(stride=None, sizes=None, tile=True):
K = shifts.shape[0]

A = cell_anchors.shape[0]
field_of_anchors = (
cell_anchors.reshape((1, A, 4)) +
shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
if not cfg.RPN.UNQUANTIZED_ANCHOR:
field_of_anchors = (
cell_anchors.reshape((1, A, 4)) +
shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
else: field_of_anchors = cell_anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2))
field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4))
# FSxFSxAx4
# Many rounding happens inside the anchor code anyway
# assert np.all(field_of_anchors == field_of_anchors.astype('int32'))
field_of_anchors = field_of_anchors.astype('float32')
field_of_anchors[:, :, :, [2, 3]] += 1
if not cfg.RPN.UNQUANTIZED_ANCHOR: field_of_anchors[:, :, :, [2, 3]] += 1
return field_of_anchors
else:
cell_anchors = cell_anchors.astype('float32')
Expand Down
74 changes: 54 additions & 20 deletions MaskRCNN/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pycocotools.mask as cocomask
import tqdm
import tensorflow as tf
from scipy import interpolate

from tensorpack.callbacks import Callback
from tensorpack.tfutils.common import get_tf_version_tuple
Expand Down Expand Up @@ -40,6 +41,22 @@
mask: None, or a binary image of the original image shape
"""

def _scale_box(box, scale):
w_half = (box[2] - box[0]) * 0.5
h_half = (box[3] - box[1]) * 0.5
x_c = (box[2] + box[0]) * 0.5
y_c = (box[3] + box[1]) * 0.5

w_half *= scale
h_half *= scale

scaled_box = np.zeros_like(box)
scaled_box[0] = x_c - w_half
scaled_box[2] = x_c + w_half
scaled_box[1] = y_c - h_half
scaled_box[3] = y_c + h_half
return scaled_box


def _paste_mask(box, mask, shape):
"""
Expand All @@ -50,23 +67,40 @@ def _paste_mask(box, mask, shape):
Returns:
A uint8 binary image of hxw.
"""
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5))
# box fpcoor=h -> intcoor=h-1, inclusive
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive
x1 = max(x0, x1) # require at least 1x1
y1 = max(y0, y1)

w = x1 + 1 - x0
h = y1 + 1 - y0
assert mask.shape[0] == mask.shape[1], mask.shape
if not cfg.RPN.SLOW_ACCURATE_MASK:
# This method (inspired by Detectron) is less accurate but fast.
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5))
# box fpcoor=h -> intcoor=h-1, inclusive
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive
x1 = max(x0, x1) # require at least 1x1
y1 = max(y0, y1)

w = x1 + 1 - x0
h = y1 + 1 - y0

# rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale
mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8')
ret = np.zeros(shape, dtype='uint8')
ret[y0:y1 + 1, x0:x1 + 1] = mask
return ret
else:
# This method is accurate but much slower.
mask = np.pad(mask, [(1, 1), (1, 1)], mode='constant')
box = _scale_box(box, float(mask.shape[0]) / (mask.shape[0] - 2))

# rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale
mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8')
ret = np.zeros(shape, dtype='uint8')
ret[y0:y1 + 1, x0:x1 + 1] = mask
return ret
mask_pixels = np.arange(0.0, mask.shape[0]) + 0.5
mask_continuous = interpolate.interp2d(mask_pixels, mask_pixels, mask, fill_value=0.0)
h, w = shape
ys = np.arange(0.0, h) + 0.5
xs = np.arange(0.0, w) + 0.5
ys = (ys - box[1]) / (box[3] - box[1]) * mask.shape[0]
xs = (xs - box[0]) / (box[2] - box[0]) * mask.shape[1]
res = mask_continuous(xs, ys)
return (res >= 0.5).astype('uint8')


def predict_image(img, model_func):
Expand Down Expand Up @@ -119,13 +153,13 @@ def predict_image_batch(img_batch, model_func, resized_sizes, scales, orig_sizes
"""

resized_sizes = np.stack(resized_sizes)
resized_sizes_in = np.concatenate((resized_sizes, 3*np.ones((resized_sizes.shape[0], 1))), axis=1)
resized_sizes_in = np.concatenate((resized_sizes, 3*np.ones((resized_sizes.shape[0], 1))), axis=1)

indices, boxes, probs, labels, *masks = model_func(img_batch, resized_sizes_in)

results = []
for i in range(len(scales)):
ind = np.where(indices.astype(np.int32) == i)[0]
for i in range(len(scales)):
ind = np.where(indices.astype(np.int32) == i)[0]

if len(ind) > 0:
boxes[ind, :] = boxes[ind, :]/scales[i]
Expand Down Expand Up @@ -293,7 +327,7 @@ def __init__(self, eval_dataset, in_names, out_names, output_dir, batch_size):
self._eval_dataset = eval_dataset
self._in_names, self._out_names = in_names, out_names
self._output_dir = output_dir
self.batched = batch_size > 0
self.batched = batch_size > 0
self.batch_size = batch_size

def _setup_graph(self):
Expand Down

0 comments on commit 5fb96b1

Please sign in to comment.