From 8c92f9905ebe7d658087dad9817c167e69f4e401 Mon Sep 17 00:00:00 2001 From: Huilin Qu Date: Mon, 13 Aug 2018 15:51:05 +0200 Subject: [PATCH] Modify MXNet source to avoid rebinding. --- .../MXNet/interface/MXNetCppPredictor.h | 10 +++------- PhysicsTools/MXNet/src/MXNetCppPredictor.cc | 20 ++++++++----------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/PhysicsTools/MXNet/interface/MXNetCppPredictor.h b/PhysicsTools/MXNet/interface/MXNetCppPredictor.h index 140944f921db8..ed97731ac1574 100644 --- a/PhysicsTools/MXNet/interface/MXNetCppPredictor.h +++ b/PhysicsTools/MXNet/interface/MXNetCppPredictor.h @@ -54,13 +54,15 @@ class MXNetCppPredictor { MXNetCppPredictor(const Block &block, const std::string &output_node); virtual ~MXNetCppPredictor(); + // set input array shapes void set_input_shapes(const std::vector& input_names, const std::vector>& input_shapes); + + // run prediction const std::vector& predict(const std::vector>& input_data); private: static std::mutex mutex_; - void infer_shapes(); void bind_executor(); // context @@ -78,12 +80,6 @@ class MXNetCppPredictor { // names of the input nodes std::vector input_names_; - // internal states - std::vector arg_arrays; - std::vector grad_arrays; - std::vector grad_reqs; - std::vector aux_arrays; - }; } /* namespace cpp */ diff --git a/PhysicsTools/MXNet/src/MXNetCppPredictor.cc b/PhysicsTools/MXNet/src/MXNetCppPredictor.cc index 2511f58f9947c..2ed1bb6b72be6 100644 --- a/PhysicsTools/MXNet/src/MXNetCppPredictor.cc +++ b/PhysicsTools/MXNet/src/MXNetCppPredictor.cc @@ -66,16 +66,15 @@ void MXNetCppPredictor::set_input_shapes(const std::vector& input_n NDArray nd(input_shapes[i], context_, false); arg_map_[name] = nd; } - // infer parameter shapes from input shapes - infer_shapes(); } const std::vector& MXNetCppPredictor::predict(const std::vector >& input_data) { assert(input_names_.size() == input_data.size()); try { - // bind executor - bind_executor(); + // create the executor (if not done yet) + if (!exec_) { bind_executor(); } + assert(exec_); // set the inputs for (unsigned i=0; i& MXNetCppPredictor::predict(const std::vector lock(mutex_); @@ -109,7 +108,7 @@ void MXNetCppPredictor::infer_shapes() { sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes); // init argument arrays - arg_arrays.clear(); + std::vector arg_arrays; for (size_t i = 0; i < in_shapes.size(); ++i) { const auto &shape = in_shapes[i]; const auto &arg_name = arg_name_list[i]; @@ -120,11 +119,11 @@ void MXNetCppPredictor::infer_shapes() { arg_arrays.push_back(NDArray(shape, context_, false)); } } - grad_arrays = std::vector(arg_arrays.size()); - grad_reqs = std::vector(arg_arrays.size(), kNullOp); + std::vector grad_arrays(arg_arrays.size()); + std::vector grad_reqs(arg_arrays.size(), kNullOp); // init auxiliary array - aux_arrays.clear(); + std::vector aux_arrays; const auto aux_name_list = sym_.ListAuxiliaryStates(); for (size_t i = 0; i < aux_shapes.size(); ++i) { const auto &shape = aux_shapes[i]; @@ -137,9 +136,6 @@ void MXNetCppPredictor::infer_shapes() { } } -} - -void MXNetCppPredictor::bind_executor() { // bind executor exec_.reset(new Executor(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays)); }