10
10
import logging
11
11
import os
12
12
import typing
13
+ import pydantic
13
14
from concurrent .futures import ThreadPoolExecutor , as_completed
15
+ from pandas import DataFrame
14
16
15
17
import requests
16
18
21
23
)
22
24
from .serialization import deserialize_data
23
25
from .types import DeserializedContent
24
- from .utilities import check_extra , get_auth
26
+ from .utilities import check_extra , get_auth , prettify_pydantic
27
+ from .models import FOLDER_REGEX , NAME_REGEX
25
28
26
29
if typing .TYPE_CHECKING :
27
- from matplotlib .figure import Figure
28
- from pandas import DataFrame
30
+ pass
29
31
30
32
CONCURRENT_DOWNLOADS = 10
31
33
DOWNLOAD_CHUNK_SIZE = 8192
@@ -133,7 +135,11 @@ def _get_json_from_response(
133
135
134
136
raise RuntimeError (error_str )
135
137
136
- def get_run_id_from_name (self , name : str ) -> str :
138
+ @prettify_pydantic
139
+ @pydantic .validate_call
140
+ def get_run_id_from_name (
141
+ self , name : typing .Annotated [str , pydantic .Field (pattern = NAME_REGEX )]
142
+ ) -> str :
137
143
"""Get Run ID from the server matching the specified name
138
144
139
145
Assumes a unique name for this run. If multiple results are found this
@@ -186,6 +192,8 @@ def get_run_id_from_name(self, name: str) -> str:
186
192
raise RuntimeError ("Failed to retrieve identifier for run." )
187
193
return first_id
188
194
195
+ @prettify_pydantic
196
+ @pydantic .validate_call
189
197
def get_run (self , run_id : str ) -> typing .Optional [dict [str , typing .Any ]]:
190
198
"""Retrieve a single run
191
199
@@ -225,6 +233,8 @@ def get_run(self, run_id: str) -> typing.Optional[dict[str, typing.Any]]:
225
233
)
226
234
return json_response
227
235
236
+ @prettify_pydantic
237
+ @pydantic .validate_call
228
238
def get_run_name_from_id (self , run_id : str ) -> str :
229
239
"""Retrieve the name of a run from its identifier
230
240
@@ -250,6 +260,8 @@ def get_run_name_from_id(self, run_id: str) -> str:
250
260
raise RuntimeError ("Expected key 'name' in server response" )
251
261
return _name
252
262
263
+ @prettify_pydantic
264
+ @pydantic .validate_call
253
265
def get_runs (
254
266
self ,
255
267
filters : typing .Optional [list [str ]],
@@ -261,7 +273,7 @@ def get_runs(
261
273
count : int = 100 ,
262
274
start_index : int = 0 ,
263
275
) -> typing .Union [
264
- " DataFrame" , list [dict [str , typing .Union [int , str , float , None ]]], None
276
+ DataFrame , list [dict [str , typing .Union [int , str , float , None ]]], None
265
277
]:
266
278
"""Retrieve all runs matching filters.
267
279
@@ -337,6 +349,8 @@ def get_runs(
337
349
else :
338
350
raise RuntimeError ("Failed to retrieve runs data" )
339
351
352
+ @prettify_pydantic
353
+ @pydantic .validate_call
340
354
def delete_run (self , run_identifier : str ) -> typing .Optional [dict ]:
341
355
"""Delete run by identifier
342
356
@@ -404,13 +418,18 @@ def _get_folder_id_from_path(self, path: str) -> typing.Optional[str]:
404
418
405
419
return None
406
420
407
- def delete_runs (self , folder_name : str ) -> typing .Optional [list ]:
421
+ @prettify_pydantic
422
+ @pydantic .validate_call
423
+ def delete_runs (
424
+ self , folder_path : typing .Annotated [str , pydantic .Field (pattern = FOLDER_REGEX )]
425
+ ) -> typing .Optional [list ]:
408
426
"""Delete runs in a named folder
409
427
410
428
Parameters
411
429
----------
412
- folder_name : str
413
- the name of the folder on which to perform deletion
430
+ folder_path : str
431
+ the path of the folder on which to perform deletion. All folder
432
+ paths are prefixed with `/`
414
433
415
434
Returns
416
435
-------
@@ -422,10 +441,10 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]:
422
441
RuntimeError
423
442
if deletion fails due to server request error
424
443
"""
425
- folder_id = self ._get_folder_id_from_path (folder_name )
444
+ folder_id = self ._get_folder_id_from_path (folder_path )
426
445
427
446
if not folder_id :
428
- raise ValueError (f"Could not find a folder matching '{ folder_name } '" )
447
+ raise ValueError (f"Could not find a folder matching '{ folder_path } '" )
429
448
430
449
params : dict [str , bool ] = {"runs_only" : True , "runs" : True }
431
450
@@ -435,19 +454,21 @@ def delete_runs(self, folder_name: str) -> typing.Optional[list]:
435
454
436
455
if response .status_code == 200 :
437
456
if runs := response .json ().get ("runs" , []):
438
- logger .debug (f"Runs from '{ folder_name } ' deleted successfully: { runs } " )
457
+ logger .debug (f"Runs from '{ folder_path } ' deleted successfully: { runs } " )
439
458
else :
440
459
logger .debug ("Folder empty, no runs deleted." )
441
460
return runs
442
461
443
462
raise RuntimeError (
444
- f"Deletion of runs from folder '{ folder_name } ' failed"
463
+ f"Deletion of runs from folder '{ folder_path } ' failed"
445
464
f"with code { response .status_code } : { response .text } "
446
465
)
447
466
467
+ @prettify_pydantic
468
+ @pydantic .validate_call
448
469
def delete_folder (
449
470
self ,
450
- folder_name : str ,
471
+ folder_path : typing . Annotated [ str , pydantic . Field ( pattern = FOLDER_REGEX )] ,
451
472
recursive : bool = False ,
452
473
remove_runs : bool = False ,
453
474
allow_missing : bool = False ,
@@ -456,8 +477,8 @@ def delete_folder(
456
477
457
478
Parameters
458
479
----------
459
- folder_name : str
460
- name of the folder to delete
480
+ folder_path : str
481
+ name of the folder to delete. All paths are prefixed with `/`
461
482
recursive : bool, optional
462
483
if folder contains additional folders remove these, else return an
463
484
error. Default False.
@@ -477,14 +498,14 @@ def delete_folder(
477
498
RuntimeError
478
499
if deletion of the folder from the server failed
479
500
"""
480
- folder_id = self ._get_folder_id_from_path (folder_name )
501
+ folder_id = self ._get_folder_id_from_path (folder_path )
481
502
482
503
if not folder_id :
483
504
if allow_missing :
484
505
return None
485
506
else :
486
507
raise RuntimeError (
487
- f"Deletion of folder '{ folder_name } ' failed, "
508
+ f"Deletion of folder '{ folder_path } ' failed, "
488
509
"folder does not exist."
489
510
)
490
511
@@ -497,7 +518,7 @@ def delete_folder(
497
518
498
519
json_response = self ._get_json_from_response (
499
520
expected_status = [200 , 404 ],
500
- scenario = f"Deletion of folder '{ folder_name } '" ,
521
+ scenario = f"Deletion of folder '{ folder_path } '" ,
501
522
response = response ,
502
523
)
503
524
@@ -510,6 +531,8 @@ def delete_folder(
510
531
runs : list [dict ] = json_response .get ("runs" , [])
511
532
return runs
512
533
534
+ @prettify_pydantic
535
+ @pydantic .validate_call
513
536
def list_artifacts (self , run_id : str ) -> list [dict [str , typing .Any ]]:
514
537
"""Retrieve artifacts for a given run
515
538
@@ -574,9 +597,11 @@ def _retrieve_artifact_from_server(
574
597
575
598
return json_response
576
599
600
+ @prettify_pydantic
601
+ @pydantic .validate_call
577
602
def get_artifact (
578
603
self , run_id : str , name : str , allow_pickle : bool = False
579
- ) -> typing .Optional [ DeserializedContent ] :
604
+ ) -> typing .Any :
580
605
"""Return the contents of a specified artifact
581
606
582
607
Parameters
@@ -618,6 +643,8 @@ def get_artifact(
618
643
619
644
return content or response .content
620
645
646
+ @prettify_pydantic
647
+ @pydantic .validate_call
621
648
def get_artifact_as_file (
622
649
self , run_id : str , name : str , path : typing .Optional [str ] = None
623
650
) -> None :
@@ -708,6 +735,8 @@ def _assemble_artifact_downloads(
708
735
709
736
return downloads
710
737
738
+ @prettify_pydantic
739
+ @pydantic .validate_call
711
740
def get_artifacts_as_files (
712
741
self ,
713
742
run_id : str ,
@@ -771,13 +800,18 @@ def get_artifacts_as_files(
771
800
f"failed with exception: { e } "
772
801
)
773
802
774
- def get_folder (self , folder_id : str ) -> typing .Optional [dict [str , typing .Any ]]:
803
+ @prettify_pydantic
804
+ @pydantic .validate_call
805
+ def get_folder (
806
+ self , folder_path : typing .Annotated [str , pydantic .Field (pattern = FOLDER_REGEX )]
807
+ ) -> typing .Optional [dict [str , typing .Any ]]:
775
808
"""Retrieve a folder by identifier
776
809
777
810
Parameters
778
811
----------
779
- folder_id : str
780
- unique identifier for the folder
812
+ folder_path : str
813
+ the path of the folder to retrieve on the server.
814
+ Paths are prefixed with `/`
781
815
782
816
Returns
783
817
-------
@@ -789,15 +823,16 @@ def get_folder(self, folder_id: str) -> typing.Optional[dict[str, typing.Any]]:
789
823
RuntimeError
790
824
if there was a failure when retrieving information from the server
791
825
"""
792
- if not (_folders := self .get_folders (filters = [f"path == { folder_id } " ])):
826
+ if not (_folders := self .get_folders (filters = [f"path == { folder_path } " ])):
793
827
return None
794
828
return _folders [0 ]
795
829
830
+ @pydantic .validate_call
796
831
def get_folders (
797
832
self ,
798
833
filters : typing .Optional [list [str ]] = None ,
799
- count : int = 100 ,
800
- start_index : int = 0 ,
834
+ count : pydantic . PositiveInt = 100 ,
835
+ start_index : pydantic . NonNegativeInt = 0 ,
801
836
) -> list [dict [str , typing .Any ]]:
802
837
"""Retrieve folders from the server
803
838
@@ -847,6 +882,8 @@ def get_folders(
847
882
848
883
return data
849
884
885
+ @prettify_pydantic
886
+ @pydantic .validate_call
850
887
def get_metrics_names (self , run_id : str ) -> list [str ]:
851
888
"""Return information on all metrics within a run
852
889
@@ -918,6 +955,8 @@ def _get_run_metrics_from_server(
918
955
919
956
return json_response
920
957
958
+ @prettify_pydantic
959
+ @pydantic .validate_call
921
960
def get_metric_values (
922
961
self ,
923
962
metric_names : list [str ],
@@ -927,8 +966,8 @@ def get_metric_values(
927
966
run_filters : typing .Optional [list [str ]] = None ,
928
967
use_run_names : bool = False ,
929
968
aggregate : bool = False ,
930
- max_points : int = - 1 ,
931
- ) -> typing .Union [dict , " DataFrame" , None ]:
969
+ max_points : typing . Optional [ pydantic . PositiveInt ] = None ,
970
+ ) -> typing .Union [dict , DataFrame , None ]:
932
971
"""Retrieve the values for a given metric across multiple runs
933
972
934
973
Uses filters to specify which runs should be retrieved.
@@ -955,7 +994,7 @@ def get_metric_values(
955
994
return results as averages (not compatible with xaxis=timestamp),
956
995
default is False
957
996
max_points : int, optional
958
- maximum number of data points, by default -1 (all)
997
+ maximum number of data points, by default None (all)
959
998
960
999
Returns
961
1000
-------
@@ -1010,7 +1049,7 @@ def get_metric_values(
1010
1049
run_ids = run_ids ,
1011
1050
xaxis = xaxis ,
1012
1051
aggregate = aggregate ,
1013
- max_points = max_points ,
1052
+ max_points = max_points or - 1 ,
1014
1053
)
1015
1054
1016
1055
if aggregate :
@@ -1023,13 +1062,15 @@ def get_metric_values(
1023
1062
)
1024
1063
1025
1064
@check_extra ("plot" )
1065
+ @prettify_pydantic
1066
+ @pydantic .validate_call
1026
1067
def plot_metrics (
1027
1068
self ,
1028
1069
run_ids : list [str ],
1029
1070
metric_names : list [str ],
1030
1071
xaxis : typing .Literal ["step" , "time" ],
1031
- max_points : int = - 1 ,
1032
- ) -> "Figure" :
1072
+ max_points : typing . Optional [ int ] = None ,
1073
+ ) -> typing . Any :
1033
1074
"""Plt the time series values for multiple metrics/runs
1034
1075
1035
1076
Parameters
@@ -1041,7 +1082,7 @@ def plot_metrics(
1041
1082
xaxis : str, ('step' | 'time' | 'timestep')
1042
1083
the x axis to plot against
1043
1084
max_points : int, optional
1044
- maximum number of data points, by default -1 (all)
1085
+ maximum number of data points, by default None (all)
1045
1086
1046
1087
Returns
1047
1088
-------
@@ -1059,7 +1100,7 @@ def plot_metrics(
1059
1100
if not isinstance (metric_names , list ):
1060
1101
raise ValueError ("Invalid names specified, must be a list of metric names." )
1061
1102
1062
- data : " DataFrame" = self .get_metric_values ( # type: ignore
1103
+ data : DataFrame = self .get_metric_values ( # type: ignore
1063
1104
run_ids = run_ids ,
1064
1105
metric_names = metric_names ,
1065
1106
xaxis = xaxis ,
@@ -1099,12 +1140,14 @@ def plot_metrics(
1099
1140
1100
1141
return plt .figure ()
1101
1142
1143
+ @prettify_pydantic
1144
+ @pydantic .validate_call
1102
1145
def get_events (
1103
1146
self ,
1104
1147
run_id : str ,
1105
1148
message_contains : typing .Optional [str ] = None ,
1106
- start_index : typing .Optional [int ] = None ,
1107
- count_limit : typing .Optional [int ] = None ,
1149
+ start_index : typing .Optional [pydantic . NonNegativeInt ] = None ,
1150
+ count_limit : typing .Optional [pydantic . PositiveInt ] = None ,
1108
1151
) -> list [dict [str , str ]]:
1109
1152
"""Return events for a specified run
1110
1153
@@ -1160,6 +1203,8 @@ def get_events(
1160
1203
1161
1204
return response .json ().get ("data" , [])
1162
1205
1206
+ @prettify_pydantic
1207
+ @pydantic .validate_call
1163
1208
def get_alerts (
1164
1209
self , run_id : str , critical_only : bool = True , names_only : bool = True
1165
1210
) -> list [dict [str , typing .Any ]]:
0 commit comments