diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 0c3fc904fb1f..267a402f246b 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -1107,7 +1107,7 @@ def update(self, index, weight, grad, state): lr = self._get_lr(index) wd = self._get_wd(index) - is_sparse = weight.stype == 'row_sparse' and grad.stype == 'row_sparse' + is_sparse = grad.stype == 'row_sparse' history = state if is_sparse: diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 28b382c92fbf..9251b8614806 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1663,16 +1663,20 @@ inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector* in_attrs, std::vector* out_attrs) { + const AdagradParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); - const AdagradParam& param = nnvm::get(attrs.parsed); + const int weight_stype = in_attrs->at(0); + const int grad_stype = in_attrs->at(1); + const int state_stype = in_attrs->at(2); bool dispatched = false; - if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) && - common::ContainsOnlyStorage(*in_attrs, kRowSparseStorage) && - param.wd == 0.0f) { - // rsp, rsp, rsp -> rsp with wd = 0.0 - dispatched = storage_type_assign(out_attrs, kRowSparseStorage, - dispatch_mode, DispatchMode::kFComputeEx); + if (!dispatched && grad_stype == kRowSparseStorage && + (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) && + state_stype == weight_stype && param.wd == 0.0f) { + // weight and state share stype, grad's stype = rsp + dispatched = storage_type_assign( + out_attrs, static_cast(weight_stype), dispatch_mode, + DispatchMode::kFComputeEx); } return dispatched; } @@ -1802,10 +1806,24 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs, const std::vector &outputs) { using namespace mxnet_op; const AdagradParam& param = nnvm::get(attrs.parsed); + + const auto weight_stype = inputs[0].storage_type(); + const auto grad_stype = inputs[1].storage_type(); + const auto state_stype = inputs[2].storage_type(); + const auto output_stype = outputs[0].storage_type(); + if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && common::ContainsOnlyStorage(outputs, kRowSparseStorage)) { NDArray out = outputs[0]; - AdagradUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], inputs[2], req[0], &out); + AdagradUpdateRspRspRspImpl(param, ctx, inputs[0], inputs[1], inputs[2], + req[0], &out); + } else if (state_stype == weight_stype && output_stype == weight_stype && + weight_stype == kDefaultStorage && + grad_stype == kRowSparseStorage) { + TBlob out_blob = outputs[0].data(); + AdagradUpdateDnsRspDnsImpl(param, ctx, inputs[0].data(), inputs[1], + inputs[2].data(), req[0], + &out_blob); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index fba10fb522a2..a5b3d4047df9 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -1034,6 +1034,8 @@ def test_adagrad(): if wd_option.get('wd', 0.0) == 0.0: compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, + g_stype='row_sparse')