diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index d86c3e6ce4f3..abde51b433af 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -381,7 +381,8 @@ class Embedding(HybridBlock): Data type of output embeddings. weight_initializer : Initializer Initializer for the `embeddings` matrix. - + sparse_grad: bool + If True, gradient w.r.t. weight will be a 'row_sparse' NDArray. Inputs: - **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`. @@ -390,13 +391,14 @@ class Embedding(HybridBlock): - **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`. """ def __init__(self, input_dim, output_dim, dtype='float32', - weight_initializer=None, **kwargs): + weight_initializer=None, sparse_grad=False, **kwargs): super(Embedding, self).__init__(**kwargs) + grad_stype = 'row_sparse' if sparse_grad else 'default' self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim, - 'dtype': dtype} + 'dtype': dtype, 'sparse_grad': sparse_grad} self.weight = self.params.get('weight', shape=(input_dim, output_dim), init=weight_initializer, dtype=dtype, - allow_deferred_init=True) + allow_deferred_init=True, grad_stype=grad_stype) def hybrid_forward(self, F, x, weight): return F.Embedding(x, weight, name='fwd', **self._kwargs) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 320b376fe0b0..c7cbcccc95ec 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -81,6 +81,8 @@ class Parameter(object): Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult. init : Initializer, default None Initializer of this parameter. Will use the global initializer by default. + grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'. + The storage type of the parameter's gradient. Attributes ---------- @@ -97,7 +99,7 @@ class Parameter(object): """ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False, - differentiable=True): + differentiable=True, grad_stype='default'): self._var = None self._data = None self._grad = None @@ -114,6 +116,11 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, self.wd_mult = wd_mult self.grad_req = grad_req self.init = init + assert grad_stype in ['default', 'row_sparse', 'csr'], \ + "grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \ + " but got '%s'" % (name, grad_stype) + self._grad_stype = grad_stype + def __repr__(self): s = 'Parameter {name} (shape={shape}, dtype={dtype})' @@ -261,7 +268,9 @@ def _init_grad(self): self._grad = None return - self._grad = [ndarray.zeros_like(i) for i in self._data] + self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context, + stype=self._grad_stype) for i in self._data] + autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req) def _reduce(self): @@ -431,7 +440,7 @@ def zero_grad(self): if self._grad is None: return for i in self._grad: - i[:] = 0 + ndarray.zeros_like(i, out=i) def var(self): """Returns a symbol representing this parameter.""" diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index da67fc0b1d99..39c4a1fd6104 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -110,7 +110,17 @@ def _init_optimizer(self, optimizer, optimizer_params): for _ in self._contexts] def _init_kvstore(self): - arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + arg_arrays = {} + contains_sparse = False + for param in self._params: + arg_arrays[param.name] = param.data(self._contexts[0]) + if param._grad_stype != 'default': + contains_sparse = True + # update_on_kvstore is set to False by the user + if self._update_on_kvstore is False: + raise RuntimeError("Cannot set update_on_kvstore to False when sparse " + "gradients and/or sparse weights are present for " + "Parameter %s." % param.name) kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts), arg_arrays) update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore is not None \ @@ -118,8 +128,12 @@ def _init_kvstore(self): if kvstore: if self._compression_params: kvstore.set_gradient_compression(self._compression_params) - if 'dist' in kvstore.type: - update_on_kvstore = False + # kv.pull(row_sparse_grad) is not supported + if contains_sparse: + update_on_kvstore = True + else: + if 'dist' in kvstore.type: + update_on_kvstore = False if update_on_kvstore: kvstore.set_optimizer(self._optimizer) # optimizer preferably needs to be set before init for multiprecision diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 2ac6c11a1675..38ecf121dfeb 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -276,7 +276,7 @@ class KVStoreLocal : public KVStore { // invalid, print warning messages once if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " - "This call has been ignored. Please make sure to use" + "This call has been ignored. Please make sure to use " "kv.row_sparse_pull() or module.prepare() with row_ids."; this->warnings_printed_.insert(key); } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 0c74cac2dca5..6c9660556371 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -383,8 +383,8 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, Stream *s = ctx.get_stream(); CHECK_EQ(outputs.size(), 1); auto stype = outputs[0].storage_type(); - if (req[0] == kNullOp) return; - CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for FillComputeZerosEx"; + // x + 0 == x + if (req[0] == kNullOp || req[0] == kAddTo) return; if (stype == kRowSparseStorage) { FillZerosRspImpl(s, outputs[0]); } else if (stype == kCSRStorage) { diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index b054aa6555f8..946b1406e78a 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -38,6 +38,21 @@ def test_parameter(): assert p.data(mx.cpu(1)).context == mx.cpu(1) assert p.data(mx.cpu(0)).shape == (10, 10) assert p.var().name == 'weight' + assert p.grad(mx.cpu(0)).stype == 'default' + + p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)]) + assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)] + +@with_seed() +def test_sparse_parameter(): + p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse') + p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)]) + assert len(p.list_data()) == 2 + assert len(p.list_grad()) == 2 + assert p.data(mx.cpu(1)).context == mx.cpu(1) + assert p.data(mx.cpu(0)).shape == (10, 10) + assert p.var().name == 'weight' + assert p.grad(mx.cpu(0)).stype == 'row_sparse' p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)]) assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)] @@ -676,15 +691,17 @@ def test_global_norm_clip(): @with_seed() def test_embedding(): - layer = gluon.nn.Embedding(10, 100) - layer.initialize() - x = mx.nd.array([3,4,2,0,1]) - with mx.autograd.record(): - y = layer(x) - y.backward() - assert (layer.weight.grad()[:5] == 1).asnumpy().all() - assert (layer.weight.grad()[5:] == 0).asnumpy().all() - + def check_embedding(sparse_grad): + layer = gluon.nn.Embedding(10, 100, sparse_grad=sparse_grad) + layer.initialize() + x = mx.nd.array([3,4,2,0,1]) + with mx.autograd.record(): + y = layer(x) + y.backward() + assert (layer.weight.grad().asnumpy()[:5] == 1).all() + assert (layer.weight.grad().asnumpy()[5:] == 0).all() + check_embedding(True) + check_embedding(False) @with_seed() def test_export(): @@ -977,6 +994,7 @@ def test_req(): assert_almost_equal(grad * 2, grad_double) +@with_seed() def test_save_load(): net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True) net.save_params('test.params') @@ -987,6 +1005,7 @@ def test_save_load(): net.load_params('test.params') +@with_seed() def test_hybrid_multi_context(): net = mx.gluon.model_zoo.vision.get_resnet(1, 18) net.initialize(ctx=[mx.cpu(0), mx.cpu(1)]) @@ -994,6 +1013,19 @@ def test_hybrid_multi_context(): net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy() +@with_seed() +def test_zero_grad(): + data = mx.nd.random.uniform(shape=(3,3)) + net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_') + net.initialize() + with mx.autograd.record(): + l = net(data) + l.backward() + net.collect_params().zero_grad() + grad = net.collect_params()['test_zero_grad_weight'].grad() + assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) + + if __name__ == '__main__': import nose nose.runmodule()