Skip to content

Commit

Permalink
feat: recompose action to recreate an image from a GAN + crop
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Sep 4, 2023
1 parent 7a3641e commit e1118b1
Show file tree
Hide file tree
Showing 15 changed files with 392 additions and 39 deletions.
9 changes: 9 additions & 0 deletions src/apidata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include "apidata.h"

#include "utils/utils.hpp"

namespace dd
{
/*- visitor_vad -*/
Expand Down Expand Up @@ -244,4 +246,11 @@ namespace dd
}
}

std::string APIData::toJSONString() const
{
JDoc jd;
jd.SetObject();
toJDoc(jd);
return dd_utils::jrender(jd);
}
}
5 changes: 5 additions & 0 deletions src/apidata.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ namespace dd
*/
void toJVal(JDoc &jd, JVal &jv) const;

/**
* \brief converts APIData to json string
*/
std::string toJSONString() const;

/**
* \brief converts APIData to oat++ DTO
*/
Expand Down
30 changes: 24 additions & 6 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ namespace dd
_floatOut = tl._floatOut;
_keepCount = tl._keepCount;
_dims = tl._dims;
_error_recorder = tl._error_recorder;
_calibrator = tl._calibrator;
_engine = tl._engine;
_builder = tl._builder;
_context = tl._context;
_builderc = tl._builderc;
_runtime = tl._runtime;
}

Expand All @@ -147,6 +153,12 @@ namespace dd
TensorRTLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::~TensorRTLib()
{
// Delete objects in the correct order
_calibrator = nullptr;
_context = nullptr;
_engine = nullptr;
_builderc = nullptr;
_builder = nullptr;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand All @@ -158,7 +170,8 @@ namespace dd
initLibNvInferPlugins(&trtLogger, "");
_runtime = std::shared_ptr<nvinfer1::IRuntime>(
nvinfer1::createInferRuntime(trtLogger));
_runtime->setErrorRecorder(new TRTErrorRecorder(this->_logger));
_error_recorder.reset(new TRTErrorRecorder(this->_logger));
_runtime->setErrorRecorder(_error_recorder.get());

if (ad.has("tensorRTEngineFile"))
_engineFileName = ad.get("tensorRTEngineFile").get<std::string>();
Expand Down Expand Up @@ -377,7 +390,8 @@ namespace dd
break;
}

nvinfer1::INetworkDefinition *network = _builder->createNetworkV2(0U);
std::unique_ptr<nvinfer1::INetworkDefinition> network(
_builder->createNetworkV2(0U));
nvcaffeparser1::ICaffeParser *caffeParser
= nvcaffeparser1::createCaffeParser();

Expand Down Expand Up @@ -426,7 +440,6 @@ namespace dd
outl->setPrecision(nvinfer1::DataType::kFLOAT);
nvinfer1::IHostMemory *n
= _builder->buildSerializedNetwork(*network, *_builderc);

return _runtime->deserializeCudaEngine(n->data(), n->size());
}

Expand All @@ -439,8 +452,8 @@ namespace dd
const auto explicitBatch
= 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
nvinfer1::INetworkDefinition *network
= _builder->createNetworkV2(explicitBatch);
std::unique_ptr<nvinfer1::INetworkDefinition> network(
_builder->createNetworkV2(explicitBatch));
_explicit_batch = true;

nvonnxparser::IParser *onnxParser
Expand Down Expand Up @@ -473,7 +486,6 @@ namespace dd
if (n == nullptr)
throw MLLibInternalException("Could not build model: "
+ this->_mlmodel._model);

return _runtime->deserializeCudaEngine(n->data(), n->size());
}

Expand Down Expand Up @@ -506,6 +518,12 @@ namespace dd
if (ad.has("data_raw_img"))
predict_dto->_data_raw_img
= ad.get("data_raw_img").get<std::vector<cv::Mat>>();
#ifdef USE_CUDA_CV
if (ad.has("data_raw_img_cuda"))
predict_dto->_data_raw_img_cuda
= ad.get("data_raw_img_cuda")
.get<std::vector<cv::cuda::GpuMat>>();
#endif
if (ad.has("ids"))
predict_dto->_ids = ad.get("ids").get<std::vector<std::string>>();
if (ad.has("meta_uris"))
Expand Down
5 changes: 3 additions & 2 deletions src/backends/tensorrt/tensorrtlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
#ifndef TENSORRTLIB_H
#define TENSORRTLIB_H

#include "tensorrtmodel.h"
#include "apidata.h"
#include "NvCaffeParser.h"
#include "NvInfer.h"

