1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport random 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerimport warnings 6*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 7*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, combinations, permutations, product 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport numpy as np 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerfrom torch import nan 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 15*da0073e9SAndroid Build Coastguard Worker dtypes, 16*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 17*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 18*da0073e9SAndroid Build Coastguard Worker largeTensorTest, 19*da0073e9SAndroid Build Coastguard Worker onlyCPU, 20*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 21*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 24*da0073e9SAndroid Build Coastguard Worker all_types, 25*da0073e9SAndroid Build Coastguard Worker all_types_and, 26*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and, 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 29*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 30*da0073e9SAndroid Build Coastguard Worker run_tests, 31*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 32*da0073e9SAndroid Build Coastguard Worker TEST_PRIVATEUSE1_DEVICE_TYPE, 33*da0073e9SAndroid Build Coastguard Worker TestCase, 34*da0073e9SAndroid Build Coastguard Worker torch_to_numpy_dtype_dict, 35*da0073e9SAndroid Build Coastguard Worker) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker# TODO: replace with make_tensor 39*da0073e9SAndroid Build Coastguard Workerdef _generate_input(shape, dtype, device, with_extremal): 40*da0073e9SAndroid Build Coastguard Worker if shape == (): 41*da0073e9SAndroid Build Coastguard Worker x = torch.tensor((), dtype=dtype, device=device) 42*da0073e9SAndroid Build Coastguard Worker else: 43*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 44*da0073e9SAndroid Build Coastguard Worker # work around torch.randn not being implemented for bfloat16 45*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 46*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, device=device) * random.randint(30, 100) 47*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.bfloat16) 48*da0073e9SAndroid Build Coastguard Worker else: 49*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( 50*da0073e9SAndroid Build Coastguard Worker 30, 100 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = 0 53*da0073e9SAndroid Build Coastguard Worker if with_extremal and dtype.is_floating_point: 54*da0073e9SAndroid Build Coastguard Worker # Use extremal values 55*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("nan") 56*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("inf") 57*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("-inf") 58*da0073e9SAndroid Build Coastguard Worker elif with_extremal and dtype.is_complex: 59*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("nan") 60*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("inf") 61*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("-inf") 62*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bool: 63*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(shape, dtype=dtype, device=device) 64*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = True 65*da0073e9SAndroid Build Coastguard Worker else: 66*da0073e9SAndroid Build Coastguard Worker x = torch.randint(15, 100, shape, dtype=dtype, device=device) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker return x 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerclass TestShapeOps(TestCase): 72*da0073e9SAndroid Build Coastguard Worker # TODO: update to work on CUDA, too 73*da0073e9SAndroid Build Coastguard Worker @onlyCPU 74*da0073e9SAndroid Build Coastguard Worker def test_unbind(self, device): 75*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3, 4, 5) 76*da0073e9SAndroid Build Coastguard Worker for dim in range(4): 77*da0073e9SAndroid Build Coastguard Worker res = torch.unbind(x, dim) 78*da0073e9SAndroid Build Coastguard Worker res2 = x.unbind(dim) 79*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(dim), len(res)) 80*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(dim), len(res2)) 81*da0073e9SAndroid Build Coastguard Worker for i in range(dim): 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.select(dim, i), res[i]) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.select(dim, i), res2[i]) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker # TODO: update to work on CUDA, too? 86*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with an unknown error") 87*da0073e9SAndroid Build Coastguard Worker @onlyCPU 88*da0073e9SAndroid Build Coastguard Worker def test_tolist(self, device): 89*da0073e9SAndroid Build Coastguard Worker list0D = [] 90*da0073e9SAndroid Build Coastguard Worker tensor0D = torch.tensor(list0D) 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor0D.tolist(), list0D) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker table1D = [1.0, 2.0, 3.0] 94*da0073e9SAndroid Build Coastguard Worker tensor1D = torch.tensor(table1D) 95*da0073e9SAndroid Build Coastguard Worker storage = torch.Storage(table1D) 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor1D.tolist(), table1D) 97*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage.tolist(), table1D) 98*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor1D.tolist(), table1D) 99*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage.tolist(), table1D) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker table2D = [[1, 2], [3, 4]] 102*da0073e9SAndroid Build Coastguard Worker tensor2D = torch.tensor(table2D) 103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2D.tolist(), table2D) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker tensor3D = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) 106*da0073e9SAndroid Build Coastguard Worker tensorNonContig = tensor3D.select(1, 1) 107*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tensorNonContig.is_contiguous()) 108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]]) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex128) 111*da0073e9SAndroid Build Coastguard Worker def test_movedim_invalid(self, device, dtype): 112*da0073e9SAndroid Build Coastguard Worker shape = self._rand_shape(4, min_size=5, max_size=10) 113*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, False) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker for fn in [torch.movedim, torch.moveaxis]: 116*da0073e9SAndroid Build Coastguard Worker # Invalid `source` and `destination` dimension 117*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 118*da0073e9SAndroid Build Coastguard Worker fn(x, 5, 0) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 121*da0073e9SAndroid Build Coastguard Worker fn(x, 0, 5) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker # Mismatch in size of `source` and `destination` 124*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 125*da0073e9SAndroid Build Coastguard Worker RuntimeError, "movedim: Invalid source or destination dims:" 126*da0073e9SAndroid Build Coastguard Worker ): 127*da0073e9SAndroid Build Coastguard Worker fn(x, (1, 0), (0,)) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 130*da0073e9SAndroid Build Coastguard Worker RuntimeError, "movedim: repeated dim in `source`" 131*da0073e9SAndroid Build Coastguard Worker ): 132*da0073e9SAndroid Build Coastguard Worker fn(x, (0, 0), (0, 1)) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 135*da0073e9SAndroid Build Coastguard Worker RuntimeError, "movedim: repeated dim in `source`" 136*da0073e9SAndroid Build Coastguard Worker ): 137*da0073e9SAndroid Build Coastguard Worker fn(x, (0, 1, 0), (0, 1, 2)) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 140*da0073e9SAndroid Build Coastguard Worker RuntimeError, "movedim: repeated dim in `destination`" 141*da0073e9SAndroid Build Coastguard Worker ): 142*da0073e9SAndroid Build Coastguard Worker fn(x, (0, 1), (1, 1)) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 145*da0073e9SAndroid Build Coastguard Worker RuntimeError, "movedim: repeated dim in `destination`" 146*da0073e9SAndroid Build Coastguard Worker ): 147*da0073e9SAndroid Build Coastguard Worker fn(x, (0, 1, 2), (1, 0, 1)) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex128) 150*da0073e9SAndroid Build Coastguard Worker def test_movedim(self, device, dtype): 151*da0073e9SAndroid Build Coastguard Worker for fn in [torch.moveaxis, torch.movedim]: 152*da0073e9SAndroid Build Coastguard Worker for nd in range(5): 153*da0073e9SAndroid Build Coastguard Worker shape = self._rand_shape(nd, min_size=5, max_size=10) 154*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal=False) 155*da0073e9SAndroid Build Coastguard Worker for random_negative in [True, False]: 156*da0073e9SAndroid Build Coastguard Worker for src_dim, dst_dim in permutations(range(nd), r=2): 157*da0073e9SAndroid Build Coastguard Worker random_prob = random.random() 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker if random_negative and random_prob > 0.66: 160*da0073e9SAndroid Build Coastguard Worker src_dim = src_dim - nd 161*da0073e9SAndroid Build Coastguard Worker elif random_negative and random_prob > 0.33: 162*da0073e9SAndroid Build Coastguard Worker dst_dim = dst_dim - nd 163*da0073e9SAndroid Build Coastguard Worker elif random_negative: 164*da0073e9SAndroid Build Coastguard Worker src_dim = src_dim - nd 165*da0073e9SAndroid Build Coastguard Worker dst_dim = dst_dim - nd 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker # Integer `source` and `destination` 168*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(fn, source=src_dim, destination=dst_dim) 169*da0073e9SAndroid Build Coastguard Worker np_fn = partial( 170*da0073e9SAndroid Build Coastguard Worker np.moveaxis, source=src_dim, destination=dst_dim 171*da0073e9SAndroid Build Coastguard Worker ) 172*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 173*da0073e9SAndroid Build Coastguard Worker torch_fn, np_fn, x, device=None, dtype=None 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker if nd == 0: 177*da0073e9SAndroid Build Coastguard Worker continue 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def make_index_negative(sequence, idx): 180*da0073e9SAndroid Build Coastguard Worker sequence = list(sequence) 181*da0073e9SAndroid Build Coastguard Worker sequence[random_idx] = sequence[random_idx] - nd 182*da0073e9SAndroid Build Coastguard Worker return tuple(src_sequence) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker for src_sequence in permutations( 185*da0073e9SAndroid Build Coastguard Worker range(nd), r=random.randint(1, nd) 186*da0073e9SAndroid Build Coastguard Worker ): 187*da0073e9SAndroid Build Coastguard Worker # Sequence `source` and `destination` 188*da0073e9SAndroid Build Coastguard Worker dst_sequence = tuple( 189*da0073e9SAndroid Build Coastguard Worker random.sample(range(nd), len(src_sequence)) 190*da0073e9SAndroid Build Coastguard Worker ) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker # Randomly change a dim to a negative dim representation of itself. 193*da0073e9SAndroid Build Coastguard Worker random_prob = random.random() 194*da0073e9SAndroid Build Coastguard Worker if random_negative and random_prob > 0.66: 195*da0073e9SAndroid Build Coastguard Worker random_idx = random.randint(0, len(src_sequence) - 1) 196*da0073e9SAndroid Build Coastguard Worker src_sequence = make_index_negative(src_sequence, random_idx) 197*da0073e9SAndroid Build Coastguard Worker elif random_negative and random_prob > 0.33: 198*da0073e9SAndroid Build Coastguard Worker random_idx = random.randint(0, len(src_sequence) - 1) 199*da0073e9SAndroid Build Coastguard Worker dst_sequence = make_index_negative(dst_sequence, random_idx) 200*da0073e9SAndroid Build Coastguard Worker elif random_negative: 201*da0073e9SAndroid Build Coastguard Worker random_idx = random.randint(0, len(src_sequence) - 1) 202*da0073e9SAndroid Build Coastguard Worker dst_sequence = make_index_negative(dst_sequence, random_idx) 203*da0073e9SAndroid Build Coastguard Worker random_idx = random.randint(0, len(src_sequence) - 1) 204*da0073e9SAndroid Build Coastguard Worker src_sequence = make_index_negative(src_sequence, random_idx) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker torch_fn = partial( 207*da0073e9SAndroid Build Coastguard Worker fn, source=src_sequence, destination=dst_sequence 208*da0073e9SAndroid Build Coastguard Worker ) 209*da0073e9SAndroid Build Coastguard Worker np_fn = partial( 210*da0073e9SAndroid Build Coastguard Worker np.moveaxis, source=src_sequence, destination=dst_sequence 211*da0073e9SAndroid Build Coastguard Worker ) 212*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 213*da0073e9SAndroid Build Coastguard Worker torch_fn, np_fn, x, device=None, dtype=None 214*da0073e9SAndroid Build Coastguard Worker ) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker # Move dim to same position 217*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5, 7, 11) 218*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(fn, source=(0, 1), destination=(0, 1)) 219*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1)) 220*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(fn, source=1, destination=1) 223*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.moveaxis, source=1, destination=1) 224*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker # Empty Sequence 227*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(fn, source=(), destination=()) 228*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.moveaxis, source=(), destination=()) 229*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.bool) 232*da0073e9SAndroid Build Coastguard Worker def test_diag(self, device, dtype): 233*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bool: 234*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, device=device) >= 0.5 235*da0073e9SAndroid Build Coastguard Worker else: 236*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, dtype=dtype, device=device) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker res1 = torch.diag(x) 239*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 240*da0073e9SAndroid Build Coastguard Worker torch.diag(x, out=res2) 241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def test_diagonal(self, device): 244*da0073e9SAndroid Build Coastguard Worker x = torch.randn((100, 100), device=device) 245*da0073e9SAndroid Build Coastguard Worker result = torch.diagonal(x) 246*da0073e9SAndroid Build Coastguard Worker expected = torch.diag(x) 247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker x = torch.randn((100, 100), device=device) 250*da0073e9SAndroid Build Coastguard Worker result = torch.diagonal(x, 17) 251*da0073e9SAndroid Build Coastguard Worker expected = torch.diag(x, 17) 252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker @onlyCPU 255*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 256*da0073e9SAndroid Build Coastguard Worker def test_diagonal_multidim(self, device, dtype): 257*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) 258*da0073e9SAndroid Build Coastguard Worker xn = x.numpy() 259*da0073e9SAndroid Build Coastguard Worker for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]: 260*da0073e9SAndroid Build Coastguard Worker result = torch.diagonal(x, *args) 261*da0073e9SAndroid Build Coastguard Worker expected = xn.diagonal(*args) 262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, result.shape) 263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, result) 264*da0073e9SAndroid Build Coastguard Worker # test non-continguous 265*da0073e9SAndroid Build Coastguard Worker xp = x.permute(1, 2, 3, 0) 266*da0073e9SAndroid Build Coastguard Worker result = torch.diagonal(xp, 0, -2, -1) 267*da0073e9SAndroid Build Coastguard Worker expected = xp.numpy().diagonal(0, -2, -1) 268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, result.shape) 269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, result) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 272*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types()) 273*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*all_types_and(torch.half)) 274*da0073e9SAndroid Build Coastguard Worker def test_trace(self, device, dtype): 275*da0073e9SAndroid Build Coastguard Worker def test(shape): 276*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) 277*da0073e9SAndroid Build Coastguard Worker expected_dtype = tensor.sum().dtype 278*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch_to_numpy_dtype_dict[expected_dtype] 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype) 281*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(result, device=device) 282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.trace(), expected) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker shapes = ( 285*da0073e9SAndroid Build Coastguard Worker [10, 1], 286*da0073e9SAndroid Build Coastguard Worker [1, 10], 287*da0073e9SAndroid Build Coastguard Worker [100, 100], 288*da0073e9SAndroid Build Coastguard Worker [20, 100], 289*da0073e9SAndroid Build Coastguard Worker [100, 20], 290*da0073e9SAndroid Build Coastguard Worker ) 291*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 292*da0073e9SAndroid Build Coastguard Worker test(shape) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans): 295*da0073e9SAndroid Build Coastguard Worker """ 296*da0073e9SAndroid Build Coastguard Worker Creates a random tensor for a given device and dtype, and computes the expected clamped 297*da0073e9SAndroid Build Coastguard Worker values given the min_vals and/or max_vals. 298*da0073e9SAndroid Build Coastguard Worker If with_nans is provided, then some values are randomly set to nan. 299*da0073e9SAndroid Build Coastguard Worker """ 300*da0073e9SAndroid Build Coastguard Worker X = torch.rand(100, device=device).mul(50).add(-25) # uniform in [-25, 25] 301*da0073e9SAndroid Build Coastguard Worker X = X.to(dtype) 302*da0073e9SAndroid Build Coastguard Worker if with_nans: 303*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device) 304*da0073e9SAndroid Build Coastguard Worker X[mask] = nan 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker if isinstance(min_vals, torch.Tensor): 307*da0073e9SAndroid Build Coastguard Worker min_vals = min_vals.cpu().numpy() 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker if isinstance(max_vals, torch.Tensor): 310*da0073e9SAndroid Build Coastguard Worker max_vals = max_vals.cpu().numpy() 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker # Use NumPy implementation as reference 313*da0073e9SAndroid Build Coastguard Worker X_clamped = torch.tensor( 314*da0073e9SAndroid Build Coastguard Worker np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device 315*da0073e9SAndroid Build Coastguard Worker ) 316*da0073e9SAndroid Build Coastguard Worker return X, X_clamped 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker # Tests clamp and its alias, clip 319*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float32) 320*da0073e9SAndroid Build Coastguard Worker def test_clamp(self, device, dtype): 321*da0073e9SAndroid Build Coastguard Worker op_list = ( 322*da0073e9SAndroid Build Coastguard Worker torch.clamp, 323*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clamp, 324*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clamp_, 325*da0073e9SAndroid Build Coastguard Worker torch.clip, 326*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clip, 327*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clip_, 328*da0073e9SAndroid Build Coastguard Worker ) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker # min/max argument product 331*da0073e9SAndroid Build Coastguard Worker args = product((-10, None), (10, None)) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker for op in op_list: 334*da0073e9SAndroid Build Coastguard Worker for min_val, max_val in args: 335*da0073e9SAndroid Build Coastguard Worker if min_val is None and max_val is None: 336*da0073e9SAndroid Build Coastguard Worker continue 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker X, Y_expected = self.generate_clamp_baseline( 339*da0073e9SAndroid Build Coastguard Worker device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False 340*da0073e9SAndroid Build Coastguard Worker ) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # Test op 343*da0073e9SAndroid Build Coastguard Worker X1 = X.clone() # So that the in-place ops do not change X 344*da0073e9SAndroid Build Coastguard Worker Y_actual = op(X1, min_val, max_val) 345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_expected, Y_actual) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker # Test op-out behavior (out does not exist for method versions) 348*da0073e9SAndroid Build Coastguard Worker if op in (torch.clamp, torch.clip): 349*da0073e9SAndroid Build Coastguard Worker Y_out = torch.empty_like(X) 350*da0073e9SAndroid Build Coastguard Worker op(X, min=min_val, max=max_val, out=Y_out) 351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_expected, Y_out) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker def test_clamp_propagates_nans(self, device): 354*da0073e9SAndroid Build Coastguard Worker op_list = ( 355*da0073e9SAndroid Build Coastguard Worker torch.clamp, 356*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clamp, 357*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clamp_, 358*da0073e9SAndroid Build Coastguard Worker torch.clip, 359*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clip, 360*da0073e9SAndroid Build Coastguard Worker torch.Tensor.clip_, 361*da0073e9SAndroid Build Coastguard Worker ) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker # min/max argument product 364*da0073e9SAndroid Build Coastguard Worker args = product((-10, None), (10, None)) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker for op in op_list: 367*da0073e9SAndroid Build Coastguard Worker for min_val, max_val in args: 368*da0073e9SAndroid Build Coastguard Worker if min_val is None and max_val is None: 369*da0073e9SAndroid Build Coastguard Worker continue 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker X, Y_expected = self.generate_clamp_baseline( 372*da0073e9SAndroid Build Coastguard Worker device, 373*da0073e9SAndroid Build Coastguard Worker torch.float, 374*da0073e9SAndroid Build Coastguard Worker min_vals=min_val, 375*da0073e9SAndroid Build Coastguard Worker max_vals=max_val, 376*da0073e9SAndroid Build Coastguard Worker with_nans=True, 377*da0073e9SAndroid Build Coastguard Worker ) 378*da0073e9SAndroid Build Coastguard Worker Y_expected = torch.isnan(Y_expected) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker # Test op 381*da0073e9SAndroid Build Coastguard Worker X1 = X.clone() # So that the in-place ops do not change X 382*da0073e9SAndroid Build Coastguard Worker Y_actual = op(X1, min_val, max_val) 383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_expected, torch.isnan(Y_actual)) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker # Test op-out behavior (out does not exist for method versions) 386*da0073e9SAndroid Build Coastguard Worker if op in (torch.clamp, torch.clip): 387*da0073e9SAndroid Build Coastguard Worker Y_out = torch.empty_like(X) 388*da0073e9SAndroid Build Coastguard Worker op(X, min_val, max_val, out=Y_out) 389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(Y_expected, torch.isnan(Y_out)) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker def test_clamp_raises_arg_errors(self, device): 392*da0073e9SAndroid Build Coastguard Worker X = torch.randn(100, dtype=torch.float, device=device) 393*da0073e9SAndroid Build Coastguard Worker error_msg = "At least one of 'min' or 'max' must not be None" 394*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 395*da0073e9SAndroid Build Coastguard Worker X.clamp() 396*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 397*da0073e9SAndroid Build Coastguard Worker X.clamp_() 398*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 399*da0073e9SAndroid Build Coastguard Worker torch.clamp(X) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 402*da0073e9SAndroid Build Coastguard Worker def test_flip(self, device, dtype): 403*da0073e9SAndroid Build Coastguard Worker make_from_data = partial(torch.tensor, device=device, dtype=dtype) 404*da0073e9SAndroid Build Coastguard Worker make_from_size = partial(make_tensor, device=device, dtype=dtype) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker def test_flip_impl(input_t, dims, output_t): 407*da0073e9SAndroid Build Coastguard Worker def all_t(): 408*da0073e9SAndroid Build Coastguard Worker yield input_t, output_t 409*da0073e9SAndroid Build Coastguard Worker if dtype is torch.float: 410*da0073e9SAndroid Build Coastguard Worker # We generate quantized versions as well 411*da0073e9SAndroid Build Coastguard Worker for qdtype in (torch.quint8, torch.qint8, torch.qint32): 412*da0073e9SAndroid Build Coastguard Worker qinput_t = torch.quantize_per_tensor(input_t, 0.1, 5, qdtype) 413*da0073e9SAndroid Build Coastguard Worker qoutput_t = torch.quantize_per_tensor(output_t, 0.1, 5, qdtype) 414*da0073e9SAndroid Build Coastguard Worker yield qinput_t, qoutput_t 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker for in_t, out_t in all_t(): 417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(in_t.flip(dims), out_t) 418*da0073e9SAndroid Build Coastguard Worker n = in_t.ndim 419*da0073e9SAndroid Build Coastguard Worker if not isinstance(dims, tuple): 420*da0073e9SAndroid Build Coastguard Worker # Wrap dim 421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(in_t.flip(-n + dims), out_t) 422*da0073e9SAndroid Build Coastguard Worker else: 423*da0073e9SAndroid Build Coastguard Worker # Permute dimensions 424*da0073e9SAndroid Build Coastguard Worker for p_dims in permutations(dims): 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(in_t.flip(p_dims), out_t) 426*da0073e9SAndroid Build Coastguard Worker if len(p_dims) > 0: 427*da0073e9SAndroid Build Coastguard Worker # Wrap 1st dim 428*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 429*da0073e9SAndroid Build Coastguard Worker in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t 430*da0073e9SAndroid Build Coastguard Worker ) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def gen_data(): 433*da0073e9SAndroid Build Coastguard Worker # Basic tests 434*da0073e9SAndroid Build Coastguard Worker data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2) 435*da0073e9SAndroid Build Coastguard Worker nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker dims_result = ( 438*da0073e9SAndroid Build Coastguard Worker (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), 439*da0073e9SAndroid Build Coastguard Worker (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), 440*da0073e9SAndroid Build Coastguard Worker (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), 441*da0073e9SAndroid Build Coastguard Worker ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), 442*da0073e9SAndroid Build Coastguard Worker ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)), 443*da0073e9SAndroid Build Coastguard Worker ) 444*da0073e9SAndroid Build Coastguard Worker for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result): 445*da0073e9SAndroid Build Coastguard Worker yield in_tensor, dims, out_tensor 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker # Expanded 448*da0073e9SAndroid Build Coastguard Worker in_t = make_from_data([1, 2, 3]).view(3, 1).expand(3, 2) 449*da0073e9SAndroid Build Coastguard Worker dims = 0 450*da0073e9SAndroid Build Coastguard Worker out_t = make_from_data([3, 3, 2, 2, 1, 1]).view(3, 2) 451*da0073e9SAndroid Build Coastguard Worker yield in_t, dims, out_t 452*da0073e9SAndroid Build Coastguard Worker # Noop on expanded dimension 453*da0073e9SAndroid Build Coastguard Worker yield in_t, 1, in_t 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker # Transposed 456*da0073e9SAndroid Build Coastguard Worker in_t = ( 457*da0073e9SAndroid Build Coastguard Worker make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) 458*da0073e9SAndroid Build Coastguard Worker ) 459*da0073e9SAndroid Build Coastguard Worker dims = (0, 1, 2) 460*da0073e9SAndroid Build Coastguard Worker out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2) 461*da0073e9SAndroid Build Coastguard Worker yield in_t, dims, out_t 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker # Rectangular case 464*da0073e9SAndroid Build Coastguard Worker in_t = make_from_data([1, 2, 3, 4, 5, 6]).view(2, 3) 465*da0073e9SAndroid Build Coastguard Worker dims = 0 466*da0073e9SAndroid Build Coastguard Worker out_t = make_from_data([[4, 5, 6], [1, 2, 3]]) 467*da0073e9SAndroid Build Coastguard Worker yield in_t, dims, out_t 468*da0073e9SAndroid Build Coastguard Worker dims = 1 469*da0073e9SAndroid Build Coastguard Worker out_t = make_from_data([[3, 2, 1], [6, 5, 4]]) 470*da0073e9SAndroid Build Coastguard Worker yield in_t, dims, out_t 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker # vectorized NCHW cases (images) 473*da0073e9SAndroid Build Coastguard Worker if device == "cpu" and dtype != torch.bfloat16: 474*da0073e9SAndroid Build Coastguard Worker for mf in [torch.contiguous_format, torch.channels_last]: 475*da0073e9SAndroid Build Coastguard Worker for c in [2, 3, 8, 16]: 476*da0073e9SAndroid Build Coastguard Worker in_t = make_from_size((2, c, 32, 32)).contiguous( 477*da0073e9SAndroid Build Coastguard Worker memory_format=mf 478*da0073e9SAndroid Build Coastguard Worker ) 479*da0073e9SAndroid Build Coastguard Worker np_in_t = in_t.numpy() 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker np_out_t = np_in_t[:, :, :, ::-1].copy() 482*da0073e9SAndroid Build Coastguard Worker out_t = torch.from_numpy(np_out_t) 483*da0073e9SAndroid Build Coastguard Worker yield in_t, 3, out_t 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker np_out_t = np_in_t[:, :, ::-1, :].copy() 486*da0073e9SAndroid Build Coastguard Worker out_t = torch.from_numpy(np_out_t) 487*da0073e9SAndroid Build Coastguard Worker yield in_t, 2, out_t 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker # non-contig cases 490*da0073e9SAndroid Build Coastguard Worker in_tt = in_t[..., ::2, :] 491*da0073e9SAndroid Build Coastguard Worker np_in_t = in_tt.numpy() 492*da0073e9SAndroid Build Coastguard Worker np_out_t = np_in_t[:, :, :, ::-1].copy() 493*da0073e9SAndroid Build Coastguard Worker out_t = torch.from_numpy(np_out_t) 494*da0073e9SAndroid Build Coastguard Worker yield in_tt, 3, out_t 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker in_tt = in_t[..., ::2] 497*da0073e9SAndroid Build Coastguard Worker np_in_t = in_tt.numpy() 498*da0073e9SAndroid Build Coastguard Worker np_out_t = np_in_t[:, :, :, ::-1].copy() 499*da0073e9SAndroid Build Coastguard Worker out_t = torch.from_numpy(np_out_t) 500*da0073e9SAndroid Build Coastguard Worker yield in_tt, 3, out_t 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker # Noops (edge cases) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker # Size 0 505*da0073e9SAndroid Build Coastguard Worker in_t = make_from_data(()) 506*da0073e9SAndroid Build Coastguard Worker yield in_t, 0, in_t 507*da0073e9SAndroid Build Coastguard Worker yield in_t, (), in_t 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker # dims = () 510*da0073e9SAndroid Build Coastguard Worker in_t = make_from_size((3, 2, 1)) 511*da0073e9SAndroid Build Coastguard Worker yield in_t, (), in_t 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker # Zero elements, non-zero size 514*da0073e9SAndroid Build Coastguard Worker in_t = make_from_size((3, 0, 2)) 515*da0073e9SAndroid Build Coastguard Worker for i in range(in_t.ndim): 516*da0073e9SAndroid Build Coastguard Worker yield in_t, i, in_t 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker # Size 1 519*da0073e9SAndroid Build Coastguard Worker in_t = make_from_size(()) 520*da0073e9SAndroid Build Coastguard Worker yield in_t, 0, in_t 521*da0073e9SAndroid Build Coastguard Worker in_t = make_from_size((1,)) 522*da0073e9SAndroid Build Coastguard Worker yield in_t, 0, in_t 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker for in_tensor, dims, out_tensor in gen_data(): 525*da0073e9SAndroid Build Coastguard Worker test_flip_impl(in_tensor, dims, out_tensor) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker # test for shape 528*da0073e9SAndroid Build Coastguard Worker size = [2, 3, 4] 529*da0073e9SAndroid Build Coastguard Worker data = make_from_size(size) 530*da0073e9SAndroid Build Coastguard Worker possible_dims = range(len(size)) 531*da0073e9SAndroid Build Coastguard Worker test_dims = chain( 532*da0073e9SAndroid Build Coastguard Worker combinations(possible_dims, 1), combinations(possible_dims, 2) 533*da0073e9SAndroid Build Coastguard Worker ) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker for dims in test_dims: 536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(size, list(data.flip(dims).size())) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 539*da0073e9SAndroid Build Coastguard Worker def test_flip_errors(self, device, dtype): 540*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 541*da0073e9SAndroid Build Coastguard Worker data = make_arg((2, 2, 2)) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker # not allow flip on the same dim more than once 544*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) 545*da0073e9SAndroid Build Coastguard Worker # not allow empty list as input 546*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: data.flip()) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker # not allow dim > max dim 549*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) 550*da0073e9SAndroid Build Coastguard Worker self.assertRaises(IndexError, lambda: data.flip(3)) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker def _rand_shape(self, dim, min_size, max_size): 553*da0073e9SAndroid Build Coastguard Worker return tuple(torch.randint(min_size, max_size + 1, (dim,))) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 556*da0073e9SAndroid Build Coastguard Worker def test_flip_numpy(self, device, dtype): 557*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device=device) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker for ndim in [3, 4]: 560*da0073e9SAndroid Build Coastguard Worker shape = self._rand_shape(ndim, 5, 10) 561*da0073e9SAndroid Build Coastguard Worker data = make_arg(shape) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker # Axis to sample for given shape. 564*da0073e9SAndroid Build Coastguard Worker for i in range(1, ndim + 1): 565*da0073e9SAndroid Build Coastguard Worker # Check all combinations of `i` axis. 566*da0073e9SAndroid Build Coastguard Worker for flip_dim in combinations(range(ndim), i): 567*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.flip, dims=flip_dim) 568*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.flip, axis=flip_dim) 569*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, data) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker @onlyCUDA # CPU is too slow 572*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB 573*da0073e9SAndroid Build Coastguard Worker @largeTensorTest( 574*da0073e9SAndroid Build Coastguard Worker "81GB", "cpu" 575*da0073e9SAndroid Build Coastguard Worker ) # even for CUDA test, sufficient system memory is required 576*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "Too large for Jetson") 577*da0073e9SAndroid Build Coastguard Worker def test_flip_large_tensor(self, device): 578*da0073e9SAndroid Build Coastguard Worker t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() 579*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.flip, dims=(0,)) 580*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.flip, axis=0) 581*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, t_in) 582*da0073e9SAndroid Build Coastguard Worker del t_in 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype): 585*da0073e9SAndroid Build Coastguard Worker for dim in range(min_dim, max_dim + 1): 586*da0073e9SAndroid Build Coastguard Worker shape = self._rand_shape(dim, 5, 10) 587*da0073e9SAndroid Build Coastguard Worker # Randomly scale the input 588*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 589*da0073e9SAndroid Build Coastguard Worker data = torch.randn(*shape, device=device, dtype=dtype) 590*da0073e9SAndroid Build Coastguard Worker else: 591*da0073e9SAndroid Build Coastguard Worker data = torch.randint(0, 10, shape, device=device, dtype=dtype) 592*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, data) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.double, torch.cdouble) 595*da0073e9SAndroid Build Coastguard Worker def test_fliplr(self, device, dtype): 596*da0073e9SAndroid Build Coastguard Worker self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.double, torch.cdouble) 599*da0073e9SAndroid Build Coastguard Worker def test_fliplr_invalid(self, device, dtype): 600*da0073e9SAndroid Build Coastguard Worker x = torch.randn(42).to(dtype) 601*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): 602*da0073e9SAndroid Build Coastguard Worker torch.fliplr(x) 603*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): 604*da0073e9SAndroid Build Coastguard Worker torch.fliplr(torch.tensor(42, device=device, dtype=dtype)) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.double, torch.cdouble) 607*da0073e9SAndroid Build Coastguard Worker def test_flipud(self, device, dtype): 608*da0073e9SAndroid Build Coastguard Worker self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.double, torch.cdouble) 611*da0073e9SAndroid Build Coastguard Worker def test_flipud_invalid(self, device, dtype): 612*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."): 613*da0073e9SAndroid Build Coastguard Worker torch.flipud(torch.tensor(42, device=device, dtype=dtype)) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker def test_rot90(self, device): 616*da0073e9SAndroid Build Coastguard Worker data = torch.arange(1, 5, device=device).view(2, 2) 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) 618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) 620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker # test for default args k=1, dims=[0, 1] 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker # test for reversed order of dims 626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker # test for modulo of k 629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) 630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) 631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker # test for dims out-of-range error 634*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) 635*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker # test tensor with more than 2D 638*da0073e9SAndroid Build Coastguard Worker data = torch.arange(1, 9, device=device).view(2, 2, 2) 639*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 640*da0073e9SAndroid Build Coastguard Worker torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]) 641*da0073e9SAndroid Build Coastguard Worker ) 642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker # test for errors 645*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) 646*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) 647*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) 648*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with an unknown error") 651*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.cfloat, torch.cdouble) 652*da0073e9SAndroid Build Coastguard Worker def test_complex_rot90(self, device, dtype): 653*da0073e9SAndroid Build Coastguard Worker shape = self._rand_shape(random.randint(2, 4), 5, 10) 654*da0073e9SAndroid Build Coastguard Worker for rot_times in range(4): 655*da0073e9SAndroid Build Coastguard Worker data = torch.randn(*shape, device=device, dtype=dtype) 656*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1]) 657*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.rot90, k=rot_times, axes=[0, 1]) 658*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, data) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker # TODO: update once warning flag is available to always trigger ONCE warnings 661*da0073e9SAndroid Build Coastguard Worker # Ensures nonzero does not throw a warning, even when the as_tuple argument 662*da0073e9SAndroid Build Coastguard Worker # is not provided 663*da0073e9SAndroid Build Coastguard Worker def test_nonzero_no_warning(self, device): 664*da0073e9SAndroid Build Coastguard Worker t = torch.randn((2, 2), device=device) 665*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 666*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 667*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t) 668*da0073e9SAndroid Build Coastguard Worker t.nonzero() 669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 670*da0073e9SAndroid Build Coastguard Worker 671*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16)) 672*da0073e9SAndroid Build Coastguard Worker def test_nonzero(self, device, dtype): 673*da0073e9SAndroid Build Coastguard Worker shapes = [ 674*da0073e9SAndroid Build Coastguard Worker torch.Size((12,)), 675*da0073e9SAndroid Build Coastguard Worker torch.Size((12, 1)), 676*da0073e9SAndroid Build Coastguard Worker torch.Size((1, 12)), 677*da0073e9SAndroid Build Coastguard Worker torch.Size((6, 2)), 678*da0073e9SAndroid Build Coastguard Worker torch.Size((3, 2, 2)), 679*da0073e9SAndroid Build Coastguard Worker torch.Size((5, 5, 5)), 680*da0073e9SAndroid Build Coastguard Worker ] 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker def gen_nontrivial_input(shape, dtype, device): 683*da0073e9SAndroid Build Coastguard Worker if dtype != torch.bfloat16: 684*da0073e9SAndroid Build Coastguard Worker return torch.randint(2, shape, device=device, dtype=dtype) 685*da0073e9SAndroid Build Coastguard Worker else: 686*da0073e9SAndroid Build Coastguard Worker # windows does not work for bfloat16 randing 687*da0073e9SAndroid Build Coastguard Worker return torch.randint(2, shape, device=device, dtype=torch.float).to( 688*da0073e9SAndroid Build Coastguard Worker dtype 689*da0073e9SAndroid Build Coastguard Worker ) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 692*da0073e9SAndroid Build Coastguard Worker tensor = gen_nontrivial_input(shape, dtype, device) 693*da0073e9SAndroid Build Coastguard Worker dst1 = torch.nonzero(tensor, as_tuple=False) 694*da0073e9SAndroid Build Coastguard Worker dst2 = tensor.nonzero(as_tuple=False) 695*da0073e9SAndroid Build Coastguard Worker dst3 = torch.empty([], dtype=torch.long, device=device) 696*da0073e9SAndroid Build Coastguard Worker torch.nonzero(tensor, out=dst3) 697*da0073e9SAndroid Build Coastguard Worker if self.device_type != "xla": 698*da0073e9SAndroid Build Coastguard Worker # xla does not raise runtime error 699*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 700*da0073e9SAndroid Build Coastguard Worker RuntimeError, 701*da0073e9SAndroid Build Coastguard Worker "scalar type Long", 702*da0073e9SAndroid Build Coastguard Worker lambda: torch.nonzero( 703*da0073e9SAndroid Build Coastguard Worker tensor, out=torch.empty([], dtype=torch.float, device=device) 704*da0073e9SAndroid Build Coastguard Worker ), 705*da0073e9SAndroid Build Coastguard Worker ) 706*da0073e9SAndroid Build Coastguard Worker if ( 707*da0073e9SAndroid Build Coastguard Worker self.device_type == "cuda" 708*da0073e9SAndroid Build Coastguard Worker or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE 709*da0073e9SAndroid Build Coastguard Worker ): 710*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 711*da0073e9SAndroid Build Coastguard Worker RuntimeError, 712*da0073e9SAndroid Build Coastguard Worker "on the same device", 713*da0073e9SAndroid Build Coastguard Worker lambda: torch.nonzero( 714*da0073e9SAndroid Build Coastguard Worker tensor, out=torch.empty([], dtype=torch.long) 715*da0073e9SAndroid Build Coastguard Worker ), 716*da0073e9SAndroid Build Coastguard Worker ) 717*da0073e9SAndroid Build Coastguard Worker np_array = ( 718*da0073e9SAndroid Build Coastguard Worker tensor.cpu().numpy() 719*da0073e9SAndroid Build Coastguard Worker if dtype != torch.bfloat16 720*da0073e9SAndroid Build Coastguard Worker else tensor.float().cpu().numpy() 721*da0073e9SAndroid Build Coastguard Worker ) 722*da0073e9SAndroid Build Coastguard Worker np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() 723*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) 724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) 725*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) 726*da0073e9SAndroid Build Coastguard Worker tup1 = torch.nonzero(tensor, as_tuple=True) 727*da0073e9SAndroid Build Coastguard Worker tup2 = tensor.nonzero(as_tuple=True) 728*da0073e9SAndroid Build Coastguard Worker tup1 = torch.stack(tup1).t().cpu() 729*da0073e9SAndroid Build Coastguard Worker tup2 = torch.stack(tup2).t().cpu() 730*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tup1, np_result, atol=0, rtol=0) 731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tup2, np_result, atol=0, rtol=0) 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker def test_nonzero_astuple_out(self, device): 734*da0073e9SAndroid Build Coastguard Worker t = torch.randn((3, 3, 3), device=device) 735*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(t, dtype=torch.long) 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 738*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t, as_tuple=True, out=out) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 741*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out) 742*da0073e9SAndroid Build Coastguard Worker ) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker # Verifies that JIT script cannot handle the as_tuple kwarg 745*da0073e9SAndroid Build Coastguard Worker # See Issue https://github.com/pytorch/pytorch/issues/45499. 746*da0073e9SAndroid Build Coastguard Worker def _foo(t): 747*da0073e9SAndroid Build Coastguard Worker tuple_result = torch.nonzero(t, as_tuple=True) 748*da0073e9SAndroid Build Coastguard Worker nontuple_result = torch.nonzero(t, as_tuple=False) 749*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(nontuple_result) 750*da0073e9SAndroid Build Coastguard Worker torch.nonzero(t, as_tuple=False, out=out) 751*da0073e9SAndroid Build Coastguard Worker return tuple_result, nontuple_result, out 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 754*da0073e9SAndroid Build Coastguard Worker scripted_foo = torch.jit.script(_foo) 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker # Verifies that JIT tracing works fine 757*da0073e9SAndroid Build Coastguard Worker traced_foo = torch.jit.trace(_foo, t) 758*da0073e9SAndroid Build Coastguard Worker traced_tuple, traced_nontuple, traced_out = traced_foo(t) 759*da0073e9SAndroid Build Coastguard Worker expected_tuple = torch.nonzero(t, as_tuple=True) 760*da0073e9SAndroid Build Coastguard Worker expected_nontuple = torch.nonzero(t) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_tuple, expected_tuple) 763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_nontuple, expected_nontuple) 764*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced_out, expected_nontuple) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 767*da0073e9SAndroid Build Coastguard Worker def test_nonzero_discontiguous(self, device): 768*da0073e9SAndroid Build Coastguard Worker shape = (4, 4) 769*da0073e9SAndroid Build Coastguard Worker tensor = torch.randint(2, shape, device=device) 770*da0073e9SAndroid Build Coastguard Worker tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_( 771*da0073e9SAndroid Build Coastguard Worker tensor 772*da0073e9SAndroid Build Coastguard Worker ) 773*da0073e9SAndroid Build Coastguard Worker dst1 = tensor.nonzero(as_tuple=False) 774*da0073e9SAndroid Build Coastguard Worker dst2 = tensor_nc.nonzero(as_tuple=False) 775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1, dst2, atol=0, rtol=0) 776*da0073e9SAndroid Build Coastguard Worker dst3 = torch.empty_like(dst1) 777*da0073e9SAndroid Build Coastguard Worker data_ptr = dst3.data_ptr() 778*da0073e9SAndroid Build Coastguard Worker # expect dst3 storage to be reused 779*da0073e9SAndroid Build Coastguard Worker torch.nonzero(tensor, out=dst3) 780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data_ptr, dst3.data_ptr()) 781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1, dst3, atol=0, rtol=0) 782*da0073e9SAndroid Build Coastguard Worker # discontiguous out 783*da0073e9SAndroid Build Coastguard Worker dst4 = torch.empty( 784*da0073e9SAndroid Build Coastguard Worker dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device 785*da0073e9SAndroid Build Coastguard Worker )[:, ::2] 786*da0073e9SAndroid Build Coastguard Worker data_ptr = dst4.data_ptr() 787*da0073e9SAndroid Build Coastguard Worker strides = dst4.stride() 788*da0073e9SAndroid Build Coastguard Worker torch.nonzero(tensor, out=dst4) 789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(data_ptr, dst4.data_ptr()) 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst1, dst4, atol=0, rtol=0) 791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strides, dst4.stride()) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker def test_nonzero_non_diff(self, device): 794*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, requires_grad=True) 795*da0073e9SAndroid Build Coastguard Worker nz = x.nonzero() 796*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nz.requires_grad) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float, torch.complex128) 799*da0073e9SAndroid Build Coastguard Worker def test_sparse_dense_dim(self, device, dtype): 800*da0073e9SAndroid Build Coastguard Worker for shape in [(), (2,), (2, 3)]: 801*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex or dtype.is_floating_point: 802*da0073e9SAndroid Build Coastguard Worker x = torch.rand(shape, device=device, dtype=dtype) 803*da0073e9SAndroid Build Coastguard Worker else: 804*da0073e9SAndroid Build Coastguard Worker x = torch.randint(-9, 9, shape, device=device, dtype=dtype) 805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sparse_dim(), 0) 806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dense_dim(), len(shape)) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestShapeOps, globals()) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 812*da0073e9SAndroid Build Coastguard Worker run_tests() 813