Skip to content

Commit

Permalink
integrated object tracking into detectNet
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Dec 21, 2022
1 parent e367b2c commit c72b2f1
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 26 deletions.
15 changes: 13 additions & 2 deletions c/detectNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*/

#include "detectNet.h"
#include "objectTracker.h"
#include "tensorConvert.h"
#include "modelDownloader.h"

Expand Down Expand Up @@ -48,6 +49,7 @@
// constructor
detectNet::detectNet( float meanPixel ) : tensorNet()
{
mTracker = NULL;
mMeanPixel = meanPixel;
mLineWidth = 2.0f;

Expand All @@ -67,6 +69,8 @@ detectNet::detectNet( float meanPixel ) : tensorNet()
// destructor
detectNet::~detectNet()
{
SAFE_DELETE(mTracker);

CUDA_FREE_HOST(mDetectionSets);
CUDA_FREE_HOST(mClassColors);
}
Expand Down Expand Up @@ -384,6 +388,9 @@ detectNet* detectNet::Create( const commandLine& cmdLine )
net->SetOverlayAlpha(cmdLine.GetFloat("alpha", DETECTNET_DEFAULT_ALPHA));
net->SetClusteringThreshold(cmdLine.GetFloat("clustering", DETECTNET_DEFAULT_CLUSTERING_THRESHOLD));

// enable tracking if requested
net->SetTracker(objectTracker::Create(cmdLine));

return net;
}

Expand Down Expand Up @@ -514,7 +521,7 @@ int detectNet::Detect( void* input, uint32_t width, uint32_t height, imageFormat
PROFILER_END(PROFILER_NETWORK);

// post-processing / clustering
const int numDetections = postProcess(detections, width, height);
const int numDetections = postProcess(input, width, height, format, detections);

// render the overlay
if( overlay != 0 && numDetections > 0 )
Expand Down Expand Up @@ -592,7 +599,7 @@ bool detectNet::preProcess( void* input, uint32_t width, uint32_t height, imageF


// postProcess
int detectNet::postProcess( Detection* detections, uint32_t width, uint32_t height )
int detectNet::postProcess( void* input, uint32_t width, uint32_t height, imageFormat format, Detection* detections )
{
PROFILER_BEGIN(PROFILER_POSTPROCESS);

Expand Down Expand Up @@ -629,6 +636,10 @@ int detectNet::postProcess( Detection* detections, uint32_t width, uint32_t heig
detections[n].Bottom = height - 1;
}

// update tracking
if( mTracker != NULL )
numDetections = mTracker->Process(input, width, height, format, detections, numDetections);

PROFILER_END(PROFILER_POSTPROCESS);
return numDetections;
}
Expand Down
20 changes: 18 additions & 2 deletions c/detectNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@
" --profile enable layer profiling in TensorRT\n\n"


// forward declarations
class objectTracker;


/**
* Object recognition and localization networks with TensorRT support.
* @ingroup detectNet
Expand Down Expand Up @@ -429,6 +433,16 @@ class detectNet : public tensorNet
*/
inline void SetClusteringThreshold( float threshold ) { mClusteringThreshold = threshold; }

/**
* Get the object tracker being used.
*/
inline objectTracker* GetTracker() const { return mTracker; }

/**
* Set the object tracker to be used.
*/
inline void SetTracker( objectTracker* tracker ) { mTracker = tracker; }

/**
* Retrieve the maximum number of simultaneous detections the network supports.
* Knowing this is useful for allocating the buffers to store the output detection results.
Expand Down Expand Up @@ -484,7 +498,7 @@ class detectNet : public tensorNet
* Set overlay alpha blending value for all classes (between 0-255).
*/
void SetOverlayAlpha( float alpha );

protected:

// constructor
Expand All @@ -500,8 +514,8 @@ class detectNet : public tensorNet
precisionType precision, deviceType device, bool allowGPUFallback );

bool preProcess( void* input, uint32_t width, uint32_t height, imageFormat format );
int postProcess( void* input, uint32_t width, uint32_t height, imageFormat format, Detection* detections );

int postProcess( Detection* detections, uint32_t width, uint32_t height );
int postProcessSSD_UFF( Detection* detections, uint32_t width, uint32_t height );
int postProcessSSD_ONNX( Detection* detections, uint32_t width, uint32_t height );
int postProcessDetectNet( Detection* detections, uint32_t width, uint32_t height );
Expand All @@ -510,6 +524,8 @@ class detectNet : public tensorNet
int clusterDetections( Detection* detections, int n );
void sortDetections( Detection* detections, int numDetections );

objectTracker* mTracker;

float mConfidenceThreshold; // TODO change this to per-class
float mClusteringThreshold; // TODO change this to per-class

Expand Down
6 changes: 4 additions & 2 deletions c/trackers/objectTracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ objectTracker* objectTracker::Create( const commandLine& cmdLine )
}
else
{
LogError(LOG_TRACKER "tried to create invalid object tracker type: %s\n", str);
return NULL;
if( str != NULL )
LogError(LOG_TRACKER "tried to create invalid object tracker type: %s\n", str);
}

