Skip to content

Commit

Permalink
feat: tensorrt object detector top_k control
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Mar 9, 2022
1 parent 6a89b83 commit 655aa48
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/backends/tensorrt/protoUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace dd
return -1;
}

int findTopK(const std::string source)
int findBBoxCount(const std::string source)
{
caffe::NetParameter net;
if (!TRTReadProtoFromTextFile(source.c_str(), &net))
Expand Down Expand Up @@ -193,7 +193,7 @@ namespace dd
namespace onnx_proto
{
// https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
int findTopK(const std::string &source, const std::string &out_name)
int findBBoxCount(const std::string &source, const std::string &out_name)
{
onnx::ModelProto net;
if (!TRTReadProtoFromBinaryFile(source.c_str(), &net))
Expand Down
4 changes: 2 additions & 2 deletions src/backends/tensorrt/protoUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ namespace dd
bool findInputDimensions(const std::string &source, int &width,
int &height);
int findNClasses(const std::string source, bool bbox);
int findTopK(const std::string source);
int findBBoxCount(const std::string source);
bool isRefinedet(const std::string source);
}

namespace onnx_proto
{
int findTopK(const std::string &source, const std::string &out_name);
int findBBoxCount(const std::string &source, const std::string &out_name);
}

bool TRTReadProtoFromTextFile(const char *filename,
Expand Down
23 changes: 12 additions & 11 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ namespace dd
_datatype = tl._datatype;
_max_batch_size = tl._max_batch_size;
_max_workspace_size = tl._max_workspace_size;
_top_k = tl._top_k;
_results_height = tl._results_height;
_builder = tl._builder;
_builderc = tl._builderc;
_engineFileName = tl._engineFileName;
Expand Down Expand Up @@ -202,9 +202,6 @@ namespace dd
+ this->_mlmodel._repo);
}

if (ad.has("topk"))
_top_k = ad.get("topk").get<int>();

if (ad.has("template"))
{
_template = ad.get("template").get<std::string>();
Expand Down Expand Up @@ -553,9 +550,11 @@ namespace dd
if (_bbox)
{
if (this->_mlmodel.is_onnx_source())
_top_k = onnx_proto::findTopK(this->_mlmodel._model, out_blob);
_results_height
= onnx_proto::findBBoxCount(this->_mlmodel._model, out_blob);
else if (!this->_mlmodel._def.empty())
_top_k = caffe_proto::findTopK(this->_mlmodel._def);
_results_height
= caffe_proto::findBBoxCount(this->_mlmodel._def);
}

if (_nclasses <= 0)
Expand Down Expand Up @@ -677,8 +676,8 @@ namespace dd

_outputIndex1 = _engine->getBindingIndex("keep_count");
_buffers.resize(3);
int det_out_size = _max_batch_size * _top_k * _dims.d[2];
// int det_out_size = _max_batch_size * _top_k * 7;
int det_out_size
= _max_batch_size * _results_height * _dims.d[2];
_floatOut.resize(det_out_size);
_keepCount.resize(_max_batch_size);
if (inputc._bw)
Expand Down Expand Up @@ -873,15 +872,17 @@ namespace dd

if (_bbox)
{
int results_height = _top_k;
int top_k = _results_height;
if (output_params->top_k > 0)
top_k = output_params->top_k;
const float *outr = _floatOut.data();

// preproc yolox
std::vector<float> yolo_out;
if (_template == "yolox")
{
yolo_out = yolo_utils::parse_yolo_output(
_floatOut, num_processed, results_height, _nclasses,
_floatOut, num_processed, _results_height, _nclasses,
inputc._width, inputc._height);
outr = yolo_out.data();
};
Expand Down Expand Up @@ -913,7 +914,7 @@ namespace dd
bool leave = false;
int curi = -1;

while (true && k < results_height)
while (true && k < top_k)
{
if (!_need_nms && output_params->best_bbox > 0
&& bboxes.size() >= static_cast<size_t>(
Expand Down
3 changes: 1 addition & 2 deletions src/backends/tensorrt/tensorrtlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ namespace dd
nvinfer1::DataType _datatype = nvinfer1::DataType::kFLOAT;
int _max_batch_size = 48;
size_t _max_workspace_size = 1 << 30; // 1GB
int _top_k
= 200; // top_k parameters in ssd in dede templates, can be overriden
int _results_height = -1;
std::string _engineFileName = "TRTengine";
bool _readEngine = true;
bool _writeEngine = true;
Expand Down
1 change: 1 addition & 0 deletions src/dto/output_connector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ namespace dd
DTO_FIELD(Int32, best_bbox) = -1;
DTO_FIELD(Float32, nms_threshold) = 0.45;
DTO_FIELD(Vector<String>, confidences);
DTO_FIELD(Int32, top_k) = -1;

DTO_FIELD_INFO(image)
{
Expand Down

0 comments on commit 655aa48

Please sign in to comment.