Skip to content

Commit

Permalink
fix(torch): when reading bbox dataset, also check that the class is n…
Browse files Browse the repository at this point in the history
…ot >= nclasses
  • Loading branch information
Bycob authored and mergify[bot] committed Mar 7, 2023
1 parent df318cb commit 7b2de88
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/backends/torch/torchdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -871,11 +871,12 @@ namespace dd
std::string val;
iss >> val;
int cls = std::stoi(val);
if (cls <= 0)
if (cls <= 0 || cls >= static_cast<int>(inputc->_nclasses))
{
throw InputConnectorBadParamException(
"Dataset contains an invalid class:" + std::to_string(cls)
+ " in file " + bboxfname);
"Dataset contains an invalid class: " + std::to_string(cls)
+ " in file " + bboxfname
+ " (nclasses=" + std::to_string(inputc->_nclasses) + ")");
}
classes.push_back(target_to_tensor(cls));

Expand Down
6 changes: 4 additions & 2 deletions src/backends/torch/torchinputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace dd
TorchInputInterface(const TorchInputInterface &i)
: _lm_params(i._lm_params), _dataset(i._dataset),
_test_datasets(i._test_datasets), _input_format(i._input_format),
_ctc(i._ctc), _ntargets(i._ntargets),
_ctc(i._ctc), _nclasses(i._nclasses), _ntargets(i._ntargets),
_alphabet_size(i._alphabet_size), _tilogger(i._tilogger), _db(i._db)
{
}
Expand Down Expand Up @@ -191,7 +191,9 @@ namespace dd
TorchMultipleDataset _test_datasets; /**< test datasets */
std::string _input_format; /**< for text, "bert" or nothing */

bool _ctc = false; /**< whether this is a CTC service */
bool _ctc = false; /**< whether this is a CTC service */
unsigned int _nclasses = 0; /**< number of classes for classification /
detection / segmentation */
unsigned int _ntargets
= 0; /**< number of targets for regression / timeseries */
int _alphabet_size = 0; /**< alphabet size for text prediction model */
Expand Down
1 change: 1 addition & 0 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ namespace dd
if (mllib_dto->nclasses != 0)
{
_nclasses = mllib_dto->nclasses;
this->_inputc._nclasses = _nclasses;
}
else if (mllib_dto->ntargets != 0)
{
Expand Down

0 comments on commit 7b2de88

Please sign in to comment.