Skip to content

Commit

Permalink
fix: dataset output dimension after crop augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz authored and mergify[bot] committed Jan 8, 2022
1 parent ac7ce0f commit 636d455
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 15 deletions.
15 changes: 2 additions & 13 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,10 @@ namespace dd

if (_segmentation)
{

cv::resize(bw_target, bw_target, cv::Size(width, height), 0, 0,
cv::INTER_NEAREST);
}
}

if (_segmentation)
{
at::Tensor targett_seg
= image_to_tensor(bw_target, height, width, true);
targett.push_back(targett_seg);
}
}

// add image batch
Expand Down Expand Up @@ -554,14 +546,13 @@ namespace dd
_img_rand_aug_cv.augment(bgr);
}

torch::Tensor imgt
= image_to_tensor(bgr, inputc->height(), inputc->width());
torch::Tensor imgt = image_to_tensor(bgr, bgr.rows, bgr.cols);
d.push_back(imgt);

if (_segmentation)
{
at::Tensor targett_seg = image_to_tensor(
bw_target, inputc->height(), inputc->width(), true);
bw_target, bw_target.rows, bw_target.cols, true);
t.push_back(targett_seg);
}
}
Expand Down Expand Up @@ -834,8 +825,6 @@ namespace dd
return imgt;
}

// TODO: segmentation target image to tensor

at::Tensor TorchDataset::target_to_tensor(const int &target)
{
at::Tensor targett{ torch::full(1, target, torch::kLong) };
Expand Down
2 changes: 1 addition & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ namespace dd
throw;
}

// TODO: set inputc dataset data augmentation options
// set inputc dataset data augmentation options
APIData ad_mllib = ad.getobj("parameters").getobj("mllib");
if (typeid(inputc) == typeid(ImgTorchInputFileConn))
{
Expand Down
1 change: 0 additions & 1 deletion src/backends/torch/torchloss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ namespace dd
{
if (_loss.empty())
{

loss = torch::nn::functional::cross_entropy(
y_pred, y.squeeze(1).to(torch::kLong)); // TODO: options
}
Expand Down

0 comments on commit 636d455

Please sign in to comment.