Skip to content

Commit

Permalink
updated tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Dec 22, 2022
1 parent c72b2f1 commit 7368198
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 94 deletions.
70 changes: 41 additions & 29 deletions c/detectNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ int detectNet::postProcessSSD_UFF( Detection* detections, uint32_t width, uint32
if( object_data[2] < mConfidenceThreshold )
continue;

detections[numDetections].Instance = -1; //numDetections; //(uint32_t)object_data[0];
detections[numDetections].TrackID = -1; //numDetections; //(uint32_t)object_data[0];
detections[numDetections].ClassID = (uint32_t)object_data[1];
detections[numDetections].Confidence = object_data[2];
detections[numDetections].Left = object_data[3] * width;
Expand Down Expand Up @@ -722,7 +722,7 @@ int detectNet::postProcessSSD_ONNX( Detection* detections, uint32_t width, uint3
// populate a new detection entry
const float* coord = bbox + n * numCoord;

detections[numDetections].Instance = -1; //numDetections;
detections[numDetections].TrackID = -1; //numDetections;
detections[numDetections].ClassID = maxClass;
detections[numDetections].Confidence = maxScore;
detections[numDetections].Left = coord[0] * width;
Expand Down Expand Up @@ -802,7 +802,7 @@ int detectNet::postProcessDetectNet( Detection* detections, uint32_t width, uint
// create new entry if the detection wasn't merged with another detection
if( !detectionMerged )
{
detections[numDetections].Instance = -1; //numDetections;
detections[numDetections].TrackID = -1; //numDetections;
detections[numDetections].ClassID = z;
detections[numDetections].Confidence = coverage;

Expand Down Expand Up @@ -872,7 +872,7 @@ int detectNet::postProcessDetectNet_v2( Detection* detections, uint32_t width, u
LogDebug(LOG_TRT "rect x=%u y=%u conf=%f (%f, %f) (%f, %f) \n", x, y, confidence, x1, y1, x2, y2);
#endif

detections[numDetections].Instance = -1; //numDetections;
detections[numDetections].TrackID = -1; //numDetections;
detections[numDetections].ClassID = c;
detections[numDetections].Confidence = confidence;
detections[numDetections].Left = x1;
Expand Down Expand Up @@ -902,24 +902,30 @@ int detectNet::clusterDetections( Detection* detections, int n )
{
// if the intersecting detections have different classes, pick the one with highest confidence
// otherwise if they have the same object class, expand the detection bounding box
#ifdef CLUSTER_INTERCLASS
if( detections[n].ClassID != detections[m].ClassID )
{
if( detections[n].Confidence > detections[m].Confidence )
{
detections[m] = detections[n];

detections[m].Instance = -1; //m;
detections[m].TrackID = -1; //m;
detections[m].ClassID = detections[n].ClassID;
detections[m].Confidence = detections[n].Confidence;
detections[m].Confidence = detections[n].Confidence;
}

return 0; // merged detection
}
else
#else
if( detections[n].ClassID == detections[m].ClassID )
#endif
{
detections[m].Expand(detections[n]);
detections[m].Confidence = fmaxf(detections[n].Confidence, detections[m].Confidence);
}

return 0; // merged detection
return 0; // merged detection
}
}
}

Expand Down Expand Up @@ -949,7 +955,7 @@ void detectNet::sortDetections( Detection* detections, int numDetections )

// renumber the instance ID's
//for( int i=0; i < numDetections; i++ )
// detections[i].Instance = i;
// detections[i].TrackID = i;
}


Expand Down Expand Up @@ -1008,7 +1014,7 @@ bool detectNet::Overlay( void* input, void* output, uint32_t width, uint32_t hei
}

// class label overlay
if( (flags & OVERLAY_LABEL) || (flags & OVERLAY_CONFIDENCE) )
if( (flags & OVERLAY_LABEL) || (flags & OVERLAY_CONFIDENCE) || (flags & OVERLAY_TRACKING) )
{
static cudaFont* font = NULL;

Expand All @@ -1025,38 +1031,42 @@ bool detectNet::Overlay( void* input, void* output, uint32_t width, uint32_t hei
}

// draw each object's description
std::vector< std::pair< std::string, int2 > > labels;

#ifdef BATCH_TEXT
std::vector<std::pair<std::string, int2>> labels;
#endif
for( uint32_t n=0; n < numDetections; n++ )
{
const char* className = GetClassDesc(detections[n].ClassID);
const float confidence = detections[n].Confidence * 100.0f;
const int2 position = make_int2(detections[n].Left+5, detections[n].Top+3);

char buffer[256];
char* str = buffer;

if( flags & OVERLAY_LABEL )
str += sprintf(str, "%s ", className);

if( flags & OVERLAY_TRACKING && detections[n].TrackID >= 0 )
str += sprintf(str, "%i ", detections[n].TrackID);

if( flags & OVERLAY_CONFIDENCE )
{
char str[256];

if( (flags & OVERLAY_LABEL) && (flags & OVERLAY_CONFIDENCE) )
{
if( detections[n].Instance >= 0 )
sprintf(str, "%s %i %.1f%%", className, detections[n].Instance, confidence);
else
sprintf(str, "%s %.1f%%", className, confidence);
}
else
sprintf(str, "%.1f%%", confidence);
str += sprintf(str, "%.1f%%", confidence);

labels.push_back(std::pair<std::string, int2>(str, position));
}
else
#ifdef BATCH_TEXT
labels.push_back(std::pair<std::string, int2>(buffer, position));
#else
if( detections[n].TrackID >= 0 )
{
// overlay label only
labels.push_back(std::pair<std::string, int2>(className, position));
float4 color = make_float4(255,255,255,255);
color.w *= 1.0f - (fminf(detections[n].TrackLost, 15.0f) / 15.0f);
font->OverlayText(output, format, width, height, buffer, position.x, position.y, color);
}
#endif
}

#ifdef BATCH_TEXT
font->OverlayText(output, format, width, height, labels, make_float4(255,255,255,255));
#endif
}

PROFILER_END(PROFILER_VISUALIZE);
Expand Down Expand Up @@ -1105,6 +1115,8 @@ uint32_t detectNet::OverlayFlagsFromStr( const char* str_user )
flags |= OVERLAY_LABEL;
else if( strcasecmp(token, "conf") == 0 || strcasecmp(token, "confidence") == 0 )
flags |= OVERLAY_CONFIDENCE;
else if( strcasecmp(token, "track") == 0 || strcasecmp(token, "tracking") == 0 )
flags |= OVERLAY_TRACKING;
else if( strcasecmp(token, "line") == 0 || strcasecmp(token, "lines") == 0 )
flags |= OVERLAY_LINES;
else if( strcasecmp(token, "default") == 0 )
Expand Down
7 changes: 6 additions & 1 deletion c/detectNet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ cudaError_t launchDetectionOverlay( T* input, T* output, uint32_t width, uint32_
const dim3 blockDim(8, 8);
const dim3 gridDim(iDivUp(boxWidth,blockDim.x), iDivUp(boxHeight,blockDim.y));

gpuDetectionOverlayBox<T><<<gridDim, blockDim>>>(input, output, width, height, (int)detections[n].Left, (int)detections[n].Top, boxWidth, boxHeight, colors[detections[n].ClassID]);
float4 color = colors[detections[n].ClassID];

if( detections[n].TrackID >= 0 )
color.w *= 1.0f - (fminf(detections[n].TrackLost, 15.0f) / 15.0f);

gpuDetectionOverlayBox<T><<<gridDim, blockDim>>>(input, output, width, height, (int)detections[n].Left, (int)detections[n].Top, boxWidth, boxHeight, color);
}

return cudaGetLastError();
Expand Down
12 changes: 7 additions & 5 deletions c/detectNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
" --alpha=ALPHA overlay alpha blending value, range 0-255 (default: 120)\n" \
" --overlay=OVERLAY detection overlay flags (e.g. --overlay=box,labels,conf)\n" \
" valid combinations are: 'box', 'lines', 'labels', 'conf', 'none'\n" \
" --profile enable layer profiling in TensorRT\n\n"
" --profile enable layer profiling in TensorRT\n\n" \


// forward declarations
Expand All @@ -128,8 +128,9 @@ class detectNet : public tensorNet
float Confidence; /**< Confidence value of the detected object. */

// Tracking Info
int Instance; /**< Unique tracking ID (or -1 if untracked) */
int TrackFrames; /**< The number of frames the object has been positively tracked for */
int TrackID; /**< Unique tracking ID (or -1 if untracked) */
int TrackStatus; /**< -1 for dropped, 0 for initializing, 1 for active/valid */
int TrackFrames; /**< The number of frames the object has been re-identified for */
int TrackLost; /**< The number of consecutive frames tracking has been lost for */

// Bounding Box Coordinates
Expand Down Expand Up @@ -193,7 +194,7 @@ class detectNet : public tensorNet
inline bool Expand( const Detection& det ) { if(!Overlaps(det)) return false; Left = fminf(det.Left, Left); Top = fminf(det.Top, Top); Right = fmaxf(det.Right, Right); Bottom = fmaxf(det.Bottom, Bottom); return true; }

/**< Reset all member variables to zero */
inline void Reset() { ClassID = 0; Confidence = 0; Instance = -1; TrackFrames = 0; TrackLost = 0; Left = 0; Right = 0; Top = 0; Bottom = 0; }
inline void Reset() { ClassID = 0; Confidence = 0; TrackID = -1; TrackStatus = -1; TrackFrames = 0; TrackLost = 0; Left = 0; Right = 0; Top = 0; Bottom = 0; }

/**< Default constructor */
inline Detection() { Reset(); }
Expand All @@ -208,7 +209,8 @@ class detectNet : public tensorNet
OVERLAY_BOX = (1 << 0), /**< Overlay the object bounding boxes (filled) */
OVERLAY_LABEL = (1 << 1), /**< Overlay the class description labels */
OVERLAY_CONFIDENCE = (1 << 2), /**< Overlay the detection confidence values */
OVERLAY_LINES = (1 << 3), /**< Overlay the bounding box lines (unfilled) */
OVERLAY_TRACKING = (1 << 3), /**< Overlay tracking information (like track ID) */
OVERLAY_LINES = (1 << 4), /**< Overlay the bounding box lines (unfilled) */
OVERLAY_DEFAULT = OVERLAY_BOX|OVERLAY_LABEL|OVERLAY_CONFIDENCE, /**< The default choice of overlay */
};

Expand Down
35 changes: 26 additions & 9 deletions c/trackers/objectTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,56 @@
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/


#include "detectNet.h"
#include "objectTracker.h"


#include "objectTrackerIOU.h"
#include "objectTrackerKLT.h"


// Create
objectTracker* objectTracker::Create( objectTracker::Type type )
{
objectTracker* tracker = NULL;

if( type == KLT )
{
#if HAS_VPI
return objectTrackerKLT::Create();
tracker = objectTrackerKLT::Create();
#else
LogError(LOG_TRACKER "couldn't create KLT tracker (not built with VPI enabled)\n");
return NULL;
#endif
}
else if( type == IOU )
{
return objectTrackerIOU::Create();
tracker = objectTrackerIOU::Create();
}

return NULL;
if( !tracker )
return NULL;

if( !tracker->Init() )
{
delete tracker;
return NULL;
}

return tracker;
}


// Create
objectTracker* objectTracker::Create( const commandLine& cmdLine )
{
const char* str = cmdLine.GetString("tracker", cmdLine.GetString("tracking"));
const Type type = TypeFromStr(str);
Type type = IOU;

const bool useDefault = cmdLine.GetFlag("tracking");
const char* typeStr = cmdLine.GetString("tracker", cmdLine.GetString("tracking"));

if( !useDefault )
type = TypeFromStr(typeStr);

if( type == KLT )
{
Expand All @@ -68,8 +85,8 @@ objectTracker* objectTracker::Create( const commandLine& cmdLine )
}
else
{
if( str != NULL )
LogError(LOG_TRACKER "tried to create invalid object tracker type: %s\n", str);
if( typeStr != NULL )
LogError(LOG_TRACKER "tried to create invalid object tracker type: %s\n", typeStr);
}

return NULL;
Expand Down
25 changes: 23 additions & 2 deletions c/trackers/objectTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,23 @@


/**
* Tracker logging prefix
* Standard command-line options able to be passed to detectNet::Create()
* @ingroup objectTracker
*/
#define OBJECT_TRACKER_USAGE_STRING "objectTracker arguments: \n" \
" --tracking flag to enable default tracker (IOU)\n" \
" --tracker=TRACKER enable tracking with 'IOU' or 'KLT'\n" \
" --tracker-min-frames=N the number of re-identified frames for a track to be considered valid (default: 3)\n" \
" --tracker-lost-frames=N number of consecutive lost frames before a track is removed (default: 15)\n" \
" --tracker-overlap=N how much IOU overlap is required for a bounding box to be matched (default: 0.5)\n\n" \

/**
* Object tracker logging prefix
* @ingroup objectTracker
*/
#define LOG_TRACKER "[tracker] "


/**
* Object tracker interface
* @ingroup objectTracker
Expand Down Expand Up @@ -66,6 +77,11 @@ class objectTracker
*/
static objectTracker* Create( const commandLine& cmdLine );

/**
* Init (optional)
*/
virtual bool Init() { return true; }

/**
* Process
*/
Expand All @@ -81,6 +97,11 @@ class objectTracker
*/
virtual Type GetType() const = 0;

/**
* Usage string for command line arguments to Create()
*/
static inline const char* Usage() { return OBJECT_TRACKER_USAGE_STRING; }

/**
* Convert a Type enum to string.
*/
Expand Down
Loading

0 comments on commit 7368198

Please sign in to comment.