From 26f13a3e08c48b1f51a3bffa38fbe48a8b13814f Mon Sep 17 00:00:00 2001 From: Nikolas Markou Date: Mon, 6 Apr 2020 15:20:39 +0300 Subject: [PATCH 1/2] #426 Fixed ResizeNearest and ResizeBilinear plugins for trt branch 5.1 --- ResizeBilinear.cu | 172 ++++++++++++++++++++++ ResizeBilinear.hpp | 288 ++++++++++++++++++++++++++++++++++++ ResizeNearest.cu | 186 ++++++++++++++---------- ResizeNearest.hpp | 306 +++++++++++++++++++++++++++++++-------- builtin_op_importers.cpp | 24 ++- 5 files changed, 834 insertions(+), 142 deletions(-) create mode 100644 ResizeBilinear.cu create mode 100644 ResizeBilinear.hpp diff --git a/ResizeBilinear.cu b/ResizeBilinear.cu new file mode 100644 index 00000000..ec9dbc5d --- /dev/null +++ b/ResizeBilinear.cu @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include +#include +#include "ResizeBilinear.hpp" + +//================================================================== + +// TODO: Move this to a common header +inline bool +is_CHW( + nvinfer1::Dims const& dims) +{ + return (dims.nbDims == 3 && + dims.type[0] == nvinfer1::DimensionType::kCHANNEL && + dims.type[1] == nvinfer1::DimensionType::kSPATIAL && + dims.type[2] == nvinfer1::DimensionType::kSPATIAL); +} + +//================================================================== + +nvinfer1::Dims +ResizeBilinearPlugin::getOutputDimensions( + int index, + const nvinfer1::Dims *inputDims, + int nbInputs) +{ + assert(nbInputs == 1); + nvinfer1::Dims const& input = inputDims[0]; + assert(is_CHW(input)); + assert(_ndims == 2); + assert(index == 0); + nvinfer1::Dims output = {0}; + output.nbDims = input.nbDims; + + output.d[0] = input.d[0]; + output.d[1] = input.d[1] * _scale[0]; + output.d[2] = input.d[2] * _scale[1]; + + output.type[0] = input.type[0]; + output.type[1] = input.type[1]; + output.type[2] = input.type[2]; + + return output; +} + +//================================================================== + +int +ResizeBilinearPlugin::initialize() +{ + _output_dims = this->getOutputDimensions(0, &this->getInputDims(0), 1); + assert(is_CHW(this->getInputDims(0))); + assert(is_CHW(_output_dims)); + assert(_ndims == 2); + return 0; +} + +//================================================================== + +template __global__ +void +resize_bilinear_kernel_2d( + int nbatch, + float2 scale, + int2 osize, + Data const* idata, int istride, int ibatchstride, + Data* odata, int ostride, int obatchstride) +{ + const int x0 = threadIdx.x + blockIdx.x * blockDim.x; + const int y0 = threadIdx.y + blockIdx.y * blockDim.y; + const int z0 = blockIdx.z; + const int src_cols = int(osize.x / scale.x); + const int src_rows = int(osize.y / scale.y); + + for( int batch=z0; batchgetInputDims(0); + int nchan = input_dims.d[0]; + + if (_ndims != 2) + { + return -1; + } + + const float2 scale = {_scale[1], _scale[0]}; + const int2 osize = {_output_dims.d[2], _output_dims.d[1]}; + const int istride = input_dims.d[2]; + const int ostride = _output_dims.d[2]; + const int ibatchstride = input_dims.d[1] * istride; + const int obatchstride = _output_dims.d[1] * ostride; + dim3 block(32, 16); + dim3 grid((osize.x - 1) / block.x + 1, + (osize.y - 1) / block.y + 1, + std::min(batchSize * nchan, 65535)); + + if (getDataType() == nvinfer1::DataType::kFLOAT) + { + resize_bilinear_kernel_2d<<>>( + batchSize * nchan, scale, osize, + static_cast( inputs[0]), istride, ibatchstride, + static_cast(outputs[0]), ostride, obatchstride); + } + else + { + resize_bilinear_kernel_2d<<>>( + batchSize * nchan, scale, osize, + static_cast<__half const*>( inputs[0]), istride, ibatchstride, + static_cast<__half* >(outputs[0]), ostride, obatchstride); + } + + return cudaGetLastError() != cudaSuccess; +} + +//================================================================== diff --git a/ResizeBilinear.hpp b/ResizeBilinear.hpp new file mode 100644 index 00000000..d3b3efec --- /dev/null +++ b/ResizeBilinear.hpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#pragma once +#include + +#include "plugin.hpp" +#include "serialize.hpp" + +#include + +namespace +{ +constexpr const char *RESIZE_BILINEAR_PLUGIN_VERSION{"001"}; +constexpr const char *RESIZE_BILINEAR_PLUGIN_NAME{"ResizeBilinear"}; +} // namespace + +//================================================================== + +class ResizeBilinearPlugin final : + public onnx2trt::PluginV2 +{ + //-------------------------------------------------------------------- + + int _ndims = 0; + float _scale[nvinfer1::Dims::MAX_DIMS] = {0.0f}; + nvinfer1::Dims _output_dims = {0}; + + //-------------------------------------------------------------------- + protected: + + //-------------------------------------------------------------------- + + void deserialize( + void const *serialData, + size_t serialLength) + { + deserializeBase(serialData, serialLength); + deserialize_value(&serialData, &serialLength, &_ndims); + deserialize_value(&serialData, &serialLength, &_scale); + } + + //-------------------------------------------------------------------- + + size_t getSerializationSize() const override + { + return + serialized_size(_ndims) + + serialized_size(_scale) + + getBaseSerializationSize(); + } + + //-------------------------------------------------------------------- + + void serialize( + void *buffer) const override + { + serializeBase(buffer); + serialize_value(&buffer, _ndims); + serialize_value(&buffer, _scale); + } + + //-------------------------------------------------------------------- + + public: + + //-------------------------------------------------------------------- + + ResizeBilinearPlugin( + std::vector const &scale) : + _ndims(scale.size()) + { + assert(scale.size() <= nvinfer1::Dims::MAX_DIMS); + std::copy(scale.begin(), scale.end(), _scale); + } + + //-------------------------------------------------------------------- + + ResizeBilinearPlugin( + void const *serialData, + size_t serialLength) + { + this->deserialize( + serialData, + serialLength); + } + + //-------------------------------------------------------------------- + + virtual + const char + *getPluginType() const override + { + return RESIZE_BILINEAR_PLUGIN_NAME; + } + + //-------------------------------------------------------------------- + + virtual + void + destroy() override + { + delete this; + } + + //-------------------------------------------------------------------- + + virtual + nvinfer1::IPluginV2 + *clone() const override + { + return new ResizeBilinearPlugin{ + std::vector(_scale, _scale + _ndims)}; + } + + //-------------------------------------------------------------------- + + virtual + const char + *getPluginVersion() const override + { + return RESIZE_BILINEAR_PLUGIN_VERSION; + } + + //-------------------------------------------------------------------- + + virtual + void + setPluginNamespace( + const char *pluginNamespace) override + {} + + //-------------------------------------------------------------------- + + virtual + const char* + getPluginNamespace() const override + { + return ""; + } + + //-------------------------------------------------------------------- + + virtual + int + getNbOutputs() const override + { + return 1; + } + + //-------------------------------------------------------------------- + + virtual + nvinfer1::Dims + getOutputDimensions( + int index, + const nvinfer1::Dims *inputs, + int nbInputDims) override; + + //-------------------------------------------------------------------- + + virtual int initialize() override; + + //-------------------------------------------------------------------- + + int enqueue( + int batchSize, + const void *const *inputs, + void **outputs, + void *workspace, + cudaStream_t stream) override; + + //-------------------------------------------------------------------- +}; + +//================================================================== + +class ResizeBilinearPluginCreator : + public nvinfer1::IPluginCreator +{ + private: + //-------------------------------------------------------------------- + + std::string mNamespace; + + //-------------------------------------------------------------------- + public: + //-------------------------------------------------------------------- + + ResizeBilinearPluginCreator() + {} + + //-------------------------------------------------------------------- + + ~ResizeBilinearPluginCreator() + {} + + //-------------------------------------------------------------------- + + const char* + getPluginName() const + { + return RESIZE_BILINEAR_PLUGIN_NAME; + } + + //-------------------------------------------------------------------- + + const char* + getPluginVersion() const + { + return RESIZE_BILINEAR_PLUGIN_VERSION; + } + + //-------------------------------------------------------------------- + + const nvinfer1::PluginFieldCollection* + getFieldNames() + { + std::cerr << "Function not implemented" << std::endl; + return nullptr; + } + + //-------------------------------------------------------------------- + + nvinfer1::IPluginV2* + createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) + { + std::cerr << "Function not implemented" << std::endl; + return nullptr; + } + + //-------------------------------------------------------------------- + + nvinfer1::IPluginV2* + deserializePlugin( + const char* name, + const void *serialData, + size_t serialLength) + { + return new ResizeBilinearPlugin{ + serialData, + serialLength}; + } + + //-------------------------------------------------------------------- + + void + setPluginNamespace( + const char *libNamespace) + { + mNamespace = libNamespace; + } + + //-------------------------------------------------------------------- + + const char* + getPluginNamespace() const + { + return mNamespace.c_str(); + } + + //-------------------------------------------------------------------- +}; + +//================================================================== + +REGISTER_TENSORRT_PLUGIN(ResizeBilinearPluginCreator); diff --git a/ResizeNearest.cu b/ResizeNearest.cu index e45b6c42..564b3e06 100644 --- a/ResizeNearest.cu +++ b/ResizeNearest.cu @@ -20,101 +20,137 @@ * DEALINGS IN THE SOFTWARE. */ -#include + #include #include +#include #include "ResizeNearest.hpp" +//================================================================== + // TODO: Move this to a common header -inline bool is_CHW(nvinfer1::Dims const& dims) { +inline bool +is_CHW( + nvinfer1::Dims const& dims) +{ return (dims.nbDims == 3 && dims.type[0] == nvinfer1::DimensionType::kCHANNEL && dims.type[1] == nvinfer1::DimensionType::kSPATIAL && dims.type[2] == nvinfer1::DimensionType::kSPATIAL); } -nvinfer1::Dims ResizeNearestPlugin::getOutputDimensions(int index, - const nvinfer1::Dims *inputDims, - int nbInputs) { - assert(nbInputs == 1); - nvinfer1::Dims const& input = inputDims[0]; - assert(is_CHW(input)); - assert(_ndims == 2); - assert(index == 0); - nvinfer1::Dims output; - output.nbDims = input.nbDims; - int s = 0; - for( int d=0; dgetOutputDimensions(0, &this->getInputDims(0), 1); - assert(is_CHW(this->getInputDims(0))); - assert(is_CHW(_output_dims)); - assert(_ndims == 2); - return 0; +//================================================================== + +int +ResizeNearestPlugin::initialize() +{ + _output_dims = this->getOutputDimensions(0, &this->getInputDims(0), 1); + assert(is_CHW(this->getInputDims(0))); + assert(is_CHW(_output_dims)); + assert(_ndims == 2); + return 0; } -template -__global__ -void resize_nearest_kernel_2d(int nbatch, - float2 scale, - int2 osize, - Data const* idata, int istride, int ibatchstride, - Data* odata, int ostride, int obatchstride) { - int x0 = threadIdx.x + blockIdx.x * blockDim.x; - int y0 = threadIdx.y + blockIdx.y * blockDim.y; - int z0 = blockIdx.z; - for( int batch=z0; batch __global__ +void +resize_nearest_kernel_2d( + int nbatch, + float2 scale, + int2 osize, + Data const* idata, int istride, int ibatchstride, + Data* odata, int ostride, int obatchstride) +{ + const int x0 = threadIdx.x + blockIdx.x * blockDim.x; + const int y0 = threadIdx.y + blockIdx.y * blockDim.y; + const int z0 = blockIdx.z; + + for( int batch=z0; batchgetInputDims(0); - int nchan = input_dims.d[0]; - switch( _ndims ) { - case 2: { - float2 scale = {_scale[1], _scale[0]}; - int2 osize = {_output_dims.d[2], _output_dims.d[1]}; - int istride = input_dims.d[2]; - int ostride = _output_dims.d[2]; - int ibatchstride = input_dims.d[1] * istride; - int obatchstride = _output_dims.d[1] * ostride; +//================================================================== + +int +ResizeNearestPlugin::enqueue( + int batchSize, + const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) +{ + auto const& input_dims = this->getInputDims(0); + int nchan = input_dims.d[0]; + + if (_ndims != 2) + { + return -1; + } + + const float2 scale = {_scale[1], _scale[0]}; + const int2 osize = {_output_dims.d[2], _output_dims.d[1]}; + const int istride = input_dims.d[2]; + const int ostride = _output_dims.d[2]; + const int ibatchstride = input_dims.d[1] * istride; + const int obatchstride = _output_dims.d[1] * ostride; dim3 block(32, 16); dim3 grid((osize.x - 1) / block.x + 1, - (osize.y - 1) / block.y + 1, - std::min(batchSize * nchan, 65535)); - if (getDataType()==nvinfer1::DataType::kFLOAT) { - resize_nearest_kernel_2d<<>> - (batchSize * nchan, scale, osize, - static_cast( inputs[0]), istride, ibatchstride, - static_cast(outputs[0]), ostride, obatchstride); - } else { - resize_nearest_kernel_2d<<>> - (batchSize * nchan, scale, osize, - static_cast<__half const*>( inputs[0]), istride, ibatchstride, - static_cast<__half* >(outputs[0]), ostride, obatchstride); + (osize.y - 1) / block.y + 1, + std::min(batchSize * nchan, 65535)); + + if (getDataType()==nvinfer1::DataType::kFLOAT) + { + resize_nearest_kernel_2d<<>>( + batchSize * nchan, scale, osize, + static_cast( inputs[0]), istride, ibatchstride, + static_cast(outputs[0]), ostride, obatchstride); + } + else + { + resize_nearest_kernel_2d<<>>( + batchSize * nchan, scale, osize, + static_cast<__half const*>( inputs[0]), istride, ibatchstride, + static_cast<__half* >(outputs[0]), ostride, obatchstride); } + return cudaGetLastError() != cudaSuccess; - } - default: return -1; - } } + +//================================================================== \ No newline at end of file diff --git a/ResizeNearest.hpp b/ResizeNearest.hpp index 1a43db3c..235951c4 100644 --- a/ResizeNearest.hpp +++ b/ResizeNearest.hpp @@ -28,81 +28,259 @@ #include -namespace { - constexpr const char* RESIZE_PLUGIN_VERSION{"001"}; - constexpr const char* RESIZE_PLUGIN_NAME{"ResizeNearest"}; -} - -class ResizeNearestPlugin final : public onnx2trt::PluginV2 { - int _ndims; - float _scale[nvinfer1::Dims::MAX_DIMS]; - nvinfer1::Dims _output_dims; -protected: - void deserialize(void const* serialData, size_t serialLength) { - deserializeBase(serialData, serialLength); - deserialize_value(&serialData, &serialLength, &_ndims); - deserialize_value(&serialData, &serialLength, &_scale); - } - size_t getSerializationSize() const override { - return serialized_size(_ndims) + serialized_size(_scale) + getBaseSerializationSize(); - } - void serialize(void *buffer) const override { - serializeBase(buffer); - serialize_value(&buffer, _ndims); - serialize_value(&buffer, _scale); - } -public: - ResizeNearestPlugin(std::vector const& scale) - : _ndims(scale.size()) { - assert(scale.size() <= nvinfer1::Dims::MAX_DIMS); - std::copy(scale.begin(), scale.end(), _scale); - } - ResizeNearestPlugin(void const* serialData, size_t serialLength) { - this->deserialize(serialData, serialLength); - } - virtual const char* getPluginType() const override { return RESIZE_PLUGIN_NAME; } - - virtual void destroy() override { delete this; } - - virtual nvinfer1::IPluginV2* clone() const override { return new ResizeNearestPlugin{std::vector(_scale, _scale + _ndims)}; } - - virtual const char* getPluginVersion() const override { return RESIZE_PLUGIN_VERSION; } - - virtual void setPluginNamespace(const char* pluginNamespace) override {} - - virtual const char* getPluginNamespace() const override { return ""; } - - virtual int getNbOutputs() const override { return 1; } - virtual nvinfer1::Dims getOutputDimensions(int index, - const nvinfer1::Dims *inputs, int nbInputDims) override; - virtual int initialize() override; - int enqueue(int batchSize, - const void *const *inputs, void **outputs, - void *workspace, cudaStream_t stream) override; -}; +namespace +{ +constexpr const char *RESIZE_PLUGIN_VERSION{"001"}; +constexpr const char *RESIZE_PLUGIN_NAME{"ResizeNearest"}; +} // namespace -class ResizeNearestPluginCreator : public nvinfer1::IPluginCreator +//================================================================== + +class ResizeNearestPlugin final : + public onnx2trt::PluginV2 { -public: - ResizeNearestPluginCreator() {} + //-------------------------------------------------------------------- + + int _ndims = 0; + float _scale[nvinfer1::Dims::MAX_DIMS] = {0.0f}; + nvinfer1::Dims _output_dims = {0}; + + //-------------------------------------------------------------------- + protected: + + //-------------------------------------------------------------------- + + void deserialize( + void const *serialData, + size_t serialLength) + { + deserializeBase(serialData, serialLength); + deserialize_value(&serialData, &serialLength, &_ndims); + deserialize_value(&serialData, &serialLength, &_scale); + } + + //-------------------------------------------------------------------- + + size_t getSerializationSize() const override + { + return + serialized_size(_ndims) + + serialized_size(_scale) + + getBaseSerializationSize(); + } + + //-------------------------------------------------------------------- + + void serialize( + void *buffer) const override + { + serializeBase(buffer); + serialize_value(&buffer, _ndims); + serialize_value(&buffer, _scale); + } + + //-------------------------------------------------------------------- + + public: + + //-------------------------------------------------------------------- + + ResizeNearestPlugin( + std::vector const &scale) : + _ndims(scale.size()) + { + assert(scale.size() <= nvinfer1::Dims::MAX_DIMS); + std::copy(scale.begin(), scale.end(), _scale); + } + + //-------------------------------------------------------------------- + + ResizeNearestPlugin( + void const *serialData, + size_t serialLength) + { + this->deserialize( + serialData, + serialLength); + } + + //-------------------------------------------------------------------- + + virtual + const char + *getPluginType() const override + { + return RESIZE_PLUGIN_NAME; + } + + //-------------------------------------------------------------------- + + virtual + void + destroy() override + { + delete this; + } + + //-------------------------------------------------------------------- + + virtual + nvinfer1::IPluginV2 + *clone() const override + { + return new ResizeNearestPlugin{ + std::vector(_scale, _scale + _ndims)}; + } + + //-------------------------------------------------------------------- + + virtual + const char + *getPluginVersion() const override + { + return RESIZE_PLUGIN_VERSION; + } + + //-------------------------------------------------------------------- + + virtual + void + setPluginNamespace( + const char *pluginNamespace) override + {} - ~ResizeNearestPluginCreator() {} + //-------------------------------------------------------------------- - const char* getPluginName() const { return RESIZE_PLUGIN_NAME; } + virtual + const char* + getPluginNamespace() const override + { + return ""; + } - const char* getPluginVersion() const { return RESIZE_PLUGIN_VERSION; } + //-------------------------------------------------------------------- - const nvinfer1::PluginFieldCollection* getFieldNames() { std::cerr<< "Function not implemented" << std::endl; return nullptr; } + virtual + int + getNbOutputs() const override + { + return 1; + } - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) { std::cerr<< "Function not implemented" << std::endl; return nullptr; } + //-------------------------------------------------------------------- - nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) { return new ResizeNearestPlugin{serialData, serialLength}; } + virtual + nvinfer1::Dims + getOutputDimensions( + int index, + const nvinfer1::Dims *inputs, + int nbInputDims) override; - void setPluginNamespace(const char* libNamespace) { mNamespace = libNamespace; } + //-------------------------------------------------------------------- + + virtual int initialize() override; + + //-------------------------------------------------------------------- + + int enqueue( + int batchSize, + const void *const *inputs, + void **outputs, + void *workspace, + cudaStream_t stream) override; + + //-------------------------------------------------------------------- +}; + +//================================================================== + +class ResizeNearestPluginCreator : + public nvinfer1::IPluginCreator +{ + private: + //-------------------------------------------------------------------- - const char* getPluginNamespace() const { return mNamespace.c_str(); } -private: std::string mNamespace; + + //-------------------------------------------------------------------- + public: + //-------------------------------------------------------------------- + + ResizeNearestPluginCreator() + {} + + //-------------------------------------------------------------------- + + ~ResizeNearestPluginCreator() + {} + + //-------------------------------------------------------------------- + + const char* + getPluginName() const + { + return RESIZE_PLUGIN_NAME; + } + + //-------------------------------------------------------------------- + + const char* + getPluginVersion() const + { + return RESIZE_PLUGIN_VERSION; + } + + //-------------------------------------------------------------------- + + const nvinfer1::PluginFieldCollection* + getFieldNames() + { + std::cerr << "Function not implemented" << std::endl; + return nullptr; + } + + //-------------------------------------------------------------------- + + nvinfer1::IPluginV2* + createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) + { + std::cerr << "Function not implemented" << std::endl; + return nullptr; + } + + //-------------------------------------------------------------------- + + nvinfer1::IPluginV2* + deserializePlugin( + const char* name, + const void *serialData, + size_t serialLength) + { + return new ResizeNearestPlugin{serialData, serialLength}; + } + + //-------------------------------------------------------------------- + + void + setPluginNamespace( + const char *libNamespace) + { + mNamespace = libNamespace; + } + + //-------------------------------------------------------------------- + + const char* + getPluginNamespace() const + { + return mNamespace.c_str(); + } + + //-------------------------------------------------------------------- }; +//================================================================== + REGISTER_TENSORRT_PLUGIN(ResizeNearestPluginCreator); diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index 655976d4..d19422da 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -25,6 +25,7 @@ #include "plugin.hpp" #include "FancyActivation.hpp" #include "ResizeNearest.hpp" +#include "ResizeBilinear.hpp" #include "Split.hpp" #include "InstanceNormalization.hpp" @@ -2006,9 +2007,26 @@ DEFINE_BUILTIN_OP_IMPORTER(Upsample) { } auto scale = {height_scale, width_scale}; auto mode = attrs.get("mode", "nearest"); - ASSERT(mode == "nearest", ErrorCode::kUNSUPPORTED_NODE); - RETURN_FIRST_OUTPUT( - ctx->addPluginV2(new ResizeNearestPlugin(scale), {&inputs.at(0).tensor()})); + + ASSERT( + mode == "nearest" || mode == "linear", + ErrorCode::kUNSUPPORTED_NODE); + + if (mode == "nearest"){ + RETURN_FIRST_OUTPUT( + ctx->addPluginV2( + new ResizeNearestPlugin(scale), + {&inputs.at(0).tensor()})); + } + + if (mode == "linear"){ + RETURN_FIRST_OUTPUT( + ctx->addPluginV2( + new ResizeBilinearPlugin(scale), + {&inputs.at(0).tensor()})); + } + + ASSERT(false, ErrorCode::kUNSUPPORTED_NODE); } } // namespace From 145f490d4f56bccdab0167db709b27aa27c9bf4b Mon Sep 17 00:00:00 2001 From: Nikolas Markou Date: Mon, 6 Apr 2020 15:22:56 +0300 Subject: [PATCH 2/2] #426 Fixed ResizeNearest and ResizeBilinear plugins for trt branch 5.1, added new plugin in CMakeLists --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 33e9c346..9176086c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -70,6 +70,7 @@ set(ONNX2TRT_PATCH 0) set(PLUGIN_SOURCES FancyActivation.cu ResizeNearest.cu + ResizeBilinear.cu Split.cu InstanceNormalization.cpp plugin.cpp