Skip to content

Commit

Permalink
fix: exported yolox have the correct number of classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Jun 9, 2022
1 parent e8a70cf commit 4dac269
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
16 changes: 12 additions & 4 deletions src/backends/tensorrt/models/yolo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <vector>
#include <algorithm>

#include "mllibstrategy.h"

namespace dd
{
namespace yolo_utils
Expand All @@ -28,14 +30,19 @@ namespace dd
* sorted batch id | class_id | class confidence | bbox * 4*/
static std::vector<float>
parse_yolo_output(const std::vector<float> &model_out, size_t batch_size,
size_t top_k, size_t n_classes, size_t im_width,
size_t im_height)
size_t top_k, size_t step, size_t n_classes,
size_t im_width, size_t im_height)
{
std::vector<float> vals;
vals.reserve(batch_size * top_k * 7);
size_t step = n_classes + 5;
auto batch_it = model_out.begin();

if (step < n_classes + 4 || step > n_classes + 5)
throw MLLibBadParamException("YOLOX: wrong number of classes");
// model can have a background class or not, but dede always
// requires it. We vary the offset to take account of this.
int cls_offset = step - n_classes;

for (size_t batch = 0; batch < batch_size; ++batch)
{
std::vector<std::vector<float>> result;
Expand All @@ -47,7 +54,8 @@ namespace dd
// get class id & confidence
auto max_batch_it
= std::max_element(batch_it + 5, batch_it + step);
float cls_pred = std::distance(batch_it + 5, max_batch_it);
auto ref_it = batch_it + cls_offset;
float cls_pred = std::distance(ref_it, max_batch_it);
float prob = *max_batch_it * (*(batch_it + 4));

// convert center, dims to xyxy
Expand Down
4 changes: 2 additions & 2 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,8 @@ namespace dd
if (_template == "yolox")
{
yolo_out = yolo_utils::parse_yolo_output(
_floatOut, num_processed, _results_height, _nclasses,
inputc._width, inputc._height);
_floatOut, num_processed, _results_height, _dims.d[2],
_nclasses, inputc._width, inputc._height);
outr = yolo_out.data();
};

Expand Down
6 changes: 3 additions & 3 deletions tests/ut-tensorrtapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ TEST(tensorrtapi, service_predict_bbox_onnx)
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"640,\"width\":640,\"rgb\":true},\"mllib\":{\"template\":\"yolox\","
"\"maxBatchSize\":2,\"maxWorkspaceSize\":256,\"gpuid\":0,"
"\"nclasses\":80}}}";
"\"nclasses\":81}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

Expand Down Expand Up @@ -285,7 +285,7 @@ TEST(tensorrtapi, service_predict_bbox_onnx)
auto &preds = jd["body"]["predictions"][cat_id]["classes"];
ASSERT_EQ(preds.Size(), 1);
std::string cl1 = preds[0]["cat"].GetString();
ASSERT_EQ(cl1, "15");
ASSERT_EQ(cl1, "16");
ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.9);
auto &bbox = preds[0]["bbox"];
ASSERT_TRUE(bbox["xmin"].GetDouble() < 50 && bbox["xmax"].GetDouble() > 200
Expand All @@ -298,7 +298,7 @@ TEST(tensorrtapi, service_predict_bbox_onnx)
auto &preds2 = jd["body"]["predictions"][dog_id]["classes"];
ASSERT_EQ(preds2.Size(), 1);
std::string cl2 = preds2[0]["cat"].GetString();
ASSERT_EQ(cl2, "16");
ASSERT_EQ(cl2, "17");
ASSERT_TRUE(preds2[0]["prob"].GetDouble() > 0.8);
auto &bbox2 = preds[0]["bbox"];
ASSERT_TRUE(bbox2["xmin"].GetDouble() < 50 && bbox2["xmax"].GetDouble() > 200
Expand Down
46 changes: 35 additions & 11 deletions tools/torch/trace_yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def main():
from yolox.models.network_blocks import SiLU

exp = get_exp(None, args.model)
exp.num_classes = args.num_classes
# dede assumes a background class absent from yolox
exp.num_classes = args.num_classes - 1
logging.info("num_classes == %d" % args.num_classes)

model = exp.get_model()
Expand All @@ -59,16 +60,32 @@ def main():

if args.weights:
logging.info("Load weights from %s" % args.weights)

def load_yolox_weights():
try:
# state_dict
weights = torch.load(args.weights)["model"]
except:
# torchscript
logging.info("Detected torchscript weights")
weights = torch.jit.load(args.weights).state_dict()
weights = {k[6:] : w for k, w in weights.items()} # skip "model." prefix

model.load_state_dict(weights, strict=True)

try:
# state_dict
weights = torch.load(args.weights)["model"]
load_yolox_weights()
except:
# torchscript
logging.info("Detected torchscript weights")
weights = torch.jit.load(args.weights).state_dict()
weights = {k[6:] : w for k, w in weights.items()} # skip "model." prefix
# Legacy model
exp.num_classes = args.num_classes

exp.model = None
model = exp.get_model()
model.eval()
model.head.decode_in_inference = True

model.load_state_dict(weights, strict=True)
load_yolox_weights()
logging.info("Detected yolox trained with a background class")

elif args.backbone_weights:
logging.info("Load weights from %s" % args.backbone_weights)
Expand All @@ -85,7 +102,11 @@ def main():
if args.to_onnx:
model = replace_module(model, nn.SiLU, SiLU)

model = YoloXWrapper_TRT(model, topk = args.top_k, raw_output = not args.use_wrapper)
model = YoloXWrapper_TRT(
model,
topk = args.top_k,
raw_output = not args.use_wrapper
)
model.to(device)
model.eval()

Expand All @@ -102,7 +123,7 @@ def main():
dynamic_axes = dynamic_axes)
else:
# wrap model
model = YoloXWrapper(model, args.num_classes, postprocess)
model = YoloXWrapper(model, exp.num_classes, postprocess)
model.to(device)
model.eval()

Expand Down Expand Up @@ -166,6 +187,8 @@ def forward(self, x, ids = None, bboxes = None, labels = None):
labels[start:stop].unsqueeze(1),
self.convert_targs(bboxes[start:stop])
), dim=1)
# dd uses 0 as background class, not YOLOX
targ = targ - 1
l_targs.append(targ)
max_count = max(max_count, targ.shape[0])

Expand All @@ -190,7 +213,8 @@ def forward(self, x, ids = None, bboxes = None, labels = None):
preds.append({
"boxes": pred[:,:4],
"scores": pred[:,4]*pred[:,5],
"labels": pred[:,6].to(torch.int64)
# dd uses 0 as background class, not YOLOX
"labels": pred[:,6].to(torch.int64) + 1
})

return losses, preds
Expand Down

0 comments on commit 4dac269

Please sign in to comment.