#include "apidata.h"
#include "tensorrtmodel.h"
#include "error_recorder.hpp"

namespace dd
Expand Down Expand Up @@ -127,6 +127,7 @@ namespace dd
_template; /**< template for models that require specific treatment */

//!< The TensorRT engine used to run the network
std::shared_ptr<TRTErrorRecorder> _error_recorder = nullptr;
std::shared_ptr<nvinfer1::IInt8Calibrator> _calibrator = nullptr;
std::shared_ptr<nvinfer1::ICudaEngine> _engine = nullptr;
std::shared_ptr<nvinfer1::IBuilder> _builder = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions src/chain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ namespace dd
for (auto p : *body->predictions)
{
std::string uri = p->uri;
p->uri = model_name.c_str();
p->uri = model_name;
other_models_out.insert(
std::pair<std::string, oatpp::Object<DTO::Prediction>>(uri,
p));
Expand All @@ -123,7 +123,7 @@ namespace dd
for (auto p : *out_body->predictions)
{
std::string uri = p->uri;
p->uri = action_id.c_str();
p->uri = action_id;
other_models_out.insert(
std::pair<std::string, oatpp::Object<DTO::Prediction>>(uri,
p));
Expand Down
210 changes: 207 additions & 3 deletions src/chain_actions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ namespace dd
}

cv::Rect roi(cxmin, cymin, cxmax - cxmin, cymax - cymin);

#ifdef USE_CUDA_CV
if (!cuda_imgs.empty())
{
Expand Down Expand Up @@ -235,11 +234,16 @@ namespace dd
}
// store crops into action output store
APIData action_out;
action_out.add("data_raw_img", cropped_imgs);
#ifdef USE_CUDA_CV
if (!cropped_cuda_imgs.empty())
action_out.add("data_cuda_img", cropped_cuda_imgs);
{
action_out.add("data_cuda_img", cropped_cuda_imgs);
}
else
#endif
{
action_out.add("data_raw_img", cropped_imgs);
}
action_out.add("cids", bbox_ids);
cdata.add_action_data(_action_id, action_out);

Expand Down Expand Up @@ -326,6 +330,205 @@ namespace dd
cdata.add_action_data(_action_id, action_out);
}

void make_even(cv::Mat &mat, int &width, int &height)
{
if (width % 2 != 0)
width -= 1;
if (height % 2 != 0)
height -= 1;

if (width != mat.cols || height != mat.rows)
{
cv::Rect roi{ 0, 0, width, height };
mat = mat(roi);
}
}

