Skip to content

Commit 26d3fe8

Browse files
jaingauravtensorflower-gardener
authored andcommitted
Make default Keras ConfigProto use tf.config
PiperOrigin-RevId: 251659257
1 parent d15c612 commit 26d3fe8

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

tensorflow/python/keras/backend.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
7070
from tensorflow.python.ops import tensor_array_ops
7171
from tensorflow.python.ops import variables as variables_module
72+
from tensorflow.python.platform import tf_logging as logging
7273
from tensorflow.python.training import server_lib
7374
from tensorflow.python.util import nest
7475
from tensorflow.python.util import tf_contextlib
@@ -521,14 +522,14 @@ def set_session(session):
521522

522523

523524
def get_default_session_config():
524-
if not os.environ.get('OMP_NUM_THREADS'):
525-
config = config_pb2.ConfigProto(allow_soft_placement=True)
526-
else:
527-
num_thread = int(os.environ.get('OMP_NUM_THREADS'))
528-
config = config_pb2.ConfigProto(
529-
intra_op_parallelism_threads=num_thread,
530-
inter_op_parallelism_threads=num_thread,
531-
allow_soft_placement=True)
525+
if os.environ.get('OMP_NUM_THREADS'):
526+
logging.warning(
527+
'OMP_NUM_THREADS is no longer used by the default Keras config. '
528+
'To configure the number of threads, use tf.config.threading APIs.')
529+
530+
config = context.context().config
531+
config.allow_soft_placement = True
532+
532533
return config
533534

534535

tensorflow/python/keras/backend_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensorflow.core.protobuf import config_pb2
2525
from tensorflow.python import keras
2626
from tensorflow.python.eager import context
27+
from tensorflow.python.framework import config
2728
from tensorflow.python.framework import errors_impl
2829
from tensorflow.python.framework import ops
2930
from tensorflow.python.framework import sparse_tensor
@@ -94,6 +95,34 @@ def compare_two_inputs_op_to_numpy(keras_op,
9495

9596
class BackendResetTest(test.TestCase, parameterized.TestCase):
9697

98+
@test_util.run_all_in_graph_and_eager_modes
99+
def test_new_config(self):
100+
# User defined jit setting
101+
config.set_optimizer_jit(False)
102+
sess = keras.backend.get_session()
103+
default_config = context.context().config
104+
self.assertEqual(
105+
sess._config.graph_options.optimizer_options.global_jit_level,
106+
default_config.graph_options.optimizer_options.global_jit_level)
107+
keras.backend.clear_session()
108+
109+
# New session has the same jit setting
110+
sess = keras.backend.get_session()
111+
default_config = context.context().config
112+
self.assertEqual(
113+
sess._config.graph_options.optimizer_options.global_jit_level,
114+
default_config.graph_options.optimizer_options.global_jit_level)
115+
keras.backend.clear_session()
116+
117+
# Change respected
118+
config.set_optimizer_jit(True)
119+
sess = keras.backend.get_session()
120+
default_config = context.context().config
121+
self.assertEqual(
122+
sess._config.graph_options.optimizer_options.global_jit_level,
123+
default_config.graph_options.optimizer_options.global_jit_level)
124+
keras.backend.clear_session()
125+
97126
# We can't use the normal parameterized decorator because the test session
98127
# will block graph clearing.
99128
@parameterized.named_parameters(('_v1', context.graph_mode),

0 commit comments

Comments
 (0)