return NULL;
}


Expand Down
2 changes: 1 addition & 1 deletion c/trackers/objectTrackerIOU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ int objectTrackerIOU::Process( void* input, uint32_t width, uint32_t height, ima

for( int n=0; n < mTracks.size(); n++ )
{
if( mTracks[n].TrackFrames >= 5 )
if( mTracks[n].TrackFrames >= 3 )
detections[numDetections++] = mTracks[n];
}

Expand Down
7 changes: 5 additions & 2 deletions examples/detectnet/detectnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ int main( int argc, char** argv )

for( int n=0; n < numDetections; n++ )
{
LogVerbose("detected obj %i class #%u (%s) confidence=%f\n", n, detections[n].ClassID, net->GetClassDesc(detections[n].ClassID), detections[n].Confidence);
LogVerbose("bounding box %i (%f, %f) (%f, %f) w=%f h=%f\n", n, detections[n].Left, detections[n].Top, detections[n].Right, detections[n].Bottom, detections[n].Width(), detections[n].Height());
LogVerbose("\ndetected obj %i class #%u (%s) confidence=%f\n", n, detections[n].ClassID, net->GetClassDesc(detections[n].ClassID), detections[n].Confidence);
LogVerbose("bounding box %i (%.2f, %.2f) (%.2f, %.2f) w=%.2f h=%.2f\n", n, detections[n].Left, detections[n].Top, detections[n].Right, detections[n].Bottom, detections[n].Width(), detections[n].Height());

if( detections[n].Instance >= 0 ) // is this a tracked object?
LogVerbose("tracking instance %i frames=%i lost=%i\n", detections[n].Instance, detections[n].TrackFrames, detections[n].TrackLost);
}
}

Expand Down
59 changes: 42 additions & 17 deletions python/bindings/PyDetectNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,24 +134,49 @@ static PyObject* PyDetection_ToString( PyDetection_Object* self )
self->det.Center(&cx, &cy);

// format string
char str[1024];

sprintf(str,
"<detectNet.Detection object>\n"
" -- ClassID: %i\n"
" -- Confidence: %g\n"
" -- Left: %g\n"
" -- Top: %g\n"
" -- Right: %g\n"
" -- Bottom: %g\n"
" -- Width: %g\n"
" -- Height: %g\n"
" -- Area: %g\n"
" -- Center: (%g, %g)",
self->det.ClassID, self->det.Confidence,
self->det.Left, self->det.Top, self->det.Right, self->det.Bottom,
self->det.Width(), self->det.Height(), self->det.Area(), cx, cy);
char str[4096];

if( self->det.Instance >= 0 )
{
sprintf(str,
"<detectNet.Detection object>\n"
" -- ClassID: %i\n"
" -- Confidence: %g\n"
" -- Instance: %i\n"
" -- Track Frames: %i\n"
" -- Track Lost: %i\n"
" -- Left: %g\n"
" -- Top: %g\n"
" -- Right: %g\n"
" -- Bottom: %g\n"
" -- Width: %g\n"
" -- Height: %g\n"
" -- Area: %g\n"
" -- Center: (%g, %g)",
self->det.ClassID, self->det.Confidence,
self->det.Instance, self->det.TrackFrames, self->det.TrackLost,
self->det.Left, self->det.Top, self->det.Right, self->det.Bottom,
self->det.Width(), self->det.Height(), self->det.Area(), cx, cy);
}
else
{
sprintf(str,
"<detectNet.Detection object>\n"
" -- ClassID: %i\n"
" -- Confidence: %g\n"
" -- Left: %g\n"
" -- Top: %g\n"
" -- Right: %g\n"
" -- Bottom: %g\n"
" -- Width: %g\n"
" -- Height: %g\n"
" -- Area: %g\n"
" -- Center: (%g, %g)",
self->det.ClassID, self->det.Confidence,
self->det.Left, self->det.Top, self->det.Right, self->det.Bottom,
self->det.Width(), self->det.Height(), self->det.Area(), cx, cy);
}

return PYSTRING_FROM_STRING(str);
}

Expand Down

0 comments on commit c72b2f1

Please sign in to comment.