Skip to content

Commit

Permalink
feat(torch): data augmentation w/o db for bbox
Browse files Browse the repository at this point in the history
  • Loading branch information
royale authored and mergify[bot] committed Jan 25, 2023
1 parent b1accb7 commit a99ca7b
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 7 deletions.
52 changes: 47 additions & 5 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ namespace dd
{
data_size = _lfilesseg.size();
}
else if (!_lfilesbbox.empty())
{
data_size = _lfilesbbox.size();
}
}
else // below db case
{
Expand Down Expand Up @@ -528,6 +532,24 @@ namespace dd
}
}
}
else if (!_lfilesbbox.empty()) // bbox with no db
{
cv::Mat timg;

for (int64_t id : ids)
{
auto lfile = _lfilesbbox.at(id);

cv::Mat dimg;
std::vector<at::Tensor> t;
int res
= read_image_bbox_file(lfile.first, lfile.second, dimg, t);
if (res == 0)
{
dataaug_then_push_back(dimg, t, timg, data, target);
}
}
}
else // batches
{
bool first_iter = true;
Expand Down Expand Up @@ -783,8 +805,10 @@ namespace dd
return 0;
}

int TorchDataset::add_image_bbox_file(const std::string &fname,
const std::string &bboxfname)
int TorchDataset::read_image_bbox_file(const std::string &fname,
const std::string &bboxfname,
cv::Mat &out_img,
std::vector<at::Tensor> &out_targett)
{
// read image before reading bboxes to get the size of the image
ImgTorchInputFileConn *inputc
Expand Down Expand Up @@ -856,9 +880,27 @@ namespace dd
classes.push_back(target_to_tensor(cls));
}

// add image
add_image_batch(dimg._imgs[0],
{ torch::stack(bboxes), torch::cat(classes) });
out_img = dimg._imgs[0];
out_targett = { torch::stack(bboxes), torch::cat(classes) };
return 0;
}

int TorchDataset::add_image_bbox_file(const std::string &fname,
const std::string &bboxfname)
{
if (_db)
{
cv::Mat img;
std::vector<at::Tensor> targett;
int res = read_image_bbox_file(fname, bboxfname, img, targett);
if (res != 0)
return res;
add_image_batch(img, targett);
}
else
#pragma omp ordered
_lfilesbbox.push_back(
std::pair<std::string, std::string>(fname, bboxfname));
return 0;
}

Expand Down
12 changes: 11 additions & 1 deletion src/backends/torch/torchdataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ namespace dd
_lfiles; /**< list of files */
std::vector<std::pair<std::string, std::string>>
_lfilesseg; /**< list of files for segmentation */
std::vector<std::pair<std::string, std::string>>
_lfilesbbox; /**< list of files for bbox */

std::vector<TorchBatch> _batches; /**< Vector containing the whole dataset
(the "cached data") */
Expand Down Expand Up @@ -114,7 +116,8 @@ namespace dd
_batches_per_transaction(d._batches_per_transaction), _txn(d._txn),
_logger(d._logger), _shuffle(d._shuffle), _dbData(d._dbData),
_indices(d._indices), _lfiles(d._lfiles), _lfilesseg(d._lfilesseg),
_batches(d._batches), _dbFullName(d._dbFullName), _inputc(d._inputc),
_lfilesbbox(d._lfilesbbox), _batches(d._batches),
_dbFullName(d._dbFullName), _inputc(d._inputc),
_classification(d._classification), _image(d._image), _bbox(d._bbox),
_segmentation(d._segmentation), _test(d._test),
_img_rand_aug_cv(d._img_rand_aug_cv)
Expand Down Expand Up @@ -315,6 +318,13 @@ namespace dd
int add_image_image_file(const std::string &fname,
const std::string &fname_target);

/**
* \brief reads image with a bbox list file as target.
*/
int read_image_bbox_file(const std::string &fname,
const std::string &bboxfname, cv::Mat &out_img,
std::vector<at::Tensor> &out_targett);

