Skip to content

Commit

Permalink
feat: data augmentation for training segmentation models with torch b…
Browse files Browse the repository at this point in the history
…ackend
  • Loading branch information
beniz authored and mergify[bot] committed Nov 30, 2021
1 parent 0b5c7f3 commit b55c218
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 41 deletions.
68 changes: 51 additions & 17 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ namespace dd
{
// geometry on bboxes
std::vector<std::vector<float>> bboxes_c = bboxes;
applyGeometryBBox(bboxes_c, geoparams, src_c.cols, src_c.rows);
if (!bboxes_c.empty()) // some bboxes remain
applyGeometryBBox(bboxes_c, geoparams, src_c.cols,
src_c.rows); // uses the stored lambda
if (!bboxes_c.empty()) // some bboxes remain
{
src = src_c;
bboxes = bboxes_c;
Expand All @@ -106,15 +107,35 @@ namespace dd
}
}

bool TorchImgRandAugCV::applyMirror(cv::Mat &src)
void TorchImgRandAugCV::augment_with_segmap(cv::Mat &src, cv::Mat &tgt)
{
GeometryParams geoparams = _geometry_params;
applyGeometry(src, geoparams, true, true);
if (!geoparams._lambda.empty())
applyGeometry(tgt, geoparams, false, false); // reuses geoparams

applyCutout(src, _cutout_params);
bool mirrored = applyMirror(src);
if (mirrored)
applyMirror(tgt, false);
int rot = applyRotate(src);
if (rot != 0)
applyRotate(tgt, false, rot);
}

/*- transforms -*/
bool TorchImgRandAugCV::applyMirror(cv::Mat &src, const bool &sample)
{
if (!_mirror)
return false;

bool mirror = false;
#pragma omp critical
{
mirror = _bernouilli(_rnd_gen);
if (sample)
mirror = _bernouilli(_rnd_gen);
else
mirror = true;
}
if (mirror)
{
Expand All @@ -137,15 +158,15 @@ namespace dd
}
}

int TorchImgRandAugCV::applyRotate(cv::Mat &src)
int TorchImgRandAugCV::applyRotate(cv::Mat &src, const bool &sample, int rot)
{
if (!_rotate)
return -1;

int rot = 0;
#pragma omp critical
{
rot = _uniform_int_rotate(_rnd_gen);
if (sample)
rot = _uniform_int_rotate(_rnd_gen);
}
if (rot == 0)
return rot;
Expand Down Expand Up @@ -355,6 +376,7 @@ namespace dd
x0min = x0;
y0min = y0;
}

x0 = ((x0max - x0min) * _uniform_real_1(_rnd_gen) + x0min);
x1 = 3 * cols - x0;
y0 = ((y0max - y0min) * _uniform_real_1(_rnd_gen) + y0min);
Expand Down Expand Up @@ -409,19 +431,23 @@ namespace dd
}

void TorchImgRandAugCV::applyGeometry(cv::Mat &src, GeometryParams &cp,
const bool &store_rparams)
const bool &store_rparams,
const bool &sample)
{
if (!cp._prob)
return;

// enlarge image
float g1 = 0.0;
if (sample)
{
float g1 = 0.0;
#pragma omp critical
{
g1 = _uniform_real_1(_rnd_gen);
}
if (g1 > cp._prob)
return;
{
g1 = _uniform_real_1(_rnd_gen);
}
if (g1 > cp._prob)
return;
}

cv::Mat src_enlarged;
getEnlargedImage(src, cp, src_enlarged);
Expand All @@ -434,12 +460,20 @@ namespace dd
// get perpective matrix
#pragma omp critical
{
getQuads(src.rows, src.cols, cp, inputQuad, outputQuad);
if (sample)
getQuads(src.rows, src.cols, cp, inputQuad, outputQuad);
}

// warp perspective
cv::Mat lambda = cv::getPerspectiveTransform(inputQuad, outputQuad);
cv::warpPerspective(src_enlarged, src, lambda, src.size());
cv::Mat lambda
= (sample ? cv::getPerspectiveTransform(inputQuad, outputQuad)
: cp._lambda);
int inter_flag
= cv::INTER_NEAREST; //(sample ? cv::INTER_LINEAR : cv::INTER_NEAREST);
int border_mode = (cp._geometry_pad_mode == 1 ? cv::BORDER_CONSTANT
: cv::BORDER_REPLICATE);
cv::warpPerspective(src_enlarged, src, lambda, src.size(), inter_flag,
border_mode);

if (store_rparams)
cp._lambda = lambda;
Expand Down
10 changes: 6 additions & 4 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ namespace dd
= 0.25; /**< persp factor: 0.25 means that new
image corners be in 1.25 or 0.75. */
uint8_t _geometry_pad_mode = 1; /**< filling around images, 1: constant, 2:
mirrored, 3: repeat nearest. */
repeat nearest (replicate). */
float _geometry_bbox_intersect
= 0.75; /**< warped bboxes must at least have a 75% intersect with the
original bbox, otherwise they are filtered out.*/
Expand Down Expand Up @@ -192,12 +192,13 @@ namespace dd

void augment(cv::Mat &src);
void augment_with_bbox(cv::Mat &src, std::vector<torch::Tensor> &targets);
void augment_with_segmap(cv::Mat &src, cv::Mat &tgt);

protected:
bool applyMirror(cv::Mat &src);
bool applyMirror(cv::Mat &src, const bool &sample = true);
void applyMirrorBBox(std::vector<std::vector<float>> &bboxes,
const float &img_width);
int applyRotate(cv::Mat &src);
int applyRotate(cv::Mat &src, const bool &sample = true, int rot = 0);
void applyRotateBBox(std::vector<std::vector<float>> &bboxes,
const float &img_width, const float &img_height,
const int &rot);
Expand All @@ -206,7 +207,8 @@ namespace dd
void applyCutout(cv::Mat &src, CutoutParams &cp,
const bool &store_rparams = false);
void applyGeometry(cv::Mat &src, GeometryParams &cp,
const bool &store_rparams = false);
const bool &store_rparams = false,
const bool &sample = true);
void applyGeometryBBox(std::vector<std::vector<float>> &bboxes,
const GeometryParams &cp, const int &img_width,
const int &img_height);
Expand Down
47 changes: 33 additions & 14 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,23 @@ namespace dd
}

