Skip to content

Commit

Permalink
SDK/python: Add support for self-signed certificates with or without …
Browse files Browse the repository at this point in the history
…verification

Signed-off-by: Aaron Wilson <[email protected]>
  • Loading branch information
aaronnw committed Nov 30, 2023
1 parent 143a38d commit b810ac5
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 52 deletions.
4 changes: 2 additions & 2 deletions python/aistore/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Client:
endpoint (str): AIStore endpoint
"""

def __init__(self, endpoint: str):
self._request_client = RequestClient(endpoint)
def __init__(self, endpoint: str, skip_verify: bool = False, ca_cert: str = None):
self._request_client = RequestClient(endpoint, skip_verify, ca_cert)

def bucket(
self, bck_name: str, provider: str = PROVIDER_AIS, namespace: Namespace = None
Expand Down
23 changes: 18 additions & 5 deletions python/aistore/sdk/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#

from __future__ import annotations # pylint: disable=unused-variable

import logging
from typing import List, Optional

from aistore.sdk.const import (
Expand All @@ -26,6 +28,8 @@
from aistore.sdk.request_client import RequestClient
from aistore.sdk.types import ActionMsg, Smap

logger = logging.getLogger("cluster")


# pylint: disable=unused-variable
class Cluster:
Expand Down Expand Up @@ -62,6 +66,12 @@ def get_info(self) -> Smap:
params={QPARAM_WHAT: WHAT_SMAP},
)

def get_primary_url(self) -> str:
"""
Returns: URL of primary proxy
"""
return self.get_info().proxy_si.public_net.direct_url

def list_buckets(self, provider: str = PROVIDER_AIS):
"""
Returns list of buckets in AIStore cluster.
Expand Down Expand Up @@ -144,20 +154,23 @@ def list_running_etls(self) -> List[ETLInfo]:
HTTP_METHOD_GET, path=URL_PATH_ETL, res_model=List[ETLInfo]
)

def is_aistore_running(self) -> bool:
def is_ready(self) -> bool:
"""
Checks if cluster is ready or still setting up.
Returns:
bool: True if cluster is ready, or false if cluster is still setting up
"""

