@@ -3506,6 +3506,42 @@ def test_reclayer_optimize_out_cum_concat_gen_self_att():
35063506 })
35073507
35083508
3509+ def test_reclayer_optimize_out_accum_loop_dyn_size ():
3510+ # We want to test for the case where some layer inside the loop
3511+ # generates some dyn size of shape [B] which is different in each loop frame.
3512+ # So outside the loop, the accumulated dyn size should be of shape [T,B] or [B,T].
3513+ # To test this, we first generate some random seq lens based on the input data (shape [B,T,D]).
3514+ from returnn .tf .util .basic import py_print
3515+
3516+ def _eval_seq_lens (source , ** _kwargs ):
3517+ # Get some random varying seq lens.
3518+ res = tf .cast (4. * source (0 ) / source (1 ) + 0.3 * tf .cast (source (2 ), tf .float32 ), tf .int32 ) + 1
3519+ res = py_print (res , ["seq lens" , res , "step :i" , source (2 )])
3520+ return res
3521+
3522+ check_reclayer_optimize_out (
3523+ subnet_layer_dict = {"class" : "linear" , "from" : "combine" , "activation" : None , "n_out" : 3 },
3524+ other_subnet_layers = {
3525+ "exp_data" : {"class" : "activation" , "from" : "data:source" , "activation" : "exp" }, # >0
3526+ "sum_exp_data" : {"class" : "reduce" , "mode" : "sum" , "from" : "exp_data" , "axis" : "F" }, # [B]
3527+ "seq_lens" : {
3528+ "class" : "eval" , "from" : ["sum_exp_data" , "base:max_sum_exp_data" , ":i" ],
3529+ "out_type" : {"dtype" : "int32" },
3530+ "eval" : _eval_seq_lens }, # [B]
3531+ "range" : {"class" : "range_from_length" , "from" : "seq_lens" }, # [T_new]
3532+ "combine" : {
3533+ "class" : "eval" , "from" : ["data:source" , "range" ],
3534+ "eval" : "source(0) + 0.1 * tf.cast(source(1), tf.float32)" }, # [B,T_new,D]
3535+ },
3536+ shared_base_net = {
3537+ "exp_data" : {"class" : "activation" , "from" : "data" , "activation" : "exp" }, # >0
3538+ "sum_exp_data" : {"class" : "reduce" , "mode" : "sum" , "from" : "exp_data" , "axis" : "F" }, # [B,T]
3539+ "max_sum_exp_data" : {
3540+ "class" : "reduce" , "mode" : "max" , "from" : "sum_exp_data" , "axis" : "T" ,
3541+ "is_output_layer" : True }, # [B]
3542+ })
3543+
3544+
35093545def test_reclayer_optimize_out_dot ():
35103546 # Used for multi-head dot-attention.
35113547 AttNumHeads = 4
0 commit comments