Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Sparse-Gluon] embedding with sparse grad #10924

Merged
merged 8 commits into from
May 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ class Embedding(HybridBlock):
Data type of output embeddings.
weight_initializer : Initializer
Initializer for the `embeddings` matrix.

sparse_grad: bool
If True, gradient w.r.t. weight will be a 'row_sparse' NDArray.

Inputs:
- **data**: (N-1)-D tensor with shape: `(x1, x2, ..., xN-1)`.
Expand All @@ -390,13 +391,14 @@ class Embedding(HybridBlock):
- **out**: N-D tensor with shape: `(x1, x2, ..., xN-1, output_dim)`.
"""
def __init__(self, input_dim, output_dim, dtype='float32',
weight_initializer=None, **kwargs):
weight_initializer=None, sparse_grad=False, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation for argument

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

super(Embedding, self).__init__(**kwargs)
grad_stype = 'row_sparse' if sparse_grad else 'default'
self._kwargs = {'input_dim': input_dim, 'output_dim': output_dim,
'dtype': dtype}
'dtype': dtype, 'sparse_grad': sparse_grad}
self.weight = self.params.get('weight', shape=(input_dim, output_dim),
init=weight_initializer, dtype=dtype,
allow_deferred_init=True)
allow_deferred_init=True, grad_stype=grad_stype)

def hybrid_forward(self, F, x, weight):
return F.Embedding(x, weight, name='fwd', **self._kwargs)
Expand Down
15 changes: 12 additions & 3 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Parameter(object):
Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
init : Initializer, default None
Initializer of this parameter. Will use the global initializer by default.
grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
The storage type of the parameter's gradient.

Attributes
----------
Expand All @@ -97,7 +99,7 @@ class Parameter(object):
"""
def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False,
differentiable=True):
differentiable=True, grad_stype='default'):
self._var = None
self._data = None
self._grad = None
Expand All @@ -114,6 +116,11 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
self.wd_mult = wd_mult
self.grad_req = grad_req
self.init = init
assert grad_stype in ['default', 'row_sparse', 'csr'], \
"grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \
" but got '%s'" % (name, grad_stype)
self._grad_stype = grad_stype


def __repr__(self):
s = 'Parameter {name} (shape={shape}, dtype={dtype})'
Expand Down Expand Up @@ -261,7 +268,9 @@ def _init_grad(self):
self._grad = None
return

self._grad = [ndarray.zeros_like(i) for i in self._data]
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
stype=self._grad_stype) for i in self._data]

autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)

def _reduce(self):
Expand Down Expand Up @@ -431,7 +440,7 @@ def zero_grad(self):
if self._grad is None:
return
for i in self._grad:
i[:] = 0
ndarray.zeros_like(i, out=i)

def var(self):
"""Returns a symbol representing this parameter."""
Expand Down
20 changes: 17 additions & 3 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,30 @@ def _init_optimizer(self, optimizer, optimizer_params):
for _ in self._contexts]

def _init_kvstore(self):
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
arg_arrays = {}
contains_sparse = False
for param in self._params:
arg_arrays[param.name] = param.data(self._contexts[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you try to do step on parameter with deferred initialization? what message did you get?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I tried the language model example. param.data() was called in the previous implementation (line 113)

if param._grad_stype != 'default':
contains_sparse = True
# update_on_kvstore is set to False by the user
if self._update_on_kvstore is False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not self._update_on_kvstore

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self._update_on_kvstore is None, I don't need to throw the err

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

raise RuntimeError("Cannot set update_on_kvstore to False when sparse "
"gradients and/or sparse weights are present for "
"Parameter %s." % param.name)
kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts),
arg_arrays)
update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore is not None \
else update_on_kvstore
if kvstore:
if self._compression_params:
kvstore.set_gradient_compression(self._compression_params)
if 'dist' in kvstore.type:
update_on_kvstore = False
# kv.pull(row_sparse_grad) is not supported
if contains_sparse:
update_on_kvstore = True
else:
if 'dist' in kvstore.type:
update_on_kvstore = False
if update_on_kvstore:
kvstore.set_optimizer(self._optimizer)
# optimizer preferably needs to be set before init for multiprecision
Expand Down
2 changes: 1 addition & 1 deletion src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class KVStoreLocal : public KVStore {
// invalid, print warning messages once
if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) {
LOG(INFO) << "Warning: non-default weights detected during kvstore pull. "
"This call has been ignored. Please make sure to use"
"This call has been ignored. Please make sure to use "
"kv.row_sparse_pull() or module.prepare() with row_ids.";
this->warnings_printed_.insert(key);
}
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(outputs.size(), 1);
auto stype = outputs[0].storage_type();
if (req[0] == kNullOp) return;
CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for FillComputeZerosEx";
// x + 0 == x
if (req[0] == kNullOp || req[0] == kAddTo) return;
if (stype == kRowSparseStorage) {
FillZerosRspImpl(s, outputs[0]);
} else if (stype == kCSRStorage) {
Expand Down
50 changes: 41 additions & 9 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ def test_parameter():
assert p.data(mx.cpu(1)).context == mx.cpu(1)
assert p.data(mx.cpu(0)).shape == (10, 10)
assert p.var().name == 'weight'
assert p.grad(mx.cpu(0)).stype == 'default'

p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]

@with_seed()
def test_sparse_parameter():
p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse')
p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
assert len(p.list_data()) == 2
assert len(p.list_grad()) == 2
assert p.data(mx.cpu(1)).context == mx.cpu(1)
assert p.data(mx.cpu(0)).shape == (10, 10)
assert p.var().name == 'weight'
assert p.grad(mx.cpu(0)).stype == 'row_sparse'

p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
Expand Down Expand Up @@ -676,15 +691,17 @@ def test_global_norm_clip():

@with_seed()
def test_embedding():
layer = gluon.nn.Embedding(10, 100)
layer.initialize()
x = mx.nd.array([3,4,2,0,1])
with mx.autograd.record():
y = layer(x)
y.backward()
assert (layer.weight.grad()[:5] == 1).asnumpy().all()
assert (layer.weight.grad()[5:] == 0).asnumpy().all()

def check_embedding(sparse_grad):
layer = gluon.nn.Embedding(10, 100, sparse_grad=sparse_grad)
layer.initialize()
x = mx.nd.array([3,4,2,0,1])
with mx.autograd.record():
y = layer(x)
y.backward()
assert (layer.weight.grad().asnumpy()[:5] == 1).all()
assert (layer.weight.grad().asnumpy()[5:] == 0).all()
check_embedding(True)
check_embedding(False)

@with_seed()
def test_export():
Expand Down Expand Up @@ -977,6 +994,7 @@ def test_req():
assert_almost_equal(grad * 2, grad_double)


@with_seed()
def test_save_load():
net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
net.save_params('test.params')
Expand All @@ -987,13 +1005,27 @@ def test_save_load():
net.load_params('test.params')


@with_seed()
def test_hybrid_multi_context():
net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
net.initialize(ctx=[mx.cpu(0), mx.cpu(1)])
net.hybridize()
net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()


@with_seed()
def test_zero_grad():
data = mx.nd.random.uniform(shape=(3,3))
net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_')
net.initialize()
with mx.autograd.record():
l = net(data)
l.backward()
net.collect_params().zero_grad()
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)


if __name__ == '__main__':
import nose
nose.runmodule()