@@ -406,8 +406,10 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i
406406 out = None
407407 if isinstance (unit , _SubnetworkRecCell ): # subnetwork
408408 subnet = unit
409- sub_out = subnet .layer_data_templates ["output" ].output .copy_template_adding_time_dim (
410- name = "%s_output" % kwargs ["name" ], time_dim_axis = 0 )
409+ sub_out = (
410+ subnet .layer_data_templates ["output" ].output
411+ .copy_template_adding_time_dim (name = "%s_output" % kwargs ["name" ], time_dim_axis = 0 )
412+ .copy_template_set_ctx (network .get_control_flow_ctx ()))
411413 if out :
412414 assert sub_out .dim == out .dim
413415 assert sub_out .shape == out .shape
@@ -993,21 +995,26 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n
993995 self .parent_net = parent_net
994996 self .net_dict = safe_deep_copy (net_dict )
995997 from returnn .tf .network import TFNetwork , ExternData , LossHolder
998+ from returnn .tf .util .data import ControlFlowContext
999+ control_flow_ctx = ControlFlowContext (
1000+ kind = ControlFlowContext .Types .Loop , outer_ctx = parent_net .get_control_flow_ctx ())
1001+ control_flow_ctx .loop_spatial_dim = time_dim_tag
9961002 self .net = TFNetwork (
9971003 name = "%s/%s(rec-subnet)" % (parent_net .name , rec_layer_name ),
9981004 extern_data = ExternData (),
9991005 train_flag = parent_net .train_flag ,
10001006 search_flag = parent_net .search_flag ,
10011007 eval_flag = False ,
10021008 inside_rec_time_dim = time_dim_tag ,
1009+ control_flow_ctx = control_flow_ctx ,
10031010 absolute_name_prefix = "%s%s/" % (parent_net .get_absolute_name_prefix (), rec_layer_name ),
10041011 parent_net = parent_net )
10051012 self .net .is_root_in_ctx = True
10061013 self .net .layers_desc .update (self .net_dict )
10071014 self .source_data = source_data
10081015 if source_data :
10091016 self .net .extern_data .data ["source" ] = (
1010- source_data .copy_template_excluding_time_dim ())
1017+ source_data .copy_template_excluding_time_dim (). copy_template_set_ctx ( control_flow_ctx ) )
10111018 self .time_dim_tag = time_dim_tag
10121019 self ._time_dim_tags = {time_dim_tag } # type: typing.Set[DimensionTag]
10131020 if source_data :
@@ -1020,7 +1027,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n
10201027 # These are just templates. You can use them as possible targets for dimension information,
10211028 # but not as actual sources or targets.
10221029 # Note: We maybe should check data.is_same_time_dim()...
1023- self .net .extern_data .data [key ] = data .copy_template_excluding_time_dim ()
1030+ self .net .extern_data .data [key ] = data .copy_template_excluding_time_dim (). copy_template_set_ctx ( control_flow_ctx )
10241031 self .layer_data_templates = {} # type: typing.Dict[str,_TemplateLayer]
10251032 self .prev_layers_needed = set () # type: typing.Set[str]
10261033 self .prev_layer_templates = {} # type: typing.Dict[str,_TemplateLayer]
@@ -1545,7 +1552,7 @@ def get_input_moved_out(name):
15451552 self .parent_rec_layer , self .parent_rec_layer .output .get_time_dim_tag (),
15461553 layer , layer .output .get_time_dim_tag ())
15471554 return layer
1548- output = layer .output .copy_template_excluding_time_dim ()
1555+ output = layer .output .copy_template_excluding_time_dim (). copy_template_set_ctx ( self . net . control_flow_ctx )
15491556 with tf .name_scope ("%s_moved_input" % name .replace (":" , "_" )):
15501557 if prev :
15511558 output .placeholder = tf .cond (
@@ -2513,7 +2520,8 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
25132520 self .parent_rec_layer , input_beam , output_beam ,
25142521 self .parent_rec_layer .sources , self .parent_rec_layer .target ))
25152522 assert output_template .output .batch .beam == output_beam
2516- time_dim_tag = time_dim_tag .get_for_batch (output_template .output .batch )
2523+ time_dim_tag = time_dim_tag .get_for_batch_ctx (
2524+ batch = output_template .output .batch , ctx = self .net .control_flow_ctx )
25172525 assert time_dim_tag .dyn_size is not None
25182526 seq_len = time_dim_tag .dyn_size
25192527 else :
@@ -2772,7 +2780,7 @@ def get_choice_seq(choice_base):
27722780 latest_batch = (
27732781 latest_layer_choice .output .batch
27742782 or self .parent_rec_layer .output .batch .copy_set_beam (latest_layer_choice .output .beam ))
2775- tag = tag .get_for_batch ( latest_batch )
2783+ tag = tag .get_for_batch_ctx ( batch = latest_batch , ctx = self . net . control_flow_ctx )
27762784 assert tag .dyn_size is not None
27772785 assert tag .batch == latest_batch and tag .batch .beam == latest_layer_choice .output .beam
27782786 seq_len = tag .dyn_size
@@ -3216,7 +3224,10 @@ def get_loop_acc_layer(name):
32163224 acc_ta , latest_layer_choice_name , search_choices , resolved_seq_len = self ._opt_search_resolve (
32173225 layer_name = name , acc_ta = acc_ta , final_net_vars = final_net_vars , seq_len = seq_len ,
32183226 search_choices_cache = search_choices_cache )
3219- output = self .layer_data_templates [name ].output .copy_template_adding_time_dim (time_dim_axis = 0 )
3227+ output = (
3228+ self .layer_data_templates [name ].output
3229+ .copy_template_adding_time_dim (time_dim_axis = 0 )
3230+ .copy_template_set_ctx (self .parent_net .get_control_flow_ctx ()))
32203231 if latest_layer_choice_name :
32213232 output .beam = self .net .layers [latest_layer_choice_name ].search_choices .get_beam_info ()
32223233 elif search_choices :
@@ -3303,7 +3314,10 @@ def get_layer(name):
33033314 for name , search_choices in search_choices_cache .items ():
33043315 if name not in self .output_layers_net .layers :
33053316 # Create dummy layer.
3306- output = self .layer_data_templates [name ].output .copy_template_adding_time_dim (time_dim_axis = 0 )
3317+ output = (
3318+ self .layer_data_templates [name ].output
3319+ .copy_template_adding_time_dim (time_dim_axis = 0 )
3320+ .copy_template_set_ctx (self .output_layers_net .get_control_flow_ctx ()))
33073321 output .beam = search_choices .get_beam_info ()
33083322 layer = InternalLayer (name = name , network = self .output_layers_net , output = output )
33093323 self .output_layers_net .layers [name ] = layer
@@ -3350,7 +3364,8 @@ def __init__(self, network, name, construct_stack=None, cell=None):
33503364 output = Data (
33513365 name = "dummy_initial_template_data" ,
33523366 batch_dim_axis = 0 , time_dim_axis = None ,
3353- shape = ()), # (B,). no time-dim
3367+ shape = (),
3368+ control_flow_ctx = network .get_control_flow_ctx ()), # (B,). no time-dim
33543369 name = name , network = network )
33553370 self .output .size_placeholder = {} # must be initialized
33563371 self .layer_class = ":uninitialized-template"
@@ -5226,7 +5241,7 @@ def decide(cls, src, output=None, owner=None, name=None, length_normalization=Fa
52265241 for i , size in src_data .size_placeholder .items ():
52275242 tag = DimensionTag .get_tag_from_size_tensor (size )
52285243 assert tag
5229- tag = tag .get_for_batch ( output .batch )
5244+ tag = tag .get_for_batch_ctx ( batch = output .batch , ctx = output . control_flow_ctx )
52305245 if tag .dyn_size is None :
52315246 size = tf .reshape (size , [batch_dim , beam_size ]) # (batch, beam)
52325247 size = tf .gather_nd (size , indices = beam_idxs_ext ) # (batch,)
@@ -7071,8 +7086,11 @@ def _create_template(cls, name, network, sources, masked_from, unit,
70717086 # We don't care about the right masked input here, but just about deriving the right output shape.
70727087 if masked_from :
70737088 if network .is_inside_rec_layer (inside_loop = True ):
7074- source_data = masked_from .output .copy_template_excluding_time_dim (
7075- name = "%s_%s_masked_input_frame" % (masked_from .output .name , name ))
7089+ source_data = (
7090+ masked_from .output
7091+ .copy_template_excluding_time_dim (
7092+ name = "%s_%s_masked_input_frame" % (masked_from .output .name , name ))
7093+ .copy_template_set_ctx (network .get_control_flow_ctx ()))
70767094 else :
70777095 source_data = masked_from .output .copy_template (
70787096 name = "%s_%s_masked_input" % (masked_from .output .name , name ))
@@ -7347,7 +7365,7 @@ def get_out_data_from_opts(cls, name, network, sources, mask, **kwargs):
73477365 # thus when we unroll it to get into the loop, the RecLayer would have kept it as-is,
73487366 # i.e. it should still have that time-dim-axis.
73497367 # Maybe we should do some extra checks if that is like we assume, but for now, just assume that.
7350- return out .copy_template_excluding_time_dim ()
7368+ return out .copy_template_excluding_time_dim (). copy_template_set_ctx ( network . get_control_flow_ctx ())
73517369 return out
73527370 assert out .have_time_axis ()
73537371 out = out .copy_as_time_major ()
0 commit comments