diff --git a/amalgamation/python/mxnet_predict.py b/amalgamation/python/mxnet_predict.py index 627f375e1411..ca72e9affaa1 100644 --- a/amalgamation/python/mxnet_predict.py +++ b/amalgamation/python/mxnet_predict.py @@ -26,6 +26,7 @@ import os import sys import ctypes +import logging import numpy as np __all__ = ["Predictor", "load_ndarray_file"] @@ -51,15 +52,25 @@ def c_array(ctype, values): def _find_lib_path(): """Find mxnet library.""" curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - api_path = os.path.join(curr_path, '../../lib/') - dll_path = [curr_path, api_path] - dll_path = [os.path.join(p, 'libmxnet.so') for p in dll_path] + \ - [os.path.join(p, 'libmxnet_predict.so') for p in dll_path] - lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] - if len(lib_path) == 0: - raise RuntimeError('Cannot find the files.\n' + - 'List of candidates:\n' + str('\n'.join(dll_path))) - return lib_path + amalgamation_lib_path = os.path.join(curr_path, '../../lib/libmxnet_predict.so') + if os.path.exists(amalgamation_lib_path) and os.path.isfile(amalgamation_lib_path): + lib_path = [amalgamation_lib_path] + return lib_path + else: + logging.info('Cannot find libmxnet_predict.so. Will search for MXNet library using libinfo.py then.') + try: + from mxnet.libinfo import find_lib_path + lib_path = find_lib_path() + return lib_path + except ImportError: + libinfo_path = os.path.join(curr_path, '../../python/mxnet/libinfo.py') + if os.path.exists(libinfo_path) and os.path.isfile(libinfo_path): + libinfo = {'__file__': libinfo_py} + exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo) + lib_path = libinfo['find_lib_path']() + return lib_path + else: + raise RuntimeError('Cannot find libinfo.py at %s.' % libinfo_path) def _load_lib(): @@ -159,6 +170,39 @@ def forward(self, **kwargs): mx_uint(v.size))) _check_call(_LIB.MXPredForward(self.handle)) + def reshape(self, input_shapes): + """Change the input shape of the predictor. + + Parameters + ---------- + input_shapes : dict of str to tuple + The new shape of input data. + + Examples + -------- + >>> predictor.reshape({'data':data_shape_tuple}) + """ + indptr = [0] + sdata = [] + keys = [] + for k, v in input_shapes.items(): + if not isinstance(v, tuple): + raise ValueError("Expect input_shapes to be dict str->tuple") + keys.append(c_str(k)) + sdata.extend(v) + indptr.append(len(sdata)) + + new_handle = PredictorHandle() + _check_call(_LIB.MXPredReshape( + mx_uint(len(indptr) - 1), + c_array(ctypes.c_char_p, keys), + c_array(mx_uint, indptr), + c_array(mx_uint, sdata), + self.handle, + ctypes.byref(new_handle))) + _check_call(_LIB.MXPredFree(self.handle)) + self.handle = new_handle + def get_output(self, index): """Get the index-th output. diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index becb0cb364f6..d84a89ab2133 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -140,6 +140,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str, } sym = nnvm::Symbol::CreateGroup(out_syms); } + ret->sym = sym; // load the parameters std::unordered_map arg_params, aux_params; @@ -214,6 +215,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str, } Context ctx = Context::Create(static_cast(dev_type), dev_id); + ret->ctx = ctx; std::vector arg_arrays, aux_arrays; for (size_t i = 0; i < arg_shapes.size(); ++i) { @@ -231,6 +233,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str, aux_arrays.push_back(nd); } ret->arg_arrays = arg_arrays; + ret->aux_arrays = aux_arrays; // bind { std::map ctx_map; @@ -309,7 +312,6 @@ int MXPredReshape(mx_uint num_input_nodes, << " shape has been changed, only allow to change the shape of input data."; } } - p->arg_arrays.clear(); for (size_t i=0; i < aux_names.size(); ++i) { TShape newShape = aux_shapes[i]; @@ -319,7 +321,6 @@ int MXPredReshape(mx_uint num_input_nodes, << " shape has been changed, only allow to change the shape of input data."; } ret->aux_arrays = p->aux_arrays; - p->aux_arrays.clear(); // bind { diff --git a/tests/python/unittest/test_predictor.py b/tests/python/unittest/test_predictor.py new file mode 100644 index 000000000000..fc2fbf600cbc --- /dev/null +++ b/tests/python/unittest/test_predictor.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import print_function +import sys, os +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, "../../../amalgamation/python/")) +from mxnet_predict import Predictor, load_ndarray_file + +import numpy as np +import mxnet as mx +import mxnet.ndarray as nd +from mxnet import gluon +from mxnet.test_utils import assert_almost_equal +from common import setup_module, with_seed, teardown + +@with_seed() +def test_predictor(): + prefix = 'test_predictor_simple_dense' + symbol_file = "%s-symbol.json" % prefix + param_file = "%s-0000.params" % prefix + + # two inputs with different batch sizes + input1 = np.random.uniform(size=(1,3)) + input2 = np.random.uniform(size=(3,3)) + + # define a simple model + block = gluon.nn.HybridSequential() + block.add(gluon.nn.Dense(7)) + block.add(gluon.nn.Dense(3)) + block.hybridize() + block.initialize() + out1 = block.forward(nd.array(input1)) + out2 = block.forward(nd.array(input2)) + block.export(prefix) + + # create a predictor + predictor = Predictor(open(symbol_file, "r").read(), + open(param_file, "rb").read(), + {'data':input1.shape}) + + # forward and get output + predictor.forward(data=input1) + predictor_out1 = predictor.get_output(0) + assert_almost_equal(out1.asnumpy(), predictor_out1, rtol=1e-5, atol=1e-6) + + # reshape + predictor.reshape({'data':input2.shape}) + predictor.forward(data=input2) + predictor_out2 = predictor.get_output(0) + assert_almost_equal(out2.asnumpy(), predictor_out2, rtol=1e-5, atol=1e-6) + + # destroy the predictor + del predictor + +@with_seed() +def test_load_ndarray(): + nd_file = 'test_predictor_load_ndarray.params' + a = nd.random.uniform(shape=(7, 3)) + b = nd.random.uniform(shape=(7,)) + nd_data = {'a':a, 'b':b} + nd.save(nd_file, nd_data) + + # test load_ndarray_file + nd_load = load_ndarray_file(open(nd_file, "rb").read()) + assert(set(nd_data.keys()) == set(nd_load.keys())) + for k in nd_data.keys(): + assert_almost_equal(nd_data[k].asnumpy(), nd_load[k], rtol=1e-5, atol=1e-6) + + +if __name__ == '__main__': + import nose + nose.runmodule()