# compare with AIS Go API (api/cluster.go) for additional supported options
params = {QPARAM_PRIMARY_READY_REB: "true"}
try:
resp = self.client.request(
HTTP_METHOD_GET, path=URL_PATH_HEALTH, params=params
resp = self._client.request(
HTTP_METHOD_GET,
path=URL_PATH_HEALTH,
endpoint=self.get_primary_url(),
params=params,
)
return resp.ok
except Exception:
except Exception as err:
logger.debug(err)
return False
4 changes: 3 additions & 1 deletion python/aistore/sdk/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
#

HEADERS_KW = "headers"
# Standard Header Keys
HEADER_ACCEPT = "Accept"
HEADER_USER_AGENT = "User-Agent"
Expand Down Expand Up @@ -108,3 +107,6 @@
STATUS_OK = 200
STATUS_BAD_REQUEST = 400
STATUS_PARTIAL_CONTENT = 206

# Environment Variables
AIS_SERVER_CRT = "AIS_SERVER_CRT"
47 changes: 39 additions & 8 deletions python/aistore/sdk/request_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
import os
from urllib.parse import urljoin, urlencode
from typing import TypeVar, Type, Any, Dict

Expand All @@ -11,7 +12,7 @@
HEADER_USER_AGENT,
USER_AGENT_BASE,
HEADER_CONTENT_TYPE,
HEADERS_KW,
AIS_SERVER_CRT,
)
from aistore.sdk.utils import handle_errors, decode_response
from aistore.version import __version__ as sdk_version
Expand All @@ -28,10 +29,30 @@ class RequestClient:
endpoint (str): AIStore endpoint
"""

def __init__(self, endpoint: str):
def __init__(self, endpoint: str, skip_verify: bool = True, ca_cert: str = None):
self._endpoint = endpoint
self._base_url = urljoin(endpoint, "v1")
self._session = requests.session()
self._session = requests.sessions.session()
if "https" in self._endpoint:
self._set_session_verification(skip_verify, ca_cert)

def _set_session_verification(self, skip_verify: bool, ca_cert: str):
"""
Set session verify value for validating the server's SSL certificate
The requests library allows this to be a boolean or a string path to the cert
If we do not skip verification, the order is:
1. Provided cert path
2. Cert path from env var.
3. True (verify with system's approved CA list)
"""
if skip_verify:
self._session.verify = False
return
if ca_cert:
self._session.verify = ca_cert
return
env_crt = os.getenv(AIS_SERVER_CRT)
self._session.verify = env_crt if env_crt else True

@property
def base_url(self):
Expand Down Expand Up @@ -71,26 +92,36 @@ def request_deserialize(
resp = self.request(method, path, **kwargs)
return decode_response(res_model, resp)

def request(self, method: str, path: str, **kwargs) -> requests.Response:
def request(
self,
method: str,
path: str,
endpoint: str = None,
headers: dict = None,
**kwargs,
) -> requests.Response:
"""
Make a request to the AIS cluster
Args:
method (str): HTTP method, e.g. POST, GET, PUT, DELETE
path (str): URL path to call
endpoint (str): Alternative endpoint for the AIS cluster, e.g. for connecting to a specific proxy
headers (dict): Extra headers to be passed with the request. Content-Type and User-Agent will be overridden
**kwargs (optional): Optional keyword arguments to pass with the call to request
Returns:
Raw response from the API
"""
url = f"{self._base_url}/{path.lstrip('/')}"
if HEADERS_KW not in kwargs:
kwargs[HEADERS_KW] = {}
headers = kwargs.get(HEADERS_KW, {})
base = urljoin(endpoint, "v1") if endpoint else self._base_url
url = f"{base}/{path.lstrip('/')}"
if headers is None:
headers = {}
headers[HEADER_CONTENT_TYPE] = JSON_CONTENT_TYPE
headers[HEADER_USER_AGENT] = f"{USER_AGENT_BASE}/{sdk_version}"
resp = self._session.request(
method,
url,
headers=headers,
**kwargs,
)
if resp.status_code < 200 or resp.status_code >= 300:
Expand Down
2 changes: 1 addition & 1 deletion python/examples/sdk/sdk-basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
},
"outputs": [],
"source": [
"client.cluster().is_aistore_running()"
"client.cluster().is_ready()"
]
},
{
Expand Down
19 changes: 9 additions & 10 deletions python/tests/integration/sdk/test_cluster_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
class TestClusterOps(unittest.TestCase): # pylint: disable=unused-variable
def setUp(self) -> None:
self.client = Client(CLUSTER_ENDPOINT)
self.cluster = self.client.cluster()

def test_health_success(self):
self.assertEqual(Client(CLUSTER_ENDPOINT).cluster().is_aistore_running(), True)
self.assertTrue(self.cluster.is_ready())

def test_health_failure(self):
# url not exisiting or URL down
self.assertEqual(
Client("http://localhost:1234").cluster().is_aistore_running(), False
)
# url not existing or URL down
self.assertFalse(Client("http://localhost:1234").cluster().is_ready())

def test_cluster_map(self):
smap = self.client.cluster().get_info()
smap = self.cluster.get_info()

self.assertIsNotNone(smap)
self.assertIsNotNone(smap.proxy_si)
Expand All @@ -52,11 +51,11 @@ def test_list_jobs_status(self):
job_3_id = self.client.job(job_kind="cleanup").start()

self._check_jobs_in_result(
[job_1_id, job_2_id], self.client.cluster().list_jobs_status()
[job_1_id, job_2_id], self.cluster.list_jobs_status()
)
self._check_jobs_in_result(
[job_1_id, job_2_id],
self.client.cluster().list_jobs_status(job_kind=job_kind),
self.cluster.list_jobs_status(job_kind=job_kind),
[job_3_id],
)

Expand All @@ -75,10 +74,10 @@ def test_list_running_jobs(self):
self.assertIn(expected_res, self.client.cluster().list_running_jobs())
self.assertIn(
expected_res,
self.client.cluster().list_running_jobs(job_kind=ACT_COPY_OBJECTS),
self.cluster.list_running_jobs(job_kind=ACT_COPY_OBJECTS),
)
self.assertNotIn(
expected_res, self.client.cluster().list_running_jobs(job_kind="lru")
expected_res, self.cluster.list_running_jobs(job_kind="lru")
)
finally:
bck.delete()
Expand Down
2 changes: 1 addition & 1 deletion python/tests/unit/sdk/test_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_default_props(self):
def test_properties(self):
self.assertEqual(self.mock_client, self.ais_bck.client)
expected_ns = Namespace(uuid="ns-id", name="ns-name")
client = RequestClient("test client name")
client = RequestClient("test client name", skip_verify=False, ca_cert="")
bck = Bucket(
client=client,
name=BCK_NAME,
Expand Down
14 changes: 14 additions & 0 deletions python/tests/unit/sdk/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,34 @@
#

import unittest
from unittest.mock import patch

from aistore.sdk import Client
from aistore.sdk.cluster import Cluster
from aistore.sdk.etl import Etl
from aistore.sdk.request_client import RequestClient
from aistore.sdk.types import Namespace
from aistore.sdk.job import Job
from tests.unit.sdk.test_utils import test_cases


class TestClient(unittest.TestCase): # pylint: disable=unused-variable
def setUp(self) -> None:
self.endpoint = "https://aistore-endpoint"
self.client = Client(self.endpoint)

@patch("aistore.sdk.client.RequestClient")
def test_init_defaults(self, mock_request_client):
Client(self.endpoint)
mock_request_client.assert_called_with(self.endpoint, False, None)

@test_cases((True, None), (False, "ca_cert_location"))
@patch("aistore.sdk.client.RequestClient")
def test_init(self, test_case, mock_request_client):
skip_verify, ca_cert = test_case
Client(self.endpoint, skip_verify=skip_verify, ca_cert=ca_cert)
mock_request_client.assert_called_with(self.endpoint, skip_verify, ca_cert)

def test_bucket(self):
bck_name = "bucket_123"
provider = "bucketProvider"
Expand Down
46 changes: 33 additions & 13 deletions python/tests/unit/sdk/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from typing import List, Optional
from unittest.mock import Mock, create_autospec
from unittest.mock import Mock

from aistore.sdk.bucket import Bucket
from aistore.sdk.cluster import Cluster
Expand All @@ -21,7 +21,18 @@
URL_PATH_ETL,
)
from aistore.sdk.request_client import RequestClient
from aistore.sdk.types import Smap, ActionMsg, BucketModel, JobStatus, JobQuery, ETLInfo
from aistore.sdk.types import (
Smap,
ActionMsg,
BucketModel,
JobStatus,
JobQuery,
ETLInfo,
Snode,
NetInfo,
)

from tests.unit.sdk.test_utils import test_cases


class TestCluster(unittest.TestCase): # pylint: disable=unused-variable
Expand All @@ -30,7 +41,7 @@ def setUp(self) -> None:
self.cluster = Cluster(self.mock_client)

def test_get_info(self):
expected_result = create_autospec(Smap)
expected_result = Mock()
self.mock_client.request_deserialize.return_value = expected_result
result = self.cluster.get_info()
self.assertEqual(result, expected_result)
Expand Down Expand Up @@ -65,22 +76,31 @@ def list_buckets_exec_assert(self, expected_params, **kwargs):
params=expected_params,
)

def test_is_aistore_running_exception(self):
def test_is_ready_exception(self):
self.mock_client.request.side_effect = Exception
self.assertFalse(self.cluster.is_aistore_running())
self.assertFalse(self.cluster.is_ready())

def test_is_aistore_running(self):
@test_cases(True, False)
def test_is_ready(self, test_case):
expected_params = {QPARAM_PRIMARY_READY_REB: "true"}
response = Mock()
response.ok = True
self.mock_client.request.return_value = response
self.assertTrue(self.cluster.is_aistore_running())
response.ok = False
self.mock_client.request.return_value = response
self.assertFalse(self.cluster.is_aistore_running())
primary_proxy_endpoint = "primary_proxy_url"

mock_response = Mock()
mock_response.ok = test_case
self.mock_client.request.return_value = mock_response
mock_smap = Mock(spec=Smap)
mock_snode = Mock(spec=Snode)
mock_netinfo = Mock(spec=NetInfo)
mock_netinfo.direct_url = primary_proxy_endpoint
mock_snode.public_net = mock_netinfo
mock_smap.proxy_si = mock_snode
self.mock_client.request_deserialize.return_value = mock_smap

self.assertEqual(test_case, self.cluster.is_ready())
self.mock_client.request.assert_called_with(
HTTP_METHOD_GET,
path=URL_PATH_HEALTH,
endpoint=primary_proxy_endpoint,
params=expected_params,
)

Expand Down
Loading

0 comments on commit b810ac5

Please sign in to comment.