Skip to content

Commit

Permalink
Eliminate storage views. (pytorch#9466)
Browse files Browse the repository at this point in the history
Summary:
Storage views were previously used to implement CUDA IPC sharing,
but they weren't necessary.  The new strategy is described in
Note [CUDA IPC and the caching allocator].

This also fixes an unrelated bug, where we weren't actually using
the Tensor forking pickler, because we didn't register a pickler
for torch.Tensor.

Fixes pytorch#9447.  Fixes pytorch#46.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

CC apaszke
Pull Request resolved: pytorch#9466

Reviewed By: apaszke

Differential Revision: D8859698

Pulled By: ezyang

fbshipit-source-id: 3362cb92f6ae4aa37084c57d79b31004bd0b4a97
  • Loading branch information
ezyang authored and jramseyer committed Jul 30, 2018
1 parent fdd8460 commit f5a87e8
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 193 deletions.
4 changes: 0 additions & 4 deletions aten/src/TH/THStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ void THStorage_free(THStorage *storage) {
}
storage->finalizer.~unique_ptr<THFinalizer>();
storage->data_ptr.~DataPtr();
if (storage->flag & TH_STORAGE_VIEW) {
THStorage_free(storage->view);
}
THStorage_weakFree(storage);
}
}
Expand Down Expand Up @@ -227,6 +224,5 @@ void THStorage_swap(THStorage *storage1, THStorage *storage2)
SWAP(flag);
SWAP(allocator);
SWAP(finalizer);
SWAP(view);
#undef SWAP
}
1 change: 0 additions & 1 deletion aten/src/TH/THStorage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ typedef struct THStorage
char flag;
at::Allocator *allocator;
std::unique_ptr<THFinalizer> finalizer;
struct THStorage *view;

template <typename T>
inline T * data() const {
Expand Down
1 change: 0 additions & 1 deletion aten/src/TH/generic/THStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#define TH_STORAGE_REFCOUNTED 1
#define TH_STORAGE_RESIZABLE 2
#define TH_STORAGE_VIEW 8

// Struct definition is moved to THStorage.hpp (so this file stays C compatible)
typedef struct THStorage THStorage;
Expand Down
3 changes: 0 additions & 3 deletions aten/src/THC/THCStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ void THCStorage_free(THCState *state, THCStorage *storage)
}
storage->finalizer.~unique_ptr<THFinalizer>();
storage->data_ptr.~DataPtr();
if (storage->flag & TH_STORAGE_VIEW) {
THCStorage_free(state, storage->view);
}
THStorage_weakFree(storage);
}
}
Expand Down
15 changes: 9 additions & 6 deletions test/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,15 @@ def test_receive():
def _test_preserve_sharing(self, ctx=mp, repeat=1):
def do_test():
x = torch.randn(5, 5)
data = [x.storage(), x.storage()[1:4], x, x[2], x[:, 1]]
data = [x.storage(), x, x[2], x[:, 1]]
q = ctx.Queue()
q.put(data)
new_data = q.get(timeout=1)
self.assertEqual(new_data, data, 0)
storage_cdata = data[0]._cdata
self.assertEqual(new_data[0]._cdata, storage_cdata)
for t in new_data[2:]:
for t in new_data[1:]:
self.assertEqual(t.storage()._cdata, storage_cdata)
# TODO: enable after fixing #46
# new_data[0].fill_(10)
# self.assertEqual(new_data[1], new_data[0][1:4], 0)

with leak_checker(self):
for i in range(repeat):
Expand Down Expand Up @@ -335,7 +332,13 @@ def test_cuda_small_tensors(self):
self.assertEqual(v, torch.arange(i * 5., (i + 1) * 5).sum())
self.assertEqual(device, i % 2)
self.assertEqual(tensor_size, 5)
self.assertEqual(storage_size, 5)
# You might think this should be the case, but it's not! After
# data from the CUDA caching allocator goes through IPC, the
# size of the storage is the size of the *cached cudaMalloc for
# the entire memory block* of the storage, not just the storage.
# See Note [CUDA IPC and the caching allocator] for more info
#
# self.assertEqual(storage_size, 5)

