Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Enable support for dense weight and sparse grad Adagrad updates (#11355)
Browse files Browse the repository at this point in the history
* Support dense weight and sparse grad AdagradUpdate

* Simplify AdagradStorageType

* Add test
  • Loading branch information
leezu authored and eric-haibin-lin committed Jun 25, 2018
1 parent adec280 commit 9b27262
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def update(self, index, weight, grad, state):
lr = self._get_lr(index)
wd = self._get_wd(index)

is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse'
is_sparse = grad.stype == 'row_sparse'
history = state

if is_sparse:
Expand Down
34 changes: 26 additions & 8 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,16 +1663,20 @@ inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
const int weight_stype = in_attrs->at(0);
const int grad_stype = in_attrs->at(1);
const int state_stype = in_attrs->at(2);
bool dispatched = false;
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) &&
param.wd == 0.0f) {
// rsp, rsp, rsp -> rsp with wd = 0.0
dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
dispatch_mode, DispatchMode::kFComputeEx);
if (!dispatched && grad_stype == kRowSparseStorage &&
(weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
state_stype == weight_stype && param.wd == 0.0f) {
// weight and state share stype, grad's stype = rsp
dispatched = storage_type_assign(
out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode,
DispatchMode::kFComputeEx);
}
return dispatched;
}
Expand Down Expand Up @@ -1802,10 +1806,24 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
using namespace mxnet_op;
const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);

const auto weight_stype = inputs[0].storage_type();
const auto grad_stype = inputs[1].storage_type();
const auto state_stype = inputs[2].storage_type();
const auto output_stype = outputs[0].storage_type();

if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
NDArray out = outputs[0];
AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2], req[0], &out);
AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
req[0], &out);
} else if (state_stype == weight_stype && output_stype == weight_stype &&
weight_stype == kDefaultStorage &&
grad_stype == kRowSparseStorage) {
TBlob out_blob = outputs[0].data();
AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
inputs[2].data(), req[0],
&out_blob);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,8 @@ def test_adagrad():
if wd_option.get('wd', 0.0) == 0.0:
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
w_stype='row_sparse', g_stype='row_sparse')
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
g_stype='row_sparse')



Expand Down

0 comments on commit 9b27262

Please sign in to comment.