Skip to content

Commit

Permalink
Ensure noncontiguous tensor creation tests offsetting (pytorch#136396)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#136396
Approved by: https://github.com/amjames, https://github.com/eellison
ghstack dependencies: pytorch#136055
  • Loading branch information
benjaminglass1 authored and pytorchmergebot committed Oct 2, 2024
1 parent c7638da commit f984b88
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 42 deletions.
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ def check_inplace_view(func, input, rs, input_size, input_strides):
# A mode that when enabled runs correctness checks to ensure
# that operators have expected tags based on their input and
# output tensor properties
class TestTagsMode(TorchDispatchMode):
class _TestTagsMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if isinstance(args[0], torch.Tensor):
old_size = args[0].size()
Expand All @@ -2083,7 +2083,7 @@ def test_tags(self, device, dtype, op):
if isinstance(input, torch.Tensor):
old_size = input.size()
old_stride = input.stride()
with TestTagsMode():
with _TestTagsMode():
rs = op(input, *sample.args, **sample.kwargs)
# TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761
aten_name = op.aten_name if op.aten_name is not None else op.name
Expand Down
100 changes: 73 additions & 27 deletions test/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,26 +3093,40 @@ def test_histc_lowp(self, device, dtype):
Runs torch.histogram and numpy.histogram on the specified input parameters
and asserts that their output is equal.
"""
def _test_histogram_numpy(self, t, bins, bin_range, weights, density):
def _test_histogram_numpy(self, t, bins, bin_range, weights, density, eq_func=None):
def to_np(t):
if not torch.is_tensor(t):
return t
else:
return t.cpu().numpy()
return t.cpu().numpy()

# Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
def reference_histogram(self, t, bins, bin_range, weights, density, dtype):
(np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
(np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density)
return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype))
def reference_histogram(t, bins, bin_range, weights, density, dtype):
np_t, np_bins, np_weights = map(to_np, [t, bins, weights])
np_hist, np_bin_edges = np.histogram(
np_t, np_bins, range=bin_range, weights=np_weights, density=density
)
return (
torch.from_numpy(np_hist).to(dtype),
torch.from_numpy(np_bin_edges).to(dtype),
)

# Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
if eq_func is None:
eq_func = self.assertEqual

# Doesn't pass a 'range' kwarg unless necessary because the override of
# histogram with Tensor bins doesn't accept one.
if bin_range:
(actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density)
actual_hist, actual_bin_edges = torch.histogram(
t, bins, range=bin_range, weight=weights, density=density
)
else:
(actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density)
actual_hist, actual_bin_edges = torch.histogram(
t, bins, weight=weights, density=density
)

(expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype)
expected_hist, expected_bin_edges = reference_histogram(
t, bins, bin_range, weights, density, actual_hist.dtype
)

"""
Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
Expand All @@ -3122,28 +3136,48 @@ def reference_histogram(self, t, bins, bin_range, weights, density, dtype):
Issue: https://github.com/pytorch/pytorch/issues/58758
"""
if not torch.is_tensor(bins):
self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5)
# Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
(expected_hist, expected_bin_edges) = reference_histogram(
self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
eq_func(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5)
# Calls numpy.histogram again, passing torch's actual_bin_edges as the bins
# argument.
expected_hist, expected_bin_edges = reference_histogram(
t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype,
)

self.assertEqual(actual_hist, expected_hist)
self.assertEqual(actual_bin_edges, expected_bin_edges)
eq_func(actual_hist, expected_hist)
eq_func(actual_bin_edges, expected_bin_edges)

# Test passing non-contiguous output tensors
hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype,
noncontiguous=True)
bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype,
noncontiguous=True)
hist_out = make_tensor(
expected_hist.shape,
device=expected_hist.device,
dtype=expected_hist.dtype,
noncontiguous=True,
)
bin_edges_out = make_tensor(
expected_bin_edges.shape,
device=expected_bin_edges.device,
dtype=expected_bin_edges.dtype,
noncontiguous=True,
)

# Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
# Doesn't pass a 'range' kwarg unless necessary because the override of
# histogram with Tensor bins doesn't accept one.
if bin_range:
torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out))
torch.histogram(
t,
bins,
range=bin_range,
weight=weights,
density=density,
out=(hist_out, bin_edges_out),
)
else:
torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out))
torch.histogram(
t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out)
)

self.assertEqual(hist_out, expected_hist)
self.assertEqual(bin_edges_out, expected_bin_edges)
eq_func(hist_out, expected_hist)
eq_func(bin_edges_out, expected_bin_edges)

@onlyCPU
@dtypes(torch.float32)
Expand Down Expand Up @@ -3171,7 +3205,19 @@ def test_histogram(self, device, dtype):

# Tests with range min=max
bin_range[1] = bin_range[0]
self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
self._test_histogram_numpy(
values,
bin_ct,
bin_range,
weights,
density,
# TODO: investigate why torch.histogram differs from numpy.histogram
# so strongly on this particular test. There seems to be more
# differences here than the linspace issue, which is itself fairly
# easily patched around. Likely, the other tests also differ
# significantly, but below the default threshold for assertEqual.
eq_func=partial(self.assertEqual, rtol=3e-5, atol=0.0),
)

# Tests with caller-specified bin edges
bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort()
Expand Down
13 changes: 10 additions & 3 deletions torch/testing/_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import collections.abc
import functools
import math
import warnings
from typing import cast, List, Optional, Tuple, Union
Expand Down Expand Up @@ -189,6 +190,12 @@ def clamp(a: float, l: float, h: float) -> float:
f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}"
)

noncontiguous = noncontiguous and functools.reduce(lambda x, y: x * y, shape, 1) > 1
if noncontiguous:
# Double the size of the shape in the last dimension, so that we have
# non-identical values when we make the non-contiguous operation.
shape = cast(Tuple[int, ...], (*shape[:-1], 2 * shape[-1]))

if dtype is torch.bool:
low, high = cast(
Tuple[int, int],
Expand Down Expand Up @@ -252,9 +259,9 @@ def clamp(a: float, l: float, h: float) -> float:
" To request support, file an issue at: https://github.com/pytorch/pytorch/issues"
)

if noncontiguous and result.numel() > 1:
result = torch.repeat_interleave(result, 2, dim=-1)
result = result[..., ::2]
if noncontiguous:
# Offset by 1 to also catch offsetting issues
result = result[..., 1::2]
elif memory_format is not None:
result = result.clone(memory_format=memory_format)

Expand Down
6 changes: 0 additions & 6 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14646,12 +14646,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
skips=(
# Note: This xfail is fine -- it's inherent to how as_strided works
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
# RuntimeError: This operator is not Composite Compliant: the
# storage_offset of the tensor was modified directly without
# going through the PyTorch dispatcher.
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'),

# These fail because the test changes the input's in-memory layout
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
Expand Down
19 changes: 15 additions & 4 deletions torch/testing/_internal/composite_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,21 @@ def __new__(cls, elem, mode, *args, **kwargs):
if elem.requires_grad:
# CompositeCompliantTensor steals the "requires_grad"-ness.
# Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype,
device=elem.device, layout=elem.layout,
requires_grad=False)
tmp.copy_(elem.detach())
tmp = torch.empty(
(),
dtype=elem.dtype,
device=elem.device,
layout=elem.layout,
requires_grad=False,
)
# Use set_ rather than empty_strided() + copy_ so that we can preserve
# things like storage_offset.
tmp.set_(
source=elem.untyped_storage().clone(),
storage_offset=elem.storage_offset(),
size=elem.size(),
stride=elem.stride(),
)
r.elem = tmp
else:
r.elem = elem
Expand Down

0 comments on commit f984b88

Please sign in to comment.