Skip to content

Commit fbf22dc

Browse files
authored
DimensionTag better same_as logic, validate graph (#677)
DimensionTag declare_same_as, fix dyn_size_ext for ctx in some cases. DimensionTag.get_for_batch_ctx validate current graph. This partly addresses the concerns and mostly fixes #672.
1 parent 45362e2 commit fbf22dc

File tree

1 file changed

+126
-29
lines changed

1 file changed

+126
-29
lines changed

returnn/tf/util/data.py

Lines changed: 126 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ def __init__(self, kind=Types.Unspecified, description=None,
8282
if dyn_size_ext:
8383
assert batch == dyn_size_ext.batch
8484
self.dyn_size_ext = dyn_size_ext # type: typing.Optional[Data]
85-
if dyn_size is not None:
86-
assert not dyn_size_ext
87-
self.dyn_size = dyn_size
8885
self._dyn_size_same = set() # type: typing.Set[tf.Tensor]
8986
self._undefined = undefined
9087
# We can have different tag variants per batch info (e.g. with beam), or per control flow ctx.
9188
# They each have same_as = self. The same_base should have the base (global) batch info.
9289
self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],DimensionTag] # nopep8
90+
if dyn_size is not None:
91+
assert not dyn_size_ext
92+
self.dyn_size = dyn_size
9393

9494
def __repr__(self):
9595
return "DimensionTag{%s}" % self.short_repr()
@@ -149,13 +149,54 @@ def _can_use_in_ctx(self, ctx):
149149
return False
150150
return True
151151

