diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index 21462af5f..6a3ea1913 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -189,8 +189,8 @@ def log_response_safely(self, server_response: "Response") -> str: loggable_response = helpers.strings.redact_xml(server_response.content.decode(server_response.encoding)) return loggable_response - def get_unauthenticated_request(self, url): - return self._make_request(self.parent_srv.session.get, url) + def get_unauthenticated_request(self, url, parameters=None): + return self._make_request(self.parent_srv.session.get, url, parameters=parameters) def get_request(self, url, request_object=None, parameters=None): if request_object is not None: diff --git a/tableauserverclient/server/endpoint/server_info_endpoint.py b/tableauserverclient/server/endpoint/server_info_endpoint.py index dc934496a..58e687fe7 100644 --- a/tableauserverclient/server/endpoint/server_info_endpoint.py +++ b/tableauserverclient/server/endpoint/server_info_endpoint.py @@ -1,5 +1,5 @@ import logging -from typing import Union +from typing import Literal, Union, TYPE_CHECKING from .endpoint import Endpoint, api from .exceptions import ServerResponseError @@ -9,10 +9,15 @@ ) from tableauserverclient.models import ServerInfoItem +if TYPE_CHECKING: + from tableauserverclient.server import Server + +Products = Literal["TableauServer", "TableauOnline"] + class ServerInfo(Endpoint): def __init__(self, server): - self.parent_srv = server + self.parent_srv: "Server" = server self._info = None @property @@ -80,3 +85,25 @@ def get(self) -> Union[ServerInfoItem, None]: logging.getLogger(self.__class__.__name__).debug(e) logging.getLogger(self.__class__.__name__).debug(server_response.content) return self._info + + def _get_product_info(self) -> Products: + """ + Retrieve the server product information to determine if the server is + Tableau Server or Tableau Online. + """ + method = "getServerSettingsUnauthenticated" + response = self.parent_srv.session.post( + f"{self.parent_srv.server_address}/vizportal/api/web/v1/{method}", + headers={"Content-Type": "application/json"}, + verify=self.parent_srv.http_options.get("verify", True), + json={"method": method, "params": {}}, + ) + if not response.ok: + return "TableauServer" + else: + try: + return response.json().get("result", {}).get("product", "TableauServer") + except Exception as e: + logging.getLogger(self.__class__.__name__).debug(e) + logging.getLogger(self.__class__.__name__).debug("Failed to parse product info response.") + return "TableauServer" diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index d5d163db3..482778ec6 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -144,6 +144,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self._site_id = None self._user_id = None self._ssl_context = None + self._product = "TableauServer" # default product type # TODO: this needs to change to default to https, but without breaking existing code if not server_address.startswith("http://") and not server_address.startswith("https://"): @@ -267,6 +268,7 @@ def _determine_highest_version(self): def use_server_version(self): self.version = self._determine_highest_version() + self._product = self.server_info._get_product_info() def use_highest_version(self): self.use_server_version() diff --git a/test/assets/getServerSettingsUnauthenticated.json b/test/assets/getServerSettingsUnauthenticated.json new file mode 100644 index 000000000..9c3464353 --- /dev/null +++ b/test/assets/getServerSettingsUnauthenticated.json @@ -0,0 +1 @@ +{"result": {"product": "TableauOnline"}} diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py index ce845502d..d96c4389b 100644 --- a/test/http/test_http_requests.py +++ b/test/http/test_http_requests.py @@ -27,6 +27,19 @@ def __init__(self, status_code): return MockResponse(200) +# This method will be used by the mock to replace requests.get +def mocked_requests_post(*args, **kwargs): + class MockResponse: + def __init__(self, status_code): + self.headers = {} + self.encoding = None + self.content = '{"result": {"product": "TableauOnline"}}' + self.status_code = status_code + self.ok = True + + return MockResponse(200) + + class ServerTests(unittest.TestCase): def test_init_server_model_empty_throws(self): with self.assertRaises(TypeError): @@ -46,7 +59,8 @@ def test_init_server_model_bad_server_name_not_version_check(self): server = TSC.Server("fake-url", use_server_version=False) @mock.patch("requests.sessions.Session.get", side_effect=mocked_requests_get) - def test_init_server_model_bad_server_name_do_version_check(self, mock_get): + @mock.patch("requests.sessions.Session.post", side_effect=mocked_requests_post) + def test_init_server_model_bad_server_name_do_version_check(self, mock_get, mock_post): server = TSC.Server("fake-url", use_server_version=True) def test_init_server_model_bad_server_name_not_version_check_random_options(self): @@ -114,4 +128,5 @@ def test_session_factory_adds_headers(self): test_request_bin = "http://capture-this-with-mock.com" with requests_mock.mock() as m: m.get(url="http://capture-this-with-mock.com/api/2.4/serverInfo", request_headers=SessionTests.test_header) + m.post(f"{test_request_bin}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) server = TSC.Server(test_request_bin, use_server_version=True, session_factory=SessionTests.session_factory) diff --git a/test/test_server_info.py b/test/test_server_info.py index fa1472c9a..eb5809fab 100644 --- a/test/test_server_info.py +++ b/test/test_server_info.py @@ -1,3 +1,4 @@ +import json import os.path import unittest @@ -13,6 +14,7 @@ SERVER_INFO_404 = os.path.join(TEST_ASSET_DIR, "server_info_404.xml") SERVER_INFO_AUTH_INFO_XML = os.path.join(TEST_ASSET_DIR, "server_info_auth_info.xml") SERVER_INFO_WRONG_SITE = os.path.join(TEST_ASSET_DIR, "server_info_wrong_site.html") +SERVER_PRODUCT_INFO = os.path.join(TEST_ASSET_DIR, "getServerSettingsUnauthenticated.json") class ServerInfoTests(unittest.TestCase): @@ -26,6 +28,7 @@ def test_server_info_get(self): response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: m.get(self.server.server_info.baseurl, text=response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) actual = self.server.server_info.get() self.assertEqual("10.1.0", actual.product_version) @@ -43,6 +46,8 @@ def test_server_info_use_highest_version_downgrades(self): # Return a 404 for serverInfo so we can pretend this is an old Server m.get(self.server.server_address + "/api/2.4/serverInfo", text=si_response_xml, status_code=404) m.get(self.server.server_address + "/auth?format=xml", text=auth_response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) + self.server.use_server_version() # does server-version[9.2] lookup in PRODUCT_TO_REST_VERSION self.assertEqual(self.server.version, "2.2") @@ -52,6 +57,7 @@ def test_server_info_use_highest_version_upgrades(self): si_response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: m.get(self.server.server_address + "/api/2.8/serverInfo", text=si_response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) # Pretend we're old self.server.version = "2.8" self.server.use_server_version() @@ -63,6 +69,7 @@ def test_server_use_server_version_flag(self): si_response_xml = f.read().decode("utf-8") with requests_mock.mock() as m: m.get("http://test/api/2.4/serverInfo", text=si_response_xml) + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) server = TSC.Server("http://test", use_server_version=True) self.assertEqual(server.version, "2.5") @@ -73,3 +80,21 @@ def test_server_wrong_site(self): m.get(self.server.server_info.baseurl, text=response, status_code=404) with self.assertRaises(NonXMLResponseError): self.server.server_info.get() + + def test_server_info_product(self): + with open(SERVER_PRODUCT_INFO) as f: + product_info_json = json.load(f) + + with requests_mock.mock() as m: + m.post( + f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", + json=product_info_json, + ) + self.server.use_server_version() + assert self.server._product == "TableauOnline" + + def test_server_info_product_no_response(self): + with requests_mock.mock() as m: + m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={}) + self.server.use_server_version() + assert self.server._product == "TableauServer"