void TorchDataset::image_to_stringstream(const cv::Mat &img,
std::ostringstream &dstream)
std::ostringstream &dstream,
const bool &lossless)
{
std::vector<uint8_t> buffer;
std::vector<int> param = { cv::IMWRITE_JPEG_QUALITY, 100 };
cv::imencode(".jpg", img, buffer, param);
std::vector<int> param;
std::string ext;
if (!lossless)
{
ext = ".jpg";
param = { cv::IMWRITE_JPEG_QUALITY, 100 };
}
else
{
ext = ".png";
param = { cv::IMWRITE_PNG_COMPRESSION, 1 };
}
cv::imencode(ext, img, buffer, param);
for (uint8_t c : buffer)
dstream << c;
}
Expand All @@ -134,7 +146,7 @@ namespace dd
{
// serialize image
std::ostringstream dstream;
image_to_stringstream(bgr, dstream);
image_to_stringstream(bgr, dstream, true);

// serialize target
std::ostringstream tstream;
Expand All @@ -148,11 +160,11 @@ namespace dd
{
// serialize image
std::ostringstream dstream;
image_to_stringstream(bgr, dstream);
image_to_stringstream(bgr, dstream, true);

// serialize target
std::ostringstream tstream;
image_to_stringstream(bw_target, tstream);
image_to_stringstream(bw_target, tstream, false);

write_image_to_db(dstream, tstream, bgr.rows, bgr.cols);
}
Expand Down Expand Up @@ -193,14 +205,13 @@ namespace dd
const std::string &targets,
cv::Mat &bgr,
std::vector<torch::Tensor> &targett,
const bool &bw, const int &width,
const int &height)
cv::Mat &bw_target, const bool &bw,
const int &width, const int &height)
{
std::vector<uint8_t> img_data(datas.begin(), datas.end());
bgr = cv::Mat(img_data, true);
bgr = cv::imdecode(bgr,
bw ? CV_LOAD_IMAGE_GRAYSCALE : CV_LOAD_IMAGE_COLOR);
cv::Mat bw_target; // for segmentation only.

if (_segmentation)
{
Expand Down Expand Up @@ -231,6 +242,7 @@ namespace dd

if (_segmentation)
{

cv::resize(bw_target, bw_target, cv::Size(width, height), 0, 0,
cv::INTER_NEAREST);
}
Expand Down Expand Up @@ -524,9 +536,10 @@ namespace dd
ImgTorchInputFileConn *inputc
= reinterpret_cast<ImgTorchInputFileConn *>(_inputc);

cv::Mat bgr;
read_image_from_db(datas, targets, bgr, t, inputc->_bw,
inputc->width(), inputc->height());
cv::Mat bgr, bw_target;
read_image_from_db(datas, targets, bgr, t, bw_target,
inputc->_bw, inputc->width(),
inputc->height());

// data augmentation can apply here, with OpenCV
if (!_test)
Expand All @@ -535,16 +548,22 @@ namespace dd
_img_rand_aug_cv.augment_with_bbox(bgr, t);
else if (_segmentation)
{
// TODO: augment for segmentation
_img_rand_aug_cv.augment_with_segmap(bgr, bw_target);
}
else
_img_rand_aug_cv.augment(bgr);
}

torch::Tensor imgt
= image_to_tensor(bgr, inputc->height(), inputc->width());

d.push_back(imgt);

if (_segmentation)
{
at::Tensor targett_seg = image_to_tensor(
bw_target, inputc->height(), inputc->width(), true);
t.push_back(targett_seg);
}
}

for (unsigned int i = 0; i < d.size(); ++i)
Expand Down
8 changes: 4 additions & 4 deletions src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ namespace dd
/**
* \brief converts and image to a serialized string
*/
void image_to_stringstream(const cv::Mat &img,
std::ostringstream &dstream);
void image_to_stringstream(const cv::Mat &img, std::ostringstream &dstream,
const bool &lossless = true);

/**
* \brief writes encoded image to db with a tensor target
Expand All @@ -378,8 +378,8 @@ namespace dd
void read_image_from_db(const std::string &datas,
const std::string &targets, cv::Mat &bgr,
std::vector<torch::Tensor> &targett,
const bool &bw, const int &width,
const int &height);
cv::Mat &bw_target, const bool &bw,
const int &width, const int &height);
};

/**
Expand Down
7 changes: 5 additions & 2 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,10 @@ TEST(torchapi, service_train_image_segmentation)
+ iterations_deeplabv3 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_"
"interval\":100},\"net\":{\"batch_size\":4},"
"\"resume\":false},"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1}},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,"
"\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]"
",\"std\":[0.229,0.224,0.225]},"
Expand All @@ -782,7 +785,7 @@ TEST(torchapi, service_train_image_segmentation)
ASSERT_EQ(201, jd["status"]["code"]);

ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() <= 1) << "accuracy";
ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() >= 0.01)
ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() >= 0.007)
<< "accuracy good";
ASSERT_TRUE(jd["body"]["measure"]["meaniou"].GetDouble() <= 1) << "meaniou";

Expand Down

0 comments on commit b55c218

Please sign in to comment.