Skip to content

Commit bcb23d8

Browse files
authored
Merge pull request #617 from Labelbox/ms/assign-data-splits
assign data row split
2 parents f12bb3f + 6461e2c commit bcb23d8

File tree

5 files changed

+119
-19
lines changed

5 files changed

+119
-19
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ client = Client( endpoint = "<local deployment>")
8585
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="http://localhost:8080/graphql")
8686
8787
# Staging
88-
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="https://staging-api.labelbox.com/graphql")
88+
client = Client(api_key=os.environ['LABELBOX_TEST_API_KEY_LOCAL'], endpoint="https://api.lb-stage.xyz/graphql")
8989
```
9090

9191
## Contribution
@@ -122,5 +122,5 @@ make test-prod # with an optional flag: PATH_TO_TEST=tests/integration/...etc LA
122122
make -B {build|test-staging|test-prod}
123123
```
124124

125-
6. Testing against Delegated Access will be skipped unless the local env contains the key:
126-
DA_GCP_LABELBOX_API_KEY. These tests will be included when run against a PR. If you would like to test it manually, please reach out to the Devops team for information on the key.
125+
6. Testing against Delegated Access will be skipped unless the local env contains the key:
126+
DA_GCP_LABELBOX_API_KEY. These tests will be included when run against a PR. If you would like to test it manually, please reach out to the Devops team for information on the key.

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from labelbox.schema.role import Role, ProjectRole
2222
from labelbox.schema.invite import Invite, InviteLimit
2323
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
24-
from labelbox.schema.model_run import ModelRun
24+
from labelbox.schema.model_run import ModelRun, DataSplit
2525
from labelbox.schema.benchmark import Benchmark
2626
from labelbox.schema.iam_integration import IAMIntegration
2727
from labelbox.schema.resource_tag import ResourceTag

labelbox/schema/model_run.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
# type: ignore
12
from typing import TYPE_CHECKING, Dict, Iterable, Union, List, Optional, Any
23
from pathlib import Path
34
import os
45
import time
56
import logging
67
import requests
78
import ndjson
9+
from enum import Enum
810

911
from labelbox.pagination import PaginatedCollection
1012
from labelbox.orm.query import results_query_part
@@ -17,13 +19,27 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
class DataSplit(Enum):
23+
TRAINING = "TRAINING"
24+
TEST = "TEST"
25+
VALIDATION = "VALIDATION"
26+
UNASSIGNED = "UNASSIGNED"
27+
28+
2029
class ModelRun(DbObject):
2130
name = Field.String("name")
2231
updated_at = Field.DateTime("updated_at")
2332
created_at = Field.DateTime("created_at")
2433
created_by_id = Field.String("created_by_id", "createdBy")
2534
model_id = Field.String("model_id")
2635

36+
class Status(Enum):
37+
EXPORTING_DATA = "EXPORTING_DATA"
38+
PREPARING_DATA = "PREPARING_DATA"
39+
TRAINING_MODEL = "TRAINING_MODEL"
40+
COMPLETE = "COMPLETE"
41+
FAILED = "FAILED"
42+
2743
def upsert_labels(self, label_ids, timeout_seconds=60):
2844
""" Adds data rows and labels to a model run
2945
Args:
@@ -90,8 +106,9 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
90106
}})['MEADataRowRegistrationTaskStatus'],
91107
timeout_seconds=timeout_seconds)
92108

93-
def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
109+
def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5):
94110
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
111+
original_timeout = timeout_seconds
95112
while True:
96113
res = status_fn()
97114
if res['status'] == 'COMPLETE':
@@ -102,9 +119,8 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
102119
timeout_seconds -= sleep_time
103120
if timeout_seconds <= 0:
104121
raise TimeoutError(
105-
f"Unable to complete import within {timeout_seconds} seconds."
122+
f"Unable to complete import within {original_timeout} seconds."
106123
)
107-
108124
time.sleep(sleep_time)
109125

110126
def add_predictions(
@@ -161,7 +177,7 @@ def delete(self):
161177
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param)
162178
self.client.execute(query_str, {ids_param: str(self.uid)})
163179

164-
def delete_model_run_data_rows(self, data_row_ids):
180+
def delete_model_run_data_rows(self, data_row_ids: List[str]):
165181
""" Deletes data rows from model runs.
166182
167183
Args:
@@ -180,22 +196,62 @@ def delete_model_run_data_rows(self, data_row_ids):
180196
data_row_ids_param: data_row_ids
181197
})
182198

