Skip to content

Commit 1aa9361

Browse files
authored
Data get_sequence_mask_broadcast fix for uncommon cases (#646)
1 parent b770d9d commit 1aa9361

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

returnn/tf/util/basic.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,9 @@ def mask_dyn_seq_len_nd(x, pad_value, axes):
258258
if set(existing_pad_values) == {pad_value}:
259259
return x.placeholder # nothing to do
260260

261-
x_shape = get_shape(x_)
262-
mask = tf.ones([1] * len(x_shape), dtype=tf.bool)
261+
mask = tf.ones([1] * x.batch_ndim, dtype=tf.bool)
263262
for axis in axes:
264-
tag = x.dim_tags[axis]
265-
idx_range = tf.range(x_shape[axis])
266-
idx_range = tf.reshape(idx_range, [1] * (axis - 1) + x_shape[axis:axis + 1] + [1] * (len(x_shape) - axis - 1))
267-
assert tag.dyn_size_ext
268-
assert set(tag.dyn_size_ext.dim_tags).issubset(x.dim_tags)
269-
size_ext = tag.dyn_size_ext.copy_compatible_to(x, check_dtype=False)
270-
mask_ = tf.less(idx_range, size_ext.placeholder)
263+
mask_ = x.get_sequence_mask_broadcast(axis=axis)
271264
mask = tf.logical_and(mask, mask_)
272265
x_ = where_bc(mask, x_, tf.cast(tf.constant(pad_value, name="pad_value"), dtype=x_.dtype))
273266
d = get_padding_info_dict_ref(x_)

returnn/tf/util/data.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3197,18 +3197,34 @@ def get_sequence_mask_broadcast(self, axis=None):
31973197
axis += self.batch_ndim
31983198
assert 0 <= axis < self.batch_ndim
31993199
assert axis != self.batch_dim_axis
3200-
size = self.get_dynamic_size(axis)
3201-
if axis >= self.batch_dim_axis:
3202-
seq_mask = sequence_mask(size) # (B,T)
3203-
else: # axis < batch_dim_axis
3204-
seq_mask = sequence_mask_time_major(size) # (T,B)
3205-
shape = [1] * self.batch_ndim # type: typing.List[typing.Union[int,tf.Tensor]]
3200+
tag = self.dim_tags[axis]
3201+
assert tag.dyn_size_ext
32063202
with tf.name_scope("get_sequence_mask_broadcast"):
3207-
placeholder_shape = tf.shape(self.placeholder)
3208-
shape[self.batch_dim_axis] = placeholder_shape[self.batch_dim_axis]
3209-
shape[axis] = placeholder_shape[axis]
3210-
seq_mask = tf.reshape(seq_mask, shape, name="seq_mask_reshape")
3211-
assert seq_mask.get_shape().ndims == self.batch_ndim
3203+
if tag.dyn_size_ext.have_batch_axis() and tag.dyn_size_ext.batch_ndim == 1: # just [B]
3204+
# This is the common case where the size is of shape [B].
3205+
# We make use of sequence_mask or sequence_mask_time_major in that case,
3206+
# which is optimized by caching.
3207+
size = tag.dyn_size
3208+
if axis >= self.batch_dim_axis:
3209+
seq_mask = sequence_mask(size) # (B,T)
3210+
else: # axis < batch_dim_axis
3211+
seq_mask = sequence_mask_time_major(size) # (T,B)
3212+
shape = [1] * self.batch_ndim # type: typing.List[typing.Union[int,tf.Tensor]]
3213+
placeholder_shape = tf.shape(self.placeholder)
3214+
shape[self.batch_dim_axis] = placeholder_shape[self.batch_dim_axis]
3215+
shape[axis] = placeholder_shape[axis]
3216+
seq_mask = tf.reshape(seq_mask, shape, name="seq_mask_reshape")
3217+
assert seq_mask.get_shape().ndims == self.batch_ndim
3218+
else: # size is something unusual
3219+
max_idx = tf.reduce_max(tag.dyn_size)
3220+
# We use the assumption that self.placeholder.shape[axis] == max_idx.
3221+
idx_range = tf.range(max_idx)
3222+
idx_range = tf.reshape(idx_range, [1] * (axis - 1) + [max_idx] + [1] * (self.batch_ndim - axis - 1))
3223+
assert tag.dyn_size_ext
3224+
assert set(tag.dyn_size_ext.dim_tags).issubset(self.dim_tags)
3225+
size_ext = tag.dyn_size_ext.copy_compatible_to(self, check_sparse=False, check_dtype=False)
3226+
seq_mask = tf.less(idx_range, size_ext.placeholder)
3227+
assert seq_mask.get_shape().ndims == self.batch_ndim
32123228
return seq_mask
32133229

32143230
def get_batch_dim(self):

0 commit comments

Comments
 (0)