1# Owner(s): ["module: tests"] 2import random 3import unittest 4from functools import partial 5from itertools import combinations, permutations, product 6 7import numpy as np 8 9import torch 10from torch.testing import make_tensor 11from torch.testing._internal.common_device_type import ( 12 dtypes, 13 instantiate_device_type_tests, 14 onlyCPU, 15 onlyNativeDeviceTypes, 16 skipLazy, 17 skipMeta, 18 skipXLA, 19) 20from torch.testing._internal.common_dtype import ( 21 all_types_and, 22 all_types_and_complex_and, 23 complex_types, 24 floating_and_complex_types_and, 25) 26from torch.testing._internal.common_utils import ( 27 gradcheck, 28 gradgradcheck, 29 IS_FBCODE, 30 numpy_to_torch_dtype_dict, 31 run_tests, 32 skipIfTorchDynamo, 33 suppress_warnings, 34 TestCase, 35) 36 37 38# TODO: replace this with make_tensor() in common_utils.py 39def _generate_input(shape, dtype, device, with_extremal): 40 if shape == (): 41 x = torch.tensor((), dtype=dtype, device=device) 42 else: 43 if dtype.is_floating_point or dtype.is_complex: 44 # work around torch.randn not being implemented for bfloat16 45 if dtype == torch.bfloat16: 46 x = torch.randn(*shape, device=device) * random.randint(30, 100) 47 x = x.to(torch.bfloat16) 48 else: 49 x = torch.randn(*shape, dtype=dtype, device=device) * random.randint( 50 30, 100 51 ) 52 x[torch.randn(*shape) > 0.5] = 0 53 if with_extremal and dtype.is_floating_point: 54 # Use extremal values 55 x[torch.randn(*shape) > 0.5] = float("nan") 56 x[torch.randn(*shape) > 0.5] = float("inf") 57 x[torch.randn(*shape) > 0.5] = float("-inf") 58 elif with_extremal and dtype.is_complex: 59 x[torch.randn(*shape) > 0.5] = complex("nan") 60 x[torch.randn(*shape) > 0.5] = complex("inf") 61 x[torch.randn(*shape) > 0.5] = complex("-inf") 62 elif dtype == torch.bool: 63 x = torch.zeros(shape, dtype=dtype, device=device) 64 x[torch.randn(*shape) > 0.5] = True 65 else: 66 x = torch.randint(15, 100, shape, dtype=dtype, device=device) 67 68 return x 69 70 71# TODO: replace this with make_tensor() in common_utils.py 72def _rand_shape(dim, min_size, max_size): 73 shape = [] 74 for i in range(dim): 75 shape.append(random.randint(min_size, max_size)) 76 return tuple(shape) 77 78 79# TODO: refactor tests to avoid this function 80# Converts half/bfloat16 dtype to float when device is cpu 81def _convert_t(dtype, device): 82 if device == "cpu" and dtype in {torch.half, torch.bfloat16}: 83 return torch.float 84 return dtype 85 86 87# TODO: replace this with make_tensor() in common_utils.py 88# Returns a tensor of the requested shape, dtype, and device 89# Requesting a half CPU tensor returns a float CPU tensor with 90# values representable by a half. 91# Initialization uses randint for non-float types and randn for float types. 92def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: 93 # Returns a tensor filled with ones 94 if fill_ones: 95 return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) 96 97 # Returns a tensor with random integer values 98 if not (dtype.is_floating_point or dtype.is_complex): 99 t = torch.randint(0, 10, shape, device=device) 100 if dtype != torch.uint8: 101 t = t - 5 # generate negative values also 102 return t.to(_convert_t(dtype, device)) 103 104 # Populates the CPU tensor with floats representable as half/bfloat16 105 if dtype == torch.half and device == "cpu": 106 return torch.randn(*shape, dtype=torch.float, device=device).half().float() 107 if dtype == torch.bfloat16 and device == "cpu": 108 return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() 109 110 # Default: returns a tensor with random float values 111 return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) 112 113 114# Tests ops and indexing to ensure they return views (and new tensors) as 115# appropriate. 116class TestViewOps(TestCase): 117 exact_dtype = True 118 119 def is_view_of(self, base, other): 120 if ( 121 not other._is_view() 122 or other is base 123 or other._base is not base 124 or base.device != other.device 125 ): 126 return False 127 # Note: only validates storage on native device types 128 # because some accelerators, like XLA, do not expose storage 129 if base.device.type == "cpu" or base.device.type == "cuda": 130 if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr(): 131 return False 132 133 return True 134 135 # Returns true if v1 and v2 are views of the same base 136 def is_view_of_same_base(self, v1, v2): 137 if not v1._is_view() or v1 is v2: 138 return False 139 return self.is_view_of(v1._base, v2) 140 141 # Performs transpose if contiguous=True, else returns the input tensor as is 142 def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): 143 if contiguous: 144 return x 145 else: 146 return x.transpose(dim0, dim1) 147 148 @dtypes(*all_types_and(torch.half, torch.bfloat16)) 149 def test_conj_self(self, device, dtype): 150 t = torch.ones(5, 5, device=device) 151 s = t.conj() 152 self.assertTrue(s is t) 153 154 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 155 @onlyNativeDeviceTypes 156 @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 157 def test_view_dtype_new(self, device, dtype): 158 dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} 159 del dtypes[torch.bool] 160 161 def generate_inputs(): 162 yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5) 163 yield make_tensor( 164 (4, 4, 64), dtype=dtype, device=device, low=-5, high=5 165 ).permute(1, 0, 2) 166 yield make_tensor( 167 (4, 64, 4), dtype=dtype, device=device, low=-5, high=5 168 ).permute(2, 0, 1) 169 yield make_tensor( 170 (1, 5, 1), dtype=dtype, device=device, low=-5, high=5 171 ).expand(5, 5, 64) 172 yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[ 173 1::2, 1:, ::2 174 ] 175 yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5) 176 yield make_tensor((), dtype=dtype, device=device, low=-5, high=5) 177 178 def calc_expected_size_and_stride(a, view_dtype): 179 dtype_size = torch._utils._element_size(a.dtype) 180 view_dtype_size = torch._utils._element_size(view_dtype) 181 182 if dtype_size == view_dtype_size: 183 return a.size(), a.stride() 184 185 elif dtype_size > view_dtype_size: 186 size_ratio = dtype_size // view_dtype_size 187 188 view_size = list(a.size()) 189 view_size[-1] = view_size[-1] * size_ratio 190 191 view_stride = [stride * size_ratio for stride in a.stride()] 192 view_stride[-1] = 1 193 return torch.Size(view_size), tuple(view_stride) 194 195 else: 196 size_ratio = view_dtype_size // dtype_size 197 198 view_size = list(a.size()) 199 view_size[-1] = view_size[-1] // size_ratio 200 201 view_stride = [stride // size_ratio for stride in a.stride()] 202 view_stride[-1] = 1 203 return torch.Size(view_size), tuple(view_stride) 204 205 for a in generate_inputs(): 206 a_np = a.cpu().numpy() 207 a_np_contiguous = a.cpu().contiguous().numpy() 208 209 for view_dtype, np_view_dtype in dtypes.items(): 210 equal_element_size = torch._utils._element_size( 211 dtype 212 ) == torch._utils._element_size(view_dtype) 213 214 if not equal_element_size and a.dim() == 0: 215 with self.assertRaisesRegex( 216 RuntimeError, r"self.dim\(\) cannot be 0" 217 ): 218 a.view(view_dtype) 219 continue 220 221 if not equal_element_size and a.stride(-1) != 1: 222 with self.assertRaisesRegex( 223 RuntimeError, r"self.stride\(-1\) must be 1" 224 ): 225 a.view(view_dtype) 226 continue 227 228 a_view = a.view(view_dtype) 229 self.assertEqual(a_view.dtype, view_dtype) 230 self.assertEqual(a.data_ptr(), a_view.data_ptr()) 231 232 expected_size, expected_stride = calc_expected_size_and_stride( 233 a, view_dtype 234 ) 235 self.assertEqual(a_view.size(), expected_size) 236 self.assertEqual(a_view.stride(), expected_stride) 237 238 self.assertEqual(a_view.view(dtype), a, rtol=0, atol=0) 239 240 # NumPy's dtype view requires contiguous input if target 241 # dtype is a different size 242 if equal_element_size: 243 a_np_view = a_np.view(np_view_dtype) 244 245 else: 246 a_np_view = a_np_contiguous.view(np_view_dtype) 247 248 self.assertEqual(a_view, a_np_view) 249 250 # Test that requires_grad is dropped for floating point casts, 251 # because view(dtype) does not support backward yet 252 # TODO: Remove this when autograd support is added 253 if dtype.is_floating_point or dtype.is_complex: 254 for view_dtype in floating_and_complex_types_and( 255 torch.half, torch.bfloat16 256 ): 257 t = make_tensor( 258 (5, 5, 64), 259 dtype=dtype, 260 device=device, 261 low=-5, 262 high=5, 263 requires_grad=True, 264 ) 265 self.assertFalse(t.view(view_dtype).requires_grad) 266 267 # Test the extra error checks that happen when the view dtype 268 # has a greater element size than the original dtype 269 @onlyNativeDeviceTypes 270 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 271 def test_view_dtype_upsize_errors(self, device, dtype): 272 dtype_size = torch._utils._element_size(dtype) 273 274 for view_dtype in all_types_and_complex_and( 275 torch.half, torch.bfloat16, torch.bool 276 ): 277 view_dtype_size = torch._utils._element_size(view_dtype) 278 if view_dtype_size <= dtype_size: 279 continue 280 281 size_ratio = view_dtype_size // dtype_size 282 a = make_tensor( 283 (4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5 284 ) 285 with self.assertRaisesRegex( 286 RuntimeError, rf"self.size\(-1\) must be divisible by {size_ratio}" 287 ): 288 a.view(view_dtype) 289 290 with self.assertRaisesRegex( 291 RuntimeError, 292 rf"self.storage_offset\(\) must be divisible by {size_ratio}", 293 ): 294 a[:, :, 1:].view(view_dtype) 295 296 a = make_tensor( 297 (4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5 298 ) 299 a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1)) 300 with self.assertRaisesRegex( 301 RuntimeError, rf"self.stride\(1\) must be divisible by {size_ratio}" 302 ): 303 a.view(view_dtype) 304 305 @onlyNativeDeviceTypes 306 def test_view_as_complex(self, device): 307 def fn(contiguous_input=True, dim0=0, dim1=1): 308 t = torch.randn(3, 2, 2, device=device) 309 c_t = t[:, :, 0] + 1j * t[:, :, 1] 310 311 input = self._do_transpose(t, contiguous_input, dim0, dim1) 312 313 if input.size()[-1] != 2: 314 self.assertRaisesRegex( 315 RuntimeError, 316 "Tensor must have a last dimension of size 2", 317 lambda: torch.view_as_complex(input), 318 ) 319 return 320 321 if input.stride()[-1] != 1: 322 self.assertRaisesRegex( 323 RuntimeError, 324 "Tensor must have a last dimension with stride 1", 325 lambda: torch.view_as_complex(input), 326 ) 327 return 328 329 res = torch.view_as_complex(input) 330 self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1)) 331 self.assertTrue(self.is_view_of(t, res)) 332 333 fn() 334 fn(contiguous_input=False) 335 # RuntimeError since in this case the last dim of input would not be of size 2 336 fn(contiguous_input=False, dim0=0, dim1=2) 337 # RuntimeError since in this case the last dim of input would not have stride 1 338 fn(contiguous_input=False, dim0=1, dim1=2) 339 340 # RuntimeError since in this case the stride of non-last dim of input would not be of size 2 341 x = torch.randn(3, 3, device=device) 342 t = torch.as_strided(x, (2, 2), (1, 1)) 343 self.assertRaisesRegex( 344 RuntimeError, 345 "Tensor must have a stride divisible by 2 for all but last dimension", 346 lambda: torch.view_as_complex(t), 347 ) 348 349 # tensor with zero elements 350 x = torch.tensor([], device=device) # torch.Size([0]) 351 self.assertRaisesRegex( 352 RuntimeError, 353 "Tensor must have a last dimension of size 2", 354 lambda: torch.view_as_complex(x), 355 ) 356 357 # zero dimension tensor 358 z = torch.tensor(2.0) 359 self.assertRaisesRegex( 360 RuntimeError, 361 "Input tensor must have one or more dimensions", 362 lambda: torch.view_as_complex(z), 363 ) 364 365 y = x.reshape(0, 2) # torch.Size([0, 2]) 366 res = torch.view_as_complex(y) 367 self.assertTrue(self.is_view_of(x, res)) 368 self.assertEqual(res.shape, torch.Size([0])) 369 370 @onlyNativeDeviceTypes 371 @dtypes(*complex_types(), torch.complex32) 372 def test_view_as_real(self, device, dtype): 373 def fn(contiguous_input=True): 374 t = torch.randn(3, 4, dtype=dtype, device=device) 375 input = self._do_transpose(t, contiguous_input) 376 res = torch.view_as_real(input) 377 self.assertEqual(res[:, :, 0], input.real) 378 self.assertEqual(res[:, :, 1], input.imag) 379 self.assertTrue(self.is_view_of(t, res)) 380 381 fn() 382 fn(contiguous_input=False) 383 384 # tensor with zero elements 385 x = torch.tensor([], dtype=dtype, device=device) 386 res = torch.view_as_real(x) 387 self.assertTrue(self.is_view_of(x, res)) 388 self.assertEqual(res.shape, torch.Size([0, 2])) 389 390 # tensor with zero dim 391 x = torch.tensor(2 + 3j, dtype=dtype, device=device) 392 res = torch.view_as_real(x) 393 self.assertTrue(self.is_view_of(x, res)) 394 self.assertEqual(res.shape, torch.Size([2])) 395 396 @onlyNativeDeviceTypes 397 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 398 def test_view_tensor_split(self, device, dtype): 399 a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9) 400 a_split_dim0 = a.tensor_split(7, 0) 401 for a_split_dim0_tensor in a_split_dim0: 402 self.assertTrue(self.is_view_of(a, a_split_dim0_tensor)) 403 a_split_dim1 = a.tensor_split(7, 1) 404 for a_split_dim1_tensor in a_split_dim1: 405 self.assertTrue(self.is_view_of(a, a_split_dim1_tensor)) 406 407 @onlyNativeDeviceTypes 408 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 409 def test_view_tensor_hsplit(self, device, dtype): 410 t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) 411 t_hsplit = torch.hsplit(t, 2) 412 for t_hsplit_tensor in t_hsplit: 413 self.assertTrue(self.is_view_of(t, t_hsplit_tensor)) 414 t[2, 2, 2] = 7 415 self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2]) 416 417 @onlyNativeDeviceTypes 418 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 419 def test_view_tensor_vsplit(self, device, dtype): 420 t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) 421 t_vsplit = torch.vsplit(t, 2) 422 for t_vsplit_tensor in t_vsplit: 423 self.assertTrue(self.is_view_of(t, t_vsplit_tensor)) 424 t[2, 2, 2] = 7 425 self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2]) 426 427 @onlyNativeDeviceTypes 428 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 429 def test_view_tensor_dsplit(self, device, dtype): 430 t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) 431 t_dsplit = torch.dsplit(t, 2) 432 for t_dsplit_tensor in t_dsplit: 433 self.assertTrue(self.is_view_of(t, t_dsplit_tensor)) 434 t[2, 2, 2] = 7 435 self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) 436 437 @onlyNativeDeviceTypes 438 @dtypes(*all_types_and(torch.half, torch.bfloat16)) 439 def test_imag_noncomplex(self, device, dtype): 440 t = torch.ones((5, 5), dtype=dtype, device=device) 441 442 with self.assertRaises(RuntimeError): 443 torch.imag(t) 444 445 @onlyNativeDeviceTypes 446 @dtypes(*complex_types()) 447 def test_real_imag_view(self, device, dtype): 448 def compare_with_numpy(contiguous_input=True): 449 t = torch.randn(3, 3, dtype=dtype, device=device) 450 if not contiguous_input: 451 u = t.T 452 else: 453 u = t 454 455 re = u.real 456 exp = torch.from_numpy(u.cpu().numpy().real).to(device=device) 457 self.assertEqual(re, exp) 458 # for the case of contiguous_input, t=u 459 # for the case of non contiguous_input, the base still remains 460 # t since we are performing a view operation to make the input non-contiguous 461 self.assertTrue(self.is_view_of(t, re)) 462 463 im = u.imag 464 exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device) 465 self.assertEqual(im, exp) 466 self.assertTrue(self.is_view_of(t, im)) 467 468 compare_with_numpy() 469 compare_with_numpy(contiguous_input=False) 470 471 # ensure storage offset is being correctly set 472 a = torch.randn(10, dtype=dtype) 473 self.assertEqual(a[5:].real, a.real[5:]) 474 self.assertEqual(a[5:].imag, a.imag[5:]) 475 476 @onlyNativeDeviceTypes 477 @dtypes(*complex_types()) 478 def test_conj_imag_view(self, device, dtype) -> None: 479 t = _make_tensor((4, 5), dtype, device) 480 t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device) 481 v = t.conj() 482 self.assertTrue(self.is_view_of(t, v)) 483 self.assertEqual(v, t_numpy_conj) 484 485 if t.is_complex(): 486 v_imag = v.imag 487 self.assertTrue(self.is_view_of(t, v_imag)) 488 self.assertEqual(v_imag, t_numpy_conj.imag) 489 self.assertTrue(v_imag.is_neg()) 490 491 @onlyNativeDeviceTypes 492 def test_conj_view_with_shared_memory(self, device) -> None: 493 a = _make_tensor((4, 5), torch.cfloat, device) 494 b = a.conj() 495 c = a.conj() 496 497 self.assertEqual(torch.add(a, b), a.add_(b)) 498 self.assertEqual(torch.add(b, c), torch.add(b, c, out=a)) 499 self.assertEqual(torch.add(b, c), b.add_(c)) 500 501 @onlyNativeDeviceTypes 502 @dtypes( 503 *product( 504 complex_types(), 505 all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 506 ) 507 ) 508 @suppress_warnings 509 def test_set_real_imag(self, device, dtypes): 510 x = torch.randn(10, dtype=dtypes[0], device=device) 511 512 new_real = _make_tensor((10,), dtypes[1], device) 513 new_imag = _make_tensor((10,), dtypes[1], device) 514 515 x.real = new_real 516 x.imag = new_imag 517 518 if dtypes[1].is_complex: 519 self.assertEqual(x.real, new_real.real, exact_dtype=False) 520 self.assertEqual(x.imag, new_imag.real, exact_dtype=False) 521 522 else: 523 self.assertEqual(x.real, new_real, exact_dtype=False) 524 self.assertEqual(x.imag, new_imag, exact_dtype=False) 525 526 def test_diagonal_view(self, device) -> None: 527 t = torch.ones((5, 5), device=device) 528 v = torch.diagonal(t) 529 self.assertTrue(self.is_view_of(t, v)) 530 531 v[0] = 0 532 self.assertEqual(t[0, 0], v[0]) 533 534 t = torch.ones((3, 3, 3), device=device) 535 v = torch.diagonal(t, offset=1, dim1=1, dim2=2) 536 self.assertTrue(self.is_view_of(t, v)) 537 538 v[0, 0] = 0 539 self.assertEqual(t[0, 0, 1], v[0, 0]) 540 541 def test_select_view(self, device) -> None: 542 t = torch.ones((5, 5), device=device) 543 v = t.select(0, 2) 544 self.assertTrue(self.is_view_of(t, v)) 545 546 v[0] = 0 547 self.assertEqual(t[2, 0], v[0]) 548 549 # Lazy hasn't implemented unbind yet. 550 @skipLazy 551 def test_unbind_view(self, device) -> None: 552 t = torch.zeros((5, 5), device=device) 553 tup = torch.unbind(t) 554 555 for idx, v in enumerate(tup): 556 self.assertTrue(self.is_view_of(t, v)) 557 558 v[0] = idx + 1 559 self.assertEqual(t[idx, 0], v[0]) 560 561 # TODO: opinfo this or move to unbind's test suite 562 def test_unbind(self): 563 stacked = torch.randn(3, 10, 10, requires_grad=True) 564 x, y, z = stacked.unbind() 565 grad = torch.randn(3, 10, 10) 566 torch.autograd.backward([x, y, z], grad.unbind()) 567 self.assertEqual(stacked.grad, grad) 568 # check that it works with only one gradient provided (#9977) 569 for i in range(3): 570 stacked = torch.randn(3, 10, 10, requires_grad=True) 571 outs = stacked.unbind() 572 gi = grad.unbind()[i] 573 (g,) = torch.autograd.grad(outs[i], stacked, gi) 574 g_expected = torch.stack( 575 [gi if j == i else torch.zeros_like(gi) for j in range(3)], dim=0 576 ) 577 self.assertEqual(g, g_expected) 578 # Check with gradcheck 579 stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True) 580 gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True) 581 582 # TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken, 583 # causing asserts to trigger. 584 @skipLazy 585 def test_expand_view(self, device) -> None: 586 t = torch.ones((5, 1), device=device) 587 v = t.expand(5, 5) 588 self.assertTrue(self.is_view_of(t, v)) 589 590 v[2, 2] = 0 591 self.assertEqual(t[2, 0], v[2, 2]) 592 593 def test_expand_as_view(self, device): 594 t = torch.ones((5, 1), device=device) 595 e = torch.empty((5, 5), device=device) 596 v = t.expand_as(e) 597 self.assertTrue(self.is_view_of(t, v)) 598 599 v[2, 2] = 0 600 self.assertEqual(t[2, 0], v[2, 2]) 601 602 def test_narrow_view(self, device): 603 t = torch.ones((5, 5), device=device) 604 v = torch.narrow(t, 1, 2, 2) 605 self.assertTrue(self.is_view_of(t, v)) 606 607 v[0, 0] = 0 608 self.assertEqual(t[0, 2], v[0, 0]) 609 610 def test_permute_view(self, device) -> None: 611 t = torch.ones((5, 5), device=device) 612 v = t.permute(1, 0) 613 self.assertTrue(self.is_view_of(t, v)) 614 615 v[0, 1] = 0 616 self.assertEqual(t[1, 0], v[0, 1]) 617 618 def test_transpose_view(self, device): 619 for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 620 t = torch.ones((5, 5), device=device) 621 v = fn(t, 0, 1) 622 self.assertTrue(self.is_view_of(t, v)) 623 624 v[0, 1] = 0 625 self.assertEqual(t[1, 0], v[0, 1]) 626 627 def test_transpose_inplace_view(self, device): 628 t = torch.ones(5, 5, device=device) 629 v = t.view_as(t) 630 v = v.swapdims_(0, 1) 631 self.assertTrue(self.is_view_of(t, v)) 632 v[0, 1] = 0 633 self.assertEqual(t[1, 0], v[0, 1]) 634 635 t = torch.ones(5, 5, device=device) 636 v = t.view_as(t) 637 v = v.swapaxes_(0, 1) 638 self.assertTrue(self.is_view_of(t, v)) 639 v[0, 1] = 0 640 self.assertEqual(t[1, 0], v[0, 1]) 641 642 t = torch.ones(5, 5, device=device) 643 v = t.view_as(t) 644 v = v.transpose_(0, 1) 645 self.assertTrue(self.is_view_of(t, v)) 646 v[0, 1] = 0 647 self.assertEqual(t[1, 0], v[0, 1]) 648 649 def test_t_view(self, device): 650 t = torch.ones((5, 5), device=device) 651 v = t.t() 652 self.assertTrue(self.is_view_of(t, v)) 653 654 v[0, 1] = 0 655 self.assertEqual(t[1, 0], v[0, 1]) 656 657 def test_t_inplace_view(self, device): 658 t = torch.ones(5, 5, device=device) 659 v = t.view_as(t) 660 v = v.t_() 661 self.assertTrue(self.is_view_of(t, v)) 662 v[0, 1] = 0 663 self.assertEqual(t[1, 0], v[0, 1]) 664 665 def test_T_view(self, device): 666 for op in ("T", "H", "mT", "mH"): 667 t = torch.ones((5, 5), device=device) 668 v = getattr(t, op) 669 self.assertTrue(self.is_view_of(t, v)) 670 671 v[0, 1] = 0 672 self.assertEqual(t[1, 0], v[0, 1]) 673 674 def test_unfold_view(self, device): 675 t = torch.ones(10, device=device) 676 v = t.unfold(0, 3, 2) 677 self.assertTrue(self.is_view_of(t, v)) 678 679 v[1, 0] = 0 680 self.assertEqual(t[2], v[1, 0]) 681 682 def test_squeeze_view(self, device): 683 t = torch.ones(5, 1, 5, device=device) 684 v = torch.squeeze(t) 685 self.assertTrue(self.is_view_of(t, v)) 686 v[0, 1] = 0 687 self.assertEqual(t, v._base) 688 689 def test_squeeze_inplace_view(self, device): 690 t = torch.ones(5, 5, device=device) 691 v = t.view_as(t) 692 v = v.squeeze_() 693 self.assertTrue(self.is_view_of(t, v)) 694 v[0, 1] = 0 695 self.assertEqual(t, v._base) 696 697 def test_unsqueeze_view(self, device): 698 t = torch.ones(5, 5, device=device) 699 v = torch.unsqueeze(t, 1) 700 self.assertTrue(self.is_view_of(t, v)) 701 702 v[0, 0, 1] = 0 703 self.assertEqual(t[0, 1], v[0, 0, 1]) 704 705 def test_unsqueeze_inplace_view(self, device): 706 t = torch.ones(5, 5, device=device) 707 v = t.view_as(t) 708 v = v.unsqueeze_(1) 709 self.assertTrue(self.is_view_of(t, v)) 710 v[0, 0, 1] = 0 711 self.assertEqual(t[0, 1], v[0, 0, 1]) 712 713 def test_as_strided_view(self, device): 714 t = torch.ones(5, 5, device=device) 715 v = torch.as_strided(t, (25,), (1,)) 716 self.assertTrue(self.is_view_of(t, v)) 717 718 v[6] = 0 719 self.assertEqual(t[1, 1], v[6]) 720 721 def test_as_strided_inplace_view(self, device): 722 t = torch.ones(5, 5, device=device) 723 v = t.view_as(t) 724 v = v.as_strided_((25,), (1,)) 725 self.assertTrue(self.is_view_of(t, v)) 726 v[6] = 0 727 self.assertEqual(t[1, 1], v[6]) 728 729 def test_as_strided_gradients(self): 730 def test(x, prepro_fn, size, strides, offset=None): 731 x = x.to(torch.double).detach().requires_grad_() 732 733 # Check that forward will **not** resize storage because it may 734 # cause NaN in output and fail numerical Jacobian check consequently 735 with torch.no_grad(): 736 y = prepro_fn(x) if prepro_fn is not None else x 737 max_offset = sum((si - 1) * st for si, st in zip(size, strides)) 738 max_offset += offset if offset is not None else y.storage_offset() 739 assert max_offset < len(y.storage()), "test case resizes storage" 740 741 def closure(x): 742 if prepro_fn is not None: 743 x = prepro_fn(x) 744 return x.as_strided(size, strides, offset) 745 746 gradcheck(closure, [x], check_forward_ad=True) 747 gradgradcheck(closure, [x]) 748 749 # test 750 test(torch.arange(0, 25), lambda x: x.view(5, 5), [3, 3], [6, 2], 2) 751 752 # test crazy stride at dim with size 1 case 753 test(torch.randn(12), None, [1, 2, 1, 5], [0, 5, 100, 1], 2) 754 755 # test expand case 756 test(torch.randn(5), None, [3, 3, 3], [0, 1, 0], 2) 757 test(torch.randn(5), None, [3, 3, 3], [0, 0, 0], 4) 758 test(torch.randn(5), lambda x: x.expand(5, 5), [5, 5], [0, 1], 0) 759 760 # test non-expand overlapping case 761 test(torch.randn(35), None, [6, 6], [5, 1], 2) 762 test(torch.randn(15), None, [3, 2], [3, 6], 2) 763 764 # test transpose case 765 test(torch.randn(3, 4), None, [4, 3], [1, 4]) 766 767 # test "getting things outside the input" case 768 x = torch.randn(6, 2) 769 test(x[3:], None, [3, 2], [2, 1], 0) # should be all zeros 770 self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3]) 771 772 # test select on expanded input case 773 test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0) 774 775 def test_view_view(self, device): 776 t = torch.ones(5, 5, device=device) 777 v = t.view(25) 778 self.assertTrue(self.is_view_of(t, v)) 779 780 v[6] = 0 781 self.assertEqual(t[1, 1], v[6]) 782 783 def test_view_as_view(self, device): 784 t = torch.ones(5, 5, device=device) 785 e = torch.empty((25,)) 786 v = t.view_as(e) 787 self.assertTrue(self.is_view_of(t, v)) 788 789 v[6] = 0 790 self.assertEqual(t[1, 1], v[6]) 791 792 def test_contiguous_self(self, device): 793 t = torch.ones(5, 5, device=device) 794 s = t.contiguous() 795 self.assertTrue(s is t) 796 797 @skipMeta 798 # self.is_view_of reports false positives for lazy 799 @skipLazy 800 def test_contiguous_nonview(self, device): 801 t = torch.ones(5, 5, device=device) 802 nv = t.t().contiguous() 803 self.assertTrue(not self.is_view_of(t, nv)) 804 805 nv[0, 0] = 0 806 self.assertNotEqual(t[0, 0], nv[0, 0]) 807 808 def test_reshape_view(self, device): 809 t = torch.ones(5, 5, device=device) 810 v = torch.reshape(t, (25,)) 811 self.assertTrue(self.is_view_of(t, v)) 812 813 v[6] = 0 814 self.assertEqual(t[1, 1], v[6]) 815 816 def test_reshape_as_view(self, device): 817 t = torch.ones(5, 5, device=device) 818 e = torch.empty((25,), device=device) 819 v = t.reshape_as(e) 820 self.assertTrue(self.is_view_of(t, v)) 821 822 v[6] = 0 823 self.assertEqual(t[1, 1], v[6]) 824 825 @skipMeta 826 # self.is_view_of reports false positives for lazy 827 @skipLazy 828 def test_reshape_nonview(self, device): 829 t = torch.ones(5, 5, device=device) 830 nv = torch.reshape(t.t(), (25,)) 831 self.assertTrue(not self.is_view_of(t, nv)) 832 833 nv[6] = 0 834 self.assertNotEqual(t[1, 1], nv[6]) 835 836 # This test use as_strided to construct a tensor with overlapping memory, 837 # which is not handled by the functionalization pass. 838 @skipLazy 839 @skipXLA 840 def test_flatten_view(self, device): 841 def test_writes_propagate(t, v): 842 idx_t = (0,) * t.ndim 843 idx_v = (0,) * v.ndim 844 v[idx_v] = 0 845 self.assertEqual(t[idx_t], v[idx_v]) 846 847 t = torch.ones(1, 2, 3, 4, device=device) 848 v = t.flatten() 849 self.assertTrue(self.is_view_of(t, v)) 850 test_writes_propagate(t, v) 851 852 # zero-dimensional tensor 853 t = torch.tensor(1, device=device) 854 v = t.flatten() 855 test_writes_propagate(t, v) 856 self.assertTrue(self.is_view_of(t, v)) 857 858 t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) 859 v = t.flatten(0, 1) 860 test_writes_propagate(t, v) 861 self.assertTrue(self.is_view_of_same_base(t, v)) 862 863 # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: 864 t = torch.ones(720, device=device).as_strided( 865 (2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0) 866 ) 867 # [--1--|---2---|-3-] [--1--|----2---|-3-] 868 v1 = t.flatten(0, 1) 869 v2 = v1.flatten(1, 3) 870 v3 = v2.flatten(2, 2) 871 test_writes_propagate(t, v1) 872 self.assertTrue(self.is_view_of_same_base(t, v1)) 873 test_writes_propagate(t, v2) 874 self.assertTrue(self.is_view_of_same_base(t, v2)) 875 test_writes_propagate(t, v3) 876 self.assertTrue(self.is_view_of_same_base(t, v3)) 877 878 @onlyNativeDeviceTypes 879 def test_flatten_nonview(self, device): 880 def assert_is_nonview(t, nv): 881 idx_t = (0,) * t.ndim 882 idx_nv = (0,) * nv.ndim 883 self.assertTrue(not nv._is_view()) 884 nv[idx_nv] = 0 885 if device != "meta": 886 self.assertNotEqual(t[idx_t], nv[idx_nv]) 887 888 t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) 889 nv = t.flatten(1, 3) 890 assert_is_nonview(t, nv) 891 892 t = torch.ones(2, 2, device=device).T 893 nv = t.flatten() 894 assert_is_nonview(t, nv) 895 896 # flatten returns the original object if start_dim=end_dim 897 t = t = torch.ones(2, 2, device=device) 898 nv = t.flatten(1, 1) 899 self.assertTrue(t is nv) 900 901 def test_basic_indexing_slice_view(self, device): 902 t = torch.ones(5, 5, device=device) 903 v = t[:2, :3] 904 self.assertTrue(self.is_view_of(t, v)) 905 906 v[0, 0] = 0 907 self.assertEqual(t[0, 0], v[0, 0]) 908 909 def test_basic_indexing_ellipses_view(self, device): 910 t = torch.ones(5, 5, device=device) 911 v = t[..., :2] 912 self.assertTrue(self.is_view_of(t, v)) 913 914 v[0, 0] = 0 915 self.assertEqual(t[0, 0], v[0, 0]) 916 917 def test_basic_indexing_newaxis_view(self, device): 918 t = torch.ones(5, 5, device=device) 919 v = t[None, :2, 3] 920 self.assertTrue(self.is_view_of(t, v)) 921 922 v[0, 0] = 0 923 self.assertEqual(t[0, 3], v[0, 0]) 924 925 def test_advanced_indexing_nonview(self, device): 926 t = torch.ones(3, 3, device=device) 927 rows = torch.tensor([[0, 0], [2, 2]], device=device) 928 cols = torch.tensor([[0, 1], [2, 2]], device=device) 929 nv = t[rows, cols] 930 self.assertTrue(not self.is_view_of(t, nv)) 931 932 nv[1, 1] = 0 933 self.assertNotEqual(t[2, 2], nv[1, 1]) 934 935 @unittest.skipIf( 936 IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds" 937 ) 938 def test_advanced_indexing_assignment(self, device): 939 t = torch.ones(3, 3, device=device) 940 rows = torch.tensor([[0, 0], [2, 2]], device=device) 941 cols = torch.tensor([[0, 1], [2, 2]], device=device) 942 t[rows, cols] = 0 943 self.assertEqual(t[2, 2], 0) 944 945 @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") 946 def test_chunk_view(self, device): 947 t = torch.zeros(3, 3, device=device) 948 l = torch.chunk(t, 3) 949 950 for idx, v in enumerate(l): 951 self.assertTrue(self.is_view_of(t, v)) 952 953 v[0, 0] = idx + 1 954 self.assertEqual(t[idx, 0], v[0, 0]) 955 956 @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") 957 def test_split_view(self, device): 958 t = torch.zeros(3, 3, device=device) 959 l = torch.split(t, [1, 1, 1]) 960 961 for idx, v in enumerate(l): 962 self.assertTrue(self.is_view_of(t, v)) 963 964 v[0, 0] = idx + 1 965 self.assertEqual(t[idx, 0], v[0, 0]) 966 967 def test_movedim_view(self, device): 968 def run_test(device, op): 969 t = torch.zeros(3, 3, device=device) 970 out = op(t) 971 972 self.assertTrue(self.is_view_of(t, out)) 973 974 # Randomly change values in output 975 # and verify that original is changed 976 # as well. 977 for _ in range(3): 978 idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) 979 out[idx_1, idx_2] = random.random() 980 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) 981 982 for fn in [torch.movedim, torch.moveaxis]: 983 op = partial(fn, source=(0, 1), destination=(1, 0)) 984 run_test(device, op) 985 986 op = partial(fn, source=0, destination=1) 987 run_test(device, op) 988 989 # Testing that the generated view_copy kernel and its derivative are implemented correctly 990 def test_view_copy(self, device): 991 a = torch.randn(4, device=device, requires_grad=True) 992 a_ref = a.clone().detach().requires_grad_() 993 a_view = a_ref.view(2, 2) 994 a_view_copy = torch.view_copy(a, (2, 2)) 995 996 # view_copy ops don't preserve view relationship 997 self.assertTrue(self.is_view_of(a_ref, a_view)) 998 self.assertFalse(self.is_view_of(a, a_view_copy)) 999 1000 a_view_copy.sum().backward() 1001 a_view.sum().backward() 1002 1003 # forward and backward give the same shape + result 1004 self.assertEqual(a_view_copy, a_view) 1005 self.assertEqual(a.grad, a_ref.grad) 1006 1007 # Testing that the output of a view_copy kernel (by default) is contiguous. 1008 def test_view_copy_output_contiguous(self, device): 1009 a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last) 1010 b = torch.ops.aten.slice_copy(a, 0, 0, 2) 1011 self.assertTrue(b.is_contiguous()) 1012 1013 def test_view_copy_out(self, device): 1014 a = torch.randn(2, 2, device=device) 1015 out = torch.empty(2, device=device) 1016 1017 torch.diagonal_copy(a, out=out) 1018 expected = torch.diagonal_copy(a) 1019 1020 self.assertEqual(expected, out) 1021 1022 a = torch.randn(4, device=device) 1023 out1 = torch.empty(2, device=device) 1024 out2 = torch.empty(2, device=device) 1025 1026 torch.split_copy(a, 2, out=(out1, out2)) 1027 expected1, expected2 = torch.split_copy(a, 2) 1028 1029 self.assertEqual(expected1, out1) 1030 self.assertEqual(expected2, out2) 1031 1032 1033class TestOldViewOps(TestCase): 1034 def test_ravel(self, device): 1035 def _test_ravel(tensors, size, nc=False): 1036 for src in tensors: 1037 # Continuous Tensor -> View 1038 flat = src.ravel() 1039 self.assertEqual(flat.shape, torch.Size([size])) 1040 self.assertEqual(src.view(-1), flat) 1041 self.assertIs(flat._base, src) 1042 self.assertTrue(flat.is_contiguous()) 1043 1044 # Non-continuous Tensor -> Copy 1045 if nc: 1046 nc_src = src.t() 1047 nc_flat = nc_src.ravel() 1048 self.assertEqual(nc_flat.shape, torch.Size([size])) 1049 self.assertEqual(nc_src.contiguous().view(-1), nc_flat) 1050 self.assertIsNot(nc_flat._base, src) 1051 self.assertTrue(nc_flat.is_contiguous()) 1052 1053 # Test that flatten returns 1-dim tensor when given a 0-dim tensor 1054 zero_dim_tensor = torch.tensor(123, device=device) 1055 flat0 = zero_dim_tensor.ravel() 1056 one_dim_tensor = torch.tensor([123], device=device) 1057 flat1 = zero_dim_tensor.ravel() 1058 nc_ones_tensor = torch.ones(10, device=device)[::2] 1059 flat2 = nc_ones_tensor.ravel() 1060 1061 self.assertEqual(zero_dim_tensor.shape, torch.Size([])) 1062 self.assertEqual(flat0.shape, torch.Size([1])) 1063 self.assertEqual(one_dim_tensor.shape, torch.Size([1])) 1064 self.assertEqual(flat1.shape, torch.Size([1])) 1065 self.assertEqual(nc_ones_tensor.shape, torch.Size([5])) 1066 self.assertEqual(flat2.shape, torch.Size([5])) 1067 self.assertEqual(flat0, one_dim_tensor) 1068 self.assertEqual(flat0, flat1) 1069 self.assertEqual(flat0.shape, flat1.shape) 1070 self.assertTrue(flat0.is_contiguous()) 1071 self.assertTrue(flat1.is_contiguous()) 1072 self.assertTrue(flat2.is_contiguous()) 1073 1074 # Test both float tensor and quantized tensor 1075 tensors = [ 1076 torch.randn(5, 5, 5, 5, device=device), 1077 torch._empty_affine_quantized( 1078 [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device 1079 ), 1080 ] 1081 _test_ravel(tensors, 625) 1082 1083 tensors = [ 1084 torch.randn(0, 2, 3, device=device), 1085 torch.randn(3, 0, 2, device=device), 1086 torch._empty_affine_quantized( 1087 [0, 2, 3], scale=2, zero_point=3, dtype=torch.quint8, device=device 1088 ), 1089 torch._empty_affine_quantized( 1090 [3, 0, 2], scale=2, zero_point=3, dtype=torch.quint8, device=device 1091 ), 1092 ] 1093 _test_ravel(tensors, 0) 1094 1095 tensors = [ 1096 torch.randn(5, 5, device=device), 1097 torch._empty_affine_quantized( 1098 [5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device 1099 ), 1100 ] 1101 _test_ravel(tensors, 25, True) 1102 1103 # TODO: this should be refactored into the view ops test suite 1104 def test_empty_reshape(self, device): 1105 x = torch.randn(0, 6, device=device) 1106 self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) 1107 # should be viewable -- i.e. data_ptr is the same. 1108 self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) 1109 1110 # match NumPy semantics -- don't infer the size of dimension with a degree of freedom 1111 self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) 1112 1113 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 1114 def test_expand(self, device): 1115 tensor = torch.rand(1, 8, 1, device=device) 1116 tensor2 = torch.rand(5, device=device) 1117 template = torch.rand(4, 8, 5, device=device) 1118 target = template.size() 1119 self.assertEqual(tensor.expand_as(template).size(), target) 1120 self.assertEqual(tensor.expand(4, 8, 5).size(), target) 1121 self.assertEqual(tensor.expand(target).size(), target) 1122 self.assertEqual(tensor2.expand_as(template).size(), target) 1123 self.assertEqual(tensor2.expand(4, 8, 5).size(), target) 1124 self.assertEqual(tensor2.expand(target).size(), target) 1125 1126 # test double expand 1127 self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) 1128 1129 # test non-contiguous 1130 noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] 1131 self.assertFalse(noncontig.is_contiguous()) 1132 self.assertEqual( 1133 noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1) 1134 ) 1135 1136 # make sure it's compatible with unsqueeze 1137 expanded = tensor2.expand(1, 1, 5) 1138 unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) 1139 self.assertEqual(expanded, unsqueezed) 1140 self.assertEqual(expanded.stride(), unsqueezed.stride()) 1141 1142 # test -1 as target size 1143 self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) 1144 self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) 1145 1146 # test expanding empty to empty 1147 self.assertEqual( 1148 torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device) 1149 ) 1150 1151 # TODO: this should be refactored into the view ops test suite 1152 def test_view_empty(self, device): 1153 x = torch.randn(0, 6, device=device) 1154 self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) 1155 1156 # TODO: this should be refactored into the view ops test suite 1157 @onlyNativeDeviceTypes 1158 def test_reshape(self, device): 1159 x = torch.randn(3, 3, device=device) 1160 self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) 1161 self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) 1162 self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) 1163 self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) 1164 1165 y = torch.randn(4, 4, 4, device=device)[:, 0, :] 1166 # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape 1167 if device != "meta": 1168 self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) 1169 self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) 1170 self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) 1171 1172 s = torch.randn((), device=device) 1173 self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) 1174 self.assertEqual(s.reshape(-1).shape, (1,)) 1175 self.assertRaises(RuntimeError, lambda: s.reshape(2)) 1176 1177 empty = torch.tensor([], device=device) 1178 self.assertEqual(empty, empty.reshape(-1)) 1179 self.assertEqual(empty, empty.reshape([0])) 1180 # TODO: fix these once we have multi-dimensional empty tensors 1181 self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) 1182 self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) 1183 self.assertRaises(RuntimeError, lambda: empty.reshape(1)) 1184 1185 x = torch.randn(3, 3, device=device) 1186 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) 1187 self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) 1188 self.assertRaises( 1189 RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)) 1190 ) 1191 1192 def test_flatten(self, device): 1193 # Test that flatten returns 1-dim tensor when given a 0-dim tensor 1194 zero_dim_tensor = torch.tensor(123, device=device) 1195 flat0 = zero_dim_tensor.flatten() 1196 one_dim_tensor = torch.tensor([123], device=device) 1197 flat1 = zero_dim_tensor.flatten() 1198 1199 self.assertEqual(zero_dim_tensor.shape, torch.Size([])) 1200 self.assertEqual(flat0.shape, torch.Size([1])) 1201 self.assertEqual(one_dim_tensor.shape, torch.Size([1])) 1202 self.assertEqual(flat1.shape, torch.Size([1])) 1203 self.assertEqual(flat0, one_dim_tensor) 1204 self.assertEqual(flat0, flat1) 1205 self.assertEqual(flat0.shape, flat1.shape) 1206 1207 # Test both float tensor and quantized tensor 1208 tensors = [ 1209 torch.randn(5, 5, 5, 5, device=device), 1210 torch._empty_affine_quantized( 1211 [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device 1212 ), 1213 ] 1214 for src in tensors: 1215 flat = src.flatten(0, -1) 1216 self.assertEqual(flat.shape, torch.Size([625])) 1217 self.assertEqual(src.view(-1), flat.view(-1)) 1218 1219 flat = src.flatten(0, 2) 1220 self.assertEqual(flat.shape, torch.Size([125, 5])) 1221 self.assertEqual(src.view(-1), flat.view(-1)) 1222 1223 flat = src.flatten(0, 1) 1224 self.assertEqual(flat.shape, torch.Size([25, 5, 5])) 1225 self.assertEqual(src.view(-1), flat.view(-1)) 1226 1227 flat = src.flatten(1, 2) 1228 self.assertEqual(flat.shape, torch.Size([5, 25, 5])) 1229 self.assertEqual(src.view(-1), flat.view(-1)) 1230 1231 flat = src.flatten(2, 3) 1232 self.assertEqual(flat.shape, torch.Size([5, 5, 25])) 1233 self.assertEqual(src.view(-1), flat.view(-1)) 1234 1235 flat = src.flatten(-2, -1) 1236 self.assertEqual(flat.shape, torch.Size([5, 5, 25])) 1237 self.assertEqual(src.view(-1), flat.view(-1)) 1238 1239 flat = src.flatten(2, 2) 1240 self.assertEqual(flat, src) 1241 1242 # out of bounds index 1243 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1244 src.flatten(5, 10) 1245 1246 # invalid start and end 1247 with self.assertRaisesRegex( 1248 RuntimeError, "start_dim cannot come after end_dim" 1249 ): 1250 src.flatten(2, 0) 1251 1252 # TODO: update to work on CUDA, too 1253 @onlyCPU 1254 def test_narrow(self, device): 1255 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 1256 self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) 1257 self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) 1258 self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) 1259 self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) 1260 self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) 1261 self.assertEqual( 1262 x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 1263 ) 1264 self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) 1265 self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) 1266 1267 # TODO: update to work on CUDA, too 1268 @onlyCPU 1269 def test_narrow_tensor(self, device): 1270 x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 1271 self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) 1272 with self.assertRaises(Exception): 1273 x.narrow(0, torch.tensor(0.0), 1) 1274 with self.assertRaises(Exception): 1275 x.narrow(0, torch.tensor([0]), 1) 1276 with self.assertRaises(Exception): 1277 x.narrow(0, torch.tensor([0, 1]), 1) 1278 1279 # TODO: make work on CUDA, too 1280 @onlyCPU 1281 def test_t(self, device): 1282 # Test 0D tensors 1283 x = torch.randn(()) 1284 self.assertEqual(x, x.t()) 1285 x = x.to_sparse() 1286 self.assertEqual(x, x.t()) 1287 1288 # Test 1D tensors 1289 x = torch.arange(4) 1290 self.assertEqual(x, x.t()) 1291 x = x.to_sparse() 1292 self.assertEqual(x, x.t()) 1293 1294 # Test 2D tensors 1295 x = torch.rand((2, 2)) 1296 self.assertEqual(x.t(), x.transpose(0, 1)) 1297 x = x.to_sparse() 1298 self.assertEqual(x.t(), x.transpose(0, 1)) 1299 1300 # Test 3D tensor 1301 x = torch.rand((2, 2, 2)) 1302 with self.assertRaisesRegex( 1303 RuntimeError, "expects a tensor with <= 2 dimensions, but self is 3D" 1304 ): 1305 x.t() 1306 x = x.to_sparse() 1307 with self.assertRaisesRegex( 1308 RuntimeError, "expects a tensor with <= 2 sparse and 0 dense dimensions" 1309 ): 1310 x.t() 1311 1312 @onlyCPU 1313 def test_split(self, device): 1314 tensor = torch.rand(7, 4) 1315 split_size = 3 1316 dim = 0 1317 target_sizes = ([3, 4], [3, 4], [1, 4]) 1318 splits = tensor.split(split_size, dim) 1319 start = 0 1320 for target_size, split in zip(target_sizes, splits): 1321 self.assertEqual(split.size(), target_size) 1322 self.assertEqual( 1323 tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1324 ) 1325 start = start + target_size[dim] 1326 1327 # Variable sections split 1328 tensor = torch.randn(20, 10) 1329 dim = 0 1330 split_sizes = [5, 5, 10] 1331 target_sizes = [[5, 10], [5, 10], [10, 10]] 1332 splits = tensor.split(split_sizes, dim) 1333 start = 0 1334 for target_size, split in zip(target_sizes, splits): 1335 self.assertEqual(split.size(), target_size) 1336 self.assertEqual( 1337 tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1338 ) 1339 start = start + target_size[dim] 1340 1341 split_sizes = [2, 2, 6] 1342 target_sizes = ([20, 2], [20, 2], [20, 6]) 1343 dim = 1 1344 splits = tensor.split(split_sizes, dim) 1345 start = 0 1346 for target_size, split in zip(target_sizes, splits): 1347 self.assertEqual(split.size(), target_size) 1348 self.assertEqual( 1349 tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1350 ) 1351 start = start + target_size[dim] 1352 1353 @onlyCPU 1354 def test_chunk(self, device): 1355 tensor = torch.rand(4, 7) 1356 num_chunks = 3 1357 dim = 1 1358 target_sizes = ([4, 3], [4, 3], [4, 1]) 1359 splits = tensor.chunk(num_chunks, dim) 1360 start = 0 1361 for target_size, split in zip(target_sizes, splits): 1362 self.assertEqual(split.size(), target_size) 1363 self.assertEqual( 1364 tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0 1365 ) 1366 start = start + target_size[dim] 1367 1368 # Invalid chunk sizes 1369 error_regex = "chunk expects.*greater than 0" 1370 with self.assertRaisesRegex(RuntimeError, error_regex): 1371 tensor.chunk(0) 1372 with self.assertRaisesRegex(RuntimeError, error_regex): 1373 tensor.chunk(-2) 1374 1375 # TODO: make work on CUDA, too 1376 @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 1377 @onlyCPU 1378 def test_unsqueeze(self, device) -> None: 1379 x = torch.randn(2, 3, 4) 1380 y = x.unsqueeze(1) 1381 self.assertEqual(y, x.view(2, 1, 3, 4)) 1382 y = x.clone().unsqueeze_(2) 1383 self.assertEqual(y, x.view(2, 3, 1, 4)) 1384 1385 x = x[:, 1] 1386 self.assertFalse(x.is_contiguous()) 1387 y = x.unsqueeze(1) 1388 self.assertEqual(y, x.contiguous().view(2, 1, 4)) 1389 y = x.clone().unsqueeze_(2) 1390 self.assertEqual(y, x.contiguous().view(2, 4, 1)) 1391 1392 # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) 1393 def test_big_transpose(self, device): 1394 t = torch.rand(456, 789, device=device) 1395 t1 = t.t().contiguous() 1396 t2 = torch.from_numpy(t.cpu().numpy().transpose()) 1397 self.assertEqual(t1, t2) 1398 1399 def test_T(self, device): 1400 a = torch.randn(2, 3, 4, device=device) 1401 t1 = a.T 1402 t2 = a.permute(2, 1, 0) 1403 self.assertEqual(t2, t1) 1404 b = torch.randn(10, device=device) 1405 self.assertEqual(b, b.T) 1406 1407 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1408 def test_transposes(self, device, dtype): 1409 for op in ("T", "H", "mT", "mH", "adjoint"): 1410 shapes = ( 1411 ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),) 1412 ) 1413 for shape in shapes: 1414 a = make_tensor(shape, device=device, dtype=dtype) 1415 t1 = getattr(a, op) 1416 if op == "adjoint": 1417 t1 = t1() 1418 t2 = a 1419 t2 = t2.transpose(-2, -1) 1420 if op[-1] == "H" or op == "adjoint": 1421 t2 = t2.conj() 1422 self.assertEqual(t2, t1) 1423 1424 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1425 def test_transposes_errors(self, device, dtype): 1426 for op in ("H", "mT", "mH", "adjoint"): 1427 shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) 1428 for shape in shapes: 1429 a = make_tensor(shape, device=device, dtype=dtype) 1430 with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): 1431 t1 = getattr(a, op) 1432 if op == "adjoint": 1433 t1 = t1() 1434 1435 def test_python_types(self, device): 1436 a1 = torch.randn((1, 2), device=device, dtype=torch.float64) 1437 a2 = torch.randn((1, 2), device=device, dtype=float) 1438 self.assertEqual(a1.dtype, a2.dtype) 1439 1440 b1 = torch.arange(10, 20, dtype=torch.int64, device=device) 1441 b2 = torch.arange(10, 20, dtype=int, device=device) 1442 self.assertEqual(b1.dtype, b2.dtype) 1443 1444 c1 = torch.tensor([True, False], dtype=torch.bool, device=device) 1445 c2 = torch.tensor([True, False], dtype=bool, device=device) 1446 self.assertEqual(c1.dtype, c2.dtype) 1447 1448 # TODO: is resize best put in test_view_ops? 1449 def test_resize_as_preserves_strides(self, device): 1450 x = torch.empty(2, 3).t() 1451 old_strides = x.stride() 1452 x.resize_as_(x) 1453 self.assertEqual(x.stride(), old_strides) 1454 1455 def test_memory_format_resize_as(self, device): 1456 def test_helper(shape, memory_format, device): 1457 xc = torch.randn(shape, device=device).contiguous( 1458 memory_format=memory_format 1459 ) 1460 flat = torch.randn(xc.numel(), device=device) 1461 flat.resize_as_(xc, memory_format=torch.preserve_format) 1462 self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 1463 1464 test_helper((10, 3, 32, 32), torch.channels_last, device) 1465 test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device) 1466 1467 def test_memory_format_resize_(self, device): 1468 def test_helper(shape, numel, memory_format, device): 1469 flat = torch.randn(numel, device=device) 1470 flat.resize_(shape, memory_format=memory_format) 1471 self.assertTrue(flat.is_contiguous(memory_format=memory_format)) 1472 1473 test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device) 1474 test_helper( 1475 (3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device 1476 ) 1477 1478 @onlyNativeDeviceTypes 1479 @dtypes(torch.int64, torch.float, torch.complex128) 1480 def test_transpose_invalid(self, device, dtype): 1481 for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 1482 shape = _rand_shape(4, min_size=5, max_size=10) 1483 x = _generate_input(shape, dtype, device, False) 1484 1485 # Invalid `source` and `destination` dimension 1486 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1487 fn(x, 5, 0) 1488 1489 with self.assertRaisesRegex(IndexError, "Dimension out of range"): 1490 fn(x, 0, 5) 1491 1492 @dtypes(torch.int64, torch.float, torch.complex128) 1493 def test_transpose_vs_numpy(self, device, dtype): 1494 for fn in (torch.swapdims, torch.swapaxes, torch.transpose): 1495 for nd in range(5): 1496 shape = _rand_shape(nd, min_size=5, max_size=10) 1497 x = _generate_input(shape, dtype, device, with_extremal=False) 1498 for random_negative in [True, False]: 1499 for src_dim, dst_dim in permutations(range(nd), r=2): 1500 random_prob = random.random() 1501 1502 if random_negative and random_prob > 0.66: 1503 src_dim = src_dim - nd 1504 elif random_negative and random_prob > 0.33: 1505 dst_dim = dst_dim - nd 1506 elif random_negative: 1507 src_dim = src_dim - nd 1508 dst_dim = dst_dim - nd 1509 1510 partial_map = { 1511 torch.swapdims: partial( 1512 torch.swapdims, dim0=src_dim, dim1=dst_dim 1513 ), 1514 torch.swapaxes: partial( 1515 torch.swapaxes, axis0=src_dim, axis1=dst_dim 1516 ), 1517 torch.transpose: partial( 1518 torch.transpose, dim0=src_dim, dim1=dst_dim 1519 ), 1520 } 1521 1522 torch_fn = partial_map[fn] 1523 np_fn = partial(np.swapaxes, axis1=src_dim, axis2=dst_dim) 1524 self.compare_with_numpy( 1525 torch_fn, np_fn, x, device=None, dtype=None 1526 ) 1527 1528 # Move dim to same position 1529 x = torch.randn(2, 3, 5, 7, 11) 1530 partial_map = { 1531 torch.swapdims: partial(torch.swapdims, dim0=0, dim1=0), 1532 torch.swapaxes: partial(torch.swapaxes, axis0=0, axis1=0), 1533 torch.transpose: partial(torch.transpose, dim0=0, dim1=0), 1534 } 1535 torch_fn = partial_map[fn] 1536 np_fn = partial(np.swapaxes, axis1=0, axis2=0) 1537 self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) 1538 1539 def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): 1540 for ndims in range(0, 5): 1541 shape = _rand_shape(ndims, min_size=5, max_size=10) 1542 for n in range(ndims + 1): 1543 for with_extremal in [False, True]: 1544 for contiguous in [False, True]: 1545 # Generate Input. 1546 x = _generate_input(shape, dtype, device, with_extremal) 1547 if contiguous: 1548 x = x.T 1549 self.compare_with_numpy( 1550 torch_fn, np_fn, x, device=None, dtype=None 1551 ) 1552 1553 # Compare sequence input 1554 torch_sequence_x = (x,) * random.randint(3, 10) 1555 np_sequence_x = tuple( 1556 np.array(x.detach().cpu().numpy()) for x in torch_sequence_x 1557 ) 1558 torch_res = torch_fn(*torch_sequence_x) 1559 np_res = np_fn(*np_sequence_x) 1560 1561 torch_res = tuple(x.cpu() for x in torch_res) 1562 np_res = tuple(torch.from_numpy(x) for x in np_res) 1563 self.assertEqual(np_res, torch_res) 1564 1565 # TODO: are these view ops? 1566 @dtypes(*all_types_and_complex_and(torch.half)) 1567 def test_atleast(self, device, dtype): 1568 self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) 1569 self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) 1570 self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype) 1571 1572 # TODO: OpInfo this 1573 def _test_atleast(self, device, torch_fn): 1574 # 0-dim 1575 s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) 1576 1577 gradcheck(lambda x: torch_fn(x), s) 1578 gradgradcheck(lambda x: torch_fn(x), s) 1579 1580 # 1-dim 1581 a = torch.rand(4, dtype=torch.double, requires_grad=True) 1582 1583 gradcheck(lambda x: torch_fn(x), a) 1584 gradgradcheck(lambda x: torch_fn(x), a) 1585 1586 # 2,3,4-dim 1587 b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) 1588 c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) 1589 d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) 1590 1591 input_tuple = (s, a, b, c, d) 1592 gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 1593 gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) 1594 1595 def test_atleast_gradient(self, device): 1596 self._test_atleast(device, torch.atleast_1d) 1597 self._test_atleast(device, torch.atleast_2d) 1598 self._test_atleast(device, torch.atleast_3d) 1599 1600 @onlyCPU 1601 @dtypes(torch.float) 1602 def test_broadcast_tensors(self, device, dtype): 1603 x0 = torch.randn(2, 1, 3, dtype=dtype, device=device) 1604 x1 = torch.randn(3, dtype=dtype, device=device) 1605 x2 = torch.randn(3, 1, dtype=dtype, device=device) 1606 expected_size = (2, 3, 3) 1607 1608 y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) 1609 self.assertTrue(y0.size() == expected_size) 1610 self.assertTrue(y1.size() == expected_size) 1611 self.assertTrue(y2.size() == expected_size) 1612 1613 @onlyCPU 1614 def test_broadcast_shapes(self, device): 1615 examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)] 1616 for s0 in examples: 1617 x0 = torch.randn(s0) 1618 expected = torch.broadcast_tensors(x0)[0].shape 1619 actual = torch.broadcast_shapes(s0) 1620 self.assertEqual(expected, actual) 1621 1622 for s1 in examples: 1623 x1 = torch.randn(s1) 1624 expected = torch.broadcast_tensors(x0, x1)[0].shape 1625 actual = torch.broadcast_shapes(s0, s1) 1626 self.assertEqual(expected, actual) 1627 1628 inputs_list = [[1, 4], [4, 1], [1, 1, 3]] 1629 for integral_inputs in inputs_list: 1630 res1 = torch.broadcast_shapes(*integral_inputs) 1631 res2 = torch.broadcast_tensors(*map(torch.empty, integral_inputs))[0].shape 1632 self.assertEqual(res1, res2) 1633 1634 inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]] 1635 for integral_inputs_with_neg_vals in inputs_with_neg_vals: 1636 with self.assertRaisesRegex( 1637 RuntimeError, "Trying to create tensor with negative dimension" 1638 ): 1639 torch.broadcast_shapes(*integral_inputs_with_neg_vals) 1640 1641 integral_inputs_error_case = [(3, 5), (2, 4, 1)] 1642 for error_input in integral_inputs_error_case: 1643 with self.assertRaisesRegex( 1644 RuntimeError, 1645 "Shape mismatch: objects cannot be broadcast to a single shape", 1646 ): 1647 torch.broadcast_shapes(*error_input) 1648 1649 negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)] 1650 for s0 in negative_inputs: 1651 with self.assertRaisesRegex( 1652 RuntimeError, "Trying to create tensor with negative dimension" 1653 ): 1654 torch.broadcast_shapes(s0) 1655 1656 for s1 in negative_inputs: 1657 with self.assertRaisesRegex( 1658 RuntimeError, "Trying to create tensor with negative dimension" 1659 ): 1660 torch.broadcast_shapes(s0, s1) 1661 1662 float_inputs_error_case = [(1.1, 2.0), (1.1, 1.0)] 1663 for error_case in float_inputs_error_case: 1664 for float_input in error_case: 1665 with self.assertRaisesRegex( 1666 RuntimeError, 1667 "Input shapes " 1668 "should be of type ints, a tuple of ints, or a list of ints", 1669 ): 1670 torch.broadcast_shapes(float_input) 1671 1672 diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))] 1673 for s0 in diff_input_types: 1674 res1 = torch.broadcast_shapes(*s0) 1675 res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape 1676 self.assertEqual(res1, res2) 1677 1678 # Skip BFloat16 since numpy does not support it 1679 @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1680 def test_broadcast_to(self, device, dtype): 1681 def can_broadcast(s0, s1): 1682 # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension 1683 s0 = tuple(reversed(s0)) 1684 s1 = tuple(reversed(s1)) 1685 for i in range(len(s0)): 1686 if s0[i] != 1 and s0[i] != s1[i]: 1687 return False 1688 return True 1689 1690 sizes = ((), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)) 1691 for s0, s1 in combinations(sizes, r=2): 1692 t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9) 1693 t_np = t.cpu().numpy() 1694 1695 if can_broadcast(s0, s1): 1696 res = torch.broadcast_to(t, s1) 1697 np_res = np.broadcast_to(t_np, s1) 1698 self.assertEqual(res, np_res) 1699 else: 1700 with self.assertRaisesRegex( 1701 RuntimeError, 1702 r"The expanded size of the tensor \(\d\) " 1703 r"must match the existing size \(\d\)", 1704 ): 1705 torch.broadcast_to(t, s1) 1706 1707 def test_view(self, device): 1708 tensor = torch.rand(15, device=device) 1709 template = torch.rand(3, 5, device=device) 1710 empty = torch.empty(0, device=device) 1711 target = template.size() 1712 self.assertEqual(tensor.view_as(template).size(), target) 1713 self.assertEqual(tensor.view(3, 5).size(), target) 1714 self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) 1715 self.assertEqual(tensor.view(-1, 5).size(), target) 1716 self.assertEqual(tensor.view(3, -1).size(), target) 1717 tensor_view = tensor.view(5, 3) 1718 tensor_view.fill_(random.uniform(0, 1)) 1719 self.assertEqual(empty.view_as(empty), empty) 1720 self.assertEqual(empty.view(0), empty) 1721 self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) 1722 self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) 1723 1724 # test size inference with empty tensors 1725 self.assertEqual(empty.view(-1).size(), torch.Size([0])) 1726 self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) 1727 1728 with self.assertRaisesRegex( 1729 RuntimeError, r"because the unspecified dimension size -1 can be any value" 1730 ): 1731 empty.view(-1, 0) 1732 1733 with self.assertRaisesRegex( 1734 RuntimeError, r"because the unspecified dimension size -1 can be any value" 1735 ): 1736 empty.view(3, 0, -1, 0) 1737 1738 self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) 1739 self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) 1740 self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) 1741 1742 # test view when tensor is not contiguous in every dimension, but only 1743 # contiguous dimensions are touched. 1744 tensor = ( 1745 torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device) 1746 .transpose(-1, 2) 1747 .transpose(-2, 3) 1748 ) 1749 # size: [ 4, 2, 3, 9, 6, 2, 1, 5] 1750 # stride: [3840, 1620, 1, 3, 54, 27, 324, 324] 1751 # contiguous dim chunks: [__________, ____, ____, __________, ____, ____] 1752 # merging 1 to chunk after: [__________, ____, ____, __________, __________] 1753 contig_tensor = tensor.clone() 1754 # [4, 2] => [8, 1] 1755 # [3] => [3] 1756 # [9] => [3, 3] 1757 # [6, 2] => [4, 1, 3] 1758 # [1, 5] => [5] 1759 view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] 1760 self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) 1761 # [4, 2] => [2, 4] 1762 # [3] => [3] 1763 # [9] => [1, 9] 1764 # [6, 2] => [2, 2, 3] 1765 # [1, 5] => [5, 1] 1766 view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] 1767 self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) 1768 # adding size 1 dims 1769 view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] 1770 self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) 1771 1772 # invalid views 1773 self.assertRaises(RuntimeError, lambda: tensor.view(-1)) 1774 # crossing [4, 2], [3] 1775 self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) 1776 # crossing [6, 2], [1, 5] 1777 self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) 1778 # crossing [9], [6, 2] 1779 self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) 1780 1781 # view with stride 0 dims 1782 tensor = torch.empty(1, 1, device=device).expand( 1783 3, 4 1784 ) # all dims are contiguous 1785 contig_tensor = tensor.clone() 1786 self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) 1787 self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) 1788 self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) 1789 self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) 1790 self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) 1791 1792 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1793 def test_reshape_view_semantics(self, device, dtype): 1794 tensor = make_tensor((15, 4), dtype=dtype, device=device) 1795 target = (20, 3) 1796 1797 # Cases where the tensor can be returned as a view. 1798 view_tensor = tensor.reshape(target) 1799 self.assertEqual((view_tensor.size()), target) 1800 self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr()) 1801 1802 # Cases where the tensor must be copied (transpose makes it non-contiguous forcing 1803 # the copy). 1804 copy_tensor = tensor.transpose(0, 1).reshape(target) 1805 self.assertEqual(copy_tensor.size(), target) 1806 self.assertNotEqual( 1807 tensor.storage().data_ptr(), copy_tensor.storage().data_ptr() 1808 ) 1809 1810 def test_contiguous(self, device): 1811 x = torch.randn(1, 16, 5, 5, device=device) 1812 self.assertTrue(x.is_contiguous()) 1813 stride = list(x.stride()) 1814 stride[0] = 20 1815 # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 1816 x.set_(x.storage(), 0, x.size(), stride) 1817 self.assertTrue(x.is_contiguous()) 1818 1819 @onlyNativeDeviceTypes 1820 # Skip BFloat16 since numpy does not support it 1821 @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1822 def test_tensor_split_sections(self, device, dtype): 1823 input_sizes = [ 1824 (0,), 1825 (10,), 1826 (10, 0), 1827 (0, 10), 1828 (4, 10), 1829 (12, 3), 1830 ] 1831 for input_size in input_sizes: 1832 a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1833 # Run tests on transposed input if it has at least 2 dims 1834 for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: 1835 a_n = a.cpu().numpy() 1836 for dim in range(-a.dim(), a.dim()): 1837 for sections in range(1, 2 * a.size(dim)): 1838 msg = f"input_size {input_size}, sections {sections}, dim {dim}" 1839 result1 = torch.tensor_split(a, sections, dim) 1840 result2 = torch.tensor_split( 1841 a, torch.tensor(sections, dtype=torch.int64), dim 1842 ) 1843 for r1, r2 in zip(result1, result2): 1844 self.assertEqual(r1.device, torch.device(device), msg=msg) 1845 self.assertEqual(r1.dtype, dtype, msg=msg) 1846 self.assertEqual(r2.device, torch.device(device), msg=msg) 1847 self.assertEqual(r2.dtype, dtype, msg=msg) 1848 result_n = np.array_split(a_n, sections, dim) 1849 self.assertEqual(result_n, result1, msg=msg) 1850 self.assertEqual(result_n, result2, msg=msg) 1851 1852 @onlyNativeDeviceTypes 1853 # Skip BFloat16 since numpy does not support it 1854 @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1855 def test_tensor_split_indices(self, device, dtype): 1856 input_sizes = [ 1857 (0,), 1858 (10,), 1859 (10, 0), 1860 (0, 10), 1861 (4, 10), 1862 (12, 3), 1863 ] 1864 indices_args = [ 1865 (), 1866 (0,), 1867 (3,), 1868 (10,), 1869 (-1,), 1870 (-10,), 1871 (2, -1), 1872 (3, 4, 10), 1873 (0, -1, 0, 10), 1874 (1, 5, 2, 8), 1875 ] 1876 for input_size in input_sizes: 1877 a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) 1878 # Run tests on transposed input if it has at least 2 dims 1879 for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: 1880 a_n = a.cpu().numpy() 1881 for dim in range(-a.dim(), a.dim()): 1882 for indices in indices_args: 1883 result_1 = torch.tensor_split(a, indices, dim) 1884 result_2 = torch.tensor_split( 1885 a, torch.tensor(indices, dtype=torch.int64), dim 1886 ) 1887 1888 msg = f"input_size {input_size}, indices {indices}, dim {dim}" 1889 for r1, r2 in zip(result_1, result_2): 1890 self.assertEqual(r1.device, torch.device(device), msg=msg) 1891 self.assertEqual(r1.dtype, dtype, msg=msg) 1892 self.assertEqual(r2.device, torch.device(device), msg=msg) 1893 self.assertEqual(r2.dtype, dtype, msg=msg) 1894 1895 result_n = np.array_split(a_n, indices, dim) 1896 self.assertEqual(result_n, result_1, msg=msg) 1897 self.assertEqual(result_n, result_2, msg=msg) 1898 1899 @onlyNativeDeviceTypes 1900 def test_tensor_split_errors(self, device): 1901 S = 10 1902 test_cases = [ 1903 # input size, sections or indices, dim, error type, error message, numpy error type 1904 [(S,), 10, 1, IndexError, r"Dimension out of range", IndexError], 1905 [ 1906 (), 1907 10, 1908 0, 1909 RuntimeError, 1910 r"tensor_split expected at least a 1-dimensional tensor, " 1911 + "but got a tensor with 0 dims", 1912 IndexError, 1913 ], 1914 [(S,), (10,), 1, IndexError, r"Dimension out of range", IndexError], 1915 [ 1916 (), 1917 (10,), 1918 0, 1919 RuntimeError, 1920 r"tensor_split expected at least a 1-dimensional tensor, " 1921 + "but got a tensor with 0 dims", 1922 IndexError, 1923 ], 1924 [ 1925 (S,), 1926 0, 1927 0, 1928 RuntimeError, 1929 r"number of sections must be larger than 0, got 0", 1930 ValueError, 1931 ], 1932 [ 1933 (S,), 1934 -1, 1935 0, 1936 RuntimeError, 1937 r"number of sections must be larger than 0, got -1", 1938 ValueError, 1939 ], 1940 ] 1941 for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases: 1942 a = torch.randn(input_size, device=device) 1943 msg = f"input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}" 1944 with self.assertRaisesRegex(err, err_msg, msg=msg): 1945 torch.tensor_split(a, sections_or_indices, dim) 1946 with self.assertRaisesRegex(err, err_msg, msg=msg): 1947 torch.tensor_split(a, torch.tensor(sections_or_indices), dim) 1948 with self.assertRaises(numpy_err, msg=msg): 1949 np.array_split(a.cpu().numpy(), sections_or_indices, dim) 1950 1951 # addtional tests for tensor_split with tensor_indices_or_sections 1952 with self.assertRaisesRegex( 1953 RuntimeError, 1954 r"tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float", 1955 ): 1956 torch.tensor_split(a, torch.tensor(1.1), dim) 1957 1958 with self.assertRaisesRegex( 1959 RuntimeError, 1960 r"tensor_split expected tensor_indices_or_sections to be a" 1961 + " zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims", 1962 ): 1963 torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0) 1964 1965 def test_resize_all_dtypes_and_devices(self, device): 1966 shape = (2, 2) 1967 for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 1968 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 1969 x.resize_(shape) 1970 self.assertEqual(shape, x.shape) 1971 1972 def test_resize_as_all_dtypes_and_devices(self, device): 1973 for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 1974 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 1975 y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) 1976 x.resize_as_(y) 1977 self.assertEqual(y.shape, x.shape) 1978 1979 @onlyNativeDeviceTypes 1980 def test_resize_overflow(self, device): 1981 x = torch.empty((), dtype=torch.float64) 1982 with self.assertRaisesRegex( 1983 RuntimeError, "Storage size calculation overflowed" 1984 ): 1985 x.resize_([2, 4, 2**29, 2**29]) 1986 with self.assertRaisesRegex(RuntimeError, "overflow"): 1987 x.resize_([8, 8, 2**29, 2**29]) 1988 with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"): 1989 x.resize_([0, 4, 2305843009213693952]) 1990 1991 def test_view_all_dtypes_and_devices(self, device): 1992 for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 1993 x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) 1994 self.assertEqual(x.view(6).shape, [6]) 1995 1996 @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet") 1997 @onlyCPU 1998 def test_conj_neg_view_numpy_error(self, device): 1999 self.assertRaisesRegex( 2000 RuntimeError, 2001 "has conjugate bit set", 2002 lambda: torch.tensor([1 + 2j]).conj().numpy(), 2003 ) 2004 self.assertRaisesRegex( 2005 RuntimeError, 2006 "has negative bit set", 2007 lambda: torch.tensor([1 + 2j]).conj().imag.numpy(), 2008 ) 2009 self.assertRaisesRegex( 2010 RuntimeError, 2011 "not supported for conjugate view tensors", 2012 lambda: torch.tensor([1 + 2j]).conj().view(torch.float64), 2013 ) 2014 self.assertRaisesRegex( 2015 RuntimeError, 2016 "not supported for tensors with negative bit set", 2017 lambda: torch.tensor([1 + 2j]).conj().imag.view(torch.int32), 2018 ) 2019 2020 @onlyCPU 2021 def test_crow_col_indices(self, device): 2022 crow_indices = (0, 1, 2) 2023 col_indices = (1, 0) 2024 values = (1, 2) 2025 t = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2)) 2026 # This is the test. If crow_indices is not a view op it'll 2027 # trigger an internal assert due to use count greater than 1 2028 # in debug build. 2029 t.crow_indices() 2030 t.col_indices() 2031 2032 2033instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True) 2034instantiate_device_type_tests(TestOldViewOps, globals()) 2035 2036if __name__ == "__main__": 2037 run_tests() 2038