Skip to content

Commit 91d228e

Browse files
committed
test check rec opt out, use masking before comparison
1 parent e95e485 commit 91d228e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3425,8 +3425,8 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
34253425
net1_params = net1.layers["output_not_opt"].get_param_values_dict(session=session)
34263426
net2.layers["output_opt"].set_param_values_by_dict(values_dict=net1_params, session=session)
34273427
x_np = net1.random.normal(size=(n_batch, n_time, n_in))
3428-
net1_output = net1.layers["output_not_opt"].output.get_placeholder_as_batch_major()
3429-
net2_output = net2.layers["output_opt"].output.get_placeholder_as_batch_major()
3428+
net1_output = net1.layers["output_not_opt"].output.copy_masked(0.).get_placeholder_as_batch_major()
3429+
net2_output = net2.layers["output_opt"].output.copy_masked(0.).get_placeholder_as_batch_major()
34303430
feed_dict = {
34313431
net1.extern_data.data["data"].placeholder: x_np,
34323432
net1.extern_data.data["data"].size_placeholder[0]: [n_time] * n_batch}

0 commit comments

Comments
 (0)