Skip to content

Paginate results from server #762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "simvue"
version = "2.0.1"
version = "2.1.0"
description = "Simulation tracking and monitoring"
authors = [
{name = "Simvue Development Team", email = "[email protected]"}
Expand Down
93 changes: 44 additions & 49 deletions simvue/api/objects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from simvue.version import __version__
from simvue.api.request import (
get as sv_get,
get_paginated,
post as sv_post,
put as sv_put,
delete as sv_delete,
Expand Down Expand Up @@ -347,7 +348,7 @@ def new(cls, **_) -> Self:
@classmethod
def ids(
cls, count: int | None = None, offset: int | None = None, **kwargs
) -> list[str]:
) -> typing.Generator[str, None, None]:
"""Retrieve a list of all object identifiers.

Parameters
Expand All @@ -357,17 +358,23 @@ def ids(
offset : int | None, optional
set start index for objects list

Returns
Yields
-------
list[str]
str
identifiers for all objects of this type.
"""
_class_instance = cls(_read_only=True, _local=True)
if (_data := cls._get_all_objects(count, offset, **kwargs).get("data")) is None:
raise RuntimeError(
f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)
return [_entry["id"] for _entry in _data]
_count: int = 0
for response in cls._get_all_objects(offset):
if (_data := response.get("data")) is None:
raise RuntimeError(
f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)
for entry in _data:
yield entry["id"]
_count += 1
if count and _count > count:
return

@classmethod
@pydantic.validate_call
Expand Down Expand Up @@ -396,23 +403,19 @@ def get(
Generator[tuple[str, SimvueObject | None], None, None]
"""
_class_instance = cls(_read_only=True, _local=True)
if (
_data := cls._get_all_objects(
count=count,
offset=offset,
**kwargs,
).get("data")
) is None:
raise RuntimeError(
f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)

for _entry in _data:
if not (_id := _entry.pop("id", None)):
_count: int = 0
for _response in cls._get_all_objects(offset, **kwargs):
if count and _count > count:
return
if (_data := _response.get("data")) is None:
raise RuntimeError(
f"Expected key 'id' for {_class_instance.__class__.__name__.lower()}"
f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)
yield _id, cls(_read_only=True, identifier=_id, _local=True, **_entry)

for entry in _data:
_id = entry["id"]
yield _id, cls(_read_only=True, identifier=_id, _local=True, **entry)
_count += 1

@classmethod
def count(cls, **kwargs) -> int:
Expand All @@ -424,42 +427,34 @@ def count(cls, **kwargs) -> int:
total from server database for current user.
"""
_class_instance = cls(_read_only=True)
if (
_count := cls._get_all_objects(count=None, offset=None, **kwargs).get(
"count"
)
) is None:
raise RuntimeError(
f"Expected key 'count' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)
return _count
_count_total: int = 0
for _data in cls._get_all_objects(**kwargs):
if not (_count := _data.get("count")):
raise RuntimeError(
f"Expected key 'count' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)
_count_total += _count
return _count_total

@classmethod
def _get_all_objects(
cls,
count: int | None,
offset: int | None,
**kwargs,
) -> dict[str, typing.Any]:
cls, offset: int | None, **kwargs
) -> typing.Generator[dict, None, None]:
_class_instance = cls(_read_only=True)
_url = f"{_class_instance._base_url}"
_params: dict[str, int | str] = {"start": offset, "count": count}

_response = sv_get(
_url,
headers=_class_instance._headers,
params=_params | kwargs,
)

_label = _class_instance.__class__.__name__.lower()
if _label.endswith("s"):
_label = _label[:-1]

return get_json_from_response(
response=_response,
expected_status=[http.HTTPStatus.OK],
scenario=f"Retrieval of {_label}s",
)
for response in get_paginated(
_url, headers=_class_instance._headers, offset=offset, **kwargs
):
yield get_json_from_response(
response=response,
expected_status=[http.HTTPStatus.OK],
scenario=f"Retrieval of {_label}s",
) # type: ignore

def read_only(self, is_read_only: bool) -> None:
"""Set whether this object is in read only state.
Expand Down
25 changes: 13 additions & 12 deletions simvue/api/objects/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ def get(
**kwargs,
) -> typing.Generator[EventSet, None, None]:
_class_instance = cls(_read_only=True, _local=True)

if (
_data := cls._get_all_objects(count, offset, run=run_id, **kwargs).get(
"data"
)
) is None:
raise RuntimeError(
f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)

for _entry in _data:
yield EventSet(**_entry)
_count: int = 0

for response in cls._get_all_objects(offset, run=run_id, **kwargs):
if (_data := response.get("data")) is None:
raise RuntimeError(
f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s"
)

for _entry in _data:
yield EventSet(**_entry)
_count += 1
if _count > count:
return

@classmethod
@pydantic.validate_call
Expand Down
11 changes: 4 additions & 7 deletions simvue/api/objects/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get(
count: pydantic.PositiveInt | None = None,
offset: pydantic.PositiveInt | None = None,
**kwargs,
) -> typing.Generator[MetricSet, None, None]:
) -> typing.Generator[dict[str, dict[str, list[dict[str, float]]]], None, None]:
"""Retrieve metrics from the server for a given set of runs.

Parameters
Expand All @@ -100,20 +100,17 @@ def get(

Yields
------
MetricSet
dict[str, dict[str, list[dict[str, float]]]
metric set object containing metrics for run.
"""
_class_instance = cls(_read_only=True, _local=True)
_data = cls._get_all_objects(
count,
yield from cls._get_all_objects(
offset,
metrics=json.dumps(metrics),
runs=json.dumps(runs),
xaxis=xaxis,
count=count,
**kwargs,
)
# TODO: Temp fix, just return the dictionary. Not sure what format we really want this in...
return _data

@pydantic.validate_call
def span(self, run_ids: list[str]) -> dict[str, int | float]:
Expand Down
49 changes: 49 additions & 0 deletions simvue/api/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
RETRY_MIN = 4
RETRY_MAX = 10
RETRY_STOP = 5
MAX_ENTRIES_PER_PAGE: int = 100
RETRY_STATUS_CODES = (
http.HTTPStatus.BAD_REQUEST,
http.HTTPStatus.SERVICE_UNAVAILABLE,
Expand Down Expand Up @@ -273,3 +274,51 @@ def get_json_from_response(
error_str += f": {txt_response}"

raise RuntimeError(error_str)


def get_paginated(
url: str,
headers: dict[str, str] | None = None,
timeout: int = DEFAULT_API_TIMEOUT,
json: dict[str, typing.Any] | None = None,
offset: int | None = None,
**params,
) -> typing.Generator[requests.Response, None, None]:
"""Paginate results of a server query.

Parameters
----------
url : str
URL to put to
headers : dict[str, str]
headers for the post request
timeout : int, optional
timeout of request, by default DEFAULT_API_TIMEOUT
json : dict[str, Any] | None, optional
any json to send in request

Yield
-----
requests.Response
server response
"""
_offset: int = offset or 0

while (
(
_response := get(
url=url,
headers=headers,
params=(params or {})
| {"count": MAX_ENTRIES_PER_PAGE, "start": _offset},
timeout=timeout,
json=json,
)
)
.json()
.get("data")
):
yield _response
_offset += MAX_ENTRIES_PER_PAGE

yield _response
5 changes: 4 additions & 1 deletion simvue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@ def _get_folder_id_from_path(self, path: str) -> str | None:
"""
_ids = Folder.ids(filters=json.dumps([f"path == {path}"]))

return _ids[0] if _ids else None
try:
return next(_ids)
except StopIteration:
return None

@prettify_pydantic
@pydantic.validate_call
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def test_get_alerts(
_id_1 = run.create_user_alert(
name=f"user_alert_1_{unique_id}",
)
_id_2 = run.create_user_alert(
run.create_user_alert(
name=f"user_alert_2_{unique_id}",
)
_id_3 = run.create_user_alert(
run.create_user_alert(
name=f"user_alert_3_{unique_id}",
attach_to_run=False
)
Expand Down
Loading