Skip to content

Commit

Permalink
feat(torch): dice loss https://arxiv.org/abs/1707.03237
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and mergify[bot] committed Dec 14, 2021
1 parent 5cce134 commit 542bcb4
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ mirror | bool | yes | false |
finetuning | bool | yes | false | Whether to prepare neural net template for finetuning (requires `weights`)
db | bool | yes | false | whether to set a database as input of neural net, useful for handling large datasets and training in constant-memory (requires `mlp` or `convnet`)
scaling_temperature | real | yes | 1.0 | sets the softmax temperature of an existing network (e.g. useful for model calibration)
loss | string | yes | N/A | Special network losses, from `dice`, `dice_multiclass`, `dice_weighted`, `dice_weighted_batch` or `dice_weighted_all`, useful for image segmentation, and `L1` or `L2`, useful for time-series via `csvts` connector
loss | string | yes | N/A | Special network losses, from `dice` (direct IOU maximization), `dice_multiclass` (same as dice for torch backend, different implemtation for caffe backend), `dice_weighted` (dice augmented with inter-class weighting based on image stats), `dice_weighted_batch` (dice augmented with inter-class weighting based on batch stats) or `dice_weighted_all` (dice augmented with inter-class weighting based on running stats over all seen data), useful for image segmentation, and `L1` or `L2`, useful for time-series via `csvts` connector
ssd_expand_prob | float | yes | between 0 and 1, probability of expanding the image (to improve detection of small/very small objects)
ssd_max_expand_ratio | float | yes | bbox zoom out ratio, e.g. 4.0
ssd_mining_type | str | yes | N/A | "HARD_EXAMPLE" or "MAX_NEGATIVE"
Expand Down
69 changes: 67 additions & 2 deletions src/backends/torch/torchloss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "torchloss.h"
#pragma GCC diagnostic pop
#include <iostream>

namespace dd
{
Expand Down Expand Up @@ -88,8 +89,72 @@ namespace dd
}
else if (_segmentation)
{
loss = torch::nn::functional::cross_entropy(
y_pred, y.squeeze(1).to(torch::kLong)); // TODO: options
if (_loss.empty())
{

loss = torch::nn::functional::cross_entropy(
y_pred, y.squeeze(1).to(torch::kLong)); // TODO: options
}
else if (_loss == "dice" || _loss == "dice_multiclass"
|| _loss == "dice_weighted" || _loss == "dice_weighted_batch"
|| _loss == "dice_weighted_all")
{
// see https://arxiv.org/abs/1707.03237
double smooth = 1e-7;
torch::Tensor y_true_f
= torch::one_hot(y.to(torch::kInt64), y_pred.size(1))
.squeeze(1)
.permute({ 0, 3, 1, 2 })
.flatten(2)
.to(torch::kFloat32);
torch::Tensor y_pred_f = torch::flatten(torch::sigmoid(y_pred), 2);

torch::Tensor intersect;
torch::Tensor denom;

if (_loss == "dice" || _loss == "dice_multiclass")
{
intersect = y_true_f * y_pred_f;
denom = y_true_f + y_pred_f;
}
else if (_loss == "dice_weighted")
{
torch::Tensor sum = torch::sum(y_true_f, { 2 }) + 1.0;
torch::Tensor weights = 1.0 / sum / sum;
intersect = torch::sum(y_true_f * y_pred_f, { 2 }) * weights;
denom = torch::sum(y_true_f + y_pred_f, { 2 }) * weights;
}
else if (_loss == "dice_weighted_batch"
|| _loss == "dice_weighted_all")
{
torch::Tensor sum
= torch::sum(y_true_f, std::vector<int64_t>({ 0, 2 }))
+ 1.0;
torch::Tensor weights = 1.0 / sum / sum;
if (_loss == "dice_weighted_all")
{
if (_num_batches == 0)
_class_weights = weights;
else
{
weights = (_class_weights * _num_batches + weights)
/ (_num_batches + 1);
_class_weights = weights;
}
_num_batches++;
}
intersect = torch::sum(y_true_f * y_pred_f,
std::vector<int64_t>({ 0, 2 }))
* weights;
denom = torch::sum(y_true_f + y_pred_f,
std::vector<int64_t>({ 0, 2 }))
* weights;
}

return 1.0 - torch::mean(2.0 * intersect / (denom + smooth));
}
else
throw MLLibBadParamException("unknown loss: " + _loss);
}
else
{
Expand Down
1 change: 1 addition & 0 deletions src/backends/torch/torchloss.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ namespace dd
torch::Tensor _y_pred;
torch::Tensor _y;
std::vector<c10::IValue> _ivx;
long int _num_batches = 0;
};
}
#endif
85 changes: 85 additions & 0 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,91 @@ TEST(torchapi, service_train_image_segmentation)
fileops::remove_dir(deeplabv3_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_train_image_segmentation_dice)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
torch::manual_seed(torch_seed);
at::globalContext().setDeterministicCuDNN(true);

