Skip to content

Commit 7339d55

Browse files
authored
[SDK-17] Optimize sdk tests via reducing fixture times (#1211)
2 parents becff66 + e585e8c commit 7339d55

17 files changed

+347
-186
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
echo "LABELBOX_TEST_ENVIRON=prod" >> $GITHUB_ENV
4040
else
4141
echo "LABELBOX_TEST_ENVIRON=staging" >> $GITHUB_ENV
42+
echo "FIXTURE_PROFILE=true" >> $GITHUB_ENV
4243
fi
4344
4445
- uses: actions/checkout@v2

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ test-local: build-image
1313
-e LABELBOX_TEST_ENVIRON="local" \
1414
-e DA_GCP_LABELBOX_API_KEY=${DA_GCP_LABELBOX_API_KEY} \
1515
-e LABELBOX_TEST_API_KEY_LOCAL=${LABELBOX_TEST_API_KEY_LOCAL} \
16+
-e FIXTURE_PROFILE=true \
1617
local/labelbox-python:test pytest $(PATH_TO_TEST)
1718

1819
test-staging: build-image

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
]
1313

1414

15-
@pytest.fixture
15+
@pytest.fixture(scope="session")
1616
def rand_gen():
1717

1818
def gen(field_type):

tests/integration/annotation_import/conftest.py

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Type
1010
from labelbox.schema.labeling_frontend import LabelingFrontend
1111
from labelbox.schema.annotation_import import LabelImport, AnnotationImportState
12+
from labelbox.schema.project import Project
1213
from labelbox.schema.queue_mode import QueueMode
1314

1415
DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 40
@@ -210,7 +211,7 @@ def annotations_by_data_type_v2(
210211
}
211212

212213

