Skip to content

Commit 3d17449

Browse files
committed
test_reclayer_optimize_out_accum_loop_dyn_size
1 parent 6a89d62 commit 3d17449

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3506,6 +3506,42 @@ def test_reclayer_optimize_out_cum_concat_gen_self_att():
35063506
})
35073507

35083508

3509+
def test_reclayer_optimize_out_accum_loop_dyn_size():
3510+
# We want to test for the case where some layer inside the loop
3511+
# generates some dyn size of shape [B] which is different in each loop frame.
3512+
# So outside the loop, the accumulated dyn size should be of shape [T,B] or [B,T].
3513+
# To test this, we first generate some random seq lens based on the input data (shape [B,T,D]).
3514+
from returnn.tf.util.basic import py_print
3515+
3516+
def _eval_seq_lens(source, **_kwargs):
3517+
# Get some random varying seq lens.
3518+
res = tf.cast(4. * source(0) / source(1) + 0.3 * tf.cast(source(2), tf.float32), tf.int32) + 1
3519+
res = py_print(res, ["seq lens", res, "step :i", source(2)])
3520+
return res
3521+
3522+
check_reclayer_optimize_out(
3523+
subnet_layer_dict={"class": "linear", "from": "combine", "activation": None, "n_out": 3},
3524+
other_subnet_layers={
3525+
"exp_data": {"class": "activation", "from": "data:source", "activation": "exp"}, # >0
3526+
"sum_exp_data": {"class": "reduce", "mode": "sum", "from": "exp_data", "axis": "F"}, # [B]
3527+
"seq_lens": {
3528+
"class": "eval", "from": ["sum_exp_data", "base:max_sum_exp_data", ":i"],
3529+
"out_type": {"dtype": "int32"},
3530+
"eval": _eval_seq_lens}, # [B]
3531+
"range": {"class": "range_from_length", "from": "seq_lens"}, # [T_new]
3532+
"combine": {
3533+
"class": "eval", "from": ["data:source", "range"],
3534+
"eval": "source(0) + 0.1 * tf.cast(source(1), tf.float32)"}, # [B,T_new,D]
3535+
},
3536+
shared_base_net={
3537+
"exp_data": {"class": "activation", "from": "data", "activation": "exp"}, # >0
3538+
"sum_exp_data": {"class": "reduce", "mode": "sum", "from": "exp_data", "axis": "F"}, # [B,T]
3539+
"max_sum_exp_data": {
3540+
"class": "reduce", "mode": "max", "from": "sum_exp_data", "axis": "T",
3541+
"is_output_layer": True}, # [B]
3542+
})
3543+
3544+
35093545
def test_reclayer_optimize_out_dot():
35103546
# Used for multi-head dot-attention.
35113547
AttNumHeads = 4

0 commit comments

Comments
 (0)