Skip to content

Commit

Permalink
feat: random crops for object detector training with torch backend
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Feb 16, 2022
1 parent cd23bad commit 385122d
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 8 deletions.
121 changes: 116 additions & 5 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
*/

#include "torchdataaug.h"
#include "torchdataset.h"

namespace dd
{
Expand Down Expand Up @@ -67,8 +68,10 @@ namespace dd
std::vector<torch::Tensor> &targets)
{
torch::Tensor t = targets[0];
torch::Tensor c = targets[1]; // classes
int nbbox = t.size(0);
std::vector<std::vector<float>> bboxes;
std::vector<int> classes;
for (int bb = 0; bb < nbbox; ++bb)
{
std::vector<float> bbox;
Expand All @@ -77,6 +80,7 @@ namespace dd
bbox.push_back(t[bb][d].item<float>());
}
bboxes.push_back(bbox); // add (xmin, ymin, xmax, ymax)
classes.push_back(c[bb].item<int>());
}

bool mirror = applyMirror(src);
Expand All @@ -90,6 +94,15 @@ namespace dd
applyRotateBBox(bboxes, static_cast<float>(src.cols),
static_cast<float>(src.rows), rot);
}
int crop_x = 0;
int crop_y = 0;
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y);
if (cropped)
{
applyCropBBox(bboxes, classes, _crop_params,
static_cast<float>(src.cols),
static_cast<float>(src.rows), crop_x, crop_y);
}
applyCutout(src, _cutout_params);
GeometryParams geoparams = _geometry_params;
cv::Mat src_c = src.clone();
Expand All @@ -109,17 +122,77 @@ namespace dd

// replacing the initial bboxes with the transformed ones.
nbbox = bboxes.size();
for (int bb = 0; bb < nbbox; ++bb)
if (!cropped)
{
for (int d = 0; d < 4; ++d)
for (int bb = 0; bb < nbbox; ++bb)
{
t[bb][d] = bboxes.at(bb).at(d);
for (int d = 0; d < 4; ++d)
{
t[bb][d] = bboxes.at(bb).at(d);
}
}
}
else
{
std::vector<torch::Tensor> tbboxes;
std::vector<torch::Tensor> tclasses;
for (int bb = 0; bb < nbbox; ++bb)
{
std::vector<double> fbbox(bboxes.at(bb).begin(),
bboxes.at(bb).end());
TorchDataset td;
tbboxes.push_back(td.target_to_tensor(fbbox));
tclasses.push_back(td.target_to_tensor(classes.at(bb)));
}
targets = { torch::stack(tbboxes), torch::cat(tclasses) };
}
applyNoise(src);
applyDistort(src);
}

void TorchImgRandAugCV::augment_test_with_bbox(
cv::Mat &src, std::vector<torch::Tensor> &targets)
{
torch::Tensor t = targets[0];
torch::Tensor c = targets[1]; // classes
int nbbox = t.size(0);
std::vector<std::vector<float>> bboxes;
std::vector<int> classes;
for (int bb = 0; bb < nbbox; ++bb)
{
std::vector<float> bbox;
for (int d = 0; d < 4; ++d)
{
bbox.push_back(t[bb][d].item<float>());
}
bboxes.push_back(bbox); // add (xmin, ymin, xmax, ymax)
classes.push_back(c[bb].item<int>());
}
int crop_x = 0;
int crop_y = 0;
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y);
if (cropped)
{
applyCropBBox(bboxes, classes, _crop_params,
static_cast<float>(src.cols),
static_cast<float>(src.rows), crop_x, crop_y);

// replacing the initial bboxes with the transformed ones.
std::vector<torch::Tensor> tbboxes;
std::vector<torch::Tensor> tclasses;
nbbox = bboxes.size();
for (int bb = 0; bb < nbbox; ++bb)
{
std::vector<double> fbbox(bboxes.at(bb).begin(),
bboxes.at(bb).end());
TorchDataset td;
tbboxes.push_back(td.target_to_tensor(fbbox));
tclasses.push_back(td.target_to_tensor(classes.at(bb)));
}
targets = { torch::stack(tbboxes), torch::cat(tclasses) };
}
}

void TorchImgRandAugCV::augment_with_segmap(cv::Mat &src, cv::Mat &tgt)
{
GeometryParams geoparams = _geometry_params;
Expand Down Expand Up @@ -292,6 +365,44 @@ namespace dd
return true;
}

