xref: /aosp_15_r20/external/pytorch/test/test_shape_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport random
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerimport warnings
6*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
7*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, combinations, permutations, product
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport numpy as np
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerfrom torch import nan
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
15*da0073e9SAndroid Build Coastguard Worker    dtypes,
16*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
17*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
18*da0073e9SAndroid Build Coastguard Worker    largeTensorTest,
19*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
20*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
21*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
24*da0073e9SAndroid Build Coastguard Worker    all_types,
25*da0073e9SAndroid Build Coastguard Worker    all_types_and,
26*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and,
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
29*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
30*da0073e9SAndroid Build Coastguard Worker    run_tests,
31*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
32*da0073e9SAndroid Build Coastguard Worker    TEST_PRIVATEUSE1_DEVICE_TYPE,
33*da0073e9SAndroid Build Coastguard Worker    TestCase,
34*da0073e9SAndroid Build Coastguard Worker    torch_to_numpy_dtype_dict,
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker# TODO: replace with make_tensor
39*da0073e9SAndroid Build Coastguard Workerdef _generate_input(shape, dtype, device, with_extremal):
40*da0073e9SAndroid Build Coastguard Worker    if shape == ():
41*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor((), dtype=dtype, device=device)
42*da0073e9SAndroid Build Coastguard Worker    else:
43*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point or dtype.is_complex:
44*da0073e9SAndroid Build Coastguard Worker            # work around torch.randn not being implemented for bfloat16
45*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
46*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*shape, device=device) * random.randint(30, 100)
47*da0073e9SAndroid Build Coastguard Worker                x = x.to(torch.bfloat16)
48*da0073e9SAndroid Build Coastguard Worker            else:
49*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(
50*da0073e9SAndroid Build Coastguard Worker                    30, 100
51*da0073e9SAndroid Build Coastguard Worker                )
52*da0073e9SAndroid Build Coastguard Worker            x[torch.randn(*shape) > 0.5] = 0
53*da0073e9SAndroid Build Coastguard Worker            if with_extremal and dtype.is_floating_point:
54*da0073e9SAndroid Build Coastguard Worker                # Use extremal values
55*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float("nan")
56*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float("inf")
57*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float("-inf")
58*da0073e9SAndroid Build Coastguard Worker            elif with_extremal and dtype.is_complex:
59*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex("nan")
60*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex("inf")
61*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex("-inf")
62*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.bool:
63*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(shape, dtype=dtype, device=device)
64*da0073e9SAndroid Build Coastguard Worker            x[torch.randn(*shape) > 0.5] = True
65*da0073e9SAndroid Build Coastguard Worker        else:
66*da0073e9SAndroid Build Coastguard Worker            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    return x
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerclass TestShapeOps(TestCase):
72*da0073e9SAndroid Build Coastguard Worker    # TODO: update to work on CUDA, too
73*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
74*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self, device):
75*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 3, 4, 5)
76*da0073e9SAndroid Build Coastguard Worker        for dim in range(4):
77*da0073e9SAndroid Build Coastguard Worker            res = torch.unbind(x, dim)
78*da0073e9SAndroid Build Coastguard Worker            res2 = x.unbind(dim)
79*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.size(dim), len(res))
80*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.size(dim), len(res2))
81*da0073e9SAndroid Build Coastguard Worker            for i in range(dim):
82*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.select(dim, i), res[i])
83*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.select(dim, i), res2[i])
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    # TODO: update to work on CUDA, too?
86*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with an unknown error")
87*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
88*da0073e9SAndroid Build Coastguard Worker    def test_tolist(self, device):
89*da0073e9SAndroid Build Coastguard Worker        list0D = []
90*da0073e9SAndroid Build Coastguard Worker        tensor0D = torch.tensor(list0D)
91*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor0D.tolist(), list0D)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        table1D = [1.0, 2.0, 3.0]
94*da0073e9SAndroid Build Coastguard Worker        tensor1D = torch.tensor(table1D)
95*da0073e9SAndroid Build Coastguard Worker        storage = torch.Storage(table1D)
96*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor1D.tolist(), table1D)
97*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage.tolist(), table1D)
98*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor1D.tolist(), table1D)
99*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(storage.tolist(), table1D)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        table2D = [[1, 2], [3, 4]]
102*da0073e9SAndroid Build Coastguard Worker        tensor2D = torch.tensor(table2D)
103*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2D.tolist(), table2D)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker        tensor3D = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
106*da0073e9SAndroid Build Coastguard Worker        tensorNonContig = tensor3D.select(1, 1)
107*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(tensorNonContig.is_contiguous())
108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]])
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex128)
111*da0073e9SAndroid Build Coastguard Worker    def test_movedim_invalid(self, device, dtype):
112*da0073e9SAndroid Build Coastguard Worker        shape = self._rand_shape(4, min_size=5, max_size=10)
113*da0073e9SAndroid Build Coastguard Worker        x = _generate_input(shape, dtype, device, False)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        for fn in [torch.movedim, torch.moveaxis]:
116*da0073e9SAndroid Build Coastguard Worker            # Invalid `source` and `destination` dimension
117*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
118*da0073e9SAndroid Build Coastguard Worker                fn(x, 5, 0)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
121*da0073e9SAndroid Build Coastguard Worker                fn(x, 0, 5)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker            # Mismatch in size of `source` and `destination`
124*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
125*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "movedim: Invalid source or destination dims:"
126*da0073e9SAndroid Build Coastguard Worker            ):
127*da0073e9SAndroid Build Coastguard Worker                fn(x, (1, 0), (0,))
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
130*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "movedim: repeated dim in `source`"
131*da0073e9SAndroid Build Coastguard Worker            ):
132*da0073e9SAndroid Build Coastguard Worker                fn(x, (0, 0), (0, 1))
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
135*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "movedim: repeated dim in `source`"
136*da0073e9SAndroid Build Coastguard Worker            ):
137*da0073e9SAndroid Build Coastguard Worker                fn(x, (0, 1, 0), (0, 1, 2))
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
140*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "movedim: repeated dim in `destination`"
141*da0073e9SAndroid Build Coastguard Worker            ):
142*da0073e9SAndroid Build Coastguard Worker                fn(x, (0, 1), (1, 1))
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
145*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "movedim: repeated dim in `destination`"
146*da0073e9SAndroid Build Coastguard Worker            ):
147*da0073e9SAndroid Build Coastguard Worker                fn(x, (0, 1, 2), (1, 0, 1))
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex128)
150*da0073e9SAndroid Build Coastguard Worker    def test_movedim(self, device, dtype):
151*da0073e9SAndroid Build Coastguard Worker        for fn in [torch.moveaxis, torch.movedim]:
152*da0073e9SAndroid Build Coastguard Worker            for nd in range(5):
153*da0073e9SAndroid Build Coastguard Worker                shape = self._rand_shape(nd, min_size=5, max_size=10)
154*da0073e9SAndroid Build Coastguard Worker                x = _generate_input(shape, dtype, device, with_extremal=False)
155*da0073e9SAndroid Build Coastguard Worker                for random_negative in [True, False]:
156*da0073e9SAndroid Build Coastguard Worker                    for src_dim, dst_dim in permutations(range(nd), r=2):
157*da0073e9SAndroid Build Coastguard Worker                        random_prob = random.random()
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker                        if random_negative and random_prob > 0.66:
160*da0073e9SAndroid Build Coastguard Worker                            src_dim = src_dim - nd
161*da0073e9SAndroid Build Coastguard Worker                        elif random_negative and random_prob > 0.33:
162*da0073e9SAndroid Build Coastguard Worker                            dst_dim = dst_dim - nd
163*da0073e9SAndroid Build Coastguard Worker                        elif random_negative:
164*da0073e9SAndroid Build Coastguard Worker                            src_dim = src_dim - nd
165*da0073e9SAndroid Build Coastguard Worker                            dst_dim = dst_dim - nd
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker                        # Integer `source` and `destination`
168*da0073e9SAndroid Build Coastguard Worker                        torch_fn = partial(fn, source=src_dim, destination=dst_dim)
169*da0073e9SAndroid Build Coastguard Worker                        np_fn = partial(
170*da0073e9SAndroid Build Coastguard Worker                            np.moveaxis, source=src_dim, destination=dst_dim
171*da0073e9SAndroid Build Coastguard Worker                        )
172*da0073e9SAndroid Build Coastguard Worker                        self.compare_with_numpy(
173*da0073e9SAndroid Build Coastguard Worker                            torch_fn, np_fn, x, device=None, dtype=None
174*da0073e9SAndroid Build Coastguard Worker                        )
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker                    if nd == 0:
177*da0073e9SAndroid Build Coastguard Worker                        continue
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker                    def make_index_negative(sequence, idx):
180*da0073e9SAndroid Build Coastguard Worker                        sequence = list(sequence)
181*da0073e9SAndroid Build Coastguard Worker                        sequence[random_idx] = sequence[random_idx] - nd
182*da0073e9SAndroid Build Coastguard Worker                        return tuple(src_sequence)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker                    for src_sequence in permutations(
185*da0073e9SAndroid Build Coastguard Worker                        range(nd), r=random.randint(1, nd)
186*da0073e9SAndroid Build Coastguard Worker                    ):
187*da0073e9SAndroid Build Coastguard Worker                        # Sequence `source` and `destination`
188*da0073e9SAndroid Build Coastguard Worker                        dst_sequence = tuple(
189*da0073e9SAndroid Build Coastguard Worker                            random.sample(range(nd), len(src_sequence))
190*da0073e9SAndroid Build Coastguard Worker                        )
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker                        # Randomly change a dim to a negative dim representation of itself.
193*da0073e9SAndroid Build Coastguard Worker                        random_prob = random.random()
194*da0073e9SAndroid Build Coastguard Worker                        if random_negative and random_prob > 0.66:
195*da0073e9SAndroid Build Coastguard Worker                            random_idx = random.randint(0, len(src_sequence) - 1)
196*da0073e9SAndroid Build Coastguard Worker                            src_sequence = make_index_negative(src_sequence, random_idx)
197*da0073e9SAndroid Build Coastguard Worker                        elif random_negative and random_prob > 0.33:
198*da0073e9SAndroid Build Coastguard Worker                            random_idx = random.randint(0, len(src_sequence) - 1)
199*da0073e9SAndroid Build Coastguard Worker                            dst_sequence = make_index_negative(dst_sequence, random_idx)
200*da0073e9SAndroid Build Coastguard Worker                        elif random_negative:
201*da0073e9SAndroid Build Coastguard Worker                            random_idx = random.randint(0, len(src_sequence) - 1)
202*da0073e9SAndroid Build Coastguard Worker                            dst_sequence = make_index_negative(dst_sequence, random_idx)
203*da0073e9SAndroid Build Coastguard Worker                            random_idx = random.randint(0, len(src_sequence) - 1)
204*da0073e9SAndroid Build Coastguard Worker                            src_sequence = make_index_negative(src_sequence, random_idx)
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker                        torch_fn = partial(
207*da0073e9SAndroid Build Coastguard Worker                            fn, source=src_sequence, destination=dst_sequence
208*da0073e9SAndroid Build Coastguard Worker                        )
209*da0073e9SAndroid Build Coastguard Worker                        np_fn = partial(
210*da0073e9SAndroid Build Coastguard Worker                            np.moveaxis, source=src_sequence, destination=dst_sequence
211*da0073e9SAndroid Build Coastguard Worker                        )
212*da0073e9SAndroid Build Coastguard Worker                        self.compare_with_numpy(
213*da0073e9SAndroid Build Coastguard Worker                            torch_fn, np_fn, x, device=None, dtype=None
214*da0073e9SAndroid Build Coastguard Worker                        )
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker            # Move dim to same position
217*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 3, 5, 7, 11)
218*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(fn, source=(0, 1), destination=(0, 1))
219*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
220*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(fn, source=1, destination=1)
223*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.moveaxis, source=1, destination=1)
224*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker            # Empty Sequence
227*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(fn, source=(), destination=())
228*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.moveaxis, source=(), destination=())
229*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.bool)
232*da0073e9SAndroid Build Coastguard Worker    def test_diag(self, device, dtype):
233*da0073e9SAndroid Build Coastguard Worker        if dtype is torch.bool:
234*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(100, 100, device=device) >= 0.5
235*da0073e9SAndroid Build Coastguard Worker        else:
236*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(100, 100, dtype=dtype, device=device)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        res1 = torch.diag(x)
239*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
240*da0073e9SAndroid Build Coastguard Worker        torch.diag(x, out=res2)
241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def test_diagonal(self, device):
244*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((100, 100), device=device)
245*da0073e9SAndroid Build Coastguard Worker        result = torch.diagonal(x)
246*da0073e9SAndroid Build Coastguard Worker        expected = torch.diag(x)
247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((100, 100), device=device)
250*da0073e9SAndroid Build Coastguard Worker        result = torch.diagonal(x, 17)
251*da0073e9SAndroid Build Coastguard Worker        expected = torch.diag(x, 17)
252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
255*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
256*da0073e9SAndroid Build Coastguard Worker    def test_diagonal_multidim(self, device, dtype):
257*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device)
258*da0073e9SAndroid Build Coastguard Worker        xn = x.numpy()
259*da0073e9SAndroid Build Coastguard Worker        for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]:
260*da0073e9SAndroid Build Coastguard Worker            result = torch.diagonal(x, *args)
261*da0073e9SAndroid Build Coastguard Worker            expected = xn.diagonal(*args)
262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected.shape, result.shape)
263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, result)
264*da0073e9SAndroid Build Coastguard Worker        # test non-continguous
265*da0073e9SAndroid Build Coastguard Worker        xp = x.permute(1, 2, 3, 0)
266*da0073e9SAndroid Build Coastguard Worker        result = torch.diagonal(xp, 0, -2, -1)
267*da0073e9SAndroid Build Coastguard Worker        expected = xp.numpy().diagonal(0, -2, -1)
268*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.shape, result.shape)
269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, result)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
272*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types())
273*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(*all_types_and(torch.half))
274*da0073e9SAndroid Build Coastguard Worker    def test_trace(self, device, dtype):
275*da0073e9SAndroid Build Coastguard Worker        def test(shape):
276*da0073e9SAndroid Build Coastguard Worker            tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
277*da0073e9SAndroid Build Coastguard Worker            expected_dtype = tensor.sum().dtype
278*da0073e9SAndroid Build Coastguard Worker            expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker            result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
281*da0073e9SAndroid Build Coastguard Worker            expected = torch.tensor(result, device=device)
282*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.trace(), expected)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        shapes = (
285*da0073e9SAndroid Build Coastguard Worker            [10, 1],
286*da0073e9SAndroid Build Coastguard Worker            [1, 10],
287*da0073e9SAndroid Build Coastguard Worker            [100, 100],
288*da0073e9SAndroid Build Coastguard Worker            [20, 100],
289*da0073e9SAndroid Build Coastguard Worker            [100, 20],
290*da0073e9SAndroid Build Coastguard Worker        )
291*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
292*da0073e9SAndroid Build Coastguard Worker            test(shape)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans):
295*da0073e9SAndroid Build Coastguard Worker        """
296*da0073e9SAndroid Build Coastguard Worker        Creates a random tensor for a given device and dtype, and computes the expected clamped
297*da0073e9SAndroid Build Coastguard Worker        values given the min_vals and/or max_vals.
298*da0073e9SAndroid Build Coastguard Worker        If with_nans is provided, then some values are randomly set to nan.
299*da0073e9SAndroid Build Coastguard Worker        """
300*da0073e9SAndroid Build Coastguard Worker        X = torch.rand(100, device=device).mul(50).add(-25)  # uniform in [-25, 25]
301*da0073e9SAndroid Build Coastguard Worker        X = X.to(dtype)
302*da0073e9SAndroid Build Coastguard Worker        if with_nans:
303*da0073e9SAndroid Build Coastguard Worker            mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device)
304*da0073e9SAndroid Build Coastguard Worker            X[mask] = nan
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        if isinstance(min_vals, torch.Tensor):
307*da0073e9SAndroid Build Coastguard Worker            min_vals = min_vals.cpu().numpy()
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        if isinstance(max_vals, torch.Tensor):
310*da0073e9SAndroid Build Coastguard Worker            max_vals = max_vals.cpu().numpy()
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        # Use NumPy implementation as reference
313*da0073e9SAndroid Build Coastguard Worker        X_clamped = torch.tensor(
314*da0073e9SAndroid Build Coastguard Worker            np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device
315*da0073e9SAndroid Build Coastguard Worker        )
316*da0073e9SAndroid Build Coastguard Worker        return X, X_clamped
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    # Tests clamp and its alias, clip
319*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float32)
320*da0073e9SAndroid Build Coastguard Worker    def test_clamp(self, device, dtype):
321*da0073e9SAndroid Build Coastguard Worker        op_list = (
322*da0073e9SAndroid Build Coastguard Worker            torch.clamp,
323*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clamp,
324*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clamp_,
325*da0073e9SAndroid Build Coastguard Worker            torch.clip,
326*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clip,
327*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clip_,
328*da0073e9SAndroid Build Coastguard Worker        )
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        # min/max argument product
331*da0073e9SAndroid Build Coastguard Worker        args = product((-10, None), (10, None))
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        for op in op_list:
334*da0073e9SAndroid Build Coastguard Worker            for min_val, max_val in args:
335*da0073e9SAndroid Build Coastguard Worker                if min_val is None and max_val is None:
336*da0073e9SAndroid Build Coastguard Worker                    continue
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker                X, Y_expected = self.generate_clamp_baseline(
339*da0073e9SAndroid Build Coastguard Worker                    device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False
340*da0073e9SAndroid Build Coastguard Worker                )
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker                # Test op
343*da0073e9SAndroid Build Coastguard Worker                X1 = X.clone()  # So that the in-place ops do not change X
344*da0073e9SAndroid Build Coastguard Worker                Y_actual = op(X1, min_val, max_val)
345*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(Y_expected, Y_actual)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker                # Test op-out behavior (out does not exist for method versions)
348*da0073e9SAndroid Build Coastguard Worker                if op in (torch.clamp, torch.clip):
349*da0073e9SAndroid Build Coastguard Worker                    Y_out = torch.empty_like(X)
350*da0073e9SAndroid Build Coastguard Worker                    op(X, min=min_val, max=max_val, out=Y_out)
351*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(Y_expected, Y_out)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    def test_clamp_propagates_nans(self, device):
354*da0073e9SAndroid Build Coastguard Worker        op_list = (
355*da0073e9SAndroid Build Coastguard Worker            torch.clamp,
356*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clamp,
357*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clamp_,
358*da0073e9SAndroid Build Coastguard Worker            torch.clip,
359*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clip,
360*da0073e9SAndroid Build Coastguard Worker            torch.Tensor.clip_,
361*da0073e9SAndroid Build Coastguard Worker        )
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        # min/max argument product
364*da0073e9SAndroid Build Coastguard Worker        args = product((-10, None), (10, None))
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        for op in op_list:
367*da0073e9SAndroid Build Coastguard Worker            for min_val, max_val in args:
368*da0073e9SAndroid Build Coastguard Worker                if min_val is None and max_val is None:
369*da0073e9SAndroid Build Coastguard Worker                    continue
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker                X, Y_expected = self.generate_clamp_baseline(
372*da0073e9SAndroid Build Coastguard Worker                    device,
373*da0073e9SAndroid Build Coastguard Worker                    torch.float,
374*da0073e9SAndroid Build Coastguard Worker                    min_vals=min_val,
375*da0073e9SAndroid Build Coastguard Worker                    max_vals=max_val,
376*da0073e9SAndroid Build Coastguard Worker                    with_nans=True,
377*da0073e9SAndroid Build Coastguard Worker                )
378*da0073e9SAndroid Build Coastguard Worker                Y_expected = torch.isnan(Y_expected)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker                # Test op
381*da0073e9SAndroid Build Coastguard Worker                X1 = X.clone()  # So that the in-place ops do not change X
382*da0073e9SAndroid Build Coastguard Worker                Y_actual = op(X1, min_val, max_val)
383*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(Y_expected, torch.isnan(Y_actual))
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker                # Test op-out behavior (out does not exist for method versions)
386*da0073e9SAndroid Build Coastguard Worker                if op in (torch.clamp, torch.clip):
387*da0073e9SAndroid Build Coastguard Worker                    Y_out = torch.empty_like(X)
388*da0073e9SAndroid Build Coastguard Worker                    op(X, min_val, max_val, out=Y_out)
389*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(Y_expected, torch.isnan(Y_out))
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    def test_clamp_raises_arg_errors(self, device):
392*da0073e9SAndroid Build Coastguard Worker        X = torch.randn(100, dtype=torch.float, device=device)
393*da0073e9SAndroid Build Coastguard Worker        error_msg = "At least one of 'min' or 'max' must not be None"
394*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_msg):
395*da0073e9SAndroid Build Coastguard Worker            X.clamp()
396*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_msg):
397*da0073e9SAndroid Build Coastguard Worker            X.clamp_()
398*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, error_msg):
399*da0073e9SAndroid Build Coastguard Worker            torch.clamp(X)
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
402*da0073e9SAndroid Build Coastguard Worker    def test_flip(self, device, dtype):
403*da0073e9SAndroid Build Coastguard Worker        make_from_data = partial(torch.tensor, device=device, dtype=dtype)
404*da0073e9SAndroid Build Coastguard Worker        make_from_size = partial(make_tensor, device=device, dtype=dtype)
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker        def test_flip_impl(input_t, dims, output_t):
407*da0073e9SAndroid Build Coastguard Worker            def all_t():
408*da0073e9SAndroid Build Coastguard Worker                yield input_t, output_t
409*da0073e9SAndroid Build Coastguard Worker                if dtype is torch.float:
410*da0073e9SAndroid Build Coastguard Worker                    # We generate quantized versions as well
411*da0073e9SAndroid Build Coastguard Worker                    for qdtype in (torch.quint8, torch.qint8, torch.qint32):
412*da0073e9SAndroid Build Coastguard Worker                        qinput_t = torch.quantize_per_tensor(input_t, 0.1, 5, qdtype)
413*da0073e9SAndroid Build Coastguard Worker                        qoutput_t = torch.quantize_per_tensor(output_t, 0.1, 5, qdtype)
414*da0073e9SAndroid Build Coastguard Worker                        yield qinput_t, qoutput_t
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker            for in_t, out_t in all_t():
417*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(in_t.flip(dims), out_t)
418*da0073e9SAndroid Build Coastguard Worker                n = in_t.ndim
419*da0073e9SAndroid Build Coastguard Worker                if not isinstance(dims, tuple):
420*da0073e9SAndroid Build Coastguard Worker                    # Wrap dim
421*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(in_t.flip(-n + dims), out_t)
422*da0073e9SAndroid Build Coastguard Worker                else:
423*da0073e9SAndroid Build Coastguard Worker                    # Permute dimensions
424*da0073e9SAndroid Build Coastguard Worker                    for p_dims in permutations(dims):
425*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(in_t.flip(p_dims), out_t)
426*da0073e9SAndroid Build Coastguard Worker                        if len(p_dims) > 0:
427*da0073e9SAndroid Build Coastguard Worker                            # Wrap 1st dim
428*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(
429*da0073e9SAndroid Build Coastguard Worker                                in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t
430*da0073e9SAndroid Build Coastguard Worker                            )
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        def gen_data():
433*da0073e9SAndroid Build Coastguard Worker            # Basic tests
434*da0073e9SAndroid Build Coastguard Worker            data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2)
435*da0073e9SAndroid Build Coastguard Worker            nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker            dims_result = (
438*da0073e9SAndroid Build Coastguard Worker                (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)),
439*da0073e9SAndroid Build Coastguard Worker                (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)),
440*da0073e9SAndroid Build Coastguard Worker                (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)),
441*da0073e9SAndroid Build Coastguard Worker                ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)),
442*da0073e9SAndroid Build Coastguard Worker                ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)),
443*da0073e9SAndroid Build Coastguard Worker            )
444*da0073e9SAndroid Build Coastguard Worker            for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result):
445*da0073e9SAndroid Build Coastguard Worker                yield in_tensor, dims, out_tensor
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker            # Expanded
448*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_data([1, 2, 3]).view(3, 1).expand(3, 2)
449*da0073e9SAndroid Build Coastguard Worker            dims = 0
450*da0073e9SAndroid Build Coastguard Worker            out_t = make_from_data([3, 3, 2, 2, 1, 1]).view(3, 2)
451*da0073e9SAndroid Build Coastguard Worker            yield in_t, dims, out_t
452*da0073e9SAndroid Build Coastguard Worker            # Noop on expanded dimension
453*da0073e9SAndroid Build Coastguard Worker            yield in_t, 1, in_t
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker            # Transposed
456*da0073e9SAndroid Build Coastguard Worker            in_t = (
457*da0073e9SAndroid Build Coastguard Worker                make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1)
458*da0073e9SAndroid Build Coastguard Worker            )
459*da0073e9SAndroid Build Coastguard Worker            dims = (0, 1, 2)
460*da0073e9SAndroid Build Coastguard Worker            out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2)
461*da0073e9SAndroid Build Coastguard Worker            yield in_t, dims, out_t
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker            # Rectangular case
464*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_data([1, 2, 3, 4, 5, 6]).view(2, 3)
465*da0073e9SAndroid Build Coastguard Worker            dims = 0
466*da0073e9SAndroid Build Coastguard Worker            out_t = make_from_data([[4, 5, 6], [1, 2, 3]])
467*da0073e9SAndroid Build Coastguard Worker            yield in_t, dims, out_t
468*da0073e9SAndroid Build Coastguard Worker            dims = 1
469*da0073e9SAndroid Build Coastguard Worker            out_t = make_from_data([[3, 2, 1], [6, 5, 4]])
470*da0073e9SAndroid Build Coastguard Worker            yield in_t, dims, out_t
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker            # vectorized NCHW cases (images)
473*da0073e9SAndroid Build Coastguard Worker            if device == "cpu" and dtype != torch.bfloat16:
474*da0073e9SAndroid Build Coastguard Worker                for mf in [torch.contiguous_format, torch.channels_last]:
475*da0073e9SAndroid Build Coastguard Worker                    for c in [2, 3, 8, 16]:
476*da0073e9SAndroid Build Coastguard Worker                        in_t = make_from_size((2, c, 32, 32)).contiguous(
477*da0073e9SAndroid Build Coastguard Worker                            memory_format=mf
478*da0073e9SAndroid Build Coastguard Worker                        )
479*da0073e9SAndroid Build Coastguard Worker                        np_in_t = in_t.numpy()
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker                        np_out_t = np_in_t[:, :, :, ::-1].copy()
482*da0073e9SAndroid Build Coastguard Worker                        out_t = torch.from_numpy(np_out_t)
483*da0073e9SAndroid Build Coastguard Worker                        yield in_t, 3, out_t
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker                        np_out_t = np_in_t[:, :, ::-1, :].copy()
486*da0073e9SAndroid Build Coastguard Worker                        out_t = torch.from_numpy(np_out_t)
487*da0073e9SAndroid Build Coastguard Worker                        yield in_t, 2, out_t
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker                        # non-contig cases
490*da0073e9SAndroid Build Coastguard Worker                        in_tt = in_t[..., ::2, :]
491*da0073e9SAndroid Build Coastguard Worker                        np_in_t = in_tt.numpy()
492*da0073e9SAndroid Build Coastguard Worker                        np_out_t = np_in_t[:, :, :, ::-1].copy()
493*da0073e9SAndroid Build Coastguard Worker                        out_t = torch.from_numpy(np_out_t)
494*da0073e9SAndroid Build Coastguard Worker                        yield in_tt, 3, out_t
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker                        in_tt = in_t[..., ::2]
497*da0073e9SAndroid Build Coastguard Worker                        np_in_t = in_tt.numpy()
498*da0073e9SAndroid Build Coastguard Worker                        np_out_t = np_in_t[:, :, :, ::-1].copy()
499*da0073e9SAndroid Build Coastguard Worker                        out_t = torch.from_numpy(np_out_t)
500*da0073e9SAndroid Build Coastguard Worker                        yield in_tt, 3, out_t
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker            # Noops (edge cases)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker            # Size 0
505*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_data(())
506*da0073e9SAndroid Build Coastguard Worker            yield in_t, 0, in_t
507*da0073e9SAndroid Build Coastguard Worker            yield in_t, (), in_t
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker            # dims = ()
510*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_size((3, 2, 1))
511*da0073e9SAndroid Build Coastguard Worker            yield in_t, (), in_t
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker            # Zero elements, non-zero size
514*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_size((3, 0, 2))
515*da0073e9SAndroid Build Coastguard Worker            for i in range(in_t.ndim):
516*da0073e9SAndroid Build Coastguard Worker                yield in_t, i, in_t
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker            # Size 1
519*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_size(())
520*da0073e9SAndroid Build Coastguard Worker            yield in_t, 0, in_t
521*da0073e9SAndroid Build Coastguard Worker            in_t = make_from_size((1,))
522*da0073e9SAndroid Build Coastguard Worker            yield in_t, 0, in_t
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker        for in_tensor, dims, out_tensor in gen_data():
525*da0073e9SAndroid Build Coastguard Worker            test_flip_impl(in_tensor, dims, out_tensor)
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker        # test for shape
528*da0073e9SAndroid Build Coastguard Worker        size = [2, 3, 4]
529*da0073e9SAndroid Build Coastguard Worker        data = make_from_size(size)
530*da0073e9SAndroid Build Coastguard Worker        possible_dims = range(len(size))
531*da0073e9SAndroid Build Coastguard Worker        test_dims = chain(
532*da0073e9SAndroid Build Coastguard Worker            combinations(possible_dims, 1), combinations(possible_dims, 2)
533*da0073e9SAndroid Build Coastguard Worker        )
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker        for dims in test_dims:
536*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(size, list(data.flip(dims).size()))
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
539*da0073e9SAndroid Build Coastguard Worker    def test_flip_errors(self, device, dtype):
540*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
541*da0073e9SAndroid Build Coastguard Worker        data = make_arg((2, 2, 2))
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        # not allow flip on the same dim more than once
544*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
545*da0073e9SAndroid Build Coastguard Worker        # not allow empty list as input
546*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(TypeError, lambda: data.flip())
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker        # not allow dim > max dim
549*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3))
550*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(IndexError, lambda: data.flip(3))
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker    def _rand_shape(self, dim, min_size, max_size):
553*da0073e9SAndroid Build Coastguard Worker        return tuple(torch.randint(min_size, max_size + 1, (dim,)))
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
556*da0073e9SAndroid Build Coastguard Worker    def test_flip_numpy(self, device, dtype):
557*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device=device)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        for ndim in [3, 4]:
560*da0073e9SAndroid Build Coastguard Worker            shape = self._rand_shape(ndim, 5, 10)
561*da0073e9SAndroid Build Coastguard Worker            data = make_arg(shape)
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker            # Axis to sample for given shape.
564*da0073e9SAndroid Build Coastguard Worker            for i in range(1, ndim + 1):
565*da0073e9SAndroid Build Coastguard Worker                # Check all combinations of `i` axis.
566*da0073e9SAndroid Build Coastguard Worker                for flip_dim in combinations(range(ndim), i):
567*da0073e9SAndroid Build Coastguard Worker                    torch_fn = partial(torch.flip, dims=flip_dim)
568*da0073e9SAndroid Build Coastguard Worker                    np_fn = partial(np.flip, axis=flip_dim)
569*da0073e9SAndroid Build Coastguard Worker                    self.compare_with_numpy(torch_fn, np_fn, data)
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA  # CPU is too slow
572*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("17GB")  # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB
573*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest(
574*da0073e9SAndroid Build Coastguard Worker        "81GB", "cpu"
575*da0073e9SAndroid Build Coastguard Worker    )  # even for CUDA test, sufficient system memory is required
576*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "Too large for Jetson")
577*da0073e9SAndroid Build Coastguard Worker    def test_flip_large_tensor(self, device):
578*da0073e9SAndroid Build Coastguard Worker        t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_()
579*da0073e9SAndroid Build Coastguard Worker        torch_fn = partial(torch.flip, dims=(0,))
580*da0073e9SAndroid Build Coastguard Worker        np_fn = partial(np.flip, axis=0)
581*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch_fn, np_fn, t_in)
582*da0073e9SAndroid Build Coastguard Worker        del t_in
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker    def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype):
585*da0073e9SAndroid Build Coastguard Worker        for dim in range(min_dim, max_dim + 1):
586*da0073e9SAndroid Build Coastguard Worker            shape = self._rand_shape(dim, 5, 10)
587*da0073e9SAndroid Build Coastguard Worker            # Randomly scale the input
588*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point or dtype.is_complex:
589*da0073e9SAndroid Build Coastguard Worker                data = torch.randn(*shape, device=device, dtype=dtype)
590*da0073e9SAndroid Build Coastguard Worker            else:
591*da0073e9SAndroid Build Coastguard Worker                data = torch.randint(0, 10, shape, device=device, dtype=dtype)
592*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, data)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.double, torch.cdouble)
595*da0073e9SAndroid Build Coastguard Worker    def test_fliplr(self, device, dtype):
596*da0073e9SAndroid Build Coastguard Worker        self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype)
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.double, torch.cdouble)
599*da0073e9SAndroid Build Coastguard Worker    def test_fliplr_invalid(self, device, dtype):
600*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(42).to(dtype)
601*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
602*da0073e9SAndroid Build Coastguard Worker            torch.fliplr(x)
603*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
604*da0073e9SAndroid Build Coastguard Worker            torch.fliplr(torch.tensor(42, device=device, dtype=dtype))
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.double, torch.cdouble)
607*da0073e9SAndroid Build Coastguard Worker    def test_flipud(self, device, dtype):
608*da0073e9SAndroid Build Coastguard Worker        self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.double, torch.cdouble)
611*da0073e9SAndroid Build Coastguard Worker    def test_flipud_invalid(self, device, dtype):
612*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."):
613*da0073e9SAndroid Build Coastguard Worker            torch.flipud(torch.tensor(42, device=device, dtype=dtype))
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker    def test_rot90(self, device):
616*da0073e9SAndroid Build Coastguard Worker        data = torch.arange(1, 5, device=device).view(2, 2)
617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1]))
619*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1]))
620*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1]))
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        # test for default args k=1, dims=[0, 1]
623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.rot90(), data.rot90(1, [0, 1]))
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        # test for reversed order of dims
626*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0]))
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker        # test for modulo of k
629*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1]))
630*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1]))
631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1]))
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        # test for dims out-of-range error
634*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3]))
635*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2]))
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker        # test tensor with more than 2D
638*da0073e9SAndroid Build Coastguard Worker        data = torch.arange(1, 9, device=device).view(2, 2, 2)
639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
640*da0073e9SAndroid Build Coastguard Worker            torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])
641*da0073e9SAndroid Build Coastguard Worker        )
642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2]))
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker        # test for errors
645*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3]))
646*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1]))
647*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2]))
648*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0]))
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("TorchDynamo fails with an unknown error")
651*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.cfloat, torch.cdouble)
652*da0073e9SAndroid Build Coastguard Worker    def test_complex_rot90(self, device, dtype):
653*da0073e9SAndroid Build Coastguard Worker        shape = self._rand_shape(random.randint(2, 4), 5, 10)
654*da0073e9SAndroid Build Coastguard Worker        for rot_times in range(4):
655*da0073e9SAndroid Build Coastguard Worker            data = torch.randn(*shape, device=device, dtype=dtype)
656*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1])
657*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.rot90, k=rot_times, axes=[0, 1])
658*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, data)
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker    # TODO: update once warning flag is available to always trigger ONCE warnings
661*da0073e9SAndroid Build Coastguard Worker    # Ensures nonzero does not throw a warning, even when the as_tuple argument
662*da0073e9SAndroid Build Coastguard Worker    #   is not provided
663*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_no_warning(self, device):
664*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((2, 2), device=device)
665*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
666*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
667*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t)
668*da0073e9SAndroid Build Coastguard Worker            t.nonzero()
669*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16))
672*da0073e9SAndroid Build Coastguard Worker    def test_nonzero(self, device, dtype):
673*da0073e9SAndroid Build Coastguard Worker        shapes = [
674*da0073e9SAndroid Build Coastguard Worker            torch.Size((12,)),
675*da0073e9SAndroid Build Coastguard Worker            torch.Size((12, 1)),
676*da0073e9SAndroid Build Coastguard Worker            torch.Size((1, 12)),
677*da0073e9SAndroid Build Coastguard Worker            torch.Size((6, 2)),
678*da0073e9SAndroid Build Coastguard Worker            torch.Size((3, 2, 2)),
679*da0073e9SAndroid Build Coastguard Worker            torch.Size((5, 5, 5)),
680*da0073e9SAndroid Build Coastguard Worker        ]
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        def gen_nontrivial_input(shape, dtype, device):
683*da0073e9SAndroid Build Coastguard Worker            if dtype != torch.bfloat16:
684*da0073e9SAndroid Build Coastguard Worker                return torch.randint(2, shape, device=device, dtype=dtype)
685*da0073e9SAndroid Build Coastguard Worker            else:
686*da0073e9SAndroid Build Coastguard Worker                # windows does not work for bfloat16 randing
687*da0073e9SAndroid Build Coastguard Worker                return torch.randint(2, shape, device=device, dtype=torch.float).to(
688*da0073e9SAndroid Build Coastguard Worker                    dtype
689*da0073e9SAndroid Build Coastguard Worker                )
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
692*da0073e9SAndroid Build Coastguard Worker            tensor = gen_nontrivial_input(shape, dtype, device)
693*da0073e9SAndroid Build Coastguard Worker            dst1 = torch.nonzero(tensor, as_tuple=False)
694*da0073e9SAndroid Build Coastguard Worker            dst2 = tensor.nonzero(as_tuple=False)
695*da0073e9SAndroid Build Coastguard Worker            dst3 = torch.empty([], dtype=torch.long, device=device)
696*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(tensor, out=dst3)
697*da0073e9SAndroid Build Coastguard Worker            if self.device_type != "xla":
698*da0073e9SAndroid Build Coastguard Worker                # xla does not raise runtime error
699*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
700*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
701*da0073e9SAndroid Build Coastguard Worker                    "scalar type Long",
702*da0073e9SAndroid Build Coastguard Worker                    lambda: torch.nonzero(
703*da0073e9SAndroid Build Coastguard Worker                        tensor, out=torch.empty([], dtype=torch.float, device=device)
704*da0073e9SAndroid Build Coastguard Worker                    ),
705*da0073e9SAndroid Build Coastguard Worker                )
706*da0073e9SAndroid Build Coastguard Worker            if (
707*da0073e9SAndroid Build Coastguard Worker                self.device_type == "cuda"
708*da0073e9SAndroid Build Coastguard Worker                or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE
709*da0073e9SAndroid Build Coastguard Worker            ):
710*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(
711*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
712*da0073e9SAndroid Build Coastguard Worker                    "on the same device",
713*da0073e9SAndroid Build Coastguard Worker                    lambda: torch.nonzero(
714*da0073e9SAndroid Build Coastguard Worker                        tensor, out=torch.empty([], dtype=torch.long)
715*da0073e9SAndroid Build Coastguard Worker                    ),
716*da0073e9SAndroid Build Coastguard Worker                )
717*da0073e9SAndroid Build Coastguard Worker            np_array = (
718*da0073e9SAndroid Build Coastguard Worker                tensor.cpu().numpy()
719*da0073e9SAndroid Build Coastguard Worker                if dtype != torch.bfloat16
720*da0073e9SAndroid Build Coastguard Worker                else tensor.float().cpu().numpy()
721*da0073e9SAndroid Build Coastguard Worker            )
722*da0073e9SAndroid Build Coastguard Worker            np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
723*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
724*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
725*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
726*da0073e9SAndroid Build Coastguard Worker            tup1 = torch.nonzero(tensor, as_tuple=True)
727*da0073e9SAndroid Build Coastguard Worker            tup2 = tensor.nonzero(as_tuple=True)
728*da0073e9SAndroid Build Coastguard Worker            tup1 = torch.stack(tup1).t().cpu()
729*da0073e9SAndroid Build Coastguard Worker            tup2 = torch.stack(tup2).t().cpu()
730*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tup1, np_result, atol=0, rtol=0)
731*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tup2, np_result, atol=0, rtol=0)
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_astuple_out(self, device):
734*da0073e9SAndroid Build Coastguard Worker        t = torch.randn((3, 3, 3), device=device)
735*da0073e9SAndroid Build Coastguard Worker        out = torch.empty_like(t, dtype=torch.long)
736*da0073e9SAndroid Build Coastguard Worker
737*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
738*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t, as_tuple=True, out=out)
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
741*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)
742*da0073e9SAndroid Build Coastguard Worker        )
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker        # Verifies that JIT script cannot handle the as_tuple kwarg
745*da0073e9SAndroid Build Coastguard Worker        # See Issue https://github.com/pytorch/pytorch/issues/45499.
746*da0073e9SAndroid Build Coastguard Worker        def _foo(t):
747*da0073e9SAndroid Build Coastguard Worker            tuple_result = torch.nonzero(t, as_tuple=True)
748*da0073e9SAndroid Build Coastguard Worker            nontuple_result = torch.nonzero(t, as_tuple=False)
749*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(nontuple_result)
750*da0073e9SAndroid Build Coastguard Worker            torch.nonzero(t, as_tuple=False, out=out)
751*da0073e9SAndroid Build Coastguard Worker            return tuple_result, nontuple_result, out
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
754*da0073e9SAndroid Build Coastguard Worker            scripted_foo = torch.jit.script(_foo)
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker        # Verifies that JIT tracing works fine
757*da0073e9SAndroid Build Coastguard Worker        traced_foo = torch.jit.trace(_foo, t)
758*da0073e9SAndroid Build Coastguard Worker        traced_tuple, traced_nontuple, traced_out = traced_foo(t)
759*da0073e9SAndroid Build Coastguard Worker        expected_tuple = torch.nonzero(t, as_tuple=True)
760*da0073e9SAndroid Build Coastguard Worker        expected_nontuple = torch.nonzero(t)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_tuple, expected_tuple)
763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_nontuple, expected_nontuple)
764*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced_out, expected_nontuple)
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
767*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_discontiguous(self, device):
768*da0073e9SAndroid Build Coastguard Worker        shape = (4, 4)
769*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randint(2, shape, device=device)
770*da0073e9SAndroid Build Coastguard Worker        tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(
771*da0073e9SAndroid Build Coastguard Worker            tensor
772*da0073e9SAndroid Build Coastguard Worker        )
773*da0073e9SAndroid Build Coastguard Worker        dst1 = tensor.nonzero(as_tuple=False)
774*da0073e9SAndroid Build Coastguard Worker        dst2 = tensor_nc.nonzero(as_tuple=False)
775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst1, dst2, atol=0, rtol=0)
776*da0073e9SAndroid Build Coastguard Worker        dst3 = torch.empty_like(dst1)
777*da0073e9SAndroid Build Coastguard Worker        data_ptr = dst3.data_ptr()
778*da0073e9SAndroid Build Coastguard Worker        # expect dst3 storage to be reused
779*da0073e9SAndroid Build Coastguard Worker        torch.nonzero(tensor, out=dst3)
780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data_ptr, dst3.data_ptr())
781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst1, dst3, atol=0, rtol=0)
782*da0073e9SAndroid Build Coastguard Worker        # discontiguous out
783*da0073e9SAndroid Build Coastguard Worker        dst4 = torch.empty(
784*da0073e9SAndroid Build Coastguard Worker            dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device
785*da0073e9SAndroid Build Coastguard Worker        )[:, ::2]
786*da0073e9SAndroid Build Coastguard Worker        data_ptr = dst4.data_ptr()
787*da0073e9SAndroid Build Coastguard Worker        strides = dst4.stride()
788*da0073e9SAndroid Build Coastguard Worker        torch.nonzero(tensor, out=dst4)
789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(data_ptr, dst4.data_ptr())
790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dst1, dst4, atol=0, rtol=0)
791*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(strides, dst4.stride())
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_non_diff(self, device):
794*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, requires_grad=True)
795*da0073e9SAndroid Build Coastguard Worker        nz = x.nonzero()
796*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nz.requires_grad)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int64, torch.float, torch.complex128)
799*da0073e9SAndroid Build Coastguard Worker    def test_sparse_dense_dim(self, device, dtype):
800*da0073e9SAndroid Build Coastguard Worker        for shape in [(), (2,), (2, 3)]:
801*da0073e9SAndroid Build Coastguard Worker            if dtype.is_complex or dtype.is_floating_point:
802*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(shape, device=device, dtype=dtype)
803*da0073e9SAndroid Build Coastguard Worker            else:
804*da0073e9SAndroid Build Coastguard Worker                x = torch.randint(-9, 9, shape, device=device, dtype=dtype)
805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.sparse_dim(), 0)
806*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.dense_dim(), len(shape))
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestShapeOps, globals())
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
812*da0073e9SAndroid Build Coastguard Worker    run_tests()
813