Skip to content

feat: generate_script_content update #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions depthai_nodes/node/utils/detection_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


def generate_script_content(
platform: str,
resize_width: int,
resize_height: int,
resize_mode: str = "NONE",
padding: float = 0,
valid_labels: Optional[List[int]] = None,
) -> str:
Expand All @@ -14,12 +14,13 @@ def generate_script_content(
also work with padding around the detection bounding box and filter detections by
labels.

@param platform: Target platform for the script. Supported values: 'rvc2', 'rvc4'
@type platform: str
@param resize_width: Target width for the resized image
@type resize_width: int
@param resize_height: Target height for the resized image
@type resize_height: int
@param resize_mode: Resize mode for the image. Supported values: "CENTER_CROP",
"LETTERBOX", "NONE", "STRETCH". Default: "NONE"
@type resize_mode: str
@param padding: Additional padding around the detection in normalized coordinates
(0-1)
@type padding: float
Expand All @@ -30,15 +31,10 @@ def generate_script_content(
@rtype: str
"""

if platform.lower() == "rvc2":
cfg_content = f"""
cfg = ImageManipConfig()
cfg.setCropRect(det.xmin - {padding}, det.ymin - {padding}, det.xmax + {padding}, det.ymax + {padding})
cfg.setResize({resize_width}, {resize_height})
cfg.setKeepAspectRatio(False)
"""
elif platform.lower() == "rvc4":
cfg_content = f"""
if resize_mode not in ["CENTER_CROP", "LETTERBOX", "NONE", "STRETCH"]:
raise ValueError("Unsupported resize mode")

cfg_content = f"""
cfg = ImageManipConfigV2()
rect = RotatedRect()
rect.center.x = (det.xmin + det.xmax) / 2
Expand All @@ -49,11 +45,9 @@ def generate_script_content(
rect.size.height = rect.size.height + {padding} * 2
rect.angle = 0

cfg.addCropRotatedRect(rect=rect, normalizedCoords=True)
cfg.setOutputSize({resize_width}, {resize_height})
cfg.addCropRotatedRect(rect, normalizedCoords=True)
cfg.setOutputSize({resize_width}, {resize_height}, ImageManipConfigV2.ResizeMode.{resize_mode})
"""
else:
raise ValueError("Unsupported platform")
validate_label = (
f"""
if det.label not in {valid_labels}:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional

import depthai as dai
import pytest
Expand All @@ -16,11 +16,6 @@ def resize_height():
return 256


def test_rvc3_unsupported(resize_width, resize_height):
with pytest.raises(ValueError, match="Unsupported"):
generate_script_content("rvc3", resize_width, resize_height)


class ImageManipConfigV2(dai.ImageManipConfigV2):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -168,16 +163,14 @@ def node_input_detections(node) -> List[dai.ImgDetections]:
return node.inputs[Node.INPUT_DETECTIONS_KEY].items


@pytest.mark.parametrize("platform", ["rvc2", "rvc4"])
def test_passthrough(
node,
node_input_detections,
node_input_frames,
platform,
resize_width,
resize_height,
):
script = generate_script_content(platform, resize_width, resize_height)
script = generate_script_content(resize_width, resize_height)
expected_frames = []
for frame, detections in zip(node_input_frames, node_input_detections):
for _ in detections.detections:
Expand All @@ -191,12 +184,11 @@ def test_passthrough(
assert len(get_output_config(node)) == len(expected_frames)


@pytest.mark.parametrize(("platform", "labels"), [("rvc2", [1]), ("rvc4", [1, 2])])
@pytest.mark.parametrize("labels", [[1], [1, 2]])
def test_label_validation(
node,
node_input_detections,
node_input_frames,
platform,
labels,
resize_width,
resize_height,
Expand All @@ -207,30 +199,16 @@ def test_label_validation(
if detection.label not in labels:
continue
expected_frames.append(frame)
script = generate_script_content(
platform, resize_width, resize_height, valid_labels=labels
)
script = generate_script_content(resize_width, resize_height, valid_labels=labels)
try:
run_script(node, script)
except Warning:
assert expected_frames == get_output_frames(node)


@pytest.mark.parametrize("resize", [(128, 128), (128, 256), (256, 256)])
def test_rvc2_output_size(node, resize):
script = generate_script_content("rvc2", *resize)
try:
run_script(node, script)
except Warning:
output_cfg = get_output_config(node)
for cfg in output_cfg:
assert isinstance(cfg, dai.ImageManipConfig)
assert cfg.getResizeWidth(), cfg.getResizeHeight() == resize


@pytest.mark.parametrize("resize", [(128, 128), (128, 256), (256, 256)])
def test_rvc4_output_size(node, resize):
script = generate_script_content("rvc4", *resize)
def test_output_size(node, resize):
script = generate_script_content(*resize)
try:
run_script(node, script)
except Warning:
Expand All @@ -241,36 +219,7 @@ def test_rvc4_output_size(node, resize):


@pytest.mark.parametrize("padding", [0, 0.1, 0.2, -0.1, -0.2])
def test_rvc2_crop(node, node_input_detections, padding, resize_width, resize_height):
expected_rects: List[dai.ImageManipConfig.CropRect] = []
for input_dets in node_input_detections:
for detection in input_dets.detections:
rect = dai.ImageManipConfig.CropRect()
rect.xmin = max(detection.xmin - padding, 0)
rect.xmax = min(detection.xmax + padding, 1)
rect.ymin = max(detection.ymin - padding, 0)
rect.ymax = min(detection.ymax + padding, 1)
expected_rects.append(rect)
script = generate_script_content(
"rvc2", resize_width, resize_height, padding=padding
)
try:
run_script(node, script)
except Warning:
output_cfg = get_output_config(node)
for cfg, expected_rect in zip(output_cfg, expected_rects):
assert isinstance(cfg, dai.ImageManipConfig)
crop_rect = cfg.getCropConfig().cropRect
assert (crop_rect.xmin, crop_rect.xmax, crop_rect.ymin, crop_rect.ymax) == (
expected_rect.xmin,
expected_rect.xmax,
expected_rect.ymin,
expected_rect.ymax,
)


@pytest.mark.parametrize("padding", [0, 0.1, 0.2, -0.1, -0.2])
def test_rvc4_crop(node, node_input_detections, padding, resize_width, resize_height):
def test_crop(node, node_input_detections, padding, resize_width, resize_height):
ANGLE = 0
expected_rects: List[dai.RotatedRect] = []
for input_dets in node_input_detections:
Expand All @@ -283,9 +232,7 @@ def test_rvc4_crop(node, node_input_detections, padding, resize_width, resize_he
rect.size.width = detection.xmax - detection.xmin + rect_padding
rect.size.height = detection.ymax - detection.ymin + rect_padding
expected_rects.append(rect)
script = generate_script_content(
"rvc4", resize_width, resize_height, padding=padding
)
script = generate_script_content(resize_width, resize_height, padding=padding)
try:
run_script(node, script)
except Warning:
Expand All @@ -310,7 +257,6 @@ def run_script(node, script):
{
"node": node,
"ImageManipConfigV2": ImageManipConfigV2,
"ImageManipConfig": dai.ImageManipConfig,
"RotatedRect": dai.RotatedRect,
},
)
Expand All @@ -322,5 +268,5 @@ def get_output_frames(node: Node) -> List[Frame]:

def get_output_config(
node: Node,
) -> Union[List[dai.ImageManipConfig], List[ImageManipConfigV2]]:
) -> List[ImageManipConfigV2]:
return node.outputs[Node.OUTPUT_CONFIG_KEY].items
Loading