/**
* \brief adds image to batch, with a bbox list file as target.
* \param width of preprocessed image
Expand Down
98 changes: 97 additions & 1 deletion tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2338,7 +2338,7 @@ TEST(torchapi, service_train_object_detection_yolox)
"\"supervised\",\"model\":{\"repository\":\""
+ detect_train_repo_yolox
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"640,\"width\":640,\"rgb\":true,\"bbox\":true},"
"640,\"width\":640,\"rgb\":true,\"bbox\":true,\"db\":true},"
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,"
"\"nclasses\":2}}}";

Expand Down Expand Up @@ -2421,6 +2421,102 @@ TEST(torchapi, service_train_object_detection_yolox)
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
}

TEST(torchapi, service_train_object_detection_yolox_no_db)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

JsonAPI japi;
std::string sname = "detectserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"yolox\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ detect_train_repo_yolox
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":"
"640,\"width\":640,\"rgb\":true,\"bbox\":true},"
"\"mllib\":{\"template\":\"yolox\",\"gpu\":true,"
"\"nclasses\":2}}}";

std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"detectserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":3"
+ std::string("")
//+ iterations_detection + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":2,\"solver_"
"type\":\"ADAM\",\"test_interval\":200},\"net\":{\"batch_size\":2,"
"\"test_batch_size\":2,\"reg_weight\":0.5},\"resume\":false,"
"\"mirror\":true,\"rotate\":true,\"crop_size\":512,"
"\"test_crop_samples\":10,"
"\"cutout\":0.1,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":\"constant\"},\"noise\":{\"prob\":0.01},\"distort\":{"
"\"prob\":0.01}},\"input\":{\"seed\":12347,\"db\":false,"
"\"shuffle\":true},\"output\":{\"measure\":[\"map\"]}},\"data\":[\""
+ fasterrcnn_train_data + "\",\"" + fasterrcnn_test_data + "\"]}";

joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

// ASSERT_EQ(jd["body"]["measure"]["iteration"], 200) << "iterations";
ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() <= 1.0) << "map";
// ASSERT_TRUE(jd["body"]["measure"]["map"].GetDouble() > 0.0) << "map";

// check metrics
auto &meas = jd["body"]["measure"];
ASSERT_TRUE(meas.HasMember("iou_loss"));
ASSERT_TRUE(meas.HasMember("conf_loss"));
ASSERT_TRUE(meas.HasMember("cls_loss"));
ASSERT_TRUE(meas.HasMember("l1_loss"));
ASSERT_TRUE(meas.HasMember("train_loss"));
ASSERT_TRUE(
std::abs(meas["train_loss"].GetDouble()
- (meas["iou_loss"].GetDouble() * 0.5
+ meas["cls_loss"].GetDouble() + meas["l1_loss"].GetDouble()
+ meas["conf_loss"].GetDouble()))
< 0.0001);

// check that predict works fine
std::string jpredictstr = "{\"service\":\"detectserv\",\"parameters\":{"
"\"input\":{\"height\":640,"
"\"width\":640},\"output\":{\"bbox\":true, "
"\"confidence_threshold\":0.8}},\"data\":[\""
+ detect_train_repo_fasterrcnn
+ "/imgs/la_melrose_ave-000020.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
jd = JDoc();
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);

std::unordered_set<std::string> lfiles;
fileops::list_directory(detect_train_repo_yolox, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
+ iterations_detection + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(detect_train_repo_yolox + "checkpoint-"
+ iterations_detection + ".pt"));

fileops::clear_directory(detect_train_repo_yolox + "train.lmdb");
fileops::clear_directory(detect_train_repo_yolox + "test_0.lmdb");
fileops::remove_dir(detect_train_repo_yolox + "train.lmdb");
fileops::remove_dir(detect_train_repo_yolox + "test_0.lmdb");
}

TEST(torchapi, service_train_object_detection_yolox_multigpu)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit a99ca7b

Please sign in to comment.