diff --git a/pyproject.toml b/pyproject.toml index d2bbd621..38af1d38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "simvue" -version = "2.0.1" +version = "2.1.0" description = "Simulation tracking and monitoring" authors = [ {name = "Simvue Development Team", email = "info@simvue.io"} diff --git a/simvue/api/objects/base.py b/simvue/api/objects/base.py index 135b6e07..bc3227f1 100644 --- a/simvue/api/objects/base.py +++ b/simvue/api/objects/base.py @@ -23,6 +23,7 @@ from simvue.version import __version__ from simvue.api.request import ( get as sv_get, + get_paginated, post as sv_post, put as sv_put, delete as sv_delete, @@ -347,7 +348,7 @@ def new(cls, **_) -> Self: @classmethod def ids( cls, count: int | None = None, offset: int | None = None, **kwargs - ) -> list[str]: + ) -> typing.Generator[str, None, None]: """Retrieve a list of all object identifiers. Parameters @@ -357,17 +358,23 @@ def ids( offset : int | None, optional set start index for objects list - Returns + Yields ------- - list[str] + str identifiers for all objects of this type. """ _class_instance = cls(_read_only=True, _local=True) - if (_data := cls._get_all_objects(count, offset, **kwargs).get("data")) is None: - raise RuntimeError( - f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" - ) - return [_entry["id"] for _entry in _data] + _count: int = 0 + for response in cls._get_all_objects(offset): + if (_data := response.get("data")) is None: + raise RuntimeError( + f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" + ) + for entry in _data: + yield entry["id"] + _count += 1 + if count and _count > count: + return @classmethod @pydantic.validate_call @@ -396,23 +403,19 @@ def get( Generator[tuple[str, SimvueObject | None], None, None] """ _class_instance = cls(_read_only=True, _local=True) - if ( - _data := cls._get_all_objects( - count=count, - offset=offset, - **kwargs, - ).get("data") - ) is None: - raise RuntimeError( - f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" - ) - - for _entry in _data: - if not (_id := _entry.pop("id", None)): + _count: int = 0 + for _response in cls._get_all_objects(offset, **kwargs): + if count and _count > count: + return + if (_data := _response.get("data")) is None: raise RuntimeError( - f"Expected key 'id' for {_class_instance.__class__.__name__.lower()}" + f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" ) - yield _id, cls(_read_only=True, identifier=_id, _local=True, **_entry) + + for entry in _data: + _id = entry["id"] + yield _id, cls(_read_only=True, identifier=_id, _local=True, **entry) + _count += 1 @classmethod def count(cls, **kwargs) -> int: @@ -424,42 +427,34 @@ def count(cls, **kwargs) -> int: total from server database for current user. """ _class_instance = cls(_read_only=True) - if ( - _count := cls._get_all_objects(count=None, offset=None, **kwargs).get( - "count" - ) - ) is None: - raise RuntimeError( - f"Expected key 'count' for retrieval of {_class_instance.__class__.__name__.lower()}s" - ) - return _count + _count_total: int = 0 + for _data in cls._get_all_objects(**kwargs): + if not (_count := _data.get("count")): + raise RuntimeError( + f"Expected key 'count' for retrieval of {_class_instance.__class__.__name__.lower()}s" + ) + _count_total += _count + return _count_total @classmethod def _get_all_objects( - cls, - count: int | None, - offset: int | None, - **kwargs, - ) -> dict[str, typing.Any]: + cls, offset: int | None, **kwargs + ) -> typing.Generator[dict, None, None]: _class_instance = cls(_read_only=True) _url = f"{_class_instance._base_url}" - _params: dict[str, int | str] = {"start": offset, "count": count} - - _response = sv_get( - _url, - headers=_class_instance._headers, - params=_params | kwargs, - ) _label = _class_instance.__class__.__name__.lower() if _label.endswith("s"): _label = _label[:-1] - return get_json_from_response( - response=_response, - expected_status=[http.HTTPStatus.OK], - scenario=f"Retrieval of {_label}s", - ) + for response in get_paginated( + _url, headers=_class_instance._headers, offset=offset, **kwargs + ): + yield get_json_from_response( + response=response, + expected_status=[http.HTTPStatus.OK], + scenario=f"Retrieval of {_label}s", + ) # type: ignore def read_only(self, is_read_only: bool) -> None: """Set whether this object is in read only state. diff --git a/simvue/api/objects/events.py b/simvue/api/objects/events.py index 4402f708..b32a7445 100644 --- a/simvue/api/objects/events.py +++ b/simvue/api/objects/events.py @@ -49,18 +49,19 @@ def get( **kwargs, ) -> typing.Generator[EventSet, None, None]: _class_instance = cls(_read_only=True, _local=True) - - if ( - _data := cls._get_all_objects(count, offset, run=run_id, **kwargs).get( - "data" - ) - ) is None: - raise RuntimeError( - f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" - ) - - for _entry in _data: - yield EventSet(**_entry) + _count: int = 0 + + for response in cls._get_all_objects(offset, run=run_id, **kwargs): + if (_data := response.get("data")) is None: + raise RuntimeError( + f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" + ) + + for _entry in _data: + yield EventSet(**_entry) + _count += 1 + if _count > count: + return @classmethod @pydantic.validate_call diff --git a/simvue/api/objects/metrics.py b/simvue/api/objects/metrics.py index a414494a..5e5a1988 100644 --- a/simvue/api/objects/metrics.py +++ b/simvue/api/objects/metrics.py @@ -79,7 +79,7 @@ def get( count: pydantic.PositiveInt | None = None, offset: pydantic.PositiveInt | None = None, **kwargs, - ) -> typing.Generator[MetricSet, None, None]: + ) -> typing.Generator[dict[str, dict[str, list[dict[str, float]]]], None, None]: """Retrieve metrics from the server for a given set of runs. Parameters @@ -100,20 +100,17 @@ def get( Yields ------ - MetricSet + dict[str, dict[str, list[dict[str, float]]] metric set object containing metrics for run. """ - _class_instance = cls(_read_only=True, _local=True) - _data = cls._get_all_objects( - count, + yield from cls._get_all_objects( offset, metrics=json.dumps(metrics), runs=json.dumps(runs), xaxis=xaxis, + count=count, **kwargs, ) - # TODO: Temp fix, just return the dictionary. Not sure what format we really want this in... - return _data @pydantic.validate_call def span(self, run_ids: list[str]) -> dict[str, int | float]: diff --git a/simvue/api/request.py b/simvue/api/request.py index 8dd6a8bd..4a376749 100644 --- a/simvue/api/request.py +++ b/simvue/api/request.py @@ -27,6 +27,7 @@ RETRY_MIN = 4 RETRY_MAX = 10 RETRY_STOP = 5 +MAX_ENTRIES_PER_PAGE: int = 100 RETRY_STATUS_CODES = ( http.HTTPStatus.BAD_REQUEST, http.HTTPStatus.SERVICE_UNAVAILABLE, @@ -273,3 +274,51 @@ def get_json_from_response( error_str += f": {txt_response}" raise RuntimeError(error_str) + + +def get_paginated( + url: str, + headers: dict[str, str] | None = None, + timeout: int = DEFAULT_API_TIMEOUT, + json: dict[str, typing.Any] | None = None, + offset: int | None = None, + **params, +) -> typing.Generator[requests.Response, None, None]: + """Paginate results of a server query. + + Parameters + ---------- + url : str + URL to put to + headers : dict[str, str] + headers for the post request + timeout : int, optional + timeout of request, by default DEFAULT_API_TIMEOUT + json : dict[str, Any] | None, optional + any json to send in request + + Yield + ----- + requests.Response + server response + """ + _offset: int = offset or 0 + + while ( + ( + _response := get( + url=url, + headers=headers, + params=(params or {}) + | {"count": MAX_ENTRIES_PER_PAGE, "start": _offset}, + timeout=timeout, + json=json, + ) + ) + .json() + .get("data") + ): + yield _response + _offset += MAX_ENTRIES_PER_PAGE + + yield _response diff --git a/simvue/client.py b/simvue/client.py index f5845edc..5ba7c320 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -346,7 +346,10 @@ def _get_folder_id_from_path(self, path: str) -> str | None: """ _ids = Folder.ids(filters=json.dumps([f"path == {path}"])) - return _ids[0] if _ids else None + try: + return next(_ids) + except StopIteration: + return None @prettify_pydantic @pydantic.validate_call diff --git a/tests/functional/test_client.py b/tests/functional/test_client.py index 390d90d6..2952e35a 100644 --- a/tests/functional/test_client.py +++ b/tests/functional/test_client.py @@ -45,10 +45,10 @@ def test_get_alerts( _id_1 = run.create_user_alert( name=f"user_alert_1_{unique_id}", ) - _id_2 = run.create_user_alert( + run.create_user_alert( name=f"user_alert_2_{unique_id}", ) - _id_3 = run.create_user_alert( + run.create_user_alert( name=f"user_alert_3_{unique_id}", attach_to_run=False )