Skip to content

Commit 9baf23d

Browse files
authored
ControlFlowContext in dim tag, network, Data (#647)
Needed for #589
1 parent 1aa9361 commit 9baf23d

File tree

5 files changed

+242
-41
lines changed

5 files changed

+242
-41
lines changed

returnn/tf/layers/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,13 @@ def fixup_out_data(cls, output, network):
393393
extern_data.init_batch_info() # this should create it and also set it
394394
assert output.batch
395395
output.batch = output.batch.copy_set_beam(output.beam)
396+
if output.control_flow_ctx != network.get_control_flow_ctx():
397+
x = output.placeholder
398+
output = output.copy_template_set_ctx(network.get_control_flow_ctx())
399+
if x is not None:
400+
# Some layers might just copy the input. But the input might have buggy ctx.
401+
# Just leave the placeholder as-is. Most layers should anyway reset this.
402+
output.placeholder = x
396403
return output
397404

398405
def get_full_ctx_name(self):
@@ -1730,7 +1737,7 @@ def opt_get_layer(layer_name):
17301737
# Don't return layer, could be inside loop and that wont work.
17311738
output = net.layers[layer_name].output.copy_template()
17321739
if not output.have_time_axis() and with_time_dim:
1733-
output = output.copy_template_adding_time_dim()
1740+
output = output.copy_template_adding_time_dim().copy_template_set_ctx(network.get_control_flow_ctx())
17341741
if not output:
17351742
layer_desc_ = net.layers_desc[layer_name].copy()
17361743
class_name_ = layer_desc_.pop("class")

returnn/tf/layers/rec.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,10 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i
406406
out = None
407407
if isinstance(unit, _SubnetworkRecCell): # subnetwork
408408
subnet = unit
409-
sub_out = subnet.layer_data_templates["output"].output.copy_template_adding_time_dim(
410-
name="%s_output" % kwargs["name"], time_dim_axis=0)
409+
sub_out = (
410+
subnet.layer_data_templates["output"].output
411+
.copy_template_adding_time_dim(name="%s_output" % kwargs["name"], time_dim_axis=0)
412+
.copy_template_set_ctx(network.get_control_flow_ctx()))
411413
if out:
412414
assert sub_out.dim == out.dim
413415
assert sub_out.shape == out.shape
@@ -993,21 +995,26 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n
993995
self.parent_net = parent_net
994996
self.net_dict = safe_deep_copy(net_dict)
995997
from returnn.tf.network import TFNetwork, ExternData, LossHolder
998+
from returnn.tf.util.data import ControlFlowContext
999+
control_flow_ctx = ControlFlowContext(
1000+
kind=ControlFlowContext.Types.Loop, outer_ctx=parent_net.get_control_flow_ctx())
1001+
control_flow_ctx.loop_spatial_dim = time_dim_tag
9961002
self.net = TFNetwork(
9971003
name="%s/%s(rec-subnet)" % (parent_net.name, rec_layer_name),
9981004
extern_data=ExternData(),
9991005
train_flag=parent_net.train_flag,
10001006
search_flag=parent_net.search_flag,
10011007
eval_flag=False,
10021008
inside_rec_time_dim=time_dim_tag,
1009+
control_flow_ctx=control_flow_ctx,
10031010
absolute_name_prefix="%s%s/" % (parent_net.get_absolute_name_prefix(), rec_layer_name),
10041011
parent_net=parent_net)
10051012
self.net.is_root_in_ctx = True
10061013
self.net.layers_desc.update(self.net_dict)
10071014
self.source_data = source_data
10081015
if source_data:
10091016
self.net.extern_data.data["source"] = (
1010-
source_data.copy_template_excluding_time_dim())
1017+
source_data.copy_template_excluding_time_dim().copy_template_set_ctx(control_flow_ctx))
10111018
self.time_dim_tag = time_dim_tag
10121019
self._time_dim_tags = {time_dim_tag} # type: typing.Set[DimensionTag]
10131020
if source_data:
@@ -1020,7 +1027,7 @@ def __init__(self, net_dict, source_data, time_dim_tag, rec_layer_name, parent_n
10201027
# These are just templates. You can use them as possible targets for dimension information,
10211028
# but not as actual sources or targets.
10221029
# Note: We maybe should check data.is_same_time_dim()...
1023-
self.net.extern_data.data[key] = data.copy_template_excluding_time_dim()
1030+
self.net.extern_data.data[key] = data.copy_template_excluding_time_dim().copy_template_set_ctx(control_flow_ctx)
10241031
self.layer_data_templates = {} # type: typing.Dict[str,_TemplateLayer]
10251032
self.prev_layers_needed = set() # type: typing.Set[str]
10261033
self.prev_layer_templates = {} # type: typing.Dict[str,_TemplateLayer]
@@ -1545,7 +1552,7 @@ def get_input_moved_out(name):
15451552
self.parent_rec_layer, self.parent_rec_layer.output.get_time_dim_tag(),
15461553
layer, layer.output.get_time_dim_tag())
15471554
return layer
1548-
output = layer.output.copy_template_excluding_time_dim()
1555+
output = layer.output.copy_template_excluding_time_dim().copy_template_set_ctx(self.net.control_flow_ctx)
15491556
with tf.name_scope("%s_moved_input" % name.replace(":", "_")):
15501557
if prev:
15511558
output.placeholder = tf.cond(
@@ -2513,7 +2520,8 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
25132520
self.parent_rec_layer, input_beam, output_beam,
25142521
self.parent_rec_layer.sources, self.parent_rec_layer.target))
25152522
assert output_template.output.batch.beam == output_beam
2516-
time_dim_tag = time_dim_tag.get_for_batch(output_template.output.batch)
2523+
time_dim_tag = time_dim_tag.get_for_batch_ctx(
2524+
batch=output_template.output.batch, ctx=self.net.control_flow_ctx)
25172525
assert time_dim_tag.dyn_size is not None
25182526
seq_len = time_dim_tag.dyn_size
25192527
else:
@@ -2772,7 +2780,7 @@ def get_choice_seq(choice_base):
27722780
latest_batch = (
27732781
latest_layer_choice.output.batch
27742782
or self.parent_rec_layer.output.batch.copy_set_beam(latest_layer_choice.output.beam))
2775-
tag = tag.get_for_batch(latest_batch)
2783+
tag = tag.get_for_batch_ctx(batch=latest_batch, ctx=self.net.control_flow_ctx)
27762784
assert tag.dyn_size is not None
27772785
assert tag.batch == latest_batch and tag.batch.beam == latest_layer_choice.output.beam
27782786
seq_len = tag.dyn_size
@@ -3216,7 +3224,10 @@ def get_loop_acc_layer(name):
32163224
acc_ta, latest_layer_choice_name, search_choices, resolved_seq_len = self._opt_search_resolve(
32173225
layer_name=name, acc_ta=acc_ta, final_net_vars=final_net_vars, seq_len=seq_len,
32183226
search_choices_cache=search_choices_cache)
3219-
output = self.layer_data_templates[name].output.copy_template_adding_time_dim(time_dim_axis=0)
3227+
output = (
3228+
self.layer_data_templates[name].output
3229+
.copy_template_adding_time_dim(time_dim_axis=0)
3230+
.copy_template_set_ctx(self.parent_net.get_control_flow_ctx()))
32203231
if latest_layer_choice_name:
32213232
output.beam = self.net.layers[latest_layer_choice_name].search_choices.get_beam_info()
32223233
elif search_choices:
@@ -3303,7 +3314,10 @@ def get_layer(name):
33033314
for name, search_choices in search_choices_cache.items():
33043315
if name not in self.output_layers_net.layers:
33053316
# Create dummy layer.
3306-
output = self.layer_data_templates[name].output.copy_template_adding_time_dim(time_dim_axis=0)
3317+
output = (
3318+
self.layer_data_templates[name].output
3319+
.copy_template_adding_time_dim(time_dim_axis=0)
3320+
.copy_template_set_ctx(self.output_layers_net.get_control_flow_ctx()))
33073321
output.beam = search_choices.get_beam_info()
33083322
layer = InternalLayer(name=name, network=self.output_layers_net, output=output)
33093323
self.output_layers_net.layers[name] = layer
@@ -3350,7 +3364,8 @@ def __init__(self, network, name, construct_stack=None, cell=None):
33503364
output=Data(
33513365
name="dummy_initial_template_data",
33523366
batch_dim_axis=0, time_dim_axis=None,
3353-
shape=()), # (B,). no time-dim
3367+
shape=(),
3368+
control_flow_ctx=network.get_control_flow_ctx()), # (B,). no time-dim
33543369
name=name, network=network)
33553370
self.output.size_placeholder = {} # must be initialized
33563371
self.layer_class = ":uninitialized-template"
@@ -5226,7 +5241,7 @@ def decide(cls, src, output=None, owner=None, name=None, length_normalization=Fa
52265241
for i, size in src_data.size_placeholder.items():
52275242
tag = DimensionTag.get_tag_from_size_tensor(size)
52285243
assert tag
5229-
tag = tag.get_for_batch(output.batch)
5244+
tag = tag.get_for_batch_ctx(batch=output.batch, ctx=output.control_flow_ctx)
52305245
if tag.dyn_size is None:
52315246
size = tf.reshape(size, [batch_dim, beam_size]) # (batch, beam)
52325247
size = tf.gather_nd(size, indices=beam_idxs_ext) # (batch,)
@@ -7071,8 +7086,11 @@ def _create_template(cls, name, network, sources, masked_from, unit,
70717086
# We don't care about the right masked input here, but just about deriving the right output shape.
70727087
if masked_from:
70737088
if network.is_inside_rec_layer(inside_loop=True):
7074-
source_data = masked_from.output.copy_template_excluding_time_dim(
7075-
name="%s_%s_masked_input_frame" % (masked_from.output.name, name))
7089+
source_data = (
7090+
masked_from.output
7091+
.copy_template_excluding_time_dim(
7092+
name="%s_%s_masked_input_frame" % (masked_from.output.name, name))
7093+
.copy_template_set_ctx(network.get_control_flow_ctx()))
70767094
else:
70777095
source_data = masked_from.output.copy_template(
70787096
name="%s_%s_masked_input" % (masked_from.output.name, name))
@@ -7347,7 +7365,7 @@ def get_out_data_from_opts(cls, name, network, sources, mask, **kwargs):
73477365
# thus when we unroll it to get into the loop, the RecLayer would have kept it as-is,
73487366
# i.e. it should still have that time-dim-axis.
73497367
# Maybe we should do some extra checks if that is like we assume, but for now, just assume that.
7350-
return out.copy_template_excluding_time_dim()
7368+
return out.copy_template_excluding_time_dim().copy_template_set_ctx(network.get_control_flow_ctx())
73517369
return out
73527370
assert out.have_time_axis()
73537371
out = out.copy_as_time_major()

returnn/tf/network.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
354354
train_flag=None, eval_flag=None, search_flag=None,
355355
parent_layer=None, parent_net=None, extra_parent_net=None, extra_name_prefix=None,
356356
inside_rec_time_dim=None, over_rec_time_dim=None, over_rec_time_dim_subs=None,
357+
control_flow_ctx=None,
357358
absolute_name_prefix=None, name=None):
358359
"""
359360
:param returnn.config.Config config: only needed to init extern_data if not specified explicitly
@@ -370,6 +371,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
370371
:param DimensionTag|None inside_rec_time_dim: dim tag of outer rec layer, when run inside the loop (not optimized)
371372
:param DimensionTag|None over_rec_time_dim: dim tag of outer rec layer, when optimized out of the loop
372373
:param set[DimensionTag]|None over_rec_time_dim_subs: outer rec layer, out of loop, potential shorter
374+
:param returnn.tf.util.data.ControlFlowContext control_flow_ctx:
373375
:param str|None absolute_name_prefix:
374376
:param str name: only for debugging
375377
"""
@@ -432,6 +434,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
432434
self._inside_rec_time_dim = inside_rec_time_dim
433435
self._over_rec_time_dim = over_rec_time_dim
434436
self._over_rec_time_dim_subs = over_rec_time_dim_subs
437+
self.control_flow_ctx = control_flow_ctx
435438
self.extra_parent_net = extra_parent_net
436439
self.extra_name_prefix = extra_name_prefix
437440
self.extra_deps_in_extra = False
@@ -519,6 +522,17 @@ def get_root_ctx_network(self):
519522
break
520523
return net, "".join(reversed(path))
521524

525+
def get_control_flow_ctx(self):
526+
"""
527+
:rtype: returnn.tf.util.data.ControlFlowContext|None
528+
"""
529+
net = self
530+
while net:
531+
if net.control_flow_ctx:
532+
return net.control_flow_ctx
533+
net = net.parent_net
534+
return None
535+
522536
def is_extra_internal_template_construction(self):
523537
"""
524538
:rtype: LayerBase|None

returnn/tf/util/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4578,7 +4578,7 @@ def _maybe_to_base_seq_len(v):
45784578
base_out_tag.set_tag_on_size_tensor(base_out_seq_len)
45794579

45804580
assert base_out_tag.batch
4581-
out_tag = base_out_tag.get_for_batch(in_tag.batch)
4581+
out_tag = base_out_tag.get_for_batch_ctx(batch=in_tag.batch, ctx=in_tag.control_flow_ctx)
45824582
assert out_tag.dyn_size is not None
45834583
return out_tag.dyn_size
45844584

0 commit comments

Comments
 (0)