Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve sparse pull performance for gluon trainer #11429

Merged
merged 18 commits into from
Jul 9, 2018
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=invalid
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon
}

Expand Down
8 changes: 6 additions & 2 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1900,27 +1900,31 @@ MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle,
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \param ignore_sparse whether to ignore sparse arrays in the request
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority);
int priority,
bool ignore_sparse = true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C API doesn't support default value

/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \param ignore_sparse whether to ignore sparse arrays in the request
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority);
int priority,
bool ignore_sparse = true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added extra CAPIs instead of adding default value to this one


/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer.
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ enum class FnProperty {
kCopyToGPU,
/*! \brief Prioritized sync operation on CPU */
kCPUPrioritized,
/*! \brief Prioritized sync operation on GPU */
kGPUPrioritized,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it at the end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to the end.

/*! \brief Asynchronous function call */
kAsync,
/*! \brief Delete variable call */
Expand Down
6 changes: 4 additions & 2 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,21 @@ class KVStore {
* \param keys the list of keys
* \param values the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
* \param ignore_sparse whether to ignore sparse arrays in the request
*/
virtual void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;
int priority = 0, bool ignore_sparse = true) = 0;
/*!
* \brief pull a list of key-value pairs from the store
* \param keys the list of keys in string format
* \param values the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
* \param ignore_sparse whether to ignore sparse arrays in the request
*/
virtual void Pull(const std::vector<std::string>& str_keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;
int priority = 0, bool ignore_sparse = true) = 0;

/*!
* \brief pull a list of key-value pairs from the store.
Expand Down
4 changes: 3 additions & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,10 +1024,12 @@ void CopyFromTo(const NDArray &from, const NDArray *to, int priority = 0);
* \param from the ndarray we want to copy data from
* \param to the target ndarray
* \param priority Priority of the action.
* \param is_opr whether it is invoked by an operator. For example, false if invoked from
KVStore, true if invoked from `_copyto` operator.
* \note The function name explicitly marks the order of from and to
* due to different possible convention carried by copy function.
*/
void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0);
void CopyFromTo(const NDArray &from, const NDArray& to, int priority = 0, bool is_opr = false);