// Create service
JsonAPI japi;
std::string sname = "imgserv";
std::string jstr
= "{\"mllib\":\"torch\",\"description\":\"image\",\"type\":"
"\"supervised\",\"model\":{\"repository\":\""
+ deeplabv3_train_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"image\","
"\"width\":480,\"height\":480,\"db\":true,\"segmentation\":true},"
"\"mllib\":{\"nclasses\":"
"13,\"gpu\":true,\"segmentation\":true,\"loss\":\"dice_weighted_"
"all\"}}"
"}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// Train
std::string jtrainstr
= "{\"service\":\"imgserv\",\"async\":false,\"parameters\":{"
"\"mllib\":{\"solver\":{\"iterations\":"
+ iterations_deeplabv3 + ",\"base_lr\":" + torch_lr
+ ",\"iter_size\":1,\"solver_type\":\"ADAM\",\"test_"
"interval\":100},\"net\":{\"batch_size\":4},"
"\"resume\":false,\"mirror\":true,\"rotate\":true,\"crop_size\":224,"
"\"cutout\":0.5,\"geometry\":{\"prob\":0.1,\"persp_horizontal\":"
"true,\"persp_vertical\":true,\"zoom_in\":true,\"zoom_out\":true,"
"\"pad_mode\":1},\"noise\":{\"prob\":0.01},\"distort\":{\"prob\":0."
"01}},"
"\"input\":{\"seed\":12345,\"db\":true,\"shuffle\":true,"
"\"segmentation\":true,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406]"
",\"std\":[0.229,0.224,0.225]},"
"\"output\":{\"measure\":[\"meaniou\",\"acc\"]}},\"data\":[\""
+ deeplabv3_train_data + "\",\"" + deeplabv3_test_data + "\"]}";
joutstr = japi.jrender(japi.service_train(jtrainstr));
JDoc jd;
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(201, jd["status"]["code"]);

ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() <= 1) << "accuracy";
ASSERT_TRUE(jd["body"]["measure"]["meanacc"].GetDouble() >= 0.007)
<< "accuracy good";
ASSERT_TRUE(jd["body"]["measure"]["meaniou"].GetDouble() <= 1) << "meaniou";

std::string jpredictstr
= "{\"service\":\"imgserv\",\"parameters\":{"
"\"input\":{\"height\":480,"
"\"width\":480,\"scale\":0.0039,\"mean\":[0.485,0.456,0.406],\"std\":["
"0.229,0.224,0.225]},\"output\":{\"segmentation\":true, "
"\"confidences\":[\"best\"]}},\"data\":[\""
+ deeplabv3_test_image + "\"]}";

joutstr = japi.jrender(japi.service_predict(jpredictstr));
std::cout << "joutstr=" << joutstr << std::endl;
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
ASSERT_TRUE(!jd.HasParseError());
ASSERT_EQ(200, jd["status"]["code"]);
ASSERT_TRUE(jd["body"]["predictions"].IsArray());

std::unordered_set<std::string> lfiles;
fileops::list_directory(deeplabv3_train_repo, true, false, false, lfiles);
for (std::string ff : lfiles)
{
if (ff.find("checkpoint") != std::string::npos
|| ff.find("solver") != std::string::npos)
remove(ff.c_str());
}
ASSERT_TRUE(!fileops::file_exists(deeplabv3_train_repo + "checkpoint-"
+ iterations_deeplabv3 + ".ptw"));
ASSERT_TRUE(!fileops::file_exists(deeplabv3_train_repo + "checkpoint-"
+ iterations_deeplabv3 + ".pt"));

fileops::clear_directory(deeplabv3_train_repo + "train.lmdb");
fileops::clear_directory(deeplabv3_train_repo + "test_0.lmdb");
fileops::remove_dir(deeplabv3_train_repo + "train.lmdb");
fileops::remove_dir(deeplabv3_train_repo + "test_0.lmdb");
}

TEST(torchapi, service_publish_trained_model)
{
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
Expand Down

0 comments on commit 542bcb4

Please sign in to comment.