@@ -3197,18 +3197,34 @@ def get_sequence_mask_broadcast(self, axis=None):
31973197 axis += self .batch_ndim
31983198 assert 0 <= axis < self .batch_ndim
31993199 assert axis != self .batch_dim_axis
3200- size = self .get_dynamic_size (axis )
3201- if axis >= self .batch_dim_axis :
3202- seq_mask = sequence_mask (size ) # (B,T)
3203- else : # axis < batch_dim_axis
3204- seq_mask = sequence_mask_time_major (size ) # (T,B)
3205- shape = [1 ] * self .batch_ndim # type: typing.List[typing.Union[int,tf.Tensor]]
3200+ tag = self .dim_tags [axis ]
3201+ assert tag .dyn_size_ext
32063202 with tf .name_scope ("get_sequence_mask_broadcast" ):
3207- placeholder_shape = tf .shape (self .placeholder )
3208- shape [self .batch_dim_axis ] = placeholder_shape [self .batch_dim_axis ]
3209- shape [axis ] = placeholder_shape [axis ]
3210- seq_mask = tf .reshape (seq_mask , shape , name = "seq_mask_reshape" )
3211- assert seq_mask .get_shape ().ndims == self .batch_ndim
3203+ if tag .dyn_size_ext .have_batch_axis () and tag .dyn_size_ext .batch_ndim == 1 : # just [B]
3204+ # This is the common case where the size is of shape [B].
3205+ # We make use of sequence_mask or sequence_mask_time_major in that case,
3206+ # which is optimized by caching.
3207+ size = tag .dyn_size
3208+ if axis >= self .batch_dim_axis :
3209+ seq_mask = sequence_mask (size ) # (B,T)
3210+ else : # axis < batch_dim_axis
3211+ seq_mask = sequence_mask_time_major (size ) # (T,B)
3212+ shape = [1 ] * self .batch_ndim # type: typing.List[typing.Union[int,tf.Tensor]]
3213+ placeholder_shape = tf .shape (self .placeholder )
3214+ shape [self .batch_dim_axis ] = placeholder_shape [self .batch_dim_axis ]
3215+ shape [axis ] = placeholder_shape [axis ]
3216+ seq_mask = tf .reshape (seq_mask , shape , name = "seq_mask_reshape" )
3217+ assert seq_mask .get_shape ().ndims == self .batch_ndim
3218+ else : # size is something unusual
3219+ max_idx = tf .reduce_max (tag .dyn_size )
3220+ # We use the assumption that self.placeholder.shape[axis] == max_idx.
3221+ idx_range = tf .range (max_idx )
3222+ idx_range = tf .reshape (idx_range , [1 ] * (axis - 1 ) + [max_idx ] + [1 ] * (self .batch_ndim - axis - 1 ))
3223+ assert tag .dyn_size_ext
3224+ assert set (tag .dyn_size_ext .dim_tags ).issubset (self .dim_tags )
3225+ size_ext = tag .dyn_size_ext .copy_compatible_to (self , check_sparse = False , check_dtype = False )
3226+ seq_mask = tf .less (idx_range , size_ext .placeholder )
3227+ assert seq_mask .get_shape ().ndims == self .batch_ndim
32123228 return seq_mask
32133229
32143230 def get_batch_dim (self ):
0 commit comments