void TorchImgRandAugCV::applyCropBBox(
std::vector<std::vector<float>> &bboxes, std::vector<int> &classes,
const CropParams &cp, const float &img_width, const float &img_height,
const float &crop_x, const float &crop_y)
{
// apply crop to bboxes.
std::vector<std::vector<float>> nbboxes;
std::vector<int> nclasses;
for (size_t i = 0; i < bboxes.size(); ++i)
{
std::vector<float> bbox = bboxes.at(i);

if (bbox[2] < crop_x
|| bbox[0] > crop_x + cp._crop_size) // xmax < cropx or xmin >
// crop_x + crop_size
continue; // no bbox
if (bbox[3] < crop_y
|| bbox[1] > crop_y + cp._crop_size) // ymax < cropy || ymin >
// cropy + crop_size
continue; // no bbox

std::vector<float> nbox;
nbox.push_back(std::max(0.f, bbox[0] - crop_x)); // xmin = xmin-crop_x
nbox.push_back(std::max(0.f, bbox[1] - crop_y));
nbox.push_back(std::min(img_width, bbox[2] - crop_x));
nbox.push_back(std::min(img_height, bbox[3] - crop_y));
nbboxes.push_back(nbox);
nclasses.push_back(classes.at(i));
}
if (nbboxes.empty())
{
nbboxes.push_back({ 0.0, 0.0, 0.0, 0.0 });
nclasses.push_back(0);
}
bboxes = nbboxes;
classes = nclasses;
}

void TorchImgRandAugCV::applyCutout(cv::Mat &src, CutoutParams &cp,
const bool &store_rparams)
{
Expand Down Expand Up @@ -498,8 +609,8 @@ namespace dd
cv::Mat lambda
= (sample ? cv::getPerspectiveTransform(inputQuad, outputQuad)
: cp._lambda);
int inter_flag
= cv::INTER_NEAREST; //(sample ? cv::INTER_LINEAR : cv::INTER_NEAREST);
int inter_flag = cv::INTER_NEAREST; //(sample ? cv::INTER_LINEAR :
// cv::INTER_NEAREST);
int border_mode = (cp._geometry_pad_mode == 1 ? cv::BORDER_CONSTANT
: cv::BORDER_REPLICATE);
cv::warpPerspective(src_enlarged, src, lambda, src.size(), inter_flag,
Expand Down
6 changes: 6 additions & 0 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ namespace dd
void augment_with_segmap(cv::Mat &src, cv::Mat &tgt);

void augment_test(cv::Mat &src);
void augment_test_with_bbox(cv::Mat &src,
std::vector<torch::Tensor> &targets);
void augment_test_with_segmap(cv::Mat &src, cv::Mat &tgt);

protected:
Expand All @@ -299,6 +301,10 @@ namespace dd
const int &rot);
bool applyCrop(cv::Mat &src, CropParams &cp, int &crop_x, int &crop_y,
const bool &sample = true);
void applyCropBBox(std::vector<std::vector<float>> &bboxes,
std::vector<int> &classes, const CropParams &cp,
const float &img_width, const float &img_height,
const float &crop_x, const float &crop_y);
void applyCutout(cv::Mat &src, CutoutParams &cp,
const bool &store_rparams = false);
void applyGeometry(cv::Mat &src, GeometryParams &cp,
Expand Down
4 changes: 2 additions & 2 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,9 @@ namespace dd
// cropping requires test set 'augmentation'
if (_bbox)
{
// no cropping yet with bboxes
_img_rand_aug_cv.augment_test_with_bbox(bgr, t);
}
if (_segmentation)
else if (_segmentation)
_img_rand_aug_cv.augment_test_with_segmap(bgr,
bw_target);
else
Expand Down
2 changes: 1 addition & 1 deletion tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,7 @@ TEST(torchapi, service_train_object_detection_yolox)
+ ",\"iter_size\":2,\"solver_"
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2,"
"\"test_batch_size\":2,\"reg_weight\":0.5},\"resume\":false,"
"\"mirror\":true,\"rotate\":true,\"crop_size\":640,"
"\"mirror\":true,\"rotate\":true,\"crop_size\":512,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
Expand Down

0 comments on commit 385122d

Please sign in to comment.