diff --git a/test/test_ops.py b/test/test_ops.py index 0e4cfcabc10c5..3f1e59727f4ca 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2058,7 +2058,7 @@ def check_inplace_view(func, input, rs, input_size, input_strides): # A mode that when enabled runs correctness checks to ensure # that operators have expected tags based on their input and # output tensor properties -class TestTagsMode(TorchDispatchMode): +class _TestTagsMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if isinstance(args[0], torch.Tensor): old_size = args[0].size() @@ -2083,7 +2083,7 @@ def test_tags(self, device, dtype, op): if isinstance(input, torch.Tensor): old_size = input.size() old_stride = input.stride() - with TestTagsMode(): + with _TestTagsMode(): rs = op(input, *sample.args, **sample.kwargs) # TODO: add test for aliases: https://github.com/pytorch/pytorch/issues/78761 aten_name = op.aten_name if op.aten_name is not None else op.name diff --git a/test/test_reductions.py b/test/test_reductions.py index 8191897bf4cc4..323866c80153c 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -3093,26 +3093,40 @@ def test_histc_lowp(self, device, dtype): Runs torch.histogram and numpy.histogram on the specified input parameters and asserts that their output is equal. """ - def _test_histogram_numpy(self, t, bins, bin_range, weights, density): + def _test_histogram_numpy(self, t, bins, bin_range, weights, density, eq_func=None): def to_np(t): if not torch.is_tensor(t): return t - else: - return t.cpu().numpy() + return t.cpu().numpy() # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays. - def reference_histogram(self, t, bins, bin_range, weights, density, dtype): - (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights]) - (np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density) - return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype)) + def reference_histogram(t, bins, bin_range, weights, density, dtype): + np_t, np_bins, np_weights = map(to_np, [t, bins, weights]) + np_hist, np_bin_edges = np.histogram( + np_t, np_bins, range=bin_range, weights=np_weights, density=density + ) + return ( + torch.from_numpy(np_hist).to(dtype), + torch.from_numpy(np_bin_edges).to(dtype), + ) - # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one + if eq_func is None: + eq_func = self.assertEqual + + # Doesn't pass a 'range' kwarg unless necessary because the override of + # histogram with Tensor bins doesn't accept one. if bin_range: - (actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density) + actual_hist, actual_bin_edges = torch.histogram( + t, bins, range=bin_range, weight=weights, density=density + ) else: - (actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density) + actual_hist, actual_bin_edges = torch.histogram( + t, bins, weight=weights, density=density + ) - (expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype) + expected_hist, expected_bin_edges = reference_histogram( + t, bins, bin_range, weights, density, actual_hist.dtype + ) """ Works around linspace discrepancies by passing torch's constructed bin_edges to numpy. @@ -3122,28 +3136,48 @@ def reference_histogram(self, t, bins, bin_range, weights, density, dtype): Issue: https://github.com/pytorch/pytorch/issues/58758 """ if not torch.is_tensor(bins): - self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5) - # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument - (expected_hist, expected_bin_edges) = reference_histogram( - self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype) + eq_func(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5) + # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins + # argument. + expected_hist, expected_bin_edges = reference_histogram( + t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype, + ) - self.assertEqual(actual_hist, expected_hist) - self.assertEqual(actual_bin_edges, expected_bin_edges) + eq_func(actual_hist, expected_hist) + eq_func(actual_bin_edges, expected_bin_edges) # Test passing non-contiguous output tensors - hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype, - noncontiguous=True) - bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype, - noncontiguous=True) + hist_out = make_tensor( + expected_hist.shape, + device=expected_hist.device, + dtype=expected_hist.dtype, + noncontiguous=True, + ) + bin_edges_out = make_tensor( + expected_bin_edges.shape, + device=expected_bin_edges.device, + dtype=expected_bin_edges.dtype, + noncontiguous=True, + ) - # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one + # Doesn't pass a 'range' kwarg unless necessary because the override of + # histogram with Tensor bins doesn't accept one. if bin_range: - torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out)) + torch.histogram( + t, + bins, + range=bin_range, + weight=weights, + density=density, + out=(hist_out, bin_edges_out), + ) else: - torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out)) + torch.histogram( + t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out) + ) - self.assertEqual(hist_out, expected_hist) - self.assertEqual(bin_edges_out, expected_bin_edges) + eq_func(hist_out, expected_hist) + eq_func(bin_edges_out, expected_bin_edges) @onlyCPU @dtypes(torch.float32) @@ -3171,7 +3205,19 @@ def test_histogram(self, device, dtype): # Tests with range min=max bin_range[1] = bin_range[0] - self._test_histogram_numpy(values, bin_ct, bin_range, weights, density) + self._test_histogram_numpy( + values, + bin_ct, + bin_range, + weights, + density, + # TODO: investigate why torch.histogram differs from numpy.histogram + # so strongly on this particular test. There seems to be more + # differences here than the linspace issue, which is itself fairly + # easily patched around. Likely, the other tests also differ + # significantly, but below the default threshold for assertEqual. + eq_func=partial(self.assertEqual, rtol=3e-5, atol=0.0), + ) # Tests with caller-specified bin edges bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index 9de3bd09882e7..17910fc52da3f 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -3,6 +3,7 @@ """ import collections.abc +import functools import math import warnings from typing import cast, List, Optional, Tuple, Union @@ -189,6 +190,12 @@ def clamp(a: float, l: float, h: float) -> float: f"`requires_grad=True` is not supported for boolean and integral dtypes, but got {dtype=}" ) + noncontiguous = noncontiguous and functools.reduce(lambda x, y: x * y, shape, 1) > 1 + if noncontiguous: + # Double the size of the shape in the last dimension, so that we have + # non-identical values when we make the non-contiguous operation. + shape = cast(Tuple[int, ...], (*shape[:-1], 2 * shape[-1])) + if dtype is torch.bool: low, high = cast( Tuple[int, int], @@ -252,9 +259,9 @@ def clamp(a: float, l: float, h: float) -> float: " To request support, file an issue at: https://github.com/pytorch/pytorch/issues" ) - if noncontiguous and result.numel() > 1: - result = torch.repeat_interleave(result, 2, dim=-1) - result = result[..., ::2] + if noncontiguous: + # Offset by 1 to also catch offsetting issues + result = result[..., 1::2] elif memory_format is not None: result = result.clone(memory_format=memory_format) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ec6fad938b435..24e70566f8f9b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14646,12 +14646,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): skips=( # Note: This xfail is fine -- it's inherent to how as_strided works DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'), - # RuntimeError: This operator is not Composite Compliant: the - # storage_offset of the tensor was modified directly without - # going through the PyTorch dispatcher. - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), - DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_backward'), - # These fail because the test changes the input's in-memory layout DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index b3c3bd4a130e0..ab1a05d4fef40 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -136,10 +136,21 @@ def __new__(cls, elem, mode, *args, **kwargs): if elem.requires_grad: # CompositeCompliantTensor steals the "requires_grad"-ness. # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests... - tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype, - device=elem.device, layout=elem.layout, - requires_grad=False) - tmp.copy_(elem.detach()) + tmp = torch.empty( + (), + dtype=elem.dtype, + device=elem.device, + layout=elem.layout, + requires_grad=False, + ) + # Use set_ rather than empty_strided() + copy_ so that we can preserve + # things like storage_offset. + tmp.set_( + source=elem.untyped_storage().clone(), + storage_offset=elem.storage_offset(), + size=elem.size(), + stride=elem.stride(), + ) r.elem = tmp else: r.elem = elem