xref: /aosp_15_r20/external/pytorch/test/test_view_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Workerimport random
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
5*da0073e9SAndroid Build Coastguard Workerfrom itertools import combinations, permutations, product
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport numpy as np
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
12*da0073e9SAndroid Build Coastguard Worker    dtypes,
13*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
14*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
15*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
16*da0073e9SAndroid Build Coastguard Worker    skipLazy,
17*da0073e9SAndroid Build Coastguard Worker    skipMeta,
18*da0073e9SAndroid Build Coastguard Worker    skipXLA,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
21*da0073e9SAndroid Build Coastguard Worker    all_types_and,
22*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and,
23*da0073e9SAndroid Build Coastguard Worker    complex_types,
24*da0073e9SAndroid Build Coastguard Worker    floating_and_complex_types_and,
25*da0073e9SAndroid Build Coastguard Worker)
26*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
27*da0073e9SAndroid Build Coastguard Worker    gradcheck,
28*da0073e9SAndroid Build Coastguard Worker    gradgradcheck,
29*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
30*da0073e9SAndroid Build Coastguard Worker    numpy_to_torch_dtype_dict,
31*da0073e9SAndroid Build Coastguard Worker    run_tests,
32*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
33*da0073e9SAndroid Build Coastguard Worker    suppress_warnings,
34*da0073e9SAndroid Build Coastguard Worker    TestCase,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker# TODO: replace this with make_tensor() in common_utils.py
39*da0073e9SAndroid Build Coastguard Workerdef _generate_input(shape, dtype, device, with_extremal):
40*da0073e9SAndroid Build Coastguard Worker    if shape == ():
41*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor((), dtype=dtype, device=device)
42*da0073e9SAndroid Build Coastguard Worker    else:
43*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point or dtype.is_complex:
44*da0073e9SAndroid Build Coastguard Worker            # work around torch.randn not being implemented for bfloat16
45*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
46*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*shape, device=device) * random.randint(30, 100)
47*da0073e9SAndroid Build Coastguard Worker                x = x.to(torch.bfloat16)
48*da0073e9SAndroid Build Coastguard Worker            else:
49*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(
50*da0073e9SAndroid Build Coastguard Worker                    30, 100
51*da0073e9SAndroid Build Coastguard Worker                )
52*da0073e9SAndroid Build Coastguard Worker            x[torch.randn(*shape) > 0.5] = 0
53*da0073e9SAndroid Build Coastguard Worker            if with_extremal and dtype.is_floating_point:
54*da0073e9SAndroid Build Coastguard Worker                # Use extremal values
55*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float("nan")
56*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float("inf")
57*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float("-inf")
58*da0073e9SAndroid Build Coastguard Worker            elif with_extremal and dtype.is_complex:
59*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex("nan")
60*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex("inf")
61*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex("-inf")
62*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.bool:
63*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(shape, dtype=dtype, device=device)
64*da0073e9SAndroid Build Coastguard Worker            x[torch.randn(*shape) > 0.5] = True
65*da0073e9SAndroid Build Coastguard Worker        else:
66*da0073e9SAndroid Build Coastguard Worker            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    return x
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker# TODO: replace this with make_tensor() in common_utils.py
72*da0073e9SAndroid Build Coastguard Workerdef _rand_shape(dim, min_size, max_size):
73*da0073e9SAndroid Build Coastguard Worker    shape = []
74*da0073e9SAndroid Build Coastguard Worker    for i in range(dim):
75*da0073e9SAndroid Build Coastguard Worker        shape.append(random.randint(min_size, max_size))
76*da0073e9SAndroid Build Coastguard Worker    return tuple(shape)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker# TODO: refactor tests to avoid this function
80*da0073e9SAndroid Build Coastguard Worker# Converts half/bfloat16 dtype to float when device is cpu
81*da0073e9SAndroid Build Coastguard Workerdef _convert_t(dtype, device):
82*da0073e9SAndroid Build Coastguard Worker    if device == "cpu" and dtype in {torch.half, torch.bfloat16}:
83*da0073e9SAndroid Build Coastguard Worker        return torch.float
84*da0073e9SAndroid Build Coastguard Worker    return dtype
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker# TODO: replace this with make_tensor() in common_utils.py
88*da0073e9SAndroid Build Coastguard Worker# Returns a tensor of the requested shape, dtype, and device
89*da0073e9SAndroid Build Coastguard Worker# Requesting a half CPU tensor returns a float CPU tensor with
90*da0073e9SAndroid Build Coastguard Worker# values representable by a half.
91*da0073e9SAndroid Build Coastguard Worker# Initialization uses randint for non-float types and randn for float types.
92*da0073e9SAndroid Build Coastguard Workerdef _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor:
93*da0073e9SAndroid Build Coastguard Worker    # Returns a tensor filled with ones
94*da0073e9SAndroid Build Coastguard Worker    if fill_ones:
95*da0073e9SAndroid Build Coastguard Worker        return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    # Returns a tensor with random integer values
98*da0073e9SAndroid Build Coastguard Worker    if not (dtype.is_floating_point or dtype.is_complex):
99*da0073e9SAndroid Build Coastguard Worker        t = torch.randint(0, 10, shape, device=device)
100*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.uint8:
101*da0073e9SAndroid Build Coastguard Worker            t = t - 5  # generate negative values also
102*da0073e9SAndroid Build Coastguard Worker        return t.to(_convert_t(dtype, device))
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    # Populates the CPU tensor with floats representable as half/bfloat16
105*da0073e9SAndroid Build Coastguard Worker    if dtype == torch.half and device == "cpu":
106*da0073e9SAndroid Build Coastguard Worker        return torch.randn(*shape, dtype=torch.float, device=device).half().float()
107*da0073e9SAndroid Build Coastguard Worker    if dtype == torch.bfloat16 and device == "cpu":
108*da0073e9SAndroid Build Coastguard Worker        return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float()
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    # Default: returns a tensor with random float values
111*da0073e9SAndroid Build Coastguard Worker    return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker# Tests ops and indexing to ensure they return views (and new tensors) as
115*da0073e9SAndroid Build Coastguard Worker# appropriate.
116*da0073e9SAndroid Build Coastguard Workerclass TestViewOps(TestCase):
117*da0073e9SAndroid Build Coastguard Worker    exact_dtype = True
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def is_view_of(self, base, other):
120*da0073e9SAndroid Build Coastguard Worker        if (
121*da0073e9SAndroid Build Coastguard Worker            not other._is_view()
122*da0073e9SAndroid Build Coastguard Worker            or other is base
123*da0073e9SAndroid Build Coastguard Worker            or other._base is not base
124*da0073e9SAndroid Build Coastguard Worker            or base.device != other.device
125*da0073e9SAndroid Build Coastguard Worker        ):
126*da0073e9SAndroid Build Coastguard Worker            return False
127*da0073e9SAndroid Build Coastguard Worker        # Note: only validates storage on native device types
128*da0073e9SAndroid Build Coastguard Worker        # because some accelerators, like XLA, do not expose storage
129*da0073e9SAndroid Build Coastguard Worker        if base.device.type == "cpu" or base.device.type == "cuda":
130*da0073e9SAndroid Build Coastguard Worker            if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
131*da0073e9SAndroid Build Coastguard Worker                return False
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        return True
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    # Returns true if v1 and v2 are views of the same base
136*da0073e9SAndroid Build Coastguard Worker    def is_view_of_same_base(self, v1, v2):
137*da0073e9SAndroid Build Coastguard Worker        if not v1._is_view() or v1 is v2:
138*da0073e9SAndroid Build Coastguard Worker            return False
139*da0073e9SAndroid Build Coastguard Worker        return self.is_view_of(v1._base, v2)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    # Performs transpose if contiguous=True, else returns the input tensor as is
142*da0073e9SAndroid Build Coastguard Worker    def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
143*da0073e9SAndroid Build Coastguard Worker        if contiguous:
144*da0073e9SAndroid Build Coastguard Worker            return x
145*da0073e9SAndroid Build Coastguard Worker        else:
146*da0073e9SAndroid Build Coastguard Worker            return x.transpose(dim0, dim1)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
149*da0073e9SAndroid Build Coastguard Worker    def test_conj_self(self, device, dtype):
150*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
151*da0073e9SAndroid Build Coastguard Worker        s = t.conj()
152*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s is t)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
155*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
156*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
157*da0073e9SAndroid Build Coastguard Worker    def test_view_dtype_new(self, device, dtype):
158*da0073e9SAndroid Build Coastguard Worker        dtypes = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
159*da0073e9SAndroid Build Coastguard Worker        del dtypes[torch.bool]
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        def generate_inputs():
162*da0073e9SAndroid Build Coastguard Worker            yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5)
163*da0073e9SAndroid Build Coastguard Worker            yield make_tensor(
164*da0073e9SAndroid Build Coastguard Worker                (4, 4, 64), dtype=dtype, device=device, low=-5, high=5
165*da0073e9SAndroid Build Coastguard Worker            ).permute(1, 0, 2)
166*da0073e9SAndroid Build Coastguard Worker            yield make_tensor(
167*da0073e9SAndroid Build Coastguard Worker                (4, 64, 4), dtype=dtype, device=device, low=-5, high=5
168*da0073e9SAndroid Build Coastguard Worker            ).permute(2, 0, 1)
169*da0073e9SAndroid Build Coastguard Worker            yield make_tensor(
170*da0073e9SAndroid Build Coastguard Worker                (1, 5, 1), dtype=dtype, device=device, low=-5, high=5
171*da0073e9SAndroid Build Coastguard Worker            ).expand(5, 5, 64)
172*da0073e9SAndroid Build Coastguard Worker            yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[
173*da0073e9SAndroid Build Coastguard Worker                1::2, 1:, ::2
174*da0073e9SAndroid Build Coastguard Worker            ]
175*da0073e9SAndroid Build Coastguard Worker            yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5)
176*da0073e9SAndroid Build Coastguard Worker            yield make_tensor((), dtype=dtype, device=device, low=-5, high=5)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        def calc_expected_size_and_stride(a, view_dtype):
179*da0073e9SAndroid Build Coastguard Worker            dtype_size = torch._utils._element_size(a.dtype)
180*da0073e9SAndroid Build Coastguard Worker            view_dtype_size = torch._utils._element_size(view_dtype)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker            if dtype_size == view_dtype_size:
183*da0073e9SAndroid Build Coastguard Worker                return a.size(), a.stride()
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker            elif dtype_size > view_dtype_size:
186*da0073e9SAndroid Build Coastguard Worker                size_ratio = dtype_size // view_dtype_size
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker                view_size = list(a.size())
189*da0073e9SAndroid Build Coastguard Worker                view_size[-1] = view_size[-1] * size_ratio
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker                view_stride = [stride * size_ratio for stride in a.stride()]
192*da0073e9SAndroid Build Coastguard Worker                view_stride[-1] = 1
193*da0073e9SAndroid Build Coastguard Worker                return torch.Size(view_size), tuple(view_stride)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker            else:
196*da0073e9SAndroid Build Coastguard Worker                size_ratio = view_dtype_size // dtype_size
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker                view_size = list(a.size())
199*da0073e9SAndroid Build Coastguard Worker                view_size[-1] = view_size[-1] // size_ratio
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker                view_stride = [stride // size_ratio for stride in a.stride()]
202*da0073e9SAndroid Build Coastguard Worker                view_stride[-1] = 1
203*da0073e9SAndroid Build Coastguard Worker                return torch.Size(view_size), tuple(view_stride)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        for a in generate_inputs():
206*da0073e9SAndroid Build Coastguard Worker            a_np = a.cpu().numpy()
207*da0073e9SAndroid Build Coastguard Worker            a_np_contiguous = a.cpu().contiguous().numpy()
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker            for view_dtype, np_view_dtype in dtypes.items():
210*da0073e9SAndroid Build Coastguard Worker                equal_element_size = torch._utils._element_size(
211*da0073e9SAndroid Build Coastguard Worker                    dtype
212*da0073e9SAndroid Build Coastguard Worker                ) == torch._utils._element_size(view_dtype)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker                if not equal_element_size and a.dim() == 0:
215*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
216*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, r"self.dim\(\) cannot be 0"
217*da0073e9SAndroid Build Coastguard Worker                    ):
218*da0073e9SAndroid Build Coastguard Worker                        a.view(view_dtype)
219*da0073e9SAndroid Build Coastguard Worker                    continue
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker                if not equal_element_size and a.stride(-1) != 1:
222*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
223*da0073e9SAndroid Build Coastguard Worker                        RuntimeError, r"self.stride\(-1\) must be 1"
224*da0073e9SAndroid Build Coastguard Worker                    ):
225*da0073e9SAndroid Build Coastguard Worker                        a.view(view_dtype)
226*da0073e9SAndroid Build Coastguard Worker                    continue
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker                a_view = a.view(view_dtype)
229*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a_view.dtype, view_dtype)
230*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.data_ptr(), a_view.data_ptr())
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker                expected_size, expected_stride = calc_expected_size_and_stride(
233*da0073e9SAndroid Build Coastguard Worker                    a, view_dtype
234*da0073e9SAndroid Build Coastguard Worker                )
235*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a_view.size(), expected_size)
236*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a_view.stride(), expected_stride)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a_view.view(dtype), a, rtol=0, atol=0)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker                # NumPy's dtype view requires contiguous input if target
241*da0073e9SAndroid Build Coastguard Worker                # dtype is a different size
242*da0073e9SAndroid Build Coastguard Worker                if equal_element_size:
243*da0073e9SAndroid Build Coastguard Worker                    a_np_view = a_np.view(np_view_dtype)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker                else:
246*da0073e9SAndroid Build Coastguard Worker                    a_np_view = a_np_contiguous.view(np_view_dtype)
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a_view, a_np_view)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        # Test that requires_grad is dropped for floating point casts,
251*da0073e9SAndroid Build Coastguard Worker        # because view(dtype) does not support backward yet
252*da0073e9SAndroid Build Coastguard Worker        # TODO: Remove this when autograd support is added
253*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point or dtype.is_complex:
254*da0073e9SAndroid Build Coastguard Worker            for view_dtype in floating_and_complex_types_and(
255*da0073e9SAndroid Build Coastguard Worker                torch.half, torch.bfloat16
256*da0073e9SAndroid Build Coastguard Worker            ):
257*da0073e9SAndroid Build Coastguard Worker                t = make_tensor(
258*da0073e9SAndroid Build Coastguard Worker                    (5, 5, 64),
259*da0073e9SAndroid Build Coastguard Worker                    dtype=dtype,
260*da0073e9SAndroid Build Coastguard Worker                    device=device,
261*da0073e9SAndroid Build Coastguard Worker                    low=-5,
262*da0073e9SAndroid Build Coastguard Worker                    high=5,
263*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
264*da0073e9SAndroid Build Coastguard Worker                )
265*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(t.view(view_dtype).requires_grad)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    # Test the extra error checks that happen when the view dtype
268*da0073e9SAndroid Build Coastguard Worker    # has a greater element size than the original dtype
269*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
270*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
271*da0073e9SAndroid Build Coastguard Worker    def test_view_dtype_upsize_errors(self, device, dtype):
272*da0073e9SAndroid Build Coastguard Worker        dtype_size = torch._utils._element_size(dtype)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        for view_dtype in all_types_and_complex_and(
275*da0073e9SAndroid Build Coastguard Worker            torch.half, torch.bfloat16, torch.bool
276*da0073e9SAndroid Build Coastguard Worker        ):
277*da0073e9SAndroid Build Coastguard Worker            view_dtype_size = torch._utils._element_size(view_dtype)
278*da0073e9SAndroid Build Coastguard Worker            if view_dtype_size <= dtype_size:
279*da0073e9SAndroid Build Coastguard Worker                continue
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker            size_ratio = view_dtype_size // dtype_size
282*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(
283*da0073e9SAndroid Build Coastguard Worker                (4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5
284*da0073e9SAndroid Build Coastguard Worker            )
285*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
286*da0073e9SAndroid Build Coastguard Worker                RuntimeError, rf"self.size\(-1\) must be divisible by {size_ratio}"
287*da0073e9SAndroid Build Coastguard Worker            ):
288*da0073e9SAndroid Build Coastguard Worker                a.view(view_dtype)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
291*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
292*da0073e9SAndroid Build Coastguard Worker                rf"self.storage_offset\(\) must be divisible by {size_ratio}",
293*da0073e9SAndroid Build Coastguard Worker            ):
294*da0073e9SAndroid Build Coastguard Worker                a[:, :, 1:].view(view_dtype)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker            a = make_tensor(
297*da0073e9SAndroid Build Coastguard Worker                (4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5
298*da0073e9SAndroid Build Coastguard Worker            )
299*da0073e9SAndroid Build Coastguard Worker            a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1))
300*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
301*da0073e9SAndroid Build Coastguard Worker                RuntimeError, rf"self.stride\(1\) must be divisible by {size_ratio}"
302*da0073e9SAndroid Build Coastguard Worker            ):
303*da0073e9SAndroid Build Coastguard Worker                a.view(view_dtype)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
306*da0073e9SAndroid Build Coastguard Worker    def test_view_as_complex(self, device):
307*da0073e9SAndroid Build Coastguard Worker        def fn(contiguous_input=True, dim0=0, dim1=1):
308*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(3, 2, 2, device=device)
309*da0073e9SAndroid Build Coastguard Worker            c_t = t[:, :, 0] + 1j * t[:, :, 1]
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker            input = self._do_transpose(t, contiguous_input, dim0, dim1)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker            if input.size()[-1] != 2:
314*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
315*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
316*da0073e9SAndroid Build Coastguard Worker                    "Tensor must have a last dimension of size 2",
317*da0073e9SAndroid Build Coastguard Worker                    lambda: torch.view_as_complex(input),
318*da0073e9SAndroid Build Coastguard Worker                )
319*da0073e9SAndroid Build Coastguard Worker                return
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker            if input.stride()[-1] != 1:
322*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
323*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
324*da0073e9SAndroid Build Coastguard Worker                    "Tensor must have a last dimension with stride 1",
325*da0073e9SAndroid Build Coastguard Worker                    lambda: torch.view_as_complex(input),
326*da0073e9SAndroid Build Coastguard Worker                )
327*da0073e9SAndroid Build Coastguard Worker                return
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker            res = torch.view_as_complex(input)
330*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1))
331*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, res))
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        fn()
334*da0073e9SAndroid Build Coastguard Worker        fn(contiguous_input=False)
335*da0073e9SAndroid Build Coastguard Worker        # RuntimeError since in this case the last dim of input would not be of size 2
336*da0073e9SAndroid Build Coastguard Worker        fn(contiguous_input=False, dim0=0, dim1=2)
337*da0073e9SAndroid Build Coastguard Worker        # RuntimeError since in this case the last dim of input would not have stride 1
338*da0073e9SAndroid Build Coastguard Worker        fn(contiguous_input=False, dim0=1, dim1=2)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        # RuntimeError since in this case the stride of non-last dim of input would not be of size 2
341*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, device=device)
342*da0073e9SAndroid Build Coastguard Worker        t = torch.as_strided(x, (2, 2), (1, 1))
343*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
344*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
345*da0073e9SAndroid Build Coastguard Worker            "Tensor must have a stride divisible by 2 for all but last dimension",
346*da0073e9SAndroid Build Coastguard Worker            lambda: torch.view_as_complex(t),
347*da0073e9SAndroid Build Coastguard Worker        )
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        # tensor with zero elements
350*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([], device=device)  # torch.Size([0])
351*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
352*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
353*da0073e9SAndroid Build Coastguard Worker            "Tensor must have a last dimension of size 2",
354*da0073e9SAndroid Build Coastguard Worker            lambda: torch.view_as_complex(x),
355*da0073e9SAndroid Build Coastguard Worker        )
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        # zero dimension tensor
358*da0073e9SAndroid Build Coastguard Worker        z = torch.tensor(2.0)
359*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
360*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
361*da0073e9SAndroid Build Coastguard Worker            "Input tensor must have one or more dimensions",
362*da0073e9SAndroid Build Coastguard Worker            lambda: torch.view_as_complex(z),
363*da0073e9SAndroid Build Coastguard Worker        )
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        y = x.reshape(0, 2)  # torch.Size([0, 2])
366*da0073e9SAndroid Build Coastguard Worker        res = torch.view_as_complex(y)
367*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(x, res))
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.shape, torch.Size([0]))
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
371*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types(), torch.complex32)
372*da0073e9SAndroid Build Coastguard Worker    def test_view_as_real(self, device, dtype):
373*da0073e9SAndroid Build Coastguard Worker        def fn(contiguous_input=True):
374*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(3, 4, dtype=dtype, device=device)
375*da0073e9SAndroid Build Coastguard Worker            input = self._do_transpose(t, contiguous_input)
376*da0073e9SAndroid Build Coastguard Worker            res = torch.view_as_real(input)
377*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res[:, :, 0], input.real)
378*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res[:, :, 1], input.imag)
379*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, res))
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        fn()
382*da0073e9SAndroid Build Coastguard Worker        fn(contiguous_input=False)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        # tensor with zero elements
385*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([], dtype=dtype, device=device)
386*da0073e9SAndroid Build Coastguard Worker        res = torch.view_as_real(x)
387*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(x, res))
388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.shape, torch.Size([0, 2]))
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        # tensor with zero dim
391*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2 + 3j, dtype=dtype, device=device)
392*da0073e9SAndroid Build Coastguard Worker        res = torch.view_as_real(x)
393*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(x, res))
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.shape, torch.Size([2]))
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
397*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
398*da0073e9SAndroid Build Coastguard Worker    def test_view_tensor_split(self, device, dtype):
399*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
400*da0073e9SAndroid Build Coastguard Worker        a_split_dim0 = a.tensor_split(7, 0)
401*da0073e9SAndroid Build Coastguard Worker        for a_split_dim0_tensor in a_split_dim0:
402*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(a, a_split_dim0_tensor))
403*da0073e9SAndroid Build Coastguard Worker        a_split_dim1 = a.tensor_split(7, 1)
404*da0073e9SAndroid Build Coastguard Worker        for a_split_dim1_tensor in a_split_dim1:
405*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(a, a_split_dim1_tensor))
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
408*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
409*da0073e9SAndroid Build Coastguard Worker    def test_view_tensor_hsplit(self, device, dtype):
410*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
411*da0073e9SAndroid Build Coastguard Worker        t_hsplit = torch.hsplit(t, 2)
412*da0073e9SAndroid Build Coastguard Worker        for t_hsplit_tensor in t_hsplit:
413*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, t_hsplit_tensor))
414*da0073e9SAndroid Build Coastguard Worker        t[2, 2, 2] = 7
415*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2])
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
418*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
419*da0073e9SAndroid Build Coastguard Worker    def test_view_tensor_vsplit(self, device, dtype):
420*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
421*da0073e9SAndroid Build Coastguard Worker        t_vsplit = torch.vsplit(t, 2)
422*da0073e9SAndroid Build Coastguard Worker        for t_vsplit_tensor in t_vsplit:
423*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, t_vsplit_tensor))
424*da0073e9SAndroid Build Coastguard Worker        t[2, 2, 2] = 7
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2])
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
428*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
429*da0073e9SAndroid Build Coastguard Worker    def test_view_tensor_dsplit(self, device, dtype):
430*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9)
431*da0073e9SAndroid Build Coastguard Worker        t_dsplit = torch.dsplit(t, 2)
432*da0073e9SAndroid Build Coastguard Worker        for t_dsplit_tensor in t_dsplit:
433*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, t_dsplit_tensor))
434*da0073e9SAndroid Build Coastguard Worker        t[2, 2, 2] = 7
435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
438*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
439*da0073e9SAndroid Build Coastguard Worker    def test_imag_noncomplex(self, device, dtype):
440*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), dtype=dtype, device=device)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
443*da0073e9SAndroid Build Coastguard Worker            torch.imag(t)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
446*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
447*da0073e9SAndroid Build Coastguard Worker    def test_real_imag_view(self, device, dtype):
448*da0073e9SAndroid Build Coastguard Worker        def compare_with_numpy(contiguous_input=True):
449*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(3, 3, dtype=dtype, device=device)
450*da0073e9SAndroid Build Coastguard Worker            if not contiguous_input:
451*da0073e9SAndroid Build Coastguard Worker                u = t.T
452*da0073e9SAndroid Build Coastguard Worker            else:
453*da0073e9SAndroid Build Coastguard Worker                u = t
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker            re = u.real
456*da0073e9SAndroid Build Coastguard Worker            exp = torch.from_numpy(u.cpu().numpy().real).to(device=device)
457*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(re, exp)
458*da0073e9SAndroid Build Coastguard Worker            # for the case of contiguous_input, t=u
459*da0073e9SAndroid Build Coastguard Worker            # for the case of non contiguous_input, the base still remains
460*da0073e9SAndroid Build Coastguard Worker            # t since we are performing a view operation to make the input non-contiguous
461*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, re))
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker            im = u.imag
464*da0073e9SAndroid Build Coastguard Worker            exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device)
465*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(im, exp)
466*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, im))
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker        compare_with_numpy()
469*da0073e9SAndroid Build Coastguard Worker        compare_with_numpy(contiguous_input=False)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        # ensure storage offset is being correctly set
472*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, dtype=dtype)
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[5:].real, a.real[5:])
474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[5:].imag, a.imag[5:])
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
477*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
478*da0073e9SAndroid Build Coastguard Worker    def test_conj_imag_view(self, device, dtype) -> None:
479*da0073e9SAndroid Build Coastguard Worker        t = _make_tensor((4, 5), dtype, device)
480*da0073e9SAndroid Build Coastguard Worker        t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device)
481*da0073e9SAndroid Build Coastguard Worker        v = t.conj()
482*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
483*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v, t_numpy_conj)
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker        if t.is_complex():
486*da0073e9SAndroid Build Coastguard Worker            v_imag = v.imag
487*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v_imag))
488*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v_imag, t_numpy_conj.imag)
489*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(v_imag.is_neg())
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
492*da0073e9SAndroid Build Coastguard Worker    def test_conj_view_with_shared_memory(self, device) -> None:
493*da0073e9SAndroid Build Coastguard Worker        a = _make_tensor((4, 5), torch.cfloat, device)
494*da0073e9SAndroid Build Coastguard Worker        b = a.conj()
495*da0073e9SAndroid Build Coastguard Worker        c = a.conj()
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.add(a, b), a.add_(b))
498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.add(b, c), torch.add(b, c, out=a))
499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.add(b, c), b.add_(c))
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
502*da0073e9SAndroid Build Coastguard Worker    @dtypes(
503*da0073e9SAndroid Build Coastguard Worker        *product(
504*da0073e9SAndroid Build Coastguard Worker            complex_types(),
505*da0073e9SAndroid Build Coastguard Worker            all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
506*da0073e9SAndroid Build Coastguard Worker        )
507*da0073e9SAndroid Build Coastguard Worker    )
508*da0073e9SAndroid Build Coastguard Worker    @suppress_warnings
509*da0073e9SAndroid Build Coastguard Worker    def test_set_real_imag(self, device, dtypes):
510*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, dtype=dtypes[0], device=device)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker        new_real = _make_tensor((10,), dtypes[1], device)
513*da0073e9SAndroid Build Coastguard Worker        new_imag = _make_tensor((10,), dtypes[1], device)
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker        x.real = new_real
516*da0073e9SAndroid Build Coastguard Worker        x.imag = new_imag
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        if dtypes[1].is_complex:
519*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.real, new_real.real, exact_dtype=False)
520*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.imag, new_imag.real, exact_dtype=False)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        else:
523*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.real, new_real, exact_dtype=False)
524*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.imag, new_imag, exact_dtype=False)
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker    def test_diagonal_view(self, device) -> None:
527*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
528*da0073e9SAndroid Build Coastguard Worker        v = torch.diagonal(t)
529*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker        v[0] = 0
532*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0], v[0])
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((3, 3, 3), device=device)
535*da0073e9SAndroid Build Coastguard Worker        v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
536*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
539*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0, 1], v[0, 0])
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker    def test_select_view(self, device) -> None:
542*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
543*da0073e9SAndroid Build Coastguard Worker        v = t.select(0, 2)
544*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        v[0] = 0
547*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 0], v[0])
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker    # Lazy hasn't implemented unbind yet.
550*da0073e9SAndroid Build Coastguard Worker    @skipLazy
551*da0073e9SAndroid Build Coastguard Worker    def test_unbind_view(self, device) -> None:
552*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros((5, 5), device=device)
553*da0073e9SAndroid Build Coastguard Worker        tup = torch.unbind(t)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker        for idx, v in enumerate(tup):
556*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker            v[0] = idx + 1
559*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx, 0], v[0])
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    # TODO: opinfo this or move to unbind's test suite
562*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self):
563*da0073e9SAndroid Build Coastguard Worker        stacked = torch.randn(3, 10, 10, requires_grad=True)
564*da0073e9SAndroid Build Coastguard Worker        x, y, z = stacked.unbind()
565*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(3, 10, 10)
566*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([x, y, z], grad.unbind())
567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stacked.grad, grad)
568*da0073e9SAndroid Build Coastguard Worker        # check that it works with only one gradient provided (#9977)
569*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
570*da0073e9SAndroid Build Coastguard Worker            stacked = torch.randn(3, 10, 10, requires_grad=True)
571*da0073e9SAndroid Build Coastguard Worker            outs = stacked.unbind()
572*da0073e9SAndroid Build Coastguard Worker            gi = grad.unbind()[i]
573*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad(outs[i], stacked, gi)
574*da0073e9SAndroid Build Coastguard Worker            g_expected = torch.stack(
575*da0073e9SAndroid Build Coastguard Worker                [gi if j == i else torch.zeros_like(gi) for j in range(3)], dim=0
576*da0073e9SAndroid Build Coastguard Worker            )
577*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(g, g_expected)
578*da0073e9SAndroid Build Coastguard Worker        # Check with gradcheck
579*da0073e9SAndroid Build Coastguard Worker        stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True)
580*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker    # TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken,
583*da0073e9SAndroid Build Coastguard Worker    # causing asserts to trigger.
584*da0073e9SAndroid Build Coastguard Worker    @skipLazy
585*da0073e9SAndroid Build Coastguard Worker    def test_expand_view(self, device) -> None:
586*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 1), device=device)
587*da0073e9SAndroid Build Coastguard Worker        v = t.expand(5, 5)
588*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker        v[2, 2] = 0
591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 0], v[2, 2])
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker    def test_expand_as_view(self, device):
594*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 1), device=device)
595*da0073e9SAndroid Build Coastguard Worker        e = torch.empty((5, 5), device=device)
596*da0073e9SAndroid Build Coastguard Worker        v = t.expand_as(e)
597*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker        v[2, 2] = 0
600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 0], v[2, 2])
601*da0073e9SAndroid Build Coastguard Worker
602*da0073e9SAndroid Build Coastguard Worker    def test_narrow_view(self, device):
603*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
604*da0073e9SAndroid Build Coastguard Worker        v = torch.narrow(t, 1, 2, 2)
605*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
608*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 2], v[0, 0])
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    def test_permute_view(self, device) -> None:
611*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
612*da0073e9SAndroid Build Coastguard Worker        v = t.permute(1, 0)
613*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
616*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
617*da0073e9SAndroid Build Coastguard Worker
618*da0073e9SAndroid Build Coastguard Worker    def test_transpose_view(self, device):
619*da0073e9SAndroid Build Coastguard Worker        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
620*da0073e9SAndroid Build Coastguard Worker            t = torch.ones((5, 5), device=device)
621*da0073e9SAndroid Build Coastguard Worker            v = fn(t, 0, 1)
622*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker            v[0, 1] = 0
625*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[1, 0], v[0, 1])
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker    def test_transpose_inplace_view(self, device):
628*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
629*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
630*da0073e9SAndroid Build Coastguard Worker        v = v.swapdims_(0, 1)
631*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
632*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
636*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
637*da0073e9SAndroid Build Coastguard Worker        v = v.swapaxes_(0, 1)
638*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
639*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
640*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
641*da0073e9SAndroid Build Coastguard Worker
642*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
643*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
644*da0073e9SAndroid Build Coastguard Worker        v = v.transpose_(0, 1)
645*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
646*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker    def test_t_view(self, device):
650*da0073e9SAndroid Build Coastguard Worker        t = torch.ones((5, 5), device=device)
651*da0073e9SAndroid Build Coastguard Worker        v = t.t()
652*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
655*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker    def test_t_inplace_view(self, device):
658*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
659*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
660*da0073e9SAndroid Build Coastguard Worker        v = v.t_()
661*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
662*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
663*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 0], v[0, 1])
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    def test_T_view(self, device):
666*da0073e9SAndroid Build Coastguard Worker        for op in ("T", "H", "mT", "mH"):
667*da0073e9SAndroid Build Coastguard Worker            t = torch.ones((5, 5), device=device)
668*da0073e9SAndroid Build Coastguard Worker            v = getattr(t, op)
669*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker            v[0, 1] = 0
672*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[1, 0], v[0, 1])
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker    def test_unfold_view(self, device):
675*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(10, device=device)
676*da0073e9SAndroid Build Coastguard Worker        v = t.unfold(0, 3, 2)
677*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        v[1, 0] = 0
680*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2], v[1, 0])
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker    def test_squeeze_view(self, device):
683*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 1, 5, device=device)
684*da0073e9SAndroid Build Coastguard Worker        v = torch.squeeze(t)
685*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
686*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
687*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, v._base)
688*da0073e9SAndroid Build Coastguard Worker
689*da0073e9SAndroid Build Coastguard Worker    def test_squeeze_inplace_view(self, device):
690*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
691*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
692*da0073e9SAndroid Build Coastguard Worker        v = v.squeeze_()
693*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
694*da0073e9SAndroid Build Coastguard Worker        v[0, 1] = 0
695*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, v._base)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze_view(self, device):
698*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
699*da0073e9SAndroid Build Coastguard Worker        v = torch.unsqueeze(t, 1)
700*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker        v[0, 0, 1] = 0
703*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 1], v[0, 0, 1])
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze_inplace_view(self, device):
706*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
707*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
708*da0073e9SAndroid Build Coastguard Worker        v = v.unsqueeze_(1)
709*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
710*da0073e9SAndroid Build Coastguard Worker        v[0, 0, 1] = 0
711*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 1], v[0, 0, 1])
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_view(self, device):
714*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
715*da0073e9SAndroid Build Coastguard Worker        v = torch.as_strided(t, (25,), (1,))
716*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
717*da0073e9SAndroid Build Coastguard Worker
718*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_inplace_view(self, device):
722*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
723*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(t)
724*da0073e9SAndroid Build Coastguard Worker        v = v.as_strided_((25,), (1,))
725*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
726*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker    def test_as_strided_gradients(self):
730*da0073e9SAndroid Build Coastguard Worker        def test(x, prepro_fn, size, strides, offset=None):
731*da0073e9SAndroid Build Coastguard Worker            x = x.to(torch.double).detach().requires_grad_()
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker            # Check that forward will **not** resize storage because it may
734*da0073e9SAndroid Build Coastguard Worker            # cause NaN in output and fail numerical Jacobian check consequently
735*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
736*da0073e9SAndroid Build Coastguard Worker                y = prepro_fn(x) if prepro_fn is not None else x
737*da0073e9SAndroid Build Coastguard Worker                max_offset = sum((si - 1) * st for si, st in zip(size, strides))
738*da0073e9SAndroid Build Coastguard Worker                max_offset += offset if offset is not None else y.storage_offset()
739*da0073e9SAndroid Build Coastguard Worker                assert max_offset < len(y.storage()), "test case resizes storage"
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker            def closure(x):
742*da0073e9SAndroid Build Coastguard Worker                if prepro_fn is not None:
743*da0073e9SAndroid Build Coastguard Worker                    x = prepro_fn(x)
744*da0073e9SAndroid Build Coastguard Worker                return x.as_strided(size, strides, offset)
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker            gradcheck(closure, [x], check_forward_ad=True)
747*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(closure, [x])
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        # test
750*da0073e9SAndroid Build Coastguard Worker        test(torch.arange(0, 25), lambda x: x.view(5, 5), [3, 3], [6, 2], 2)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        # test crazy stride at dim with size 1 case
753*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(12), None, [1, 2, 1, 5], [0, 5, 100, 1], 2)
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker        # test expand case
756*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(5), None, [3, 3, 3], [0, 1, 0], 2)
757*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(5), None, [3, 3, 3], [0, 0, 0], 4)
758*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(5), lambda x: x.expand(5, 5), [5, 5], [0, 1], 0)
759*da0073e9SAndroid Build Coastguard Worker
760*da0073e9SAndroid Build Coastguard Worker        # test non-expand overlapping case
761*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(35), None, [6, 6], [5, 1], 2)
762*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(15), None, [3, 2], [3, 6], 2)
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Worker        # test transpose case
765*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(3, 4), None, [4, 3], [1, 4])
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        # test "getting things outside the input" case
768*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(6, 2)
769*da0073e9SAndroid Build Coastguard Worker        test(x[3:], None, [3, 2], [2, 1], 0)  # should be all zeros
770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3])
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker        # test select on expanded input case
773*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0)
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker    def test_view_view(self, device):
776*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
777*da0073e9SAndroid Build Coastguard Worker        v = t.view(25)
778*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
779*da0073e9SAndroid Build Coastguard Worker
780*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker    def test_view_as_view(self, device):
784*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
785*da0073e9SAndroid Build Coastguard Worker        e = torch.empty((25,))
786*da0073e9SAndroid Build Coastguard Worker        v = t.view_as(e)
787*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_self(self, device):
793*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
794*da0073e9SAndroid Build Coastguard Worker        s = t.contiguous()
795*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(s is t)
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker    @skipMeta
798*da0073e9SAndroid Build Coastguard Worker    # self.is_view_of reports false positives for lazy
799*da0073e9SAndroid Build Coastguard Worker    @skipLazy
800*da0073e9SAndroid Build Coastguard Worker    def test_contiguous_nonview(self, device):
801*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
802*da0073e9SAndroid Build Coastguard Worker        nv = t.t().contiguous()
803*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not self.is_view_of(t, nv))
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        nv[0, 0] = 0
806*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t[0, 0], nv[0, 0])
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker    def test_reshape_view(self, device):
809*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
810*da0073e9SAndroid Build Coastguard Worker        v = torch.reshape(t, (25,))
811*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
814*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    def test_reshape_as_view(self, device):
817*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
818*da0073e9SAndroid Build Coastguard Worker        e = torch.empty((25,), device=device)
819*da0073e9SAndroid Build Coastguard Worker        v = t.reshape_as(e)
820*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker        v[6] = 0
823*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[1, 1], v[6])
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker    @skipMeta
826*da0073e9SAndroid Build Coastguard Worker    # self.is_view_of reports false positives for lazy
827*da0073e9SAndroid Build Coastguard Worker    @skipLazy
828*da0073e9SAndroid Build Coastguard Worker    def test_reshape_nonview(self, device):
829*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
830*da0073e9SAndroid Build Coastguard Worker        nv = torch.reshape(t.t(), (25,))
831*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not self.is_view_of(t, nv))
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        nv[6] = 0
834*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t[1, 1], nv[6])
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker    # This test use as_strided to construct a tensor with overlapping memory,
837*da0073e9SAndroid Build Coastguard Worker    # which is not handled by the functionalization pass.
838*da0073e9SAndroid Build Coastguard Worker    @skipLazy
839*da0073e9SAndroid Build Coastguard Worker    @skipXLA
840*da0073e9SAndroid Build Coastguard Worker    def test_flatten_view(self, device):
841*da0073e9SAndroid Build Coastguard Worker        def test_writes_propagate(t, v):
842*da0073e9SAndroid Build Coastguard Worker            idx_t = (0,) * t.ndim
843*da0073e9SAndroid Build Coastguard Worker            idx_v = (0,) * v.ndim
844*da0073e9SAndroid Build Coastguard Worker            v[idx_v] = 0
845*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx_t], v[idx_v])
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1, 2, 3, 4, device=device)
848*da0073e9SAndroid Build Coastguard Worker        v = t.flatten()
849*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
850*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v)
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Worker        # zero-dimensional tensor
853*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1, device=device)
854*da0073e9SAndroid Build Coastguard Worker        v = t.flatten()
855*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v)
856*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
859*da0073e9SAndroid Build Coastguard Worker        v = t.flatten(0, 1)
860*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v)
861*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v))
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
864*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(720, device=device).as_strided(
865*da0073e9SAndroid Build Coastguard Worker            (2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)
866*da0073e9SAndroid Build Coastguard Worker        )
867*da0073e9SAndroid Build Coastguard Worker        #               [--1--|---2---|-3-] [--1--|----2---|-3-]
868*da0073e9SAndroid Build Coastguard Worker        v1 = t.flatten(0, 1)
869*da0073e9SAndroid Build Coastguard Worker        v2 = v1.flatten(1, 3)
870*da0073e9SAndroid Build Coastguard Worker        v3 = v2.flatten(2, 2)
871*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v1)
872*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v1))
873*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v2)
874*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v2))
875*da0073e9SAndroid Build Coastguard Worker        test_writes_propagate(t, v3)
876*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of_same_base(t, v3))
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
879*da0073e9SAndroid Build Coastguard Worker    def test_flatten_nonview(self, device):
880*da0073e9SAndroid Build Coastguard Worker        def assert_is_nonview(t, nv):
881*da0073e9SAndroid Build Coastguard Worker            idx_t = (0,) * t.ndim
882*da0073e9SAndroid Build Coastguard Worker            idx_nv = (0,) * nv.ndim
883*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(not nv._is_view())
884*da0073e9SAndroid Build Coastguard Worker            nv[idx_nv] = 0
885*da0073e9SAndroid Build Coastguard Worker            if device != "meta":
886*da0073e9SAndroid Build Coastguard Worker                self.assertNotEqual(t[idx_t], nv[idx_nv])
887*da0073e9SAndroid Build Coastguard Worker
888*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
889*da0073e9SAndroid Build Coastguard Worker        nv = t.flatten(1, 3)
890*da0073e9SAndroid Build Coastguard Worker        assert_is_nonview(t, nv)
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2, 2, device=device).T
893*da0073e9SAndroid Build Coastguard Worker        nv = t.flatten()
894*da0073e9SAndroid Build Coastguard Worker        assert_is_nonview(t, nv)
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker        # flatten returns the original object if start_dim=end_dim
897*da0073e9SAndroid Build Coastguard Worker        t = t = torch.ones(2, 2, device=device)
898*da0073e9SAndroid Build Coastguard Worker        nv = t.flatten(1, 1)
899*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t is nv)
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker    def test_basic_indexing_slice_view(self, device):
902*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
903*da0073e9SAndroid Build Coastguard Worker        v = t[:2, :3]
904*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0], v[0, 0])
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker    def test_basic_indexing_ellipses_view(self, device):
910*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
911*da0073e9SAndroid Build Coastguard Worker        v = t[..., :2]
912*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
915*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 0], v[0, 0])
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker    def test_basic_indexing_newaxis_view(self, device):
918*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(5, 5, device=device)
919*da0073e9SAndroid Build Coastguard Worker        v = t[None, :2, 3]
920*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(t, v))
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker        v[0, 0] = 0
923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[0, 3], v[0, 0])
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_nonview(self, device):
926*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(3, 3, device=device)
927*da0073e9SAndroid Build Coastguard Worker        rows = torch.tensor([[0, 0], [2, 2]], device=device)
928*da0073e9SAndroid Build Coastguard Worker        cols = torch.tensor([[0, 1], [2, 2]], device=device)
929*da0073e9SAndroid Build Coastguard Worker        nv = t[rows, cols]
930*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(not self.is_view_of(t, nv))
931*da0073e9SAndroid Build Coastguard Worker
932*da0073e9SAndroid Build Coastguard Worker        nv[1, 1] = 0
933*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(t[2, 2], nv[1, 1])
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
936*da0073e9SAndroid Build Coastguard Worker        IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds"
937*da0073e9SAndroid Build Coastguard Worker    )
938*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_assignment(self, device):
939*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(3, 3, device=device)
940*da0073e9SAndroid Build Coastguard Worker        rows = torch.tensor([[0, 0], [2, 2]], device=device)
941*da0073e9SAndroid Build Coastguard Worker        cols = torch.tensor([[0, 1], [2, 2]], device=device)
942*da0073e9SAndroid Build Coastguard Worker        t[rows, cols] = 0
943*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[2, 2], 0)
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720")
946*da0073e9SAndroid Build Coastguard Worker    def test_chunk_view(self, device):
947*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(3, 3, device=device)
948*da0073e9SAndroid Build Coastguard Worker        l = torch.chunk(t, 3)
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker        for idx, v in enumerate(l):
951*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker            v[0, 0] = idx + 1
954*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx, 0], v[0, 0])
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Worker    @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720")
957*da0073e9SAndroid Build Coastguard Worker    def test_split_view(self, device):
958*da0073e9SAndroid Build Coastguard Worker        t = torch.zeros(3, 3, device=device)
959*da0073e9SAndroid Build Coastguard Worker        l = torch.split(t, [1, 1, 1])
960*da0073e9SAndroid Build Coastguard Worker
961*da0073e9SAndroid Build Coastguard Worker        for idx, v in enumerate(l):
962*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, v))
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker            v[0, 0] = idx + 1
965*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t[idx, 0], v[0, 0])
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    def test_movedim_view(self, device):
968*da0073e9SAndroid Build Coastguard Worker        def run_test(device, op):
969*da0073e9SAndroid Build Coastguard Worker            t = torch.zeros(3, 3, device=device)
970*da0073e9SAndroid Build Coastguard Worker            out = op(t)
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(self.is_view_of(t, out))
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker            # Randomly change values in output
975*da0073e9SAndroid Build Coastguard Worker            # and verify that original is changed
976*da0073e9SAndroid Build Coastguard Worker            # as well.
977*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
978*da0073e9SAndroid Build Coastguard Worker                idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
979*da0073e9SAndroid Build Coastguard Worker                out[idx_1, idx_2] = random.random()
980*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        for fn in [torch.movedim, torch.moveaxis]:
983*da0073e9SAndroid Build Coastguard Worker            op = partial(fn, source=(0, 1), destination=(1, 0))
984*da0073e9SAndroid Build Coastguard Worker            run_test(device, op)
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker            op = partial(fn, source=0, destination=1)
987*da0073e9SAndroid Build Coastguard Worker            run_test(device, op)
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker    # Testing that the generated view_copy kernel and its derivative are implemented correctly
990*da0073e9SAndroid Build Coastguard Worker    def test_view_copy(self, device):
991*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, device=device, requires_grad=True)
992*da0073e9SAndroid Build Coastguard Worker        a_ref = a.clone().detach().requires_grad_()
993*da0073e9SAndroid Build Coastguard Worker        a_view = a_ref.view(2, 2)
994*da0073e9SAndroid Build Coastguard Worker        a_view_copy = torch.view_copy(a, (2, 2))
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker        # view_copy ops don't preserve view relationship
997*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(self.is_view_of(a_ref, a_view))
998*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(self.is_view_of(a, a_view_copy))
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        a_view_copy.sum().backward()
1001*da0073e9SAndroid Build Coastguard Worker        a_view.sum().backward()
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker        # forward and backward give the same shape + result
1004*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_view_copy, a_view)
1005*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, a_ref.grad)
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker    # Testing that the output of a view_copy kernel (by default) is contiguous.
1008*da0073e9SAndroid Build Coastguard Worker    def test_view_copy_output_contiguous(self, device):
1009*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last)
1010*da0073e9SAndroid Build Coastguard Worker        b = torch.ops.aten.slice_copy(a, 0, 0, 2)
1011*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.is_contiguous())
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker    def test_view_copy_out(self, device):
1014*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, device=device)
1015*da0073e9SAndroid Build Coastguard Worker        out = torch.empty(2, device=device)
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker        torch.diagonal_copy(a, out=out)
1018*da0073e9SAndroid Build Coastguard Worker        expected = torch.diagonal_copy(a)
1019*da0073e9SAndroid Build Coastguard Worker
1020*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, out)
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, device=device)
1023*da0073e9SAndroid Build Coastguard Worker        out1 = torch.empty(2, device=device)
1024*da0073e9SAndroid Build Coastguard Worker        out2 = torch.empty(2, device=device)
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker        torch.split_copy(a, 2, out=(out1, out2))
1027*da0073e9SAndroid Build Coastguard Worker        expected1, expected2 = torch.split_copy(a, 2)
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected1, out1)
1030*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected2, out2)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker
1033*da0073e9SAndroid Build Coastguard Workerclass TestOldViewOps(TestCase):
1034*da0073e9SAndroid Build Coastguard Worker    def test_ravel(self, device):
1035*da0073e9SAndroid Build Coastguard Worker        def _test_ravel(tensors, size, nc=False):
1036*da0073e9SAndroid Build Coastguard Worker            for src in tensors:
1037*da0073e9SAndroid Build Coastguard Worker                # Continuous Tensor -> View
1038*da0073e9SAndroid Build Coastguard Worker                flat = src.ravel()
1039*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(flat.shape, torch.Size([size]))
1040*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(src.view(-1), flat)
1041*da0073e9SAndroid Build Coastguard Worker                self.assertIs(flat._base, src)
1042*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(flat.is_contiguous())
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker                # Non-continuous Tensor -> Copy
1045*da0073e9SAndroid Build Coastguard Worker                if nc:
1046*da0073e9SAndroid Build Coastguard Worker                    nc_src = src.t()
1047*da0073e9SAndroid Build Coastguard Worker                    nc_flat = nc_src.ravel()
1048*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(nc_flat.shape, torch.Size([size]))
1049*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(nc_src.contiguous().view(-1), nc_flat)
1050*da0073e9SAndroid Build Coastguard Worker                    self.assertIsNot(nc_flat._base, src)
1051*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(nc_flat.is_contiguous())
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker        # Test that flatten returns 1-dim tensor when given a 0-dim tensor
1054*da0073e9SAndroid Build Coastguard Worker        zero_dim_tensor = torch.tensor(123, device=device)
1055*da0073e9SAndroid Build Coastguard Worker        flat0 = zero_dim_tensor.ravel()
1056*da0073e9SAndroid Build Coastguard Worker        one_dim_tensor = torch.tensor([123], device=device)
1057*da0073e9SAndroid Build Coastguard Worker        flat1 = zero_dim_tensor.ravel()
1058*da0073e9SAndroid Build Coastguard Worker        nc_ones_tensor = torch.ones(10, device=device)[::2]
1059*da0073e9SAndroid Build Coastguard Worker        flat2 = nc_ones_tensor.ravel()
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
1062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0.shape, torch.Size([1]))
1063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
1064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat1.shape, torch.Size([1]))
1065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nc_ones_tensor.shape, torch.Size([5]))
1066*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat2.shape, torch.Size([5]))
1067*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0, one_dim_tensor)
1068*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0, flat1)
1069*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0.shape, flat1.shape)
1070*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(flat0.is_contiguous())
1071*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(flat1.is_contiguous())
1072*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(flat2.is_contiguous())
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        # Test both float tensor and quantized tensor
1075*da0073e9SAndroid Build Coastguard Worker        tensors = [
1076*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 5, 5, 5, device=device),
1077*da0073e9SAndroid Build Coastguard Worker            torch._empty_affine_quantized(
1078*da0073e9SAndroid Build Coastguard Worker                [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device
1079*da0073e9SAndroid Build Coastguard Worker            ),
1080*da0073e9SAndroid Build Coastguard Worker        ]
1081*da0073e9SAndroid Build Coastguard Worker        _test_ravel(tensors, 625)
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        tensors = [
1084*da0073e9SAndroid Build Coastguard Worker            torch.randn(0, 2, 3, device=device),
1085*da0073e9SAndroid Build Coastguard Worker            torch.randn(3, 0, 2, device=device),
1086*da0073e9SAndroid Build Coastguard Worker            torch._empty_affine_quantized(
1087*da0073e9SAndroid Build Coastguard Worker                [0, 2, 3], scale=2, zero_point=3, dtype=torch.quint8, device=device
1088*da0073e9SAndroid Build Coastguard Worker            ),
1089*da0073e9SAndroid Build Coastguard Worker            torch._empty_affine_quantized(
1090*da0073e9SAndroid Build Coastguard Worker                [3, 0, 2], scale=2, zero_point=3, dtype=torch.quint8, device=device
1091*da0073e9SAndroid Build Coastguard Worker            ),
1092*da0073e9SAndroid Build Coastguard Worker        ]
1093*da0073e9SAndroid Build Coastguard Worker        _test_ravel(tensors, 0)
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker        tensors = [
1096*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 5, device=device),
1097*da0073e9SAndroid Build Coastguard Worker            torch._empty_affine_quantized(
1098*da0073e9SAndroid Build Coastguard Worker                [5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device
1099*da0073e9SAndroid Build Coastguard Worker            ),
1100*da0073e9SAndroid Build Coastguard Worker        ]
1101*da0073e9SAndroid Build Coastguard Worker        _test_ravel(tensors, 25, True)
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker    # TODO: this should be refactored into the view ops test suite
1104*da0073e9SAndroid Build Coastguard Worker    def test_empty_reshape(self, device):
1105*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0, 6, device=device)
1106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
1107*da0073e9SAndroid Build Coastguard Worker        # should be viewable -- i.e. data_ptr is the same.
1108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
1109*da0073e9SAndroid Build Coastguard Worker
1110*da0073e9SAndroid Build Coastguard Worker        # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
1111*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
1114*da0073e9SAndroid Build Coastguard Worker    def test_expand(self, device):
1115*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(1, 8, 1, device=device)
1116*da0073e9SAndroid Build Coastguard Worker        tensor2 = torch.rand(5, device=device)
1117*da0073e9SAndroid Build Coastguard Worker        template = torch.rand(4, 8, 5, device=device)
1118*da0073e9SAndroid Build Coastguard Worker        target = template.size()
1119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand_as(template).size(), target)
1120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand(4, 8, 5).size(), target)
1121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand(target).size(), target)
1122*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand_as(template).size(), target)
1123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
1124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand(target).size(), target)
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        # test double expand
1127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous
1130*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
1131*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(noncontig.is_contiguous())
1132*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1133*da0073e9SAndroid Build Coastguard Worker            noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)
1134*da0073e9SAndroid Build Coastguard Worker        )
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker        # make sure it's compatible with unsqueeze
1137*da0073e9SAndroid Build Coastguard Worker        expanded = tensor2.expand(1, 1, 5)
1138*da0073e9SAndroid Build Coastguard Worker        unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
1139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expanded, unsqueezed)
1140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expanded.stride(), unsqueezed.stride())
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker        # test -1 as target size
1143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
1144*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker        # test expanding empty to empty
1147*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1148*da0073e9SAndroid Build Coastguard Worker            torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)
1149*da0073e9SAndroid Build Coastguard Worker        )
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker    # TODO: this should be refactored into the view ops test suite
1152*da0073e9SAndroid Build Coastguard Worker    def test_view_empty(self, device):
1153*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(0, 6, device=device)
1154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Worker    # TODO: this should be refactored into the view ops test suite
1157*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1158*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self, device):
1159*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, device=device)
1160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
1161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
1162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
1163*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(4, 4, 4, device=device)[:, 0, :]
1166*da0073e9SAndroid Build Coastguard Worker        # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
1167*da0073e9SAndroid Build Coastguard Worker        if device != "meta":
1168*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
1169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
1170*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker        s = torch.randn((), device=device)
1173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
1174*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s.reshape(-1).shape, (1,))
1175*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: s.reshape(2))
1176*da0073e9SAndroid Build Coastguard Worker
1177*da0073e9SAndroid Build Coastguard Worker        empty = torch.tensor([], device=device)
1178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty, empty.reshape(-1))
1179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty, empty.reshape([0]))
1180*da0073e9SAndroid Build Coastguard Worker        # TODO: fix these once we have multi-dimensional empty tensors
1181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
1182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
1183*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: empty.reshape(1))
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, device=device)
1186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
1187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
1188*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1189*da0073e9SAndroid Build Coastguard Worker            RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))
1190*da0073e9SAndroid Build Coastguard Worker        )
1191*da0073e9SAndroid Build Coastguard Worker
1192*da0073e9SAndroid Build Coastguard Worker    def test_flatten(self, device):
1193*da0073e9SAndroid Build Coastguard Worker        # Test that flatten returns 1-dim tensor when given a 0-dim tensor
1194*da0073e9SAndroid Build Coastguard Worker        zero_dim_tensor = torch.tensor(123, device=device)
1195*da0073e9SAndroid Build Coastguard Worker        flat0 = zero_dim_tensor.flatten()
1196*da0073e9SAndroid Build Coastguard Worker        one_dim_tensor = torch.tensor([123], device=device)
1197*da0073e9SAndroid Build Coastguard Worker        flat1 = zero_dim_tensor.flatten()
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(zero_dim_tensor.shape, torch.Size([]))
1200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0.shape, torch.Size([1]))
1201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(one_dim_tensor.shape, torch.Size([1]))
1202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat1.shape, torch.Size([1]))
1203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0, one_dim_tensor)
1204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0, flat1)
1205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(flat0.shape, flat1.shape)
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker        # Test both float tensor and quantized tensor
1208*da0073e9SAndroid Build Coastguard Worker        tensors = [
1209*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 5, 5, 5, device=device),
1210*da0073e9SAndroid Build Coastguard Worker            torch._empty_affine_quantized(
1211*da0073e9SAndroid Build Coastguard Worker                [5, 5, 5, 5], scale=2, zero_point=3, dtype=torch.quint8, device=device
1212*da0073e9SAndroid Build Coastguard Worker            ),
1213*da0073e9SAndroid Build Coastguard Worker        ]
1214*da0073e9SAndroid Build Coastguard Worker        for src in tensors:
1215*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(0, -1)
1216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat.shape, torch.Size([625]))
1217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(src.view(-1), flat.view(-1))
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(0, 2)
1220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat.shape, torch.Size([125, 5]))
1221*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(src.view(-1), flat.view(-1))
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(0, 1)
1224*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat.shape, torch.Size([25, 5, 5]))
1225*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(src.view(-1), flat.view(-1))
1226*da0073e9SAndroid Build Coastguard Worker
1227*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(1, 2)
1228*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat.shape, torch.Size([5, 25, 5]))
1229*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(src.view(-1), flat.view(-1))
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(2, 3)
1232*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
1233*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(src.view(-1), flat.view(-1))
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(-2, -1)
1236*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat.shape, torch.Size([5, 5, 25]))
1237*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(src.view(-1), flat.view(-1))
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker            flat = src.flatten(2, 2)
1240*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat, src)
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker            # out of bounds index
1243*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1244*da0073e9SAndroid Build Coastguard Worker                src.flatten(5, 10)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker            # invalid start and end
1247*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1248*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "start_dim cannot come after end_dim"
1249*da0073e9SAndroid Build Coastguard Worker            ):
1250*da0073e9SAndroid Build Coastguard Worker                src.flatten(2, 0)
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker    # TODO: update to work on CUDA, too
1253*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1254*da0073e9SAndroid Build Coastguard Worker    def test_narrow(self, device):
1255*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
1257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
1258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
1259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
1260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
1261*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1262*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1263*da0073e9SAndroid Build Coastguard Worker        )
1264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
1265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker    # TODO: update to work on CUDA, too
1268*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1269*da0073e9SAndroid Build Coastguard Worker    def test_narrow_tensor(self, device):
1270*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
1271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
1272*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
1273*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, torch.tensor(0.0), 1)
1274*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
1275*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, torch.tensor([0]), 1)
1276*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
1277*da0073e9SAndroid Build Coastguard Worker            x.narrow(0, torch.tensor([0, 1]), 1)
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker    # TODO: make work on CUDA, too
1280*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1281*da0073e9SAndroid Build Coastguard Worker    def test_t(self, device):
1282*da0073e9SAndroid Build Coastguard Worker        # Test 0D tensors
1283*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(())
1284*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
1285*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
1286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        # Test 1D tensors
1289*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(4)
1290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
1291*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
1292*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.t())
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        # Test 2D tensors
1295*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2))
1296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.t(), x.transpose(0, 1))
1297*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
1298*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.t(), x.transpose(0, 1))
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker        # Test 3D tensor
1301*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((2, 2, 2))
1302*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1303*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "expects a tensor with <= 2 dimensions, but self is 3D"
1304*da0073e9SAndroid Build Coastguard Worker        ):
1305*da0073e9SAndroid Build Coastguard Worker            x.t()
1306*da0073e9SAndroid Build Coastguard Worker        x = x.to_sparse()
1307*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1308*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "expects a tensor with <= 2 sparse and 0 dense dimensions"
1309*da0073e9SAndroid Build Coastguard Worker        ):
1310*da0073e9SAndroid Build Coastguard Worker            x.t()
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1313*da0073e9SAndroid Build Coastguard Worker    def test_split(self, device):
1314*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(7, 4)
1315*da0073e9SAndroid Build Coastguard Worker        split_size = 3
1316*da0073e9SAndroid Build Coastguard Worker        dim = 0
1317*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([3, 4], [3, 4], [1, 4])
1318*da0073e9SAndroid Build Coastguard Worker        splits = tensor.split(split_size, dim)
1319*da0073e9SAndroid Build Coastguard Worker        start = 0
1320*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
1321*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
1322*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1323*da0073e9SAndroid Build Coastguard Worker                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1324*da0073e9SAndroid Build Coastguard Worker            )
1325*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker        # Variable sections split
1328*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(20, 10)
1329*da0073e9SAndroid Build Coastguard Worker        dim = 0
1330*da0073e9SAndroid Build Coastguard Worker        split_sizes = [5, 5, 10]
1331*da0073e9SAndroid Build Coastguard Worker        target_sizes = [[5, 10], [5, 10], [10, 10]]
1332*da0073e9SAndroid Build Coastguard Worker        splits = tensor.split(split_sizes, dim)
1333*da0073e9SAndroid Build Coastguard Worker        start = 0
1334*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
1335*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
1336*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1337*da0073e9SAndroid Build Coastguard Worker                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1338*da0073e9SAndroid Build Coastguard Worker            )
1339*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker        split_sizes = [2, 2, 6]
1342*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([20, 2], [20, 2], [20, 6])
1343*da0073e9SAndroid Build Coastguard Worker        dim = 1
1344*da0073e9SAndroid Build Coastguard Worker        splits = tensor.split(split_sizes, dim)
1345*da0073e9SAndroid Build Coastguard Worker        start = 0
1346*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
1347*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
1348*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1349*da0073e9SAndroid Build Coastguard Worker                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1350*da0073e9SAndroid Build Coastguard Worker            )
1351*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1354*da0073e9SAndroid Build Coastguard Worker    def test_chunk(self, device):
1355*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(4, 7)
1356*da0073e9SAndroid Build Coastguard Worker        num_chunks = 3
1357*da0073e9SAndroid Build Coastguard Worker        dim = 1
1358*da0073e9SAndroid Build Coastguard Worker        target_sizes = ([4, 3], [4, 3], [4, 1])
1359*da0073e9SAndroid Build Coastguard Worker        splits = tensor.chunk(num_chunks, dim)
1360*da0073e9SAndroid Build Coastguard Worker        start = 0
1361*da0073e9SAndroid Build Coastguard Worker        for target_size, split in zip(target_sizes, splits):
1362*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(split.size(), target_size)
1363*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
1364*da0073e9SAndroid Build Coastguard Worker                tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0
1365*da0073e9SAndroid Build Coastguard Worker            )
1366*da0073e9SAndroid Build Coastguard Worker            start = start + target_size[dim]
1367*da0073e9SAndroid Build Coastguard Worker
1368*da0073e9SAndroid Build Coastguard Worker        # Invalid chunk sizes
1369*da0073e9SAndroid Build Coastguard Worker        error_regex = "chunk expects.*greater than 0"
1370*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_regex):
1371*da0073e9SAndroid Build Coastguard Worker            tensor.chunk(0)
1372*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_regex):
1373*da0073e9SAndroid Build Coastguard Worker            tensor.chunk(-2)
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker    # TODO: make work on CUDA, too
1376*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
1377*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1378*da0073e9SAndroid Build Coastguard Worker    def test_unsqueeze(self, device) -> None:
1379*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4)
1380*da0073e9SAndroid Build Coastguard Worker        y = x.unsqueeze(1)
1381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.view(2, 1, 3, 4))
1382*da0073e9SAndroid Build Coastguard Worker        y = x.clone().unsqueeze_(2)
1383*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.view(2, 3, 1, 4))
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        x = x[:, 1]
1386*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.is_contiguous())
1387*da0073e9SAndroid Build Coastguard Worker        y = x.unsqueeze(1)
1388*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.contiguous().view(2, 1, 4))
1389*da0073e9SAndroid Build Coastguard Worker        y = x.clone().unsqueeze_(2)
1390*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.contiguous().view(2, 4, 1))
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker    # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
1393*da0073e9SAndroid Build Coastguard Worker    def test_big_transpose(self, device):
1394*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(456, 789, device=device)
1395*da0073e9SAndroid Build Coastguard Worker        t1 = t.t().contiguous()
1396*da0073e9SAndroid Build Coastguard Worker        t2 = torch.from_numpy(t.cpu().numpy().transpose())
1397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, t2)
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker    def test_T(self, device):
1400*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, device=device)
1401*da0073e9SAndroid Build Coastguard Worker        t1 = a.T
1402*da0073e9SAndroid Build Coastguard Worker        t2 = a.permute(2, 1, 0)
1403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2, t1)
1404*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, device=device)
1405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, b.T)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1408*da0073e9SAndroid Build Coastguard Worker    def test_transposes(self, device, dtype):
1409*da0073e9SAndroid Build Coastguard Worker        for op in ("T", "H", "mT", "mH", "adjoint"):
1410*da0073e9SAndroid Build Coastguard Worker            shapes = (
1411*da0073e9SAndroid Build Coastguard Worker                ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),)
1412*da0073e9SAndroid Build Coastguard Worker            )
1413*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
1414*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, device=device, dtype=dtype)
1415*da0073e9SAndroid Build Coastguard Worker                t1 = getattr(a, op)
1416*da0073e9SAndroid Build Coastguard Worker                if op == "adjoint":
1417*da0073e9SAndroid Build Coastguard Worker                    t1 = t1()
1418*da0073e9SAndroid Build Coastguard Worker                t2 = a
1419*da0073e9SAndroid Build Coastguard Worker                t2 = t2.transpose(-2, -1)
1420*da0073e9SAndroid Build Coastguard Worker                if op[-1] == "H" or op == "adjoint":
1421*da0073e9SAndroid Build Coastguard Worker                    t2 = t2.conj()
1422*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t2, t1)
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1425*da0073e9SAndroid Build Coastguard Worker    def test_transposes_errors(self, device, dtype):
1426*da0073e9SAndroid Build Coastguard Worker        for op in ("H", "mT", "mH", "adjoint"):
1427*da0073e9SAndroid Build Coastguard Worker            shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
1428*da0073e9SAndroid Build Coastguard Worker            for shape in shapes:
1429*da0073e9SAndroid Build Coastguard Worker                a = make_tensor(shape, device=device, dtype=dtype)
1430*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
1431*da0073e9SAndroid Build Coastguard Worker                    t1 = getattr(a, op)
1432*da0073e9SAndroid Build Coastguard Worker                    if op == "adjoint":
1433*da0073e9SAndroid Build Coastguard Worker                        t1 = t1()
1434*da0073e9SAndroid Build Coastguard Worker
1435*da0073e9SAndroid Build Coastguard Worker    def test_python_types(self, device):
1436*da0073e9SAndroid Build Coastguard Worker        a1 = torch.randn((1, 2), device=device, dtype=torch.float64)
1437*da0073e9SAndroid Build Coastguard Worker        a2 = torch.randn((1, 2), device=device, dtype=float)
1438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a1.dtype, a2.dtype)
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
1441*da0073e9SAndroid Build Coastguard Worker        b2 = torch.arange(10, 20, dtype=int, device=device)
1442*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b1.dtype, b2.dtype)
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker        c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
1445*da0073e9SAndroid Build Coastguard Worker        c2 = torch.tensor([True, False], dtype=bool, device=device)
1446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c1.dtype, c2.dtype)
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker    # TODO: is resize best put in test_view_ops?
1449*da0073e9SAndroid Build Coastguard Worker    def test_resize_as_preserves_strides(self, device):
1450*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(2, 3).t()
1451*da0073e9SAndroid Build Coastguard Worker        old_strides = x.stride()
1452*da0073e9SAndroid Build Coastguard Worker        x.resize_as_(x)
1453*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.stride(), old_strides)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_resize_as(self, device):
1456*da0073e9SAndroid Build Coastguard Worker        def test_helper(shape, memory_format, device):
1457*da0073e9SAndroid Build Coastguard Worker            xc = torch.randn(shape, device=device).contiguous(
1458*da0073e9SAndroid Build Coastguard Worker                memory_format=memory_format
1459*da0073e9SAndroid Build Coastguard Worker            )
1460*da0073e9SAndroid Build Coastguard Worker            flat = torch.randn(xc.numel(), device=device)
1461*da0073e9SAndroid Build Coastguard Worker            flat.resize_as_(xc, memory_format=torch.preserve_format)
1462*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
1463*da0073e9SAndroid Build Coastguard Worker
1464*da0073e9SAndroid Build Coastguard Worker        test_helper((10, 3, 32, 32), torch.channels_last, device)
1465*da0073e9SAndroid Build Coastguard Worker        test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device)
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker    def test_memory_format_resize_(self, device):
1468*da0073e9SAndroid Build Coastguard Worker        def test_helper(shape, numel, memory_format, device):
1469*da0073e9SAndroid Build Coastguard Worker            flat = torch.randn(numel, device=device)
1470*da0073e9SAndroid Build Coastguard Worker            flat.resize_(shape, memory_format=memory_format)
1471*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker        test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device)
1474*da0073e9SAndroid Build Coastguard Worker        test_helper(
1475*da0073e9SAndroid Build Coastguard Worker            (3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device
1476*da0073e9SAndroid Build Coastguard Worker        )
1477*da0073e9SAndroid Build Coastguard Worker
1478*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1479*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex128)
1480*da0073e9SAndroid Build Coastguard Worker    def test_transpose_invalid(self, device, dtype):
1481*da0073e9SAndroid Build Coastguard Worker        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
1482*da0073e9SAndroid Build Coastguard Worker            shape = _rand_shape(4, min_size=5, max_size=10)
1483*da0073e9SAndroid Build Coastguard Worker            x = _generate_input(shape, dtype, device, False)
1484*da0073e9SAndroid Build Coastguard Worker
1485*da0073e9SAndroid Build Coastguard Worker            # Invalid `source` and `destination` dimension
1486*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1487*da0073e9SAndroid Build Coastguard Worker                fn(x, 5, 0)
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1490*da0073e9SAndroid Build Coastguard Worker                fn(x, 0, 5)
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex128)
1493*da0073e9SAndroid Build Coastguard Worker    def test_transpose_vs_numpy(self, device, dtype):
1494*da0073e9SAndroid Build Coastguard Worker        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
1495*da0073e9SAndroid Build Coastguard Worker            for nd in range(5):
1496*da0073e9SAndroid Build Coastguard Worker                shape = _rand_shape(nd, min_size=5, max_size=10)
1497*da0073e9SAndroid Build Coastguard Worker                x = _generate_input(shape, dtype, device, with_extremal=False)
1498*da0073e9SAndroid Build Coastguard Worker                for random_negative in [True, False]:
1499*da0073e9SAndroid Build Coastguard Worker                    for src_dim, dst_dim in permutations(range(nd), r=2):
1500*da0073e9SAndroid Build Coastguard Worker                        random_prob = random.random()
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker                        if random_negative and random_prob > 0.66:
1503*da0073e9SAndroid Build Coastguard Worker                            src_dim = src_dim - nd
1504*da0073e9SAndroid Build Coastguard Worker                        elif random_negative and random_prob > 0.33:
1505*da0073e9SAndroid Build Coastguard Worker                            dst_dim = dst_dim - nd
1506*da0073e9SAndroid Build Coastguard Worker                        elif random_negative:
1507*da0073e9SAndroid Build Coastguard Worker                            src_dim = src_dim - nd
1508*da0073e9SAndroid Build Coastguard Worker                            dst_dim = dst_dim - nd
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker                        partial_map = {
1511*da0073e9SAndroid Build Coastguard Worker                            torch.swapdims: partial(
1512*da0073e9SAndroid Build Coastguard Worker                                torch.swapdims, dim0=src_dim, dim1=dst_dim
1513*da0073e9SAndroid Build Coastguard Worker                            ),
1514*da0073e9SAndroid Build Coastguard Worker                            torch.swapaxes: partial(
1515*da0073e9SAndroid Build Coastguard Worker                                torch.swapaxes, axis0=src_dim, axis1=dst_dim
1516*da0073e9SAndroid Build Coastguard Worker                            ),
1517*da0073e9SAndroid Build Coastguard Worker                            torch.transpose: partial(
1518*da0073e9SAndroid Build Coastguard Worker                                torch.transpose, dim0=src_dim, dim1=dst_dim
1519*da0073e9SAndroid Build Coastguard Worker                            ),
1520*da0073e9SAndroid Build Coastguard Worker                        }
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Worker                        torch_fn = partial_map[fn]
1523*da0073e9SAndroid Build Coastguard Worker                        np_fn = partial(np.swapaxes, axis1=src_dim, axis2=dst_dim)
1524*da0073e9SAndroid Build Coastguard Worker                        self.compare_with_numpy(
1525*da0073e9SAndroid Build Coastguard Worker                            torch_fn, np_fn, x, device=None, dtype=None
1526*da0073e9SAndroid Build Coastguard Worker                        )
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker            # Move dim to same position
1529*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 3, 5, 7, 11)
1530*da0073e9SAndroid Build Coastguard Worker            partial_map = {
1531*da0073e9SAndroid Build Coastguard Worker                torch.swapdims: partial(torch.swapdims, dim0=0, dim1=0),
1532*da0073e9SAndroid Build Coastguard Worker                torch.swapaxes: partial(torch.swapaxes, axis0=0, axis1=0),
1533*da0073e9SAndroid Build Coastguard Worker                torch.transpose: partial(torch.transpose, dim0=0, dim1=0),
1534*da0073e9SAndroid Build Coastguard Worker            }
1535*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial_map[fn]
1536*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.swapaxes, axis1=0, axis2=0)
1537*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker    def _test_atleast_dim(self, torch_fn, np_fn, device, dtype):
1540*da0073e9SAndroid Build Coastguard Worker        for ndims in range(0, 5):
1541*da0073e9SAndroid Build Coastguard Worker            shape = _rand_shape(ndims, min_size=5, max_size=10)
1542*da0073e9SAndroid Build Coastguard Worker            for n in range(ndims + 1):
1543*da0073e9SAndroid Build Coastguard Worker                for with_extremal in [False, True]:
1544*da0073e9SAndroid Build Coastguard Worker                    for contiguous in [False, True]:
1545*da0073e9SAndroid Build Coastguard Worker                        # Generate Input.
1546*da0073e9SAndroid Build Coastguard Worker                        x = _generate_input(shape, dtype, device, with_extremal)
1547*da0073e9SAndroid Build Coastguard Worker                        if contiguous:
1548*da0073e9SAndroid Build Coastguard Worker                            x = x.T
1549*da0073e9SAndroid Build Coastguard Worker                        self.compare_with_numpy(
1550*da0073e9SAndroid Build Coastguard Worker                            torch_fn, np_fn, x, device=None, dtype=None
1551*da0073e9SAndroid Build Coastguard Worker                        )
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker                        # Compare sequence input
1554*da0073e9SAndroid Build Coastguard Worker                        torch_sequence_x = (x,) * random.randint(3, 10)
1555*da0073e9SAndroid Build Coastguard Worker                        np_sequence_x = tuple(
1556*da0073e9SAndroid Build Coastguard Worker                            np.array(x.detach().cpu().numpy()) for x in torch_sequence_x
1557*da0073e9SAndroid Build Coastguard Worker                        )
1558*da0073e9SAndroid Build Coastguard Worker                        torch_res = torch_fn(*torch_sequence_x)
1559*da0073e9SAndroid Build Coastguard Worker                        np_res = np_fn(*np_sequence_x)
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker                        torch_res = tuple(x.cpu() for x in torch_res)
1562*da0073e9SAndroid Build Coastguard Worker                        np_res = tuple(torch.from_numpy(x) for x in np_res)
1563*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(np_res, torch_res)
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Worker    # TODO: are these view ops?
1566*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half))
1567*da0073e9SAndroid Build Coastguard Worker    def test_atleast(self, device, dtype):
1568*da0073e9SAndroid Build Coastguard Worker        self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype)
1569*da0073e9SAndroid Build Coastguard Worker        self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype)
1570*da0073e9SAndroid Build Coastguard Worker        self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype)
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker    # TODO: OpInfo this
1573*da0073e9SAndroid Build Coastguard Worker    def _test_atleast(self, device, torch_fn):
1574*da0073e9SAndroid Build Coastguard Worker        # 0-dim
1575*da0073e9SAndroid Build Coastguard Worker        s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: torch_fn(x), s)
1578*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: torch_fn(x), s)
1579*da0073e9SAndroid Build Coastguard Worker
1580*da0073e9SAndroid Build Coastguard Worker        # 1-dim
1581*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4, dtype=torch.double, requires_grad=True)
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: torch_fn(x), a)
1584*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: torch_fn(x), a)
1585*da0073e9SAndroid Build Coastguard Worker
1586*da0073e9SAndroid Build Coastguard Worker        # 2,3,4-dim
1587*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
1588*da0073e9SAndroid Build Coastguard Worker        c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
1589*da0073e9SAndroid Build Coastguard Worker        d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Worker        input_tuple = (s, a, b, c, d)
1592*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
1593*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker    def test_atleast_gradient(self, device):
1596*da0073e9SAndroid Build Coastguard Worker        self._test_atleast(device, torch.atleast_1d)
1597*da0073e9SAndroid Build Coastguard Worker        self._test_atleast(device, torch.atleast_2d)
1598*da0073e9SAndroid Build Coastguard Worker        self._test_atleast(device, torch.atleast_3d)
1599*da0073e9SAndroid Build Coastguard Worker
1600*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1601*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1602*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_tensors(self, device, dtype):
1603*da0073e9SAndroid Build Coastguard Worker        x0 = torch.randn(2, 1, 3, dtype=dtype, device=device)
1604*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(3, dtype=dtype, device=device)
1605*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn(3, 1, dtype=dtype, device=device)
1606*da0073e9SAndroid Build Coastguard Worker        expected_size = (2, 3, 3)
1607*da0073e9SAndroid Build Coastguard Worker
1608*da0073e9SAndroid Build Coastguard Worker        y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2)
1609*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y0.size() == expected_size)
1610*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y1.size() == expected_size)
1611*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y2.size() == expected_size)
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1614*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_shapes(self, device):
1615*da0073e9SAndroid Build Coastguard Worker        examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)]
1616*da0073e9SAndroid Build Coastguard Worker        for s0 in examples:
1617*da0073e9SAndroid Build Coastguard Worker            x0 = torch.randn(s0)
1618*da0073e9SAndroid Build Coastguard Worker            expected = torch.broadcast_tensors(x0)[0].shape
1619*da0073e9SAndroid Build Coastguard Worker            actual = torch.broadcast_shapes(s0)
1620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, actual)
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker            for s1 in examples:
1623*da0073e9SAndroid Build Coastguard Worker                x1 = torch.randn(s1)
1624*da0073e9SAndroid Build Coastguard Worker                expected = torch.broadcast_tensors(x0, x1)[0].shape
1625*da0073e9SAndroid Build Coastguard Worker                actual = torch.broadcast_shapes(s0, s1)
1626*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, actual)
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker        inputs_list = [[1, 4], [4, 1], [1, 1, 3]]
1629*da0073e9SAndroid Build Coastguard Worker        for integral_inputs in inputs_list:
1630*da0073e9SAndroid Build Coastguard Worker            res1 = torch.broadcast_shapes(*integral_inputs)
1631*da0073e9SAndroid Build Coastguard Worker            res2 = torch.broadcast_tensors(*map(torch.empty, integral_inputs))[0].shape
1632*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker        inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11]]
1635*da0073e9SAndroid Build Coastguard Worker        for integral_inputs_with_neg_vals in inputs_with_neg_vals:
1636*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1637*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Trying to create tensor with negative dimension"
1638*da0073e9SAndroid Build Coastguard Worker            ):
1639*da0073e9SAndroid Build Coastguard Worker                torch.broadcast_shapes(*integral_inputs_with_neg_vals)
1640*da0073e9SAndroid Build Coastguard Worker
1641*da0073e9SAndroid Build Coastguard Worker        integral_inputs_error_case = [(3, 5), (2, 4, 1)]
1642*da0073e9SAndroid Build Coastguard Worker        for error_input in integral_inputs_error_case:
1643*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1644*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1645*da0073e9SAndroid Build Coastguard Worker                "Shape mismatch: objects cannot be broadcast to a single shape",
1646*da0073e9SAndroid Build Coastguard Worker            ):
1647*da0073e9SAndroid Build Coastguard Worker                torch.broadcast_shapes(*error_input)
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker        negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)]
1650*da0073e9SAndroid Build Coastguard Worker        for s0 in negative_inputs:
1651*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1652*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Trying to create tensor with negative dimension"
1653*da0073e9SAndroid Build Coastguard Worker            ):
1654*da0073e9SAndroid Build Coastguard Worker                torch.broadcast_shapes(s0)
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Worker            for s1 in negative_inputs:
1657*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
1658*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "Trying to create tensor with negative dimension"
1659*da0073e9SAndroid Build Coastguard Worker                ):
1660*da0073e9SAndroid Build Coastguard Worker                    torch.broadcast_shapes(s0, s1)
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker        float_inputs_error_case = [(1.1, 2.0), (1.1, 1.0)]
1663*da0073e9SAndroid Build Coastguard Worker        for error_case in float_inputs_error_case:
1664*da0073e9SAndroid Build Coastguard Worker            for float_input in error_case:
1665*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
1666*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1667*da0073e9SAndroid Build Coastguard Worker                    "Input shapes "
1668*da0073e9SAndroid Build Coastguard Worker                    "should be of type ints, a tuple of ints, or a list of ints",
1669*da0073e9SAndroid Build Coastguard Worker                ):
1670*da0073e9SAndroid Build Coastguard Worker                    torch.broadcast_shapes(float_input)
1671*da0073e9SAndroid Build Coastguard Worker
1672*da0073e9SAndroid Build Coastguard Worker        diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))]
1673*da0073e9SAndroid Build Coastguard Worker        for s0 in diff_input_types:
1674*da0073e9SAndroid Build Coastguard Worker            res1 = torch.broadcast_shapes(*s0)
1675*da0073e9SAndroid Build Coastguard Worker            res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape
1676*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res1, res2)
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker    # Skip BFloat16 since numpy does not support it
1679*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1680*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_to(self, device, dtype):
1681*da0073e9SAndroid Build Coastguard Worker        def can_broadcast(s0, s1):
1682*da0073e9SAndroid Build Coastguard Worker            # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension
1683*da0073e9SAndroid Build Coastguard Worker            s0 = tuple(reversed(s0))
1684*da0073e9SAndroid Build Coastguard Worker            s1 = tuple(reversed(s1))
1685*da0073e9SAndroid Build Coastguard Worker            for i in range(len(s0)):
1686*da0073e9SAndroid Build Coastguard Worker                if s0[i] != 1 and s0[i] != s1[i]:
1687*da0073e9SAndroid Build Coastguard Worker                    return False
1688*da0073e9SAndroid Build Coastguard Worker            return True
1689*da0073e9SAndroid Build Coastguard Worker
1690*da0073e9SAndroid Build Coastguard Worker        sizes = ((), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2))
1691*da0073e9SAndroid Build Coastguard Worker        for s0, s1 in combinations(sizes, r=2):
1692*da0073e9SAndroid Build Coastguard Worker            t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9)
1693*da0073e9SAndroid Build Coastguard Worker            t_np = t.cpu().numpy()
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker            if can_broadcast(s0, s1):
1696*da0073e9SAndroid Build Coastguard Worker                res = torch.broadcast_to(t, s1)
1697*da0073e9SAndroid Build Coastguard Worker                np_res = np.broadcast_to(t_np, s1)
1698*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, np_res)
1699*da0073e9SAndroid Build Coastguard Worker            else:
1700*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
1701*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
1702*da0073e9SAndroid Build Coastguard Worker                    r"The expanded size of the tensor \(\d\) "
1703*da0073e9SAndroid Build Coastguard Worker                    r"must match the existing size \(\d\)",
1704*da0073e9SAndroid Build Coastguard Worker                ):
1705*da0073e9SAndroid Build Coastguard Worker                    torch.broadcast_to(t, s1)
1706*da0073e9SAndroid Build Coastguard Worker
1707*da0073e9SAndroid Build Coastguard Worker    def test_view(self, device):
1708*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(15, device=device)
1709*da0073e9SAndroid Build Coastguard Worker        template = torch.rand(3, 5, device=device)
1710*da0073e9SAndroid Build Coastguard Worker        empty = torch.empty(0, device=device)
1711*da0073e9SAndroid Build Coastguard Worker        target = template.size()
1712*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view_as(template).size(), target)
1713*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(3, 5).size(), target)
1714*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
1715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(-1, 5).size(), target)
1716*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(3, -1).size(), target)
1717*da0073e9SAndroid Build Coastguard Worker        tensor_view = tensor.view(5, 3)
1718*da0073e9SAndroid Build Coastguard Worker        tensor_view.fill_(random.uniform(0, 1))
1719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view_as(empty), empty)
1720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(0), empty)
1721*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
1722*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker        # test size inference with empty tensors
1725*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(-1).size(), torch.Size([0]))
1726*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1729*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"because the unspecified dimension size -1 can be any value"
1730*da0073e9SAndroid Build Coastguard Worker        ):
1731*da0073e9SAndroid Build Coastguard Worker            empty.view(-1, 0)
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1734*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"because the unspecified dimension size -1 can be any value"
1735*da0073e9SAndroid Build Coastguard Worker        ):
1736*da0073e9SAndroid Build Coastguard Worker            empty.view(3, 0, -1, 0)
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
1739*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
1740*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker        # test view when tensor is not contiguous in every dimension, but only
1743*da0073e9SAndroid Build Coastguard Worker        # contiguous dimensions are touched.
1744*da0073e9SAndroid Build Coastguard Worker        tensor = (
1745*da0073e9SAndroid Build Coastguard Worker            torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device)
1746*da0073e9SAndroid Build Coastguard Worker            .transpose(-1, 2)
1747*da0073e9SAndroid Build Coastguard Worker            .transpose(-2, 3)
1748*da0073e9SAndroid Build Coastguard Worker        )
1749*da0073e9SAndroid Build Coastguard Worker        # size:                      [   4,    2,    3,    9,    6,    2,    1,    5]
1750*da0073e9SAndroid Build Coastguard Worker        # stride:                    [3840, 1620,    1,    3,   54,   27,  324,  324]
1751*da0073e9SAndroid Build Coastguard Worker        # contiguous dim chunks:     [__________, ____, ____, __________, ____, ____]
1752*da0073e9SAndroid Build Coastguard Worker        # merging 1 to chunk after:  [__________, ____, ____, __________, __________]
1753*da0073e9SAndroid Build Coastguard Worker        contig_tensor = tensor.clone()
1754*da0073e9SAndroid Build Coastguard Worker        # [4, 2] => [8, 1]
1755*da0073e9SAndroid Build Coastguard Worker        # [3] => [3]
1756*da0073e9SAndroid Build Coastguard Worker        # [9] => [3, 3]
1757*da0073e9SAndroid Build Coastguard Worker        # [6, 2] => [4, 1, 3]
1758*da0073e9SAndroid Build Coastguard Worker        # [1, 5] => [5]
1759*da0073e9SAndroid Build Coastguard Worker        view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5]
1760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
1761*da0073e9SAndroid Build Coastguard Worker        # [4, 2] => [2, 4]
1762*da0073e9SAndroid Build Coastguard Worker        # [3] => [3]
1763*da0073e9SAndroid Build Coastguard Worker        # [9] => [1, 9]
1764*da0073e9SAndroid Build Coastguard Worker        # [6, 2] => [2, 2, 3]
1765*da0073e9SAndroid Build Coastguard Worker        # [1, 5] => [5, 1]
1766*da0073e9SAndroid Build Coastguard Worker        view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1]
1767*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
1768*da0073e9SAndroid Build Coastguard Worker        # adding size 1 dims
1769*da0073e9SAndroid Build Coastguard Worker        view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1]
1770*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size))
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker        # invalid views
1773*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(-1))
1774*da0073e9SAndroid Build Coastguard Worker        # crossing [4, 2], [3]
1775*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5))
1776*da0073e9SAndroid Build Coastguard Worker        # crossing [6, 2], [1, 5]
1777*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10))
1778*da0073e9SAndroid Build Coastguard Worker        # crossing [9], [6, 2]
1779*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5))
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Worker        # view with stride 0 dims
1782*da0073e9SAndroid Build Coastguard Worker        tensor = torch.empty(1, 1, device=device).expand(
1783*da0073e9SAndroid Build Coastguard Worker            3, 4
1784*da0073e9SAndroid Build Coastguard Worker        )  # all dims are contiguous
1785*da0073e9SAndroid Build Coastguard Worker        contig_tensor = tensor.clone()
1786*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(-1), contig_tensor.view(-1))
1787*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1))
1788*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1))
1789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
1790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1793*da0073e9SAndroid Build Coastguard Worker    def test_reshape_view_semantics(self, device, dtype):
1794*da0073e9SAndroid Build Coastguard Worker        tensor = make_tensor((15, 4), dtype=dtype, device=device)
1795*da0073e9SAndroid Build Coastguard Worker        target = (20, 3)
1796*da0073e9SAndroid Build Coastguard Worker
1797*da0073e9SAndroid Build Coastguard Worker        # Cases where the tensor can be returned as a view.
1798*da0073e9SAndroid Build Coastguard Worker        view_tensor = tensor.reshape(target)
1799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((view_tensor.size()), target)
1800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr())
1801*da0073e9SAndroid Build Coastguard Worker
1802*da0073e9SAndroid Build Coastguard Worker        # Cases where the tensor must be copied (transpose makes it non-contiguous forcing
1803*da0073e9SAndroid Build Coastguard Worker        # the copy).
1804*da0073e9SAndroid Build Coastguard Worker        copy_tensor = tensor.transpose(0, 1).reshape(target)
1805*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(copy_tensor.size(), target)
1806*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
1807*da0073e9SAndroid Build Coastguard Worker            tensor.storage().data_ptr(), copy_tensor.storage().data_ptr()
1808*da0073e9SAndroid Build Coastguard Worker        )
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Worker    def test_contiguous(self, device):
1811*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 16, 5, 5, device=device)
1812*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.is_contiguous())
1813*da0073e9SAndroid Build Coastguard Worker        stride = list(x.stride())
1814*da0073e9SAndroid Build Coastguard Worker        stride[0] = 20
1815*da0073e9SAndroid Build Coastguard Worker        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
1816*da0073e9SAndroid Build Coastguard Worker        x.set_(x.storage(), 0, x.size(), stride)
1817*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.is_contiguous())
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1820*da0073e9SAndroid Build Coastguard Worker    # Skip BFloat16 since numpy does not support it
1821*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1822*da0073e9SAndroid Build Coastguard Worker    def test_tensor_split_sections(self, device, dtype):
1823*da0073e9SAndroid Build Coastguard Worker        input_sizes = [
1824*da0073e9SAndroid Build Coastguard Worker            (0,),
1825*da0073e9SAndroid Build Coastguard Worker            (10,),
1826*da0073e9SAndroid Build Coastguard Worker            (10, 0),
1827*da0073e9SAndroid Build Coastguard Worker            (0, 10),
1828*da0073e9SAndroid Build Coastguard Worker            (4, 10),
1829*da0073e9SAndroid Build Coastguard Worker            (12, 3),
1830*da0073e9SAndroid Build Coastguard Worker        ]
1831*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
1832*da0073e9SAndroid Build Coastguard Worker            a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1833*da0073e9SAndroid Build Coastguard Worker            # Run tests on transposed input if it has at least 2 dims
1834*da0073e9SAndroid Build Coastguard Worker            for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
1835*da0073e9SAndroid Build Coastguard Worker                a_n = a.cpu().numpy()
1836*da0073e9SAndroid Build Coastguard Worker                for dim in range(-a.dim(), a.dim()):
1837*da0073e9SAndroid Build Coastguard Worker                    for sections in range(1, 2 * a.size(dim)):
1838*da0073e9SAndroid Build Coastguard Worker                        msg = f"input_size {input_size}, sections {sections}, dim {dim}"
1839*da0073e9SAndroid Build Coastguard Worker                        result1 = torch.tensor_split(a, sections, dim)
1840*da0073e9SAndroid Build Coastguard Worker                        result2 = torch.tensor_split(
1841*da0073e9SAndroid Build Coastguard Worker                            a, torch.tensor(sections, dtype=torch.int64), dim
1842*da0073e9SAndroid Build Coastguard Worker                        )
1843*da0073e9SAndroid Build Coastguard Worker                        for r1, r2 in zip(result1, result2):
1844*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r1.device, torch.device(device), msg=msg)
1845*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r1.dtype, dtype, msg=msg)
1846*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r2.device, torch.device(device), msg=msg)
1847*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r2.dtype, dtype, msg=msg)
1848*da0073e9SAndroid Build Coastguard Worker                        result_n = np.array_split(a_n, sections, dim)
1849*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_n, result1, msg=msg)
1850*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_n, result2, msg=msg)
1851*da0073e9SAndroid Build Coastguard Worker
1852*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1853*da0073e9SAndroid Build Coastguard Worker    # Skip BFloat16 since numpy does not support it
1854*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1855*da0073e9SAndroid Build Coastguard Worker    def test_tensor_split_indices(self, device, dtype):
1856*da0073e9SAndroid Build Coastguard Worker        input_sizes = [
1857*da0073e9SAndroid Build Coastguard Worker            (0,),
1858*da0073e9SAndroid Build Coastguard Worker            (10,),
1859*da0073e9SAndroid Build Coastguard Worker            (10, 0),
1860*da0073e9SAndroid Build Coastguard Worker            (0, 10),
1861*da0073e9SAndroid Build Coastguard Worker            (4, 10),
1862*da0073e9SAndroid Build Coastguard Worker            (12, 3),
1863*da0073e9SAndroid Build Coastguard Worker        ]
1864*da0073e9SAndroid Build Coastguard Worker        indices_args = [
1865*da0073e9SAndroid Build Coastguard Worker            (),
1866*da0073e9SAndroid Build Coastguard Worker            (0,),
1867*da0073e9SAndroid Build Coastguard Worker            (3,),
1868*da0073e9SAndroid Build Coastguard Worker            (10,),
1869*da0073e9SAndroid Build Coastguard Worker            (-1,),
1870*da0073e9SAndroid Build Coastguard Worker            (-10,),
1871*da0073e9SAndroid Build Coastguard Worker            (2, -1),
1872*da0073e9SAndroid Build Coastguard Worker            (3, 4, 10),
1873*da0073e9SAndroid Build Coastguard Worker            (0, -1, 0, 10),
1874*da0073e9SAndroid Build Coastguard Worker            (1, 5, 2, 8),
1875*da0073e9SAndroid Build Coastguard Worker        ]
1876*da0073e9SAndroid Build Coastguard Worker        for input_size in input_sizes:
1877*da0073e9SAndroid Build Coastguard Worker            a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1878*da0073e9SAndroid Build Coastguard Worker            # Run tests on transposed input if it has at least 2 dims
1879*da0073e9SAndroid Build Coastguard Worker            for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]:
1880*da0073e9SAndroid Build Coastguard Worker                a_n = a.cpu().numpy()
1881*da0073e9SAndroid Build Coastguard Worker                for dim in range(-a.dim(), a.dim()):
1882*da0073e9SAndroid Build Coastguard Worker                    for indices in indices_args:
1883*da0073e9SAndroid Build Coastguard Worker                        result_1 = torch.tensor_split(a, indices, dim)
1884*da0073e9SAndroid Build Coastguard Worker                        result_2 = torch.tensor_split(
1885*da0073e9SAndroid Build Coastguard Worker                            a, torch.tensor(indices, dtype=torch.int64), dim
1886*da0073e9SAndroid Build Coastguard Worker                        )
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker                        msg = f"input_size {input_size}, indices {indices}, dim {dim}"
1889*da0073e9SAndroid Build Coastguard Worker                        for r1, r2 in zip(result_1, result_2):
1890*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r1.device, torch.device(device), msg=msg)
1891*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r1.dtype, dtype, msg=msg)
1892*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r2.device, torch.device(device), msg=msg)
1893*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(r2.dtype, dtype, msg=msg)
1894*da0073e9SAndroid Build Coastguard Worker
1895*da0073e9SAndroid Build Coastguard Worker                        result_n = np.array_split(a_n, indices, dim)
1896*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_n, result_1, msg=msg)
1897*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result_n, result_2, msg=msg)
1898*da0073e9SAndroid Build Coastguard Worker
1899*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1900*da0073e9SAndroid Build Coastguard Worker    def test_tensor_split_errors(self, device):
1901*da0073e9SAndroid Build Coastguard Worker        S = 10
1902*da0073e9SAndroid Build Coastguard Worker        test_cases = [
1903*da0073e9SAndroid Build Coastguard Worker            # input size, sections or indices, dim, error type, error message, numpy error type
1904*da0073e9SAndroid Build Coastguard Worker            [(S,), 10, 1, IndexError, r"Dimension out of range", IndexError],
1905*da0073e9SAndroid Build Coastguard Worker            [
1906*da0073e9SAndroid Build Coastguard Worker                (),
1907*da0073e9SAndroid Build Coastguard Worker                10,
1908*da0073e9SAndroid Build Coastguard Worker                0,
1909*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1910*da0073e9SAndroid Build Coastguard Worker                r"tensor_split expected at least a 1-dimensional tensor, "
1911*da0073e9SAndroid Build Coastguard Worker                + "but got a tensor with 0 dims",
1912*da0073e9SAndroid Build Coastguard Worker                IndexError,
1913*da0073e9SAndroid Build Coastguard Worker            ],
1914*da0073e9SAndroid Build Coastguard Worker            [(S,), (10,), 1, IndexError, r"Dimension out of range", IndexError],
1915*da0073e9SAndroid Build Coastguard Worker            [
1916*da0073e9SAndroid Build Coastguard Worker                (),
1917*da0073e9SAndroid Build Coastguard Worker                (10,),
1918*da0073e9SAndroid Build Coastguard Worker                0,
1919*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1920*da0073e9SAndroid Build Coastguard Worker                r"tensor_split expected at least a 1-dimensional tensor, "
1921*da0073e9SAndroid Build Coastguard Worker                + "but got a tensor with 0 dims",
1922*da0073e9SAndroid Build Coastguard Worker                IndexError,
1923*da0073e9SAndroid Build Coastguard Worker            ],
1924*da0073e9SAndroid Build Coastguard Worker            [
1925*da0073e9SAndroid Build Coastguard Worker                (S,),
1926*da0073e9SAndroid Build Coastguard Worker                0,
1927*da0073e9SAndroid Build Coastguard Worker                0,
1928*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1929*da0073e9SAndroid Build Coastguard Worker                r"number of sections must be larger than 0, got 0",
1930*da0073e9SAndroid Build Coastguard Worker                ValueError,
1931*da0073e9SAndroid Build Coastguard Worker            ],
1932*da0073e9SAndroid Build Coastguard Worker            [
1933*da0073e9SAndroid Build Coastguard Worker                (S,),
1934*da0073e9SAndroid Build Coastguard Worker                -1,
1935*da0073e9SAndroid Build Coastguard Worker                0,
1936*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1937*da0073e9SAndroid Build Coastguard Worker                r"number of sections must be larger than 0, got -1",
1938*da0073e9SAndroid Build Coastguard Worker                ValueError,
1939*da0073e9SAndroid Build Coastguard Worker            ],
1940*da0073e9SAndroid Build Coastguard Worker        ]
1941*da0073e9SAndroid Build Coastguard Worker        for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases:
1942*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(input_size, device=device)
1943*da0073e9SAndroid Build Coastguard Worker            msg = f"input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}"
1944*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(err, err_msg, msg=msg):
1945*da0073e9SAndroid Build Coastguard Worker                torch.tensor_split(a, sections_or_indices, dim)
1946*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(err, err_msg, msg=msg):
1947*da0073e9SAndroid Build Coastguard Worker                torch.tensor_split(a, torch.tensor(sections_or_indices), dim)
1948*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(numpy_err, msg=msg):
1949*da0073e9SAndroid Build Coastguard Worker                np.array_split(a.cpu().numpy(), sections_or_indices, dim)
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker        # addtional tests for tensor_split with tensor_indices_or_sections
1952*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1953*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1954*da0073e9SAndroid Build Coastguard Worker            r"tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float",
1955*da0073e9SAndroid Build Coastguard Worker        ):
1956*da0073e9SAndroid Build Coastguard Worker            torch.tensor_split(a, torch.tensor(1.1), dim)
1957*da0073e9SAndroid Build Coastguard Worker
1958*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1959*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1960*da0073e9SAndroid Build Coastguard Worker            r"tensor_split expected tensor_indices_or_sections to be a"
1961*da0073e9SAndroid Build Coastguard Worker            + " zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims",
1962*da0073e9SAndroid Build Coastguard Worker        ):
1963*da0073e9SAndroid Build Coastguard Worker            torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0)
1964*da0073e9SAndroid Build Coastguard Worker
1965*da0073e9SAndroid Build Coastguard Worker    def test_resize_all_dtypes_and_devices(self, device):
1966*da0073e9SAndroid Build Coastguard Worker        shape = (2, 2)
1967*da0073e9SAndroid Build Coastguard Worker        for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
1968*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
1969*da0073e9SAndroid Build Coastguard Worker            x.resize_(shape)
1970*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(shape, x.shape)
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker    def test_resize_as_all_dtypes_and_devices(self, device):
1973*da0073e9SAndroid Build Coastguard Worker        for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
1974*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
1975*da0073e9SAndroid Build Coastguard Worker            y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
1976*da0073e9SAndroid Build Coastguard Worker            x.resize_as_(y)
1977*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.shape, x.shape)
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1980*da0073e9SAndroid Build Coastguard Worker    def test_resize_overflow(self, device):
1981*da0073e9SAndroid Build Coastguard Worker        x = torch.empty((), dtype=torch.float64)
1982*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1983*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Storage size calculation overflowed"
1984*da0073e9SAndroid Build Coastguard Worker        ):
1985*da0073e9SAndroid Build Coastguard Worker            x.resize_([2, 4, 2**29, 2**29])
1986*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "overflow"):
1987*da0073e9SAndroid Build Coastguard Worker            x.resize_([8, 8, 2**29, 2**29])
1988*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"):
1989*da0073e9SAndroid Build Coastguard Worker            x.resize_([0, 4, 2305843009213693952])
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker    def test_view_all_dtypes_and_devices(self, device):
1992*da0073e9SAndroid Build Coastguard Worker        for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
1993*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
1994*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.view(6).shape, [6])
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet")
1997*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1998*da0073e9SAndroid Build Coastguard Worker    def test_conj_neg_view_numpy_error(self, device):
1999*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2000*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2001*da0073e9SAndroid Build Coastguard Worker            "has conjugate bit set",
2002*da0073e9SAndroid Build Coastguard Worker            lambda: torch.tensor([1 + 2j]).conj().numpy(),
2003*da0073e9SAndroid Build Coastguard Worker        )
2004*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2005*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2006*da0073e9SAndroid Build Coastguard Worker            "has negative bit set",
2007*da0073e9SAndroid Build Coastguard Worker            lambda: torch.tensor([1 + 2j]).conj().imag.numpy(),
2008*da0073e9SAndroid Build Coastguard Worker        )
2009*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2010*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2011*da0073e9SAndroid Build Coastguard Worker            "not supported for conjugate view tensors",
2012*da0073e9SAndroid Build Coastguard Worker            lambda: torch.tensor([1 + 2j]).conj().view(torch.float64),
2013*da0073e9SAndroid Build Coastguard Worker        )
2014*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2015*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2016*da0073e9SAndroid Build Coastguard Worker            "not supported for tensors with negative bit set",
2017*da0073e9SAndroid Build Coastguard Worker            lambda: torch.tensor([1 + 2j]).conj().imag.view(torch.int32),
2018*da0073e9SAndroid Build Coastguard Worker        )
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
2021*da0073e9SAndroid Build Coastguard Worker    def test_crow_col_indices(self, device):
2022*da0073e9SAndroid Build Coastguard Worker        crow_indices = (0, 1, 2)
2023*da0073e9SAndroid Build Coastguard Worker        col_indices = (1, 0)
2024*da0073e9SAndroid Build Coastguard Worker        values = (1, 2)
2025*da0073e9SAndroid Build Coastguard Worker        t = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2))
2026*da0073e9SAndroid Build Coastguard Worker        # This is the test. If crow_indices is not a view op it'll
2027*da0073e9SAndroid Build Coastguard Worker        # trigger an internal assert due to use count greater than 1
2028*da0073e9SAndroid Build Coastguard Worker        # in debug build.
2029*da0073e9SAndroid Build Coastguard Worker        t.crow_indices()
2030*da0073e9SAndroid Build Coastguard Worker        t.col_indices()
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestViewOps, globals(), include_lazy=True)
2034*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestOldViewOps, globals())
2035*da0073e9SAndroid Build Coastguard Worker
2036*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2037*da0073e9SAndroid Build Coastguard Worker    run_tests()
2038