152-
def get_for_batch_ctx(self, batch, ctx):
152+
def _validate_in_current_graph(self):
153+
"""
154+
:rtype: bool
155+
"""
156+
tensor = None
157+
if self.batch:
158+
batch_base = self.batch.get_global_base()
159+
if batch_base.is_global_batch():
160+
tensor = batch_base.get_global_batch_dim().size
161+
if not isinstance(tensor, tf.Tensor):
162+
if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None:
163+
tensor = self.dyn_size_ext.placeholder
164+
if isinstance(tensor, tf.Tensor):
165+
g = tf_compat.v1.get_default_graph()
166+
if tensor.graph is not g: # maybe from an earlier run which reuses the dim tag
167+
# Reset and cleanup.
168+
self.dyn_size_ext = None
169+
same_base = self.get_same_base()
170+
same_base._same_for_batch_ctx.pop((self.batch, self.control_flow_ctx), None)
171+
self.batch = None # it is invalid in the new graph
172+
self.control_flow_ctx = None # also invalid
173+
return False
174+
return True
175+
176+
def _maybe_update(self):
177+
if self.is_batch_dim():
178+
return
179+
if isinstance(self.dimension, int):
180+
return
181+
if self.dyn_size_ext:
182+
return
183+
if not self.batch:
184+
return
185+
# Check if we can find more in
186+
same = self.get_for_batch_ctx(self.batch, self.control_flow_ctx, allow_none=True)
187+
if self is same or not same or not same.dyn_size_ext:
188+
return
189+
self.dyn_size_ext = same.dyn_size_ext
190+
191+
def get_for_batch_ctx(self, batch, ctx, allow_none=False):
153192
"""
154193
:param BatchInfo batch:
155194
:param ControlFlowContext|None ctx:
156-
:rtype: DimensionTag
195+
:param bool allow_none:
196+
:rtype: DimensionTag|None
157197
"""
158-
if self.batch == batch and self.control_flow_ctx == ctx:
198+
if self.batch == batch and self.control_flow_ctx == ctx and self.dyn_size_ext:
199+
self._validate_in_current_graph()
159200
return self
160201
if self.is_batch_dim():
161202
# We ignore the ctx for the batch dim currently.
@@ -169,6 +210,7 @@ def get_for_batch_ctx(self, batch, ctx):
169210
if batch.is_broadcast():
170211
return self # just leave as-is. should not matter.
171212
same_base = self.get_same_base()
213+
same_base._validate_in_current_graph()
172214
# Might be uninitialized in some cases. Assume batch is global.
173215
if not same_base.batch:
174216
batch_base = batch.get_global_base()
@@ -182,27 +224,24 @@ def get_for_batch_ctx(self, batch, ctx):
182224
if same_base.dyn_size_ext:
183225
assert same_base.batch == same_base.dyn_size_ext.batch
184226
assert same_base.control_flow_ctx == same_base.dyn_size_ext.control_flow_ctx
185-
tag = same_base._same_for_batch_ctx.get((batch, ctx), None)
186-
if tag:
187-
return tag
188-
if same_base.batch == batch and same_base._can_use_in_ctx(ctx):
189-
return same_base
190227
for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx):
191228
tag = same_base._same_for_batch_ctx.get((batch, ctx_), None)
192-
if tag and tag._can_use_in_ctx(ctx):
229+
if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph():
193230
return tag
231+
if same_base.batch == batch and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext:
232+
return same_base
194233
# Ok, nothing matching found.
195234
dyn_size_ext = None
196235
# Maybe we have sth with the base batch without beam which we can extend.
197236
if batch.copy_remove_beam() == batch.get_global_base() and batch.beam:
198237
batch_base = batch.get_global_base()
199238
base_can_use_in_ctx = None
200-
if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx):
239+
if same_base.batch == batch_base and same_base._can_use_in_ctx(ctx) and same_base.dyn_size_ext:
201240
base_can_use_in_ctx = same_base
202241
else:
203242
for ctx_ in ControlFlowContext.abs_ctx_stack_with_root(ctx):
204243
tag = same_base._same_for_batch_ctx.get((batch_base, ctx_), None)
205-
if tag and tag._can_use_in_ctx(ctx):
244+
if tag and tag._can_use_in_ctx(ctx) and tag._validate_in_current_graph() and tag.dyn_size_ext:
206245
base_can_use_in_ctx = tag
207246
break
208247
if base_can_use_in_ctx and base_can_use_in_ctx.dyn_size_ext:
@@ -224,6 +263,8 @@ def get_for_batch_ctx(self, batch, ctx):
224263
name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name)))
225264
dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam
226265
dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data
266+
if not dyn_size_ext and allow_none:
267+
return None
227268
dim_tag = DimensionTag(
228269
kind=self.kind, description=self.description, dimension=self.dimension,
229270
batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx,
@@ -235,6 +276,34 @@ def get_for_batch_ctx(self, batch, ctx):
235276
same_base._same_for_batch_ctx[(dim_tag.batch, dim_tag.control_flow_ctx)] = dim_tag
236277
return dim_tag
237278

279+
def set_dyn_size_ext_for_batch_ctx(self, batch, ctx, dyn_size_ext):
280+
"""
281+
:param BatchInfo batch:
282+
:param ControlFlowContext|None ctx:
283+
:param Data dyn_size_ext:
284+
"""
285+
same = self.get_for_batch_ctx(batch, ctx)
286+
same.dyn_size_ext = dyn_size_ext
287+
self._maybe_update()
288+
289+
def get_dyn_size_ext_for_batch_ctx(self, batch, ctx):
290+
"""
291+
:param BatchInfo|None batch:
292+
:param ControlFlowContext|None ctx:
293+
:rtype: Data|None
294+
"""
295+
if not batch and self.batch:
296+
# Assume global batch.
297+
batch = self.batch.get_global_base()
298+
if not batch:
299+
# This is usually not valid. However, this case can happen early at initialization.
300+
assert batch == self.batch and ctx == self.control_flow_ctx
301+
return self.dyn_size_ext
302+
same = self.get_for_batch_ctx(batch, ctx, allow_none=True)
303+
if not same:
304+
return None
305+
return same.dyn_size_ext
306+
238307
@property
239308
def dyn_size(self):
240309
"""
@@ -507,6 +576,8 @@ def declare_same_as(self, other):
507576
"""
508577
:param DimensionTag other:
509578
"""
579+
self._maybe_update()
580+
self._validate_in_current_graph()
510581
if self is other:
511582
return
512583
other_same_base = other.get_same_base()
@@ -517,40 +588,66 @@ def declare_same_as(self, other):
517588
assert not self_same_as.same_as
518589
if self_same_as is other_same_base:
519590
return
591+
other_same_base._merge_same_for_batch_ctx_dict(self_same_as)
520592
self_same_as.same_as = other_same_base
521593
self_same_as._same_as_tb = traceback.extract_stack()
522-
if self_same_as.dyn_size_ext is None:
523-
self_same_as.dyn_size_ext = other_same_base.dyn_size_ext
524-
elif other_same_base.dyn_size_ext is None:
525-
other_same_base.dyn_size_ext = self_same_as.dyn_size_ext
526-
if self.dyn_size_ext is None and self_same_as.dyn_size_ext:
527-
self.dyn_size_ext = self_same_as.dyn_size_ext.copy_extend_with_beam(self.batch.beam if self.batch else None)
594+
if self_same_as.dyn_size_ext is None or not self_same_as._validate_in_current_graph():
595+
self_same_as.dyn_size_ext = other_same_base.get_dyn_size_ext_for_batch_ctx(
596+
self_same_as.batch, self_same_as.control_flow_ctx)
597+
elif other_same_base.dyn_size_ext is None or not other_same_base._validate_in_current_graph():
598+
other_same_base.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(
599+
other_same_base.batch, other_same_base.control_flow_ctx)
600+
if (self.dyn_size_ext is None or not self._validate_in_current_graph()) and self_same_as.dyn_size_ext:
601+
self.dyn_size_ext = self_same_as.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx)
602+
other_same_base._merge_same_for_batch_ctx_dict(self)
528603
self.same_as = other_same_base
529604
self._same_as_tb = traceback.extract_stack()
605+
self._maybe_update()
530606
if self.dyn_size is not None and other_same_base.dyn_size is not None:
531607
if self.dyn_size is not other_same_base.dyn_size:
532-
if self.batch == other_same_base.batch:
608+
if self.batch == other_same_base.batch and self.control_flow_ctx == other_same_base.control_flow_ctx:
533609
# Note: Instead of making this a warning, we could also enforce this at some point.
534610
# The user should be able to fix `extern_data` in the config such that this is correct in the first place.
535611
# Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes.
536612
print(
537-
"Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (self, other_same_base))
613+
"Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (
614+
self.dyn_size, other_same_base.dyn_size))
538615
# If we have a defined source, and this is a dynamic spatial axis, and it was undefined before,
539616
# maybe we can overtake the size_placeholder now.
540617
if self.same_as.dyn_size is not None and self.src_data:
541618
assert isinstance(self.src_axis, int)
542619
# Maybe it changed in the meanwhile, so check.
543-
if self.src_data.get_dim_tag(self.src_axis).description == self.description:
544-
self.src_data.size_placeholder[
545-
self.src_data.get_batch_axis_excluding_batch(self.src_axis)] = self.same_as.dyn_size
620+
tag = self.src_data.get_dim_tag(self.src_axis)
621+
if tag.description == self.description and (not tag.dyn_size_ext or not tag._validate_in_current_graph()):
622+
tag.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(tag.batch, tag.control_flow_ctx)
546623
# If others dyn_size is None but we have a dyn_size, maybe update others dyn_size.
547624
if self.dyn_size is not None and self.same_as.dyn_size is not self.dyn_size:
548625
# Could be unset if it comes from the config, or from prev graph creation.
549626
# This is important such that self.can_compare() is sane.
550-
if self.same_as.dyn_size is None or self.same_as.dyn_size.graph is not self.dyn_size.graph:
551-
self.same_as.dyn_size_ext = self.dyn_size_ext
552-
if not self.dyn_size_ext and other.dyn_size_ext:
553-
self.dyn_size_ext = other.dyn_size_ext.copy()
627+
if self.same_as.dyn_size is None or not self.same_as._validate_in_current_graph():
628+
self.same_as.dyn_size_ext = self.get_dyn_size_ext_for_batch_ctx(
629+
self.same_as.batch, self.same_as.control_flow_ctx)
630+
if (not self.dyn_size_ext or not self._validate_in_current_graph()) and other.dyn_size_ext:
631+
self.dyn_size_ext = other.get_dyn_size_ext_for_batch_ctx(self.batch, self.control_flow_ctx)
632+
633+
def _merge_same_for_batch_ctx_dict(self, other):
634+
"""
635+
:param DimensionTag other:
636+
"""
637+
self._validate_in_current_graph()
638+
for _, dim in list(self._same_for_batch_ctx.items()):
639+
assert isinstance(dim, DimensionTag)
640+
dim._validate_in_current_graph()
641+
for key, dim in other._same_for_batch_ctx.items():
642+
if not dim._validate_in_current_graph():
643+
continue
644+
self_dim = self._same_for_batch_ctx.get(key, None)
645+
if self_dim and (self_dim.dyn_size_ext or not dim.dyn_size_ext):
646+
continue # keep ours
647+
if not dim.dyn_size_ext:
648+
continue # undefined, do not overtake
649+
self._same_for_batch_ctx[key] = dim
650+
other._same_for_batch_ctx.clear() # we only want to have it once
554651

555652
@classmethod
556653
def get_existing_tag_from_collection(cls, other, tags, is_equal_opts=None):

0 commit comments

Comments
 (0)