@unittest.skipIf(IS_WINDOWS, 'not applicable to Windows (only fails with fork)')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
Expand Down
35 changes: 13 additions & 22 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6456,17 +6456,6 @@ def test_storage(self):
self.assertEqual(v.storage()[0], v.data[0][0])
self.assertEqual(v.storage()[14], v.data[2][4])

def test_storageview(self):
s1 = torch.LongStorage((3, 4, 5))
s2 = torch.LongStorage(s1, 1)

self.assertEqual(s2.size(), 2)
self.assertEqual(s2[0], s1[1])
self.assertEqual(s2[1], s1[2])

s2[1] = 13
self.assertEqual(13, s1[2])

def test_nonzero(self):
num_src = 12

Expand Down Expand Up @@ -6732,14 +6721,14 @@ def test_parsing_intlist(self):

def _test_serialization_data(self):
a = [torch.randn(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b += [torch.arange(1, 11).int()]
t1 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].storage()[1:4], 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())]
b += [a[0].storage()[0:2]]
b = [a[i % 2] for i in range(4)] # 0-3
b += [a[0].storage()] # 4
b += [a[0].reshape(-1)[1:4].storage()] # 5
b += [torch.arange(1, 11).int()] # 6
t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,))
t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,))
b += [(t1.storage(), t1.storage(), t2.storage())] # 7
b += [a[0].reshape(-1)[0:2].storage()] # 8
return b

def _test_serialization_assert(self, b, c):
Expand All @@ -6754,7 +6743,10 @@ def _test_serialization_assert(self, b, c):
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4][1:4], c[5], 0)
# I have to do it in this roundabout fashion, because there's no
# way to slice storages
for i in range(4):
self.assertEqual(c[4][i + 1], c[5][i])

# check that serializing the same storage view object unpickles
# it as one object not two (and vice versa)
Expand Down Expand Up @@ -6914,7 +6906,7 @@ def test_serialization_backwards_compat(self):
a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)]
b = [a[i % 2] for i in range(4)]
b += [a[0].storage()]
b += [a[0].storage()[1:4]]
b += [a[0].reshape(-1)[1:4].clone().storage()]
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
c = torch.load(path)
self.assertEqual(b, c, 0)
Expand All @@ -6928,7 +6920,6 @@ def test_serialization_backwards_compat(self):
self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0)
c[1].fill_(20)
self.assertEqual(c[1], c[3], 0)
self.assertEqual(c[4][1:4], c[5], 0)

# test some old tensor serialization mechanism
class OldTensorBase(object):
Expand Down
65 changes: 2 additions & 63 deletions torch/csrc/generic/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,44 +88,8 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec

// torch.Storage(view_source, [offset, [size]])
if (num_args < 4 && THPStorage_(Check)(first_arg)) {
#ifdef THD_GENERIC_FILE
THPUtils_setError("distributed storages don't support storage views");
THPUtils_setError("storage views not supported");
return NULL;
#else
THPStorage *storage_arg = (THPStorage *)first_arg;
int64_t numel = storage_arg->cdata->size;
int64_t offset = 0;

if (num_args >= 2) {
PyObject *second_arg = PyTuple_GET_ITEM(args, 1);
if (!THPUtils_checkLong(second_arg))
goto invalid_arguments;
offset = THPUtils_unpackLong(second_arg);
}

int64_t size = numel - offset;
if (num_args >= 3) {
PyObject *third_arg = PyTuple_GET_ITEM(args, 2);
if (!THPUtils_checkLong(third_arg))
goto invalid_arguments;
size = THPUtils_unpackLong(third_arg);
}

THPUtils_assert(offset >= 0 && offset <= numel, "specified an offset of "
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s)", offset, numel);
THPUtils_assert(size >= 1 && size <= numel - offset, "specified a size of "
"%" PRId64 ", but the viewed storage has only %" PRId64 " element(s) after offset %" PRId64,
size, numel - offset, offset);

real *data_ptr = THWStorage_(data)(LIBRARY_STATE storage_arg->cdata) + offset;
// TODO: Hmmmm
THWStoragePtr storage(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {data_ptr, storage_arg->cdata->data_ptr.device()} /* non-owning */, size, nullptr));
storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
storage->view = storage_arg->cdata;
THWStorage_(retain)(LIBRARY_STATE storage_arg->cdata);
self->cdata = storage.release();
return (PyObject*)self.release();
#endif
}

