Skip to content

Commit

Permalink
fix: enable caffe chain with DTO & custom actions
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Jun 2, 2022
1 parent 5146da7 commit d3e722e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 2 deletions.
28 changes: 27 additions & 1 deletion src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2815,12 +2815,38 @@ namespace dd
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
int CaffeLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::predict(const APIData &ad, APIData &out)
TMLModel>::predict(const APIData &ad_in, APIData &out)
{
std::lock_guard<std::mutex> lock(
_net_mutex); // no concurrent calls since the net is not
// re-instantiated

APIData ad;
if (ad_in.has("dto"))
{
// cast to ServicePredict...
auto any = ad_in.get("dto").get<oatpp::Any>();
oatpp::Object<DTO::ServicePredict> predict_dto
= oatpp::Object<DTO::ServicePredict>(
std::static_pointer_cast<typename DTO::ServicePredict>(
any->ptr));
ad = APIData::fromDTO<oatpp::Void>(predict_dto);

if (predict_dto->_chain)
{
ad.add("chain", predict_dto->_chain);
if (!predict_dto->_data_raw_img.empty())
ad.add("data_raw_img", predict_dto->_data_raw_img);
ad.add("ids", predict_dto->_ids);
ad.add("meta_uris", predict_dto->_meta_uris);
ad.add("index_uris", predict_dto->_index_uris);
}
}
else
{
ad = ad_in;
}

// check for net
if (!_net || _net->phase() == caffe::TRAIN)
{
Expand Down
8 changes: 8 additions & 0 deletions src/dto/chain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ namespace dd

// dlib image align
DTO_FIELD(Int32, chip_size) = 150;

// custom action
DTO_FIELD_INFO(custom)
{
info->description = "[custom] parameters for custom action";
}
DTO_FIELD(UnorderedFields<Any>, custom)
= UnorderedFields<Any>::createShared();
};

class ChainAction : public oatpp::DTO
Expand Down
19 changes: 18 additions & 1 deletion src/utils/oatpp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ namespace dd
DTO::vectorSerialize<uint8_t>);
ser->setSerializerMethod(DTO::DTOVector<bool>::Class::CLASS_ID,
DTO::vectorSerialize<bool>);

ser->getConfig()->includeNullFields = false;
return object_mapper;
}

Expand Down Expand Up @@ -180,6 +182,19 @@ namespace dd
jval.PushBack(elemJVal, jdoc.GetAllocator());
}
}
else if (polymorph.valueType->classId.id
== oatpp::data::mapping::type::__class::AbstractList::CLASS_ID
.id)
{
auto list = polymorph.staticCast<oatpp::AbstractList>();
jval = JVal(rapidjson::kArrayType);
for (auto &elem : *list)
{
JVal elemJVal;
dtoToJVal(elem, jdoc, elemJVal, ignore_null);
jval.PushBack(elemJVal, jdoc.GetAllocator());
}
}
else if (polymorph.valueType->classId.id
== oatpp::data::mapping::type::__class::AbstractPairList::
CLASS_ID.id)
Expand Down Expand Up @@ -247,7 +262,9 @@ namespace dd
}
else
{
throw std::runtime_error("dtoToJVal: Type not recognised");
std::string type_name = polymorph.valueType->classId.name;
throw std::runtime_error("dtoToJVal: \"" + type_name
+ "\": type not recognised");
}
}
}
Expand Down

0 comments on commit d3e722e

Please sign in to comment.