diff --git a/jupyter_tensorboard/handlers.py b/jupyter_tensorboard/handlers.py index 391b96c..90578a4 100644 --- a/jupyter_tensorboard/handlers.py +++ b/jupyter_tensorboard/handlers.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright (c) 2017-2019, Shengpeng Liu. All rights reserved. +# Copyright (c) 2019, Alex Ford. All rights reserved. # Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. from tornado import web @@ -8,23 +9,6 @@ from notebook.utils import url_path_join as ujoin from notebook.base.handlers import path_regex -try: - import wrapt - @wrapt.patch_function_wrapper(web.RequestHandler, 'check_xsrf_cookie') - def translate_check_xsrf_cookie(wrapped, instance, args, kwargs): - - if ((instance.request.headers.get("X-XSRF-TOKEN")) and - not (instance.get_argument("_xsrf", None) - or instance.request.headers.get("X-Xsrftoken") - or instance.request.headers.get("X-Csrftoken"))): - - instance.request.headers.add("X-Xsrftoken", instance.request.headers.get("X-XSRF-TOKEN")) - - wrapped(*args, **kwargs) -except: - pass - - notebook_dir = None def load_jupyter_server_extension(nb_app): @@ -70,7 +54,6 @@ def load_jupyter_server_extension(nb_app): class TensorboardHandler(IPythonHandler): - def _impl(self, name, path): self.request.path = path @@ -105,6 +88,74 @@ def post(self, name, path): self._impl(name, path) + def check_xsrf_cookie(self): + """Expand XSRF check exceptions for POST requests. + + For TensorBoard versions <= 2.4.x, expand check_xsrf_cookie exceptions, + normally only applied to GET and HEAD requests, to POST requests + for TensorBoard API calls, as TensorBoard doesn't return back the + header used for XSRF checks. + + Provides support for TensorBoard plugins that use POST to retrieve + experiment information. + + """ + + # Set our own expectations as to whether TB will implement this + # See https://github.com/tensorflow/tensorboard/issues/4685 + # Presently assuming TB > 2.4 will fix this + from distutils.version import LooseVersion + import tensorboard as tb + tb_has_xsrf = (LooseVersion(tb.__version__) >= LooseVersion("2.5")) + + # Check whether any XSRF token is provided + req_has_xsrf = (self.get_argument("_xsrf", None) + or self.request.headers.get("X-Xsrftoken") + or self.request.headers.get("X-Csrftoken")) + + # If applicable, translate Angular's XSRF header to Tornado's name + if (self.request.headers.get("X-XSRF-TOKEN") and not req_has_xsrf): + + self.request.headers.add("X-Xsrftoken", + self.request.headers.get("X-XSRF-TOKEN")) + + req_has_xsrf = True + + + # Check XSRF token + try: + return super(TensorboardHandler, self).check_xsrf_cookie() + + except web.HTTPError: + # Note: GET and HEAD exceptions are already handled in + # IPythonHandler.check_xsrf_cookie and will not normally throw 403 + + # If the request had one of the XSRF tokens, or if TB > 2.4.x, + # then no exceptions, and we're done + if req_has_xsrf or tb_has_xsrf: + raise + + # Otherwise, we must loosen our expectations a bit. IPythonHandler + # has some existing exceptions to consider a matching Referer as + # sufficient for GET and HEAD requests; we extend that here to POST + # for TB versions <= 2.4.x that don't handle XSRF tokens + + if self.request.method in {"POST"}: + # Consider Referer a sufficient cross-origin check, mirroring + # logic in IPythonHandler.check_xsrf_cookie. + if not self.check_referer(): + referer = self.request.headers.get("Referer") + if referer: + msg = ( + "Blocking Cross Origin request from {}." + .format(referer) + ) + else: + msg = "Blocking request from unknown origin" + raise web.HTTPError(403, msg) + else: + raise + class TbFontHandler(IPythonHandler): diff --git a/requirements.txt b/requirements.txt index fc1853c..35018a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -e . flake8 pytest -wrapt diff --git a/setup.py b/setup.py index f723186..eb191d5 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,6 @@ def run(self): ]], install_requires=[ 'notebook>=5.0', - 'wrapt>=1.12', ], classifiers=[ 'Intended Audience :: Developers',