199+
@experimental
200+
def assign_data_rows_to_split(self,
201+
data_row_ids: List[str],
202+
split: Union[DataSplit, str],
203+
timeout_seconds=120):
204+
205+
split_value = split.value if isinstance(split, DataSplit) else split
206+
207+
if split_value == DataSplit.UNASSIGNED.value:
208+
raise ValueError(
209+
f"Cannot assign split value of `{DataSplit.UNASSIGNED.value}`.")
210+
211+
valid_splits = filter(lambda name: name != DataSplit.UNASSIGNED.value,
212+
DataSplit._member_names_)
213+
214+
if split_value not in valid_splits:
215+
raise ValueError(
216+
f"`split` must be one of : `{valid_splits}`. Found : `{split}`")
217+
218+
task_id = self.client.execute(
219+
"""mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){
220+
createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)}
221+
""", {
222+
'modelRunId': self.uid,
223+
'data': {
224+
'assignments': [{
225+
'split': split_value,
226+
'dataRowIds': data_row_ids
227+
}]
228+
}
229+
},
230+
experimental=True)['createAssignDataRowsToDataSplitTask']
231+
232+
status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){
233+
assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}}
234+
"""
235+
236+
return self._wait_until_done(lambda: self.client.execute(
237+
status_query_str, {'id': task_id}, experimental=True)[
238+
'assignDataRowsToDataSplitTaskStatus'],
239+
timeout_seconds=timeout_seconds)
240+
183241
@experimental
184242
def update_status(self,
185-
status: str,
243+
status: Union[str, "ModelRun.Status"],
186244
metadata: Optional[Dict[str, str]] = None,
187245
error_message: Optional[str] = None):
188246

189-
valid_statuses = [
190-
"EXPORTING_DATA", "PREPARING_DATA", "TRAINING_MODEL", "COMPLETE",
191-
"FAILED"
192-
]
193-
if status not in valid_statuses:
247+
status_value = status.value if isinstance(status,
248+
ModelRun.Status) else status
249+
if status_value not in ModelRun.Status._member_names_:
194250
raise ValueError(
195-
f"Status must be one of : `{valid_statuses}`. Found : `{status}`"
251+
f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`"
196252
)
197253

198-
data: Dict[str, Any] = {'status': status}
254+
data: Dict[str, Any] = {'status': status_value}
199255
if error_message:
200256
data['errorMessage'] = error_message
201257

@@ -264,6 +320,7 @@ def export_labels(
264320
class ModelRunDataRow(DbObject):
265321
label_id = Field.String("label_id")
266322
model_run_id = Field.String("model_run_id")
323+
data_split = Field.Enum(DataSplit, "data_split")
267324
data_row = Relationship.ToOne("DataRow", False, cache=True)
268325

269326
def __init__(self, client, model_id, *args, **kwargs):

tests/integration/annotation_import/test_model_run.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import os
33
import pytest
44

5+
from collections import Counter
6+
from labelbox import DataSplit, ModelRun
7+
58

69
def test_model_run(client, configured_project_with_label, rand_gen):
710
project, _, _, label = configured_project_with_label
@@ -119,3 +122,40 @@ def get_model_run_status():
119122
assert model_run_status['status'] == status
120123
assert model_run_status['metadata'] == {**metadata, **extra_metadata}
121124
assert model_run_status['errorMessage'] == errorMessage
125+
126+
status = ModelRun.Status.FAILED
127+
model_run_with_model_run_data_rows.update_status(status, metadata,
128+
errorMessage)
129+
model_run_status = get_model_run_status()
130+
assert model_run_status['status'] == status.value
131+
132+
with pytest.raises(ValueError):
133+
model_run_with_model_run_data_rows.update_status(
134+
"INVALID", metadata, errorMessage)
135+
136+
137+
def test_model_run_split_assignment(model_run, dataset, image_url):
138+
n_data_rows = 10
139+
data_rows = dataset.create_data_rows([{
140+
"row_data": image_url
141+
} for _ in range(n_data_rows)])
142+
data_row_ids = [data_row['id'] for data_row in data_rows.result]
143+
144+
model_run.upsert_data_rows(data_row_ids)
145+
146+
with pytest.raises(ValueError):
147+
model_run.assign_data_rows_to_split(data_row_ids, "INVALID SPLIT")
148+
149+
with pytest.raises(ValueError):
150+
model_run.assign_data_rows_to_split(data_row_ids, DataSplit.UNASSIGNED)
151+
152+
for split in ["TRAINING", "TEST", "VALIDATION", *DataSplit]:
153+
if split == DataSplit.UNASSIGNED:
154+
continue
155+
156+
model_run.assign_data_rows_to_split(data_row_ids, split)
157+
counts = Counter()
158+
for data_row in model_run.model_run_data_rows():
159+
counts[data_row.data_split.value] += 1
160+
split = split.value if isinstance(split, DataSplit) else split
161+
assert counts[split] == n_data_rows

tests/integration/conftest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ def client(environ: str):
145145

146146
@pytest.fixture(scope="session")
147147
def image_url(client):
148-
return client.upload_data(requests.get(IMG_URL).content, sign=True)
148+
return client.upload_data(requests.get(IMG_URL).content,
149+
content_type="application/json",
150+
filename="json_import.json",
151+
sign=True)
149152

150153

151154
@pytest.fixture
@@ -181,7 +184,7 @@ def iframe_url(environ) -> str:
181184
if environ in [Environ.PROD, Environ.LOCAL]:
182185
return 'https://editor.labelbox.com'
183186
elif environ == Environ.STAGING:
184-
return 'https://staging.labelbox.dev/editor'
187+
return 'https://editor.lb-stage.xyz'
185188

186189

187190
@pytest.fixture
@@ -290,7 +293,7 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset,
290293

291294
def create_label():
292295
""" Ad-hoc function to create a LabelImport
293-
296+
294297
Creates a LabelImport task which will create a label
295298
"""
296299
upload_task = LabelImport.create_from_objects(

0 commit comments

Comments
 (0)