@@ -82,14 +82,14 @@ def __init__(self, kind=Types.Unspecified, description=None,
8282 if dyn_size_ext :
8383 assert batch == dyn_size_ext .batch
8484 self .dyn_size_ext = dyn_size_ext # type: typing.Optional[Data]
85- if dyn_size is not None :
86- assert not dyn_size_ext
87- self .dyn_size = dyn_size
8885 self ._dyn_size_same = set () # type: typing.Set[tf.Tensor]
8986 self ._undefined = undefined
9087 # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx.
9188 # They each have same_as = self. The same_base should have the base (global) batch info.
9289 self ._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],DimensionTag] # nopep8
90+ if dyn_size is not None :
91+ assert not dyn_size_ext
92+ self .dyn_size = dyn_size
9393
9494 def __repr__ (self ):
9595 return "DimensionTag{%s}" % self .short_repr ()
@@ -149,13 +149,54 @@ def _can_use_in_ctx(self, ctx):
149149 return False
150150 return True
151151
152- def get_for_batch_ctx (self , batch , ctx ):
152+ def _validate_in_current_graph (self ):
153+ """
154+ :rtype: bool
155+ """
156+ tensor = None
157+ if self .batch :
158+ batch_base = self .batch .get_global_base ()
159+ if batch_base .is_global_batch ():
160+ tensor = batch_base .get_global_batch_dim ().size
161+ if not isinstance (tensor , tf .Tensor ):
162+ if self .dyn_size_ext and self .dyn_size_ext .placeholder is not None :
163+ tensor = self .dyn_size_ext .placeholder
164+ if isinstance (tensor , tf .Tensor ):
165+ g = tf_compat .v1 .get_default_graph ()
166+ if tensor .graph is not g : # maybe from an earlier run which reuses the dim tag
167+ # Reset and cleanup.
168+ self .dyn_size_ext = None
169+ same_base = self .get_same_base ()
170+ same_base ._same_for_batch_ctx .pop ((self .batch , self .control_flow_ctx ), None )
171+ self .batch = None # it is invalid in the new graph
172+ self .control_flow_ctx = None # also invalid
173+ return False
174+ return True
175+
176+ def _maybe_update (self ):
177+ if self .is_batch_dim ():
178+ return
179+ if isinstance (self .dimension , int ):
180+ return
181+ if self .dyn_size_ext :
182+ return
183+ if not self .batch :
184+ return
185+ # Check if we can find more in
186+ same = self .get_for_batch_ctx (self .batch , self .control_flow_ctx , allow_none = True )
187+ if self is same or not same or not same .dyn_size_ext :
188+ return
189+ self .dyn_size_ext = same .dyn_size_ext
190+
191+ def get_for_batch_ctx (self , batch , ctx , allow_none = False ):
153192 """
154193 :param BatchInfo batch:
155194 :param ControlFlowContext|None ctx:
156- :rtype: DimensionTag
195+ :param bool allow_none:
196+ :rtype: DimensionTag|None
157197 """
158- if self .batch == batch and self .control_flow_ctx == ctx :
198+ if self .batch == batch and self .control_flow_ctx == ctx and self .dyn_size_ext :
199+ self ._validate_in_current_graph ()
159200 return self
160201 if self .is_batch_dim ():
161202 # We ignore the ctx for the batch dim currently.
@@ -169,6 +210,7 @@ def get_for_batch_ctx(self, batch, ctx):
169210 if batch .is_broadcast ():
170211 return self # just leave as-is. should not matter.
171212 same_base = self .get_same_base ()
213+ same_base ._validate_in_current_graph ()
172214 # Might be uninitialized in some cases. Assume batch is global.
173215 if not same_base .batch :
174216 batch_base = batch .get_global_base ()
@@ -182,27 +224,24 @@ def get_for_batch_ctx(self, batch, ctx):
182224 if same_base .dyn_size_ext :
183225 assert same_base .batch == same_base .dyn_size_ext .batch
184226 assert same_base .control_flow_ctx == same_base .dyn_size_ext .control_flow_ctx
185- tag = same_base ._same_for_batch_ctx .get ((batch , ctx ), None )
186- if tag :
187- return tag
188- if same_base .batch == batch and same_base ._can_use_in_ctx (ctx ):
189- return same_base
190227 for ctx_ in ControlFlowContext .abs_ctx_stack_with_root (ctx ):
191228 tag = same_base ._same_for_batch_ctx .get ((batch , ctx_ ), None )
192- if tag and tag ._can_use_in_ctx (ctx ):
229+ if tag and tag ._can_use_in_ctx (ctx ) and tag . _validate_in_current_graph () :
193230 return tag
231+ if same_base .batch == batch and same_base ._can_use_in_ctx (ctx ) and same_base .dyn_size_ext :
232+ return same_base
194233 # Ok, nothing matching found.
195234 dyn_size_ext = None
196235 # Maybe we have sth with the base batch without beam which we can extend.
197236 if batch .copy_remove_beam () == batch .get_global_base () and batch .beam :
198237 batch_base = batch .get_global_base ()
199238 base_can_use_in_ctx = None
200- if same_base .batch == batch_base and same_base ._can_use_in_ctx (ctx ):
239+ if same_base .batch == batch_base and same_base ._can_use_in_ctx (ctx ) and same_base . dyn_size_ext :
201240 base_can_use_in_ctx = same_base
202241 else :
203242 for ctx_ in ControlFlowContext .abs_ctx_stack_with_root (ctx ):
204243 tag = same_base ._same_for_batch_ctx .get ((batch_base , ctx_ ), None )
205- if tag and tag ._can_use_in_ctx (ctx ):
244+ if tag and tag ._can_use_in_ctx (ctx ) and tag . _validate_in_current_graph () and tag . dyn_size_ext :
206245 base_can_use_in_ctx = tag
207246 break
208247 if base_can_use_in_ctx and base_can_use_in_ctx .dyn_size_ext :
@@ -224,6 +263,8 @@ def get_for_batch_ctx(self, batch, ctx):
224263 name = get_valid_scope_name_from_str ("%s_identity_for_beam_%s" % (dyn_size_ext .name , batch .beam .name )))
225264 dyn_size_ext .placeholder ._RETURNN_dyn_size_beam = batch .beam
226265 dyn_size_ext .placeholder ._RETURNN_beam_expanded_base_data = beam_expanded_base_data
266+ if not dyn_size_ext and allow_none :
267+ return None
227268 dim_tag = DimensionTag (
228269 kind = self .kind , description = self .description , dimension = self .dimension ,
229270 batch = batch , control_flow_ctx = dyn_size_ext .control_flow_ctx if dyn_size_ext else ctx ,
@@ -235,6 +276,34 @@ def get_for_batch_ctx(self, batch, ctx):
235276 same_base ._same_for_batch_ctx [(dim_tag .batch , dim_tag .control_flow_ctx )] = dim_tag
236277 return dim_tag
237278
279+ def set_dyn_size_ext_for_batch_ctx (self , batch , ctx , dyn_size_ext ):
280+ """
281+ :param BatchInfo batch:
282+ :param ControlFlowContext|None ctx:
283+ :param Data dyn_size_ext:
284+ """
285+ same = self .get_for_batch_ctx (batch , ctx )
286+ same .dyn_size_ext = dyn_size_ext
287+ self ._maybe_update ()
288+
289+ def get_dyn_size_ext_for_batch_ctx (self , batch , ctx ):
290+ """
291+ :param BatchInfo|None batch:
292+ :param ControlFlowContext|None ctx:
293+ :rtype: Data|None
294+ """
295+ if not batch and self .batch :
296+ # Assume global batch.
297+ batch = self .batch .get_global_base ()
298+ if not batch :
299+ # This is usually not valid. However, this case can happen early at initialization.
300+ assert batch == self .batch and ctx == self .control_flow_ctx
301+ return self .dyn_size_ext
302+ same = self .get_for_batch_ctx (batch , ctx , allow_none = True )
303+ if not same :
304+ return None
305+ return same .dyn_size_ext
306+
238307 @property
239308 def dyn_size (self ):
240309 """
@@ -507,6 +576,8 @@ def declare_same_as(self, other):
507576 """
508577 :param DimensionTag other:
509578 """
579+ self ._maybe_update ()
580+ self ._validate_in_current_graph ()
510581 if self is other :
511582 return
512583 other_same_base = other .get_same_base ()
@@ -517,40 +588,66 @@ def declare_same_as(self, other):
517588 assert not self_same_as .same_as
518589 if self_same_as is other_same_base :
519590 return
591+ other_same_base ._merge_same_for_batch_ctx_dict (self_same_as )
520592 self_same_as .same_as = other_same_base
521593 self_same_as ._same_as_tb = traceback .extract_stack ()
522- if self_same_as .dyn_size_ext is None :
523- self_same_as .dyn_size_ext = other_same_base .dyn_size_ext
524- elif other_same_base .dyn_size_ext is None :
525- other_same_base .dyn_size_ext = self_same_as .dyn_size_ext
526- if self .dyn_size_ext is None and self_same_as .dyn_size_ext :
527- self .dyn_size_ext = self_same_as .dyn_size_ext .copy_extend_with_beam (self .batch .beam if self .batch else None )
594+ if self_same_as .dyn_size_ext is None or not self_same_as ._validate_in_current_graph ():
595+ self_same_as .dyn_size_ext = other_same_base .get_dyn_size_ext_for_batch_ctx (
596+ self_same_as .batch , self_same_as .control_flow_ctx )
597+ elif other_same_base .dyn_size_ext is None or not other_same_base ._validate_in_current_graph ():
598+ other_same_base .dyn_size_ext = self_same_as .get_dyn_size_ext_for_batch_ctx (
599+ other_same_base .batch , other_same_base .control_flow_ctx )
600+ if (self .dyn_size_ext is None or not self ._validate_in_current_graph ()) and self_same_as .dyn_size_ext :
601+ self .dyn_size_ext = self_same_as .get_dyn_size_ext_for_batch_ctx (self .batch , self .control_flow_ctx )
602+ other_same_base ._merge_same_for_batch_ctx_dict (self )
528603 self .same_as = other_same_base
529604 self ._same_as_tb = traceback .extract_stack ()
605+ self ._maybe_update ()
530606 if self .dyn_size is not None and other_same_base .dyn_size is not None :
531607 if self .dyn_size is not other_same_base .dyn_size :
532- if self .batch == other_same_base .batch :
608+ if self .batch == other_same_base .batch and self . control_flow_ctx == other_same_base . control_flow_ctx :
533609 # Note: Instead of making this a warning, we could also enforce this at some point.
534610 # The user should be able to fix `extern_data` in the config such that this is correct in the first place.
535611 # Also, in addition to this warning, we might want to add some runtime check on the eq of the dyn sizes.
536612 print (
537- "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (self , other_same_base ))
613+ "Warning: assuming dim tags are same with different size placeholders: %r vs %r" % (
614+ self .dyn_size , other_same_base .dyn_size ))
538615 # If we have a defined source, and this is a dynamic spatial axis, and it was undefined before,
539616 # maybe we can overtake the size_placeholder now.
540617 if self .same_as .dyn_size is not None and self .src_data :
541618 assert isinstance (self .src_axis , int )
542619 # Maybe it changed in the meanwhile, so check.
543- if self .src_data .get_dim_tag (self .src_axis ). description == self . description :
544- self .src_data . size_placeholder [
545- self .src_data . get_batch_axis_excluding_batch ( self . src_axis )] = self . same_as . dyn_size
620+ tag = self .src_data .get_dim_tag (self .src_axis )
621+ if tag . description == self .description and ( not tag . dyn_size_ext or not tag . _validate_in_current_graph ()):
622+ tag . dyn_size_ext = self .get_dyn_size_ext_for_batch_ctx ( tag . batch , tag . control_flow_ctx )
546623 # If others dyn_size is None but we have a dyn_size, maybe update others dyn_size.
547624 if self .dyn_size is not None and self .same_as .dyn_size is not self .dyn_size :
548625 # Could be unset if it comes from the config, or from prev graph creation.
549626 # This is important such that self.can_compare() is sane.
550- if self .same_as .dyn_size is None or self .same_as .dyn_size .graph is not self .dyn_size .graph :
551- self .same_as .dyn_size_ext = self .dyn_size_ext
552- if not self .dyn_size_ext and other .dyn_size_ext :
553- self .dyn_size_ext = other .dyn_size_ext .copy ()
627+ if self .same_as .dyn_size is None or not self .same_as ._validate_in_current_graph ():
628+ self .same_as .dyn_size_ext = self .get_dyn_size_ext_for_batch_ctx (
629+ self .same_as .batch , self .same_as .control_flow_ctx )
630+ if (not self .dyn_size_ext or not self ._validate_in_current_graph ()) and other .dyn_size_ext :
631+ self .dyn_size_ext = other .get_dyn_size_ext_for_batch_ctx (self .batch , self .control_flow_ctx )
632+
633+ def _merge_same_for_batch_ctx_dict (self , other ):
634+ """
635+ :param DimensionTag other:
636+ """
637+ self ._validate_in_current_graph ()
638+ for _ , dim in list (self ._same_for_batch_ctx .items ()):
639+ assert isinstance (dim , DimensionTag )
640+ dim ._validate_in_current_graph ()
641+ for key , dim in other ._same_for_batch_ctx .items ():
642+ if not dim ._validate_in_current_graph ():
643+ continue
644+ self_dim = self ._same_for_batch_ctx .get (key , None )
645+ if self_dim and (self_dim .dyn_size_ext or not dim .dyn_size_ext ):
646+ continue # keep ours
647+ if not dim .dyn_size_ext :
648+ continue # undefined, do not overtake
649+ self ._same_for_batch_ctx [key ] = dim
650+ other ._same_for_batch_ctx .clear () # we only want to have it once
554651
555652 @classmethod
556653 def get_existing_tag_from_collection (cls , other , tags , is_equal_opts = None ):
0 commit comments