Skip to content

Commit

Permalink
Fix NJT serialization (pytorch#137031)
Browse files Browse the repository at this point in the history
Fixes pytorch#129366

Since NJT has custom serialization logic, we need an NJT-specific fix to clear out cached sizes / strides PyCapsules. Eventually, we should switch NJT to use the default serialization logic, but this depends on pytorch#125622 being addressed.

This PR also makes serialization more complete by explicitly handling `lengths`, `ragged_idx`, and the `metadata_cache`, ensuring working operation for both contiguous and non-contiguous NJTs,
Pull Request resolved: pytorch#137031
Approved by: https://github.com/soulitzer
ghstack dependencies: pytorch#137030
  • Loading branch information
jbschlosser authored and pytorchmergebot committed Oct 2, 2024
1 parent be423a8 commit e95b230
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 28 deletions.
115 changes: 90 additions & 25 deletions test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import math
import sys
import tempfile
import unittest
from contextlib import nullcontext
from functools import partial
Expand Down Expand Up @@ -898,6 +899,31 @@ def test_detach(self, device, dtype):
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))

@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("weights_only", [False, True])
def test_serialization(self, device, dtype, requires_grad, weights_only):
def compare_metadata(nt1, nt2):
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
self.assertEqual(
nt1._nested_tensor_storage_offsets(),
nt2._nested_tensor_storage_offsets(),
)

nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for a in [nt_contiguous, nt_noncontiguous]:
buffer = io.BytesIO()
serialized = torch.save(a, buffer)
buffer.seek(0)
b = torch.load(buffer, weights_only=weights_only)
# should be both conceptually equal and metadata equivalent
self.assertEqual(a, b)
compare_metadata(a, b)
# should be conceptually equal but not necessarily metadata equivalent
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)

@dtypes(torch.float, torch.float16, torch.double)
def test_unbind_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
Expand Down Expand Up @@ -3621,6 +3647,70 @@ def _make_tensor(

return example_lists

@dtypes(torch.float32)
@parametrize(
"contiguity",
["contig", "noncontig_transposed", "noncontig_with_holes"],
name_fn=lambda c: c,
)
def test_serialization(self, device, dtype, contiguity):
# Test with 3 cases:
# 1. contiguous
# 2. non-contiguous transposed
# 3. non-contiguous with holes
if contiguity == "contig":
nt = random_nt_from_dims(
[4, None, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
)
elif contiguity == "noncontig_transposed":
nt = random_nt_from_dims(
[3, None, 5, 2],
device=device,
dtype=dtype,
layout=torch.jagged,
).transpose(-3, -2)
elif contiguity == "noncontig_with_holes":
nt = torch.nested.nested_tensor_from_jagged(
values=torch.randn(10, 3, device=device, dtype=dtype),
offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64),
# these lengths specify holes
lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64),
)
else:
raise ValueError("invalid contiguity specified for test_serialization()")

# Access sizes / strides to ensure cache doesn't break serialization.
# See https://github.com/pytorch/pytorch/issues/129366
nt.size()
nt.stride()

with tempfile.TemporaryFile() as f:
torch.save(nt, f)
f.seek(0)
nt_loaded = torch.load(f)

self.assertIsNot(nt, nt_loaded)
# we expect a new offsets tensor -> different nested int upon load
self.assertEqualIgnoringNestedInts(nt, nt_loaded)
self.assertEqual(nt._ragged_idx, nt_loaded._ragged_idx)
# ensure shapes are equal except nested int
nt_rest_of_shape = (
*nt.shape[: nt._ragged_idx],
*nt.shape[nt._ragged_idx + 1 :],
)
nt_loaded_rest_of_shape = (
*nt_loaded.shape[: nt_loaded._ragged_idx],
*nt_loaded.shape[nt_loaded._ragged_idx + 1 :],
)
self.assertEqual(nt_rest_of_shape, nt_loaded_rest_of_shape)
# ensure metadata cache is carried through serialization
self.assertEqual(nt._metadata_cache, nt_loaded._metadata_cache)
# ensure lengths are carried through if present
self.assertEqual(nt._lengths, nt_loaded._lengths)

def test_tensor_attributes(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
Expand Down Expand Up @@ -5150,31 +5240,6 @@ def test_mean_dim_keepdim_False(
else:
out = torch.mean(nt, dim=reduce_dim, keepdim=keepdim)

@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("weights_only", [False, True])
def test_serialization(self, device, dtype, requires_grad, weights_only):
def compare_metadata(nt1, nt2):
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
self.assertEqual(
nt1._nested_tensor_storage_offsets(),
nt2._nested_tensor_storage_offsets(),
)

nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for a in [nt_contiguous, nt_noncontiguous]:
buffer = io.BytesIO()
serialized = torch.save(a, buffer)
buffer.seek(0)
b = torch.load(buffer, weights_only=weights_only)
# should be both conceptually equal and metadata equivalent
self.assertEqual(a, b)
compare_metadata(a, b)
# should be conceptually equal but not necessarily metadata equivalent
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)

@unittest.skipIf(
PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
)
Expand Down
23 changes: 20 additions & 3 deletions torch/nested/_internal/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def _load_val_from_tensor(t: torch.Tensor):
return t.shape[0]


# serialization function must be defined at top level
def _rebuild_njt(constructor_kwargs):
return NestedTensor(**constructor_kwargs)


class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
Expand Down Expand Up @@ -218,18 +223,30 @@ def __repr__(self): # type: ignore[override]
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self._offsets}{grad_fn_str}, contiguous={self._lengths is None})"

# TODO: Remove this in favor of the default tensor subclass serialization logic.
# We don't do this today because of https://github.com/pytorch/pytorch/issues/125622.
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)

# Cached PyCapsules for sizes / strides are not serializable.
# See Note [Tensor Subclass custom size/stride caching strategy]
self._clear_non_serializable_cached_data()
# SymNodes are not serializable
assert "_size" in state and "_strides" in state
state = dict(state)
del state["_size"]
del state["_strides"]

# TODO: Update this to handle the other inner tensors
func = NestedTensor
args = (self._values, self._offsets)
func = _rebuild_njt
constructor_kwargs = {
"values": self._values,
"offsets": self._offsets,
"lengths": self._lengths,
"_ragged_idx": self._ragged_idx,
"_metadata_cache": self._metadata_cache,
"requires_grad": self.requires_grad,
}
args = (constructor_kwargs,)
return (torch._tensor._rebuild_from_type_v2, (func, type(self), args, state))

def __tensor_flatten__(self):
Expand Down

0 comments on commit e95b230

Please sign in to comment.