diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 51c803895..86e1c8ecc 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -5566,6 +5566,73 @@ def test_subnet_deps_search(): print("out:", session.run(out, feed_dict=feed_dict)) +def test_RecLayer_loss_subsubnet(): + with make_scope() as session: + config = Config() + n_in, n_out = 2, 3 + net_dict = { + 'output': { + 'class': 'rec', + 'from': 'data:classes', + 'optimize_move_layers_out': False, # just to trigger the potential bug + 'unit': { + 'output': { + 'class': 'subnetwork', + 'from': 'prev:tag_embedding', + 'subnetwork': { + 'hidden': { + 'class': 'rnn_cell', + 'n_out': 10, + 'unit': 'LSTMBlock' + }, + 'output': { + 'class': 'linear', + 'from': 'hidden', + 'n_out': n_out, + 'target': 'classes', + 'activation': 'softmax', + 'loss': 'ce', + } + }, + }, + 'tag_embedding': { + 'activation': None, + 'class': 'linear', + 'from': 'data:source', + 'n_out': 10, + 'with_bias': True + }, + }, + }, + } + from returnn.tf.util.basic import DimensionTag + dec_time = DimensionTag(kind=DimensionTag.Types.Spatial, description="combined_time") + config.update({ + "debug_print_layer_output_template": True, + "calculate_exp_loss": True, + "extern_data": { + 'classes': { + 'available_for_inference': True, + 'dim': n_out, + 'same_dim_tags_as': {'t': dec_time}, + 'sparse': True, + "batch_dim_axis": 0, + "time_dim_axis": 1, + } + }, + "network": net_dict}) + from returnn.tf.util.basic import get_global_train_flag_placeholder + train_flag = get_global_train_flag_placeholder() + network = TFNetwork(config=config, train_flag=train_flag) + network.construct_from_dict(config.typed_dict["network"]) + loss = network.get_objective() + network.initialize_params(session) + from test_TFNetworkLayer import make_feed_dict + feed_dict = make_feed_dict(network.extern_data) + feed_dict[train_flag] = True + session.run(loss, feed_dict=feed_dict) + + def test_untrainable_sublayers(): with make_scope() as session: config = Config()