diff --git a/.gitignore b/.gitignore index 1b6adc6c..3ee8a62f 100644 --- a/.gitignore +++ b/.gitignore @@ -134,6 +134,7 @@ dmypy.json # Simvue files simvue.ini +simvue.toml offline/ # Pyenv diff --git a/README.md b/README.md index bea1c0d1..18232d9c 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,10 @@ export SIMVUE_URL=... export SIMVUE_TOKEN=... ``` or a file `simvue.ini` can be created containing: -```ini +```toml [server] -url = ... -token = ... +url = "..." +token = "..." ``` The exact contents of both of the above options can be obtained directly by clicking the **Create new run** button on the web UI. Note that the environment variables have preference over the config file. diff --git a/pyproject.toml b/pyproject.toml index 28d8da30..189f27d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ markers = [ "utilities: test simvue utilities module", "scenario: test scenarios", "executor: tests of executors", + "config: tests of simvue configuration", "api: tests of RestAPI functionality", "unix: tests for UNIX systems only", "metadata: tests of metadata gathering functions", diff --git a/simvue/client.py b/simvue/client.py index 713f457a..28f448df 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -24,8 +24,9 @@ ) from .serialization import deserialize_data from .simvue_types import DeserializedContent -from .utilities import check_extra, get_auth, prettify_pydantic +from .utilities import check_extra, prettify_pydantic from .models import FOLDER_REGEX, NAME_REGEX +from .config import SimvueConfiguration if typing.TYPE_CHECKING: pass @@ -90,18 +91,33 @@ class Client: Class for querying Simvue """ - def __init__(self) -> None: - """Initialise an instance of the Simvue client""" - self._url: typing.Optional[str] - self._token: typing.Optional[str] + def __init__( + self, + server_token: typing.Optional[str] = None, + server_url: typing.Optional[str] = None, + ) -> None: + """Initialise an instance of the Simvue client - self._url, self._token = get_auth() + Parameters + ---------- + server_token : str, optional + specify token, if unset this is read from the config file + server_url : str, optional + specify URL, if unset this is read from the config file + """ + self._config = SimvueConfiguration.fetch( + server_token=server_token, server_url=server_url + ) - for label, value in zip(("URL", "API token"), (self._url, self._token)): + for label, value in zip( + ("URL", "API token"), (self._config.server.url, self._config.server.url) + ): if not value: logger.warning(f"No {label} specified") - self._headers: dict[str, str] = {"Authorization": f"Bearer {self._token}"} + self._headers: dict[str, str] = { + "Authorization": f"Bearer {self._config.server.token}" + } def _get_json_from_response( self, @@ -165,7 +181,7 @@ def get_run_id_from_name( params: dict[str, str] = {"filters": json.dumps([f"name == {name}"])} response: requests.Response = requests.get( - f"{self._url}/api/runs", headers=self._headers, params=params + f"{self._config.server.url}/api/runs", headers=self._headers, params=params ) json_response = self._get_json_from_response( @@ -215,7 +231,7 @@ def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]: """ response: requests.Response = requests.get( - f"{self._url}/api/runs/{run_id}", headers=self._headers + f"{self._config.server.url}/api/runs/{run_id}", headers=self._headers ) json_response = self._get_json_from_response( @@ -331,7 +347,7 @@ def get_runs( } response = requests.get( - f"{self._url}/api/runs", headers=self._headers, params=params + f"{self._config.server.url}/api/runs", headers=self._headers, params=params ) response.raise_for_status() @@ -380,7 +396,8 @@ def delete_run(self, run_identifier: str) -> typing.Optional[dict]: """ response = requests.delete( - f"{self._url}/api/runs/{run_identifier}", headers=self._headers + f"{self._config.server.url}/api/runs/{run_identifier}", + headers=self._headers, ) json_response = self._get_json_from_response( @@ -415,7 +432,9 @@ def _get_folder_id_from_path(self, path: str) -> typing.Optional[str]: params: dict[str, str] = {"filters": json.dumps([f"path == {path}"])} response: requests.Response = requests.get( - f"{self._url}/api/folders", headers=self._headers, params=params + f"{self._config.server.url}/api/folders", + headers=self._headers, + params=params, ) if ( @@ -458,7 +477,9 @@ def delete_runs( params: dict[str, bool] = {"runs_only": True, "runs": True} response = requests.delete( - f"{self._url}/api/folders/{folder_id}", headers=self._headers, params=params + f"{self._config.server.url}/api/folders/{folder_id}", + headers=self._headers, + params=params, ) if response.status_code == http.HTTPStatus.OK: @@ -522,7 +543,9 @@ def delete_folder( params |= {"recursive": recursive} response = requests.delete( - f"{self._url}/api/folders/{folder_id}", headers=self._headers, params=params + f"{self._config.server.url}/api/folders/{folder_id}", + headers=self._headers, + params=params, ) json_response = self._get_json_from_response( @@ -551,7 +574,7 @@ def delete_alert(self, alert_id: str) -> None: the unique identifier for the alert """ response = requests.delete( - f"{self._url}/api/alerts/{alert_id}", headers=self._headers + f"{self._config.server.url}/api/alerts/{alert_id}", headers=self._headers ) if response.status_code == http.HTTPStatus.OK: @@ -586,7 +609,9 @@ def list_artifacts(self, run_id: str) -> list[dict[str, typing.Any]]: params: dict[str, str] = {"runs": json.dumps([run_id])} response: requests.Response = requests.get( - f"{self._url}/api/artifacts", headers=self._headers, params=params + f"{self._config.server.url}/api/artifacts", + headers=self._headers, + params=params, ) json_response = self._get_json_from_response( @@ -610,7 +635,7 @@ def _retrieve_artifact_from_server( params: dict[str, str | None] = {"name": name} response = requests.get( - f"{self._url}/api/runs/{run_id}/artifacts", + f"{self._config.server.url}/api/runs/{run_id}/artifacts", headers=self._headers, params=params, ) @@ -649,7 +674,7 @@ def abort_run(self, run_id: str, reason: str) -> typing.Union[dict, list]: body: dict[str, str | None] = {"id": run_id, "reason": reason} response = requests.put( - f"{self._url}/api/runs/abort", + f"{self._config.server.url}/api/runs/abort", headers=self._headers, json=body, ) @@ -843,7 +868,7 @@ def get_artifacts_as_files( params: dict[str, typing.Optional[str]] = {"category": category} response: requests.Response = requests.get( - f"{self._url}/api/runs/{run_id}/artifacts", + f"{self._config.server.url}/api/runs/{run_id}/artifacts", headers=self._headers, params=params, ) @@ -935,7 +960,9 @@ def get_folders( } response: requests.Response = requests.get( - f"{self._url}/api/folders", headers=self._headers, params=params + f"{self._config.server.url}/api/folders", + headers=self._headers, + params=params, ) json_response = self._get_json_from_response( @@ -980,7 +1007,9 @@ def get_metrics_names(self, run_id: str) -> list[str]: params = {"runs": json.dumps([run_id])} response: requests.Response = requests.get( - f"{self._url}/api/metrics/names", headers=self._headers, params=params + f"{self._config.server.url}/api/metrics/names", + headers=self._headers, + params=params, ) json_response = self._get_json_from_response( @@ -1014,7 +1043,9 @@ def _get_run_metrics_from_server( } metrics_response: requests.Response = requests.get( - f"{self._url}/api/metrics", headers=self._headers, params=params + f"{self._config.server.url}/api/metrics", + headers=self._headers, + params=params, ) json_response = self._get_json_from_response( @@ -1271,7 +1302,9 @@ def get_events( } response = requests.get( - f"{self._url}/api/events", headers=self._headers, params=params + f"{self._config.server.url}/api/events", + headers=self._headers, + params=params, ) json_response = self._get_json_from_response( @@ -1317,7 +1350,9 @@ def get_alerts( if there was a failure retrieving data from the server """ if not run_id: - response = requests.get(f"{self._url}/api/alerts/", headers=self._headers) + response = requests.get( + f"{self._config.server.url}/api/alerts/", headers=self._headers + ) json_response = self._get_json_from_response( expected_status=[http.HTTPStatus.OK], @@ -1326,7 +1361,7 @@ def get_alerts( ) else: response = requests.get( - f"{self._url}/api/runs/{run_id}", headers=self._headers + f"{self._config.server.url}/api/runs/{run_id}", headers=self._headers ) json_response = self._get_json_from_response( diff --git a/simvue/config/__init__.py b/simvue/config/__init__.py new file mode 100644 index 00000000..db386d99 --- /dev/null +++ b/simvue/config/__init__.py @@ -0,0 +1,9 @@ +""" +Simvue Configuration +==================== + +This module contains definitions for the Simvue configuration options + +""" + +from .user import SimvueConfiguration as SimvueConfiguration diff --git a/simvue/config/parameters.py b/simvue/config/parameters.py new file mode 100644 index 00000000..7b378d72 --- /dev/null +++ b/simvue/config/parameters.py @@ -0,0 +1,86 @@ +""" +Simvue Configuration File Models +================================ + +Pydantic models for elements of the Simvue configuration file + +""" + +import logging +import os +import time +import pydantic +import typing +import pathlib +import http + +import simvue.models as sv_models +from simvue.utilities import get_expiry +from simvue.version import __version__ +from simvue.api import get + +CONFIG_FILE_NAMES: list[str] = ["simvue.toml", ".simvue.toml"] + +CONFIG_INI_FILE_NAMES: list[str] = [ + f'{pathlib.Path.cwd().joinpath("simvue.ini")}', + f'{pathlib.Path.home().joinpath(".simvue.ini")}', +] + +logger = logging.getLogger(__file__) + + +class ServerSpecifications(pydantic.BaseModel): + url: pydantic.AnyHttpUrl + token: pydantic.SecretStr + + @pydantic.field_validator("url") + @classmethod + def url_to_str(cls, v: typing.Any) -> str: + return f"{v}" + + @pydantic.field_validator("token") + def check_token(cls, v: typing.Any) -> str: + value = v.get_secret_value() + if not (expiry := get_expiry(value)): + raise AssertionError("Failed to parse Simvue token - invalid token form") + if time.time() - expiry > 0: + raise AssertionError("Simvue token has expired") + return value + + @pydantic.model_validator(mode="after") + @classmethod + def check_valid_server(cls, values: "ServerSpecifications") -> bool: + if os.environ.get("SIMVUE_NO_SERVER_CHECK"): + return values + + headers: dict[str, str] = { + "Authorization": f"Bearer {values.token}", + "User-Agent": f"Simvue Python client {__version__}", + } + try: + response = get(f"{values.url}/api/version", headers) + + if response.status_code != http.HTTPStatus.OK or not response.json().get( + "version" + ): + raise AssertionError + + if response.status_code == http.HTTPStatus.UNAUTHORIZED: + raise AssertionError("Unauthorised token") + + except Exception as err: + raise AssertionError(f"Exception retrieving server version: {str(err)}") + + return values + + +class DefaultRunSpecifications(pydantic.BaseModel): + name: typing.Optional[str] = None + description: typing.Optional[str] = None + tags: typing.Optional[list[str]] = None + folder: str = pydantic.Field("/", pattern=sv_models.FOLDER_REGEX) + metadata: typing.Optional[dict[str, typing.Union[str, int, float, bool]]] = None + + +class ClientGeneralOptions(pydantic.BaseModel): + debug: bool = False diff --git a/simvue/config/user.py b/simvue/config/user.py new file mode 100644 index 00000000..b0ab6d8e --- /dev/null +++ b/simvue/config/user.py @@ -0,0 +1,132 @@ +""" +Simvue Configuration File Model +=============================== + +Pydantic model for the Simvue TOML configuration file + +""" + +import functools +import logging +import os +import typing +import pathlib +import configparser +import contextlib +import warnings + +import pydantic +import toml + +import simvue.utilities as sv_util +from simvue.config.parameters import ( + CONFIG_FILE_NAMES, + CONFIG_INI_FILE_NAMES, + ClientGeneralOptions, + DefaultRunSpecifications, + ServerSpecifications, +) + +logger = logging.getLogger(__file__) + + +class SimvueConfiguration(pydantic.BaseModel): + # Hide values as they contain token and URL + model_config = pydantic.ConfigDict(hide_input_in_errors=True) + client: ClientGeneralOptions = ClientGeneralOptions() + server: ServerSpecifications = pydantic.Field( + ..., description="Specifications for Simvue server" + ) + run: DefaultRunSpecifications = DefaultRunSpecifications() + + @classmethod + def _parse_ini_config(cls, ini_file: pathlib.Path) -> dict[str, dict[str, str]]: + """Parse a legacy INI config file if found.""" + # NOTE: Legacy INI support, this will be removed + warnings.warn( + "Support for legacy INI based configuration files will be dropped in simvue>=1.2, " + "please switch to TOML based configuration.", + DeprecationWarning, + stacklevel=2, + ) + + config_dict: dict[str, dict[str, str]] = {"server": {}} + + with contextlib.suppress(Exception): + parser = configparser.ConfigParser() + parser.read(f"{ini_file}") + if token := parser.get("server", "token"): + config_dict["server"]["token"] = token + if url := parser.get("server", "url"): + config_dict["server"]["url"] = url + + return config_dict + + @classmethod + @sv_util.prettify_pydantic + def fetch( + cls, + server_url: typing.Optional[str] = None, + server_token: typing.Optional[str] = None, + ) -> "SimvueConfiguration": + _config_dict: dict[str, dict[str, str]] = {} + + try: + logger.info(f"Using config file '{cls.config_file()}'") + + # NOTE: Legacy INI support, this will be removed + if cls.config_file().suffix == ".toml": + _config_dict = toml.load(cls.config_file()) + else: + _config_dict = cls._parse_ini_config(cls.config_file()) + + except FileNotFoundError: + if not server_token or not server_url: + _config_dict = {"server": {}} + logger.warning("No config file found, checking environment variables") + + _config_dict["server"] = _config_dict.get("server", {}) + + # Ranking of configurations for token and URl is: + # Envionment Variables > Run Definition > Configuration File + + _server_url = os.environ.get( + "SIMVUE_URL", server_url or _config_dict["server"].get("url") + ) + + _server_token = os.environ.get( + "SIMVUE_TOKEN", server_token or _config_dict["server"].get("token") + ) + + if not _server_url: + raise RuntimeError("No server URL was specified") + + if not _server_token: + raise RuntimeError("No server token was specified") + + _config_dict["server"]["token"] = _server_token + _config_dict["server"]["url"] = _server_url + + return SimvueConfiguration(**_config_dict) + + @classmethod + @functools.lru_cache + def config_file(cls) -> pathlib.Path: + _config_file: typing.Optional[pathlib.Path] = ( + sv_util.find_first_instance_of_file( + CONFIG_FILE_NAMES, check_user_space=True + ) + ) + + # NOTE: Legacy INI support, this will be removed + if not _config_file: + _config_file: typing.Optional[pathlib.Path] = ( + sv_util.find_first_instance_of_file( + CONFIG_INI_FILE_NAMES, check_user_space=True + ) + ) + + if not _config_file: + raise FileNotFoundError("Failed to find Simvue configuration file") + + return _config_file diff --git a/simvue/factory/proxy/__init__.py b/simvue/factory/proxy/__init__.py index 0f3267d1..91065558 100644 --- a/simvue/factory/proxy/__init__.py +++ b/simvue/factory/proxy/__init__.py @@ -9,15 +9,22 @@ if typing.TYPE_CHECKING: from .base import SimvueBaseClass + from simvue.config import SimvueConfiguration from .offline import Offline from .remote import Remote def Simvue( - name: typing.Optional[str], uniq_id: str, mode: str, suppress_errors: bool = True + name: typing.Optional[str], + uniq_id: str, + mode: str, + config: "SimvueConfiguration", + suppress_errors: bool = True, ) -> "SimvueBaseClass": if mode == "offline": - return Offline(name, uniq_id, suppress_errors) + return Offline(name=name, uniq_id=uniq_id, suppress_errors=suppress_errors) else: - return Remote(name, uniq_id, suppress_errors) + return Remote( + name=name, uniq_id=uniq_id, config=config, suppress_errors=suppress_errors + ) diff --git a/simvue/factory/proxy/base.py b/simvue/factory/proxy/base.py index 1ca9060c..2dc3c13d 100644 --- a/simvue/factory/proxy/base.py +++ b/simvue/factory/proxy/base.py @@ -86,10 +86,6 @@ def send_event( def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]: pass - @abc.abstractmethod - def check_token(self) -> bool: - pass - @abc.abstractmethod def get_abort_status(self) -> bool: pass diff --git a/simvue/factory/proxy/offline.py b/simvue/factory/proxy/offline.py index 31f20e3f..5ec1ccc6 100644 --- a/simvue/factory/proxy/offline.py +++ b/simvue/factory/proxy/offline.py @@ -16,6 +16,9 @@ skip_if_failed, ) +if typing.TYPE_CHECKING: + pass + logger = logging.getLogger(__name__) @@ -216,7 +219,3 @@ def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]: ) pathlib.Path(os.path.join(self._directory, "heartbeat")).touch() return {"success": True} - - @skip_if_failed("_aborted", "_suppress_errors", False) - def check_token(self) -> bool: - return True diff --git a/simvue/factory/proxy/remote.py b/simvue/factory/proxy/remote.py index 0636fdca..f61ed0d3 100644 --- a/simvue/factory/proxy/remote.py +++ b/simvue/factory/proxy/remote.py @@ -1,11 +1,13 @@ import logging -import time import typing import http +if typing.TYPE_CHECKING: + from simvue.config import SimvueConfiguration + from simvue.api import get, post, put from simvue.factory.proxy.base import SimvueBaseClass -from simvue.utilities import get_auth, get_expiry, prepare_for_api, skip_if_failed +from simvue.utilities import prepare_for_api, skip_if_failed from simvue.version import __version__ logger = logging.getLogger(__name__) @@ -20,19 +22,22 @@ class Remote(SimvueBaseClass): """ def __init__( - self, name: typing.Optional[str], uniq_id: str, suppress_errors: bool = True + self, + name: typing.Optional[str], + uniq_id: str, + config: "SimvueConfiguration", + suppress_errors: bool = True, ) -> None: - self._url, self._token = get_auth() + self._config = config self._headers: dict[str, str] = { - "Authorization": f"Bearer {self._token}", + "Authorization": f"Bearer {self._config.server.token}", "User-Agent": f"Simvue Python client {__version__}", } self._headers_mp: dict[str, str] = self._headers | { "Content-Type": "application/msgpack" } super().__init__(name, uniq_id, suppress_errors) - self.check_token() self._id = uniq_id @@ -40,7 +45,9 @@ def __init__( def list_tags(self) -> list[str]: logger.debug("Retrieving existing tags") try: - response = get(f"{self._url}/api/runs/{self._id}", self._headers) + response = get( + f"{self._config.server.url}/api/runs/{self._id}", self._headers + ) except Exception as err: self._error(f"Exception retrieving tags: {str(err)}") return [] @@ -73,7 +80,7 @@ def create_run(self, data) -> tuple[typing.Optional[str], typing.Optional[int]]: logger.debug("Creating folder %s if necessary", data.get("folder")) try: response = post( - f"{self._url}/api/folders", + f"{self._config.server.url}/api/folders", self._headers, {"path": data.get("folder")}, ) @@ -97,7 +104,7 @@ def create_run(self, data) -> tuple[typing.Optional[str], typing.Optional[int]]: logger.debug('Creating run with data: "%s"', data) try: - response = post(f"{self._url}/api/runs", self._headers, data) + response = post(f"{self._config.server.url}/api/runs", self._headers, data) except Exception as err: self._error(f"Exception creating run: {str(err)}") return (None, None) @@ -136,7 +143,7 @@ def update( logger.debug('Updating run with data: "%s"', data) try: - response = put(f"{self._url}/api/runs", self._headers, data) + response = put(f"{self._config.server.url}/api/runs", self._headers, data) except Exception as err: self._error(f"Exception updating run: {err}") return None @@ -164,7 +171,9 @@ def set_folder_details( data["name"] = run try: - response = post(f"{self._url}/api/folders", self._headers, data) + response = post( + f"{self._config.server.url}/api/folders", self._headers, data + ) except Exception as err: self._error(f"Exception creating folder: {err}") return None @@ -181,7 +190,9 @@ def set_folder_details( logger.debug('Setting folder details with data: "%s"', data) try: - response = put(f"{self._url}/api/folders", self._headers, data) + response = put( + f"{self._config.server.url}/api/folders", self._headers, data + ) except Exception as err: self._error(f"Exception setting folder details: {err}") return None @@ -212,7 +223,9 @@ def save_file( # Get presigned URL try: response = post( - f"{self._url}/api/artifacts", self._headers, prepare_for_api(data) + f"{self._config.server.url}/api/artifacts", + self._headers, + prepare_for_api(data), ) except Exception as err: self._error( @@ -294,7 +307,7 @@ def save_file( return None if storage_id: - path = f"{self._url}/api/runs/{self._id}/artifacts" + path = f"{self._config.server.url}/api/runs/{self._id}/artifacts" data["storage"] = storage_id try: @@ -324,7 +337,9 @@ def add_alert(self, data, run=None): logger.debug('Adding alert with data: "%s"', data) try: - response = post(f"{self._url}/api/alerts", self._headers, data) + response = post( + f"{self._config.server.url}/api/alerts", self._headers, data + ) except Exception as err: self._error(f"Got exception when creating an alert: {str(err)}") return False @@ -350,7 +365,9 @@ def set_alert_state( """ data = {"run": self._id, "alert": alert_id, "status": status} try: - response = put(f"{self._url}/api/alerts/status", self._headers, data) + response = put( + f"{self._config.server.url}/api/alerts/status", self._headers, data + ) except Exception as err: self._error(f"Got exception when setting alert state: {err}") return {} @@ -366,7 +383,7 @@ def list_alerts(self) -> list[dict[str, typing.Any]]: List alerts """ try: - response = get(f"{self._url}/api/alerts", self._headers) + response = get(f"{self._config.server.url}/api/alerts", self._headers) except Exception as err: self._error(f"Got exception when listing alerts: {str(err)}") return [] @@ -395,7 +412,10 @@ def send_metrics( try: response = post( - f"{self._url}/api/metrics", self._headers_mp, data, is_json=False + f"{self._config.server.url}/api/metrics", + self._headers_mp, + data, + is_json=False, ) except Exception as err: self._error(f"Exception sending metrics: {str(err)}") @@ -420,7 +440,10 @@ def send_event( try: response = post( - f"{self._url}/api/events", self._headers_mp, data, is_json=False + f"{self._config.server.url}/api/events", + self._headers_mp, + data, + is_json=False, ) except Exception as err: self._error(f"Exception sending event: {str(err)}") @@ -443,7 +466,9 @@ def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]: try: response = put( - f"{self._url}/api/runs/heartbeat", self._headers, {"id": self._id} + f"{self._config.server.url}/api/runs/heartbeat", + self._headers, + {"id": self._id}, ) except Exception as err: self._error(f"Exception creating run: {str(err)}") @@ -457,42 +482,14 @@ def send_heartbeat(self) -> typing.Optional[dict[str, typing.Any]]: self._error(f"Got status code {response.status_code} when sending heartbeat") return None - @skip_if_failed("_aborted", "_suppress_errors", False) - def check_token(self) -> bool: - """ - Check token - """ - if not (expiry := get_expiry(self._token)): - self._error("Failed to parse user token") - return False - - if time.time() - expiry > 0: - self._error("Token has expired") - return False - - try: - response = get(f"{self._url}/api/version", self._headers) - - if response.status_code != http.HTTPStatus.OK or not response.json().get( - "version" - ): - raise AssertionError - - if response.status_code == http.HTTPStatus.UNAUTHORIZED: - self._error("Unauthorised token") - return False - - except Exception as err: - self._error(f"Exception retrieving server version: {str(err)}") - return False - return True - @skip_if_failed("_aborted", "_suppress_errors", False) def get_abort_status(self) -> bool: logger.debug("Retrieving alert status") try: - response = get(f"{self._url}/api/runs/{self._id}/abort", self._headers_mp) + response = get( + f"{self._config.server.url}/api/runs/{self._id}/abort", self._headers_mp + ) except Exception as err: self._error(f"Exception retrieving abort status: {str(err)}") return False diff --git a/simvue/run.py b/simvue/run.py index 8210820a..a6d8f8ae 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -32,6 +32,7 @@ import psutil from pydantic import ValidationError +from .config import SimvueConfiguration import simvue.api as sv_api from .factory.dispatch import Dispatcher @@ -47,7 +48,6 @@ calculate_sha256, compare_alerts, skip_if_failed, - get_auth, get_offline_directory, validate_timestamp, simvue_timestamp, @@ -101,6 +101,9 @@ def __init__( self, mode: typing.Literal["online", "offline", "disabled"] = "online", abort_callback: typing.Optional[typing.Callable[[Self], None]] = None, + server_token: typing.Optional[str] = None, + server_url: typing.Optional[str] = None, + debug: bool = False, ) -> None: """Initialise a new Simvue run @@ -115,6 +118,12 @@ def __init__( disabled - disable monitoring completely abort_callback : Callable | None, optional callback executed when the run is aborted + server_token : str, optional + overwrite value for server token, default is None + server_url : str, optional + overwrite value for server URL, default is None + debug : bool, optional + run in debug mode, default is False """ self._uuid: str = f"{uuid.uuid4()}" self._mode: typing.Literal["online", "offline", "disabled"] = mode @@ -142,10 +151,22 @@ def __init__( self._data: dict[str, typing.Any] = {} self._step: int = 0 self._active: bool = False + self._config = SimvueConfiguration.fetch( + server_token=server_token, server_url=server_url + ) + + logging.getLogger(self.__class__.__module__).setLevel( + logging.DEBUG + if (debug is not None and debug) + or (debug is None and self._config.client.debug) + else logging.INFO + ) + self._aborted: bool = False - self._url, self._token = get_auth() self._resources_metrics_interval: typing.Optional[int] = HEARTBEAT_INTERVAL - self._headers: dict[str, str] = {"Authorization": f"Bearer {self._token}"} + self._headers: dict[str, str] = { + "Authorization": f"Bearer {self._config.server.token}" + } self._simvue: typing.Optional[SimvueBaseClass] = None self._pid: typing.Optional[int] = 0 self._shutdown_event: typing.Optional[threading.Event] = None @@ -286,7 +307,7 @@ def _create_heartbeat_callback( self, ) -> typing.Callable[[threading.Event], None]: if ( - self._mode == "online" and (not self._url or not self._id) + self._mode == "online" and (not self._config.server.url or not self._id) ) or not self._heartbeat_termination_trigger: raise RuntimeError("Could not commence heartbeat, run not initialised") @@ -368,7 +389,7 @@ def _create_dispatch_callback( if self._mode == "online" and not self._id: raise RuntimeError("Expected identifier for run") - if not self._url: + if not self._config.server.url: raise RuntimeError("Cannot commence dispatch, run not initialised") def _offline_dispatch_callback( @@ -403,7 +424,7 @@ def _offline_dispatch_callback( def _online_dispatch_callback( buffer: list[typing.Any], category: str, - url: str = self._url, + url: str = self._config.server.url, run_id: typing.Optional[str] = self._id, headers: dict[str, str] = self._headers, ) -> None: @@ -450,9 +471,6 @@ def _start(self, reconnect: bool = False) -> bool: logger.debug("Starting run") - if self._simvue and not self._simvue.check_token(): - return False - data: dict[str, typing.Any] = {"status": self._status} if reconnect: @@ -551,10 +569,14 @@ def init( typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] ] = None, *, - metadata: typing.Optional[dict[str, typing.Any]] = None, + metadata: typing.Optional[ + dict[str, typing.Union[str, int, float, bool]] + ] = None, tags: typing.Optional[list[str]] = None, description: typing.Optional[str] = None, - folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] = "/", + folder: typing.Annotated[ + str, pydantic.Field(None, pattern=FOLDER_REGEX) + ] = None, running: bool = True, retention_period: typing.Optional[str] = None, timeout: typing.Optional[int] = 180, @@ -605,6 +627,12 @@ def init( ) return True + description = description or self._config.run.description + tags = (tags or []) + (self._config.run.tags or []) + folder = folder or self._config.run.folder + name = name or self._config.run.name + metadata = (metadata or {}) | (self._config.run.metadata or {}) + self._term_color = not no_color if isinstance(visibility, str) and visibility not in ("public", "tenant"): @@ -616,7 +644,7 @@ def init( self._error("invalid mode specified, must be online, offline or disabled") return False - if not self._token or not self._url: + if not self._config.server.token or not self._config.server.url: self._error( "Unable to get URL and token from environment variables or config file" ) @@ -666,7 +694,13 @@ def init( self._error(f"{err}") return False - self._simvue = Simvue(self._name, self._uuid, self._mode, self._suppress_errors) + self._simvue = Simvue( + name=self._name, + uniq_id=self._uuid, + mode=self._mode, + config=self._config, + suppress_errors=self._suppress_errors, + ) name, self._id = self._simvue.create_run(data) self._data = data @@ -687,7 +721,7 @@ def init( fg="green" if self._term_color else None, ) click.secho( - f"[simvue] Monitor in the UI at {self._url}/dashboard/runs/run/{self._id}", + f"[simvue] Monitor in the UI at {self._config.server.url}/dashboard/runs/run/{self._id}", bold=self._term_color, fg="green" if self._term_color else None, ) @@ -894,7 +928,9 @@ def reconnect(self, run_id: str) -> bool: self._status = "running" self._id = run_id - self._simvue = Simvue(self._name, self._id, self._mode, self._suppress_errors) + self._simvue = Simvue( + self._name, self._id, self._mode, self._config, self._suppress_errors + ) self._start(reconnect=True) return True diff --git a/simvue/sender.py b/simvue/sender.py index 8bee03d7..7a95b610 100644 --- a/simvue/sender.py +++ b/simvue/sender.py @@ -8,6 +8,8 @@ import msgpack +from simvue.config.user import SimvueConfiguration + from .factory.proxy.remote import Remote from .utilities import create_file, get_offline_directory, remove_file @@ -170,11 +172,11 @@ def process(run): # Create run if it hasn't previously been created created_file = f"{current}/init" name = None + config = SimvueConfiguration() if not os.path.isfile(created_file): - remote = Remote(run_init["name"], id, suppress_errors=False) - - # Check token - remote.check_token() + remote = Remote( + name=run_init["name"], uniq_id=id, config=config, suppress_errors=False + ) name, run_id = remote.create_run(run_init) if name: @@ -187,10 +189,9 @@ def process(run): else: name, run_id = get_details(created_file) run_init["name"] = name - remote = Remote(run_init["name"], run_id, suppress_errors=False) - - # Check token - remote.check_token() + remote = Remote( + name=run_init["name"], uniq_id=run_id, config=config, suppress_errors=False + ) if status == "running": # Check for recent heartbeat diff --git a/simvue/utilities.py b/simvue/utilities.py index cadcf279..f7bc0724 100644 --- a/simvue/utilities.py +++ b/simvue/utilities.py @@ -9,6 +9,7 @@ import functools import contextlib import os +import pathlib import typing import jwt @@ -24,6 +25,45 @@ from simvue.run import Run +def find_first_instance_of_file( + file_names: typing.Union[list[str], str], check_user_space: bool = True +) -> typing.Optional[pathlib.Path]: + """Traverses a file hierarchy from bottom upwards to find file + + Returns the first instance of 'file_names' found when moving + upward from the current directory. + + Parameters + ---------- + file_name: list[str] | str + candidate names of file to locate + check_user_space: bool, optional + check the users home area if current working directory is not + within it. Default is True. + + Returns + ------- + pathlib.Path | None + first matching file if found + """ + if isinstance(file_names, str): + file_names = [file_names] + + for root, _, files in os.walk(os.getcwd(), topdown=False): + for file_name in file_names: + if file_name in files: + return pathlib.Path(root).joinpath(file_name) + + # If the user is running on different mounted volume or outside + # of their user space then the above will not return the file + if check_user_space: + for file_name in file_names: + if os.path.exists(_user_file := pathlib.Path.home().joinpath(file_name)): + return _user_file + + return None + + def parse_validation_response( response: dict[str, list[dict[str, str]]], ) -> str: @@ -209,45 +249,6 @@ def wrapper(self, *args, **kwargs) -> typing.Any: return wrapper -def get_auth() -> tuple[str, str]: - """ - Get the URL and access token - """ - url: typing.Optional[str] = None - token: typing.Optional[str] = None - token_source: str = "" - url_source: str = "" - - # Try reading from config file - for filename in ( - os.path.join(os.path.expanduser("~"), ".simvue.ini"), - "simvue.ini", - ): - with contextlib.suppress(Exception): - config = configparser.ConfigParser() - config.read(filename) - token = config.get("server", "token") - token_source = filename - url = config.get("server", "url") - url_source = filename - - # Try environment variables - if not token and (token := os.getenv("SIMVUE_TOKEN")): - token_source = "env:SIMVUE_TOKEN" - if not url and (url := os.getenv("SIMVUE_URL")): - url_source = "env:SIMVUE_URL" - - if not token: - raise ValueError("No Simvue server token was specified") - if not url: - raise ValueError("No Simvue server URL was specified") - - logger.info(f"Using '{token_source}' as source for Simvue token") - logger.info(f"Using '{url_source}' as source for Simvue URL") - - return url, token - - def get_offline_directory() -> str: """ Get directory for offline cache diff --git a/tests/refactor/conftest.py b/tests/refactor/conftest.py index 106208e6..61214b10 100644 --- a/tests/refactor/conftest.py +++ b/tests/refactor/conftest.py @@ -113,7 +113,7 @@ def setup_test_run(run: sv_run.Run, create_objects: bool, request: pytest.Fixtur TEST_DATA["metrics"] = ("metric_counter", "metric_val") TEST_DATA["run_id"] = run._id TEST_DATA["run_name"] = run._name - TEST_DATA["url"] = run._url + TEST_DATA["url"] = run._config.server.url TEST_DATA["headers"] = run._headers TEST_DATA["pid"] = run._pid TEST_DATA["resources_metrics_interval"] = run._resources_metrics_interval @@ -139,3 +139,4 @@ def setup_test_run(run: sv_run.Run, create_objects: bool, request: pytest.Fixtur time.sleep(1.) return TEST_DATA + diff --git a/tests/refactor/test_config.py b/tests/refactor/test_config.py new file mode 100644 index 00000000..f7236a1d --- /dev/null +++ b/tests/refactor/test_config.py @@ -0,0 +1,118 @@ +import pytest +import typing +import os +import uuid +import pathlib +import pytest_mock +import tempfile +from simvue.config import SimvueConfiguration + + +@pytest.mark.config +@pytest.mark.parametrize( + "use_env", (True, False), + ids=("use_env", "no_env") +) +@pytest.mark.parametrize( + "use_file", (None, "basic", "extended", "ini"), + ids=("no_file", "basic_file", "extended_file", "legacy_file") +) +@pytest.mark.parametrize( + "use_args", (True, False), + ids=("args", "no_args") +) +def test_config_setup( + use_env: bool, + use_file: str | None, + use_args: bool, + monkeypatch: pytest.MonkeyPatch, + mocker: pytest_mock.MockerFixture +) -> None: + _token: str = f"{uuid.uuid4()}".replace('-', '') + _other_token: str = f"{uuid.uuid4()}".replace('-', '') + _arg_token: str = f"{uuid.uuid4()}".replace('-', '') + _url: str = "https://simvue.example.com/" + _other_url: str = "http://simvue.example.com/" + _arg_url: str = "http://simvue.example.io/" + _description: str = "test case for runs" + _folder: str = "/test-case" + _tags: list[str] = ["tag-test", "other-tag"] + + # Deactivate the server checks for this test + monkeypatch.setenv("SIMVUE_NO_SERVER_CHECK", True) + + if use_env: + monkeypatch.setenv("SIMVUE_TOKEN", _other_token) + monkeypatch.setenv("SIMVUE_URL", _other_url) + else: + monkeypatch.delenv("SIMVUE_TOKEN", False) + monkeypatch.delenv("SIMVUE_URL", False) + + with tempfile.TemporaryDirectory() as temp_d: + _config_file = None + if use_file: + with open(_config_file := pathlib.Path(temp_d).joinpath(f"simvue.{'toml' if use_file != 'ini' else 'ini'}"), "w") as out_f: + if use_file != "ini": + _lines: str = f""" +[server] +url = "{_url}" +token = "{_token}" +""" + else: + _lines = f""" +[server] +url = {_url} +token = {_token} +""" + + if use_file == "extended": + _lines += f""" +[run] +description = "{_description}" +folder = "{_folder}" +tags = {_tags} +""" + out_f.write(_lines) + SimvueConfiguration.config_file.cache_clear() + + mocker.patch("simvue.config.parameters.get_expiry", lambda *_, **__: 1e10) + mocker.patch("simvue.config.user.sv_util.find_first_instance_of_file", lambda *_, **__: _config_file) + + import simvue.config + + if not use_file and not use_env and not use_args: + with pytest.raises(RuntimeError): + simvue.config.SimvueConfiguration.fetch() + return + elif use_args: + _config = simvue.config.SimvueConfiguration.fetch( + server_url=_arg_url, + server_token=_arg_token + ) + else: + _config = simvue.config.SimvueConfiguration.fetch() + + if use_file: + assert _config.config_file() == _config_file + + if use_env: + assert _config.server.url == _other_url + assert _config.server.token == _other_token + elif use_args: + assert _config.server.url == _arg_url + assert _config.server.token == _arg_token + elif use_file: + assert _config.server.url == _url + assert _config.server.token == _token + + if use_file == "extended": + assert _config.run.description == _description + assert _config.run.folder == _folder + assert _config.run.tags == _tags + elif use_file: + assert _config.run.folder == "/" + assert not _config.run.description + assert not _config.run.tags + + simvue.config.SimvueConfiguration.config_file.cache_clear() + diff --git a/tests/refactor/test_proxies.py b/tests/refactor/test_proxies.py deleted file mode 100644 index 00a09c3b..00000000 --- a/tests/refactor/test_proxies.py +++ /dev/null @@ -1,15 +0,0 @@ -import pytest - -from simvue.factory.proxy import Simvue - -@pytest.mark.proxy -def test_simvue_url_check() -> None: - """Checks the Token/URL checker""" - remote = Simvue( - name="", - uniq_id="", - mode="online", - suppress_errors=False - ) - assert remote.check_token() - diff --git a/tests/refactor/test_run_class.py b/tests/refactor/test_run_class.py index 28ea434b..8ff44a89 100644 --- a/tests/refactor/test_run_class.py +++ b/tests/refactor/test_run_class.py @@ -502,6 +502,7 @@ def abort_callback(abort_run=trigger) -> None: client = sv_cl.Client() client.abort_run(run._id, reason="testing abort") time.sleep(4) + assert run._resources_metrics_interval == 1 for child in child_processes: assert not child.is_running() if not run._status == "terminated": @@ -528,10 +529,10 @@ def testing_exit(status: int) -> None: if i == 4: client.abort_run(run._id, reason="testing abort") i += 1 - if abort_set.is_set() or i > 9: + if abort_set.is_set() or i > 11: break - assert i < 7 + assert i < 10 assert run._status == "terminated" diff --git a/tests/unit/test_suppress_errors.py b/tests/unit/test_suppress_errors.py index 650ba16c..8a1b9f29 100644 --- a/tests/unit/test_suppress_errors.py +++ b/tests/unit/test_suppress_errors.py @@ -13,7 +13,6 @@ def test_suppress_errors_false(): suppress_errors=False, disable_resources_metrics=123, ) - print(e.value) assert "Input should be a valid boolean, unable to interpret input" in f"{e.value}" def test_suppress_errors_true(caplog): @@ -44,4 +43,4 @@ def test_suppress_errors_default(caplog): caplog.set_level(logging.ERROR) - assert "Input should be a valid boolean, unable to interpret input" in caplog.text \ No newline at end of file + assert "Input should be a valid boolean, unable to interpret input" in caplog.text