// torch.Storage(sequence)
Expand Down Expand Up @@ -161,9 +125,6 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
#endif
}

#ifndef THD_GENERIC_FILE
invalid_arguments:
#endif
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
"no arguments",
"(int size)",
Expand Down Expand Up @@ -199,30 +160,8 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
return THPUtils_(newReal)(value);
/* Slice index */
} else if (PySlice_Check(index)) {
#ifdef THD_GENERIC_FILE
THPUtils_setError("distributed storages don't support slicing");
THPUtils_setError("storages don't support slicing");
return NULL;
#else
Py_ssize_t start, stop, slicelength, step;
int64_t len = THWStorage_(size)(LIBRARY_STATE self->cdata);
if (!THPUtils_parseSlice(index, len, &start, &stop, &step, &slicelength))
return NULL;
if (step != 1) {
THPUtils_setError("Trying to slice with a step of %" PRId64 ", but only a step of "
"1 is supported", (int64_t)step);
return NULL;
}

real *data = THWStorage_(data)(LIBRARY_STATE self->cdata);
THWStoragePtr new_storage(THWStorage_(newWithDataAndAllocator)(LIBRARY_STATE {static_cast<void*>(data + start), self->cdata->data_ptr.device()} /* non-owning */, slicelength, nullptr));
new_storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_VIEW;
new_storage->view = self->cdata;
THWStorage_(retain)(LIBRARY_STATE self->cdata);

PyObject *_ret = THPStorage_(New)(new_storage);
new_storage.release();
return _ret;
#endif
}
PyErr_Format(PyExc_TypeError, "can't index a " THPStorageStr " with %s",
THPUtils_typename(index));
Expand Down
21 changes: 0 additions & 21 deletions torch/csrc/generic/StorageMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,26 +292,6 @@ PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata)
END_HANDLE_TH_ERRORS
}

#ifndef THD_GENERIC_FILE
PyObject * THPStorage_(_rootStorage)(THPStorage *self)
{
HANDLE_TH_ERRORS
if (!(self->cdata->flag & TH_STORAGE_VIEW)) {
return Py_BuildValue("(ON)", self, PyLong_FromLong(0));
}
THWStorage *root = self->cdata;
while (root->flag & TH_STORAGE_VIEW)
root = root->view;
size_t offset = THWStorage_(data)(LIBRARY_STATE self->cdata) - THWStorage_(data)(LIBRARY_STATE root);
THWStorage_(retain)(LIBRARY_STATE root);
THPObjectPtr storage(THPStorage_(New)(root));
PyObject *result = Py_BuildValue("(NN)", storage.get(), PyLong_FromLong(offset));
storage.release();
return result;
END_HANDLE_TH_ERRORS
}
#endif

static PyMethodDef THPStorage_(methods)[] = {
{"copy_", (PyCFunction)THPStorage_(copy_), METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)THPStorage_(elementSize), METH_NOARGS, NULL},
Expand All @@ -335,7 +315,6 @@ static PyMethodDef THPStorage_(methods)[] = {
#endif
{"_set_cdata", (PyCFunction)THPStorage_(_setCdata), METH_O, NULL},
#ifndef THD_GENERIC_FILE
{"_root_storage", (PyCFunction)THPStorage_(_rootStorage), METH_NOARGS, NULL},
#endif
{NULL}
};
Loading

0 comments on commit f5a87e8

Please sign in to comment.