Skip to content

Commit bd72b69

Browse files
committed
RecLayer subnet, more flexible out type logic
1 parent 4b088a5 commit bd72b69

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

returnn/tf/layers/rec.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -395,27 +395,27 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i
395395
loss = kwargs.get("loss", None)
396396
deps = list(sources) # type: typing.List[LayerBase]
397397
deps += [layer for layer in nest.flatten(initial_state) if isinstance(layer, LayerBase)]
398-
if out_type or n_out is not NotSpecified or loss:
398+
if isinstance(unit, _SubnetworkRecCell): # subnetwork
399+
subnet = unit
400+
out = (
401+
subnet.layer_data_templates["output"].output
402+
.copy_template_adding_time_dim(name="%s_output" % kwargs["name"], time_dim_axis=0)
403+
.copy_template_set_ctx(network.get_control_flow_ctx()))
404+
if n_out is not NotSpecified:
405+
assert n_out == out.dim
406+
if out_type:
407+
for k, v in out_type.items():
408+
assert getattr(out, k) == v
409+
deps += subnet.get_parent_deps()
410+
elif out_type or n_out is not NotSpecified or loss:
399411
out = super(RecLayer, cls).get_out_data_from_opts(network=network, sources=sources, **kwargs)
400412
if source_data and not source_data.have_time_axis():
401413
# We expect to be inside another RecLayer, and should do a single step (like RnnCellLayer).
402414
out = out.copy_as_batch_major() # The output is then [B,F]
403415
else:
404416
out = out.copy_as_time_batch_major() # Otherwise the output is always [T,B,F]
405417
else:
406-
out = None
407-
if isinstance(unit, _SubnetworkRecCell): # subnetwork
408-
subnet = unit
409-
sub_out = (
410-
subnet.layer_data_templates["output"].output
411-
.copy_template_adding_time_dim(name="%s_output" % kwargs["name"], time_dim_axis=0)
412-
.copy_template_set_ctx(network.get_control_flow_ctx()))
413-
if out:
414-
assert sub_out.dim == out.dim
415-
assert sub_out.shape == out.shape
416-
out = sub_out
417-
deps += subnet.get_parent_deps()
418-
assert out
418+
raise Exception("n_out or out_type must be specified")
419419
if out.have_time_axis() and _time_dim_tag:
420420
out = out.copy_template_replace_dim_tag(axis=out.time_dim_axis, new_dim_tag=_time_dim_tag)
421421
for dep in deps:

0 commit comments

Comments
 (0)