213-
@pytest.fixture
214+
@pytest.fixture(scope='session')
214215
def ontology():
215216
bbox_tool_with_nested_text = {
216217
'required':
@@ -478,34 +479,49 @@ def func(project):
478479

479480

480481
@pytest.fixture
481-
def initial_dataset(client, rand_gen):
482-
dataset = client.create_dataset(name=rand_gen(str))
483-
yield dataset
484-
dataset.delete()
482+
def configured_project_datarow_id(configured_project):
483+
484+
def get_data_row_id(indx=0):
485+
return configured_project.data_row_ids[indx]
486+
487+
yield get_data_row_id
488+
489+
490+
@pytest.fixture
491+
def configured_project_one_datarow_id(configured_project_with_one_data_row):
492+
493+
def get_data_row_id(indx=0):
494+
return configured_project_with_one_data_row.data_row_ids[0]
495+
496+
yield get_data_row_id
485497

486498

487499
@pytest.fixture
488500
def configured_project(client, initial_dataset, ontology, rand_gen, image_url):
489501
dataset = initial_dataset
490-
project = client.create_project(
491-
name=rand_gen(str),
492-
queue_mode=QueueMode.Batch,
493-
)
502+
project = client.create_project(name=rand_gen(str),
503+
queue_mode=QueueMode.Batch)
494504
editor = list(
495505
client.get_labeling_frontends(
496506
where=LabelingFrontend.name == "editor"))[0]
497507
project.setup(editor, ontology)
508+
498509
data_row_ids = []
499510

500511
for _ in range(len(ontology['tools']) + len(ontology['classifications'])):
501512
data_row_ids.append(dataset.create_data_row(row_data=image_url).uid)
502-
project.create_batch(
513+
project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids,
514+
sleep_interval=3)
515+
516+
batch = project.create_batch(
503517
rand_gen(str),
504518
data_row_ids, # sample of data row objects
505519
5 # priority between 1(Highest) - 5(lowest)
506520
)
507521
project.data_row_ids = data_row_ids
522+
508523
yield project
524+
509525
project.delete()
510526

511527

@@ -556,27 +572,74 @@ def dataset_conversation_entity(client, rand_gen, conversation_entity_data_row,
556572

557573

558574
@pytest.fixture
559-
def configured_project_without_data_rows(client, ontology, rand_gen):
575+
def configured_project_with_one_data_row(client, ontology, rand_gen,
576+
initial_dataset, image_url):
560577
project = client.create_project(name=rand_gen(str),
561578
description=rand_gen(str),
562579
queue_mode=QueueMode.Batch)
563580
editor = list(
564581
client.get_labeling_frontends(
565582
where=LabelingFrontend.name == "editor"))[0]
566583
project.setup(editor, ontology)
584+
585+
data_row = initial_dataset.create_data_row(row_data=image_url)
586+
data_row_ids = [data_row.uid]
587+
project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids,
588+
sleep_interval=3)
589+
590+
batch = project.create_batch(
591+
rand_gen(str),
592+
data_row_ids, # sample of data row objects
593+
5 # priority between 1(Highest) - 5(lowest)
594+
)
595+
project.data_row_ids = data_row_ids
596+
567597
yield project
598+
599+
batch.delete()
568600
project.delete()
569601

570602

571603
# This function allows to convert an ontology feature to actual annotation
572604
# At the moment it expects only one feature per tool type and this creates unnecessary coupling between differet tests
573605
# In an example of a 'rectangle' we have extended to support multiple instances of the same tool type
574606
# TODO: we will support this approach in the future for all tools
607+
#
608+
"""
609+
Please note that this fixture now offers the flexibility to configure three different strategies for generating data row ids for predictions:
610+
Default(configured_project fixture):
611+
configured_project that generates a data row for each member of ontology.
612+
This makes sure each prediction has its own data row id. This is applicable to prediction upload cases when last label overwrites existing ones
613+
614+
Optimized Strategy (configured_project_with_one_data_row fixture):
615+
This fixture has only one data row and all predictions will be mapped to it
616+
617+
Custom Data Row IDs Strategy:
618+
Individuals can supply hard-coded data row ids when a creation of data row is not required.
619+
This particular fixture, termed "hardcoded_datarow_id," should be defined locally within a test file.
620+
In the future, we can use this approach to inject correct number of rows instead of using configured_project fixture
621+
that creates a data row for each member of ontology (14 in total) for each run.
622+
"""
623+
624+
575625
@pytest.fixture
576-
def prediction_id_mapping(configured_project):
626+
def prediction_id_mapping(ontology, request):
577627
# Maps tool types to feature schema ids
578-
project = configured_project
628+
if 'configured_project' in request.fixturenames:
629+
data_row_id_factory = request.getfixturevalue(
630+
'configured_project_datarow_id')
631+
project = request.getfixturevalue('configured_project')
632+
elif 'hardcoded_datarow_id' in request.fixturenames:
633+
data_row_id_factory = request.getfixturevalue('hardcoded_datarow_id')
634+
project = request.getfixturevalue('configured_project_with_ontology')
635+
else:
636+
data_row_id_factory = request.getfixturevalue(
637+
'configured_project_one_datarow_id')
638+
project = request.getfixturevalue(
639+
'configured_project_with_one_data_row')
640+
579641
ontology = project.ontology().normalized
642+
580643
result = {}
581644

582645
for idx, tool in enumerate(ontology['tools'] + ontology['classifications']):
@@ -593,7 +656,7 @@ def prediction_id_mapping(configured_project):
593656
"schemaId": tool['featureSchemaId'],
594657
"name": tool['name'],
595658
"dataRow": {
596-
"id": project.data_row_ids[idx],
659+
"id": data_row_id_factory(idx),
597660
},
598661
'tool': tool
599662
}
@@ -606,7 +669,7 @@ def prediction_id_mapping(configured_project):
606669
"schemaId": tool['featureSchemaId'],
607670
"name": tool['name'],
608671
"dataRow": {
609-
"id": project.data_row_ids[idx],
672+
"id": data_row_id_factory(idx),
610673
},
611674
'tool': tool
612675
}

