xref: /aosp_15_r20/external/pytorch/test/test_view_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2import random
3import unittest
4from functools import partial
5from itertools import combinations, permutations, product
6
7import numpy as np
8
9import torch
10from torch.testing import make_tensor
11from torch.testing._internal.common_device_type import (
12    dtypes,
13    instantiate_device_type_tests,
14    onlyCPU,
15    onlyNativeDeviceTypes,
16    skipLazy,
17    skipMeta,
18    skipXLA,
19)
20from torch.testing._internal.common_dtype import (
21    all_types_and,
22    all_types_and_complex_and,
23    complex_types,
24    floating_and_complex_types_and,
25)
26from torch.testing._internal.common_utils import (
27    gradcheck,
28    gradgradcheck,
29    IS_FBCODE,
30    numpy_to_torch_dtype_dict,
31    run_tests,
32    skipIfTorchDynamo,
33    suppress_warnings,
34    TestCase,
35)
36
37
38# TODO: replace this with make_tensor() in common_utils.py
39def _generate_input(shape, dtype, device, with_extremal):
40    if shape == ():
41        x = torch.tensor((), dtype=dtype, device=device)
42    else:
43        if dtype.is_floating_point or dtype.is_complex:
44            # work around torch.randn not being implemented for bfloat16
45            if dtype == torch.bfloat16:
46                x = torch.randn(*shape, device=device) * random.randint(30, 100)
47                x = x.to(torch.bfloat16)
48            else:
49                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(
50                    30, 100
51                )
52            x[torch.randn(*shape) > 0.5] = 0
53            if with_extremal and dtype.is_floating_point:
54                # Use extremal values
55                x[torch.randn(*shape) > 0.5] = float("nan")
56                x[torch.randn(*shape) > 0.5] = float("inf")
57                x[torch.randn(*shape) > 0.5] = float("-inf")
58            elif with_extremal and dtype.is_complex:
59                x[torch.randn(*shape) > 0.5] = complex("nan")
60                x[torch.randn(*shape) > 0.5] = complex("inf")
61                x[torch.randn(*shape) > 0.5] = complex("-inf")
62        elif dtype == torch.bool:
63            x = torch.zeros(shape, dtype=dtype, device=device)
64            x[torch.randn(*shape) > 0.5] = True
65        else:
66            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
67
68    return x
69
70
71# TODO: replace this with make_tensor() in common_utils.py
72def _rand_shape(dim, min_size, max_size):
73    shape = []
74    for i in range(dim):
75        shape.append(random.randint(min_size, max_size))
76    return tuple(shape)
77
78
79# TODO: refactor tests to avoid this function
80# Converts half/bfloat16 dtype to float when device is cpu
81def _convert_t(dtype, device):
82    if device == "cpu" and dtype in {torch.half, torch.bfloat16}:
83        return torch.float
84    return dtype
85
86
87# TODO: replace this with make_tensor() in common_utils.py
88# Returns a tensor of the requested shape, dtype, and device
89# Requesting a half CPU tensor returns a float CPU tensor with
90# values representable by a half.
91# Initialization uses randint for non-float types and randn for float types.
92def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
93    # Returns a tensor filled with ones
94    if fill_ones:
95        return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
96
97    # Returns a tensor with random integer values
98    if not (dtype.is_floating_point or dtype.is_complex):
99        t = torch.randint(0, 10, shape, device=device)
100        if dtype != torch.uint8:
101            t = t - 5  # generate negative values also
102        return t.to(_convert_t(dtype, device))
103
104    # Populates the CPU tensor with floats representable as half/bfloat16
105    if dtype == torch.half and device == "cpu":
106        return torch.randn(*shape, dtype=torch.float, device=device).half().float()
107    if dtype == torch.bfloat16 and device == "cpu":
108        return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float()
109
110    # Default: returns a tensor with random float values
111    return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
112
113
114# Tests ops and indexing to ensure they return views (and new tensors) as
115# appropriate.
116class TestViewOps(TestCase):
117    exact_dtype = True
118
119    def is_view_of(self, base, other):
120        if (
121            not other._is_view()
122            or other is base
123            or other._base is not base
124            or base.device != other.device
125        ):
126            return False
127        # Note: only validates storage on native device types
128        # because some accelerators, like XLA, do not expose storage
129        if base.device.type == "cpu" or base.device.type == "cuda":
130            if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
131                return False
132
133        return True
134
135    # Returns true if v1 and v2 are views of the same base
136    def is_view_of_same_base(self, v1, v2):
137        if not v1._is_view() or v1 is v2:
138            return False
139        return self.is_view_of(v1._base, v2)
140
141    # Performs transpose if contiguous=True, else returns the input tensor as is
142    def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
143        if contiguous:
144            return x
145        else:
146            return x.transpose(dim0, dim1)
147
148    @dtypes(*all_types_and(torch.half, torch.bfloat16))
149    def test_conj_self(self, device, dtype):
150        t = torch.ones(5, 5, device=device)
151        s = t.conj()
152        self.assertTrue(s is t)
153
154    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
155    @onlyNativeDeviceTypes
156    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
157    def test_view_dtype_new(self, device, dtype):
158        dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
159        del dtypes[torch.bool]
160
161        def generate_inputs():
162            yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5)
163            yield make_tensor(
164                (4, 4, 64), dtype=dtype, device=device, low=-5, high=5
165            ).permute(1, 0, 2)
166            yield make_tensor(
167                (4, 64, 4), dtype=dtype, device=device, low=-5, high=5
168            ).permute(2, 0, 1)
169            yield make_tensor(
170                (1, 5, 1), dtype=dtype, device=device, low=-5, high=5
171            ).expand(5, 5, 64)
172            yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[
173                1::2, 1:, ::2
174            ]
175            yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5)
176            yield make_tensor((), dtype=dtype, device=device, low=-5, high=5)
177
178        def calc_expected_size_and_stride(a, view_dtype):
179            dtype_size = torch._utils._element_size(a.dtype)
180            view_dtype_size = torch._utils._element_size(view_dtype)
181
182            if dtype_size == view_dtype_size:
183                return a.size(), a.stride()
184
185            elif dtype_size > view_dtype_size:
186                size_ratio = dtype_size // view_dtype_size
187
188                view_size = list(a.size())
189                view_size[-1] = view_size[-1] * size_ratio
190
191                view_stride = [stride * size_ratio for stride in a.stride()]
192                view_stride[-1] = 1
193                return torch.Size(view_size), tuple(view_stride)
194
195            else:
196                size_ratio = view_dtype_size // dtype_size
197
198                view_size = list(a.size())
199                view_size[-1] = view_size[-1] // size_ratio
200
201                view_stride = [stride // size_ratio for stride in a.stride()]
202                view_stride[-1] = 1
203                return torch.Size(view_size), tuple(view_stride)
204
205        for a in generate_inputs():
206            a_np = a.cpu().numpy()
207            a_np_contiguous = a.cpu().contiguous().numpy()
208
209            for view_dtype, np_view_dtype in dtypes.items():
210                equal_element_size = torch._utils._element_size(
211                    dtype
212                ) == torch._utils._element_size(view_dtype)
213
214                if not equal_element_size and a.dim() == 0:
215                    with self.assertRaisesRegex(
216                        RuntimeError, r"self.dim\(\) cannot be 0"
217                    ):
218                        a.view(view_dtype)
219                    continue
220
221                if not equal_element_size and a.stride(-1) != 1:
222                    with self.assertRaisesRegex(
223                        RuntimeError, r"self.stride\(-1\) must be 1"
224                    ):
225                        a.view(view_dtype)
226                    continue
227
228                a_view = a.view(view_dtype)
229                self.assertEqual(a_view.dtype, view_dtype)
230                self.assertEqual(a.data_ptr(), a_view.data_ptr())
231
232                expected_size, expected_stride = calc_expected_size_and_stride(
233                    a, view_dtype
234                )
235                self.assertEqual(a_view.size(), expected_size)
236                self.assertEqual(a_view.stride(), expected_stride)
237
238                self.assertEqual(a_view.view(dtype), a, rtol=0, atol=0)
239
240                # NumPy's dtype view requires contiguous input if target
241                # dtype is a different size
242                if equal_element_size:
243                    a_np_view = a_np.view(np_view_dtype)
244
245                else:
246                    a_np_view = a_np_contiguous.view(np_view_dtype)
247
248                self.assertEqual(a_view, a_np_view)
249
250        # Test that requires_grad is dropped for floating point casts,
251        # because view(dtype) does not support backward yet
252        # TODO: Remove this when autograd support is added
253        if dtype.is_floating_point or dtype.is_complex:
254            for view_dtype in floating_and_complex_types_and(
255                torch.half, torch.bfloat16
256            ):
257                t = make_tensor(
258                    (5, 5, 64),
259                    dtype=dtype,
260                    device=device,
261                    low=-5,
262                    high=5,
263                    requires_grad=True,
264                )
265                self.assertFalse(t.view(view_dtype).requires_grad)
266
267    # Test the extra error checks that happen when the view dtype
268    # has a greater element size than the original dtype
269    @onlyNativeDeviceTypes
270    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
271    def test_view_dtype_upsize_errors(self, device, dtype):
272        dtype_size = torch._utils._element_size(dtype)
273
274        for view_dtype in all_types_and_complex_and(
275            torch.half, torch.bfloat16, torch.bool
276        ):
277            view_dtype_size = torch._utils._element_size(view_dtype)
278            if view_dtype_size <= dtype_size:
279                continue
280
281            size_ratio = view_dtype_size // dtype_size
282            a = make_tensor(
283                (4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5
284            )
285            with self.assertRaisesRegex(
286                RuntimeError, rf"self.size\(-1\) must be divisible by {size_ratio}"
287            ):
288                a.view(view_dtype)
289
290            with self.assertRaisesRegex(
291                RuntimeError,
292                rf"self.storage_offset\(\) must be divisible by {size_ratio}",
293            ):
294                a[:, :, 1:].view(view_dtype)
295
296            a = make_tensor(
297                (4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5
298            )
299            a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1))
300            with self.assertRaisesRegex(
301                RuntimeError, rf"self.stride\(1\) must be divisible by {size_ratio}"
302            ):
303                a.view(view_dtype)
304
305    @onlyNativeDeviceTypes
306    def test_view_as_complex(self, device):
307        def fn(contiguous_input=True, dim0=0, dim1=1):
308            t = torch.randn(3, 2, 2, device=device)
309            c_t = t[:, :, 0] + 1j * t[:, :, 1]
310
311            input = self._do_transpose(t, contiguous_input, dim0, dim1)
312
313            if input.size()[-1] != 2:
314                self.assertRaisesRegex(
315                    RuntimeError,
316                    "Tensor must have a last dimension of size 2",
317                    lambda: torch.view_as_complex(input),
318                )
319                return
320
321            if input.stride()[-1] != 1:
322                self.assertRaisesRegex(
323                    RuntimeError,
324                    "Tensor must have a last dimension with stride 1",
325                    lambda: torch.view_as_complex(input),
326                )
327                return
328
329            res = torch.view_as_complex(input)
330            self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1))
331            self.assertTrue(self.is_view_of(t, res))
332
333        fn()
334        fn(contiguous_input=False)
335        # RuntimeError since in this case the last dim of input would not be of size 2
336        fn(contiguous_input=False, dim0=0, dim1=2)
337        # RuntimeError since in this case the last dim of input would not have stride 1
338        fn(contiguous_input=False, dim0=1, dim1=2)
339
340        # RuntimeError since in this case the stride of non-last dim of input would not be of size 2
341        x = torch.randn(3, 3, device=device)
342        t = torch.as_strided(x, (2, 2), (1, 1))
343        self.assertRaisesRegex(
344            RuntimeError,
345            "Tensor must have a stride divisible by 2 for all but last dimension",
346            lambda: torch.view_as_complex(t),
347        )
348
349        # tensor with zero elements
350        x = torch.tensor([], device=device)  # torch.Size([0])
351        self.assertRaisesRegex(
352            RuntimeError,
353            "Tensor must have a last dimension of size 2",
354            lambda: torch.view_as_complex(x),
355        )
356
357        # zero dimension tensor
358        z = torch.tensor(2.0)
359        self.assertRaisesRegex(
360            RuntimeError,
361            "Input tensor must have one or more dimensions",
362            lambda: torch.view_as_complex(z),
363        )
364
365        y = x.reshape(0, 2)  # torch.Size([0, 2])
366        res = torch.view_as_complex(y)
367        self.assertTrue(self.is_view_of(x, res))
368        self.assertEqual(res.shape, torch.Size([0]))
369
370    @onlyNativeDeviceTypes
371    @dtypes(*complex_types(), torch.complex32)
372    def test_view_as_real(self, device, dtype):
373        def fn(contiguous_input=True):
374            t = torch.randn(3, 4, dtype=dtype, device=device)
375            input = self._do_transpose(t, contiguous_input)
376            res = torch.view_as_real(input)
377            self.assertEqual(res[:, :, 0], input.real)
378            self.assertEqual(res[:, :, 1], input.imag)
379            self.assertTrue(self.is_view_of(t, res))
380
381        fn()
382        fn(contiguous_input=False)
383
384        # tensor with zero elements
385        x = torch.tensor([], dtype=dtype, device=device)
386        res = torch.view_as_real(x)
387        self.assertTrue(self.is_view_of(x, res))
388        self.assertEqual(res.shape, torch.Size([0, 2]))
389
390        # tensor with zero dim
391        x = torch.tensor(2 + 3j, dtype=dtype, device=device)
392        res = torch.view_as_real(x)
393        self.assertTrue(self.is_view_of(x, res))
394        self.assertEqual(res.shape, torch.Size([2]))
395
396    @onlyNativeDeviceTypes
397    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
398    def test_view_tensor_split(self, device, dtype):
399        a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
400        a_split_dim0 = a.tensor_split(7, 0)
401        for a_split_dim0_tensor in a_split_dim0:
402            self.assertTrue(self.is_view_of(a, a_split_dim0_tensor))
403        a_split_dim1 = a.tensor_split(7, 1)
404        for a_split_dim1_tensor in a_split_dim1:
405            self.assertTrue(self.is_view_of(a, a_split_dim1_tensor))
406
407    @onlyNativeDeviceTypes
408    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
409    def test_view_tensor_hsplit(self, device, dtype):
410        t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
411        t_hsplit = torch.hsplit(t, 2)
412        for t_hsplit_tensor in t_hsplit:
413            self.assertTrue(self.is_view_of(t, t_hsplit_tensor))
414        t[2, 2, 2] = 7
415        self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2])
416
417    @onlyNativeDeviceTypes
418    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
419    def test_view_tensor_vsplit(self, device, dtype):
420        t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
421        t_vsplit = torch.vsplit(t, 2)
422        for t_vsplit_tensor in t_vsplit:
423            self.assertTrue(self.is_view_of(t, t_vsplit_tensor))
424        t[2, 2, 2] = 7
425        self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2])
426
427    @onlyNativeDeviceTypes
428    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
429    def test_view_tensor_dsplit(self, device, dtype):
430        t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
431        t_dsplit = torch.dsplit(t, 2)
432        for t_dsplit_tensor in t_dsplit:
433            self.assertTrue(self.is_view_of(t, t_dsplit_tensor))
434        t[2, 2, 2] = 7
435        self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
436
437    @onlyNativeDeviceTypes
438    @dtypes(*all_types_and(torch.half, torch.bfloat16))
439    def test_imag_noncomplex(self, device, dtype):
440        t = torch.ones((5, 5), dtype=dtype, device=device)
441
442        with self.assertRaises(RuntimeError):
443            torch.imag(t)
444
445    @onlyNativeDeviceTypes
446    @dtypes(*complex_types())
447    def test_real_imag_view(self, device, dtype):
448        def compare_with_numpy(contiguous_input=True):
449            t = torch.randn(3, 3, dtype=dtype, device=device)
450            if not contiguous_input:
451                u = t.T
452            else:
453                u = t
454
455            re = u.real
456            exp = torch.from_numpy(u.cpu().numpy().real).to(device=device)
457            self.assertEqual(re, exp)
458            # for the case of contiguous_input, t=u
459            # for the case of non contiguous_input, the base still remains
460            # t since we are performing a view operation to make the input non-contiguous
461            self.assertTrue(self.is_view_of(t, re))
462
463            im = u.imag
464            exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device)
465            self.assertEqual(im, exp)
466            self.assertTrue(self.is_view_of(t, im))
467
468        compare_with_numpy()
469        compare_with_numpy(contiguous_input=False)
470
471        # ensure storage offset is being correctly set
472        a = torch.randn(10, dtype=dtype)
473        self.assertEqual(a[5:].real, a.real[5:])
474        self.assertEqual(a[5:].imag, a.imag[5:])
475
476    @onlyNativeDeviceTypes
477    @dtypes(*complex_types())
478    def test_conj_imag_view(self, device, dtype) -> None:
479        t = _make_tensor((4, 5), dtype, device)
480        t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device)
481        v = t.conj()
482        self.assertTrue(self.is_view_of(t, v))
483        self.assertEqual(v, t_numpy_conj)
484
485        if t.is_complex():
486            v_imag = v.imag
487            self.assertTrue(self.is_view_of(t, v_imag))
488            self.assertEqual(v_imag, t_numpy_conj.imag)
489            self.assertTrue(v_imag.is_neg())
490
491    @onlyNativeDeviceTypes
492    def test_conj_view_with_shared_memory(self, device) -> None:
493        a = _make_tensor((4, 5), torch.cfloat, device)
494        b = a.conj()
495        c = a.conj()
496
497        self.assertEqual(torch.add(a, b), a.add_(b))
498        self.assertEqual(torch.add(b, c), torch.add(b, c, out=a))
499        self.assertEqual(torch.add(b, c), b.add_(c))
500
501    @onlyNativeDeviceTypes
502    @dtypes(
503        *product(
504            complex_types(),
505            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
506        )
507    )
508    @suppress_warnings
509    def test_set_real_imag(self, device, dtypes):
510        x = torch.randn(10, dtype=dtypes[0], device=device)
511
512        new_real = _make_tensor((10,), dtypes[1], device)
513        new_imag = _make_tensor((10,), dtypes[1], device)
514
515        x.real = new_real
516        x.imag = new_imag
517
518        if dtypes[1].is_complex:
519            self.assertEqual(x.real, new_real.real, exact_dtype=False)
520            self.assertEqual(x.imag, new_imag.real, exact_dtype=False)
521
522        else:
523            self.assertEqual(x.real, new_real, exact_dtype=False)
524            self.assertEqual(x.imag, new_imag, exact_dtype=False)
525
526    def test_diagonal_view(self, device) -> None:
527        t = torch.ones((5, 5), device=device)
528        v = torch.diagonal(t)
529        self.assertTrue(self.is_view_of(t, v))
530
531        v[0] = 0
532        self.assertEqual(t[0, 0], v[0])
533
534        t = torch.ones((3, 3, 3), device=device)
535        v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
536        self.assertTrue(self.is_view_of(t, v))
537
538        v[0, 0] = 0
539        self.assertEqual(t[0, 0, 1], v[0, 0])
540
541    def test_select_view(self, device) -> None:
542        t = torch.ones((5, 5), device=device)
543        v = t.select(0, 2)
544        self.assertTrue(self.is_view_of(t, v))
545
546        v[0] = 0
547        self.assertEqual(t[2, 0], v[0])
548
549    # Lazy hasn't implemented unbind yet.
550    @skipLazy
551    def test_unbind_view(self, device) -> None:
552        t = torch.zeros((5, 5), device=device)
553        tup = torch.unbind(t)
554
555        for idx, v in enumerate(tup):
556            self.assertTrue(self.is_view_of(t, v))
557
558            v[0] = idx + 1
559            self.assertEqual(t[idx, 0], v[0])
560
561    # TODO: opinfo this or move to unbind's test suite
562    def test_unbind(self):
563        stacked = torch.randn(3, 10, 10, requires_grad=True)
564        x, y, z = stacked.unbind()
565        grad = torch.randn(3, 10, 10)
566        torch.autograd.backward([x, y, z], grad.unbind())
567        self.assertEqual(stacked.grad, grad)
568        # check that it works with only one gradient provided (#9977)
569        for i in range(3):
570            stacked = torch.randn(3, 10, 10, requires_grad=True)
571            outs = stacked.unbind()
572            gi = grad.unbind()[i]
573            (g,) = torch.autograd.grad(outs[i], stacked, gi)
574            g_expected = torch.stack(
575                [gi if j == i else torch.zeros_like(gi) for j in range(3)], dim=0
576            )
577            self.assertEqual(g, g_expected)
578        # Check with gradcheck
579        stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True)
580        gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True)
581
582    # TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken,
583    # causing asserts to trigger.
584    @skipLazy
585    def test_expand_view(self, device) -> None:
586        t = torch.ones((5, 1), device=device)
587        v = t.expand(5, 5)
588        self.assertTrue(self.is_view_of(t, v))
589
590        v[2, 2] = 0
591        self.assertEqual(t[2, 0], v[2, 2])
592
593    def test_expand_as_view(self, device):
594        t = torch.ones((5, 1), device=device)
595        e = torch.empty((5, 5), device=device)
596        v = t.expand_as(e)
597        self.assertTrue(self.is_view_of(t, v))
598
599        v[2, 2] = 0
600        self.assertEqual(t[2, 0], v[2, 2])
601
602    def test_narrow_view(self, device):
603        t = torch.ones((5, 5), device=device)
604        v = torch.narrow(t, 1, 2, 2)
605        self.assertTrue(self.is_view_of(t, v))
606
607        v[0, 0] = 0
608        self.assertEqual(t[0, 2], v[0, 0])
609
610    def test_permute_view(self, device) -> None:
611        t = torch.ones((5, 5), device=device)
612        v = t.permute(1, 0)
613        self.assertTrue(self.is_view_of(t, v))
614
615        v[0, 1] = 0
616        self.assertEqual(t[1, 0], v[0, 1])
617
618    def test_transpose_view(self, device):
619        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
620            t = torch.ones((5, 5), device=device)
621            v = fn(t, 0, 1)
622            self.assertTrue(self.is_view_of(t, v))
623
624            v[0, 1] = 0
625            self.assertEqual(t[1, 0], v[0, 1])
626
627    def test_transpose_inplace_view(self, device):
628        t = torch.ones(5, 5, device=device)
629        v = t.view_as(t)
630        v = v.swapdims_(0, 1)
631        self.assertTrue(self.is_view_of(t, v))
632        v[0, 1] = 0
633        self.assertEqual(t[1, 0], v[0, 1])
634
635        t = torch.ones(5, 5, device=device)
636        v = t.view_as(t)
637        v = v.swapaxes_(0, 1)
638        self.assertTrue(self.is_view_of(t, v))
639        v[0, 1] = 0
640        self.assertEqual(t[1, 0], v[0, 1])
641
642        t = torch.ones(5, 5, device=device)
643        v = t.view_as(t)
644        v = v.transpose_(0, 1)
645        self.assertTrue(self.is_view_of(t, v))
646        v[0, 1] = 0
647        self.assertEqual(t[1, 0], v[0, 1])
648
649    def test_t_view(self, device):
650        t = torch.ones((5, 5), device=device)
651        v = t.t()
652        self.assertTrue(self.is_view_of(t, v))
653
654        v[0, 1] = 0
655        self.assertEqual(t[1, 0], v[0, 1])
656
657    def test_t_inplace_view(self, device):
658        t = torch.ones(5, 5, device=device)
659        v = t.view_as(t)
660        v = v.t_()
661        self.assertTrue(self.is_view_of(t, v))
662        v[0, 1] = 0
663        self.assertEqual(t[1, 0], v[0, 1])
664
665    def test_T_view(self, device):
666        for op in ("T", "H", "mT", "mH"):
667            t = torch.ones((5, 5), device=device)
668            v = getattr(t, op)
669            self.assertTrue(self.is_view_of(t, v))
670
671            v[0, 1] = 0
672            self.assertEqual(t[1, 0], v[0, 1])
673
674    def test_unfold_view(self, device):
675        t = torch.ones(10, device=device)
676        v = t.unfold(0, 3, 2)
677        self.assertTrue(self.is_view_of(t, v))
678
679        v[1, 0] = 0
680        self.assertEqual(t[2], v[1, 0])
681
682    def test_squeeze_view(self, device):
683        t = torch.ones(5, 1, 5, device=device)
684        v = torch.squeeze(t)
685        self.assertTrue(self.is_view_of(t, v))
686        v[0, 1] = 0
687        self.assertEqual(t, v._base)
688
689    def test_squeeze_inplace_view(self, device):
690        t = torch.ones(5, 5, device=device)
691        v = t.view_as(t)
692        v = v.squeeze_()
693        self.assertTrue(self.is_view_of(t, v))
694        v[0, 1] = 0
695        self.assertEqual(t, v._base)
696
697    def test_unsqueeze_view(self, device):
698        t = torch.ones(5, 5, device=device)
699        v = torch.unsqueeze(t, 1)
700        self.assertTrue(self.is_view_of(t, v))
701
702        v[0, 0, 1] = 0
703        self.assertEqual(t[0, 1], v[0, 0, 1])
704
705    def test_unsqueeze_inplace_view(self, device):
706        t = torch.ones(5, 5, device=device)
707        v = t.view_as(t)
708        v = v.unsqueeze_(1)
709        self.assertTrue(self.is_view_of(t, v))
710        v[0, 0, 1] = 0
711        self.assertEqual(t[0, 1], v[0, 0, 1])
712
713    def test_as_strided_view(self, device):
714        t = torch.ones(5, 5, device=device)
715        v = torch.as_strided(t, (25,), (1,))
716        self.assertTrue(self.is_view_of(t, v))
717
718        v[6] = 0
719        self.assertEqual(t[1, 1], v[6])
720
721    def test_as_strided_inplace_view(self, device):
722        t = torch.ones(5, 5, device=device)
723        v = t.view_as(t)
724        v = v.as_strided_((25,), (1,))
725        self.assertTrue(self.is_view_of(t, v))
726        v[6] = 0
727        self.assertEqual(t[1, 1], v[6])
728
729    def test_as_strided_gradients(self):
730        def test(x, prepro_fn, size, strides, offset=None):
731            x = x.to(torch.double).detach().requires_grad_()
732
733            # Check that forward will **not** resize storage because it may
734            # cause NaN in output and fail numerical Jacobian check consequently
735            with torch.no_grad():
736                y = prepro_fn(x) if prepro_fn is not None else x
737                max_offset = sum((si - 1) * st for si, st in zip(size, strides))
738                max_offset += offset if offset is not None else y.storage_offset()
739                assert max_offset < len(y.storage()), "test case resizes storage"
740
741            def closure(x):
742                if prepro_fn is not None:
743                    x = prepro_fn(x)
744                return x.as_strided(size, strides, offset)
745
746            gradcheck(closure, [x], check_forward_ad=True)
747            gradgradcheck(closure, [x])
748
749        # test
750        test(torch.arange(0, 25), lambda x: x.view(5, 5), [3, 3], [6, 2], 2)
751
752        # test crazy stride at dim with size 1 case
753        test(torch.randn(12), None, [1, 2, 1, 5], [0, 5, 100, 1], 2)
754
755        # test expand case
756        test(torch.randn(5), None, [3, 3, 3], [0, 1, 0], 2)
757        test(torch.randn(5), None, [3, 3, 3], [0, 0, 0], 4)
758        test(torch.randn(5), lambda x: x.expand(5, 5), [5, 5], [0, 1], 0)
759
760        # test non-expand overlapping case
761        test(torch.randn(35), None, [6, 6], [5, 1], 2)
762        test(torch.randn(15), None, [3, 2], [3, 6], 2)
763
764        # test transpose case
765        test(torch.randn(3, 4), None, [4, 3], [1, 4])
766
767        # test "getting things outside the input" case
768        x = torch.randn(6, 2)
769        test(x[3:], None, [3, 2], [2, 1], 0)  # should be all zeros
770        self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3])
771
772        # test select on expanded input case
773        test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0)
774
775    def test_view_view(self, device):
776        t = torch.ones(5, 5, device=device)
777        v = t.view(25)
778        self.assertTrue(self.is_view_of(t, v))
779
780        v[6] = 0
781        self.assertEqual(t[1, 1], v[6])
782
783    def test_view_as_view(self, device):
784        t = torch.ones(5, 5, device=device)
785        e = torch.empty((25,))
786        v = t.view_as(e)
787        self.assertTrue(self.is_view_of(t, v))
788
789        v[6] = 0
790        self.assertEqual(t[1, 1], v[6])
791
792    def test_contiguous_self(self, device):
793        t = torch.ones(5, 5, device=device)
794        s = t.contiguous()
795        self.assertTrue(s is t)
796
797    @skipMeta
798    # self.is_view_of reports false positives for lazy
799    @skipLazy
800    def test_contiguous_nonview(self, device):
801        t = torch.ones(5, 5, device=device)
802        nv = t.t().contiguous()
803        self.assertTrue(not self.is_view_of(t, nv))
804
805        nv[0, 0] = 0
806        self.assertNotEqual(t[0, 0], nv[0, 0])
807
808    def test_reshape_view(self, device):
809        t = torch.ones(5, 5, device=device)
810        v = torch.reshape(t, (25,))
811        self.assertTrue(self.is_view_of(t, v))
812
813        v[6] = 0
814        self.assertEqual(t[1, 1], v[6])
815
816    def test_reshape_as_view(self, device):
817        t = torch.ones(5, 5, device=device)
818        e = torch.empty((25,), device=device)
819        v = t.reshape_as(e)
820        self.assertTrue(self.is_view_of(t, v))
821
822        v[6] = 0
823        self.assertEqual(t[1, 1], v[6])
824
825    @skipMeta
826    # self.is_view_of reports false positives for lazy
827    @skipLazy
828    def test_reshape_nonview(self, device):
829        t = torch.ones(5, 5, device=device)
830        nv = torch.reshape(t.t(), (25,))
831        self.assertTrue(not self.is_view_of(t, nv))
832
833        nv[6] = 0
834        self.assertNotEqual(t[1, 1], nv[6])
835
836    # This test use as_strided to construct a tensor with overlapping memory,
837    # which is not handled by the functionalization pass.
838    @skipLazy
839    @skipXLA
840    def test_flatten_view(self, device):
841        def test_writes_propagate(t, v):
842            idx_t = (0,) * t.ndim
843            idx_v = (0,) * v.ndim
844            v[idx_v] = 0
845            self.assertEqual(t[idx_t], v[idx_v])
846
847        t = torch.ones(1, 2, 3, 4, device=device)
848        v = t.flatten()
849        self.assertTrue(self.is_view_of(t, v))
850        test_writes_propagate(t, v)
851
852        # zero-dimensional tensor
853        t = torch.tensor(1, device=device)
854        v = t.flatten()
855        test_writes_propagate(t, v)
856        self.assertTrue(self.is_view_of(t, v))
857
858        t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
859        v = t.flatten(0, 1)
860        test_writes_propagate(t, v)
861        self.assertTrue(self.is_view_of_same_base(t, v))
862
863        # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
864        t = torch.ones(720, device=device).as_strided(
865            (2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)
866        )
867        #               [--1--|---2---|-3-] [--1--|----2---|-3-]
868        v1 = t.flatten(0, 1)
869        v2 = v1.flatten(1, 3)
870        v3 = v2.flatten(2, 2)
871        test_writes_propagate(t, v1)
872        self.assertTrue(self.is_view_of_same_base(t, v1))
873        test_writes_propagate(t, v2)
874        self.assertTrue(self.is_view_of_same_base(t, v2))
875        test_writes_propagate(t, v3)
876        self.assertTrue(self.is_view_of_same_base(t, v3))
877
878    @onlyNativeDeviceTypes
879    def test_flatten_nonview(self, device):
880        def assert_is_nonview(t, nv):
881            idx_t = (0,) * t.ndim
882            idx_nv = (0,) * nv.ndim
883            self.assertTrue(not nv._is_view())
884            nv[idx_nv] = 0
885            if device != "meta":
886                self.assertNotEqual(t[idx_t], nv[idx_nv])
887
888        t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
889        nv = t.flatten(1, 3)
890        assert_is_nonview(t, nv)
891
892        t = torch.ones(2, 2, device=device).T
893        nv = t.flatten()
894        assert_is_nonview(t, nv)
895
896        # flatten returns the original object if start_dim=end_dim
897        t = t = torch.ones(2, 2, device=device)
898        nv = t.flatten(1, 1)
899        self.assertTrue(t is nv)
900
901    def test_basic_indexing_slice_view(self, device):
902        t = torch.ones(5, 5, device=device)
903        v = t[:2, :3]
904        self.assertTrue(self.is_view_of(t, v))
905
906        v[0, 0] = 0
907        self.assertEqual(t[0, 0], v[0, 0])
908
909    def test_basic_indexing_ellipses_view(self, device):
910        t = torch.ones(5, 5, device=device)
911        v = t[..., :2]
912        self.assertTrue(self.is_view_of(t, v))
913
914        v[0, 0] = 0
915        self.assertEqual(t[0, 0], v[0, 0])
916
917    def test_basic_indexing_newaxis_view(self, device):
918        t = torch.ones(5, 5, device=device)
919        v = t[None, :2, 3]
920        self.assertTrue(self.is_view_of(t, v))
921
922        v[0, 0] = 0
923        self.assertEqual(t[0, 3], v[0, 0])
924
925    def test_advanced_indexing_nonview(self, device):
926        t = torch.ones(3, 3, device=device)
927        rows = torch.tensor([[0, 0], [2, 2]], device=device)
928        cols = torch.tensor([[0, 1], [2, 2]], device=device)
929        nv = t[rows, cols]
930        self.assertTrue(not self.is_view_of(t, nv))
931
932        nv[1, 1] = 0
933        self.assertNotEqual(t[2, 2], nv[1, 1])
934
935    @unittest.skipIf(
936        IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds"
937    )
938    def test_advanced_indexing_assignment(self, device):
939        t = torch.ones(3, 3, device=device)
940        rows = torch.tensor([[0, 0], [2, 2]], device=device)
941        cols = torch.tensor([[0, 1], [2, 2]], device=device)
942        t[rows, cols] = 0
943        self.assertEqual(t[2, 2], 0)
944
945    @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720")
946    def test_chunk_view(self, device):
947        t = torch.zeros(3, 3, device=device)
948        l = torch.chunk(t, 3)
949
950        for idx, v in enumerate(l):
951            self.assertTrue(self.is_view_of(t, v))
952
953            v[0, 0] = idx + 1
954            self.assertEqual(t[idx, 0], v[0, 0])
955
956    @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720")
957    def test_split_view(self, device):
958        t = torch.zeros(3, 3, device=device)
959        l = torch.split(t, [1, 1, 1])
960
961        for idx, v in enumerate(l):
962            self.assertTrue(self.is_view_of(t, v))
963
964            v[0, 0] = idx + 1
965            self.assertEqual(t[idx, 0], v[0, 0])
966
967    def test_movedim_view(self, device):
968        def run_test(device, op):
969            t = torch.zeros(3, 3, device=device)
970            out = op(t)
971
972            self.assertTrue(self.is_view_of(t, out))
973
974            # Randomly change values in output
975            # and verify that original is changed
976            # as well.
977            for _ in range(3):
978                idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
979                out[idx_1, idx_2] = random.random()
980                self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
981
982        for fn in [torch.movedim, torch.moveaxis]:
983            op = partial(fn, source=(0, 1), destination=(1, 0))
984            run_test(device, op)
985
986            op = partial(fn, source=0, destination=1)
987            run_test(device, op)
988
989    # Testing that the generated view_copy kernel and its derivative are implemented correctly
990    def test_view_copy(self, device):
991        a = torch.randn(4, device=device, requires_grad=True)
992        a_ref = a.clone().detach().requires_grad_()
993        a_view = a_ref.view(2, 2)
994        a_view_copy = torch.view_copy(a, (2, 2))
995
996        # view_copy ops don't preserve view relationship
997        self.assertTrue(self.is_view_of(a_ref, a_view))
998        self.assertFalse(self.is_view_of(a, a_view_copy))
999
1000        a_view_copy.sum().backward()
1001        a_view.sum().backward()
1002
1003        # forward and backward give the same shape + result
1004        self.assertEqual(a_view_copy, a_view)
1005        self.assertEqual(a.grad, a_ref.grad)
1006
1007    # Testing that the output of a view_copy kernel (by default) is contiguous.
1008    def test_view_copy_output_contiguous(self, device):
1009        a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last)
1010        b = torch.ops.aten.slice_copy(a, 0, 0, 2)
1011        self.assertTrue(b.is_contiguous())
1012
1013    def test_view_copy_out(self, device):
1014        a = torch.randn(2, 2, device=device)
1015        out = torch.empty(2, device=device)
1016
1017        torch.diagonal_copy(a, out=out)
1018        expected = torch.diagonal_copy(a)
1019
1020        self.assertEqual(expected, out)
1021
1022        a = torch.randn(4, device=device)
1023        out1 = torch.empty(2, device=device)
1024        out2 = torch.empty(2, device=device)
1025
1026        torch.split_copy(a, 2, out=(out1, out2))
1027        expected1, expected2 = torch.split_copy(a, 2)
1028
1029        self.assertEqual(expected1, out1)
1030        self.assertEqual(expected2, out2)
1031
1032
1033class TestOldViewOps(TestCase):
1034    def test_ravel(self, device):
1035        def _test_ravel(tensors, size, nc=False):
1036            for src in tensors:
1037                # Continuous Tensor -> View
1038                flat = src.ravel()
1039                self.assertEqual(flat.shape, torch.Size([size]))
1040                self.assertEqual(src.view(-1), flat)
1041                self.assertIs(flat._base, src)
1042                self.assertTrue(flat.is_contiguous())
1043
1044                # Non-continuous Tensor -> Copy
1045                if nc:
1046                    nc_src = src.t()
1047                    nc_flat = nc_src.ravel()
1048                    self.assertEqual(nc_flat.shape, torch.Size([size]))
1049                    self.assertEqual(nc_src.contiguous().view(-1), nc_flat)
1050                    self.assertIsNot(nc_flat._base, src)
1051                    self.assertTrue(nc_flat.is_contiguous())
1052
1053        # Test that flatten returns 1-dim tensor when given a 0-dim tensor
1054        zero_dim_tensor = torch.tensor(123, device=device)
1055        flat0 = zero_dim_tensor.ravel()
1056        one_dim_tensor = torch.tensor([123], device=device)
1057        flat1 = zero_dim_tensor.ravel()
1058        nc_ones_tensor = torch.ones(10, device=device)[::2]
1059        flat2 = nc_ones_tensor.ravel()
1060
1061        self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
1062        self.assertEqual(flat0.shape, torch.Size([1]))
1063        self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
1064        self.assertEqual(flat1.shape, torch.Size([1]))
1065        self.assertEqual(nc_ones_tensor.shape, torch.Size([5]))
1066        self.assertEqual(flat2.shape, torch.Size([5]))
1067        self.assertEqual(flat0, one_dim_tensor)
1068        self.assertEqual(flat0, flat1)
1069        self.assertEqual(flat0.shape, flat1.shape)
1070        self.assertTrue(flat0.is_contiguous())
1071        self.assertTrue(flat1.is_contiguous())
1072        self.assertTrue(flat2.is_contiguous())
1073
1074        # Test both float tensor and quantized tensor
1075        tensors = [
1076            torch.randn(5, 5, 5, 5, device=device),
1077            torch._empty_affine_quantized(
1078                [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device
1079            ),
1080        ]
1081        _test_ravel(tensors, 625)
1082
1083        tensors = [
1084            torch.randn(0, 2, 3, device=device),
1085            torch.randn(3, 0, 2, device=device),
1086            torch._empty_affine_quantized(
1087                [0, 2, 3], scale=2, zero_point=3, dtype=torch.quint8, device=device
1088            ),
1089            torch._empty_affine_quantized(
1090                [3, 0, 2], scale=2, zero_point=3, dtype=torch.quint8, device=device
1091            ),
1092        ]
1093        _test_ravel(tensors, 0)
1094
1095        tensors = [
1096            torch.randn(5, 5, device=device),
1097            torch._empty_affine_quantized(
1098                [5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device
1099            ),
1100        ]
1101        _test_ravel(tensors, 25, True)
1102
1103    # TODO: this should be refactored into the view ops test suite
1104    def test_empty_reshape(self, device):
1105        x = torch.randn(0, 6, device=device)
1106        self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
1107        # should be viewable -- i.e. data_ptr is the same.
1108        self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
1109
1110        # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
1111        self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
1112
1113    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
1114    def test_expand(self, device):
1115        tensor = torch.rand(1, 8, 1, device=device)
1116        tensor2 = torch.rand(5, device=device)
1117        template = torch.rand(4, 8, 5, device=device)
1118        target = template.size()
1119        self.assertEqual(tensor.expand_as(template).size(), target)
1120        self.assertEqual(tensor.expand(4, 8, 5).size(), target)
1121        self.assertEqual(tensor.expand(target).size(), target)
1122        self.assertEqual(tensor2.expand_as(template).size(), target)
1123        self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
1124        self.assertEqual(tensor2.expand(target).size(), target)
1125
1126        # test double expand
1127        self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
1128
1129        # test non-contiguous
1130        noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
1131        self.assertFalse(noncontig.is_contiguous())
1132        self.assertEqual(
1133            noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)
1134        )
1135
1136        # make sure it's compatible with unsqueeze
1137        expanded = tensor2.expand(1, 1, 5)
1138        unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
1139        self.assertEqual(expanded, unsqueezed)
1140        self.assertEqual(expanded.stride(), unsqueezed.stride())
1141
1142        # test -1 as target size
1143        self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
1144        self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
1145
1146        # test expanding empty to empty
1147        self.assertEqual(
1148            torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)
1149        )
1150
1151    # TODO: this should be refactored into the view ops test suite
1152    def test_view_empty(self, device):
1153        x = torch.randn(0, 6, device=device)
1154        self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
1155
1156    # TODO: this should be refactored into the view ops test suite
1157    @onlyNativeDeviceTypes
1158    def test_reshape(self, device):
1159        x = torch.randn(3, 3, device=device)
1160        self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
1161        self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
1162        self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
1163        self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
1164
1165        y = torch.randn(4, 4, 4, device=device)[:, 0, :]
1166        # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
1167        if device != "meta":
1168            self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
1169        self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
1170        self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
1171
1172        s = torch.randn((), device=device)
1173        self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
1174        self.assertEqual(s.reshape(-1).shape, (1,))
1175        self.assertRaises(RuntimeError, lambda: s.reshape(2))
1176
1177        empty = torch.tensor([], device=device)
1178        self.assertEqual(empty, empty.reshape(-1))
1179        self.assertEqual(empty, empty.reshape([0]))
1180        # TODO: fix these once we have multi-dimensional empty tensors
1181        self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
1182        self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
1183        self.assertRaises(RuntimeError, lambda: empty.reshape(1))
1184
1185        x = torch.randn(3, 3, device=device)
1186        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
1187        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
1188        self.assertRaises(
1189            RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))
1190        )
1191
1192    def test_flatten(self, device):
1193        # Test that flatten returns 1-dim tensor when given a 0-dim tensor
1194        zero_dim_tensor = torch.tensor(123, device=device)
1195        flat0 = zero_dim_tensor.flatten()
1196        one_dim_tensor = torch.tensor([123], device=device)
1197        flat1 = zero_dim_tensor.flatten()
1198
1199        self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
1200        self.assertEqual(flat0.shape, torch.Size([1]))
1201        self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
1202        self.assertEqual(flat1.shape, torch.Size([1]))
1203        self.assertEqual(flat0, one_dim_tensor)
1204        self.assertEqual(flat0, flat1)
1205        self.assertEqual(flat0.shape, flat1.shape)
1206
1207        # Test both float tensor and quantized tensor
1208        tensors = [
1209            torch.randn(5, 5, 5, 5, device=device),
1210            torch._empty_affine_quantized(
1211                [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device
1212            ),
1213        ]
1214        for src in tensors:
1215            flat = src.flatten(0, -1)
1216            self.assertEqual(flat.shape, torch.Size([625]))
1217            self.assertEqual(src.view(-1), flat.view(-1))
1218
1219            flat = src.flatten(0, 2)
1220            self.assertEqual(flat.shape, torch.Size([125, 5]))
1221            self.assertEqual(src.view(-1), flat.view(-1))
1222
1223            flat = src.flatten(0, 1)
1224            self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
1225            self.assertEqual(src.view(-1), flat.view(-1))
1226
1227            flat = src.flatten(1, 2)
1228            self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
1229            self.assertEqual(src.view(-1), flat.view(-1))
1230
1231            flat = src.flatten(2, 3)
1232            self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
1233            self.assertEqual(src.view(-1), flat.view(-1))
1234
1235            flat = src.flatten(-2, -1)
1236            self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
1237            self.assertEqual(src.view(-1), flat.view(-1))
1238
1239            flat = src.flatten(2, 2)
1240            self.assertEqual(flat, src)
1241
1242            # out of bounds index
1243            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1244                src.flatten(5, 10)
1245
1246            # invalid start and end
1247            with self.assertRaisesRegex(
1248                RuntimeError, "start_dim cannot come after end_dim"
1249            ):
1250                src.flatten(2, 0)
1251
1252    # TODO: update to work on CUDA, too
1253    @onlyCPU
1254    def test_narrow(self, device):
1255        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1256        self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
1257        self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
1258        self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
1259        self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
1260        self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
1261        self.assertEqual(
1262            x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1263        )
1264        self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
1265        self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
1266
1267    # TODO: update to work on CUDA, too
1268    @onlyCPU
1269    def test_narrow_tensor(self, device):
1270        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1271        self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
1272        with self.assertRaises(Exception):
1273            x.narrow(0, torch.tensor(0.0), 1)
1274        with self.assertRaises(Exception):
1275            x.narrow(0, torch.tensor([0]), 1)
1276        with self.assertRaises(Exception):
1277            x.narrow(0, torch.tensor([0, 1]), 1)
1278
1279    # TODO: make work on CUDA, too
1280    @onlyCPU
1281    def test_t(self, device):
1282        # Test 0D tensors
1283        x = torch.randn(())
1284        self.assertEqual(x, x.t())
1285        x = x.to_sparse()
1286        self.assertEqual(x, x.t())
1287
1288        # Test 1D tensors
1289        x = torch.arange(4)
1290        self.assertEqual(x, x.t())
1291        x = x.to_sparse()
1292        self.assertEqual(x, x.t())
1293
1294        # Test 2D tensors
1295        x = torch.rand((2, 2))
1296        self.assertEqual(x.t(), x.transpose(0, 1))
1297        x = x.to_sparse()
1298        self.assertEqual(x.t(), x.transpose(0, 1))
1299
1300        # Test 3D tensor
1301        x = torch.rand((2, 2, 2))
1302        with self.assertRaisesRegex(
1303            RuntimeError, "expects a tensor with <= 2 dimensions, but self is 3D"
1304        ):
1305            x.t()
1306        x = x.to_sparse()
1307        with self.assertRaisesRegex(
1308            RuntimeError, "expects a tensor with <= 2 sparse and 0 dense dimensions"
1309        ):
1310            x.t()
1311
1312    @onlyCPU
1313    def test_split(self, device):
1314        tensor = torch.rand(7, 4)
1315        split_size = 3
1316        dim = 0
1317        target_sizes = ([3, 4], [3, 4], [1, 4])
1318        splits = tensor.split(split_size, dim)
1319        start = 0
1320        for target_size, split in zip(target_sizes, splits):
1321            self.assertEqual(split.size(), target_size)
1322            self.assertEqual(
1323                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1324            )
1325            start = start + target_size[dim]
1326
1327        # Variable sections split
1328        tensor = torch.randn(20, 10)
1329        dim = 0
1330        split_sizes = [5, 5, 10]
1331        target_sizes = [[5, 10], [5, 10], [10, 10]]
1332        splits = tensor.split(split_sizes, dim)
1333        start = 0
1334        for target_size, split in zip(target_sizes, splits):
1335            self.assertEqual(split.size(), target_size)
1336            self.assertEqual(
1337                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1338            )
1339            start = start + target_size[dim]
1340
1341        split_sizes = [2, 2, 6]
1342        target_sizes = ([20, 2], [20, 2], [20, 6])
1343        dim = 1
1344        splits = tensor.split(split_sizes, dim)
1345        start = 0
1346        for target_size, split in zip(target_sizes, splits):
1347            self.assertEqual(split.size(), target_size)
1348            self.assertEqual(
1349                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1350            )
1351            start = start + target_size[dim]
1352
1353    @onlyCPU
1354    def test_chunk(self, device):
1355        tensor = torch.rand(4, 7)
1356        num_chunks = 3
1357        dim = 1
1358        target_sizes = ([4, 3], [4, 3], [4, 1])
1359        splits = tensor.chunk(num_chunks, dim)
1360        start = 0
1361        for target_size, split in zip(target_sizes, splits):
1362            self.assertEqual(split.size(), target_size)
1363            self.assertEqual(
1364                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1365            )
1366            start = start + target_size[dim]
1367
1368        # Invalid chunk sizes
1369        error_regex = "chunk expects.*greater than 0"
1370        with self.assertRaisesRegex(RuntimeError, error_regex):
1371            tensor.chunk(0)
1372        with self.assertRaisesRegex(RuntimeError, error_regex):
1373            tensor.chunk(-2)
1374
1375    # TODO: make work on CUDA, too
1376    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
1377    @onlyCPU
1378    def test_unsqueeze(self, device) -> None:
1379        x = torch.randn(2, 3, 4)
1380        y = x.unsqueeze(1)
1381        self.assertEqual(y, x.view(2, 1, 3, 4))
1382        y = x.clone().unsqueeze_(2)
1383        self.assertEqual(y, x.view(2, 3, 1, 4))
1384
1385        x = x[:, 1]
1386        self.assertFalse(x.is_contiguous())
1387        y = x.unsqueeze(1)
1388        self.assertEqual(y, x.contiguous().view(2, 1, 4))
1389        y = x.clone().unsqueeze_(2)
1390        self.assertEqual(y, x.contiguous().view(2, 4, 1))
1391
1392    # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
1393    def test_big_transpose(self, device):
1394        t = torch.rand(456, 789, device=device)
1395        t1 = t.t().contiguous()
1396        t2 = torch.from_numpy(t.cpu().numpy().transpose())
1397        self.assertEqual(t1, t2)
1398
1399    def test_T(self, device):
1400        a = torch.randn(2, 3, 4, device=device)
1401        t1 = a.T
1402        t2 = a.permute(2, 1, 0)
1403        self.assertEqual(t2, t1)
1404        b = torch.randn(10, device=device)
1405        self.assertEqual(b, b.T)
1406
1407    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1408    def test_transposes(self, device, dtype):
1409        for op in ("T", "H", "mT", "mH", "adjoint"):
1410            shapes = (
1411                ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
1412            )
1413            for shape in shapes:
1414                a = make_tensor(shape, device=device, dtype=dtype)
1415                t1 = getattr(a, op)
1416                if op == "adjoint":
1417                    t1 = t1()
1418                t2 = a
1419                t2 = t2.transpose(-2, -1)
1420                if op[-1] == "H" or op == "adjoint":
1421                    t2 = t2.conj()
1422                self.assertEqual(t2, t1)
1423
1424    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1425    def test_transposes_errors(self, device, dtype):
1426        for op in ("H", "mT", "mH", "adjoint"):
1427            shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
1428            for shape in shapes:
1429                a = make_tensor(shape, device=device, dtype=dtype)
1430                with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
1431                    t1 = getattr(a, op)
1432                    if op == "adjoint":
1433                        t1 = t1()
1434
1435    def test_python_types(self, device):
1436        a1 = torch.randn((1, 2), device=device, dtype=torch.float64)
1437        a2 = torch.randn((1, 2), device=device, dtype=float)
1438        self.assertEqual(a1.dtype, a2.dtype)
1439
1440        b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
1441        b2 = torch.arange(10, 20, dtype=int, device=device)
1442        self.assertEqual(b1.dtype, b2.dtype)
1443
1444        c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
1445        c2 = torch.tensor([True, False], dtype=bool, device=device)
1446        self.assertEqual(c1.dtype, c2.dtype)
1447
1448    # TODO: is resize best put in test_view_ops?
1449    def test_resize_as_preserves_strides(self, device):
1450        x = torch.empty(2, 3).t()
1451        old_strides = x.stride()
1452        x.resize_as_(x)
1453        self.assertEqual(x.stride(), old_strides)
1454
1455    def test_memory_format_resize_as(self, device):
1456        def test_helper(shape, memory_format, device):
1457            xc = torch.randn(shape, device=device).contiguous(
1458                memory_format=memory_format
1459            )
1460            flat = torch.randn(xc.numel(), device=device)
1461            flat.resize_as_(xc, memory_format=torch.preserve_format)
1462            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
1463
1464        test_helper((10, 3, 32, 32), torch.channels_last, device)
1465        test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device)
1466
1467    def test_memory_format_resize_(self, device):
1468        def test_helper(shape, numel, memory_format, device):
1469            flat = torch.randn(numel, device=device)
1470            flat.resize_(shape, memory_format=memory_format)
1471            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
1472
1473        test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device)
1474        test_helper(
1475            (3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device
1476        )
1477
1478    @onlyNativeDeviceTypes
1479    @dtypes(torch.int64, torch.float, torch.complex128)
1480    def test_transpose_invalid(self, device, dtype):
1481        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
1482            shape = _rand_shape(4, min_size=5, max_size=10)
1483            x = _generate_input(shape, dtype, device, False)
1484
1485            # Invalid `source` and `destination` dimension
1486            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1487                fn(x, 5, 0)
1488
1489            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1490                fn(x, 0, 5)
1491
1492    @dtypes(torch.int64, torch.float, torch.complex128)
1493    def test_transpose_vs_numpy(self, device, dtype):
1494        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
1495            for nd in range(5):
1496                shape = _rand_shape(nd, min_size=5, max_size=10)
1497                x = _generate_input(shape, dtype, device, with_extremal=False)
1498                for random_negative in [True, False]:
1499                    for src_dim, dst_dim in permutations(range(nd), r=2):
1500                        random_prob = random.random()
1501
1502                        if random_negative and random_prob > 0.66:
1503                            src_dim = src_dim - nd
1504                        elif random_negative and random_prob > 0.33:
1505                            dst_dim = dst_dim - nd
1506                        elif random_negative:
1507                            src_dim = src_dim - nd
1508                            dst_dim = dst_dim - nd
1509
1510                        partial_map = {
1511                            torch.swapdims: partial(
1512                                torch.swapdims, dim0=src_dim, dim1=dst_dim
1513                            ),
1514                            torch.swapaxes: partial(
1515                                torch.swapaxes, axis0=src_dim, axis1=dst_dim
1516                            ),
1517                            torch.transpose: partial(
1518                                torch.transpose, dim0=src_dim, dim1=dst_dim
1519                            ),
1520                        }
1521
1522                        torch_fn = partial_map[fn]
1523                        np_fn = partial(np.swapaxes, axis1=src_dim, axis2=dst_dim)
1524                        self.compare_with_numpy(
1525                            torch_fn, np_fn, x, device=None, dtype=None
1526                        )
1527
1528            # Move dim to same position
1529            x = torch.randn(2, 3, 5, 7, 11)
1530            partial_map = {
1531                torch.swapdims: partial(torch.swapdims, dim0=0, dim1=0),
1532                torch.swapaxes: partial(torch.swapaxes, axis0=0, axis1=0),
1533                torch.transpose: partial(torch.transpose, dim0=0, dim1=0),
1534            }
1535            torch_fn = partial_map[fn]
1536            np_fn = partial(np.swapaxes, axis1=0, axis2=0)
1537            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
1538
1539    def _test_atleast_dim(self, torch_fn, np_fn, device, dtype):
1540        for ndims in range(0, 5):
1541            shape = _rand_shape(ndims, min_size=5, max_size=10)
1542            for n in range(ndims + 1):
1543                for with_extremal in [False, True]:
1544                    for contiguous in [False, True]:
1545                        # Generate Input.
1546                        x = _generate_input(shape, dtype, device, with_extremal)
1547                        if contiguous:
1548                            x = x.T
1549                        self.compare_with_numpy(
1550                            torch_fn, np_fn, x, device=None, dtype=None
1551                        )
1552
1553                        # Compare sequence input
1554                        torch_sequence_x = (x,) * random.randint(3, 10)
1555                        np_sequence_x = tuple(
1556                            np.array(x.detach().cpu().numpy()) for x in torch_sequence_x
1557                        )
1558                        torch_res = torch_fn(*torch_sequence_x)
1559                        np_res = np_fn(*np_sequence_x)
1560
1561                        torch_res = tuple(x.cpu() for x in torch_res)
1562                        np_res = tuple(torch.from_numpy(x) for x in np_res)
1563                        self.assertEqual(np_res, torch_res)
1564
1565    # TODO: are these view ops?
1566    @dtypes(*all_types_and_complex_and(torch.half))
1567    def test_atleast(self, device, dtype):
1568        self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype)
1569        self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype)
1570        self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype)
1571
1572    # TODO: OpInfo this
1573    def _test_atleast(self, device, torch_fn):
1574        # 0-dim
1575        s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
1576
1577        gradcheck(lambda x: torch_fn(x), s)
1578        gradgradcheck(lambda x: torch_fn(x), s)
1579
1580        # 1-dim
1581        a = torch.rand(4, dtype=torch.double, requires_grad=True)
1582
1583        gradcheck(lambda x: torch_fn(x), a)
1584        gradgradcheck(lambda x: torch_fn(x), a)
1585
1586        # 2,3,4-dim
1587        b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
1588        c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
1589        d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
1590
1591        input_tuple = (s, a, b, c, d)
1592        gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
1593        gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
1594
1595    def test_atleast_gradient(self, device):
1596        self._test_atleast(device, torch.atleast_1d)
1597        self._test_atleast(device, torch.atleast_2d)
1598        self._test_atleast(device, torch.atleast_3d)
1599
1600    @onlyCPU
1601    @dtypes(torch.float)
1602    def test_broadcast_tensors(self, device, dtype):
1603        x0 = torch.randn(2, 1, 3, dtype=dtype, device=device)
1604        x1 = torch.randn(3, dtype=dtype, device=device)
1605        x2 = torch.randn(3, 1, dtype=dtype, device=device)
1606        expected_size = (2, 3, 3)
1607
1608        y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2)
1609        self.assertTrue(y0.size() == expected_size)
1610        self.assertTrue(y1.size() == expected_size)
1611        self.assertTrue(y2.size() == expected_size)
1612
1613    @onlyCPU
1614    def test_broadcast_shapes(self, device):
1615        examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
1616        for s0 in examples:
1617            x0 = torch.randn(s0)
1618            expected = torch.broadcast_tensors(x0)[0].shape
1619            actual = torch.broadcast_shapes(s0)
1620            self.assertEqual(expected, actual)
1621
1622            for s1 in examples:
1623                x1 = torch.randn(s1)
1624                expected = torch.broadcast_tensors(x0, x1)[0].shape
1625                actual = torch.broadcast_shapes(s0, s1)
1626                self.assertEqual(expected, actual)
1627
1628        inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
1629        for integral_inputs in inputs_list:
1630            res1 = torch.broadcast_shapes(*integral_inputs)
1631            res2 = torch.broadcast_tensors(*map(torch.empty, integral_inputs))[0].shape
1632            self.assertEqual(res1, res2)
1633
1634        inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]]
1635        for integral_inputs_with_neg_vals in inputs_with_neg_vals:
1636            with self.assertRaisesRegex(
1637                RuntimeError, "Trying to create tensor with negative dimension"
1638            ):
1639                torch.broadcast_shapes(*integral_inputs_with_neg_vals)
1640
1641        integral_inputs_error_case = [(3, 5), (2, 4, 1)]
1642        for error_input in integral_inputs_error_case:
1643            with self.assertRaisesRegex(
1644                RuntimeError,
1645                "Shape mismatch: objects cannot be broadcast to a single shape",
1646            ):
1647                torch.broadcast_shapes(*error_input)
1648
1649        negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)]
1650        for s0 in negative_inputs:
1651            with self.assertRaisesRegex(
1652                RuntimeError, "Trying to create tensor with negative dimension"
1653            ):
1654                torch.broadcast_shapes(s0)
1655
1656            for s1 in negative_inputs:
1657                with self.assertRaisesRegex(
1658                    RuntimeError, "Trying to create tensor with negative dimension"
1659                ):
1660                    torch.broadcast_shapes(s0, s1)
1661
1662        float_inputs_error_case = [(1.1, 2.0), (1.1, 1.0)]
1663        for error_case in float_inputs_error_case:
1664            for float_input in error_case:
1665                with self.assertRaisesRegex(
1666                    RuntimeError,
1667                    "Input shapes "
1668                    "should be of type ints, a tuple of ints, or a list of ints",
1669                ):
1670                    torch.broadcast_shapes(float_input)
1671
1672        diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
1673        for s0 in diff_input_types:
1674            res1 = torch.broadcast_shapes(*s0)
1675            res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
1676            self.assertEqual(res1, res2)
1677
1678    # Skip BFloat16 since numpy does not support it
1679    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1680    def test_broadcast_to(self, device, dtype):
1681        def can_broadcast(s0, s1):
1682            # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension
1683            s0 = tuple(reversed(s0))
1684            s1 = tuple(reversed(s1))
1685            for i in range(len(s0)):
1686                if s0[i] != 1 and s0[i] != s1[i]:
1687                    return False
1688            return True
1689
1690        sizes = ((), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2))
1691        for s0, s1 in combinations(sizes, r=2):
1692            t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9)
1693            t_np = t.cpu().numpy()
1694
1695            if can_broadcast(s0, s1):
1696                res = torch.broadcast_to(t, s1)
1697                np_res = np.broadcast_to(t_np, s1)
1698                self.assertEqual(res, np_res)
1699            else:
1700                with self.assertRaisesRegex(
1701                    RuntimeError,
1702                    r"The expanded size of the tensor \(\d\) "
1703                    r"must match the existing size \(\d\)",
1704                ):
1705                    torch.broadcast_to(t, s1)
1706
1707    def test_view(self, device):
1708        tensor = torch.rand(15, device=device)
1709        template = torch.rand(3, 5, device=device)
1710        empty = torch.empty(0, device=device)
1711        target = template.size()
1712        self.assertEqual(tensor.view_as(template).size(), target)
1713        self.assertEqual(tensor.view(3, 5).size(), target)
1714        self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
1715        self.assertEqual(tensor.view(-1, 5).size(), target)
1716        self.assertEqual(tensor.view(3, -1).size(), target)
1717        tensor_view = tensor.view(5, 3)
1718        tensor_view.fill_(random.uniform(0, 1))
1719        self.assertEqual(empty.view_as(empty), empty)
1720        self.assertEqual(empty.view(0), empty)
1721        self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
1722        self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
1723
1724        # test size inference with empty tensors
1725        self.assertEqual(empty.view(-1).size(), torch.Size([0]))
1726        self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
1727
1728        with self.assertRaisesRegex(
1729            RuntimeError, r"because the unspecified dimension size -1 can be any value"
1730        ):
1731            empty.view(-1, 0)
1732
1733        with self.assertRaisesRegex(
1734            RuntimeError, r"because the unspecified dimension size -1 can be any value"
1735        ):
1736            empty.view(3, 0, -1, 0)
1737
1738        self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
1739        self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
1740        self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
1741
1742        # test view when tensor is not contiguous in every dimension, but only
1743        # contiguous dimensions are touched.
1744        tensor = (
1745            torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device)
1746            .transpose(-1, 2)
1747            .transpose(-2, 3)
1748        )
1749        # size:                      [   4,    2,    3,    9,    6,    2,    1,    5]
1750        # stride:                    [3840, 1620,    1,    3,   54,   27,  324,  324]
1751        # contiguous dim chunks:     [__________, ____, ____, __________, ____, ____]
1752        # merging 1 to chunk after:  [__________, ____, ____, __________, __________]
1753        contig_tensor = tensor.clone()
1754        # [4, 2] => [8, 1]
1755        # [3] => [3]
1756        # [9] => [3, 3]
1757        # [6, 2] => [4, 1, 3]
1758        # [1, 5] => [5]
1759        view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5]
1760        self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
1761        # [4, 2] => [2, 4]
1762        # [3] => [3]
1763        # [9] => [1, 9]
1764        # [6, 2] => [2, 2, 3]
1765        # [1, 5] => [5, 1]
1766        view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1]
1767        self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
1768        # adding size 1 dims
1769        view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1]
1770        self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
1771
1772        # invalid views
1773        self.assertRaises(RuntimeError, lambda: tensor.view(-1))
1774        # crossing [4, 2], [3]
1775        self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5))
1776        # crossing [6, 2], [1, 5]
1777        self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10))
1778        # crossing [9], [6, 2]
1779        self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5))
1780
1781        # view with stride 0 dims
1782        tensor = torch.empty(1, 1, device=device).expand(
1783            3, 4
1784        )  # all dims are contiguous
1785        contig_tensor = tensor.clone()
1786        self.assertEqual(tensor.view(-1), contig_tensor.view(-1))
1787        self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1))
1788        self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1))
1789        self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
1790        self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
1791
1792    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1793    def test_reshape_view_semantics(self, device, dtype):
1794        tensor = make_tensor((15, 4), dtype=dtype, device=device)
1795        target = (20, 3)
1796
1797        # Cases where the tensor can be returned as a view.
1798        view_tensor = tensor.reshape(target)
1799        self.assertEqual((view_tensor.size()), target)
1800        self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr())
1801
1802        # Cases where the tensor must be copied (transpose makes it non-contiguous forcing
1803        # the copy).
1804        copy_tensor = tensor.transpose(0, 1).reshape(target)
1805        self.assertEqual(copy_tensor.size(), target)
1806        self.assertNotEqual(
1807            tensor.storage().data_ptr(), copy_tensor.storage().data_ptr()
1808        )
1809
1810    def test_contiguous(self, device):
1811        x = torch.randn(1, 16, 5, 5, device=device)
1812        self.assertTrue(x.is_contiguous())
1813        stride = list(x.stride())
1814        stride[0] = 20
1815        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
1816        x.set_(x.storage(), 0, x.size(), stride)
1817        self.assertTrue(x.is_contiguous())
1818
1819    @onlyNativeDeviceTypes
1820    # Skip BFloat16 since numpy does not support it
1821    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1822    def test_tensor_split_sections(self, device, dtype):
1823        input_sizes = [
1824            (0,),
1825            (10,),
1826            (10, 0),
1827            (0, 10),
1828            (4, 10),
1829            (12, 3),
1830        ]
1831        for input_size in input_sizes:
1832            a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1833            # Run tests on transposed input if it has at least 2 dims
1834            for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
1835                a_n = a.cpu().numpy()
1836                for dim in range(-a.dim(), a.dim()):
1837                    for sections in range(1, 2 * a.size(dim)):
1838                        msg = f"input_size {input_size}, sections {sections}, dim {dim}"
1839                        result1 = torch.tensor_split(a, sections, dim)
1840                        result2 = torch.tensor_split(
1841                            a, torch.tensor(sections, dtype=torch.int64), dim
1842                        )
1843                        for r1, r2 in zip(result1, result2):
1844                            self.assertEqual(r1.device, torch.device(device), msg=msg)
1845                            self.assertEqual(r1.dtype, dtype, msg=msg)
1846                            self.assertEqual(r2.device, torch.device(device), msg=msg)
1847                            self.assertEqual(r2.dtype, dtype, msg=msg)
1848                        result_n = np.array_split(a_n, sections, dim)
1849                        self.assertEqual(result_n, result1, msg=msg)
1850                        self.assertEqual(result_n, result2, msg=msg)
1851
1852    @onlyNativeDeviceTypes
1853    # Skip BFloat16 since numpy does not support it
1854    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1855    def test_tensor_split_indices(self, device, dtype):
1856        input_sizes = [
1857            (0,),
1858            (10,),
1859            (10, 0),
1860            (0, 10),
1861            (4, 10),
1862            (12, 3),
1863        ]
1864        indices_args = [
1865            (),
1866            (0,),
1867            (3,),
1868            (10,),
1869            (-1,),
1870            (-10,),
1871            (2, -1),
1872            (3, 4, 10),
1873            (0, -1, 0, 10),
1874            (1, 5, 2, 8),
1875        ]
1876        for input_size in input_sizes:
1877            a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1878            # Run tests on transposed input if it has at least 2 dims
1879            for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
1880                a_n = a.cpu().numpy()
1881                for dim in range(-a.dim(), a.dim()):
1882                    for indices in indices_args:
1883                        result_1 = torch.tensor_split(a, indices, dim)
1884                        result_2 = torch.tensor_split(
1885                            a, torch.tensor(indices, dtype=torch.int64), dim
1886                        )
1887
1888                        msg = f"input_size {input_size}, indices {indices}, dim {dim}"
1889                        for r1, r2 in zip(result_1, result_2):
1890                            self.assertEqual(r1.device, torch.device(device), msg=msg)
1891                            self.assertEqual(r1.dtype, dtype, msg=msg)
1892                            self.assertEqual(r2.device, torch.device(device), msg=msg)
1893                            self.assertEqual(r2.dtype, dtype, msg=msg)
1894
1895                        result_n = np.array_split(a_n, indices, dim)
1896                        self.assertEqual(result_n, result_1, msg=msg)
1897                        self.assertEqual(result_n, result_2, msg=msg)
1898
1899    @onlyNativeDeviceTypes
1900    def test_tensor_split_errors(self, device):
1901        S = 10
1902        test_cases = [
1903            # input size, sections or indices, dim, error type, error message, numpy error type
1904            [(S,), 10, 1, IndexError, r"Dimension out of range", IndexError],
1905            [
1906                (),
1907                10,
1908                0,
1909                RuntimeError,
1910                r"tensor_split expected at least a 1-dimensional tensor, "
1911                + "but got a tensor with 0 dims",
1912                IndexError,
1913            ],
1914            [(S,), (10,), 1, IndexError, r"Dimension out of range", IndexError],
1915            [
1916                (),
1917                (10,),
1918                0,
1919                RuntimeError,
1920                r"tensor_split expected at least a 1-dimensional tensor, "
1921                + "but got a tensor with 0 dims",
1922                IndexError,
1923            ],
1924            [
1925                (S,),
1926                0,
1927                0,
1928                RuntimeError,
1929                r"number of sections must be larger than 0, got 0",
1930                ValueError,
1931            ],
1932            [
1933                (S,),
1934                -1,
1935                0,
1936                RuntimeError,
1937                r"number of sections must be larger than 0, got -1",
1938                ValueError,
1939            ],
1940        ]
1941        for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases:
1942            a = torch.randn(input_size, device=device)
1943            msg = f"input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}"
1944            with self.assertRaisesRegex(err, err_msg, msg=msg):
1945                torch.tensor_split(a, sections_or_indices, dim)
1946            with self.assertRaisesRegex(err, err_msg, msg=msg):
1947                torch.tensor_split(a, torch.tensor(sections_or_indices), dim)
1948            with self.assertRaises(numpy_err, msg=msg):
1949                np.array_split(a.cpu().numpy(), sections_or_indices, dim)
1950
1951        # addtional tests for tensor_split with tensor_indices_or_sections
1952        with self.assertRaisesRegex(
1953            RuntimeError,
1954            r"tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float",
1955        ):
1956            torch.tensor_split(a, torch.tensor(1.1), dim)
1957
1958        with self.assertRaisesRegex(
1959            RuntimeError,
1960            r"tensor_split expected tensor_indices_or_sections to be a"
1961            + " zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims",
1962        ):
1963            torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0)
1964
1965    def test_resize_all_dtypes_and_devices(self, device):
1966        shape = (2, 2)
1967        for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
1968            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
1969            x.resize_(shape)
1970            self.assertEqual(shape, x.shape)
1971
1972    def test_resize_as_all_dtypes_and_devices(self, device):
1973        for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
1974            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
1975            y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
1976            x.resize_as_(y)
1977            self.assertEqual(y.shape, x.shape)
1978
1979    @onlyNativeDeviceTypes
1980    def test_resize_overflow(self, device):
1981        x = torch.empty((), dtype=torch.float64)
1982        with self.assertRaisesRegex(
1983            RuntimeError, "Storage size calculation overflowed"
1984        ):
1985            x.resize_([2, 4, 2**29, 2**29])
1986        with self.assertRaisesRegex(RuntimeError, "overflow"):
1987            x.resize_([8, 8, 2**29, 2**29])
1988        with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"):
1989            x.resize_([0, 4, 2305843009213693952])
1990
1991    def test_view_all_dtypes_and_devices(self, device):
1992        for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
1993            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
1994            self.assertEqual(x.view(6).shape, [6])
1995
1996    @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet")
1997    @onlyCPU
1998    def test_conj_neg_view_numpy_error(self, device):
1999        self.assertRaisesRegex(
2000            RuntimeError,
2001            "has conjugate bit set",
2002            lambda: torch.tensor([1 + 2j]).conj().numpy(),
2003        )
2004        self.assertRaisesRegex(
2005            RuntimeError,
2006            "has negative bit set",
2007            lambda: torch.tensor([1 + 2j]).conj().imag.numpy(),
2008        )
2009        self.assertRaisesRegex(
2010            RuntimeError,
2011            "not supported for conjugate view tensors",
2012            lambda: torch.tensor([1 + 2j]).conj().view(torch.float64),
2013        )
2014        self.assertRaisesRegex(
2015            RuntimeError,
2016            "not supported for tensors with negative bit set",
2017            lambda: torch.tensor([1 + 2j]).conj().imag.view(torch.int32),
2018        )
2019
2020    @onlyCPU
2021    def test_crow_col_indices(self, device):
2022        crow_indices = (0, 1, 2)
2023        col_indices = (1, 0)
2024        values = (1, 2)
2025        t = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
2026        # This is the test. If crow_indices is not a view op it'll
2027        # trigger an internal assert due to use count greater than 1
2028        # in debug build.
2029        t.crow_indices()
2030        t.col_indices()
2031
2032
2033instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True)
2034instantiate_device_type_tests(TestOldViewOps, globals())
2035
2036if __name__ == "__main__":
2037    run_tests()
2038