@@ -395,27 +395,27 @@ def get_out_data_from_opts(cls, network, unit, _time_dim_tag=None, sources=(), i
395395 loss = kwargs .get ("loss" , None )
396396 deps = list (sources ) # type: typing.List[LayerBase]
397397 deps += [layer for layer in nest .flatten (initial_state ) if isinstance (layer , LayerBase )]
398- if out_type or n_out is not NotSpecified or loss :
398+ if isinstance (unit , _SubnetworkRecCell ): # subnetwork
399+ subnet = unit
400+ out = (
401+ subnet .layer_data_templates ["output" ].output
402+ .copy_template_adding_time_dim (name = "%s_output" % kwargs ["name" ], time_dim_axis = 0 )
403+ .copy_template_set_ctx (network .get_control_flow_ctx ()))
404+ if n_out is not NotSpecified :
405+ assert n_out == out .dim
406+ if out_type :
407+ for k , v in out_type .items ():
408+ assert getattr (out , k ) == v
409+ deps += subnet .get_parent_deps ()
410+ elif out_type or n_out is not NotSpecified or loss :
399411 out = super (RecLayer , cls ).get_out_data_from_opts (network = network , sources = sources , ** kwargs )
400412 if source_data and not source_data .have_time_axis ():
401413 # We expect to be inside another RecLayer, and should do a single step (like RnnCellLayer).
402414 out = out .copy_as_batch_major () # The output is then [B,F]
403415 else :
404416 out = out .copy_as_time_batch_major () # Otherwise the output is always [T,B,F]
405417 else :
406- out = None
407- if isinstance (unit , _SubnetworkRecCell ): # subnetwork
408- subnet = unit
409- sub_out = (
410- subnet .layer_data_templates ["output" ].output
411- .copy_template_adding_time_dim (name = "%s_output" % kwargs ["name" ], time_dim_axis = 0 )
412- .copy_template_set_ctx (network .get_control_flow_ctx ()))
413- if out :
414- assert sub_out .dim == out .dim
415- assert sub_out .shape == out .shape
416- out = sub_out
417- deps += subnet .get_parent_deps ()
418- assert out
418+ raise Exception ("n_out or out_type must be specified" )
419419 if out .have_time_axis () and _time_dim_tag :
420420 out = out .copy_template_replace_dim_tag (axis = out .time_dim_axis , new_dim_tag = _time_dim_tag )
421421 for dep in deps :
0 commit comments