/*!
* \brief Perform elementwise sum over each data from source, store result into out.
Expand Down
49 changes: 35 additions & 14 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
"got %s."%(type(params)))
self._params = []
# parameters to initialize on the kvstore
self._contains_sparse = False
self._contains_sparse_weight = False
self._contains_sparse_grad = False
self._param2idx = {}
for i, param in enumerate(params):
if not isinstance(param, Parameter):
Expand All @@ -80,7 +81,9 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
self._params.append(param)
param._set_trainer(self)
if param._stype != 'default':
self._contains_sparse = True
self._contains_sparse_weight = True
if param._grad_stype != 'default':
self._contains_sparse_grad = True
self._compression_params = compression_params
optimizer_params = optimizer_params if optimizer_params else {}
self._scale = float(optimizer_params.get('rescale_grad', 1.0))
Expand Down Expand Up @@ -153,13 +156,31 @@ def _reset_kvstore(self):
def _init_kvstore(self):
"""Create kvstore."""
config = self._kvstore_params
if self._contains_sparse:
# if weight is sparse, the weight must be updated on KVStore.
# training loop contains:
# - row_sparse_pull(sparse_weight)
# - forward()
# - backward()
# - push(sparse_grad), push(dense_grad)
# - pull(dense_weight)
if self._contains_sparse_weight:
kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
# update_on_kvstore is set to False by the user
# raise Error if update_on_kvstore is set to False by the user
if config['update_on_kvstore'] is False:
raise RuntimeError("Cannot set update_on_kvstore to False when sparse "
"gradients and/or sparse weights are present for "
"Parameter '%s'."%param.name)
raise RuntimeError("Cannot set update_on_kvstore to False when sparse weights "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be a error, or can it be a warning and automatically use update_on_kvstore ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, shouldn't this be outside the if contains_sparse_weight condition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default update_on_kvstore is None. It's only set if user provides a value on purpose. I think an explicit err is better, since we cannot satisfy user's original intent.
If user set update_on_kvstore to False and the model contains no sparse weight, it's totally fine. Why should this be outside the if condition?

"are present.")
# if weight is dense and grad is sparse, the weight better not be updated on KVStore.
# training loop contains:
# - forward()
# - backward()
# - push(grad)
# - pull(grad)
# - update(grad, weight)
elif self._contains_sparse_grad:
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays)
update_on_kvstore = False
# normal case
else:
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
Expand All @@ -169,9 +190,9 @@ def _init_kvstore(self):
if kvstore:
if self._compression_params:
kvstore.set_gradient_compression(self._compression_params)
# kv.pull(row_sparse_grad) is not supported
if 'dist' in kvstore.type and not self._contains_sparse:
update_on_kvstore = False
if 'dist' in kvstore.type:
# kv.pull(row_sparse_grad) is not supported for dist kvstore
update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the comment I'm guessing you meant not self._contains_sparse_weight and not self._contains_sparse_grad

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended. kv.pull(row_sparse_grad) is not supported for dist kvstore, so we want to set update_on_kvstore = True if there's sparse grad.

if update_on_kvstore:
# optimizer preferably needs to be set before init for multiprecision
kvstore.set_optimizer(self._optimizer)
Expand Down Expand Up @@ -211,8 +232,8 @@ def _row_sparse_pull(self, parameter, out, row_id):
self._init_kvstore()
if self._params_to_init:
self._init_params()
self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \
out=out, row_ids=row_id)
idx = self._param2idx[parameter.name]
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)

def step(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update. Should be called after
Expand Down Expand Up @@ -272,7 +293,7 @@ def _allreduce_grads(self):
self._kvstore.push(i, param.list_grad(), priority=-i)

if not self._update_on_kvstore:
self._kvstore.pull(i, param.list_grad(), priority=-i)
self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False)

def update(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update.
Expand Down Expand Up @@ -327,7 +348,7 @@ def _update(self, ignore_stale_grad=False):
if self._kvstore and self._update_on_kvstore:
if param._stype == 'default':
# 'row_sparse' parameters are not pulled immediately - they're pulled
# in `SparseBlock.sparse_forward`
# in `Block.forward`
self._kvstore.pull(i, param.list_data(), priority=-i)
continue

Expand Down
17 changes: 10 additions & 7 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def push(self, key, value, priority=0):
self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))


def pull(self, key, out=None, priority=0):
def pull(self, key, out=None, priority=0, ignore_sparse=True):
""" Pulls a single value or a sequence of values from the store.

This function returns immediately after adding an operator to the engine.
Expand All @@ -247,8 +247,8 @@ def pull(self, key, out=None, priority=0):

The returned values are guaranteed to be the latest values in the store.

For `RowSparseNDArray` values, this call is ignored,
please use ``row_sparse_pull`` instead.
pull with `RowSparseNDArray` is not supported for dist kvstore.
Please use ``row_sparse_pull`` instead.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ignore_sparse be defaulted to false to be consistent with previous behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous behavior is to always ignore sparse. So it's consistent


Parameters
----------
Expand All @@ -263,6 +263,9 @@ def pull(self, key, out=None, priority=0):
Higher priority pull operations are likely to be executed before
other pull actions.

ignore_sparse: bool, optional, default True
Whether to ignore sparse arrays in the request.

Examples
--------
>>> # pull a single key-value pair
Expand Down Expand Up @@ -298,11 +301,11 @@ def pull(self, key, out=None, priority=0):
assert(out is not None)
ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
if use_str_keys:
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
check_call(_LIB.MXKVStorePullEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority), ctypes.c_bool(ignore_sparse)))
else:
check_call(_LIB.MXKVStorePull(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
check_call(_LIB.MXKVStorePull(self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority), ctypes.c_bool(ignore_sparse)))

def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
Expand Down
10 changes: 6 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -861,31 +861,33 @@ int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
const int* keys,
NDArrayHandle* vals,
int priority) {
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<int> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
API_END();
}

int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
int priority,
bool ignore_sparse) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority, ignore_sparse);
API_END();
}

Expand Down
57 changes: 41 additions & 16 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
void StopNoWait() {
SignalQueuesForKill();
gpu_normal_workers_.Clear();
gpu_priority_workers_.Clear();
gpu_copy_workers_.Clear();
cpu_normal_workers_.Clear();
cpu_priority_worker_.reset(nullptr);
Expand Down Expand Up @@ -101,6 +102,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
this->ExecuteOprBlock(RunContext{ctx, nullptr}, opr_block);
} else {
if (ctx.dev_mask() == Context::kCPU) {
// CPU execution.
if (opr_block->opr->prop == FnProperty::kCPUPrioritized) {
cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority);
} else {
Expand Down Expand Up @@ -152,24 +154,44 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
}
} else {
const size_t nthread = gpu_worker_nthreads_;
auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() {
// Signify to kernel that GPU is being used, so reserve cores as necessary
OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true));
auto blk = new ThreadWorkerBlock<kWorkerQueue>();
blk->pool.reset(new ThreadPool(
nthread,
[this, ctx, is_copy, blk]
(std::shared_ptr<dmlc::ManualEvent> ready_event) {
this->GPUWorker(ctx, is_copy, blk, ready_event);
}, true));
return blk;
});
if (ptr) {
if (opr_block->opr->prop == FnProperty::kDeleteVar) {
ptr->task_queue.PushFront(opr_block, opr_block->priority);
} else {
// GPU priority task
if (opr_block->opr->prop == FnProperty::kGPUPrioritized) {
auto ptr = gpu_priority_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() {
// Signify to kernel that GPU is being used, so reserve cores as necessary
OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true));
auto blk = new ThreadWorkerBlock<kPriorityQueue>();
blk->pool.reset(new ThreadPool(
nthread,
[this, ctx, is_copy, blk]
(std::shared_ptr<dmlc::ManualEvent> ready_event) {
this->GPUWorker(ctx, is_copy, blk, ready_event);
}, true));
return blk;
});
if (ptr) {
ptr->task_queue.Push(opr_block, opr_block->priority);
}
} else {
// GPU normal task
auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() {
// Signify to kernel that GPU is being used, so reserve cores as necessary
OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true));
auto blk = new ThreadWorkerBlock<kWorkerQueue>();
blk->pool.reset(new ThreadPool(
nthread,
[this, ctx, is_copy, blk]
(std::shared_ptr<dmlc::ManualEvent> ready_event) {
this->GPUWorker(ctx, is_copy, blk, ready_event);
}, true));
return blk;
});
if (ptr) {
if (opr_block->opr->prop == FnProperty::kDeleteVar) {
ptr->task_queue.PushFront(opr_block, opr_block->priority);
} else {
ptr->task_queue.Push(opr_block, opr_block->priority);
}
}
}
}
}
Expand Down Expand Up @@ -206,6 +228,8 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
common::LazyAllocArray<ThreadWorkerBlock<kWorkerQueue> > gpu_normal_workers_;
// workers doing copy works from/to GPU
common::LazyAllocArray<ThreadWorkerBlock<kCopyQueue> > gpu_copy_workers_;
// gpu priority workers
common::LazyAllocArray<ThreadWorkerBlock<kPriorityQueue> > gpu_priority_workers_;
/*!
* \brief GPU worker that performs operations on a certain device.
* \param dev_id The device id of the worker.
Expand Down Expand Up @@ -304,6 +328,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {

/*! Signal all queues for shutdown */
void SignalQueuesForKill() {
SignalQueueForKill(&gpu_priority_workers_);
SignalQueueForKill(&gpu_normal_workers_);
SignalQueueForKill(&gpu_copy_workers_);
SignalQueueForKill(&cpu_normal_workers_);
Expand Down
5 changes: 3 additions & 2 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ class CommDevice : public Comm {
"next time row_sparse_pull() is called. To avoid such an issue,"
"consider create a new NDArray buffer to store the output.");
}

bool is_gpu = retained_gpu.ctx().dev_mask() == gpu::kDevMask;
Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
using namespace mxnet::common;
Expand All @@ -653,7 +653,8 @@ class CommDevice : public Comm {
}
on_complete();
}, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()},
FnProperty::kNormal, priority, "KVStoreSparseRetain");
is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized,
priority, "KVStoreSparseRetain");
CopyFromTo(retained_gpu, out, priority);
}
}
Expand Down
Loading