1
+ # type: ignore
1
2
from typing import TYPE_CHECKING , Dict , Iterable , Union , List , Optional , Any
2
3
from pathlib import Path
3
4
import os
4
5
import time
5
6
import logging
6
7
import requests
7
8
import ndjson
9
+ from enum import Enum
8
10
9
11
from labelbox .pagination import PaginatedCollection
10
12
from labelbox .orm .query import results_query_part
17
19
logger = logging .getLogger (__name__ )
18
20
19
21
22
+ class DataSplit (Enum ):
23
+ TRAINING = "TRAINING"
24
+ TEST = "TEST"
25
+ VALIDATION = "VALIDATION"
26
+ UNASSIGNED = "UNASSIGNED"
27
+
28
+
20
29
class ModelRun (DbObject ):
21
30
name = Field .String ("name" )
22
31
updated_at = Field .DateTime ("updated_at" )
23
32
created_at = Field .DateTime ("created_at" )
24
33
created_by_id = Field .String ("created_by_id" , "createdBy" )
25
34
model_id = Field .String ("model_id" )
26
35
36
+ class Status (Enum ):
37
+ EXPORTING_DATA = "EXPORTING_DATA"
38
+ PREPARING_DATA = "PREPARING_DATA"
39
+ TRAINING_MODEL = "TRAINING_MODEL"
40
+ COMPLETE = "COMPLETE"
41
+ FAILED = "FAILED"
42
+
27
43
def upsert_labels (self , label_ids , timeout_seconds = 60 ):
28
44
""" Adds data rows and labels to a model run
29
45
Args:
@@ -90,8 +106,9 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
90
106
}})['MEADataRowRegistrationTaskStatus' ],
91
107
timeout_seconds = timeout_seconds )
92
108
93
- def _wait_until_done (self , status_fn , timeout_seconds = 60 , sleep_time = 5 ):
109
+ def _wait_until_done (self , status_fn , timeout_seconds = 120 , sleep_time = 5 ):
94
110
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
111
+ original_timeout = timeout_seconds
95
112
while True :
96
113
res = status_fn ()
97
114
if res ['status' ] == 'COMPLETE' :
@@ -102,9 +119,8 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
102
119
timeout_seconds -= sleep_time
103
120
if timeout_seconds <= 0 :
104
121
raise TimeoutError (
105
- f"Unable to complete import within { timeout_seconds } seconds."
122
+ f"Unable to complete import within { original_timeout } seconds."
106
123
)
107
-
108
124
time .sleep (sleep_time )
109
125
110
126
def add_predictions (
@@ -161,7 +177,7 @@ def delete(self):
161
177
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param , ids_param )
162
178
self .client .execute (query_str , {ids_param : str (self .uid )})
163
179
164
- def delete_model_run_data_rows (self , data_row_ids ):
180
+ def delete_model_run_data_rows (self , data_row_ids : List [ str ] ):
165
181
""" Deletes data rows from model runs.
166
182
167
183
Args:
@@ -180,22 +196,62 @@ def delete_model_run_data_rows(self, data_row_ids):
180
196
data_row_ids_param : data_row_ids
181
197
})
182
198
199
+ @experimental
200
+ def assign_data_rows_to_split (self ,
201
+ data_row_ids : List [str ],
202
+ split : Union [DataSplit , str ],
203
+ timeout_seconds = 120 ):
204
+
205
+ split_value = split .value if isinstance (split , DataSplit ) else split
206
+
207
+ if split_value == DataSplit .UNASSIGNED .value :
208
+ raise ValueError (
209
+ f"Cannot assign split value of `{ DataSplit .UNASSIGNED .value } `." )
210
+
211
+ valid_splits = filter (lambda name : name != DataSplit .UNASSIGNED .value ,
212
+ DataSplit ._member_names_ )
213
+
214
+ if split_value not in valid_splits :
215
+ raise ValueError (
216
+ f"`split` must be one of : `{ valid_splits } `. Found : `{ split } `" )
217
+
218
+ task_id = self .client .execute (
219
+ """mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){
220
+ createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)}
221
+ """ , {
222
+ 'modelRunId' : self .uid ,
223
+ 'data' : {
224
+ 'assignments' : [{
225
+ 'split' : split_value ,
226
+ 'dataRowIds' : data_row_ids
227
+ }]
228
+ }
229
+ },
230
+ experimental = True )['createAssignDataRowsToDataSplitTask' ]
231
+
232
+ status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){
233
+ assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}}
234
+ """
235
+
236
+ return self ._wait_until_done (lambda : self .client .execute (
237
+ status_query_str , {'id' : task_id }, experimental = True )[
238
+ 'assignDataRowsToDataSplitTaskStatus' ],
239
+ timeout_seconds = timeout_seconds )
240
+
183
241
@experimental
184
242
def update_status (self ,
185
- status : str ,
243
+ status : Union [ str , "ModelRun.Status" ] ,
186
244
metadata : Optional [Dict [str , str ]] = None ,
187
245
error_message : Optional [str ] = None ):
188
246
189
- valid_statuses = [
190
- "EXPORTING_DATA" , "PREPARING_DATA" , "TRAINING_MODEL" , "COMPLETE" ,
191
- "FAILED"
192
- ]
193
- if status not in valid_statuses :
247
+ status_value = status .value if isinstance (status ,
248
+ ModelRun .Status ) else status
249
+ if status_value not in ModelRun .Status ._member_names_ :
194
250
raise ValueError (
195
- f"Status must be one of : `{ valid_statuses } `. Found : `{ status } `"
251
+ f"Status must be one of : `{ ModelRun . Status . _member_names_ } `. Found : `{ status_value } `"
196
252
)
197
253
198
- data : Dict [str , Any ] = {'status' : status }
254
+ data : Dict [str , Any ] = {'status' : status_value }
199
255
if error_message :
200
256
data ['errorMessage' ] = error_message
201
257
@@ -264,6 +320,7 @@ def export_labels(
264
320
class ModelRunDataRow (DbObject ):
265
321
label_id = Field .String ("label_id" )
266
322
model_run_id = Field .String ("model_run_id" )
323
+ data_split = Field .Enum (DataSplit , "data_split" )
267
324
data_row = Relationship .ToOne ("DataRow" , False , cache = True )
268
325
269
326
def __init__ (self , client , model_id , * args , ** kwargs ):
0 commit comments