Skip to content

Commit

Permalink
feat: torch segmentation model prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 4, 2021
1 parent bbfcce6 commit d72a138
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 60 deletions.
30 changes: 21 additions & 9 deletions demo/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,21 @@
parser.add_argument("--width",help="image width",type=int,default=480)
parser.add_argument("--height",help="image height",type=int,default=480)
parser.add_argument("--model-dir",help="model directory")
parser.add_argument("--mllib",default="caffe",help="caffe or torch")
parser.add_argument("--scale",type=float,default=1.0,help="scaling factor, e.g. 0.0044")
parser.add_argument("--rgb",action="store_true",help="whether to use RGB output, e.g. for torch pretrained models")
parser.add_argument("--confidences",type=str,help="whether to output the confidence map, e.g. best",default='')
args = parser.parse_args();

host = 'localhost'
port = 8080
sname = 'segserv'
description = 'image segmentation'
mllib = 'caffe'
mltype = 'unsupervised'
mllib = args.mllib
if mllib == 'caffe':
mltype = 'unsupervised'
else:
mltype = 'supervised'
nclasses = args.nclasses
width = args.width
height = args.height
Expand All @@ -38,7 +45,7 @@ def random_color():
model_repo = os.getcwd() + '/model/'
model = {'repository':model_repo}
parameters_input = {'connector':'image','width':width,'height':height}
parameters_mllib = {'nclasses':nclasses}
parameters_mllib = {'nclasses':nclasses,'segmentation':True,'gpu':True,'gpuid':0}
parameters_output = {}
try:
servput = dd.put_service(sname,model,description,mllib,
Expand All @@ -47,13 +54,15 @@ def random_color():
pass

# prediction call
parameters_input = {'segmentation':True}
parameters_mllib = {'gpu':True,'gpuid':0}
parameters_output = {}
parameters_input = {'scale':args.scale,'rgb':args.rgb}
parameters_mllib = {'segmentation':True}
parameters_output = {'confidences':[args.confidences]}
data = [args.image]
detect = dd.post_predict(sname,data,parameters_input,parameters_mllib,parameters_output)

pixels = np.array((map(int,detect['body']['predictions'][0]['vals'])))
#print(detect['body']['predictions'][0])
pixels = np.array(detect['body']['predictions'][0]['vals']).astype(int)
if args.confidences:
pixels_confs = np.array(detect['body']['predictions'][0]['confidences']['best'])
imgsize = detect['body']['predictions'][0]['imgsize']

# visual output
Expand All @@ -79,5 +88,8 @@ def random_color():
rgb[:,:,2] = b/255.0

plt.figure()
plt.imshow(rgb,vmin=0,vmax=1)
if not args.confidences:
plt.imshow(rgb,vmin=0,vmax=1)
else:
plt.imshow(np.reshape(pixels_confs,(imgsize['height'],imgsize['width'])),vmin=0,vmax=1)
plt.show()
46 changes: 12 additions & 34 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3221,20 +3221,21 @@ namespace dd
* (*bit).second.second) // resizing output
// segmentation array
{
vals = img_resize(vals, inputc.height(),
inputc.width(), (*bit).second.first,
(*bit).second.second, true);
vals = ImgInputFileConn::img_resize_vector(
vals, inputc.height(), inputc.width(),
(*bit).second.first, (*bit).second.second, true);
if (conf_best)
conf_map_best
= img_resize(conf_map_best, inputc.height(),
inputc.width(), (*bit).second.first,
(*bit).second.second, false);
conf_map_best = ImgInputFileConn::img_resize_vector(
conf_map_best, inputc.height(), inputc.width(),
(*bit).second.first, (*bit).second.second,
false);
for (int ci = 0; ci < _nclasses; ++ci)
if (confidences[ci])
confidence_maps[ci] = img_resize(
confidence_maps[ci], inputc.height(),
inputc.width(), (*bit).second.first,
(*bit).second.second, false);
confidence_maps[ci]
= ImgInputFileConn::img_resize_vector(
confidence_maps[ci], inputc.height(),
inputc.width(), (*bit).second.first,
(*bit).second.second, false);
}
rad.add("vals", vals);
if (conf_best || !confidence_maps.empty())
Expand Down Expand Up @@ -5195,29 +5196,6 @@ namespace dd
return nullptr;
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
std::vector<double>
CaffeLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::img_resize(const std::vector<double> &vals,
const int height_net, const int width_net,
const int height_dest, const int width_dest,
bool resize_nn)
{
cv::Mat segimg = cv::Mat(height_net, width_net, CV_64FC1);
std::memcpy(segimg.data, vals.data(), vals.size() * sizeof(double));
cv::Mat segimg_res;
if (resize_nn)
cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0,
cv::INTER_NEAREST);
else
cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0,
cv::INTER_LINEAR);
return std::vector<double>((double *)segimg_res.data,
(double *)segimg_res.data
+ segimg_res.rows * segimg_res.cols);
}

