diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index e49639903f92..3e2e44d9d297 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -730,6 +730,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() { ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision ../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py + ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=invalid ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon } diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4dd858a51c4b..abb46ce684a8 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1893,6 +1893,38 @@ MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, const char** keys, NDArrayHandle* vals, int priority); +/*! + * \brief pull a list of (key, value) pairs from the kvstore + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullWithSparse(KVStoreHandle handle, + mx_uint num, + const int* keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse); +/*! + * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \param ignore_sparse whether to ignore sparse arrays in the request + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullWithSparseEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse); /*! * \brief pull a list of (key, value) pairs from the kvstore * \param handle handle to the kvstore diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index fd1fe89bdbaf..dc48bfb83fa3 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -87,7 +87,9 @@ enum class FnProperty { /*! \brief Asynchronous function call */ kAsync, /*! \brief Delete variable call */ - kDeleteVar + kDeleteVar, + /*! \brief Prioritized sync operation on GPU */ + kGPUPrioritized }; // enum class FnProperty /*! diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 9e92207fb8db..e10bd213aa26 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -170,19 +170,21 @@ class KVStore { * \param keys the list of keys * \param values the list of buffers for the pulled data, they should be preallocated * \param priority Priority of the action. + * \param ignore_sparse whether to ignore sparse arrays in the request */ virtual void Pull(const std::vector& keys, const std::vector& values, - int priority = 0) = 0; + int priority = 0, bool ignore_sparse = true) = 0; /*! * \brief pull a list of key-value pairs from the store * \param keys the list of keys in string format * \param values the list of buffers for the pulled data, they should be preallocated * \param priority Priority of the action. + * \param ignore_sparse whether to ignore sparse arrays in the request */ virtual void Pull(const std::vector& str_keys, const std::vector& values, - int priority = 0) = 0; + int priority = 0, bool ignore_sparse = true) = 0; /*! * \brief pull a list of key-value pairs from the store. diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index faffe1bdea99..1cc8c87b4d67 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -1024,10 +1024,12 @@ void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0); * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. + * \param is_opr whether it is invoked by an operator. For example, false if invoked from + KVStore, true if invoked from `_copyto` operator. * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ -void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0); +void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0, bool is_opr = false); /*! * \brief Perform elementwise sum over each data from source, store result into out. diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 02d68f0c39cb..09ad96314d5a 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -69,7 +69,8 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', "got %s."%(type(params))) self._params = [] # parameters to initialize on the kvstore - self._contains_sparse = False + self._contains_sparse_weight = False + self._contains_sparse_grad = False self._param2idx = {} for i, param in enumerate(params): if not isinstance(param, Parameter): @@ -80,7 +81,9 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', self._params.append(param) param._set_trainer(self) if param._stype != 'default': - self._contains_sparse = True + self._contains_sparse_weight = True + if param._grad_stype != 'default': + self._contains_sparse_grad = True self._compression_params = compression_params optimizer_params = optimizer_params if optimizer_params else {} self._scale = float(optimizer_params.get('rescale_grad', 1.0)) @@ -153,13 +156,31 @@ def _reset_kvstore(self): def _init_kvstore(self): """Create kvstore.""" config = self._kvstore_params - if self._contains_sparse: + # if weight is sparse, the weight must be updated on KVStore. + # training loop contains: + # - row_sparse_pull(sparse_weight) + # - forward() + # - backward() + # - push(sparse_grad), push(dense_grad) + # - pull(dense_weight) + if self._contains_sparse_weight: kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) - # update_on_kvstore is set to False by the user + # raise Error if update_on_kvstore is set to False by the user if config['update_on_kvstore'] is False: - raise RuntimeError("Cannot set update_on_kvstore to False when sparse " - "gradients and/or sparse weights are present for " - "Parameter '%s'."%param.name) + raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights " + "are present.") + # if weight is dense and grad is sparse, the weight better not be updated on KVStore. + # training loop contains: + # - forward() + # - backward() + # - push(grad) + # - pull(grad) + # - update(grad, weight) + elif self._contains_sparse_grad: + arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) + update_on_kvstore = False + # normal case else: arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), @@ -169,9 +190,9 @@ def _init_kvstore(self): if kvstore: if self._compression_params: kvstore.set_gradient_compression(self._compression_params) - # kv.pull(row_sparse_grad) is not supported - if 'dist' in kvstore.type and not self._contains_sparse: - update_on_kvstore = False + if 'dist' in kvstore.type: + # kv.pull(row_sparse_grad) is not supported for dist kvstore + update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad if update_on_kvstore: # optimizer preferably needs to be set before init for multiprecision kvstore.set_optimizer(self._optimizer) @@ -211,8 +232,8 @@ def _row_sparse_pull(self, parameter, out, row_id): self._init_kvstore() if self._params_to_init: self._init_params() - self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \ - out=out, row_ids=row_id) + idx = self._param2idx[parameter.name] + self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx) def step(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. Should be called after @@ -272,7 +293,7 @@ def _allreduce_grads(self): self._kvstore.push(i, param.list_grad(), priority=-i) if not self._update_on_kvstore: - self._kvstore.pull(i, param.list_grad(), priority=-i) + self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False) def update(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. @@ -327,7 +348,7 @@ def _update(self, ignore_stale_grad=False): if self._kvstore and self._update_on_kvstore: if param._stype == 'default': # 'row_sparse' parameters are not pulled immediately - they're pulled - # in `SparseBlock.sparse_forward` + # in `Block.forward` self._kvstore.pull(i, param.list_data(), priority=-i) continue diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index f31dac01cd10..609733659753 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -235,7 +235,7 @@ def push(self, key, value, priority=0): self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) - def pull(self, key, out=None, priority=0): + def pull(self, key, out=None, priority=0, ignore_sparse=True): """ Pulls a single value or a sequence of values from the store. This function returns immediately after adding an operator to the engine. @@ -247,8 +247,8 @@ def pull(self, key, out=None, priority=0): The returned values are guaranteed to be the latest values in the store. - For `RowSparseNDArray` values, this call is ignored, - please use ``row_sparse_pull`` instead. + pull with `RowSparseNDArray` is not supported for dist kvstore. + Please use ``row_sparse_pull`` instead. Parameters ---------- @@ -263,6 +263,9 @@ def pull(self, key, out=None, priority=0): Higher priority pull operations are likely to be executed before other pull actions. + ignore_sparse: bool, optional, default True + Whether to ignore sparse arrays in the request. + Examples -------- >>> # pull a single key-value pair @@ -298,11 +301,13 @@ def pull(self, key, out=None, priority=0): assert(out is not None) ckeys, cvals, use_str_keys = _ctype_key_value(key, out) if use_str_keys: - check_call(_LIB.MXKVStorePullEx( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + check_call(_LIB.MXKVStorePullWithSparseEx(self.handle, mx_uint(len(ckeys)), ckeys, + cvals, ctypes.c_int(priority), + ctypes.c_bool(ignore_sparse))) else: - check_call(_LIB.MXKVStorePull( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + check_call(_LIB.MXKVStorePullWithSparse(self.handle, mx_uint(len(ckeys)), ckeys, + cvals, ctypes.c_int(priority), + ctypes.c_bool(ignore_sparse))) def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 467118b9921e..efa7301d7ab2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -869,15 +869,49 @@ int MXKVStorePull(KVStoreHandle handle, v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } - static_cast(handle)->Pull(v_keys, v_vals, priority); + static_cast(handle)->Pull(v_keys, v_vals, priority, true); API_END(); } int MXKVStorePullEx(KVStoreHandle handle, - mx_uint num, - const char** keys, - NDArrayHandle* vals, - int priority) { + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority, true); + API_END(); +} + +int MXKVStorePullWithSparse(KVStoreHandle handle, + mx_uint num, + const int* keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority, ignore_sparse); + API_END(); +} + +int MXKVStorePullWithSparseEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority, + bool ignore_sparse) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -885,7 +919,7 @@ int MXKVStorePullEx(KVStoreHandle handle, v_keys[i] = keys[i]; v_vals[i] = static_cast(vals[i]); } - static_cast(handle)->Pull(v_keys, v_vals, priority); + static_cast(handle)->Pull(v_keys, v_vals, priority, ignore_sparse); API_END(); } diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 97f258c10618..b6537dabb638 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -61,6 +61,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { void StopNoWait() { SignalQueuesForKill(); gpu_normal_workers_.Clear(); + gpu_priority_workers_.Clear(); gpu_copy_workers_.Clear(); cpu_normal_workers_.Clear(); cpu_priority_worker_.reset(nullptr); @@ -101,6 +102,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { this->ExecuteOprBlock(RunContext{ctx, nullptr}, opr_block); } else { if (ctx.dev_mask() == Context::kCPU) { + // CPU execution. if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); } else { @@ -152,24 +154,44 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } } else { const size_t nthread = gpu_worker_nthreads_; - auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { - // Signify to kernel that GPU is being used, so reserve cores as necessary - OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); - auto blk = new ThreadWorkerBlock(); - blk->pool.reset(new ThreadPool( - nthread, - [this, ctx, is_copy, blk] - (std::shared_ptr ready_event) { - this->GPUWorker(ctx, is_copy, blk, ready_event); - }, true)); - return blk; - }); - if (ptr) { - if (opr_block->opr->prop == FnProperty::kDeleteVar) { - ptr->task_queue.PushFront(opr_block, opr_block->priority); - } else { + // GPU priority task + if (opr_block->opr->prop == FnProperty::kGPUPrioritized) { + auto ptr = gpu_priority_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { + // Signify to kernel that GPU is being used, so reserve cores as necessary + OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool( + nthread, + [this, ctx, is_copy, blk] + (std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, true)); + return blk; + }); + if (ptr) { ptr->task_queue.Push(opr_block, opr_block->priority); } + } else { + // GPU normal task + auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { + // Signify to kernel that GPU is being used, so reserve cores as necessary + OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); + auto blk = new ThreadWorkerBlock(); + blk->pool.reset(new ThreadPool( + nthread, + [this, ctx, is_copy, blk] + (std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, true)); + return blk; + }); + if (ptr) { + if (opr_block->opr->prop == FnProperty::kDeleteVar) { + ptr->task_queue.PushFront(opr_block, opr_block->priority); + } else { + ptr->task_queue.Push(opr_block, opr_block->priority); + } + } } } } @@ -206,6 +228,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine { common::LazyAllocArray > gpu_normal_workers_; // workers doing copy works from/to GPU common::LazyAllocArray > gpu_copy_workers_; + // gpu priority workers + common::LazyAllocArray > gpu_priority_workers_; /*! * \brief GPU worker that performs operations on a certain device. * \param dev_id The device id of the worker. @@ -304,6 +328,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { /*! Signal all queues for shutdown */ void SignalQueuesForKill() { + SignalQueueForKill(&gpu_priority_workers_); SignalQueueForKill(&gpu_normal_workers_); SignalQueueForKill(&gpu_copy_workers_); SignalQueueForKill(&cpu_normal_workers_); diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index a5d6a1dabeff..d242dc30450f 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -629,7 +629,7 @@ class CommDevice : public Comm { "next time row_sparse_pull() is called. To avoid such an issue," "consider create a new NDArray buffer to store the output."); } - + bool is_gpu = retained_gpu.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { const TBlob& indices = row_id.data(); using namespace mxnet::common; @@ -653,7 +653,8 @@ class CommDevice : public Comm { } on_complete(); }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()}, - FnProperty::kNormal, priority, "KVStoreSparseRetain"); + is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, + priority, "KVStoreSparseRetain"); CopyFromTo(retained_gpu, out, priority); } } diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index dd3464bf6db4..7e2f5cb5faa9 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -207,10 +207,11 @@ class KVStoreDist : public KVStoreLocal { void PullImpl(const std::vector& keys, const std::vector& values, - int priority) override { + int priority, bool ignore_sparse) override { + CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False"; std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -263,7 +264,7 @@ class KVStoreDist : public KVStoreLocal { int priority = 0) override { std::vector uniq_keys; std::vector>> grouped_val_rowids; - GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -305,7 +306,7 @@ class KVStoreDist : public KVStoreLocal { // first aggregate the values over keys std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { // merge over devices diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 38ecf121dfeb..84e2700a20de 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -101,9 +101,10 @@ class KVStoreLocal : public KVStore { void Pull(const std::vector& keys, const std::vector& values, - int priority) override { + int priority, + bool ignore_sparse) override { SetKeyType(kIntKey); - PullImpl(keys, values, priority); + PullImpl(keys, values, priority, ignore_sparse); } void PullRowSparse(const std::vector& keys, @@ -124,11 +125,12 @@ class KVStoreLocal : public KVStore { void Pull(const std::vector& str_keys, const std::vector& values, - int priority) override { + int priority, + bool ignore_sparse) override { SetKeyType(kStringKey); std::vector keys(str_keys.size()); LookupKeys(str_keys, &keys); - PullImpl(keys, values, priority); + PullImpl(keys, values, priority, ignore_sparse); } void PullRowSparse(const std::vector& str_keys, @@ -162,7 +164,7 @@ class KVStoreLocal : public KVStore { int priority) { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority); @@ -198,10 +200,11 @@ class KVStoreLocal : public KVStore { virtual void PullImpl(const std::vector& keys, const std::vector& values, - int priority) { + int priority, + bool ignore_sparse) { std::vector uniq_keys; std::vector > grouped_vals; - GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); + GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, ignore_sparse); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; @@ -216,7 +219,7 @@ class KVStoreLocal : public KVStore { int priority = 0) { std::vector uniq_keys; std::vector>> grouped_val_rowids; - GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); + GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; const NDArray& local = local_[key]; @@ -250,9 +253,11 @@ class KVStoreLocal : public KVStore { virtual void GroupKVPairsPush(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const NDArray& nd) -> bool { + auto validator = [this](const int key, const NDArray& nd, bool ignore_sparse) -> bool { + CHECK(!ignore_sparse) << "Cannot ignore sparse arrays for push"; auto stype = nd.storage_type(); // valid NDArray if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; @@ -260,7 +265,7 @@ class KVStoreLocal : public KVStore { LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; return false; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } /** * \brief group values on keys for pull @@ -268,11 +273,12 @@ class KVStoreLocal : public KVStore { virtual void GroupKVPairsPull(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const NDArray* nd) -> bool { + auto validator = [this](const int key, const NDArray* nd, bool ignore_sparse) -> bool { // valid - if (nd->storage_type() == kDefaultStorage) return true; + if (nd->storage_type() == kDefaultStorage || !ignore_sparse) return true; // invalid, print warning messages once if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " @@ -282,7 +288,7 @@ class KVStoreLocal : public KVStore { } return false; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } typedef std::pair RSPVal; @@ -292,9 +298,11 @@ class KVStoreLocal : public KVStore { virtual void GroupKVPairsPullRsp(const std::vector& keys, const std::vector& values, std::vector *uniq_keys, - std::vector> *grouped_vals) { + std::vector> *grouped_vals, + bool ignore_sparse) { // check if the storage type of a value is valid - auto validator = [this](const int key, const RSPVal& val_rowid) -> bool { + auto validator = [this](const int key, const RSPVal& val_rowid, bool ignore_sparse) -> bool { + CHECK(!ignore_sparse) << "Cannot ignore sparse arrays in row_sparse_pull"; auto val_stype = val_rowid.first->storage_type(); auto rowid_stype = val_rowid.second.storage_type(); // check storage types @@ -304,7 +312,7 @@ class KVStoreLocal : public KVStore { << "row_sparse_pull rowids, but detected storage type " << rowid_stype; return true; }; - GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); + GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); } /** @@ -316,7 +324,8 @@ class KVStoreLocal : public KVStore { const std::vector& values, std::vector* uniq_keys, std::vector >* grouped_vals, - const FValidate& is_valid) { + const FValidate& is_valid, + bool ignore_sparse) { CHECK_EQ(keys.size(), values.size()); // TODO(mli) check if already sorted as an optimization using Idx = std::pair; @@ -330,7 +339,7 @@ class KVStoreLocal : public KVStore { int pre_key = idx[0].first - 1; for (auto i : idx) { - if (is_valid(i.first, values[i.second])) { + if (is_valid(i.first, values[i.second], ignore_sparse)) { if (i.first != pre_key) { uniq_keys->push_back(i.first); grouped_vals->push_back({values[i.second]}); @@ -370,22 +379,20 @@ class KVStoreLocal : public KVStore { NDArray data_in_ctx = diff_ctx ? NDArray(data.shape(), ctx, true, data.dtype()) : data; // if data == data_in_ctx, CopyFromTo is smart enough to skip the copy CopyFromTo(data, &data_in_ctx, priority); - Resource rsc = ResourceManager::Get()->Request(out.ctx(), - ResourceRequest(ResourceRequest::kTempSpace)); // GPU requires temp resources - std::vector mutate_vars{out.var()}; - if (out.ctx().dev_mask() == gpu::kDevMask) mutate_vars.emplace_back(rsc.var); + bool is_gpu = out.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync( [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { // copy data.data() to out.data() out.CheckAndAlloc({mshadow::Shape1(num_elements)}); TBlob out_data = out.data(); + NDArray workspace; switch (out.ctx().dev_mask()) { case cpu::kDevMask: { mshadow::Stream *s = rctx.get_stream(); ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); - UniqueImpl(rsc, s, out); + UniqueImpl(&workspace, s, out); break; } #if MXNET_USE_CUDA @@ -393,7 +400,7 @@ class KVStoreLocal : public KVStore { mshadow::Stream *s = rctx.get_stream(); ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); - UniqueImpl(rsc, s, out); + UniqueImpl(&workspace, s, out); // wait for GPU operations to complete s->Wait(); break; @@ -403,8 +410,9 @@ class KVStoreLocal : public KVStore { LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } on_complete(); - }, out.ctx(), {data_in_ctx.var()}, mutate_vars, - FnProperty::kNormal, priority, "KVStoreUnique"); + }, out.ctx(), {data_in_ctx.var()}, {out.var()}, + is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, + priority, "KVStoreUnique"); return out; } diff --git a/src/kvstore/kvstore_utils.cc b/src/kvstore/kvstore_utils.cc index e187b0ce4890..b53eca433e97 100644 --- a/src/kvstore/kvstore_utils.cc +++ b/src/kvstore/kvstore_utils.cc @@ -29,8 +29,8 @@ namespace mxnet { namespace kvstore { template<> -void UniqueImpl(const Resource& rsc, mshadow::Stream *s, - const NDArray& out) { +void UniqueImpl(NDArray* workspace, mshadow::Stream *s, + const NDArray& out) { const size_t num_elements = out.shape().Size(); CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected"; MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, { diff --git a/src/kvstore/kvstore_utils.cu b/src/kvstore/kvstore_utils.cu index 438fe29dac4e..2dab5bc1802d 100644 --- a/src/kvstore/kvstore_utils.cu +++ b/src/kvstore/kvstore_utils.cu @@ -41,8 +41,8 @@ namespace mxnet { namespace kvstore { template -size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, - IType *dptr, const size_t size) { +size_t UniqueImplGPU(NDArray *workspace, mshadow::Stream *s, + IType *dptr, const size_t size, Context ctx) { // estimate unique temp space. The first byte is reserved to store the number // of unique values selected const size_t num_selected_bytes = sizeof(size_t); @@ -68,12 +68,12 @@ size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, // request temp storage const size_t total_workspace = num_selected_bytes + sort_output_bytes + std::max(sort_temp_bytes, unique_temp_bytes); - mshadow::Tensor workspace = rsc - .get_space_typed(mshadow::Shape1(total_workspace), s); + *workspace = NDArray(mshadow::Shape1((total_workspace + 3) / 4), ctx, false); + char* workspace_dptr = reinterpret_cast(workspace->data().dptr_); // temp space layout: num_selected_ptr, sort_output_bytes, unique/sort_temp_storage - size_t* num_selected_ptr = reinterpret_cast(workspace.dptr_); - IType* sort_output_ptr = reinterpret_cast(workspace.dptr_ + num_selected_bytes); - void *temp_storage = static_cast(workspace.dptr_ + + size_t* num_selected_ptr = reinterpret_cast(workspace_dptr); + IType* sort_output_ptr = reinterpret_cast(workspace_dptr + num_selected_bytes); + void *temp_storage = static_cast(workspace_dptr + num_selected_bytes + sort_output_bytes); // execute the sort kernel #ifndef SORT_WITH_THRUST @@ -96,13 +96,12 @@ size_t UniqueImplGPU(const Resource& rsc, mshadow::Stream *s, } template<> -void UniqueImpl(const Resource& rsc, mshadow::Stream *s, - const NDArray &out) { +void UniqueImpl(NDArray *workspace, mshadow::Stream *s, const NDArray &out) { const size_t num_elements = out.shape().Size(); CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected"; MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, { IType *dptr = out.data().dptr(); - size_t num_selected_out = UniqueImplGPU(rsc, s, dptr, num_elements); + size_t num_selected_out = UniqueImplGPU(workspace, s, dptr, num_elements, out.ctx()); // set the shape of data/aux_data according to the number of unique values out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out)); }); diff --git a/src/kvstore/kvstore_utils.h b/src/kvstore/kvstore_utils.h index ee173b4559f7..2527f7ed0ce2 100644 --- a/src/kvstore/kvstore_utils.h +++ b/src/kvstore/kvstore_utils.h @@ -36,14 +36,15 @@ namespace kvstore { /*! * \brief compute unique and sorted values in a row_sparse ndarray. - * \param rsc Temp resource for computation + * \param workspace Temp workspace for computation. Its a pointer to a + NDArray placeholder to make sure the NDArray is not free'd + during execution. * \param s Stream * \param out Input and output ndarray. The ndarray stores the * unique elements in out.data(). */ template -void UniqueImpl(const Resource& rsc, mshadow::Stream *s, - const NDArray& out); +void UniqueImpl(NDArray* workspace, mshadow::Stream *s, const NDArray& out); } // namespace kvstore } // namespace mxnet diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e90fb6319d77..daa2abd1e4a6 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1143,7 +1143,7 @@ void CopyFromToImpl(const NDArray& from, const NDArray& to, } } -void CopyFromTo(const NDArray& from, const NDArray& to, int priority) { +void CopyFromTo(const NDArray& from, const NDArray& to, int priority, bool is_opr) { if (from.var() == to.var() && from.byte_offset() == to.byte_offset()) { // skip to copy to itself return; @@ -1224,7 +1224,7 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority) { on_complete(); }, from.ctx(), const_vars, mutable_vars, from.dtype() != to.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, - priority, "CopyGPU2GPU"); + priority, is_opr ? "_copyto_GPU2GPU" : "CopyGPU2GPU"); } else { LOG(FATAL) << "unknown device mask"; } @@ -2004,7 +2004,7 @@ void CopyFromToSimple( const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CopyFromTo(inputs[0], outputs[0], 0); + CopyFromTo(inputs[0], outputs[0], 0, true); } // copy function is special diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 32ed2dddb6fb..8ba1edab3a0d 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -38,6 +38,7 @@ def check_diff(A, x, rank=None): irregular_shape = (1211,1211) big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND +keys_invalid = [999] keys_shape = ['3', '5', '7'] keys_big_shape = ['99'] fp16_keys_shape = ['4', '6', '8'] @@ -350,18 +351,44 @@ def check_init(kv, cur_keys, cur_shape, device=False): check_init(kv, init_test_keys_device_big, big_shape, device=True) print('worker ' + str(kv.rank) + ' is initialized') -def test_gluon_trainer_reset(): - params = mx.gluon.ParameterDict() - x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse') - params.initialize(ctx=mx.cpu(0), init='zeros') - trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) - params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params') - row_id = mx.nd.arange(0, 4) - w = x.row_sparse_data(row_id) - assert trainer._kv_initialized and trainer._update_on_kvstore - # load would fail to reset kvstore since update_on_kvstore is True - assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params') - print('worker ' + str(my_rank) + ' passed test_gluon_trainer_reset') +def test_invalid_operations(): + def check_invalid_gluon_trainer_reset(): + params = mx.gluon.ParameterDict() + x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse') + params.initialize(ctx=mx.cpu(0), init='zeros') + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params') + row_id = mx.nd.arange(0, 4) + w = x.row_sparse_data(row_id) + assert trainer._kv_initialized and trainer._update_on_kvstore + mx.nd.waitall() + # load would fail to reset kvstore since update_on_kvstore is True + assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params') + print('worker ' + str(my_rank) + ' passed check_invalid_gluon_trainer_reset') + + def check_invalid_pull(): + kv.init(keys_invalid[0], mx.nd.ones((2,2)).tostype('row_sparse')) + out = mx.nd.ones((2,2)).tostype('row_sparse') + assert_exception(kv.pull, mx.MXNetError, 'invalid_key', out=out, ignore_sparse=False) + print('worker ' + str(my_rank) + ' passed check_invalid_pull') + + check_invalid_gluon_trainer_reset() + check_invalid_pull() + +def test_gluon_trainer(): + def check_trainer_kv_type(stype, grad_stype, update_on_kv): + params = mx.gluon.ParameterDict() + x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + trainer._init_kvstore() + assert trainer._kv_initialized + assert trainer._update_on_kvstore is update_on_kv + + check_trainer_kv_type('default', 'default', False) + check_trainer_kv_type('default', 'row_sparse', True) + check_trainer_kv_type('row_sparse', 'row_sparse', True) + print('worker ' + str(my_rank) + ' passed test_gluon_trainer') if __name__ == "__main__": @@ -372,7 +399,9 @@ def test_gluon_trainer_reset(): parser.add_argument('--no-multiprecision', dest='multiprecision', action='store_false') opt = parser.parse_args() if opt.type == 'gluon': - test_gluon_trainer_reset() + test_gluon_trainer() + if opt.type == 'invalid': + test_invalid_operations() if opt.type == 'all' or opt.type == 'init': test_sync_init(opt.gpu) if opt.type == 'all' or opt.type == 'default': diff --git a/tests/python/gpu/test_kvstore_gpu.py b/tests/python/gpu/test_kvstore_gpu.py index 76231fbe90ee..90ed129114ff 100644 --- a/tests/python/gpu/test_kvstore_gpu.py +++ b/tests/python/gpu/test_kvstore_gpu.py @@ -77,6 +77,13 @@ def check_rsp_pull(kv, count, ctxs, is_same_rowid=False, use_slice=False): expected_val += 0 if row in excluded_row_ids else 2 assert_almost_equal(retained[row], expected_val) + kv.pull('e', out=vals_to_pull, ignore_sparse=False) + for val in vals: + retained = val.asnumpy() + expected_val = np.zeros_like(retained) + expected_val[:] = 2 + assert_almost_equal(retained, expected_val) + check_rsp_pull(kv, 1, [mx.gpu(0)]) check_rsp_pull(kv, 1, [mx.cpu(0)]) check_rsp_pull(kv, 4, [mx.gpu(i//2) for i in range(4)]) diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index eac9fad45f57..2a34400d60ab 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -205,3 +205,31 @@ def check_trainer_reset_kv(kv): kvs = ['local', 'device'] for kv in kvs: check_trainer_reset_kv(kv) + +@with_seed() +def test_trainer_sparse_kv(): + def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv): + params = gluon.ParameterDict() + x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0)) + ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows) + with mx.autograd.record(): + for w in ws: + y = w + 1 + y.backward() + trainer.step(1) + assert trainer._kvstore.type == kv + assert trainer._kv_initialized + assert trainer._update_on_kvstore is update_on_kv + # the updated parameter should be based on the loaded checkpoint + mx.nd.waitall() + updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows) + assert (updated_w == -0.2).asnumpy().all() + + kvs = ['local', 'device'] + for kv in kvs: + check_trainer_sparse_kv(kv, 'default', 'default', True) + check_trainer_sparse_kv(kv, 'default', 'row_sparse', False) + check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', True) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 0ab61bb27483..921a5704d54b 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -160,43 +160,52 @@ def check_aggregator(kv, key, key_list, stype): @with_seed() def test_sparse_aggregator(): """aggregate sparse ndarray on muliple devices""" + def check_sparse_aggregator(sparse_pull): + stype = 'row_sparse' + kv = init_kv_with_str(stype) - stype = 'row_sparse' - kv = init_kv_with_str(stype) - - # devices - num_devs = 4 - devs = [mx.Context('cpu', i) for i in range(num_devs)] - - # single - vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] - expected_sum = np.zeros(shape) - for v in vals: - expected_sum += v.asnumpy() - - # prepare row_ids - all_rows = mx.nd.array(np.arange(shape[0])) - kv.push('a', vals) - kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) - result_sum = np.zeros(shape) - for v in vals: - result_sum += v.asnumpy() - assert_almost_equal(result_sum, expected_sum * num_devs) + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] - # list - vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) - expected_sum = np.zeros(shape) - for v in vals[0]: - expected_sum += v.asnumpy() - - kv.push(str_keys, vals) - kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) - for vv in vals: + # single + vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] + expected_sum = np.zeros(shape) + for v in vals: + expected_sum += v.asnumpy() + + # prepare row_ids + kv.push('a', vals) + if sparse_pull: + all_rows = mx.nd.array(np.arange(shape[0])) + kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) + else: + kv.pull('a', out=vals, ignore_sparse=False) result_sum = np.zeros(shape) - for v in vv: + for v in vals: result_sum += v.asnumpy() assert_almost_equal(result_sum, expected_sum * num_devs) + # list + vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) + expected_sum = np.zeros(shape) + for v in vals[0]: + expected_sum += v.asnumpy() + + kv.push(str_keys, vals) + if sparse_pull: + kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) + else: + kv.pull(str_keys, out=vals, ignore_sparse=False) + for vv in vals: + result_sum = np.zeros(shape) + for v in vv: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) + + check_sparse_aggregator(False) + check_sparse_aggregator(True) + def updater(key, recv, local): """use updater: += with int keys""" assert(isinstance(key, int))