From 15ff2f5075115f122208aaba5dbeabffe9b485f4 Mon Sep 17 00:00:00 2001 From: "Ching Yi, Chan" Date: Sat, 5 Oct 2019 00:47:45 +0800 Subject: [PATCH] fix TensorBoardWSGIApp method signature changed --- jupyter_tensorboard/tensorboard_manager.py | 63 +++++++++++++++++++++- tests/test_tensorboard_integration.py | 7 ++- tox.ini | 6 ++- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/jupyter_tensorboard/tensorboard_manager.py b/jupyter_tensorboard/tensorboard_manager.py index f45ddfa..33389ab 100644 --- a/jupyter_tensorboard/tensorboard_manager.py +++ b/jupyter_tensorboard/tensorboard_manager.py @@ -108,7 +108,63 @@ def _ReloadForever(): return thread -def TensorBoardWSGIApp(logdir, plugins, multiplexer, +def is_tensorboard_greater_than_or_equal_to20(): + import tensorflow + version = tensorflow.__version__.split(".") + return int(version[0]) >= 2 + + +def TensorBoardWSGIApp_2x( + flags, + plugins, + data_provider=None, + assets_zip_provider=None, + deprecated_multiplexer=None): + + logdir = flags.logdir + multiplexer = deprecated_multiplexer + reload_interval = flags.reload_interval + + path_to_run = application.parse_event_files_spec(logdir) + if reload_interval: + thread = start_reloading_multiplexer( + multiplexer, path_to_run, reload_interval) + else: + application.reload_multiplexer(multiplexer, path_to_run) + thread = None + + + db_uri = None + db_connection_provider = None + + plugin_name_to_instance = {} + + from tensorboard.plugins import base_plugin + context = base_plugin.TBContext( + data_provider=data_provider, + db_connection_provider=db_connection_provider, + db_uri=db_uri, + flags=flags, + logdir=flags.logdir, + multiplexer=deprecated_multiplexer, + assets_zip_provider=assets_zip_provider, + plugin_name_to_instance=plugin_name_to_instance, + window_title=flags.window_title) + + tbplugins = [] + for loader in plugins: + plugin = loader.load(context) + if plugin is None: + continue + tbplugins.append(plugin) + plugin_name_to_instance[plugin.plugin_name] = plugin + + tb_app = application.TensorBoardWSGI(tbplugins) + manager.add_instance(logdir, tb_app, thread) + return tb_app + + +def TensorBoardWSGIApp_1x(logdir, plugins, multiplexer, reload_interval, path_prefix="", reload_task="auto"): path_to_run = application.parse_event_files_spec(logdir) if reload_interval: @@ -122,7 +178,10 @@ def TensorBoardWSGIApp(logdir, plugins, multiplexer, return tb_app -application.TensorBoardWSGIApp = TensorBoardWSGIApp +if is_tensorboard_greater_than_or_equal_to20(): + application.TensorBoardWSGIApp = TensorBoardWSGIApp_2x +else: + application.TensorBoardWSGIApp = TensorBoardWSGIApp_1x class TensorboardManger(dict): diff --git a/tests/test_tensorboard_integration.py b/tests/test_tensorboard_integration.py index 1cd2674..47f6aab 100644 --- a/tests/test_tensorboard_integration.py +++ b/tests/test_tensorboard_integration.py @@ -13,7 +13,12 @@ def tf_logs(tmpdir_factory): import numpy as np - import tensorflow as tf + try: + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() + except: + import tensorflow as tf + x = np.random.rand(5) y = 3 * x + 1 + 0.05 * np.random.rand(5) diff --git a/tox.ini b/tox.ini index 35fe143..89832f2 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,6 @@ [tox] -envlist = {py27,py34,py35,py36}-tensorflow{13,14,15,16,17,18,19,110,111,112,113} +envlist = {py34,py35,py36}-tensorflow{13,14,15,16,17,18,19,110,111,112,113,200} + [testenv] deps = @@ -15,6 +16,7 @@ deps = tensorflow111: tensorflow>=1.11, <1.12 tensorflow112: tensorflow>=1.12, <1.13 tensorflow113: tensorflow<=1.13, <1.14 + tensorflow200: tensorflow<=2.0, <2.1 commands = pytest @@ -23,4 +25,4 @@ alwayscopy = True [testenv:py36] commands = - flake8 \ No newline at end of file + flake8