Skip to content

Commit a75c364

Browse files
committed
Add test for optimize_out_slice_nd
1 parent 6410df5 commit a75c364

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn():
33283328
print("rnn_cell also fine.")
33293329

33303330

3331-
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4):
3331+
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, from_=None, rtol=1e-4):
33323332
"""
33333333
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
33343334
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
@@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33443344
subnet_layer_dict.setdefault("from", ["data:source"])
33453345
rec_layer_dict = {
33463346
"class": "rec",
3347-
"from": ["data"],
3347+
"from": ["data"] if from_ is None else [from_],
33483348
"unit": {"output": subnet_layer_dict},
33493349
"n_out": n_out,
33503350
"is_output_layer": True
@@ -3353,7 +3353,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33533353
assert "output" not in other_subnet_layers
33543354
rec_layer_dict["unit"].update(other_subnet_layers)
33553355
config = Config({
3356-
"debug_print_layer_output_template": True,
3356+
"debug_print_layer_output_template": False,
33573357
"num_inputs": n_in,
33583358
"num_outputs": n_out
33593359
})
@@ -3598,6 +3598,21 @@ def test_reclayer_optimize_out_access_split():
35983598
other_subnet_layers={"split": {"class": "split", "from": ["data:source"], "size_splits": [5, 8]}})
35993599

36003600

3601+
def test_reclayer_optimize_out_slice_nd():
3602+
check_reclayer_optimize_out(
3603+
{"class": "linear", "activation": None, "from": ["encoder_reduced"]},
3604+
from_="position",
3605+
other_subnet_layers={
3606+
"window": {"class": "slice_nd", "from": "base:encoder", "start": "data:source", "size": None, "min_size": 1, "is_output_layer": True},
3607+
"encoder_reduced": {"class": "reduce", "mode": "sum", "axis": "T", "from": ["base:encoder"], "is_output_layer": True}},
3608+
shared_base_net={
3609+
"encoder": {"class": "copy", "from": ["data"], "is_output_layer": True},
3610+
"position": {"class": "eval", "from": ["encoder"], "is_output_layer": True,
3611+
"eval": "tf.zeros(tf.shape(source(0, enforce_batch_major=True, auto_convert=False))[:-1], dtype=tf.dtypes.int32)",
3612+
"out_type": {"batch_dim_axis": 0, "time_dim_axis": 1, "shape": (None,),
3613+
"sparse": True, "dtype": "int32", "dim": None}}})
3614+
3615+
36013616
def test_reclayer_att_with_kv_in_rec():
36023617
net_dict = {
36033618
'decision': {'class': 'decide', 'from': ['output'], 'loss': 'edit_distance', 'loss_opts': {}, 'target': 'classes'},

0 commit comments

Comments
 (0)