1# Owner(s): ["module: tests"] 2 3import random 4import unittest 5import warnings 6from functools import partial 7from itertools import chain, combinations, permutations, product 8 9import numpy as np 10 11import torch 12from torch import nan 13from torch.testing import make_tensor 14from torch.testing._internal.common_device_type import ( 15 dtypes, 16 dtypesIfCUDA, 17 instantiate_device_type_tests, 18 largeTensorTest, 19 onlyCPU, 20 onlyCUDA, 21 onlyNativeDeviceTypes, 22) 23from torch.testing._internal.common_dtype import ( 24 all_types, 25 all_types_and, 26 all_types_and_complex_and, 27) 28from torch.testing._internal.common_utils import ( 29 IS_JETSON, 30 run_tests, 31 skipIfTorchDynamo, 32 TEST_PRIVATEUSE1_DEVICE_TYPE, 33 TestCase, 34 torch_to_numpy_dtype_dict, 35) 36 37 38# TODO: replace with make_tensor 39def _generate_input(shape, dtype, device, with_extremal): 40 if shape == (): 41 x = torch.tensor((), dtype=dtype, device=device) 42 else: 43 if dtype.is_floating_point or dtype.is_complex: 44 # work around torch.randn not being implemented for bfloat16 45 if dtype == torch.bfloat16: 46 x = torch.randn(*shape, device=device) * random.randint(30, 100) 47 x = x.to(torch.bfloat16) 48 else: 49 x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( 50 30, 100 51 ) 52 x[torch.randn(*shape) > 0.5] = 0 53 if with_extremal and dtype.is_floating_point: 54 # Use extremal values 55 x[torch.randn(*shape) > 0.5] = float("nan") 56 x[torch.randn(*shape) > 0.5] = float("inf") 57 x[torch.randn(*shape) > 0.5] = float("-inf") 58 elif with_extremal and dtype.is_complex: 59 x[torch.randn(*shape) > 0.5] = complex("nan") 60 x[torch.randn(*shape) > 0.5] = complex("inf") 61 x[torch.randn(*shape) > 0.5] = complex("-inf") 62 elif dtype == torch.bool: 63 x = torch.zeros(shape, dtype=dtype, device=device) 64 x[torch.randn(*shape) > 0.5] = True 65 else: 66 x = torch.randint(15, 100, shape, dtype=dtype, device=device) 67 68 return x 69 70 71class TestShapeOps(TestCase): 72 # TODO: update to work on CUDA, too 73 @onlyCPU 74 def test_unbind(self, device): 75 x = torch.rand(2, 3, 4, 5) 76 for dim in range(4): 77 res = torch.unbind(x, dim) 78 res2 = x.unbind(dim) 79 self.assertEqual(x.size(dim), len(res)) 80 self.assertEqual(x.size(dim), len(res2)) 81 for i in range(dim): 82 self.assertEqual(x.select(dim, i), res[i]) 83 self.assertEqual(x.select(dim, i), res2[i]) 84 85 # TODO: update to work on CUDA, too? 86 @skipIfTorchDynamo("TorchDynamo fails with an unknown error") 87 @onlyCPU 88 def test_tolist(self, device): 89 list0D = [] 90 tensor0D = torch.tensor(list0D) 91 self.assertEqual(tensor0D.tolist(), list0D) 92 93 table1D = [1.0, 2.0, 3.0] 94 tensor1D = torch.tensor(table1D) 95 storage = torch.Storage(table1D) 96 self.assertEqual(tensor1D.tolist(), table1D) 97 self.assertEqual(storage.tolist(), table1D) 98 self.assertEqual(tensor1D.tolist(), table1D) 99 self.assertEqual(storage.tolist(), table1D) 100 101 table2D = [[1, 2], [3, 4]] 102 tensor2D = torch.tensor(table2D) 103 self.assertEqual(tensor2D.tolist(), table2D) 104 105 tensor3D = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) 106 tensorNonContig = tensor3D.select(1, 1) 107 self.assertFalse(tensorNonContig.is_contiguous()) 108 self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]]) 109 110 @dtypes(torch.int64, torch.float, torch.complex128) 111 def test_movedim_invalid(self, device, dtype): 112 shape = self._rand_shape(4, min_size=5, max_size=10) 113 x = _generate_input(shape, dtype, device, False) 114 115 for fn in [torch.movedim, torch.moveaxis]: 116 # Invalid `source` and `destination` dimension 117 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 118 fn(x, 5, 0) 119 120 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 121 fn(x, 0, 5) 122 123 # Mismatch in size of `source` and `destination` 124 with self.assertRaisesRegex( 125 RuntimeError, "movedim: Invalid source or destination dims:" 126 ): 127 fn(x, (1, 0), (0,)) 128 129 with self.assertRaisesRegex( 130 RuntimeError, "movedim: repeated dim in `source`" 131 ): 132 fn(x, (0, 0), (0, 1)) 133 134 with self.assertRaisesRegex( 135 RuntimeError, "movedim: repeated dim in `source`" 136 ): 137 fn(x, (0, 1, 0), (0, 1, 2)) 138 139 with self.assertRaisesRegex( 140 RuntimeError, "movedim: repeated dim in `destination`" 141 ): 142 fn(x, (0, 1), (1, 1)) 143 144 with self.assertRaisesRegex( 145 RuntimeError, "movedim: repeated dim in `destination`" 146 ): 147 fn(x, (0, 1, 2), (1, 0, 1)) 148 149 @dtypes(torch.int64, torch.float, torch.complex128) 150 def test_movedim(self, device, dtype): 151 for fn in [torch.moveaxis, torch.movedim]: 152 for nd in range(5): 153 shape = self._rand_shape(nd, min_size=5, max_size=10) 154 x = _generate_input(shape, dtype, device, with_extremal=False) 155 for random_negative in [True, False]: 156 for src_dim, dst_dim in permutations(range(nd), r=2): 157 random_prob = random.random() 158 159 if random_negative and random_prob > 0.66: 160 src_dim = src_dim - nd 161 elif random_negative and random_prob > 0.33: 162 dst_dim = dst_dim - nd 163 elif random_negative: 164 src_dim = src_dim - nd 165 dst_dim = dst_dim - nd 166 167 # Integer `source` and `destination` 168 torch_fn = partial(fn, source=src_dim, destination=dst_dim) 169 np_fn = partial( 170 np.moveaxis, source=src_dim, destination=dst_dim 171 ) 172 self.compare_with_numpy( 173 torch_fn, np_fn, x, device=None, dtype=None 174 ) 175 176 if nd == 0: 177 continue 178 179 def make_index_negative(sequence, idx): 180 sequence = list(sequence) 181 sequence[random_idx] = sequence[random_idx] - nd 182 return tuple(src_sequence) 183 184 for src_sequence in permutations( 185 range(nd), r=random.randint(1, nd) 186 ): 187 # Sequence `source` and `destination` 188 dst_sequence = tuple( 189 random.sample(range(nd), len(src_sequence)) 190 ) 191 192 # Randomly change a dim to a negative dim representation of itself. 193 random_prob = random.random() 194 if random_negative and random_prob > 0.66: 195 random_idx = random.randint(0, len(src_sequence) - 1) 196 src_sequence = make_index_negative(src_sequence, random_idx) 197 elif random_negative and random_prob > 0.33: 198 random_idx = random.randint(0, len(src_sequence) - 1) 199 dst_sequence = make_index_negative(dst_sequence, random_idx) 200 elif random_negative: 201 random_idx = random.randint(0, len(src_sequence) - 1) 202 dst_sequence = make_index_negative(dst_sequence, random_idx) 203 random_idx = random.randint(0, len(src_sequence) - 1) 204 src_sequence = make_index_negative(src_sequence, random_idx) 205 206 torch_fn = partial( 207 fn, source=src_sequence, destination=dst_sequence 208 ) 209 np_fn = partial( 210 np.moveaxis, source=src_sequence, destination=dst_sequence 211 ) 212 self.compare_with_numpy( 213 torch_fn, np_fn, x, device=None, dtype=None 214 ) 215 216 # Move dim to same position 217 x = torch.randn(2, 3, 5, 7, 11) 218 torch_fn = partial(fn, source=(0, 1), destination=(0, 1)) 219 np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1)) 220 self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 221 222 torch_fn = partial(fn, source=1, destination=1) 223 np_fn = partial(np.moveaxis, source=1, destination=1) 224 self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 225 226 # Empty Sequence 227 torch_fn = partial(fn, source=(), destination=()) 228 np_fn = partial(np.moveaxis, source=(), destination=()) 229 self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 230 231 @dtypes(torch.float, torch.bool) 232 def test_diag(self, device, dtype): 233 if dtype is torch.bool: 234 x = torch.rand(100, 100, device=device) >= 0.5 235 else: 236 x = torch.rand(100, 100, dtype=dtype, device=device) 237 238 res1 = torch.diag(x) 239 res2 = torch.tensor((), dtype=dtype, device=device) 240 torch.diag(x, out=res2) 241 self.assertEqual(res1, res2) 242 243 def test_diagonal(self, device): 244 x = torch.randn((100, 100), device=device) 245 result = torch.diagonal(x) 246 expected = torch.diag(x) 247 self.assertEqual(result, expected) 248 249 x = torch.randn((100, 100), device=device) 250 result = torch.diagonal(x, 17) 251 expected = torch.diag(x, 17) 252 self.assertEqual(result, expected) 253 254 @onlyCPU 255 @dtypes(torch.float) 256 def test_diagonal_multidim(self, device, dtype): 257 x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device) 258 xn = x.numpy() 259 for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]: 260 result = torch.diagonal(x, *args) 261 expected = xn.diagonal(*args) 262 self.assertEqual(expected.shape, result.shape) 263 self.assertEqual(expected, result) 264 # test non-continguous 265 xp = x.permute(1, 2, 3, 0) 266 result = torch.diagonal(xp, 0, -2, -1) 267 expected = xp.numpy().diagonal(0, -2, -1) 268 self.assertEqual(expected.shape, result.shape) 269 self.assertEqual(expected, result) 270 271 @onlyNativeDeviceTypes 272 @dtypes(*all_types()) 273 @dtypesIfCUDA(*all_types_and(torch.half)) 274 def test_trace(self, device, dtype): 275 def test(shape): 276 tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) 277 expected_dtype = tensor.sum().dtype 278 expected_dtype = torch_to_numpy_dtype_dict[expected_dtype] 279 280 result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype) 281 expected = torch.tensor(result, device=device) 282 self.assertEqual(tensor.trace(), expected) 283 284 shapes = ( 285 [10, 1], 286 [1, 10], 287 [100, 100], 288 [20, 100], 289 [100, 20], 290 ) 291 for shape in shapes: 292 test(shape) 293 294 def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans): 295 """ 296 Creates a random tensor for a given device and dtype, and computes the expected clamped 297 values given the min_vals and/or max_vals. 298 If with_nans is provided, then some values are randomly set to nan. 299 """ 300 X = torch.rand(100, device=device).mul(50).add(-25) # uniform in [-25, 25] 301 X = X.to(dtype) 302 if with_nans: 303 mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device) 304 X[mask] = nan 305 306 if isinstance(min_vals, torch.Tensor): 307 min_vals = min_vals.cpu().numpy() 308 309 if isinstance(max_vals, torch.Tensor): 310 max_vals = max_vals.cpu().numpy() 311 312 # Use NumPy implementation as reference 313 X_clamped = torch.tensor( 314 np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device 315 ) 316 return X, X_clamped 317 318 # Tests clamp and its alias, clip 319 @dtypes(torch.int64, torch.float32) 320 def test_clamp(self, device, dtype): 321 op_list = ( 322 torch.clamp, 323 torch.Tensor.clamp, 324 torch.Tensor.clamp_, 325 torch.clip, 326 torch.Tensor.clip, 327 torch.Tensor.clip_, 328 ) 329 330 # min/max argument product 331 args = product((-10, None), (10, None)) 332 333 for op in op_list: 334 for min_val, max_val in args: 335 if min_val is None and max_val is None: 336 continue 337 338 X, Y_expected = self.generate_clamp_baseline( 339 device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False 340 ) 341 342 # Test op 343 X1 = X.clone() # So that the in-place ops do not change X 344 Y_actual = op(X1, min_val, max_val) 345 self.assertEqual(Y_expected, Y_actual) 346 347 # Test op-out behavior (out does not exist for method versions) 348 if op in (torch.clamp, torch.clip): 349 Y_out = torch.empty_like(X) 350 op(X, min=min_val, max=max_val, out=Y_out) 351 self.assertEqual(Y_expected, Y_out) 352 353 def test_clamp_propagates_nans(self, device): 354 op_list = ( 355 torch.clamp, 356 torch.Tensor.clamp, 357 torch.Tensor.clamp_, 358 torch.clip, 359 torch.Tensor.clip, 360 torch.Tensor.clip_, 361 ) 362 363 # min/max argument product 364 args = product((-10, None), (10, None)) 365 366 for op in op_list: 367 for min_val, max_val in args: 368 if min_val is None and max_val is None: 369 continue 370 371 X, Y_expected = self.generate_clamp_baseline( 372 device, 373 torch.float, 374 min_vals=min_val, 375 max_vals=max_val, 376 with_nans=True, 377 ) 378 Y_expected = torch.isnan(Y_expected) 379 380 # Test op 381 X1 = X.clone() # So that the in-place ops do not change X 382 Y_actual = op(X1, min_val, max_val) 383 self.assertEqual(Y_expected, torch.isnan(Y_actual)) 384 385 # Test op-out behavior (out does not exist for method versions) 386 if op in (torch.clamp, torch.clip): 387 Y_out = torch.empty_like(X) 388 op(X, min_val, max_val, out=Y_out) 389 self.assertEqual(Y_expected, torch.isnan(Y_out)) 390 391 def test_clamp_raises_arg_errors(self, device): 392 X = torch.randn(100, dtype=torch.float, device=device) 393 error_msg = "At least one of 'min' or 'max' must not be None" 394 with self.assertRaisesRegex(RuntimeError, error_msg): 395 X.clamp() 396 with self.assertRaisesRegex(RuntimeError, error_msg): 397 X.clamp_() 398 with self.assertRaisesRegex(RuntimeError, error_msg): 399 torch.clamp(X) 400 401 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 402 def test_flip(self, device, dtype): 403 make_from_data = partial(torch.tensor, device=device, dtype=dtype) 404 make_from_size = partial(make_tensor, device=device, dtype=dtype) 405 406 def test_flip_impl(input_t, dims, output_t): 407 def all_t(): 408 yield input_t, output_t 409 if dtype is torch.float: 410 # We generate quantized versions as well 411 for qdtype in (torch.quint8, torch.qint8, torch.qint32): 412 qinput_t = torch.quantize_per_tensor(input_t, 0.1, 5, qdtype) 413 qoutput_t = torch.quantize_per_tensor(output_t, 0.1, 5, qdtype) 414 yield qinput_t, qoutput_t 415 416 for in_t, out_t in all_t(): 417 self.assertEqual(in_t.flip(dims), out_t) 418 n = in_t.ndim 419 if not isinstance(dims, tuple): 420 # Wrap dim 421 self.assertEqual(in_t.flip(-n + dims), out_t) 422 else: 423 # Permute dimensions 424 for p_dims in permutations(dims): 425 self.assertEqual(in_t.flip(p_dims), out_t) 426 if len(p_dims) > 0: 427 # Wrap 1st dim 428 self.assertEqual( 429 in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t 430 ) 431 432 def gen_data(): 433 # Basic tests 434 data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2) 435 nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data) 436 437 dims_result = ( 438 (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)), 439 (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)), 440 (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)), 441 ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)), 442 ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)), 443 ) 444 for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result): 445 yield in_tensor, dims, out_tensor 446 447 # Expanded 448 in_t = make_from_data([1, 2, 3]).view(3, 1).expand(3, 2) 449 dims = 0 450 out_t = make_from_data([3, 3, 2, 2, 1, 1]).view(3, 2) 451 yield in_t, dims, out_t 452 # Noop on expanded dimension 453 yield in_t, 1, in_t 454 455 # Transposed 456 in_t = ( 457 make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1) 458 ) 459 dims = (0, 1, 2) 460 out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2) 461 yield in_t, dims, out_t 462 463 # Rectangular case 464 in_t = make_from_data([1, 2, 3, 4, 5, 6]).view(2, 3) 465 dims = 0 466 out_t = make_from_data([[4, 5, 6], [1, 2, 3]]) 467 yield in_t, dims, out_t 468 dims = 1 469 out_t = make_from_data([[3, 2, 1], [6, 5, 4]]) 470 yield in_t, dims, out_t 471 472 # vectorized NCHW cases (images) 473 if device == "cpu" and dtype != torch.bfloat16: 474 for mf in [torch.contiguous_format, torch.channels_last]: 475 for c in [2, 3, 8, 16]: 476 in_t = make_from_size((2, c, 32, 32)).contiguous( 477 memory_format=mf 478 ) 479 np_in_t = in_t.numpy() 480 481 np_out_t = np_in_t[:, :, :, ::-1].copy() 482 out_t = torch.from_numpy(np_out_t) 483 yield in_t, 3, out_t 484 485 np_out_t = np_in_t[:, :, ::-1, :].copy() 486 out_t = torch.from_numpy(np_out_t) 487 yield in_t, 2, out_t 488 489 # non-contig cases 490 in_tt = in_t[..., ::2, :] 491 np_in_t = in_tt.numpy() 492 np_out_t = np_in_t[:, :, :, ::-1].copy() 493 out_t = torch.from_numpy(np_out_t) 494 yield in_tt, 3, out_t 495 496 in_tt = in_t[..., ::2] 497 np_in_t = in_tt.numpy() 498 np_out_t = np_in_t[:, :, :, ::-1].copy() 499 out_t = torch.from_numpy(np_out_t) 500 yield in_tt, 3, out_t 501 502 # Noops (edge cases) 503 504 # Size 0 505 in_t = make_from_data(()) 506 yield in_t, 0, in_t 507 yield in_t, (), in_t 508 509 # dims = () 510 in_t = make_from_size((3, 2, 1)) 511 yield in_t, (), in_t 512 513 # Zero elements, non-zero size 514 in_t = make_from_size((3, 0, 2)) 515 for i in range(in_t.ndim): 516 yield in_t, i, in_t 517 518 # Size 1 519 in_t = make_from_size(()) 520 yield in_t, 0, in_t 521 in_t = make_from_size((1,)) 522 yield in_t, 0, in_t 523 524 for in_tensor, dims, out_tensor in gen_data(): 525 test_flip_impl(in_tensor, dims, out_tensor) 526 527 # test for shape 528 size = [2, 3, 4] 529 data = make_from_size(size) 530 possible_dims = range(len(size)) 531 test_dims = chain( 532 combinations(possible_dims, 1), combinations(possible_dims, 2) 533 ) 534 535 for dims in test_dims: 536 self.assertEqual(size, list(data.flip(dims).size())) 537 538 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 539 def test_flip_errors(self, device, dtype): 540 make_arg = partial(make_tensor, dtype=dtype, device=device) 541 data = make_arg((2, 2, 2)) 542 543 # not allow flip on the same dim more than once 544 self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) 545 # not allow empty list as input 546 self.assertRaises(TypeError, lambda: data.flip()) 547 548 # not allow dim > max dim 549 self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) 550 self.assertRaises(IndexError, lambda: data.flip(3)) 551 552 def _rand_shape(self, dim, min_size, max_size): 553 return tuple(torch.randint(min_size, max_size + 1, (dim,))) 554 555 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 556 def test_flip_numpy(self, device, dtype): 557 make_arg = partial(make_tensor, dtype=dtype, device=device) 558 559 for ndim in [3, 4]: 560 shape = self._rand_shape(ndim, 5, 10) 561 data = make_arg(shape) 562 563 # Axis to sample for given shape. 564 for i in range(1, ndim + 1): 565 # Check all combinations of `i` axis. 566 for flip_dim in combinations(range(ndim), i): 567 torch_fn = partial(torch.flip, dims=flip_dim) 568 np_fn = partial(np.flip, axis=flip_dim) 569 self.compare_with_numpy(torch_fn, np_fn, data) 570 571 @onlyCUDA # CPU is too slow 572 @largeTensorTest("17GB") # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB 573 @largeTensorTest( 574 "81GB", "cpu" 575 ) # even for CUDA test, sufficient system memory is required 576 @unittest.skipIf(IS_JETSON, "Too large for Jetson") 577 def test_flip_large_tensor(self, device): 578 t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_() 579 torch_fn = partial(torch.flip, dims=(0,)) 580 np_fn = partial(np.flip, axis=0) 581 self.compare_with_numpy(torch_fn, np_fn, t_in) 582 del t_in 583 584 def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype): 585 for dim in range(min_dim, max_dim + 1): 586 shape = self._rand_shape(dim, 5, 10) 587 # Randomly scale the input 588 if dtype.is_floating_point or dtype.is_complex: 589 data = torch.randn(*shape, device=device, dtype=dtype) 590 else: 591 data = torch.randint(0, 10, shape, device=device, dtype=dtype) 592 self.compare_with_numpy(torch_fn, np_fn, data) 593 594 @dtypes(torch.int64, torch.double, torch.cdouble) 595 def test_fliplr(self, device, dtype): 596 self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype) 597 598 @dtypes(torch.int64, torch.double, torch.cdouble) 599 def test_fliplr_invalid(self, device, dtype): 600 x = torch.randn(42).to(dtype) 601 with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): 602 torch.fliplr(x) 603 with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."): 604 torch.fliplr(torch.tensor(42, device=device, dtype=dtype)) 605 606 @dtypes(torch.int64, torch.double, torch.cdouble) 607 def test_flipud(self, device, dtype): 608 self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype) 609 610 @dtypes(torch.int64, torch.double, torch.cdouble) 611 def test_flipud_invalid(self, device, dtype): 612 with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."): 613 torch.flipud(torch.tensor(42, device=device, dtype=dtype)) 614 615 def test_rot90(self, device): 616 data = torch.arange(1, 5, device=device).view(2, 2) 617 self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) 618 self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) 619 self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) 620 self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) 621 622 # test for default args k=1, dims=[0, 1] 623 self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) 624 625 # test for reversed order of dims 626 self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) 627 628 # test for modulo of k 629 self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) 630 self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) 631 self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) 632 633 # test for dims out-of-range error 634 self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) 635 self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) 636 637 # test tensor with more than 2D 638 data = torch.arange(1, 9, device=device).view(2, 2, 2) 639 self.assertEqual( 640 torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]) 641 ) 642 self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) 643 644 # test for errors 645 self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) 646 self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) 647 self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) 648 self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) 649 650 @skipIfTorchDynamo("TorchDynamo fails with an unknown error") 651 @dtypes(torch.cfloat, torch.cdouble) 652 def test_complex_rot90(self, device, dtype): 653 shape = self._rand_shape(random.randint(2, 4), 5, 10) 654 for rot_times in range(4): 655 data = torch.randn(*shape, device=device, dtype=dtype) 656 torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1]) 657 np_fn = partial(np.rot90, k=rot_times, axes=[0, 1]) 658 self.compare_with_numpy(torch_fn, np_fn, data) 659 660 # TODO: update once warning flag is available to always trigger ONCE warnings 661 # Ensures nonzero does not throw a warning, even when the as_tuple argument 662 # is not provided 663 def test_nonzero_no_warning(self, device): 664 t = torch.randn((2, 2), device=device) 665 with warnings.catch_warnings(record=True) as w: 666 warnings.simplefilter("always") 667 torch.nonzero(t) 668 t.nonzero() 669 self.assertEqual(len(w), 0) 670 671 @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16)) 672 def test_nonzero(self, device, dtype): 673 shapes = [ 674 torch.Size((12,)), 675 torch.Size((12, 1)), 676 torch.Size((1, 12)), 677 torch.Size((6, 2)), 678 torch.Size((3, 2, 2)), 679 torch.Size((5, 5, 5)), 680 ] 681 682 def gen_nontrivial_input(shape, dtype, device): 683 if dtype != torch.bfloat16: 684 return torch.randint(2, shape, device=device, dtype=dtype) 685 else: 686 # windows does not work for bfloat16 randing 687 return torch.randint(2, shape, device=device, dtype=torch.float).to( 688 dtype 689 ) 690 691 for shape in shapes: 692 tensor = gen_nontrivial_input(shape, dtype, device) 693 dst1 = torch.nonzero(tensor, as_tuple=False) 694 dst2 = tensor.nonzero(as_tuple=False) 695 dst3 = torch.empty([], dtype=torch.long, device=device) 696 torch.nonzero(tensor, out=dst3) 697 if self.device_type != "xla": 698 # xla does not raise runtime error 699 self.assertRaisesRegex( 700 RuntimeError, 701 "scalar type Long", 702 lambda: torch.nonzero( 703 tensor, out=torch.empty([], dtype=torch.float, device=device) 704 ), 705 ) 706 if ( 707 self.device_type == "cuda" 708 or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE 709 ): 710 self.assertRaisesRegex( 711 RuntimeError, 712 "on the same device", 713 lambda: torch.nonzero( 714 tensor, out=torch.empty([], dtype=torch.long) 715 ), 716 ) 717 np_array = ( 718 tensor.cpu().numpy() 719 if dtype != torch.bfloat16 720 else tensor.float().cpu().numpy() 721 ) 722 np_result = torch.from_numpy(np.stack(np_array.nonzero())).t() 723 self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0) 724 self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0) 725 self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0) 726 tup1 = torch.nonzero(tensor, as_tuple=True) 727 tup2 = tensor.nonzero(as_tuple=True) 728 tup1 = torch.stack(tup1).t().cpu() 729 tup2 = torch.stack(tup2).t().cpu() 730 self.assertEqual(tup1, np_result, atol=0, rtol=0) 731 self.assertEqual(tup2, np_result, atol=0, rtol=0) 732 733 def test_nonzero_astuple_out(self, device): 734 t = torch.randn((3, 3, 3), device=device) 735 out = torch.empty_like(t, dtype=torch.long) 736 737 with self.assertRaises(RuntimeError): 738 torch.nonzero(t, as_tuple=True, out=out) 739 740 self.assertEqual( 741 torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out) 742 ) 743 744 # Verifies that JIT script cannot handle the as_tuple kwarg 745 # See Issue https://github.com/pytorch/pytorch/issues/45499. 746 def _foo(t): 747 tuple_result = torch.nonzero(t, as_tuple=True) 748 nontuple_result = torch.nonzero(t, as_tuple=False) 749 out = torch.empty_like(nontuple_result) 750 torch.nonzero(t, as_tuple=False, out=out) 751 return tuple_result, nontuple_result, out 752 753 with self.assertRaises(RuntimeError): 754 scripted_foo = torch.jit.script(_foo) 755 756 # Verifies that JIT tracing works fine 757 traced_foo = torch.jit.trace(_foo, t) 758 traced_tuple, traced_nontuple, traced_out = traced_foo(t) 759 expected_tuple = torch.nonzero(t, as_tuple=True) 760 expected_nontuple = torch.nonzero(t) 761 762 self.assertEqual(traced_tuple, expected_tuple) 763 self.assertEqual(traced_nontuple, expected_nontuple) 764 self.assertEqual(traced_out, expected_nontuple) 765 766 @onlyNativeDeviceTypes 767 def test_nonzero_discontiguous(self, device): 768 shape = (4, 4) 769 tensor = torch.randint(2, shape, device=device) 770 tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_( 771 tensor 772 ) 773 dst1 = tensor.nonzero(as_tuple=False) 774 dst2 = tensor_nc.nonzero(as_tuple=False) 775 self.assertEqual(dst1, dst2, atol=0, rtol=0) 776 dst3 = torch.empty_like(dst1) 777 data_ptr = dst3.data_ptr() 778 # expect dst3 storage to be reused 779 torch.nonzero(tensor, out=dst3) 780 self.assertEqual(data_ptr, dst3.data_ptr()) 781 self.assertEqual(dst1, dst3, atol=0, rtol=0) 782 # discontiguous out 783 dst4 = torch.empty( 784 dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device 785 )[:, ::2] 786 data_ptr = dst4.data_ptr() 787 strides = dst4.stride() 788 torch.nonzero(tensor, out=dst4) 789 self.assertEqual(data_ptr, dst4.data_ptr()) 790 self.assertEqual(dst1, dst4, atol=0, rtol=0) 791 self.assertEqual(strides, dst4.stride()) 792 793 def test_nonzero_non_diff(self, device): 794 x = torch.randn(10, requires_grad=True) 795 nz = x.nonzero() 796 self.assertFalse(nz.requires_grad) 797 798 @dtypes(torch.int64, torch.float, torch.complex128) 799 def test_sparse_dense_dim(self, device, dtype): 800 for shape in [(), (2,), (2, 3)]: 801 if dtype.is_complex or dtype.is_floating_point: 802 x = torch.rand(shape, device=device, dtype=dtype) 803 else: 804 x = torch.randint(-9, 9, shape, device=device, dtype=dtype) 805 self.assertEqual(x.sparse_dim(), 0) 806 self.assertEqual(x.dense_dim(), len(shape)) 807 808 809instantiate_device_type_tests(TestShapeOps, globals()) 810 811if __name__ == "__main__": 812 run_tests() 813