Skip to content

Commit

Permalink
Add trt instance norm plugin (open-mmlab#16)
Browse files Browse the repository at this point in the history
* add trt instance norm plugin

* last line empty

* fix clang format

* fix grid_sample clang format

* remove redundant

* fix lint

* refine codes

* fix clang format

* clang format

* clang format

* clang format
  • Loading branch information
AllentDan authored Jul 16, 2021
1 parent 1718501 commit b9e64f9
Show file tree
Hide file tree
Showing 12 changed files with 418 additions and 55 deletions.
2 changes: 1 addition & 1 deletion backend_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)

# build TensorRT ops
# build ONNXRUNTIME ops
if (BUILD_ONNXRUNTIME_OPS)
message("Build ONNXRUNTIME custom ops.")
add_subdirectory (onnxruntime)
Expand Down
3 changes: 2 additions & 1 deletion backend_ops/onnxruntime/grid_sample/grid_sample.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "grid_sample.h"

#include <cmath>

#include "grid_sample.h"
#include "ort_mmcv_utils.h"

#define MIN(a, b) (((a) < (b)) ? (a) : (b))
Expand Down
1 change: 1 addition & 0 deletions backend_ops/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ set(PLUGIN_LISTS scatternd
nms
roi_align
batched_nms
instance_norm
multi_level_roi_align)

foreach(PLUGIN_ITER ${PLUGIN_LISTS})
Expand Down
5 changes: 5 additions & 0 deletions backend_ops/tensorrt/common/trt_plugin_helper.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
#ifndef TRT_PLUGIN_HELPER_HPP
#define TRT_PLUGIN_HELPER_HPP
#include <cudnn.h>

#include <iostream>
#include <stdexcept>

#include "NvInferPlugin.h"

cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
cudnnDataType_t* cudnn_dtype);

// Enumerator for status
typedef enum {
STATUS_SUCCESS = 0,
Expand Down
15 changes: 15 additions & 0 deletions backend_ops/tensorrt/common_impl/trt_cuda_helper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,18 @@ void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
template void memcpyPermute<float>(float *dst, const float *src, int *src_size,
int *permute, int src_dim,
cudaStream_t stream);

cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
cudnnDataType_t *cudnn_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
*cudnn_dtype = CUDNN_DATA_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*cudnn_dtype = CUDNN_DATA_HALF;
break;
default:
return CUDNN_STATUS_BAD_PARAM;
}
return CUDNN_STATUS_SUCCESS;
}
200 changes: 200 additions & 0 deletions backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// Modified from:
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.cpp

#include "trt_instance_norm.hpp"

#include <cuda_fp16.h>

#include <stdexcept>

#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmlab {
namespace {
constexpr const char* PLUGIN_VERSION{"1"};
constexpr const char* PLUGIN_NAME{"TRTInstanceNormalization"};
} // namespace

TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name,
float epsilon)
: TRTPluginBase(name), mEpsilon(epsilon) {}

TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name,
void const* serialData,
size_t serialLength)
: TRTPluginBase(name) {
deserialize_value(&serialData, &serialLength, &mEpsilon);
}

TRTInstanceNormalization::~TRTInstanceNormalization() {}

// TRTInstanceNormalization returns one output.
int TRTInstanceNormalization::getNbOutputs() const { return 1; }

DimsExprs TRTInstanceNormalization::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) {
nvinfer1::DimsExprs output(inputs[0]);
return output;
}

size_t TRTInstanceNormalization::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
int n = inputs[0].dims.d[0];
int c = inputs[0].dims.d[1];
int elem_size = getElementSize(inputs[1].type);
return getAlignedSize(n * c * elem_size) * 2;
}

int TRTInstanceNormalization::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) {
nvinfer1::Dims input_dims = inputDesc[0].dims;
int n = input_dims.d[0];
int c = input_dims.d[1];
int h = input_dims.d[2];
int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1;
int elem_size = getElementSize(inputDesc[1].type);

void* n_scales = (void*)workspace;
void* n_bias = (void*)(workspace + getAlignedSize(n * c * elem_size));

const void* scales = (const void*)inputs[1];
const void* bias = (const void*)inputs[2];

for (int i = 0; i < n; ++i) {
cudaMemcpyAsync(n_scales + i * c * elem_size, scales, c * elem_size,
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(n_bias + i * c * elem_size, bias, c * elem_size,
cudaMemcpyDeviceToDevice, stream);
}

cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1,
n * c, 1, 1);
cudnnDataType_t cudnn_dtype{};
convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype);
cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c,
h, w);
cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c,
h, w);
float alpha = 1;
float beta = 0;
void const* x_ptr = inputs[0];
void* y_ptr = outputs[0];
cudnnSetStream(_cudnn_handle, stream);
// Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical
// overflows (NaNs) for fp32 data in some circumstances. The lower-
// performance CUDNN_BATCHNORM_SPATIAL should be used if this is not
// acceptable.
cudnnBatchNormalizationForwardTraining(
_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, _x_desc,
x_ptr, _y_desc, y_ptr, _b_desc, n_scales, n_bias, 1., nullptr, nullptr,
mEpsilon, nullptr, nullptr);
return 0;
}

