From b0c3900b163ff0e9bca26c46e1a246f1be62b132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Tue, 13 Aug 2024 18:44:42 +0200 Subject: [PATCH 1/3] use superclass checker --- boxmot/trackers/basetracker.py | 14 ++++++++++++++ boxmot/trackers/botsort/bot_sort.py | 14 ++------------ boxmot/trackers/bytetrack/byte_tracker.py | 11 ++--------- boxmot/trackers/deepocsort/deep_ocsort.py | 5 +---- boxmot/trackers/hybridsort/hybridsort.py | 3 +++ boxmot/trackers/imprassoc/impr_assoc_tracker.py | 13 +------------ boxmot/trackers/ocsort/ocsort.py | 10 +--------- boxmot/trackers/strongsort/strong_sort.py | 13 +------------ 8 files changed, 25 insertions(+), 58 deletions(-) diff --git a/boxmot/trackers/basetracker.py b/boxmot/trackers/basetracker.py index 7e28ab811e..544d74b20e 100644 --- a/boxmot/trackers/basetracker.py +++ b/boxmot/trackers/basetracker.py @@ -60,6 +60,20 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> """ raise NotImplementedError("The update method needs to be implemented by the subclass.") + def check_inputs(self, dets, im): + assert isinstance( + dets, np.ndarray + ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" + assert isinstance( + img, np.ndarray + ), f"Unsupported 'img_numpy' input format '{type(img)}', valid format is np.ndarray" + assert ( + len(dets.shape) == 2 + ), "Unsupported 'dets' dimensions, valid number of dimensions is two" + assert ( + dets.shape[1] == 6 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + def id_to_color(self, id: int, saturation: float = 0.75, value: float = 0.95) -> tuple: """ Generates a consistent unique BGR color for a given ID using hashing. diff --git a/boxmot/trackers/botsort/bot_sort.py b/boxmot/trackers/botsort/bot_sort.py index 8356eb6859..020a4df3d8 100644 --- a/boxmot/trackers/botsort/bot_sort.py +++ b/boxmot/trackers/botsort/bot_sort.py @@ -234,18 +234,8 @@ def __init__( @PerClassDecorator def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: - assert isinstance( - dets, np.ndarray - ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" - assert isinstance( - img, np.ndarray - ), f"Unsupported 'img_numpy' input format '{type(img)}', valid format is np.ndarray" - assert ( - len(dets.shape) == 2 - ), "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert ( - dets.shape[1] == 6 - ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + + self.check_inputs(dets, img) self.frame_count += 1 activated_starcks = [] diff --git a/boxmot/trackers/bytetrack/byte_tracker.py b/boxmot/trackers/bytetrack/byte_tracker.py index f566837ca3..cd6cb79cf3 100644 --- a/boxmot/trackers/bytetrack/byte_tracker.py +++ b/boxmot/trackers/bytetrack/byte_tracker.py @@ -143,15 +143,8 @@ def __init__( @PerClassDecorator def update(self, dets: np.ndarray, img: np.ndarray = None, embs: np.ndarray = None) -> np.ndarray: - assert isinstance( - dets, np.ndarray - ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" - assert ( - len(dets.shape) == 2 - ), "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert ( - dets.shape[1] == 6 - ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + + self.check_inputs(dets, img) dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) self.frame_count += 1 diff --git a/boxmot/trackers/deepocsort/deep_ocsort.py b/boxmot/trackers/deepocsort/deep_ocsort.py index 1e5d3f6f90..d88f11b5ee 100644 --- a/boxmot/trackers/deepocsort/deep_ocsort.py +++ b/boxmot/trackers/deepocsort/deep_ocsort.py @@ -286,10 +286,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> """ #dets, s, c = dets.data #print(dets, s, c) - assert isinstance(dets, np.ndarray), f"Unsupported 'dets' input type '{type(dets)}', valid format is np.ndarray" - assert isinstance(img, np.ndarray), f"Unsupported 'img' input type '{type(img)}', valid format is np.ndarray" - assert len(dets.shape) == 2, "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert dets.shape[1] == 6, "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + self.check_inputs(dets, img) self.frame_count += 1 self.height, self.width = img.shape[:2] diff --git a/boxmot/trackers/hybridsort/hybridsort.py b/boxmot/trackers/hybridsort/hybridsort.py index 946c876d49..d5b76079ae 100644 --- a/boxmot/trackers/hybridsort/hybridsort.py +++ b/boxmot/trackers/hybridsort/hybridsort.py @@ -386,6 +386,9 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> Returns the a similar array, where the last column is the object ID. NOTE: The number of objects returned may differ from the number of detections provided. """ + + self.check_inputs(dets, img) + if dets is None: return np.empty((0, 7)) diff --git a/boxmot/trackers/imprassoc/impr_assoc_tracker.py b/boxmot/trackers/imprassoc/impr_assoc_tracker.py index 0c79d6775c..8ffb5bdd6c 100644 --- a/boxmot/trackers/imprassoc/impr_assoc_tracker.py +++ b/boxmot/trackers/imprassoc/impr_assoc_tracker.py @@ -242,18 +242,7 @@ def __init__( @PerClassDecorator def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: - assert isinstance( - dets, np.ndarray - ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" - assert isinstance( - img, np.ndarray - ), f"Unsupported 'img_numpy' input format '{type(img)}', valid format is np.ndarray" - assert ( - len(dets.shape) == 2 - ), "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert ( - dets.shape[1] == 6 - ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + self.check_inputs(dets, img) self.frame_count += 1 activated_starcks = [] diff --git a/boxmot/trackers/ocsort/ocsort.py b/boxmot/trackers/ocsort/ocsort.py index 21515bd247..1ae0963ed3 100644 --- a/boxmot/trackers/ocsort/ocsort.py +++ b/boxmot/trackers/ocsort/ocsort.py @@ -226,15 +226,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> NOTE: The number of objects returned may differ from the number of detections provided. """ - assert isinstance( - dets, np.ndarray - ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" - assert ( - len(dets.shape) == 2 - ), "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert ( - dets.shape[1] == 6 - ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + self.check_inputs(dets, img) self.frame_count += 1 h, w = img.shape[0:2] diff --git a/boxmot/trackers/strongsort/strong_sort.py b/boxmot/trackers/strongsort/strong_sort.py index 66bc73b4fd..1c84d9e4cd 100644 --- a/boxmot/trackers/strongsort/strong_sort.py +++ b/boxmot/trackers/strongsort/strong_sort.py @@ -44,18 +44,7 @@ def __init__( @PerClassDecorator def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: - assert isinstance( - dets, np.ndarray - ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" - assert isinstance( - img, np.ndarray - ), f"Unsupported 'img' input format '{type(img)}', valid format is np.ndarray" - assert ( - len(dets.shape) == 2 - ), "Unsupported 'dets' dimensions, valid number of dimensions is two" - assert ( - dets.shape[1] == 6 - ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" + self.check_inputs(dets, img) dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) xyxy = dets[:, 0:4] From 6a4cb94754fd7785993ec5500a104d903f82be5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Tue, 13 Aug 2024 18:49:09 +0200 Subject: [PATCH 2/3] use superclass checker --- boxmot/trackers/basetracker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/boxmot/trackers/basetracker.py b/boxmot/trackers/basetracker.py index 544d74b20e..47c45be842 100644 --- a/boxmot/trackers/basetracker.py +++ b/boxmot/trackers/basetracker.py @@ -60,7 +60,7 @@ def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> """ raise NotImplementedError("The update method needs to be implemented by the subclass.") - def check_inputs(self, dets, im): + def check_inputs(self, dets, img): assert isinstance( dets, np.ndarray ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" From ca2605c535ea13a300645b10e06e31451b430890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikel=20Brostr=C3=B6m?= Date: Tue, 13 Aug 2024 18:52:25 +0200 Subject: [PATCH 3/3] use superclass checker --- boxmot/trackers/strongsort/strong_sort.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/boxmot/trackers/strongsort/strong_sort.py b/boxmot/trackers/strongsort/strong_sort.py index 1c84d9e4cd..66bc73b4fd 100644 --- a/boxmot/trackers/strongsort/strong_sort.py +++ b/boxmot/trackers/strongsort/strong_sort.py @@ -44,7 +44,18 @@ def __init__( @PerClassDecorator def update(self, dets: np.ndarray, img: np.ndarray, embs: np.ndarray = None) -> np.ndarray: - self.check_inputs(dets, img) + assert isinstance( + dets, np.ndarray + ), f"Unsupported 'dets' input format '{type(dets)}', valid format is np.ndarray" + assert isinstance( + img, np.ndarray + ), f"Unsupported 'img' input format '{type(img)}', valid format is np.ndarray" + assert ( + len(dets.shape) == 2 + ), "Unsupported 'dets' dimensions, valid number of dimensions is two" + assert ( + dets.shape[1] == 6 + ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)]) xyxy = dets[:, 0:4]