Skip to content

Commit 10cffc1

Browse files
committed
Add test for optimize_out_slice_nd
1 parent 6410df5 commit 10cffc1

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn():
33283328
print("rnn_cell also fine.")
33293329

33303330

3331-
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4):
3331+
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, from_=None, rtol=1e-4):
33323332
"""
33333333
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
33343334
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
@@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33443344
subnet_layer_dict.setdefault("from", ["data:source"])
33453345
rec_layer_dict = {
33463346
"class": "rec",
3347-
"from": ["data"],
3347+
"from": ["data"] if from_ is None else [from_],
33483348
"unit": {"output": subnet_layer_dict},
33493349
"n_out": n_out,
33503350
"is_output_layer": True
@@ -3598,6 +3598,31 @@ def test_reclayer_optimize_out_access_split():
35983598
other_subnet_layers={"split": {"class": "split", "from": ["data:source"], "size_splits": [5, 8]}})
35993599

36003600

3601+
def test_reclayer_optimize_out_slice_nd():
3602+
def random_start_positions(source, **kwargs):
3603+
import tensorflow as tf
3604+
enc = source(0, as_data=True, enforce_batch_major=True, auto_convert=False)
3605+
enc_shape = tf.shape(enc.placeholder)
3606+
enc_time_dim = enc_shape[enc.time_dim_axis]
3607+
return tf.random.uniform(enc_shape[:-1], 0, enc_time_dim-2, dtype=tf.dtypes.int32)
3608+
3609+
check_reclayer_optimize_out(
3610+
{"class": "linear", "activation": None, "from": ["encoder_reduced"]},
3611+
from_="position",
3612+
other_subnet_layers={
3613+
"window": {"class": "slice_nd", "from": "base:encoder", "start": "data:source", "size": None, "min_size": 1, "is_output_layer": True},
3614+
"encoder_reduced": {"class": "reduce", "mode": "sum", "axis": "T", "from": ["base:encoder"], "is_output_layer": True}
3615+
},
3616+
shared_base_net={
3617+
"encoder": {"class": "copy", "from": "data", "is_output_layer": True},
3618+
"position": {
3619+
"class": "eval", "from": "encoder", "is_output_layer": True,
3620+
"eval": random_start_positions,
3621+
"out_type": {"batch_dim_axis": 0, "time_dim_axis": 1, "shape": (None,), "sparse": True, "dtype": "int32", "dim": None}}
3622+
}
3623+
)
3624+
3625+
36013626
def test_reclayer_att_with_kv_in_rec():
36023627
net_dict = {
36033628
'decision': {'class': 'decide', 'from': ['output'], 'loss': 'edit_distance', 'loss_opts': {}, 'target': 'classes'},

0 commit comments

Comments
 (0)