template class CaffeLib<ImgCaffeInputFileConn, SupervisedOutput, CaffeModel>;
template class CaffeLib<CSVCaffeInputFileConn, SupervisedOutput, CaffeModel>;
template class CaffeLib<CSVTSCaffeInputFileConn, SupervisedOutput,
Expand Down
5 changes: 0 additions & 5 deletions src/backends/caffe/caffelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,6 @@ namespace dd
boost::shared_ptr<Blob<float>> findBlobByName(const caffe::Net<float> *net,
const std::string blob_name);

std::vector<double> img_resize(const std::vector<double> &vals,
const int height_net, const int width_net,
const int height_dest, const int width_dest,
bool resize_nn);

bool is_refinedet(caffe::NetParameter &net_param);

public:
Expand Down
100 changes: 95 additions & 5 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ namespace dd
_regression = tl._regression;
_timeserie = tl._timeserie;
_bbox = tl._bbox;
_segmentation = tl._segmentation;
_loss = tl._loss;
_template_params = tl._template_params;
_dtype = tl._dtype;
Expand Down Expand Up @@ -272,7 +273,11 @@ namespace dd
_main_device = _devices[0];

// Set model type
if (mllib_dto->nclasses != 0)
if (mllib_dto->segmentation)
{
_segmentation = true;
}
else if (mllib_dto->nclasses != 0)
{
_classification = true;
}
Expand Down Expand Up @@ -315,7 +320,8 @@ namespace dd
{
_timeserie = true;
}
if (!_regression && !_timeserie && !_bbox && self_supervised.empty())
if (!_regression && !_timeserie && !_bbox && !_segmentation
&& self_supervised.empty())
_classification = true; // classification is default

// Set mltype
Expand All @@ -327,6 +333,8 @@ namespace dd
this->_mltype = "regression";
else if (_classification)
this->_mltype = "classification";
else if (_segmentation)
this->_mltype = "segmentation";

// Create the model
_module._device = _main_device;
Expand Down Expand Up @@ -429,7 +437,7 @@ namespace dd
throw MLLibBadParamException("invalid torch model template "
+ _template);
}
if (_classification || _regression)
if (_classification || _regression || _segmentation)
{
_module._nclasses = _nclasses;

Expand Down Expand Up @@ -1161,6 +1169,7 @@ namespace dd
double confidence_threshold = 0.0;
int best_count = _nclasses;
int best_bbox = -1;
std::vector<std::string> confidences;

if (params.has("mllib"))
{
Expand Down Expand Up @@ -1225,6 +1234,11 @@ namespace dd
{
best_bbox = output_params.get("best_bbox").get<int>();
}
if (output_params.has("confidences"))
{
confidences
= output_params.get("confidences").get<std::vector<std::string>>();
}

bool lstm_continuation = false;
TInputConnectorStrategy inputc(this->_inputc);
Expand Down Expand Up @@ -1310,7 +1324,7 @@ namespace dd
else
out_ivalue = _module.extract(in_vals, extract_layer);

if (!bbox)
if (!bbox && !_segmentation)
{
output = torch_utils::to_tensor_safe(out_ivalue);

Expand Down Expand Up @@ -1511,6 +1525,82 @@ namespace dd
results_ads.push_back(results_ad);
}
}
else if (_segmentation)
{
auto out_dict = out_ivalue.toGenericDict();
output = torch_utils::to_tensor_safe(out_dict.at("out"));
output = torch::softmax(output, 1);

int imgsize = inputc.width() * inputc.height();
torch::Tensor segmap;
torch::Tensor confmap;
std::tuple<torch::Tensor, torch::Tensor> maxmap;
if (!confidences.empty()) // applies "best" confidence lookup
{
maxmap = torch::max(output.clone().squeeze(), 0, false);
confmap = torch::flatten(std::get<0>(maxmap))
.contiguous()
.to(torch::kFloat64)
.to(cpu);
segmap = torch::flatten(std::get<1>(maxmap))
.contiguous()
.to(torch::kFloat64)
.to(cpu);
}
else
segmap = torch::flatten(torch::argmax(output.squeeze(), 0))
.contiguous()
.to(torch::kFloat64)
.to(cpu); // squeeze removes the batch size

APIData rad;
std::string uri;
if (!inputc._ids.empty())
uri = inputc._ids.at(results_ads.size());
else
uri = std::to_string(results_ads.size());
rad.add("uri", uri);
rad.add("loss", static_cast<double>(0.0));
double *startout = segmap.data_ptr<double>();
std::vector<double> vals(startout,
startout + torch::numel(segmap));
std::vector<double> confs;
if (!confidences.empty())
{
startout = confmap.data_ptr<double>();
confs = std::vector<double>(
startout, startout + torch::numel(confmap));
}

auto bit = inputc._imgs_size.find(uri);
APIData ad_imgsize;
ad_imgsize.add("height", (*bit).second.first);
ad_imgsize.add("width", (*bit).second.second);
rad.add("imgsize", ad_imgsize);

if (imgsize
!= (*bit).second.first
* (*bit).second.second) // resizing output
// segmentation array
{
vals = ImgInputFileConn::img_resize_vector(
vals, inputc.height(), inputc.width(),
(*bit).second.first, (*bit).second.second, true);
if (!confidences.empty())
confs = ImgInputFileConn::img_resize_vector(
confs, inputc.height(), inputc.width(),
(*bit).second.first, (*bit).second.second, false);
}

rad.add("vals", vals);
if (!confidences.empty())
{
APIData vconfs;
vconfs.add("best", confs);
rad.add("confidences", vconfs);
}
results_ads.push_back(rad);
}
else if (_regression)
{
auto probs_acc = output.accessor<float, 2>();
Expand Down Expand Up @@ -1570,7 +1660,7 @@ namespace dd
}
}

