Skip to content

Commit

Permalink
Modify MXNet source to avoid rebinding.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Sep 11, 2018
1 parent 03dd2cd commit 8c92f99
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
10 changes: 3 additions & 7 deletions PhysicsTools/MXNet/interface/MXNetCppPredictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ class MXNetCppPredictor {
MXNetCppPredictor(const Block &block, const std::string &output_node);
virtual ~MXNetCppPredictor();

// set input array shapes
void set_input_shapes(const std::vector<std::string>& input_names, const std::vector<std::vector<mx_uint>>& input_shapes);

// run prediction
const std::vector<float>& predict(const std::vector<std::vector<mx_float>>& input_data);

private:
static std::mutex mutex_;

void infer_shapes();
void bind_executor();

// context
Expand All @@ -78,12 +80,6 @@ class MXNetCppPredictor {
// names of the input nodes
std::vector<std::string> input_names_;

// internal states
std::vector<NDArray> arg_arrays;
std::vector<NDArray> grad_arrays;
std::vector<OpReqType> grad_reqs;
std::vector<NDArray> aux_arrays;

};

} /* namespace cpp */
Expand Down
20 changes: 8 additions & 12 deletions PhysicsTools/MXNet/src/MXNetCppPredictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,15 @@ void MXNetCppPredictor::set_input_shapes(const std::vector<std::string>& input_n
NDArray nd(input_shapes[i], context_, false);
arg_map_[name] = nd;
}
// infer parameter shapes from input shapes
infer_shapes();
}

const std::vector<float>& MXNetCppPredictor::predict(const std::vector<std::vector<mx_float> >& input_data) {
assert(input_names_.size() == input_data.size());

try {
// bind executor
bind_executor();
// create the executor (if not done yet)
if (!exec_) { bind_executor(); }
assert(exec_);
// set the inputs
for (unsigned i=0; i<input_names_.size(); ++i){
const auto& name = input_names_[i];
Expand All @@ -91,7 +90,7 @@ const std::vector<float>& MXNetCppPredictor::predict(const std::vector<std::vect
}
}

void MXNetCppPredictor::infer_shapes() {
void MXNetCppPredictor::bind_executor() {
// acquire lock
std::lock_guard<std::mutex> lock(mutex_);

Expand All @@ -109,7 +108,7 @@ void MXNetCppPredictor::infer_shapes() {
sym_.InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);

// init argument arrays
arg_arrays.clear();
std::vector<NDArray> arg_arrays;
for (size_t i = 0; i < in_shapes.size(); ++i) {
const auto &shape = in_shapes[i];
const auto &arg_name = arg_name_list[i];
Expand All @@ -120,11 +119,11 @@ void MXNetCppPredictor::infer_shapes() {
arg_arrays.push_back(NDArray(shape, context_, false));
}
}
grad_arrays = std::vector<NDArray>(arg_arrays.size());
grad_reqs = std::vector<OpReqType>(arg_arrays.size(), kNullOp);
std::vector<NDArray> grad_arrays(arg_arrays.size());
std::vector<OpReqType> grad_reqs(arg_arrays.size(), kNullOp);

// init auxiliary array
aux_arrays.clear();
std::vector<NDArray> aux_arrays;
const auto aux_name_list = sym_.ListAuxiliaryStates();
for (size_t i = 0; i < aux_shapes.size(); ++i) {
const auto &shape = aux_shapes[i];
Expand All @@ -137,9 +136,6 @@ void MXNetCppPredictor::infer_shapes() {
}
}

}

void MXNetCppPredictor::bind_executor() {
// bind executor
exec_.reset(new Executor(sym_, context_, arg_arrays, grad_arrays, grad_reqs, aux_arrays));
}
Expand Down

0 comments on commit 8c92f99

Please sign in to comment.