From f0f7bd6666cdcd600c1c2021e3dad2294c00ee94 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 13 Jun 2018 23:00:54 +0000 Subject: [PATCH 01/13] clip sparse grad. fix _reduce for rowsparse param --- python/mxnet/gluon/parameter.py | 10 +++--- python/mxnet/gluon/utils.py | 8 +++-- tests/python/unittest/test_gluon.py | 47 ++++++++++++++++------------- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index c0d89fbd4cc1..9764d1e65139 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -310,14 +310,16 @@ def _init_grad(self): self._grad, self.grad_req) def _reduce(self): - """Reduce data from multiple context.""" + """Reduce data from multiple context to cpu.""" + ctx = context.cpu() if self._stype == 'default': block = self.list_data() - data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block) + data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) else: # fetch all rows for 'row_sparse' param - all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu()) - data = self.row_sparse_data(all_row_ids) + all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) + data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx) + self._trainer._row_sparse_pull(self, data, all_row_ids) return data def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 06b91fadcee4..fcb7c97b9809 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -118,10 +118,14 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): def clip_global_norm(arrays, max_norm): """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`. """ + def _norm(array): + if array.stype == 'default': + x = array.reshape((-1,)) + return ndarray.dot(x, x) + return array.norm().square() assert len(arrays) > 0 ctx = arrays[0].context - total_norm = ndarray.add_n(*[ndarray.dot(x, x).as_in_context(ctx) - for x in (arr.reshape((-1,)) for arr in arrays)]) + total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays]) total_norm = ndarray.sqrt(total_norm).asscalar() if not np.isfinite(total_norm): warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'), diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index ced3063448bb..062eceb30900 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -91,15 +91,16 @@ def test_parameter_invalid_access(): @with_seed() def test_paramdict(): + ctx = mx.cpu(1) params0 = gluon.ParameterDict('net_') params0.get('w0', shape=(10, 10)) params0.get('w1', shape=(10, 10), stype='row_sparse') - all_row_ids = mx.nd.arange(0, 10, ctx=mx.cpu()) + all_row_ids = mx.nd.arange(0, 10, ctx=ctx) # check param names assert list(params0.keys()) == ['net_w0', 'net_w1'] - params0.initialize(ctx=mx.cpu()) + params0.initialize(ctx=ctx) trainer0 = mx.gluon.Trainer(params0, 'sgd') - prev_w0 = params0.get('w0').data(mx.cpu()) + prev_w0 = params0.get('w0').data(ctx) prev_w1 = params0.get('w1').row_sparse_data(all_row_ids) # save params params0.save('test_paramdict.params') @@ -108,11 +109,11 @@ def test_paramdict(): params1 = gluon.ParameterDict('net_') params1.get('w0', shape=(10, 10)) params1.get('w1', shape=(10, 10), stype='row_sparse') - params1.load('test_paramdict.params', mx.cpu()) + params1.load('test_paramdict.params', ctx) trainer1 = mx.gluon.Trainer(params1, 'sgd') # compare the values before and after save/load - cur_w0 = params1.get('w0').data(mx.cpu()) + cur_w0 = params1.get('w0').data(ctx) cur_w1 = params1.get('w1').row_sparse_data(all_row_ids) mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) @@ -122,11 +123,11 @@ def test_paramdict(): params2 = gluon.ParameterDict('net_') params2.get('w0', shape=(10, 10)) params2.get('w1', shape=(10, 10)) - params2.load('test_paramdict.params', mx.cpu()) + params2.load('test_paramdict.params', ctx) # compare the values before and after save/load - cur_w0 = params2.get('w0').data(mx.cpu()) - cur_w1 = params2.get('w1').data(mx.cpu()) + cur_w0 = params2.get('w0').data(ctx) + cur_w1 = params2.get('w1').data(ctx) mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) @@ -728,19 +729,23 @@ def test_sequential_warning(): @with_seed() def test_global_norm_clip(): - x1 = mx.nd.ones((3,3)) - x2 = mx.nd.ones((4,4)) - norm = gluon.utils.clip_global_norm([x1, x2], 1.0) - assert norm == 5.0 - assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) - assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) - - x3 = mx.nd.array([1.0, 2.0, float('nan')]) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - gluon.utils.clip_global_norm([x1, x3], 2.0) - assert len(w) == 1 - + stypes = ['default', 'row_sparse'] + def check_global_norm_clip(stype): + x1 = mx.nd.ones((3,3)).tostype(stype) + x2 = mx.nd.ones((4,4)).tostype(stype) + norm = gluon.utils.clip_global_norm([x1, x2], 1.0) + assert norm == 5.0 + assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) + assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) + + x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + gluon.utils.clip_global_norm([x1, x3], 2.0) + assert len(w) == 1 + + for stype in stypes: + check_global_norm_clip(stype) @with_seed() def test_embedding(): From 3e6c4c2408ae60bc791313ea236b5465ab5ad8c0 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 14 Jun 2018 00:35:48 +0000 Subject: [PATCH 02/13] fix kvstore init for local kv --- python/mxnet/gluon/trainer.py | 2 +- tests/python/unittest/test_gluon_trainer.py | 48 ++++++++++++--------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index ef20109021aa..02d68f0c39cb 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -152,7 +152,6 @@ def _reset_kvstore(self): def _init_kvstore(self): """Create kvstore.""" - arg_arrays = {} config = self._kvstore_params if self._contains_sparse: kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) @@ -162,6 +161,7 @@ def _init_kvstore(self): "gradients and/or sparse weights are present for " "Parameter '%s'."%param.name) else: + arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) if config['update_on_kvstore'] is not None: diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index c2e11ebb18ee..1c59ceaa093a 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -177,24 +177,30 @@ def test_trainer_save_load(): @with_seed() def test_trainer_reset_kv(): - params = gluon.ParameterDict() - x = params.get('x', shape=(10,), lr_mult=1.0) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}) - params.save('test_trainer_reset_kv.params') - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - trainer.step(1) - # load would reset kvstore - params.load('test_trainer_reset_kv.params') - assert trainer._kvstore is None - assert trainer._kv_initialized is False - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - trainer.step(1) - # the updated parameter should be based on the loaded checkpoint - assert (x.data(mx.cpu()) == -0.2).asnumpy().all() + def check_trainer_reset_kv(kv): + params = gluon.ParameterDict() + x = params.get('x', shape=(10,), lr_mult=1.0) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + params.save('test_trainer_reset_kv.params') + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + assert trainer._kvstore.type == kv + # load would reset kvstore + params.load('test_trainer_reset_kv.params') + assert trainer._kvstore is None + assert trainer._kv_initialized is False + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + # the updated parameter should be based on the loaded checkpoint + assert (x.data(mx.cpu()) == -0.2).asnumpy().all() + + kvs = ['local', 'device'] + for kv in kvs: + check_trainer_reset_kv(kv) From 62b2c6a8b162d6b683170ce2b1a8ed9a804f1887 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Sat, 16 Jun 2018 04:06:33 +0000 Subject: [PATCH 03/13] trigger From 4cafce14be7ac50f60af264a96dffd52db38c7bc Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 22 Jun 2018 05:46:53 +0000 Subject: [PATCH 04/13] pull with ignore sparse --- include/mxnet/c_api.h | 8 ++++-- include/mxnet/kvstore.h | 6 +++-- include/mxnet/ndarray.h | 2 +- python/mxnet/gluon/trainer.py | 50 ++++++++++++++++++++++++++--------- python/mxnet/kvstore.py | 17 +++++++----- python/mxnet/model.py | 15 +++++++---- src/c_api/c_api.cc | 10 ++++--- src/kvstore/comm.h | 2 +- src/kvstore/kvstore_local.h | 49 ++++++++++++++++++++-------------- src/ndarray/ndarray.cc | 6 ++--- 10 files changed, 107 insertions(+), 58 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4dd858a51c4b..68fda1941186 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1900,13 +1900,15 @@ MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, - int priority); + int priority, + bool ignore_sparse = true); /*! * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string * \param handle handle to the kvstore @@ -1914,13 +1916,15 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, mx_uint num, const char** keys, NDArrayHandle* vals, - int priority); + int priority, + bool ignore_sparse = true); /*! * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer. diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 9e92207fb8db..e10bd213aa26 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -170,19 +170,21 @@ class KVStore { * \param keys the list of keys * \param values the list of buffers for the pulled data, they should be preallocated * \param priority Priority of the action. + * \param ignore_sparse whether to ignore sparse arrays in the request */ virtual void Pull(const std::vector& keys, const std::vector& values, - int priority = 0) = 0; + int priority = 0, bool ignore_sparse = true) = 0; /*! * \brief pull a list of key-value pairs from the store * \param keys the list of keys in string format * \param values the list of buffers for the pulled data, they should be preallocated * \param priority Priority of the action. + * \param ignore_sparse whether to ignore sparse arrays in the request */ virtual void Pull(const std::vector& str_keys, const std::vector& values, - int priority = 0) = 0; + int priority = 0, bool ignore_sparse = true) = 0; /*! * \brief pull a list of key-value pairs from the store. diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index faffe1bdea99..54317e7107ab 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -1027,7 +1027,7 @@ void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0); * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ -void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0); +void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0, bool is_opr = false); /*! * \brief Perform elementwise sum over each data from source, store result into out. diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 02d68f0c39cb..3405d9ad69c0 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -69,7 +69,8 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', "got %s."%(type(params))) self._params = [] # parameters to initialize on the kvstore - self._contains_sparse = False + self._contains_sparse_weight = False + self._contains_sparse_grad = False self._param2idx = {} for i, param in enumerate(params): if not isinstance(param, Parameter): @@ -80,7 +81,9 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', self._params.append(param) param._set_trainer(self) if param._stype != 'default': - self._contains_sparse = True + self._contains_sparse_weight = True + if param._grad_stype != 'default': + self._contains_sparse_grad = True self._compression_params = compression_params optimizer_params = optimizer_params if optimizer_params else {} self._scale = float(optimizer_params.get('rescale_grad', 1.0)) @@ -153,13 +156,31 @@ def _reset_kvstore(self): def _init_kvstore(self): """Create kvstore.""" config = self._kvstore_params - if self._contains_sparse: - kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) - # update_on_kvstore is set to False by the user + # if weight is sparse, the weight must be updated on KVStore. + # training loop contains: + # - row_sparse_pull(sparse_weight) + # - forward() + # - backward() + # - push(sparse_grad), push(dense_grad) + # - pull(dense_weight) + if self._contains_sparse_weight: + kvstore = _create_sparse_kvstore(config['kvstore'], len(self._contexts)) + update_on_kvstore = True + # raise Error if update_on_kvstore is set to False by the user if config['update_on_kvstore'] is False: - raise RuntimeError("Cannot set update_on_kvstore to False when sparse " - "gradients and/or sparse weights are present for " - "Parameter '%s'."%param.name) + raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights " + "are present.") + # if weight is dense and grad is sparse, the weight better not be updated on KVStore. + # training loop contains: + # - forward() + # - backward() + # - push(grad) + # - pull(grad) + # - update(grad, weight) + elif self._contains_sparse_grad: + kvstore = _create_sparse_kvstore(config['kvstore'], len(self._contexts)) + update_on_kvstore = False + # normal case else: arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), @@ -169,9 +190,12 @@ def _init_kvstore(self): if kvstore: if self._compression_params: kvstore.set_gradient_compression(self._compression_params) - # kv.pull(row_sparse_grad) is not supported - if 'dist' in kvstore.type and not self._contains_sparse: - update_on_kvstore = False + if 'dist' in kvstore.type: + # kv.pull(row_sparse_grad) is not supported for dist kvstore + if self._contains_sparse_weight or self._contains_sparse_grad: + update_on_kvstore = True + else: + update_on_kvstore = False if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision kvstore.set_optimizer(self._optimizer) @@ -272,7 +296,7 @@ def _allreduce_grads(self): self._kvstore.push(i, param.list_grad(), priority=-i) if not self._update_on_kvstore: - self._kvstore.pull(i, param.list_grad(), priority=-i) + self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False) def update(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. @@ -327,7 +351,7 @@ def _update(self, ignore_stale_grad=False): if self._kvstore and self._update_on_kvstore: if param._stype == 'default': # 'row_sparse' parameters are not pulled immediately - they're pulled - # in `SparseBlock.sparse_forward` + # in `Block.forward` self._kvstore.pull(i, param.list_data(), priority=-i) continue diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index f31dac01cd10..6b954acc7a5f 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -235,7 +235,7 @@ def push(self, key, value, priority=0): self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) - def pull(self, key, out=None, priority=0): + def pull(self, key, out=None, priority=0, ignore_sparse=True): """ Pulls a single value or a sequence of values from the store. This function returns immediately after adding an operator to the engine. @@ -247,8 +247,8 @@ def pull(self, key, out=None, priority=0): The returned values are guaranteed to be the latest values in the store. - For `RowSparseNDArray` values, this call is ignored, - please use ``row_sparse_pull`` instead. + pull with `RowSparseNDArray` is not supported for dist kvstore. + Please use ``row_sparse_pull`` instead. Parameters ---------- @@ -263,6 +263,9 @@ def pull(self, key, out=None, priority=0): Higher priority pull operations are likely to be executed before other pull actions. + ignore_sparse: bool, optional + Whether to ignore sparse arrays in the request. + Examples -------- >>> # pull a single key-value pair @@ -298,11 +301,11 @@ def pull(self, key, out=None, priority=0): assert(out is not None) ckeys, cvals, use_str_keys = _ctype_key_value(key, out) if use_str_keys: - check_call(_LIB.MXKVStorePullEx( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + check_call(_LIB.MXKVStorePullEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority), ctypes.c_bool(ignore_sparse))) else: - check_call(_LIB.MXKVStorePull( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + check_call(_LIB.MXKVStorePull(self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority), ctypes.c_bool(ignore_sparse))) def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 3a50553a615c..6ce72f595cbf 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -55,24 +55,29 @@ 'eval_metric', 'locals']) -def _create_sparse_kvstore(kvstore): +def _create_sparse_kvstore(kvstore, num_device): """Create kvstore assuming some parameters' storage types are row_sparse. Parameters ---------- kvstore : KVStore or str The kvstore. + num_device : int + The number of devices """ - # always update on kvstore - update_on_kvstore = True if isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): - kv = kvs.create(kvstore) + # create kvstore using the string type + if num_device is 1 and 'dist' not in kvstore: + # no need to use kv for single device and single machine + kv = None + else: + kv = kvs.create(kvstore) else: raise TypeError("Cannot create '%s' KVStore with row_sparse parameters. " "The type must be KVStore or str." % kvstore) - return (kv, update_on_kvstore) + return kv def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 467118b9921e..6b1012b1a4d7 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -861,7 +861,8 @@ int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, - int priority) { + int priority, + bool ignore_sparse) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -869,7 +870,7 @@ int MXKVStorePull(KVStoreHandle handle, v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } - static_cast(handle)->Pull(v_keys, v_vals, priority); + static_cast(handle)->Pull(v_keys, v_vals, priority, ignore_sparse); API_END(); } @@ -877,7 +878,8 @@ int MXKVStorePullEx(KVStoreHandle handle, mx_uint num, const char** keys, NDArrayHandle* vals, - int priority) { + int priority, + bool ignore_sparse) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -885,7 +887,7 @@ int MXKVStorePullEx(KVStoreHandle handle, v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } - static_cast(handle)->Pull(v_keys, v_vals, priority); + static_cast(handle)->Pull(v_keys, v_vals, priority, ignore_sparse); API_END(); } diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index a5d6a1dabeff..bcaa046255b7 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -653,7 +653,7 @@ class CommDevice : public Comm { } on_complete(); }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()}, - FnProperty::kNormal, priority, "KVStoreSparseRetain"); + FnProperty::kCopyFromGPU, priority, "KVStoreSparseRetain"); CopyFromTo(retained_gpu, out, priority); } } diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 38ecf121dfeb..ce1870b5dc05 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -101,9 +101,10 @@ class KVStoreLocal : public KVStore { void Pull(const std::vector& keys, const std::vector& values, - int priority) override { + int priority, + bool ignore_sparse) override { SetKeyType(kIntKey); - PullImpl(keys, values, priority); + PullImpl(keys, values, priority, ignore_sparse); } void PullRowSparse(const std::vector& keys, @@ -124,11 +125,12 @@ class KVStoreLocal : public KVStore { void Pull(const std::vector& str_keys, const std::vector& values, - int priority) override { + int priority, + bool ignore_sparse) override { SetKeyType(kStringKey); std::vector keys(str_keys.size()); LookupKeys(str_keys, &keys); - PullImpl(keys, values, priority); + PullImpl(keys, values, priority, ignore_sparse); } void PullRowSparse(const std::vector& str_keys, @@ -162,7 +164,7 @@ class KVStoreLocal : public KVStore { int priority) { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority); @@ -198,10 +200,11 @@ class KVStoreLocal : public KVStore { virtual void PullImpl(const std::vector& keys, const std::vector& values, - int priority) { + int priority, + bool ignore_sparse) { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, ignore_sparse); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -216,7 +219,7 @@ class KVStoreLocal : public KVStore { int priority = 0) { std::vector uniq_keys; std::vector>> grouped_val_rowids; - GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; const NDArray& local = local_[key]; @@ -250,9 +253,11 @@ class KVStoreLocal : public KVStore { virtual void GroupKVPairsPush(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const NDArray& nd) -> bool { + auto validator = [this](const int key, const NDArray& nd, bool ignore_sparse) -> bool { + CHECK(!ignore_sparse) << "Cannot ignore sparse arrays for push"; auto stype = nd.storage_type(); // valid NDArray if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; @@ -260,7 +265,7 @@ class KVStoreLocal : public KVStore { LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; return false; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } /** * \brief group values on keys for pull @@ -268,11 +273,12 @@ class KVStoreLocal : public KVStore { virtual void GroupKVPairsPull(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const NDArray* nd) -> bool { + auto validator = [this](const int key, const NDArray* nd, bool ignore_sparse) -> bool { // valid - if (nd->storage_type() == kDefaultStorage) return true; + if (nd->storage_type() == kDefaultStorage || !ignore_sparse) return true; // invalid, print warning messages once if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " @@ -282,7 +288,7 @@ class KVStoreLocal : public KVStore { } return false; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } typedef std::pair RSPVal; @@ -292,9 +298,11 @@ class KVStoreLocal : public KVStore { virtual void GroupKVPairsPullRsp(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const RSPVal& val_rowid) -> bool { + auto validator = [this](const int key, const RSPVal& val_rowid, bool ignore_sparse) -> bool { + CHECK(!ignore_sparse) << "Cannot ignore sparse arrays in row_sparse_pull"; auto val_stype = val_rowid.first->storage_type(); auto rowid_stype = val_rowid.second.storage_type(); // check storage types @@ -304,7 +312,7 @@ class KVStoreLocal : public KVStore { << "row_sparse_pull rowids, but detected storage type " << rowid_stype; return true; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } /** @@ -316,7 +324,8 @@ class KVStoreLocal : public KVStore { const std::vector& values, std::vector* uniq_keys, std::vector >* grouped_vals, - const FValidate& is_valid) { + const FValidate& is_valid, + bool ignore_sparse) { CHECK_EQ(keys.size(), values.size()); // TODO(mli) check if already sorted as an optimization using Idx = std::pair; @@ -330,7 +339,7 @@ class KVStoreLocal : public KVStore { int pre_key = idx[0].first - 1; for (auto i : idx) { - if (is_valid(i.first, values[i.second])) { + if (is_valid(i.first, values[i.second], ignore_sparse)) { if (i.first != pre_key) { uniq_keys->push_back(i.first); grouped_vals->push_back({values[i.second]}); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 94d3d90413ab..9ef6a3829d09 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1211,7 +1211,7 @@ void CopyFromToImpl(const NDArray& from, const NDArray& to, } } -void CopyFromTo(const NDArray& from, const NDArray& to, int priority) { +void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_opr) { if (from.var() == to.var() && from.byte_offset() == to.byte_offset()) { // skip to copy to itself return; @@ -1292,7 +1292,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority) { on_complete(); }, from.ctx(), const_vars, mutable_vars, from.dtype() != to.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, - priority, "CopyGPU2GPU"); + priority, is_opr ? "_copyto_GPU2GPU" : "CopyGPU2GPU"); } else { LOG(FATAL) << "unknown device mask"; } @@ -2072,7 +2072,7 @@ void CopyFromToSimple( const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CopyFromTo(inputs[0], outputs[0], 0); + CopyFromTo(inputs[0], outputs[0], 0, true); } // copy function is special From 932bf4997481386d86ab91bb975fc71bfd58d2b0 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 22 Jun 2018 21:15:35 +0000 Subject: [PATCH 05/13] rsp pull with priority --- include/mxnet/engine.h | 2 + python/mxnet/gluon/trainer.py | 4 +- src/engine/threaded_engine_perdevice.cc | 57 ++++++++++++++++++------- src/kvstore/comm.h | 5 ++- src/kvstore/kvstore_local.h | 15 +++---- src/kvstore/kvstore_utils.cc | 4 +- src/kvstore/kvstore_utils.cu | 19 ++++----- src/kvstore/kvstore_utils.h | 7 +-- 8 files changed, 70 insertions(+), 43 deletions(-) diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index fd1fe89bdbaf..2550c0772900 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -84,6 +84,8 @@ enum class FnProperty { kCopyToGPU, /*! \brief Prioritized sync operation on CPU */ kCPUPrioritized, + /*! \brief Prioritized sync operation on GPU */ + kGPUPrioritized, /*! \brief Asynchronous function call */ kAsync, /*! \brief Delete variable call */ diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 3405d9ad69c0..7ac915d24094 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -235,8 +235,8 @@ def _row_sparse_pull(self, parameter, out, row_id): self._init_kvstore() if self._params_to_init: self._init_params() - self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \ - out=out, row_ids=row_id) + idx = self._param2idx[parameter.name] + self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) def step(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. Should be called after diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 97f258c10618..b6537dabb638 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -61,6 +61,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { void StopNoWait() { SignalQueuesForKill(); gpu_normal_workers_.Clear(); + gpu_priority_workers_.Clear(); gpu_copy_workers_.Clear(); cpu_normal_workers_.Clear(); cpu_priority_worker_.reset(nullptr); @@ -101,6 +102,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { this->ExecuteOprBlock(RunContext{ctx, nullptr}, opr_block); } else { if (ctx.dev_mask() == Context::kCPU) { + // CPU execution. if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); } else { @@ -152,24 +154,44 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } } else { const size_t nthread = gpu_worker_nthreads_; - auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { - // Signify to kernel that GPU is being used, so reserve cores as necessary - OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); - auto blk = new ThreadWorkerBlock(); - blk->pool.reset(new ThreadPool( - nthread, - [this, ctx, is_copy, blk] - (std::shared_ptr ready_event) { - this->GPUWorker(ctx, is_copy, blk, ready_event); - }, true)); - return blk; - }); - if (ptr) { - if (opr_block->opr->prop == FnProperty::kDeleteVar) { - ptr->task_queue.PushFront(opr_block, opr_block->priority); - } else { + // GPU priority task + if (opr_block->opr->prop == FnProperty::kGPUPrioritized) { + auto ptr = gpu_priority_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { + // Signify to kernel that GPU is being used, so reserve cores as necessary + OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool( + nthread, + [this, ctx, is_copy, blk] + (std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, true)); + return blk; + }); + if (ptr) { ptr->task_queue.Push(opr_block, opr_block->priority); } + } else { + // GPU normal task + auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { + // Signify to kernel that GPU is being used, so reserve cores as necessary + OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool( + nthread, + [this, ctx, is_copy, blk] + (std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, true)); + return blk; + }); + if (ptr) { + if (opr_block->opr->prop == FnProperty::kDeleteVar) { + ptr->task_queue.PushFront(opr_block, opr_block->priority); + } else { + ptr->task_queue.Push(opr_block, opr_block->priority); + } + } } } } @@ -206,6 +228,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { common::LazyAllocArray > gpu_normal_workers_; // workers doing copy works from/to GPU common::LazyAllocArray > gpu_copy_workers_; + // gpu priority workers + common::LazyAllocArray > gpu_priority_workers_; /*! * \brief GPU worker that performs operations on a certain device. * \param dev_id The device id of the worker. @@ -304,6 +328,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { /*! Signal all queues for shutdown */ void SignalQueuesForKill() { + SignalQueueForKill(&gpu_priority_workers_); SignalQueueForKill(&gpu_normal_workers_); SignalQueueForKill(&gpu_copy_workers_); SignalQueueForKill(&cpu_normal_workers_); diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index bcaa046255b7..d242dc30450f 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -629,7 +629,7 @@ class CommDevice : public Comm { "next time row_sparse_pull() is called. To avoid such an issue," "consider create a new NDArray buffer to store the output."); } - + bool is_gpu = retained_gpu.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { const TBlob& indices = row_id.data(); using namespace mxnet::common; @@ -653,7 +653,8 @@ class CommDevice : public Comm { } on_complete(); }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()}, - FnProperty::kCopyFromGPU, priority, "KVStoreSparseRetain"); + is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, + priority, "KVStoreSparseRetain"); CopyFromTo(retained_gpu, out, priority); } } diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index ce1870b5dc05..84e2700a20de 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -379,22 +379,20 @@ class KVStoreLocal : public KVStore { NDArray data_in_ctx = diff_ctx ? NDArray(data.shape(), ctx, true, data.dtype()) : data; // if data == data_in_ctx, CopyFromTo is smart enough to skip the copy CopyFromTo(data, &data_in_ctx, priority); - Resource rsc = ResourceManager::Get()->Request(out.ctx(), - ResourceRequest(ResourceRequest::kTempSpace)); // GPU requires temp resources - std::vector mutate_vars{out.var()}; - if (out.ctx().dev_mask() == gpu::kDevMask) mutate_vars.emplace_back(rsc.var); + bool is_gpu = out.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync( [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { // copy data.data() to out.data() out.CheckAndAlloc({mshadow::Shape1(num_elements)}); TBlob out_data = out.data(); + NDArray workspace; switch (out.ctx().dev_mask()) { case cpu::kDevMask: { mshadow::Stream *s = rctx.get_stream(); ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); - UniqueImpl(rsc, s, out); + UniqueImpl(&workspace, s, out); break; } #if MXNET_USE_CUDA @@ -402,7 +400,7 @@ class KVStoreLocal : public KVStore { mshadow::Stream *s = rctx.get_stream(); ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); - UniqueImpl(rsc, s, out); + UniqueImpl(&workspace, s, out); // wait for GPU operations to complete s->Wait(); break; @@ -412,8 +410,9 @@ class KVStoreLocal : public KVStore { LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } on_complete(); - }, out.ctx(), {data_in_ctx.var()}, mutate_vars, - FnProperty::kNormal, priority, "KVStoreUnique"); + }, out.ctx(), {data_in_ctx.var()}, {out.var()}, + is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, + priority, "KVStoreUnique"); return out; } diff --git a/src/kvstore/kvstore_utils.cc b/src/kvstore/kvstore_utils.cc index e187b0ce4890..b53eca433e97 100644 --- a/src/kvstore/kvstore_utils.cc +++ b/src/kvstore/kvstore_utils.cc @@ -29,8 +29,8 @@ namespace mxnet { namespace kvstore { template<> -void UniqueImpl(const Resource& rsc, mshadow::Stream *s, - const NDArray& out) { +void UniqueImpl(NDArray* workspace, mshadow::Stream *s, + const NDArray& out) { const size_t num_elements = out.shape().Size(); CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected"; MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, { diff --git a/src/kvstore/kvstore_utils.cu b/src/kvstore/kvstore_utils.cu index 438fe29dac4e..2dab5bc1802d 100644 --- a/src/kvstore/kvstore_utils.cu +++ b/src/kvstore/kvstore_utils.cu @@ -41,8 +41,8 @@ namespace mxnet { namespace kvstore { template -size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, - IType *dptr, const size_t size) { +size_t UniqueImplGPU(NDArray *workspace, mshadow::Stream *s, + IType *dptr, const size_t size, Context ctx) { // estimate unique temp space. The first byte is reserved to store the number // of unique values selected const size_t num_selected_bytes = sizeof(size_t); @@ -68,12 +68,12 @@ size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, // request temp storage const size_t total_workspace = num_selected_bytes + sort_output_bytes + std::max(sort_temp_bytes, unique_temp_bytes); - mshadow::Tensor workspace = rsc - .get_space_typed(mshadow::Shape1(total_workspace), s); + *workspace = NDArray(mshadow::Shape1((total_workspace + 3) / 4), ctx, false); + char* workspace_dptr = reinterpret_cast(workspace->data().dptr_); // temp space layout: num_selected_ptr, sort_output_bytes, unique/sort_temp_storage - size_t* num_selected_ptr = reinterpret_cast(workspace.dptr_); - IType* sort_output_ptr = reinterpret_cast(workspace.dptr_ + num_selected_bytes); - void *temp_storage = static_cast(workspace.dptr_ + + size_t* num_selected_ptr = reinterpret_cast(workspace_dptr); + IType* sort_output_ptr = reinterpret_cast(workspace_dptr + num_selected_bytes); + void *temp_storage = static_cast(workspace_dptr + num_selected_bytes + sort_output_bytes); // execute the sort kernel #ifndef SORT_WITH_THRUST @@ -96,13 +96,12 @@ size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, } template<> -void UniqueImpl(const Resource& rsc, mshadow::Stream *s, - const NDArray &out) { +void UniqueImpl(NDArray *workspace, mshadow::Stream *s, const NDArray &out) { const size_t num_elements = out.shape().Size(); CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected"; MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, { IType *dptr = out.data().dptr(); - size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, num_elements); + size_t num_selected_out = UniqueImplGPU(workspace, s, dptr, num_elements, out.ctx()); // set the shape of data/aux_data according to the number of unique values out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out)); }); diff --git a/src/kvstore/kvstore_utils.h b/src/kvstore/kvstore_utils.h index ee173b4559f7..2527f7ed0ce2 100644 --- a/src/kvstore/kvstore_utils.h +++ b/src/kvstore/kvstore_utils.h @@ -36,14 +36,15 @@ namespace kvstore { /*! * \brief compute unique and sorted values in a row_sparse ndarray. - * \param rsc Temp resource for computation + * \param workspace Temp workspace for computation. Its a pointer to a + NDArray placeholder to make sure the NDArray is not free'd + during execution. * \param s Stream * \param out Input and output ndarray. The ndarray stores the * unique elements in out.data(). */ template -void UniqueImpl(const Resource& rsc, mshadow::Stream *s, - const NDArray& out); +void UniqueImpl(NDArray* workspace, mshadow::Stream *s, const NDArray& out); } // namespace kvstore } // namespace mxnet From 1a6dc361849e4d6eb1ec21ddd11280fff7012ccb Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 22 Jun 2018 21:44:54 +0000 Subject: [PATCH 06/13] add doc; --- include/mxnet/ndarray.h | 2 ++ python/mxnet/kvstore.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 54317e7107ab..1cc8c87b4d67 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -1024,6 +1024,8 @@ void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0); * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. + * \param is_opr whether it is invoked by an operator. For example, false if invoked from + KVStore, true if invoked from `_copyto` operator. * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 6b954acc7a5f..4e69e4b0f96c 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -263,7 +263,7 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True): Higher priority pull operations are likely to be executed before other pull actions. - ignore_sparse: bool, optional + ignore_sparse: bool, optional, default True Whether to ignore sparse arrays in the request. Examples From 6f38f75da9b9cc2804e67599573234c12346a8e1 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Mon, 25 Jun 2018 00:56:14 +0000 Subject: [PATCH 07/13] fix bug in sparse kvstore --- python/mxnet/gluon/trainer.py | 5 +++-- python/mxnet/model.py | 11 ++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 7ac915d24094..5ea844a9561e 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -164,7 +164,7 @@ def _init_kvstore(self): # - push(sparse_grad), push(dense_grad) # - pull(dense_weight) if self._contains_sparse_weight: - kvstore = _create_sparse_kvstore(config['kvstore'], len(self._contexts)) + kvstore = _create_sparse_kvstore(config['kvstore']) update_on_kvstore = True # raise Error if update_on_kvstore is set to False by the user if config['update_on_kvstore'] is False: @@ -178,7 +178,8 @@ def _init_kvstore(self): # - pull(grad) # - update(grad, weight) elif self._contains_sparse_grad: - kvstore = _create_sparse_kvstore(config['kvstore'], len(self._contexts)) + arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) update_on_kvstore = False # normal case else: diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 6ce72f595cbf..5f2d8022ada9 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -55,25 +55,18 @@ 'eval_metric', 'locals']) -def _create_sparse_kvstore(kvstore, num_device): +def _create_sparse_kvstore(kvstore): """Create kvstore assuming some parameters' storage types are row_sparse. Parameters ---------- kvstore : KVStore or str The kvstore. - num_device : int - The number of devices """ if isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): - # create kvstore using the string type - if num_device is 1 and 'dist' not in kvstore: - # no need to use kv for single device and single machine - kv = None - else: - kv = kvs.create(kvstore) + kv = kvs.create(kvstore) else: raise TypeError("Cannot create '%s' KVStore with row_sparse parameters. " "The type must be KVStore or str." % kvstore) From 72817979497573a797a104b20de1f7ef278de49f Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Tue, 26 Jun 2018 22:30:52 +0000 Subject: [PATCH 08/13] +kvstore test --- tests/python/gpu/test_kvstore_gpu.py | 7 +++ tests/python/unittest/test_kvstore.py | 71 +++++++++++++++------------ 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 76231fbe90ee..90ed129114ff 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -77,6 +77,13 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): expected_val += 0 if row in excluded_row_ids else 2 assert_almost_equal(retained[row], expected_val) + kv.pull('e', out=vals_to_pull, ignore_sparse=False) + for val in vals: + retained = val.asnumpy() + expected_val = np.zeros_like(retained) + expected_val[:] = 2 + assert_almost_equal(retained, expected_val) + check_rsp_pull(kv, 1, [mx.gpu(0)]) check_rsp_pull(kv, 1, [mx.cpu(0)]) check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)]) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 0ab61bb27483..921a5704d54b 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -160,43 +160,52 @@ def check_aggregator(kv, key, key_list, stype): @with_seed() def test_sparse_aggregator(): """aggregate sparse ndarray on muliple devices""" + def check_sparse_aggregator(sparse_pull): + stype = 'row_sparse' + kv = init_kv_with_str(stype) - stype = 'row_sparse' - kv = init_kv_with_str(stype) - - # devices - num_devs = 4 - devs = [mx.Context('cpu', i) for i in range(num_devs)] - - # single - vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] - expected_sum = np.zeros(shape) - for v in vals: - expected_sum += v.asnumpy() - - # prepare row_ids - all_rows = mx.nd.array(np.arange(shape[0])) - kv.push('a', vals) - kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) - result_sum = np.zeros(shape) - for v in vals: - result_sum += v.asnumpy() - assert_almost_equal(result_sum, expected_sum * num_devs) + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] - # list - vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) - expected_sum = np.zeros(shape) - for v in vals[0]: - expected_sum += v.asnumpy() - - kv.push(str_keys, vals) - kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) - for vv in vals: + # single + vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] + expected_sum = np.zeros(shape) + for v in vals: + expected_sum += v.asnumpy() + + # prepare row_ids + kv.push('a', vals) + if sparse_pull: + all_rows = mx.nd.array(np.arange(shape[0])) + kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) + else: + kv.pull('a', out=vals, ignore_sparse=False) result_sum = np.zeros(shape) - for v in vv: + for v in vals: result_sum += v.asnumpy() assert_almost_equal(result_sum, expected_sum * num_devs) + # list + vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) + expected_sum = np.zeros(shape) + for v in vals[0]: + expected_sum += v.asnumpy() + + kv.push(str_keys, vals) + if sparse_pull: + kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) + else: + kv.pull(str_keys, out=vals, ignore_sparse=False) + for vv in vals: + result_sum = np.zeros(shape) + for v in vv: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) + + check_sparse_aggregator(False) + check_sparse_aggregator(True) + def updater(key, recv, local): """use updater: += with int keys""" assert(isinstance(key, int)) From a83482613294f305d2604b6b6bd63674a01df987 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 27 Jun 2018 20:23:28 +0000 Subject: [PATCH 09/13] add dist kvstore test --- ci/docker/runtime_functions.sh | 2 +- python/mxnet/gluon/trainer.py | 3 +- python/mxnet/model.py | 4 +- src/kvstore/kvstore_dist.h | 9 +++-- tests/nightly/dist_sync_kvstore.py | 41 +++++++++++++-------- tests/python/unittest/test_gluon_trainer.py | 28 ++++++++++++++ 6 files changed, 64 insertions(+), 23 deletions(-) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index 07980471c580..bcb590785ac9 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -730,7 +730,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() { ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision ../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py - ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon + ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=invalid } test_ubuntu_cpu_python2() { diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 5ea844a9561e..1c4c608285ad 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -164,8 +164,7 @@ def _init_kvstore(self): # - push(sparse_grad), push(dense_grad) # - pull(dense_weight) if self._contains_sparse_weight: - kvstore = _create_sparse_kvstore(config['kvstore']) - update_on_kvstore = True + kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) # raise Error if update_on_kvstore is set to False by the user if config['update_on_kvstore'] is False: raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights " diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 5f2d8022ada9..3a50553a615c 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -63,6 +63,8 @@ def _create_sparse_kvstore(kvstore): kvstore : KVStore or str The kvstore. """ + # always update on kvstore + update_on_kvstore = True if isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): @@ -70,7 +72,7 @@ def _create_sparse_kvstore(kvstore): else: raise TypeError("Cannot create '%s' KVStore with row_sparse parameters. " "The type must be KVStore or str." % kvstore) - return kv + return (kv, update_on_kvstore) def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index dd3464bf6db4..7e2f5cb5faa9 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -207,10 +207,11 @@ class KVStoreDist : public KVStoreLocal { void PullImpl(const std::vector& keys, const std::vector& values, - int priority) override { + int priority, bool ignore_sparse) override { + CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False"; std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -263,7 +264,7 @@ class KVStoreDist : public KVStoreLocal { int priority = 0) override { std::vector uniq_keys; std::vector>> grouped_val_rowids; - GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -305,7 +306,7 @@ class KVStoreDist : public KVStoreLocal { // first aggregate the values over keys std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { // merge over devices diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 32ed2dddb6fb..73d03420a8c6 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -38,6 +38,7 @@ def check_diff(A, x, rank=None): irregular_shape = (1211,1211) big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND +keys_invalid = [999] keys_shape = ['3', '5', '7'] keys_big_shape = ['99'] fp16_keys_shape = ['4', '6', '8'] @@ -350,19 +351,29 @@ def check_init(kv, cur_keys, cur_shape, device=False): check_init(kv, init_test_keys_device_big, big_shape, device=True) print('worker ' + str(kv.rank) + ' is initialized') -def test_gluon_trainer_reset(): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse') - params.initialize(ctx=mx.cpu(0), init='zeros') - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) - params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params') - row_id = mx.nd.arange(0, 4) - w = x.row_sparse_data(row_id) - assert trainer._kv_initialized and trainer._update_on_kvstore - # load would fail to reset kvstore since update_on_kvstore is True - assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params') - print('worker ' + str(my_rank) + ' passed test_gluon_trainer_reset') - +def test_invalid_operations(): + def check_invalid_gluon_trainer_reset(): + params = mx.gluon.ParameterDict() + x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse') + params.initialize(ctx=mx.cpu(0), init='zeros') + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params') + row_id = mx.nd.arange(0, 4) + w = x.row_sparse_data(row_id) + assert trainer._kv_initialized and trainer._update_on_kvstore + mx.nd.waitall() + # load would fail to reset kvstore since update_on_kvstore is True + assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params') + print('worker ' + str(my_rank) + ' passed check_invalid_gluon_trainer_reset') + + def check_invalid_pull(): + kv.init(keys_invalid[0], mx.nd.ones((2,2)).tostype('row_sparse')) + out = mx.nd.ones((2,2)).tostype('row_sparse') + assert_exception(kv.pull, mx.MXNetError, 'invalid_key', out=out, ignore_sparse=False) + print('worker ' + str(my_rank) + ' passed check_invalid_pull') + + check_invalid_gluon_trainer_reset() + check_invalid_pull() if __name__ == "__main__": parser = argparse.ArgumentParser(description='test distributed kvstore in dist_sync mode') @@ -371,8 +382,8 @@ def test_gluon_trainer_reset(): parser.add_argument('--no-gpu', dest='gpu', action='store_false') parser.add_argument('--no-multiprecision', dest='multiprecision', action='store_false') opt = parser.parse_args() - if opt.type == 'gluon': - test_gluon_trainer_reset() + if opt.type == 'invalid': + test_invalid_operations() if opt.type == 'all' or opt.type == 'init': test_sync_init(opt.gpu) if opt.type == 'all' or opt.type == 'default': diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index eac9fad45f57..2a34400d60ab 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -205,3 +205,31 @@ def check_trainer_reset_kv(kv): kvs = ['local', 'device'] for kv in kvs: check_trainer_reset_kv(kv) + +@with_seed() +def test_trainer_sparse_kv(): + def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv): + params = gluon.ParameterDict() + x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0)) + ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows) + with mx.autograd.record(): + for w in ws: + y = w + 1 + y.backward() + trainer.step(1) + assert trainer._kvstore.type == kv + assert trainer._kv_initialized + assert trainer._update_on_kvstore is update_on_kv + # the updated parameter should be based on the loaded checkpoint + mx.nd.waitall() + updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows) + assert (updated_w == -0.2).asnumpy().all() + + kvs = ['local', 'device'] + for kv in kvs: + check_trainer_sparse_kv(kv, 'default', 'default', True) + check_trainer_sparse_kv(kv, 'default', 'row_sparse', False) + check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', True) From 26229245a2cddf3114d2c2aa823bfa740c907bb0 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 27 Jun 2018 20:36:24 +0000 Subject: [PATCH 10/13] enhance dist kv test --- ci/docker/runtime_functions.sh | 1 + tests/nightly/dist_sync_kvstore.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index bcb590785ac9..8ac7f71feb6d 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -731,6 +731,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() { ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision ../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=invalid + ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon } test_ubuntu_cpu_python2() { diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 73d03420a8c6..8ba1edab3a0d 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -375,6 +375,22 @@ def check_invalid_pull(): check_invalid_gluon_trainer_reset() check_invalid_pull() +def test_gluon_trainer(): + def check_trainer_kv_type(stype, grad_stype, update_on_kv): + params = mx.gluon.ParameterDict() + x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + trainer._init_kvstore() + assert trainer._kv_initialized + assert trainer._update_on_kvstore is update_on_kv + + check_trainer_kv_type('default', 'default', False) + check_trainer_kv_type('default', 'row_sparse', True) + check_trainer_kv_type('row_sparse', 'row_sparse', True) + print('worker ' + str(my_rank) + ' passed test_gluon_trainer') + + if __name__ == "__main__": parser = argparse.ArgumentParser(description='test distributed kvstore in dist_sync mode') parser.add_argument('--nrepeat', type=int, default=7) @@ -382,6 +398,8 @@ def check_invalid_pull(): parser.add_argument('--no-gpu', dest='gpu', action='store_false') parser.add_argument('--no-multiprecision', dest='multiprecision', action='store_false') opt = parser.parse_args() + if opt.type == 'gluon': + test_gluon_trainer() if opt.type == 'invalid': test_invalid_operations() if opt.type == 'all' or opt.type == 'init': From 46ff1a04b016557546234649f746e8ef6972d9c5 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 27 Jun 2018 20:52:08 +0000 Subject: [PATCH 11/13] fix lint --- python/mxnet/gluon/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 1c4c608285ad..377c43cf5cd5 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -192,7 +192,8 @@ def _init_kvstore(self): kvstore.set_gradient_compression(self._compression_params) if 'dist' in kvstore.type: # kv.pull(row_sparse_grad) is not supported for dist kvstore - if self._contains_sparse_weight or self._contains_sparse_grad: + contains_sparse = self._contains_sparse_weight or self._contains_sparse_grad + if contains_sparse: update_on_kvstore = True else: update_on_kvstore = False From 47b143d7cc824c6b216963cf85f4f5c58232cab7 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 27 Jun 2018 21:00:56 +0000 Subject: [PATCH 12/13] fix lint --- python/mxnet/gluon/trainer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 377c43cf5cd5..09ad96314d5a 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -192,11 +192,7 @@ def _init_kvstore(self): kvstore.set_gradient_compression(self._compression_params) if 'dist' in kvstore.type: # kv.pull(row_sparse_grad) is not supported for dist kvstore - contains_sparse = self._contains_sparse_weight or self._contains_sparse_grad - if contains_sparse: - update_on_kvstore = True - else: - update_on_kvstore = False + update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision kvstore.set_optimizer(self._optimizer) From a2b1cc9d431d90966c324dce427cdea9dc018a4e Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Mon, 2 Jul 2018 20:36:21 +0000 Subject: [PATCH 13/13] CR comments --- include/mxnet/c_api.h | 38 +++++++++++++++++++++++++++----- include/mxnet/engine.h | 6 +++--- python/mxnet/kvstore.py | 10 +++++---- src/c_api/c_api.cc | 48 ++++++++++++++++++++++++++++++++++------- 4 files changed, 82 insertions(+), 20 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 68fda1941186..abb46ce684a8 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1903,12 +1903,42 @@ MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, * \param ignore_sparse whether to ignore sparse arrays in the request * \return 0 when success, -1 when failure happens */ +MXNET_DLL int MXKVStorePullWithSparse(KVStoreHandle handle, + mx_uint num, + const int* keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse); +/*! + * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullWithSparseEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse); +/*! + * \brief pull a list of (key, value) pairs from the kvstore + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, - int priority, - bool ignore_sparse = true); + int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string * \param handle handle to the kvstore @@ -1916,15 +1946,13 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, * \param keys the list of keys * \param vals the list of values * \param priority the priority of the action - * \param ignore_sparse whether to ignore sparse arrays in the request * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, mx_uint num, const char** keys, NDArrayHandle* vals, - int priority, - bool ignore_sparse = true); + int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer. diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 2550c0772900..dc48bfb83fa3 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -84,12 +84,12 @@ enum class FnProperty { kCopyToGPU, /*! \brief Prioritized sync operation on CPU */ kCPUPrioritized, - /*! \brief Prioritized sync operation on GPU */ - kGPUPrioritized, /*! \brief Asynchronous function call */ kAsync, /*! \brief Delete variable call */ - kDeleteVar + kDeleteVar, + /*! \brief Prioritized sync operation on GPU */ + kGPUPrioritized }; // enum class FnProperty /*! diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 4e69e4b0f96c..609733659753 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -301,11 +301,13 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True): assert(out is not None) ckeys, cvals, use_str_keys = _ctype_key_value(key, out) if use_str_keys: - check_call(_LIB.MXKVStorePullEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority), ctypes.c_bool(ignore_sparse))) + check_call(_LIB.MXKVStorePullWithSparseEx(self.handle, mx_uint(len(ckeys)), ckeys, + cvals, ctypes.c_int(priority), + ctypes.c_bool(ignore_sparse))) else: - check_call(_LIB.MXKVStorePull(self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority), ctypes.c_bool(ignore_sparse))) + check_call(_LIB.MXKVStorePullWithSparse(self.handle, mx_uint(len(ckeys)), ckeys, + cvals, ctypes.c_int(priority), + ctypes.c_bool(ignore_sparse))) def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 6b1012b1a4d7..efa7301d7ab2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -861,8 +861,7 @@ int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, NDArrayHandle* vals, - int priority, - bool ignore_sparse) { + int priority) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -870,16 +869,49 @@ int MXKVStorePull(KVStoreHandle handle, v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } - static_cast(handle)->Pull(v_keys, v_vals, priority, ignore_sparse); + static_cast(handle)->Pull(v_keys, v_vals, priority, true); API_END(); } int MXKVStorePullEx(KVStoreHandle handle, - mx_uint num, - const char** keys, - NDArrayHandle* vals, - int priority, - bool ignore_sparse) { + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority, true); + API_END(); +} + +int MXKVStorePullWithSparse(KVStoreHandle handle, + mx_uint num, + const int* keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority, ignore_sparse); + API_END(); +} + +int MXKVStorePullWithSparseEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num);