@@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn():
33283328 print ("rnn_cell also fine." )
33293329
33303330
3331- def check_reclayer_optimize_out (subnet_layer_dict , other_subnet_layers = None , shared_base_net = None , rtol = 1e-4 ):
3331+ def check_reclayer_optimize_out (subnet_layer_dict , other_subnet_layers = None , shared_base_net = None , from_ = None , rtol = 1e-4 ):
33323332 """
33333333 :param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
33343334 :param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
@@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33443344 subnet_layer_dict .setdefault ("from" , ["data:source" ])
33453345 rec_layer_dict = {
33463346 "class" : "rec" ,
3347- "from" : ["data" ],
3347+ "from" : ["data" ] if from_ is None else [ from_ ] ,
33483348 "unit" : {"output" : subnet_layer_dict },
33493349 "n_out" : n_out ,
33503350 "is_output_layer" : True
@@ -3598,6 +3598,31 @@ def test_reclayer_optimize_out_access_split():
35983598 other_subnet_layers = {"split" : {"class" : "split" , "from" : ["data:source" ], "size_splits" : [5 , 8 ]}})
35993599
36003600
3601+ def test_reclayer_optimize_out_slice_nd ():
3602+ def random_start_positions (source , ** kwargs ):
3603+ import tensorflow as tf
3604+ enc = source (0 , as_data = True , enforce_batch_major = True , auto_convert = False )
3605+ enc_shape = tf .shape (enc .placeholder )
3606+ enc_time_dim = enc_shape [enc .time_dim_axis ]
3607+ return tf .random .uniform (enc_shape [:- 1 ], 0 , enc_time_dim - 2 , dtype = tf .dtypes .int32 )
3608+
3609+ check_reclayer_optimize_out (
3610+ {"class" : "linear" , "activation" : None , "from" : ["encoder_reduced" ]},
3611+ from_ = "position" ,
3612+ other_subnet_layers = {
3613+ "window" : {"class" : "slice_nd" , "from" : "base:encoder" , "start" : "data:source" , "size" : None , "min_size" : 1 , "is_output_layer" : True },
3614+ "encoder_reduced" : {"class" : "reduce" , "mode" : "sum" , "axis" : "T" , "from" : ["base:encoder" ], "is_output_layer" : True }
3615+ },
3616+ shared_base_net = {
3617+ "encoder" : {"class" : "copy" , "from" : "data" , "is_output_layer" : True },
3618+ "position" : {
3619+ "class" : "eval" , "from" : "encoder" , "is_output_layer" : True ,
3620+ "eval" : random_start_positions ,
3621+ "out_type" : {"batch_dim_axis" : 0 , "time_dim_axis" : 1 , "shape" : (None ,), "sparse" : True , "dtype" : "int32" , "dim" : None }}
3622+ }
3623+ )
3624+
3625+
36013626def test_reclayer_att_with_kv_in_rec ():
36023627 net_dict = {
36033628 'decision' : {'class' : 'decide' , 'from' : ['output' ], 'loss' : 'edit_distance' , 'loss_opts' : {}, 'target' : 'classes' },
0 commit comments