Skip to content

Commit

Permalink
automatic model downloads for actionNet
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Dec 19, 2022
1 parent 7808c8b commit 9933995
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 127 deletions.
101 changes: 32 additions & 69 deletions c/actionNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@

#include "actionNet.h"
#include "tensorConvert.h"
#include "modelDownloader.h"

#include "commandLine.h"
#include "filesystem.h"
#include "logging.h"


// constructor
actionNet::actionNet() : tensorNet()
{
mNetworkType = CUSTOM;
mNumClasses = 0;
mNumFrames = 0;
mNumClasses = 0;
mNumFrames = 0;

mInputBuffers[0] = NULL;
mInputBuffers[1] = NULL;
Expand All @@ -53,30 +54,23 @@ actionNet::~actionNet()


// Create
actionNet* actionNet::Create( actionNet::NetworkType networkType, uint32_t maxBatchSize,
precisionType precision, deviceType device, bool allowGPUFallback )
actionNet* actionNet::Create( const char* network, uint32_t maxBatchSize,
precisionType precision, deviceType device, bool allowGPUFallback )
{
actionNet* net = NULL;

if( networkType == RESNET_18 )
net = Create("networks/Action-ResNet18/resnet-18-kinetics-moments.onnx", "networks/Action-ResNet18/labels.txt", "0", "198", maxBatchSize, precision, device, allowGPUFallback);
else if( networkType == RESNET_34 )
net = Create("networks/Action-ResNet34/resnet-34-kinetics-moments.onnx", "networks/Action-ResNet34/labels.txt", "0", "350", maxBatchSize, precision, device, allowGPUFallback);
/*else if( networkType == RESNET_50 )
net = Create("networks/Action-ResNet50/resnet-50-kinetics.onnx", "networks/Action-ResNet50/labels.txt", "0", "503", maxBatchSize, precision, device, allowGPUFallback);
else if( networkType == RESNET_101 )
net = Create("networks/Action-ResNet101/resnext-101-kinetics.onnx", "networks/Action-ResNet101/labels.txt", "0", "979", maxBatchSize, precision, device, allowGPUFallback);*/
nlohmann::json model;

if( !net )
{
LogError(LOG_TRT "actionNet -- invalid built-in model '%s' requested\n", actionNet::NetworkTypeToStr(networkType));
if( !DownloadModel(ACTIONNET_MODEL_TYPE, network, model) )
return NULL;
}

net->mNetworkType = networkType;
return net;
}
std::string model_dir = "networks/" + model["dir"].get<std::string>() + "/";
std::string model_path = model_dir + JSON_STR(model["model"]);
std::string labels = model_dir + JSON_STR(model["labels"]);
std::string input = JSON_STR_DEFAULT(model["input"], ACTIONNET_DEFAULT_INPUT);
std::string output = JSON_STR_DEFAULT(model["output"], ACTIONNET_DEFAULT_OUTPUT);

return Create(model_path.c_str(), labels.c_str(), input.c_str(), output.c_str(),
maxBatchSize, precision, device, allowGPUFallback);
}


// Create
Expand All @@ -98,17 +92,12 @@ actionNet* actionNet::Create( const commandLine& cmdLine )
modelName = cmdLine.GetString("model", "resnet-18");

// parse the network type
const actionNet::NetworkType type = NetworkTypeFromStr(modelName);