tests/integration/annotation_import/test_bulk_import_request.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,40 +25,40 @@
2525
"""
2626

2727

28-
def test_create_from_url(configured_project):
28+
def test_create_from_url(project):
2929
name = str(uuid.uuid4())
3030
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
3131

32-
bulk_import_request = configured_project.upload_annotations(name=name,
33-
annotations=url,
34-
validate=False)
32+
bulk_import_request = project.upload_annotations(name=name,
33+
annotations=url,
34+
validate=False)
3535

36-
assert bulk_import_request.project() == configured_project
36+
assert bulk_import_request.project() == project
3737
assert bulk_import_request.name == name
3838
assert bulk_import_request.input_file_url == url
3939
assert bulk_import_request.error_file_url is None
4040
assert bulk_import_request.status_file_url is None
4141
assert bulk_import_request.state == BulkImportRequestState.RUNNING
4242

4343

44-
def test_validate_file(configured_project):
44+
def test_validate_file(project_with_empty_ontology):
4545
name = str(uuid.uuid4())
4646
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
4747
with pytest.raises(MALValidationError):
48-
configured_project.upload_annotations(name=name,
49-
annotations=url,
50-
validate=True)
48+
project_with_empty_ontology.upload_annotations(name=name,
49+
annotations=url,
50+
validate=True)
5151
#Schema ids shouldn't match
5252

5353

54-
def test_create_from_objects(configured_project, predictions,
54+
def test_create_from_objects(configured_project_with_one_data_row, predictions,
5555
annotation_import_test_helpers):
5656
name = str(uuid.uuid4())
5757

58-
bulk_import_request = configured_project.upload_annotations(
58+
bulk_import_request = configured_project_with_one_data_row.upload_annotations(
5959
name=name, annotations=predictions)
6060

61-
assert bulk_import_request.project() == configured_project
61+
assert bulk_import_request.project() == configured_project_with_one_data_row
6262
assert bulk_import_request.name == name
6363
assert bulk_import_request.error_file_url is None
6464
assert bulk_import_request.status_file_url is None
@@ -105,34 +105,33 @@ def test_create_from_local_file(tmp_path, predictions, configured_project,
105105
bulk_import_request.input_file_url, predictions)
106106

107107

108-
def test_get(client, configured_project):
108+
def test_get(client, configured_project_with_one_data_row):
109109
name = str(uuid.uuid4())
110110
url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
111-
configured_project.upload_annotations(name=name,
112-
annotations=url,
113-
validate=False)
111+
configured_project_with_one_data_row.upload_annotations(name=name,
112+
annotations=url,
113+
validate=False)
114114

115115
bulk_import_request = BulkImportRequest.from_name(
116-
client, project_id=configured_project.uid, name=name)
116+
client, project_id=configured_project_with_one_data_row.uid, name=name)
117117

118-
assert bulk_import_request.project() == configured_project
118+
assert bulk_import_request.project() == configured_project_with_one_data_row
119119
assert bulk_import_request.name == name
120120
assert bulk_import_request.input_file_url == url
121121
assert bulk_import_request.error_file_url is None
122122
assert bulk_import_request.status_file_url is None
123123
assert bulk_import_request.state == BulkImportRequestState.RUNNING
124124

125125

126-
def test_validate_ndjson(tmp_path, configured_project):
126+
def test_validate_ndjson(tmp_path, configured_project_with_one_data_row):
127127
file_name = f"broken.ndjson"
128128
file_path = tmp_path / file_name
129129
with file_path.open("w") as f:
130130
f.write("test")
131131

132132
with pytest.raises(ValueError):
133-
configured_project.upload_annotations(name="name",
134-
validate=True,
135-
annotations=str(file_path))
133+
configured_project_with_one_data_row.upload_annotations(
134+
name="name", validate=True, annotations=str(file_path))
136135

137136

138137
def test_validate_ndjson_uuid(tmp_path, configured_project, predictions):
@@ -158,14 +157,13 @@ def test_validate_ndjson_uuid(tmp_path, configured_project, predictions):
158157

159158

160159
@pytest.mark.slow
161-
def test_wait_till_done(rectangle_inference, configured_project):
160+
def test_wait_till_done(rectangle_inference,
161+
configured_project_with_one_data_row):
162162
name = str(uuid.uuid4())
163-
url = configured_project.client.upload_data(content=parser.dumps(
164-
[rectangle_inference]),
165-
sign=True)
166-
bulk_import_request = configured_project.upload_annotations(name=name,
167-
annotations=url,
168-
validate=False)
163+
url = configured_project_with_one_data_row.client.upload_data(
164+
content=parser.dumps([rectangle_inference]), sign=True)
165+
bulk_import_request = configured_project_with_one_data_row.upload_annotations(
166+
name=name, annotations=url, validate=False)
169167

170168
assert len(bulk_import_request.inputs) == 1
171169
bulk_import_request.wait_until_done()
@@ -299,7 +297,7 @@ def test_pdf_mal_bbox(client, configured_project_pdf):
299297
assert import_annotations.errors == []
300298

301299

302-
def test_pdf_document_entity(client, configured_project_without_data_rows,
300+
def test_pdf_document_entity(client, configured_project_with_one_data_row,
303301
dataset_pdf_entity, rand_gen):
304302
# for content "Metal-insulator (MI) transitions have been one of the" in OCR JSON extract tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json
305303
document_text_selection = DocumentTextSelection(
@@ -323,7 +321,7 @@ def test_pdf_document_entity(client, configured_project_without_data_rows,
323321

324322
labels = []
325323
_, data_row_uids = dataset_pdf_entity
326-
configured_project_without_data_rows.create_batch(
324+
configured_project_with_one_data_row.create_batch(
327325
rand_gen(str),
328326
data_row_uids, # sample of data row objects
329327
5 # priority between 1(Highest) - 5(lowest)
@@ -338,7 +336,7 @@ def test_pdf_document_entity(client, configured_project_without_data_rows,
338336

339337
import_annotations = MALPredictionImport.create_from_objects(
340338
client=client,
341-
project_id=configured_project_without_data_rows.uid,
339+
project_id=configured_project_with_one_data_row.uid,
342340
name=f"import {str(uuid.uuid4())}",
343341
predictions=labels)
344342
import_annotations.wait_until_done()
@@ -347,14 +345,14 @@ def test_pdf_document_entity(client, configured_project_without_data_rows,
347345

348346

349347
def test_nested_video_object_annotations(client,
350-
configured_project_without_data_rows,
348+
configured_project_with_one_data_row,
351349
video_data,
352350
bbox_video_annotation_objects,
353351
rand_gen):
354352
labels = []
355353
_, data_row_uids = video_data
356-
configured_project_without_data_rows.update(media_type=MediaType.Video)
357-
configured_project_without_data_rows.create_batch(
354+
configured_project_with_one_data_row.update(media_type=MediaType.Video)
355+
configured_project_with_one_data_row.create_batch(
358356
rand_gen(str),
359357
data_row_uids, # sample of data row objects
360358
5 # priority between 1(Highest) - 5(lowest)
@@ -366,7 +364,7 @@ def test_nested_video_object_annotations(client,
366364
annotations=bbox_video_annotation_objects))
367365
import_annotations = MALPredictionImport.create_from_objects(
368366
client=client,
369-
project_id=configured_project_without_data_rows.uid,
367+
project_id=configured_project_with_one_data_row.uid,
370368
name=f"import {str(uuid.uuid4())}",
371369
predictions=labels)
372370
import_annotations.wait_until_done()

tests/integration/annotation_import/test_conversation_import.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from labelbox.schema.annotation_import import MALPredictionImport
88

99

10-
def test_conversation_entity(client, configured_project_without_data_rows,
10+
def test_conversation_entity(client, configured_project_with_one_data_row,
1111
dataset_conversation_entity, rand_gen):
1212

1313
conversation_entity_annotation = ConversationEntity(start=0,
@@ -20,7 +20,7 @@ def test_conversation_entity(client, configured_project_without_data_rows,
2020
labels = []
2121
_, data_row_uids = dataset_conversation_entity
2222

23-
configured_project_without_data_rows.create_batch(
23+
configured_project_with_one_data_row.create_batch(
2424
rand_gen(str),
2525
data_row_uids, # sample of data row objects
2626
5 # priority between 1(Highest) - 5(lowest)
@@ -35,7 +35,7 @@ def test_conversation_entity(client, configured_project_without_data_rows,
3535

3636
import_annotations = MALPredictionImport.create_from_objects(
3737
client=client,
38-
project_id=configured_project_without_data_rows.uid,
38+
project_id=configured_project_with_one_data_row.uid,
3939
name=f"import {str(uuid.uuid4())}",
4040
predictions=labels)
4141

0 commit comments

Comments
 (0)