Skip to content

Commit a6ad5b7

Browse files
jkbmrzklemen1999
andauthored
feat: generate_script_content update (#206)
* feat: generate_script_content remove rvc2/4 differentiation * fix: unittests * fix: unittests 2 * feat: generate_script_content extend resize_mode options * Update depthai_nodes/node/utils/detection_config_generator.py Co-authored-by: KlemenSkrlj <[email protected]> * feat: rename generate_script_content unittests file --------- Co-authored-by: KlemenSkrlj <[email protected]>
1 parent ce751eb commit a6ad5b7

File tree

2 files changed

+19
-79
lines changed

2 files changed

+19
-79
lines changed

depthai_nodes/node/utils/detection_config_generator.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33

44
def generate_script_content(
5-
platform: str,
65
resize_width: int,
76
resize_height: int,
7+
resize_mode: str = "NONE",
88
padding: float = 0,
99
valid_labels: Optional[List[int]] = None,
1010
) -> str:
@@ -14,12 +14,13 @@ def generate_script_content(
1414
also work with padding around the detection bounding box and filter detections by
1515
labels.
1616
17-
@param platform: Target platform for the script. Supported values: 'rvc2', 'rvc4'
18-
@type platform: str
1917
@param resize_width: Target width for the resized image
2018
@type resize_width: int
2119
@param resize_height: Target height for the resized image
2220
@type resize_height: int
21+
@param resize_mode: Resize mode for the image. Supported values: "CENTER_CROP",
22+
"LETTERBOX", "NONE", "STRETCH". Default: "NONE"
23+
@type resize_mode: str
2324
@param padding: Additional padding around the detection in normalized coordinates
2425
(0-1)
2526
@type padding: float
@@ -30,15 +31,10 @@ def generate_script_content(
3031
@rtype: str
3132
"""
3233

33-
if platform.lower() == "rvc2":
34-
cfg_content = f"""
35-
cfg = ImageManipConfig()
36-
cfg.setCropRect(det.xmin - {padding}, det.ymin - {padding}, det.xmax + {padding}, det.ymax + {padding})
37-
cfg.setResize({resize_width}, {resize_height})
38-
cfg.setKeepAspectRatio(False)
39-
"""
40-
elif platform.lower() == "rvc4":
41-
cfg_content = f"""
34+
if resize_mode not in ["CENTER_CROP", "LETTERBOX", "NONE", "STRETCH"]:
35+
raise ValueError("Unsupported resize mode")
36+
37+
cfg_content = f"""
4238
cfg = ImageManipConfigV2()
4339
rect = RotatedRect()
4440
rect.center.x = (det.xmin + det.xmax) / 2
@@ -49,11 +45,9 @@ def generate_script_content(
4945
rect.size.height = rect.size.height + {padding} * 2
5046
rect.angle = 0
5147
52-
cfg.addCropRotatedRect(rect=rect, normalizedCoords=True)
53-
cfg.setOutputSize({resize_width}, {resize_height})
48+
cfg.addCropRotatedRect(rect, normalizedCoords=True)
49+
cfg.setOutputSize({resize_width}, {resize_height}, ImageManipConfigV2.ResizeMode.{resize_mode})
5450
"""
55-
else:
56-
raise ValueError("Unsupported platform")
5751
validate_label = (
5852
f"""
5953
if det.label not in {valid_labels}:

tests/unittests/test_nodes/test_host_nodes/test_detection_config_generator.py renamed to tests/unittests/test_nodes/test_host_nodes/test_generate_script_content.py

+9-63
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional
22

33
import depthai as dai
44
import pytest
@@ -16,11 +16,6 @@ def resize_height():
1616
return 256
1717

1818

19-
def test_rvc3_unsupported(resize_width, resize_height):
20-
with pytest.raises(ValueError, match="Unsupported"):
21-
generate_script_content("rvc3", resize_width, resize_height)
22-
23-
2419
class ImageManipConfigV2(dai.ImageManipConfigV2):
2520
def __init__(self):
2621
super().__init__()
@@ -168,16 +163,14 @@ def node_input_detections(node) -> List[dai.ImgDetections]:
168163
return node.inputs[Node.INPUT_DETECTIONS_KEY].items
169164

170165

171-
@pytest.mark.parametrize("platform", ["rvc2", "rvc4"])
172166
def test_passthrough(
173167
node,
174168
node_input_detections,
175169
node_input_frames,
176-
platform,
177170
resize_width,
178171
resize_height,
179172
):
180-
script = generate_script_content(platform, resize_width, resize_height)
173+
script = generate_script_content(resize_width, resize_height)
181174
expected_frames = []
182175
for frame, detections in zip(node_input_frames, node_input_detections):
183176
for _ in detections.detections:
@@ -191,12 +184,11 @@ def test_passthrough(
191184
assert len(get_output_config(node)) == len(expected_frames)
192185

193186

194-
@pytest.mark.parametrize(("platform", "labels"), [("rvc2", [1]), ("rvc4", [1, 2])])
187+
@pytest.mark.parametrize("labels", [[1], [1, 2]])
195188
def test_label_validation(
196189
node,
197190
node_input_detections,
198191
node_input_frames,
199-
platform,
200192
labels,
201193
resize_width,
202194
resize_height,
@@ -207,30 +199,16 @@ def test_label_validation(
207199
if detection.label not in labels:
208200
continue
209201
expected_frames.append(frame)
210-
script = generate_script_content(
211-
platform, resize_width, resize_height, valid_labels=labels
212-
)
202+
script = generate_script_content(resize_width, resize_height, valid_labels=labels)
213203
try:
214204
run_script(node, script)
215205
except Warning:
216206
assert expected_frames == get_output_frames(node)
217207

218208

219209
@pytest.mark.parametrize("resize", [(128, 128), (128, 256), (256, 256)])
220-
def test_rvc2_output_size(node, resize):
221-
script = generate_script_content("rvc2", *resize)
222-
try:
223-
run_script(node, script)
224-
except Warning:
225-
output_cfg = get_output_config(node)
226-
for cfg in output_cfg:
227-
assert isinstance(cfg, dai.ImageManipConfig)
228-
assert cfg.getResizeWidth(), cfg.getResizeHeight() == resize
229-
230-
231-
@pytest.mark.parametrize("resize", [(128, 128), (128, 256), (256, 256)])
232-
def test_rvc4_output_size(node, resize):
233-
script = generate_script_content("rvc4", *resize)
210+
def test_output_size(node, resize):
211+
script = generate_script_content(*resize)
234212
try:
235213
run_script(node, script)
236214
except Warning:
@@ -241,36 +219,7 @@ def test_rvc4_output_size(node, resize):
241219

242220

243221
@pytest.mark.parametrize("padding", [0, 0.1, 0.2, -0.1, -0.2])
244-
def test_rvc2_crop(node, node_input_detections, padding, resize_width, resize_height):
245-
expected_rects: List[dai.ImageManipConfig.CropRect] = []
246-
for input_dets in node_input_detections:
247-
for detection in input_dets.detections:
248-
rect = dai.ImageManipConfig.CropRect()
249-
rect.xmin = max(detection.xmin - padding, 0)
250-
rect.xmax = min(detection.xmax + padding, 1)
251-
rect.ymin = max(detection.ymin - padding, 0)
252-
rect.ymax = min(detection.ymax + padding, 1)
253-
expected_rects.append(rect)
254-
script = generate_script_content(
255-
"rvc2", resize_width, resize_height, padding=padding
256-
)
257-
try:
258-
run_script(node, script)
259-
except Warning:
260-
output_cfg = get_output_config(node)
261-
for cfg, expected_rect in zip(output_cfg, expected_rects):
262-
assert isinstance(cfg, dai.ImageManipConfig)
263-
crop_rect = cfg.getCropConfig().cropRect
264-
assert (crop_rect.xmin, crop_rect.xmax, crop_rect.ymin, crop_rect.ymax) == (
265-
expected_rect.xmin,
266-
expected_rect.xmax,
267-
expected_rect.ymin,
268-
expected_rect.ymax,
269-
)
270-
271-
272-
@pytest.mark.parametrize("padding", [0, 0.1, 0.2, -0.1, -0.2])
273-
def test_rvc4_crop(node, node_input_detections, padding, resize_width, resize_height):
222+
def test_crop(node, node_input_detections, padding, resize_width, resize_height):
274223
ANGLE = 0
275224
expected_rects: List[dai.RotatedRect] = []
276225
for input_dets in node_input_detections:
@@ -283,9 +232,7 @@ def test_rvc4_crop(node, node_input_detections, padding, resize_width, resize_he
283232
rect.size.width = detection.xmax - detection.xmin + rect_padding
284233
rect.size.height = detection.ymax - detection.ymin + rect_padding
285234
expected_rects.append(rect)
286-
script = generate_script_content(
287-
"rvc4", resize_width, resize_height, padding=padding
288-
)
235+
script = generate_script_content(resize_width, resize_height, padding=padding)
289236
try:
290237
run_script(node, script)
291238
except Warning:
@@ -310,7 +257,6 @@ def run_script(node, script):
310257
{
311258
"node": node,
312259
"ImageManipConfigV2": ImageManipConfigV2,
313-
"ImageManipConfig": dai.ImageManipConfig,
314260
"RotatedRect": dai.RotatedRect,
315261
},
316262
)
@@ -322,5 +268,5 @@ def get_output_frames(node: Node) -> List[Frame]:
322268

323269
def get_output_config(
324270
node: Node,
325-
) -> Union[List[dai.ImageManipConfig], List[ImageManipConfigV2]]:
271+
) -> List[ImageManipConfigV2]:
326272
return node.outputs[Node.OUTPUT_CONFIG_KEY].items

0 commit comments

Comments
 (0)