if( type == actionNet::CUSTOM )
if( !FindModel(ACTIONNET_MODEL_TYPE, modelName) )
{
const char* labels = cmdLine.GetString("labels");
const char* input = cmdLine.GetString("input_blob");
const char* output = cmdLine.GetString("output_blob");
const char* input = cmdLine.GetString("input_blob", ACTIONNET_DEFAULT_INPUT);
const char* output = cmdLine.GetString("output_blob", ACTIONNET_DEFAULT_OUTPUT);

if( !input ) input = ACTIONNET_DEFAULT_INPUT;
if( !output ) output = ACTIONNET_DEFAULT_OUTPUT;

int maxBatchSize = cmdLine.GetInt("batch_size");

if( maxBatchSize < 1 )
Expand All @@ -119,7 +108,7 @@ actionNet* actionNet::Create( const commandLine& cmdLine )
else
{
// create from pretrained model
net = actionNet::Create(type);
net = actionNet::Create(modelName);
}

if( !net )
Expand All @@ -139,6 +128,18 @@ actionNet* actionNet::Create( const char* model_path, const char* class_path,
uint32_t maxBatchSize, precisionType precision,
deviceType device, bool allowGPUFallback )
{
// check for built-in model string
if( FindModel(ACTIONNET_MODEL_TYPE, model_path) )
{
return Create(model_path, maxBatchSize, precision, device, allowGPUFallback);
}
else if( fileExtension(model_path).length() == 0 )
{
LogError(LOG_TRT "couldn't find built-in action model '%s'\n", model_path);
return NULL;
}

// load custom model
actionNet* net = new actionNet();

if( !net )
Expand Down Expand Up @@ -200,44 +201,6 @@ bool actionNet::init(const char* model_path, const char* class_path,
LogSuccess(LOG_TRT "actionNet -- %s initialized.\n", model_path);
return true;
}


// NetworkTypeFromStr
actionNet::NetworkType actionNet::NetworkTypeFromStr( const char* modelName )
{
if( !modelName )
return actionNet::CUSTOM;

actionNet::NetworkType type = actionNet::CUSTOM;

if( strcasecmp(modelName, "resnet-18") == 0 || strcasecmp(modelName, "resnet_18") == 0 || strcasecmp(modelName, "resnet18") == 0 )
type = actionNet::RESNET_18;
else if( strcasecmp(modelName, "resnet-34") == 0 || strcasecmp(modelName, "resnet_34") == 0 || strcasecmp(modelName, "resnet34") == 0 )
type = actionNet::RESNET_34;
/*else if( strcasecmp(modelName, "resnet-50") == 0 || strcasecmp(modelName, "resnet_50") == 0 || strcasecmp(modelName, "resnet50") == 0 )
type = actionNet::RESNET_50;
else if( strcasecmp(modelName, "resnet-101") == 0 || strcasecmp(modelName, "resnet_101") == 0 || strcasecmp(modelName, "resnet101") == 0 )
type = actionNet::RESNET_101;*/
else
type = actionNet::CUSTOM;

return type;
}


// NetworkTypeToStr
const char* actionNet::NetworkTypeToStr( actionNet::NetworkType network )
{
switch(network)
{
case actionNet::RESNET_18: return "ResNet-18";
case actionNet::RESNET_34: return "ResNet-34";
/*case actionNet::RESNET_50: return "ResNet-50";
case actionNet::RESNET_101: return "ResNet-101";*/
}

return "Custom";
}


// preProcess
Expand Down
44 changes: 8 additions & 36 deletions c/actionNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
*/
#define ACTIONNET_DEFAULT_OUTPUT "output"

/**
* The model type for actionNet in data/networks/models.json
* @ingroup actionNet
*/
#define ACTIONNET_MODEL_TYPE "action"

/**
* Standard command-line options able to be passed to actionNet::Create()
Expand All @@ -57,38 +62,16 @@


/**
* Action/activity classification on a sequence of images, using TensorRT.
* Action/activity classification on a sequence of images or video, using TensorRT.
* @ingroup actionNet
*/
class actionNet : public tensorNet
{
public:
/**
* Network choice enumeration.
*/
enum NetworkType
{
CUSTOM, /**< Custom model provided by the user */
RESNET_18, /**< ResNet-18 trained on 1040-class Kinetics-700 and Moments In Time dataset */
RESNET_34, /**< ResNet-50 trained on 1040-class Kinetics-700 and Moments In Time dataset */
};

/**
* Parse a string to one of the built-in pretrained models.
* Valid names are "resnet-18", or "resnet-34", ect.
* @returns one of the actionNet::NetworkType enums, or actionNet::CUSTOM on invalid string.
* Load a pre-trained model, either "resnet-18" or "resnet-34".
*/
static NetworkType NetworkTypeFromStr( const char* model_name );

/**
* Convert a NetworkType enum to a string.
*/
static const char* NetworkTypeToStr( NetworkType network );

/**
* Load a new network instance
*/
static actionNet* Create( NetworkType networkType=RESNET_18, uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
static actionNet* Create( const char* network="resnet-18", uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
precisionType precision=TYPE_FASTEST, deviceType device=DEVICE_GPU,
bool allowGPUFallback=true );

Expand Down Expand Up @@ -169,16 +152,6 @@ class actionNet : public tensorNet
*/
inline const char* GetClassPath() const { return mClassPath.c_str(); }

/**
* Retrieve the network type (alexnet or googlenet)
*/
inline NetworkType GetNetworkType() const { return mNetworkType; }

/**
* Retrieve a string describing the network name.
*/
inline const char* GetNetworkName() const { return NetworkTypeToStr(mNetworkType); }

protected:
actionNet();

Expand All @@ -196,7 +169,6 @@ class actionNet : public tensorNet
std::vector<std::string> mClassDesc;

std::string mClassPath;
NetworkType mNetworkType;
};


Expand Down
26 changes: 26 additions & 0 deletions data/networks/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -412,5 +412,31 @@
"model": "monodepth_fcn_resnet50.onnx",
"description": "FCN-ResNet50 monocular depth estimation model"
}
},

"action": {
"resnet-18": {
"alias": ["resnet_18", "resnet18"],
"url": "https://nvidia.box.com/shared/static/rhuvudgijtefzno55092fgwy0xn485n0.gz",
"tar": "Action-ResNet18.tar.gz",
"dir": "Action-ResNet18",
"model": "resnet-18-kinetics-moments.onnx",
"labels": "labels.txt",
"input": "0",
"output": "198",
"description": "ResNet-18 trained on 1040-class Kinetics-700 and Moments-In-Time dataset"
},

"resnet-34": {
"alias": ["resnet_34", "resnet34"],
"url": "https://nvidia.box.com/shared/static/gr9kgox9zpwh93v0v28f9zjkzr1n4ffs.gz",
"tar": "Action-ResNet34.tar.gz",
"dir": "Action-ResNet34",
"model": "resnet-34-kinetics-moments.onnx",
"labels": "labels.txt",
"input": "0",
"output": "350",
"description": "ResNet-34 trained on 1040-class Kinetics-700 and Moments-In-Time dataset"
}
}
}
26 changes: 4 additions & 22 deletions python/bindings/PyActionNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,16 @@ static int PyActionNet_Init( PyActionNet_Object* self, PyObject *args, PyObject
// free the arguments array
free(argv);
}
else if( model != NULL )
{
LogVerbose(LOG_PY_INFERENCE "actionNet loading custom model '%s'\n", model);

// load the network using custom model parameters
Py_BEGIN_ALLOW_THREADS
self->net = actionNet::Create(model, labels, input_blob, output_blob);
Py_END_ALLOW_THREADS
}
else
{
LogVerbose(LOG_PY_INFERENCE "actionNet loading build-in network '%s'\n", network);

// parse the selected built-in network
actionNet::NetworkType networkType = actionNet::NetworkTypeFromStr(network);
LogDebug(LOG_PY_INFERENCE "actionNet loading custom model '%s'\n", model);

if( networkType == actionNet::CUSTOM )
{
PyErr_SetString(PyExc_Exception, LOG_PY_INFERENCE "actionNet invalid built-in network was requested");
return -1;
}

// load the built-in network
// load the network using custom model parameters
Py_BEGIN_ALLOW_THREADS
self->net = actionNet::Create(networkType);
self->net = actionNet::Create(model != NULL ? model : network, labels, input_blob, output_blob);
Py_END_ALLOW_THREADS
}

// confirm the network loaded
if( !self->net )
{
Expand Down

0 comments on commit 9933995

Please sign in to comment.