Skip to content

feat: retrieve tableau server product name #1631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tableauserverclient/server/endpoint/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def log_response_safely(self, server_response: "Response") -> str:
loggable_response = helpers.strings.redact_xml(server_response.content.decode(server_response.encoding))
return loggable_response

def get_unauthenticated_request(self, url):
return self._make_request(self.parent_srv.session.get, url)
def get_unauthenticated_request(self, url, parameters=None):
return self._make_request(self.parent_srv.session.get, url, parameters=parameters)

def get_request(self, url, request_object=None, parameters=None):
if request_object is not None:
Expand Down
31 changes: 29 additions & 2 deletions tableauserverclient/server/endpoint/server_info_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Union
from typing import Literal, Union, TYPE_CHECKING

from .endpoint import Endpoint, api
from .exceptions import ServerResponseError
Expand All @@ -9,10 +9,15 @@
)
from tableauserverclient.models import ServerInfoItem

if TYPE_CHECKING:
from tableauserverclient.server import Server

Products = Literal["TableauServer", "TableauOnline"]


class ServerInfo(Endpoint):
def __init__(self, server):
self.parent_srv = server
self.parent_srv: "Server" = server
self._info = None

@property
Expand Down Expand Up @@ -80,3 +85,25 @@ def get(self) -> Union[ServerInfoItem, None]:
logging.getLogger(self.__class__.__name__).debug(e)
logging.getLogger(self.__class__.__name__).debug(server_response.content)
return self._info

def _get_product_info(self) -> Products:
"""
Retrieve the server product information to determine if the server is
Tableau Server or Tableau Online.
"""
method = "getServerSettingsUnauthenticated"
response = self.parent_srv.session.post(
f"{self.parent_srv.server_address}/vizportal/api/web/v1/{method}",
headers={"Content-Type": "application/json"},
verify=self.parent_srv.http_options.get("verify", True),
json={"method": method, "params": {}},
)
if not response.ok:
return "TableauServer"
else:
try:
return response.json().get("result", {}).get("product", "TableauServer")
except Exception as e:
logging.getLogger(self.__class__.__name__).debug(e)
logging.getLogger(self.__class__.__name__).debug("Failed to parse product info response.")
return "TableauServer"
2 changes: 2 additions & 0 deletions tableauserverclient/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(self, server_address, use_server_version=False, http_options=None,
self._site_id = None
self._user_id = None
self._ssl_context = None
self._product = "TableauServer" # default product type

# TODO: this needs to change to default to https, but without breaking existing code
if not server_address.startswith("http://") and not server_address.startswith("https://"):
Expand Down Expand Up @@ -267,6 +268,7 @@ def _determine_highest_version(self):

def use_server_version(self):
self.version = self._determine_highest_version()
self._product = self.server_info._get_product_info()

def use_highest_version(self):
self.use_server_version()
Expand Down
1 change: 1 addition & 0 deletions test/assets/getServerSettingsUnauthenticated.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"result": {"product": "TableauOnline"}}
17 changes: 16 additions & 1 deletion test/http/test_http_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def __init__(self, status_code):
return MockResponse(200)


# This method will be used by the mock to replace requests.get
def mocked_requests_post(*args, **kwargs):
class MockResponse:
def __init__(self, status_code):
self.headers = {}
self.encoding = None
self.content = '{"result": {"product": "TableauOnline"}}'
self.status_code = status_code
self.ok = True

return MockResponse(200)


class ServerTests(unittest.TestCase):
def test_init_server_model_empty_throws(self):
with self.assertRaises(TypeError):
Expand All @@ -46,7 +59,8 @@ def test_init_server_model_bad_server_name_not_version_check(self):
server = TSC.Server("fake-url", use_server_version=False)

@mock.patch("requests.sessions.Session.get", side_effect=mocked_requests_get)
def test_init_server_model_bad_server_name_do_version_check(self, mock_get):
@mock.patch("requests.sessions.Session.post", side_effect=mocked_requests_post)
def test_init_server_model_bad_server_name_do_version_check(self, mock_get, mock_post):
server = TSC.Server("fake-url", use_server_version=True)

def test_init_server_model_bad_server_name_not_version_check_random_options(self):
Expand Down Expand Up @@ -114,4 +128,5 @@ def test_session_factory_adds_headers(self):
test_request_bin = "http://capture-this-with-mock.com"
with requests_mock.mock() as m:
m.get(url="http://capture-this-with-mock.com/api/2.4/serverInfo", request_headers=SessionTests.test_header)
m.post(f"{test_request_bin}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={})
server = TSC.Server(test_request_bin, use_server_version=True, session_factory=SessionTests.session_factory)
25 changes: 25 additions & 0 deletions test/test_server_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os.path
import unittest

Expand All @@ -13,6 +14,7 @@
SERVER_INFO_404 = os.path.join(TEST_ASSET_DIR, "server_info_404.xml")
SERVER_INFO_AUTH_INFO_XML = os.path.join(TEST_ASSET_DIR, "server_info_auth_info.xml")
SERVER_INFO_WRONG_SITE = os.path.join(TEST_ASSET_DIR, "server_info_wrong_site.html")
SERVER_PRODUCT_INFO = os.path.join(TEST_ASSET_DIR, "getServerSettingsUnauthenticated.json")


class ServerInfoTests(unittest.TestCase):
Expand All @@ -26,6 +28,7 @@ def test_server_info_get(self):
response_xml = f.read().decode("utf-8")
with requests_mock.mock() as m:
m.get(self.server.server_info.baseurl, text=response_xml)
m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={})
actual = self.server.server_info.get()

self.assertEqual("10.1.0", actual.product_version)
Expand All @@ -43,6 +46,8 @@ def test_server_info_use_highest_version_downgrades(self):
# Return a 404 for serverInfo so we can pretend this is an old Server
m.get(self.server.server_address + "/api/2.4/serverInfo", text=si_response_xml, status_code=404)
m.get(self.server.server_address + "/auth?format=xml", text=auth_response_xml)
m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={})

self.server.use_server_version()
# does server-version[9.2] lookup in PRODUCT_TO_REST_VERSION
self.assertEqual(self.server.version, "2.2")
Expand All @@ -52,6 +57,7 @@ def test_server_info_use_highest_version_upgrades(self):
si_response_xml = f.read().decode("utf-8")
with requests_mock.mock() as m:
m.get(self.server.server_address + "/api/2.8/serverInfo", text=si_response_xml)
m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={})
# Pretend we're old
self.server.version = "2.8"
self.server.use_server_version()
Expand All @@ -63,6 +69,7 @@ def test_server_use_server_version_flag(self):
si_response_xml = f.read().decode("utf-8")
with requests_mock.mock() as m:
m.get("http://test/api/2.4/serverInfo", text=si_response_xml)
m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={})
server = TSC.Server("http://test", use_server_version=True)
self.assertEqual(server.version, "2.5")

Expand All @@ -73,3 +80,21 @@ def test_server_wrong_site(self):
m.get(self.server.server_info.baseurl, text=response, status_code=404)
with self.assertRaises(NonXMLResponseError):
self.server.server_info.get()

def test_server_info_product(self):
with open(SERVER_PRODUCT_INFO) as f:
product_info_json = json.load(f)

with requests_mock.mock() as m:
m.post(
f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated",
json=product_info_json,
)
self.server.use_server_version()
assert self.server._product == "TableauOnline"

def test_server_info_product_no_response(self):
with requests_mock.mock() as m:
m.post(f"{self.server.server_address}/vizportal/api/web/v1/getServerSettingsUnauthenticated", json={})
self.server.use_server_version()
assert self.server._product == "TableauServer"
Loading