void ImgsCropRecomposeAction::apply(APIData &model_out, ChainData &cdata)
{
APIData first_model = cdata.get_model_data("0");
APIData input_ad = first_model.getobj("input");
// image
std::vector<cv::Mat> imgs;
#ifdef USE_CUDA_CV
std::vector<cv::cuda::GpuMat> cuda_imgs;
std::vector<cv::cuda::GpuMat> cropped_cuda_imgs;

if (input_ad.has("cuda_imgs"))
{
cuda_imgs
= input_ad.get("cuda_imgs").get<std::vector<cv::cuda::GpuMat>>();
}
else
#endif
{
imgs = input_ad.get("imgs").get<std::vector<cv::Mat>>();
}
std::vector<std::pair<int, int>> imgs_size
= input_ad.get("imgs_size").get<std::vector<std::pair<int, int>>>();

// bbox
std::vector<APIData> vad = first_model.getv("predictions");

// generated images
// XXX: Images always are written on RAM first.
// This may change in the future
std::map<std::string, cv::Mat> gen_imgs;
if (model_out.has("dto"))
{
auto dto = model_out.get("dto")
.get<oatpp::Any>()
.retrieve<oatpp::Object<DTO::PredictBody>>();
for (size_t i = 0; i < dto->predictions->size(); ++i)
{
auto images = dto->predictions->at(i)->images;
if (images->size() == 0)
throw ActionBadParamException(
"Recompose requires output.image = true in previous model");
gen_imgs.insert(
{ *dto->predictions->at(i)->uri, images->at(0)->get_img() });
}
}
else
{
throw ActionBadParamException("Recompose action requires GAN output");
}

std::vector<cv::Mat> rimgs;
std::vector<std::string> uris;

bool save_img = _params->save_img;
std::string save_path = _params->save_path;
if (!save_path.empty())
save_path += "/";

auto pred_body = DTO::PredictBody::createShared();

// need: original image, bbox coordinates, new image
for (size_t i = 0; i < vad.size(); i++)
{
std::string uri = vad.at(i).get("uri").get<std::string>();
uris.push_back(uri);

cv::Mat input_img;
int input_width, input_height;
cv::Mat rimg;
#ifdef USE_CUDA_CV
cv::cuda::GpuMat cuda_input_img;
cv::cuda::GpuMat cuda_rimg;
if (!cuda_imgs.empty())
{
cuda_input_img = cuda_imgs.at(i);
input_width = cuda_input_img.cols;
input_height = cuda_input_img.rows;
cuda_rimg = cuda_input_img.clone();
}
else
#endif
{
input_img = imgs.at(i);
input_width = input_img.cols;
input_height = input_img.rows;
rimg = input_img.clone();
}
int orig_width = imgs_size.at(i).second;
int orig_height = imgs_size.at(i).first;

std::vector<APIData> ad_cls = vad.at(i).getv("classes");
APIData bbox;

for (size_t j = 0; j < ad_cls.size(); j++)
{
bbox = ad_cls.at(j).getobj("bbox");
std::string cls_id
= ad_cls.at(j).get("class_id").get<std::string>();

cv::Mat gen_img = gen_imgs.at(cls_id);
int gen_width = gen_img.cols;
int gen_height = gen_img.rows;

// support odd width & height
make_even(gen_img, gen_width, gen_height);

if (gen_width > input_width || gen_height > input_height)
{
throw ActionBadParamException(
"Recomposing image is impossible, crop is too big: "
+ std::to_string(gen_width) + ","
+ std::to_string(gen_height) + "/"
+ std::to_string(orig_width) + ","
+ std::to_string(orig_height));
}

double xmin
= bbox.get("xmin").get<double>() / orig_width * input_width;
double ymin
= bbox.get("ymin").get<double>() / orig_height * input_height;
double xmax
= bbox.get("xmax").get<double>() / orig_width * input_width;
double ymax
= bbox.get("ymax").get<double>() / orig_height * input_height;

int cx = static_cast<int>((xmin + xmax) / 2);
int cy = static_cast<int>((ymin + ymax) / 2);
cx = std::min(std::max(cx, gen_width / 2),
input_width - gen_width / 2);
cy = std::min(std::max(cy, gen_height / 2),
input_height - gen_height / 2);

cv::Rect roi{ cx - gen_width / 2, cy - gen_height / 2, gen_width,
gen_height };
#ifdef USE_CUDA_CV
if (cuda_rimg.cols != 0)
{
cv::cuda::GpuMat cuda_gen_img;
cuda_gen_img.upload(gen_img);
cuda_gen_img.copyTo(cuda_rimg(roi));
}
else
#endif
{
gen_img.copyTo(rimg(roi));
}
}

rimgs.push_back(rimg);

auto action_pred = DTO::Prediction::createShared();
action_pred->uri = uri.c_str();
action_pred->images = oatpp::Vector<DTO::DTOImage>::createShared();
#ifdef USE_CUDA_CV
if (cuda_rimg.cols != 0)
{
action_pred->images->push_back({ cuda_rimg });
}
else
#endif
{
action_pred->images->push_back({ rimg });
}
pred_body->predictions->push_back(action_pred);

// save image if requested
if (save_img)
{
std::string puri = dd_utils::split(uri, '/').back();
#ifdef USE_CUDA_CV
if (cuda_rimg.cols != 0)
cuda_rimg.download(rimg);
#endif
cv::imwrite(save_path + "recompose_" + puri + ".png", rimg);
}
}

// Output: new image -> only works in "image" output mode
APIData action_out;
action_out.add("data_raw_img", rimgs);
action_out.add("cids", uris);
action_out.add("output", pred_body);
cdata.add_action_data(_action_id, action_out);
}

cv::Scalar bbox_palette[]
= { { 82, 188, 227 }, { 196, 110, 49 }, { 39, 54, 227 },
{ 68, 227, 81 }, { 77, 157, 255 }, { 255, 112, 207 },
Expand Down Expand Up @@ -533,6 +736,7 @@ namespace dd

CHAIN_ACTION("crop", ImgsCropAction)
CHAIN_ACTION("rotate", ImgsRotateAction)
CHAIN_ACTION("recompose", ImgsCropRecomposeAction)
CHAIN_ACTION("draw_bbox", ImgsDrawBBoxAction)
CHAIN_ACTION("filter", ClassFilter)
#ifdef USE_DLIB
Expand Down
Loading

0 comments on commit e1118b1

Please sign in to comment.