Skip to content

Commit

Permalink
feat(ml): random cropping for training segmentation models with torch
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Jan 8, 2022
1 parent 04ef758 commit ac7ce0f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
37 changes: 21 additions & 16 deletions src/backends/torch/torchdataaug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ namespace dd
applyCutout(src, _cutout_params);

// these transforms do affect dimensions
applyCrop(src, _crop_params);
int crop_x = 0;
int crop_y = 0;
applyCrop(src, _crop_params, crop_x, crop_y);
applyMirror(src);
applyRotate(src);
applyNoise(src);
Expand Down Expand Up @@ -119,6 +121,13 @@ namespace dd
applyGeometry(tgt, geoparams, false, false); // reuses geoparams

applyCutout(src, _cutout_params);

int crop_x = 0;
int crop_y = 0;
bool cropped = applyCrop(src, _crop_params, crop_x, crop_y);
if (cropped)
applyCrop(tgt, _crop_params, crop_x, crop_y, false);

bool mirrored = applyMirror(src);
if (mirrored)
applyMirror(tgt, false);
Expand Down Expand Up @@ -247,28 +256,24 @@ namespace dd
bboxes = nbboxes;
}

void TorchImgRandAugCV::applyCrop(cv::Mat &src, CropParams &cp,
const bool &store_rparams)
bool TorchImgRandAugCV::applyCrop(cv::Mat &src, CropParams &cp, int &crop_x,
int &crop_y, const bool &sample)
{
if (cp._crop_size <= 0)
return;
return false;

int crop_x = 0;
int crop_y = 0;
if (sample)
{
#pragma omp critical
{
crop_x = cp._uniform_int_crop_x(_rnd_gen);
crop_y = cp._uniform_int_crop_y(_rnd_gen);
}
{
crop_x = cp._uniform_int_crop_x(_rnd_gen);
crop_y = cp._uniform_int_crop_y(_rnd_gen);
}
}
cv::Rect crop(crop_x, crop_y, cp._crop_size, cp._crop_size);
cv::Mat dst = src(crop).clone();
src = dst;

if (store_rparams)
{
cp._crop_x = crop_x;
cp._crop_y = crop_y;
}
return true;
}

void TorchImgRandAugCV::applyCutout(cv::Mat &src, CutoutParams &cp,
Expand Down
8 changes: 2 additions & 6 deletions src/backends/torch/torchdataaug.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ namespace dd
int _crop_size = -1;
std::uniform_int_distribution<int> _uniform_int_crop_x;
std::uniform_int_distribution<int> _uniform_int_crop_y;

// randomized params
int _crop_x = 0;
int _crop_y = 0;
};

class CutoutParams : public ImgAugParams
Expand Down Expand Up @@ -288,8 +284,8 @@ namespace dd
void applyRotateBBox(std::vector<std::vector<float>> &bboxes,
const float &img_width, const float &img_height,
const int &rot);
void applyCrop(cv::Mat &src, CropParams &cp,
const bool &store_rparams = false);
bool applyCrop(cv::Mat &src, CropParams &cp, int &crop_x, int &crop_y,
const bool &sample = true);
void applyCutout(cv::Mat &src, CutoutParams &cp,
const bool &store_rparams = false);
void applyGeometry(cv::Mat &src, GeometryParams &cp,
Expand Down

0 comments on commit ac7ce0f

Please sign in to comment.