diff --git a/demo/segmentation/segment.py b/demo/segmentation/segment.py index 622657ef9..502757bf0 100644 --- a/demo/segmentation/segment.py +++ b/demo/segmentation/segment.py @@ -13,14 +13,21 @@ parser.add_argument("--width",help="image width",type=int,default=480) parser.add_argument("--height",help="image height",type=int,default=480) parser.add_argument("--model-dir",help="model directory") +parser.add_argument("--mllib",default="caffe",help="caffe or torch") +parser.add_argument("--scale",type=float,default=1.0,help="scaling factor, e.g. 0.0044") +parser.add_argument("--rgb",action="store_true",help="whether to use RGB output, e.g. for torch pretrained models") +parser.add_argument("--confidences",type=str,help="whether to output the confidence map, e.g. best",default='') args = parser.parse_args(); host = 'localhost' port = 8080 sname = 'segserv' description = 'image segmentation' -mllib = 'caffe' -mltype = 'unsupervised' +mllib = args.mllib +if mllib == 'caffe': + mltype = 'unsupervised' +else: + mltype = 'supervised' nclasses = args.nclasses width = args.width height = args.height @@ -38,7 +45,7 @@ def random_color(): model_repo = os.getcwd() + '/model/' model = {'repository':model_repo} parameters_input = {'connector':'image','width':width,'height':height} -parameters_mllib = {'nclasses':nclasses} +parameters_mllib = {'nclasses':nclasses,'segmentation':True,'gpu':True,'gpuid':0} parameters_output = {} try: servput = dd.put_service(sname,model,description,mllib, @@ -47,13 +54,15 @@ def random_color(): pass # prediction call -parameters_input = {'segmentation':True} -parameters_mllib = {'gpu':True,'gpuid':0} -parameters_output = {} +parameters_input = {'scale':args.scale,'rgb':args.rgb} +parameters_mllib = {'segmentation':True} +parameters_output = {'confidences':[args.confidences]} data = [args.image] detect = dd.post_predict(sname,data,parameters_input,parameters_mllib,parameters_output) - -pixels = np.array((map(int,detect['body']['predictions'][0]['vals']))) +#print(detect['body']['predictions'][0]) +pixels = np.array(detect['body']['predictions'][0]['vals']).astype(int) +if args.confidences: + pixels_confs = np.array(detect['body']['predictions'][0]['confidences']['best']) imgsize = detect['body']['predictions'][0]['imgsize'] # visual output @@ -79,5 +88,8 @@ def random_color(): rgb[:,:,2] = b/255.0 plt.figure() -plt.imshow(rgb,vmin=0,vmax=1) +if not args.confidences: + plt.imshow(rgb,vmin=0,vmax=1) +else: + plt.imshow(np.reshape(pixels_confs,(imgsize['height'],imgsize['width'])),vmin=0,vmax=1) plt.show() diff --git a/src/backends/caffe/caffelib.cc b/src/backends/caffe/caffelib.cc index f9ba7bf1c..696111684 100644 --- a/src/backends/caffe/caffelib.cc +++ b/src/backends/caffe/caffelib.cc @@ -3221,20 +3221,21 @@ namespace dd * (*bit).second.second) // resizing output // segmentation array { - vals = img_resize(vals, inputc.height(), - inputc.width(), (*bit).second.first, - (*bit).second.second, true); + vals = ImgInputFileConn::img_resize_vector( + vals, inputc.height(), inputc.width(), + (*bit).second.first, (*bit).second.second, true); if (conf_best) - conf_map_best - = img_resize(conf_map_best, inputc.height(), - inputc.width(), (*bit).second.first, - (*bit).second.second, false); + conf_map_best = ImgInputFileConn::img_resize_vector( + conf_map_best, inputc.height(), inputc.width(), + (*bit).second.first, (*bit).second.second, + false); for (int ci = 0; ci < _nclasses; ++ci) if (confidences[ci]) - confidence_maps[ci] = img_resize( - confidence_maps[ci], inputc.height(), - inputc.width(), (*bit).second.first, - (*bit).second.second, false); + confidence_maps[ci] + = ImgInputFileConn::img_resize_vector( + confidence_maps[ci], inputc.height(), + inputc.width(), (*bit).second.first, + (*bit).second.second, false); } rad.add("vals", vals); if (conf_best || !confidence_maps.empty()) @@ -5195,29 +5196,6 @@ namespace dd return nullptr; } - template - std::vector - CaffeLib::img_resize(const std::vector &vals, - const int height_net, const int width_net, - const int height_dest, const int width_dest, - bool resize_nn) - { - cv::Mat segimg = cv::Mat(height_net, width_net, CV_64FC1); - std::memcpy(segimg.data, vals.data(), vals.size() * sizeof(double)); - cv::Mat segimg_res; - if (resize_nn) - cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0, - cv::INTER_NEAREST); - else - cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0, - cv::INTER_LINEAR); - return std::vector((double *)segimg_res.data, - (double *)segimg_res.data - + segimg_res.rows * segimg_res.cols); - } - template class CaffeLib; template class CaffeLib; template class CaffeLib> findBlobByName(const caffe::Net *net, const std::string blob_name); - std::vector img_resize(const std::vector &vals, - const int height_net, const int width_net, - const int height_dest, const int width_dest, - bool resize_nn); - bool is_refinedet(caffe::NetParameter &net_param); public: diff --git a/src/backends/torch/torchlib.cc b/src/backends/torch/torchlib.cc index 486ff3664..20eac7457 100644 --- a/src/backends/torch/torchlib.cc +++ b/src/backends/torch/torchlib.cc @@ -80,6 +80,7 @@ namespace dd _regression = tl._regression; _timeserie = tl._timeserie; _bbox = tl._bbox; + _segmentation = tl._segmentation; _loss = tl._loss; _template_params = tl._template_params; _dtype = tl._dtype; @@ -272,7 +273,11 @@ namespace dd _main_device = _devices[0]; // Set model type - if (mllib_dto->nclasses != 0) + if (mllib_dto->segmentation) + { + _segmentation = true; + } + else if (mllib_dto->nclasses != 0) { _classification = true; } @@ -315,7 +320,8 @@ namespace dd { _timeserie = true; } - if (!_regression && !_timeserie && !_bbox && self_supervised.empty()) + if (!_regression && !_timeserie && !_bbox && !_segmentation + && self_supervised.empty()) _classification = true; // classification is default // Set mltype @@ -327,6 +333,8 @@ namespace dd this->_mltype = "regression"; else if (_classification) this->_mltype = "classification"; + else if (_segmentation) + this->_mltype = "segmentation"; // Create the model _module._device = _main_device; @@ -429,7 +437,7 @@ namespace dd throw MLLibBadParamException("invalid torch model template " + _template); } - if (_classification || _regression) + if (_classification || _regression || _segmentation) { _module._nclasses = _nclasses; @@ -1161,6 +1169,7 @@ namespace dd double confidence_threshold = 0.0; int best_count = _nclasses; int best_bbox = -1; + std::vector confidences; if (params.has("mllib")) { @@ -1225,6 +1234,11 @@ namespace dd { best_bbox = output_params.get("best_bbox").get(); } + if (output_params.has("confidences")) + { + confidences + = output_params.get("confidences").get>(); + } bool lstm_continuation = false; TInputConnectorStrategy inputc(this->_inputc); @@ -1310,7 +1324,7 @@ namespace dd else out_ivalue = _module.extract(in_vals, extract_layer); - if (!bbox) + if (!bbox && !_segmentation) { output = torch_utils::to_tensor_safe(out_ivalue); @@ -1511,6 +1525,82 @@ namespace dd results_ads.push_back(results_ad); } } + else if (_segmentation) + { + auto out_dict = out_ivalue.toGenericDict(); + output = torch_utils::to_tensor_safe(out_dict.at("out")); + output = torch::softmax(output, 1); + + int imgsize = inputc.width() * inputc.height(); + torch::Tensor segmap; + torch::Tensor confmap; + std::tuple maxmap; + if (!confidences.empty()) // applies "best" confidence lookup + { + maxmap = torch::max(output.clone().squeeze(), 0, false); + confmap = torch::flatten(std::get<0>(maxmap)) + .contiguous() + .to(torch::kFloat64) + .to(cpu); + segmap = torch::flatten(std::get<1>(maxmap)) + .contiguous() + .to(torch::kFloat64) + .to(cpu); + } + else + segmap = torch::flatten(torch::argmax(output.squeeze(), 0)) + .contiguous() + .to(torch::kFloat64) + .to(cpu); // squeeze removes the batch size + + APIData rad; + std::string uri; + if (!inputc._ids.empty()) + uri = inputc._ids.at(results_ads.size()); + else + uri = std::to_string(results_ads.size()); + rad.add("uri", uri); + rad.add("loss", static_cast(0.0)); + double *startout = segmap.data_ptr(); + std::vector vals(startout, + startout + torch::numel(segmap)); + std::vector confs; + if (!confidences.empty()) + { + startout = confmap.data_ptr(); + confs = std::vector( + startout, startout + torch::numel(confmap)); + } + + auto bit = inputc._imgs_size.find(uri); + APIData ad_imgsize; + ad_imgsize.add("height", (*bit).second.first); + ad_imgsize.add("width", (*bit).second.second); + rad.add("imgsize", ad_imgsize); + + if (imgsize + != (*bit).second.first + * (*bit).second.second) // resizing output + // segmentation array + { + vals = ImgInputFileConn::img_resize_vector( + vals, inputc.height(), inputc.width(), + (*bit).second.first, (*bit).second.second, true); + if (!confidences.empty()) + confs = ImgInputFileConn::img_resize_vector( + confs, inputc.height(), inputc.width(), + (*bit).second.first, (*bit).second.second, false); + } + + rad.add("vals", vals); + if (!confidences.empty()) + { + APIData vconfs; + vconfs.add("best", confs); + rad.add("confidences", vconfs); + } + results_ads.push_back(rad); + } else if (_regression) { auto probs_acc = output.accessor(); @@ -1570,7 +1660,7 @@ namespace dd } } - if (extract_layer.empty()) + if (extract_layer.empty() && !_segmentation) { outputc.add_results(results_ads); diff --git a/src/backends/torch/torchlib.h b/src/backends/torch/torchlib.h index 8901cceff..38484b795 100644 --- a/src/backends/torch/torchlib.h +++ b/src/backends/torch/torchlib.h @@ -105,8 +105,9 @@ namespace dd bool _seq_training = false; /**< true for bert/gpt2*/ bool _classification = false; /**< select classification type problem*/ bool _regression = false; /**< select regression type problem. */ - bool _timeserie = false; /**< select timeserie type problem*/ + bool _timeserie = false; /**< select timeserie type problem */ bool _bbox = false; /**< select detection type problem */ + bool _segmentation = false; /**< select segmentation type problem */ std::string _loss = ""; /**< selected loss*/ APIData _template_params; /**< template parameters, for recurrent and diff --git a/src/dto/mllib.hpp b/src/dto/mllib.hpp index 24148efd2..9e7c2af62 100644 --- a/src/dto/mllib.hpp +++ b/src/dto/mllib.hpp @@ -54,6 +54,12 @@ namespace dd }; DTO_FIELD(Int32, ntargets) = 0; + DTO_FIELD_INFO(segmentation) + { + info->description = "whether the model type is segmentation"; + }; + DTO_FIELD(Boolean, segmentation) = false; + DTO_FIELD_INFO(from_repository) { info->description = "initialize model repository with checkpoint and " diff --git a/src/dto/output_connector.hpp b/src/dto/output_connector.hpp index e5c61ab90..1a639e46c 100644 --- a/src/dto/output_connector.hpp +++ b/src/dto/output_connector.hpp @@ -44,6 +44,7 @@ namespace dd DTO_FIELD(Float32, confidence_threshold) = 0.0; DTO_FIELD(Int32, best); DTO_FIELD(Int32, best_bbox) = -1; + DTO_FIELD(Vector, confidences); DTO_FIELD_INFO(image) { diff --git a/src/imginputfileconn.h b/src/imginputfileconn.h index 1c9f2342a..9ed86dd2f 100644 --- a/src/imginputfileconn.h +++ b/src/imginputfileconn.h @@ -801,6 +801,25 @@ namespace dd throw InputConnectorBadParamException("no image could be found"); } + static std::vector + img_resize_vector(const std::vector &vals, const int height_net, + const int width_net, const int height_dest, + const int width_dest, bool resize_nn) + { + cv::Mat segimg = cv::Mat(height_net, width_net, CV_64FC1); + std::memcpy(segimg.data, vals.data(), vals.size() * sizeof(double)); + cv::Mat segimg_res; + if (resize_nn) + cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0, + cv::INTER_NEAREST); + else + cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0, + cv::INTER_LINEAR); + return std::vector((double *)segimg_res.data, + (double *)segimg_res.data + + segimg_res.rows * segimg_res.cols); + } + // data std::vector _images; std::vector _orig_images; /**< stored upon request. */ diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8049175c4..3ec11f718 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -385,6 +385,13 @@ if (USE_TORCH) "fasterrcnn_train_torch_bs2.tar.gz" "fasterrcnn_train_torch" ) + DOWNLOAD_DATASET( + "Torchvision training DeepLabV3 Resnet50 model" + "https://www.deepdetect.com/dd/examples/torch/deeplabv3_torch.tar.gz" + "examples/torch" + "deeplabv3_torch.tar.gz" + "deeplabv3_torch" + ) DOWNLOAD_DATASET( "Torch BERT classification test model" "https://www.deepdetect.com/dd/examples/torch/bert_inference_torch.tar.gz" diff --git a/tests/ut-torchapi.cc b/tests/ut-torchapi.cc index 8d14f3ff8..102972fb3 100644 --- a/tests/ut-torchapi.cc +++ b/tests/ut-torchapi.cc @@ -1,7 +1,9 @@ /** * DeepDetect - * Copyright (c) 2019 Jolibrain - * Author: Louis Jean + * Copyright (c) 2019-2021 Jolibrain + * Author: Louis Jean + * Guillaume Infantes + * Emmanuel Benazera * * This file is part of deepdetect. * @@ -41,6 +43,7 @@ static std::string not_found_str static std::string incept_repo = "../examples/torch/resnet50_torch/"; static std::string detect_repo = "../examples/torch/fasterrcnn_torch/"; +static std::string seg_repo = "../examples/torch/deeplabv3_torch/"; static std::string detect_train_repo = "../examples/torch/fasterrcnn_train_torch"; static std::string resnet50_train_repo @@ -320,6 +323,43 @@ TEST(torchapi, service_predict_object_detection) ASSERT_EQ(preds_best.Size(), 3); } +TEST(torchapi, service_predict_segmentation) +{ + JsonAPI japi; + std::string sname = "segserv"; + std::string jstr + = "{\"mllib\":\"torch\",\"description\":\"deeplabv3\",\"type\":" + "\"supervised\",\"model\":{\"repository\":\"" + + seg_repo + + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":" + "224,\"width\":224,\"rgb\":true,\"scale\":0.0039},\"mllib\":{" + "\"segmentation\":true,\"nclasses\":21}}}"; + + std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); + ASSERT_EQ(created_str, joutstr); + std::string jpredictstr + = "{\"service\":\"segserv\",\"parameters\":{" + "\"input\":{\"height\":224," + "\"width\":224},\"output\":{\"segmentation\":true, " + "\"confidences\":[\"best\"]}},\"data\":[\"" + + seg_repo + "cat.jpg\"]}"; + + joutstr = japi.jrender(japi.service_predict(jpredictstr)); + JDoc jd; + // std::cout << "joutstr=" << joutstr << std::endl; + jd.Parse(joutstr.c_str()); + ASSERT_TRUE(!jd.HasParseError()); + ASSERT_EQ(200, jd["status"]["code"]); + ASSERT_TRUE(jd["body"]["predictions"].IsArray()); + + auto &preds = jd["body"]["predictions"][0]["vals"]; + auto &confs = jd["body"]["predictions"][0]["confidences"]["best"]; + ASSERT_TRUE(preds.IsArray()); + ASSERT_TRUE(confs.IsArray()); + ASSERT_TRUE(preds.Size() == 500 * 374); + ASSERT_TRUE(confs.Size() == 500 * 374); +} + TEST(torchapi, service_predict_txt_classification) { // create service diff --git a/tools/torch/trace_torchvision.py b/tools/torch/trace_torchvision.py index 45f7d70e8..33ac80634 100755 --- a/tools/torch/trace_torchvision.py +++ b/tools/torch/trace_torchvision.py @@ -43,6 +43,7 @@ help="Whether the exported models should not be pretrained") parser.add_argument('--cpu', action='store_true', help="Force models to be exported for CPU device") parser.add_argument('--num_classes', type=int, help="Number of classes") +parser.add_argument('--trace', action='store_true', help="Whether to trace model instead of scripting") args = parser.parse_args() @@ -169,6 +170,15 @@ def get_detection_input(): "retinanet_resnet50_fpn": M.detection.retinanet_resnet50_fpn, } model_classes.update(detection_model_classes) +segmentation_model_classes = { + "fcn_resnet50": M.segmentation.fcn_resnet50, + "fcn_resnet101": M.segmentation.fcn_resnet101, + "deeplabv3_resnet50": M.segmentation.deeplabv3_resnet50, + "deeplabv3_resnet101": M.segmentation.deeplabv3_resnet101, + "deeplabv3_mobilenet_v3_large": M.segmentation.deeplabv3_mobilenet_v3_large, + "lraspp_mobilenet_v3_large": M.segmentation.lraspp_mobilenet_v3_large +} +model_classes.update(segmentation_model_classes) if args.all: @@ -194,7 +204,8 @@ def get_detection_input(): logging.info("Exporting model %s %s", mname, "(pretrained)" if args.pretrained else "") detection = mname in detection_model_classes - + segmentation = mname in segmentation_model_classes + if detection: if "fasterrcnn" in mname and version.parse(torchvision.__version__) < version.parse("0.10.0"): raise RuntimeError("Fasterrcnn needs torchvision >= 0.10.0 (current = %s)" % torchvision.__version__) @@ -253,10 +264,14 @@ def get_detection_input(): model.eval() - # TODO try scripting instead of tracing - example = torch.rand(1, 3, 224, 224) - script_module = torch.jit.trace(model, example) + # tracing or scripting model (default) + if args.trace: + example = torch.rand(1, 3, 224, 224) + script_module = torch.jit.trace(model, example) + else: + script_module = torch.jit.script(model) + filename = os.path.join(args.output_dir, mname + ("-pretrained" if args.pretrained else "") + ("-" + args.backbone if args.backbone else "") + "-cls" + str(args.num_classes) + ".pt") logging.info("Saving to %s", filename) script_module.save(filename)