Skip to content

Commit

Permalink
py: client: add logging (kubeflow#563)
Browse files Browse the repository at this point in the history
* py: client: add logging

Co-authored-by: Luca Giorgi <[email protected]>
Signed-off-by: Isabella do Amaral <[email protected]>

* minor fixes

Co-authored-by: Matteo Mortari <[email protected]>
Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
Co-authored-by: Luca Giorgi <[email protected]>
Co-authored-by: Matteo Mortari <[email protected]>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent 4c1cf00 commit 3285729
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
import os
from collections.abc import Mapping
from pathlib import Path
Expand All @@ -22,6 +23,30 @@
ModelTypes = Union[RegisteredModel, ModelVersion, ModelArtifact]
TModel = TypeVar("TModel", bound=ModelTypes)

logging.basicConfig(
format="%(asctime)s.%(msecs)03d - %(name)s:%(levelname)s: %(message)s",
datefmt="%H:%M:%S",
level=logging.WARNING, # the default loglevel
handlers=[
# logging.FileHandler(
# LOGS
# / "log-{}-{}.log".format(
# datetime.now(tz=datetime.now().astimezone().tzinfo).strftime(
# "%Y-%m-%d-%H-%M-%S"
# ),
# os.getpid(),
# ),
# encoding="utf-8",
# delay=False,
# ),
logging.StreamHandler(),
],
)

logger = logging.getLogger("model-registry")

DEFAULT_USER_TOKEN_ENVVAR = "KF_PIPELINES_SA_TOKEN_PATH" # noqa: S105


class ModelRegistry:
"""Model registry client."""
Expand All @@ -34,7 +59,10 @@ def __init__(
author: str,
is_secure: bool = True,
user_token: str | None = None,
user_token_envvar: str = DEFAULT_USER_TOKEN_ENVVAR,
custom_ca: str | None = None,
custom_ca_envvar: str | None = None,
log_level: int = logging.WARNING,
):
"""Constructor.
Expand All @@ -45,47 +73,59 @@ def __init__(
Keyword Args:
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT.
user_token: The PEM-encoded user token as a string.
user_token_envvar: Environment variable to read the user token from if it's not passed as an arg. Defaults to KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string.
custom_ca_envvar: Environment variable to read the custom CA from if it's not passed as an arg.
log_level: Log level. Defaults to logging.WARNING.
"""
logger.setLevel(log_level)

import nest_asyncio

logger.debug("Setting up reentrant async event loop")
nest_asyncio.apply()

# TODO: get remaining args from env
self._author = author

if not user_token:
if not user_token and user_token_envvar:
logger.info("Reading user token from %s", user_token_envvar)
# /var/run/secrets/kubernetes.io/serviceaccount/token
sa_token = os.environ.get("KF_PIPELINES_SA_TOKEN_PATH")
if sa_token:
if sa_token := os.environ.get(user_token_envvar):
if user_token_envvar == DEFAULT_USER_TOKEN_ENVVAR:
logger.warning(
f"Sourcing user token from default envvar: {DEFAULT_USER_TOKEN_ENVVAR}"
)
user_token = Path(sa_token).read_text()
else:
warn("User access token is missing", stacklevel=2)

if is_secure:
root_ca = None
if not custom_ca:
if cert := os.getenv("CERT"):
root_ca = cert
# client might have a default CA setup
else:
root_ca = custom_ca
if (
not custom_ca
and custom_ca_envvar
and (cert := os.getenv(custom_ca_envvar))
):
logger.info(
"Using custom CA envvar %s",
custom_ca_envvar,
)
custom_ca = cert
# client might have a default CA setup

if not user_token:
msg = "user token must be provided for secure connection"
raise StoreError(msg)

self._api = ModelRegistryAPIClient.secure_connection(
server_address, port, user_token=user_token, custom_ca=root_ca
server_address, port, user_token=user_token, custom_ca=custom_ca
)
elif custom_ca:
msg = "Custom CA provided without secure connection, conflicting options"
raise StoreError(msg)
else:
self._api = ModelRegistryAPIClient.insecure_connection(
server_address, port, user_token
)
self.get_registered_models().page_size(1)._next_page()

def async_runner(self, coro: Any) -> Any:
import asyncio
Expand Down

0 comments on commit 3285729

Please sign in to comment.