Skip to content

Commit 234011f

Browse files
Merge pull request #5 from srihari-humbarwadi/training
* set segmentation loss weight to 0.5 * skip boxes with zero area while pasting masks
2 parents d1f0183 + b14f82a commit 234011f

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

official/vision/beta/projects/panoptic_maskrcnn/configs/panoptic_maskrcnn.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class Losses(maskrcnn.Losses):
107107
semantic_segmentation_use_groundtruth_dimension: bool = True
108108
semantic_segmentation_top_k_percent_pixels: float = 1.0
109109
instance_segmentation_weight: float = 1.0
110-
semantic_segmentation_weight: float = 1.0
110+
semantic_segmentation_weight: float = 0.5
111111

112112

113113
@dataclasses.dataclass
@@ -170,7 +170,8 @@ def panoptic_fpn_coco() -> cfg.ExperimentConfig:
170170
is_thing.append(True if idx <= num_thing_categories else False)
171171

172172
config = cfg.ExperimentConfig(
173-
runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
173+
runtime=cfg.RuntimeConfig(
174+
mixed_precision_dtype='bfloat16', enable_xla=True),
174175
task=PanopticMaskRCNNTask(
175176
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', # pylint: disable=line-too-long
176177
init_checkpoint_modules=['backbone'],

official/vision/beta/projects/panoptic_maskrcnn/modeling/layers/panoptic_segmentation_generator.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -79,26 +79,27 @@ def _paste_mask(self, box, mask):
7979
pasted_mask = tf.ones(
8080
self._output_size + [1], dtype=mask.dtype) * self._void_class_label
8181

82-
ymin = box[0]
83-
xmin = box[1]
82+
ymin = tf.clip_by_value(box[0], 0, self._output_size[0])
83+
xmin = tf.clip_by_value(box[1], 0, self._output_size[1])
8484
ymax = tf.clip_by_value(box[2] + 1, 0, self._output_size[0])
8585
xmax = tf.clip_by_value(box[3] + 1, 0, self._output_size[1])
8686
box_height = ymax - ymin
8787
box_width = xmax - xmin
8888

89-
# resize mask to match the shape of the instance bounding box
90-
resized_mask = tf.image.resize(
91-
mask,
92-
size=(box_height, box_width),
93-
method='nearest')
94-
95-
# paste resized mask on a blank mask that matches image shape
96-
pasted_mask = tf.raw_ops.TensorStridedSliceUpdate(
97-
input=pasted_mask,
98-
begin=[ymin, xmin],
99-
end=[ymax, xmax],
100-
strides=[1, 1],
101-
value=resized_mask)
89+
if not (box_height == 0 or box_width == 0):
90+
# resize mask to match the shape of the instance bounding box
91+
resized_mask = tf.image.resize(
92+
mask,
93+
size=(box_height, box_width),
94+
method='nearest')
95+
96+
# paste resized mask on a blank mask that matches image shape
97+
pasted_mask = tf.raw_ops.TensorStridedSliceUpdate(
98+
input=pasted_mask,
99+
begin=[ymin, xmin],
100+
end=[ymax, xmax],
101+
strides=[1, 1],
102+
value=resized_mask)
102103

103104
return pasted_mask
104105

0 commit comments

Comments
 (0)