@@ -8495,3 +8495,172 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
84958495 kind = DimensionTag .Types .Spatial , description = "%s_rel_pos_enc_time" % name , dimension = None )
84968496 data = data .copy_template_new_dim_tags ((dummy_dim_tag , time_dim_tag , feature_dim_tag ))
84978497 return data
8498+
8499+
8500+ class CumConcatLayer (_ConcatInputLayer ):
8501+ """
8502+ Concatenates all previous frames of a time-axis.
8503+ Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8504+
8505+ This layer expects to be inside a :class:`RecLayer`.
8506+
8507+ Inside a rec loop (not optimized out),
8508+ this will concatenate the current input
8509+ to the previous accumulated inputs.
8510+ For an input of shape `input_shape`,
8511+ it will output a tensor of shape `[new_dim] + input_shape`.
8512+ `new_dim` is a special dimension, usually of length `i`,
8513+ where `i` is the current loop frame,
8514+ i.e. the length increases in every loop frame.
8515+ `new_dim` is specified by a separate own dim tag.
8516+ For example, in the first frame,
8517+ this will be of shape `[1] + input_shape`,
8518+ in the second frame shape `[2] + input_shape`,
8519+ and so on,
8520+ and in the last frame shape `[T] + input_shape`.
8521+
8522+ Outside the rec loop (optimized out),
8523+ this layer expects an input with the time dim of the rec layer,
8524+ and returns the input as-is,
8525+ but replacing the time dim tag with the dim tag `new_dim`
8526+ converted as outside the loop.
8527+
8528+ Normally the optimization should not matter for the user,
8529+ i.e. for the user, the logical behavior is always as being inside the rec loop.
8530+ Outside the loop,
8531+ the output represents a tensor of shape `[T, new_dim] + input_shape`,
8532+ although we actually have another `new_dim` outside the loop,
8533+ and `T` is not actually there,
8534+ but we still have all the information,
8535+ because the last frame has all information.
8536+
8537+ This layer can be used as a base for auto-regressive self-attention.
8538+ """
8539+ layer_class = "cum_concat"
8540+ recurrent = True # order matters
8541+
8542+ def __init__ (self , new_dim , ** kwargs ):
8543+ """
8544+ :param DimensionTag new_dim:
8545+ """
8546+ super (CumConcatLayer , self ).__init__ (** kwargs )
8547+ rec_layer = self .network .get_rec_parent_layer (inside_loop = False )
8548+ assert rec_layer , "%r must be used inside a RecLayer" % self
8549+ out_axis = self .output .get_axis_from_description (new_dim )
8550+ new_dim_ = self .output .dim_tags [out_axis ]
8551+
8552+ if self .network .is_inside_rec_layer (inside_loop = True ):
8553+ current_data = self .input_data .copy_compatible_to (self .output , unbroadcast = False )
8554+ current_frame = current_data .placeholder # [B, 1, ..., D]
8555+ last_frames = self ._rec_previous_layer .rec_vars_outputs ["state" ] # [B, t, ..., D]
8556+ concat_frames = tf .concat ([last_frames , current_frame ], axis = out_axis ) # [B, t+1, ..., D]
8557+ self .rec_vars_outputs ["state" ] = concat_frames
8558+ self .output .placeholder = concat_frames
8559+
8560+ if not new_dim_ .dyn_size_ext :
8561+ # Unbroadcasting to [B] is not needed because any layers operating on this
8562+ # should be able to handle extended dyn sizes.
8563+ # Clipping it to the max length for sequences in the loop which are already ended
8564+ # (i.e. considering the end flag)
8565+ # is also not needed because any calculations after the end are irrelevant.
8566+ # Note: In case we have some initial state/output, this can be extended.
8567+ dyn_size = self .network .get_rec_step_index () + 1 # scalar
8568+ new_dim_ .dyn_size_ext = Data (
8569+ name = "%s:cum-concat:size-inside" % self .name ,
8570+ dim_tags = [], # scalar
8571+ placeholder = dyn_size )
8572+
8573+ else :
8574+ # If not inside a rec loop, this layer is a no-op on the tensor.
8575+ self .output .placeholder = self .input_data .placeholder
8576+
8577+ # However, we used new dim tags, which were already prepared.
8578+ # We now must fill in the extended dynamic size information.
8579+ if not new_dim_ .dyn_size_ext :
8580+ # This must match the logic above for inside the loop.
8581+ # Note: In case we have some initial state/output, this can be extended.
8582+ dyn_size = tf .range (tf .math .reduce_max (rec_layer .time_dim_tag .dyn_size )) + 1 # [T]
8583+ new_dim_ .dyn_size_ext = Data (
8584+ name = "%s:cum-concat:size-outside" % self .name ,
8585+ dim_tags = [rec_layer .time_dim_tag ],
8586+ placeholder = dyn_size )
8587+
8588+ @classmethod
8589+ def get_out_data_from_opts (cls , name , network , sources , new_dim , ** kwargs ):
8590+ """
8591+ :param str name:
8592+ :param returnn.tf.network.TFNetwork network:
8593+ :param list[LayerBase] sources:
8594+ :param DimensionTag new_dim:
8595+ :rtype: Data
8596+ """
8597+ rec_layer = network .get_rec_parent_layer (inside_loop = False )
8598+ assert rec_layer , "CumConcatLayer %r must be used inside a RecLayer" % name
8599+ new_dim_base = new_dim .get_same_base ()
8600+ if new_dim_base .per_spatial_frame is None :
8601+ new_dim_base .per_spatial_frame = rec_layer .time_dim_tag
8602+ else :
8603+ assert new_dim_base .per_spatial_frame == rec_layer .time_dim_tag
8604+
8605+ input_data = get_concat_sources_data_template (sources , name = "%s_output" % name )
8606+ if network .is_inside_rec_layer (inside_loop = True ):
8607+ # Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8608+ # Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8609+ # In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8610+ # which should be more efficient
8611+ out = input_data .copy_as_batch_major ()
8612+ out = out .copy_add_dim_by_tag (new_dim_base , unbroadcast = True , axis = 1 )
8613+ return out
8614+
8615+ else : # outside loop
8616+ if not new_dim_base .per_spatial_frame_accumulated :
8617+ new_dim_accum = DimensionTag (
8618+ kind = new_dim_base .kind , description = "%s:accumulated" % name )
8619+ new_dim_accum .same_as = new_dim_base
8620+ new_dim_base .per_spatial_frame_accumulated = new_dim_accum
8621+ else :
8622+ new_dim_accum = new_dim_base .per_spatial_frame_accumulated
8623+ # Assume that the input has the time dim from the rec layer.
8624+ axis = input_data .get_axis_from_description (rec_layer .time_dim_tag )
8625+ return input_data .copy_template_replace_dim_tag (axis = axis , new_dim_tag = new_dim_accum )
8626+
8627+ # noinspection PyMethodOverriding
8628+ @classmethod
8629+ def get_rec_initial_extra_outputs (cls , network , batch_dim , rec_layer , sources , output , new_dim , ** kwargs ):
8630+ """
8631+ :param returnn.tf.network.TFNetwork network:
8632+ :param tf.Tensor batch_dim:
8633+ :param TFNetworkRecLayer.RecLayer|LayerBase rec_layer:
8634+ :param list[LayerBase] sources:
8635+ :param Data output:
8636+ :param DimensionTag new_dim:
8637+ :rtype: dict[str,tf.Tensor]
8638+ """
8639+ if network .is_inside_rec_layer ():
8640+ shape = []
8641+ for tag in output .dim_tags :
8642+ if tag .is_batch_dim ():
8643+ shape .append (batch_dim )
8644+ elif tag == new_dim :
8645+ shape .append (0 )
8646+ elif tag .dimension is not None :
8647+ shape .append (tag .dimension )
8648+ else :
8649+ assert tag .dyn_size is not None
8650+ shape .append (tf .math .reduce_max (tag .dyn_size ))
8651+ return {"state" : tf .zeros (shape , dtype = output .dtype )}
8652+ else :
8653+ return {}
8654+
8655+ @classmethod
8656+ def get_rec_initial_extra_outputs_shape_invariants (cls , network , sources , output , ** kwargs ):
8657+ """
8658+ :param returnn.tf.network.TFNetwork network:
8659+ :param list[LayerBase] sources:
8660+ :param Data output:
8661+ :rtype: dict[str, tf.TensorShape]
8662+ """
8663+ if network .is_inside_rec_layer ():
8664+ return {"state" : tf .TensorShape (output .batch_shape )}
8665+ else :
8666+ return {}
0 commit comments