|
24 | 24 | from tensorflow.core.protobuf import config_pb2
|
25 | 25 | from tensorflow.python import keras
|
26 | 26 | from tensorflow.python.eager import context
|
| 27 | +from tensorflow.python.framework import config |
27 | 28 | from tensorflow.python.framework import errors_impl
|
28 | 29 | from tensorflow.python.framework import ops
|
29 | 30 | from tensorflow.python.framework import sparse_tensor
|
@@ -94,6 +95,34 @@ def compare_two_inputs_op_to_numpy(keras_op,
|
94 | 95 |
|
95 | 96 | class BackendResetTest(test.TestCase, parameterized.TestCase):
|
96 | 97 |
|
| 98 | + @test_util.run_all_in_graph_and_eager_modes |
| 99 | + def test_new_config(self): |
| 100 | + # User defined jit setting |
| 101 | + config.set_optimizer_jit(False) |
| 102 | + sess = keras.backend.get_session() |
| 103 | + default_config = context.context().config |
| 104 | + self.assertEqual( |
| 105 | + sess._config.graph_options.optimizer_options.global_jit_level, |
| 106 | + default_config.graph_options.optimizer_options.global_jit_level) |
| 107 | + keras.backend.clear_session() |
| 108 | + |
| 109 | + # New session has the same jit setting |
| 110 | + sess = keras.backend.get_session() |
| 111 | + default_config = context.context().config |
| 112 | + self.assertEqual( |
| 113 | + sess._config.graph_options.optimizer_options.global_jit_level, |
| 114 | + default_config.graph_options.optimizer_options.global_jit_level) |
| 115 | + keras.backend.clear_session() |
| 116 | + |
| 117 | + # Change respected |
| 118 | + config.set_optimizer_jit(True) |
| 119 | + sess = keras.backend.get_session() |
| 120 | + default_config = context.context().config |
| 121 | + self.assertEqual( |
| 122 | + sess._config.graph_options.optimizer_options.global_jit_level, |
| 123 | + default_config.graph_options.optimizer_options.global_jit_level) |
| 124 | + keras.backend.clear_session() |
| 125 | + |
97 | 126 | # We can't use the normal parameterized decorator because the test session
|
98 | 127 | # will block graph clearing.
|
99 | 128 | @parameterized.named_parameters(('_v1', context.graph_mode),
|
|
0 commit comments