Skip to content

Commit

Permalink
enable obb tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikel Broström committed Jan 31, 2025
1 parent e0a1177 commit 93010bc
Showing 1 changed file with 159 additions and 6 deletions.
165 changes: 159 additions & 6 deletions boxmot/trackers/ocsort/ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,18 @@
import numpy as np
from collections import deque

import cv2 as cv

from boxmot.motion.kalman_filters.aabb.xysr_kf import KalmanFilterXYSR
from boxmot.utils.association import associate, linear_assignment
from boxmot.trackers.basetracker import BaseTracker
from boxmot.utils.ops import xyxy2xysr
from boxmot.motion.kalman_filters.obb.xywha_kf import KalmanBoxTrackerOBB


def k_previous_obs(observations, cur_age, k, is_obb=False):

def k_previous_obs(observations, cur_age, k):
if len(observations) == 0:
if is_obb:
return [-1, -1, -1, -1, -1, -1]
else :
return [-1, -1, -1, -1, -1]
return [-1, -1, -1, -1, -1]
for i in range(k):
dt = k - i
if cur_age - dt in observations:
Expand All @@ -28,6 +26,161 @@ def k_previous_obs(observations, cur_age, k, is_obb=False):
return observations[max_age]


def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score is None:
return np.array(
[x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]
).reshape((1, 4))
else:
return np.array(
[x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]
).reshape((1, 5))


def speed_direction(bbox1, bbox2):
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
speed = np.array([cy2 - cy1, cx2 - cx1])
norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6
return speed / norm


class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
"""

count = 0

def __init__(self, bbox, cls, det_ind, delta_t=3, max_obs=50, Q_xy_scaling = 0.01, Q_s_scaling = 0.0001):
"""
Initialises a tracker using initial bounding box.
"""
# define constant velocity model
self.det_ind = det_ind

self.Q_xy_scaling = Q_xy_scaling
self.Q_s_scaling = Q_s_scaling

self.kf = KalmanFilterXYSR(dim_x=7, dim_z=4, max_obs=max_obs)
self.kf.F = np.array(
[
[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1],
]
)
self.kf.H = np.array(
[
[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
]
)

self.kf.R[2:, 2:] *= 10.0
self.kf.P[
4:, 4:
] *= 1000.0 # give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.0

self.kf.Q[4:6, 4:6] *= self.Q_xy_scaling
self.kf.Q[-1, -1] *= self.Q_s_scaling

self.kf.x[:4] = xyxy2xysr(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.max_obs = max_obs
self.history = deque([], maxlen=self.max_obs)
self.hits = 0
self.hit_streak = 0
self.age = 0
self.conf = bbox[-1]
self.cls = cls
"""
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]),
let's bear it for now.
"""
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
self.observations = dict()
self.history_observations = deque([], maxlen=self.max_obs)
self.velocity = None
self.delta_t = delta_t

def update(self, bbox, cls, det_ind):
"""
Updates the state vector with observed bbox.
"""
self.det_ind = det_ind
if bbox is not None:
self.conf = bbox[-1]
self.cls = cls
if self.last_observation.sum() >= 0: # no previous observation
previous_box = None
for i in range(self.delta_t):
dt = self.delta_t - i
if self.age - dt in self.observations:
previous_box = self.observations[self.age - dt]
break
if previous_box is None:
previous_box = self.last_observation
"""
Estimate the track speed direction with observations \Delta t steps away
"""
self.velocity = speed_direction(previous_box, bbox)

"""
Insert new observations. This is a ugly way to maintain both self.observations
and self.history_observations. Bear it for the moment.
"""
self.last_observation = bbox
self.observations[self.age] = bbox
self.history_observations.append(bbox)

self.time_since_update = 0
self.hits += 1
self.hit_streak += 1
self.kf.update(xyxy2xysr(bbox))
else:
self.kf.update(bbox)

def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
if (self.kf.x[6] + self.kf.x[2]) <= 0:
self.kf.x[6] *= 0.0

self.kf.predict()
self.age += 1
if self.time_since_update > 0:
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]

def get_state(self):
"""
Returns the current bounding box estimate.
"""
return convert_x_to_bbox(self.kf.x)


class OcSort(BaseTracker):
"""
OCSort Tracker: A tracking algorithm that utilizes motion-based tracking.
Expand Down

0 comments on commit 93010bc

Please sign in to comment.