if (extract_layer.empty())
if (extract_layer.empty() && !_segmentation)
{
outputc.add_results(results_ads);

Expand Down
3 changes: 2 additions & 1 deletion src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ namespace dd
bool _seq_training = false; /**< true for bert/gpt2*/
bool _classification = false; /**< select classification type problem*/
bool _regression = false; /**< select regression type problem. */
bool _timeserie = false; /**< select timeserie type problem*/
bool _timeserie = false; /**< select timeserie type problem */
bool _bbox = false; /**< select detection type problem */
bool _segmentation = false; /**< select segmentation type problem */
std::string _loss = ""; /**< selected loss*/

APIData _template_params; /**< template parameters, for recurrent and
Expand Down
6 changes: 6 additions & 0 deletions src/dto/mllib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ namespace dd
};
DTO_FIELD(Int32, ntargets) = 0;

DTO_FIELD_INFO(segmentation)
{
info->description = "whether the model type is segmentation";
};
DTO_FIELD(Boolean, segmentation) = false;

DTO_FIELD_INFO(from_repository)
{
info->description = "initialize model repository with checkpoint and "
Expand Down
1 change: 1 addition & 0 deletions src/dto/output_connector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace dd
DTO_FIELD(Float32, confidence_threshold) = 0.0;
DTO_FIELD(Int32, best);
DTO_FIELD(Int32, best_bbox) = -1;
DTO_FIELD(Vector<String>, confidences);

DTO_FIELD_INFO(image)
{
Expand Down
19 changes: 19 additions & 0 deletions src/imginputfileconn.h
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,25 @@ namespace dd
throw InputConnectorBadParamException("no image could be found");
}

static std::vector<double>
img_resize_vector(const std::vector<double> &vals, const int height_net,
const int width_net, const int height_dest,
const int width_dest, bool resize_nn)
{
cv::Mat segimg = cv::Mat(height_net, width_net, CV_64FC1);
std::memcpy(segimg.data, vals.data(), vals.size() * sizeof(double));
cv::Mat segimg_res;
if (resize_nn)
cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0,
cv::INTER_NEAREST);
else
cv::resize(segimg, segimg_res, cv::Size(width_dest, height_dest), 0, 0,
cv::INTER_LINEAR);
return std::vector<double>((double *)segimg_res.data,
(double *)segimg_res.data
+ segimg_res.rows * segimg_res.cols);
}

// data
std::vector<cv::Mat> _images;
std::vector<cv::Mat> _orig_images; /**< stored upon request. */
Expand Down
7 changes: 7 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ if (USE_TORCH)
"fasterrcnn_train_torch_bs2.tar.gz"
"fasterrcnn_train_torch"
)
DOWNLOAD_DATASET(
"Torchvision training DeepLabV3 Resnet50 model"
"https://www.deepdetect.com/dd/examples/torch/deeplabv3_torch.tar.gz"
"examples/torch"
"deeplabv3_torch.tar.gz"
"deeplabv3_torch"
)
DOWNLOAD_DATASET(
"Torch BERT classification test model"
"https://www.deepdetect.com/dd/examples/torch/bert_inference_torch.tar.gz"
Expand Down
Loading

0 comments on commit d72a138

Please sign in to comment.