Skip to content

Commit

Permalink
automatic model downloads for backgroundNet
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Dec 19, 2022
1 parent 9933995 commit 17c7ba4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 104 deletions.
75 changes: 28 additions & 47 deletions c/backgroundNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@

#include "backgroundNet.h"
#include "tensorConvert.h"
#include "modelDownloader.h"

#include "commandLine.h"
#include "filesystem.h"


// constructor
backgroundNet::backgroundNet() : tensorNet()
{
mNetworkType = CUSTOM;

}


Expand All @@ -41,23 +43,21 @@ backgroundNet::~backgroundNet()


// Create
backgroundNet* backgroundNet::Create( backgroundNet::NetworkType networkType, uint32_t maxBatchSize,
backgroundNet* backgroundNet::Create( const char* network, uint32_t maxBatchSize,
precisionType precision, deviceType device, bool allowGPUFallback )
{
backgroundNet* net = NULL;
nlohmann::json model;

if( networkType == U2NET )
net = Create("networks/Background-U2Net/u2net.onnx", BACKGROUNDNET_DEFAULT_INPUT, BACKGROUNDNET_DEFAULT_OUTPUT, maxBatchSize, precision, device, allowGPUFallback);

if( !net )
{
LogError(LOG_TRT "backgroundNet -- invalid built-in model '%s' requested\n", backgroundNet::NetworkTypeToStr(networkType));
if( !DownloadModel(BACKGROUNDNET_MODEL_TYPE, network, model) )
return NULL;
}

net->mNetworkType = networkType;

return net;
std::string model_dir = "networks/" + model["dir"].get<std::string>() + "/";
std::string model_path = model_dir + JSON_STR(model["model"]);;
std::string input = JSON_STR_DEFAULT(model["input"], BACKGROUNDNET_DEFAULT_INPUT);
std::string output = JSON_STR_DEFAULT(model["output"], BACKGROUNDNET_DEFAULT_OUTPUT);

return Create(model_path.c_str(), input.c_str(), output.c_str(),
maxBatchSize, precision, device, allowGPUFallback);
}


Expand All @@ -80,9 +80,7 @@ backgroundNet* backgroundNet::Create( const commandLine& cmdLine )
modelName = cmdLine.GetString("model", "u2net");

// parse the network type
const backgroundNet::NetworkType type = NetworkTypeFromStr(modelName);

if( type == backgroundNet::CUSTOM )
if( !FindModel(BACKGROUNDNET_MODEL_TYPE, modelName) )
{
const char* input = cmdLine.GetString("input_blob", BACKGROUNDNET_DEFAULT_INPUT);
const char* output = cmdLine.GetString("output_blob", BACKGROUNDNET_DEFAULT_OUTPUT);
Expand All @@ -97,7 +95,7 @@ backgroundNet* backgroundNet::Create( const commandLine& cmdLine )
else
{
// create from pretrained model
net = backgroundNet::Create(type);
net = backgroundNet::Create(modelName, DEFAULT_MAX_BATCH_SIZE);
}

if( !net )
Expand All @@ -116,6 +114,18 @@ backgroundNet* backgroundNet::Create( const char* model_path, const char* input,
uint32_t maxBatchSize, precisionType precision,
deviceType device, bool allowGPUFallback )
{
// check for built-in model string
if( FindModel(BACKGROUNDNET_MODEL_TYPE, model_path) )
{
return Create(model_path, maxBatchSize, precision, device, allowGPUFallback);
}
else if( fileExtension(model_path).length() == 0 )
{
LogError(LOG_TRT "couldn't find built-in background model '%s'\n", model_path);
return NULL;
}

// load custom model
backgroundNet* net = new backgroundNet();

if( !net )
Expand All @@ -137,7 +147,7 @@ bool backgroundNet::init( const char* model_path, const char* input, const char*
return NULL;

LogInfo("\n");
LogInfo("backgroundNet -- loading feature matching network model from:\n");
LogInfo("backgroundNet -- loading background network from:\n");
LogInfo(" -- model %s\n", model_path);
LogInfo(" -- input_blob '%s'\n", input);
LogInfo(" -- output_blob '%s'\n", output);
Expand All @@ -153,35 +163,6 @@ bool backgroundNet::init( const char* model_path, const char* input, const char*

return true;
}


// NetworkTypeFromStr
backgroundNet::NetworkType backgroundNet::NetworkTypeFromStr( const char* modelName )
{
if( !modelName )
return backgroundNet::CUSTOM;

backgroundNet::NetworkType type = backgroundNet::CUSTOM;

if( strcasecmp(modelName, "u2net") == 0 )
type = backgroundNet::U2NET;
else
type = backgroundNet::CUSTOM;

return type;
}


// NetworkTypeToStr
const char* backgroundNet::NetworkTypeToStr( backgroundNet::NetworkType network )
{
switch(network)
{
case backgroundNet::U2NET: return "u2net";
}

return "Custom";
}


// from backgroundNet.cu
Expand Down
41 changes: 7 additions & 34 deletions c/backgroundNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
*/
#define BACKGROUNDNET_DEFAULT_OUTPUT "output_0"

/**
* The model type for backgroundNet in data/networks/models.json
* @ingroup backgroundNet
*/
#define BACKGROUNDNET_MODEL_TYPE "background"

/**
* Standard command-line options able to be passed to backgroundNet::Create()
Expand All @@ -63,30 +68,9 @@ class backgroundNet : public tensorNet
{
public:
/**
* Network choice enumeration.
*/
enum NetworkType
{
CUSTOM, /**< Custom model provided by the user */
U2NET, /**< U2-Net (U-Square Net) for Salient Object Detection */
};

/**
* Parse a string to one of the built-in pretrained models.
* Valid names are "u2net", ect.
* @returns one of the backgroundNet::NetworkType enums, or backgroundNet::CUSTOM on invalid string.
* Load a pre-trained model.
*/
static NetworkType NetworkTypeFromStr( const char* model_name );

/**
* Convert a NetworkType enum to a string.
*/
static const char* NetworkTypeToStr( NetworkType network );

/**
* Load a new network instance
*/
static backgroundNet* Create( NetworkType networkType=U2NET, uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
static backgroundNet* Create( const char* network="u2net", uint32_t maxBatchSize=DEFAULT_MAX_BATCH_SIZE,
precisionType precision=TYPE_FASTEST, deviceType device=DEVICE_GPU, bool allowGPUFallback=true );

/**
Expand Down Expand Up @@ -172,23 +156,12 @@ class backgroundNet : public tensorNet
*/
bool Process( void* input, void* output, uint32_t width, uint32_t height, imageFormat format,
cudaFilterMode filter=FILTER_LINEAR, bool maskAlpha=true );

/**
* Retrieve the network type (alexnet or googlenet)
*/
inline NetworkType GetNetworkType() const { return mNetworkType; }

/**
* Retrieve a string describing the network name.
*/
inline const char* GetNetworkName() const { return NetworkTypeToStr(mNetworkType); }

protected:
backgroundNet();

bool init(const char* model_path, const char* input, const char* output, uint32_t maxBatchSize, precisionType precision, deviceType device, bool allowGPUFallback );

NetworkType mNetworkType;
};


Expand Down
10 changes: 10 additions & 0 deletions data/networks/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -438,5 +438,15 @@
"output": "350",
"description": "ResNet-34 trained on 1040-class Kinetics-700 and Moments-In-Time dataset"
}
},

"background": {
"u2net": {
"url": "https://nvidia.box.com/shared/static/pp72renayt4do23sxyqtzu406nsftxg4.gz",
"tar": "Background-U2Net.tar.gz",
"dir": "Background-U2Net",
"model": "u2net.onnx",
"description": "U2-Net (U-Square Net) for Salient Object Detection"
}
}
}
27 changes: 4 additions & 23 deletions python/bindings/PyBackgroundNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
typedef struct {
PyTensorNet_Object base;
backgroundNet* net; // object instance
PyObject* depthField; // depth field cudaImage
} PyBackgroundNet_Object;


Expand Down Expand Up @@ -73,7 +72,7 @@ static int PyBackgroundNet_Init( PyBackgroundNet_Object* self, PyObject *args, P
// determine whether to use argv or built-in network
if( argList != NULL && PyList_Check(argList) && PyList_Size(argList) > 0 )
{
LogVerbose(LOG_PY_INFERENCE "backgroundNet loading network using argv command line params\n");
LogDebug(LOG_PY_INFERENCE "backgroundNet loading network using argv command line params\n");

// parse the python list into char**
const size_t argc = PyList_Size(argList);
Expand Down Expand Up @@ -113,31 +112,13 @@ static int PyBackgroundNet_Init( PyBackgroundNet_Object* self, PyObject *args, P
// free the arguments array
free(argv);
}
else if( model != NULL )
{
LogVerbose(LOG_PY_INFERENCE "backgroundNet loading custom model '%s'\n", model);

// load the network using custom model parameters
Py_BEGIN_ALLOW_THREADS
self->net = backgroundNet::Create(model, input_blob, output_blob);
Py_END_ALLOW_THREADS
}
else
{
LogVerbose(LOG_PY_INFERENCE "backgroundNet loading build-in network '%s'\n", network);
LogDebug(LOG_PY_INFERENCE "backgroundNet loading custom model '%s'\n", model);

// parse the selected built-in network
backgroundNet::NetworkType networkType = backgroundNet::NetworkTypeFromStr(network);

if( networkType == backgroundNet::CUSTOM )
{
PyErr_SetString(PyExc_Exception, LOG_PY_INFERENCE "backgroundNet invalid built-in network was requested");
return -1;
}

// load the built-in network
// load the network using custom model parameters
Py_BEGIN_ALLOW_THREADS
self->net = backgroundNet::Create(networkType);
self->net = backgroundNet::Create(model != NULL ? model : network, input_blob, output_blob);
Py_END_ALLOW_THREADS
}

Expand Down

0 comments on commit 17c7ba4

Please sign in to comment.