size_t TRTInstanceNormalization::getSerializationSize() const {
return serialized_size(mEpsilon);
}

void TRTInstanceNormalization::serialize(void* buffer) const {
serialize_value(&buffer, mEpsilon);
}

bool TRTInstanceNormalization::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) {
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
inOut[pos].type == nvinfer1::DataType::kHALF) &&
inOut[pos].format == nvinfer1::PluginFormat::kLINEAR &&
inOut[pos].type == inOut[0].type);
}

const char* TRTInstanceNormalization::getPluginType() const {
return PLUGIN_NAME;
}

const char* TRTInstanceNormalization::getPluginVersion() const {
return PLUGIN_VERSION;
}

IPluginV2DynamicExt* TRTInstanceNormalization::clone() const {
auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon};
plugin->setPluginNamespace(mPluginNamespace.c_str());
return plugin;
}

nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
return inputTypes[0];
}

// Attach the plugin object to an execution context and grant the plugin the
// access to some context resource.
void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext,
cublasContext* cublasContext,
IGpuAllocator* gpuAllocator) {
_cudnn_handle = cudnnContext;
cudnnCreateTensorDescriptor(&_b_desc);
cudnnCreateTensorDescriptor(&_x_desc);
cudnnCreateTensorDescriptor(&_y_desc);
}

// Detach the plugin object from its execution context.
void TRTInstanceNormalization::detachFromContext() {
cudnnDestroyTensorDescriptor(_y_desc);
cudnnDestroyTensorDescriptor(_x_desc);
cudnnDestroyTensorDescriptor(_b_desc);
}

void TRTInstanceNormalization::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}

// TRTInstanceNormalizationCreator methods
TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(
PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char* TRTInstanceNormalizationCreator::getPluginName() const {
return PLUGIN_NAME;
}

const char* TRTInstanceNormalizationCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}

IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) {
float epsilon = 1e-5;
const PluginField* fields = fc->fields;
for (int i = 0; i < fc->nbFields; ++i) {
const char* attrName = fields[i].name;
if (!strcmp(attrName, "epsilon")) {
epsilon = *(static_cast<const float*>(fields[i].data));
}
}

TRTInstanceNormalization* obj = new TRTInstanceNormalization(name, epsilon);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}

IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) {
TRTInstanceNormalization* obj =
new TRTInstanceNormalization{name, serialData, serialLength};
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
REGISTER_TENSORRT_PLUGIN(TRTInstanceNormalizationCreator);
} // namespace mmlab
98 changes: 98 additions & 0 deletions backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Modified from:
// https://github.com/NVIDIA/TensorRT/blob/master/plugin/instanceNormalizationPlugin/instanceNormalizationPlugin.h

#ifndef TRT_INSTANCE_NORMALIZATION_HPP
#define TRT_INSTANCE_NORMALIZATION_HPP
#include <cudnn.h>

#include <iostream>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

typedef unsigned short half_type;

namespace mmlab {
class TRTInstanceNormalization final : public TRTPluginBase {
public:
TRTInstanceNormalization(const std::string& name, float epsilon);

TRTInstanceNormalization(const std::string& name, void const* serialData,
size_t serialLength);

TRTInstanceNormalization() = delete;

~TRTInstanceNormalization() override;

int getNbOutputs() const override;

// DynamicExt plugins returns DimsExprs class instead of Dims
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;

size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;

int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;

size_t getSerializationSize() const override;

void serialize(void* buffer) const override;

// DynamicExt plugin supportsFormat update.
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;

const char* getPluginType() const override;

const char* getPluginVersion() const override;

nvinfer1::IPluginV2DynamicExt* clone() const override;

nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;

void attachToContext(cudnnContext* cudnn, cublasContext* cublas,
nvinfer1::IGpuAllocator* allocator) override;

void detachFromContext() override;

void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;

private:
float mEpsilon{};
cudnnHandle_t _cudnn_handle{};
cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{};
std::string mPluginNamespace{};
};

class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase {
public:
TRTInstanceNormalizationCreator();

~TRTInstanceNormalizationCreator() override = default;

const char* getPluginName() const override;

const char* getPluginVersion() const override;

nvinfer1::IPluginV2DynamicExt* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;

nvinfer1::IPluginV2DynamicExt* deserializePlugin(
const char* name, const void* serialData, size_t serialLength) override;
};
} // namespace mmlab
#endif // TRT_INSTANCE_NORMALIZATION_HPP
Loading

0 comments on commit b9e64f9

Please sign in to comment.