xref: /aosp_15_r20/external/pytorch/test/test_shape_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import random
4import unittest
5import warnings
6from functools import partial
7from itertools import chain, combinations, permutations, product
8
9import numpy as np
10
11import torch
12from torch import nan
13from torch.testing import make_tensor
14from torch.testing._internal.common_device_type import (
15    dtypes,
16    dtypesIfCUDA,
17    instantiate_device_type_tests,
18    largeTensorTest,
19    onlyCPU,
20    onlyCUDA,
21    onlyNativeDeviceTypes,
22)
23from torch.testing._internal.common_dtype import (
24    all_types,
25    all_types_and,
26    all_types_and_complex_and,
27)
28from torch.testing._internal.common_utils import (
29    IS_JETSON,
30    run_tests,
31    skipIfTorchDynamo,
32    TEST_PRIVATEUSE1_DEVICE_TYPE,
33    TestCase,
34    torch_to_numpy_dtype_dict,
35)
36
37
38# TODO: replace with make_tensor
39def _generate_input(shape, dtype, device, with_extremal):
40    if shape == ():
41        x = torch.tensor((), dtype=dtype, device=device)
42    else:
43        if dtype.is_floating_point or dtype.is_complex:
44            # work around torch.randn not being implemented for bfloat16
45            if dtype == torch.bfloat16:
46                x = torch.randn(*shape, device=device) * random.randint(30, 100)
47                x = x.to(torch.bfloat16)
48            else:
49                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(
50                    30, 100
51                )
52            x[torch.randn(*shape) > 0.5] = 0
53            if with_extremal and dtype.is_floating_point:
54                # Use extremal values
55                x[torch.randn(*shape) > 0.5] = float("nan")
56                x[torch.randn(*shape) > 0.5] = float("inf")
57                x[torch.randn(*shape) > 0.5] = float("-inf")
58            elif with_extremal and dtype.is_complex:
59                x[torch.randn(*shape) > 0.5] = complex("nan")
60                x[torch.randn(*shape) > 0.5] = complex("inf")
61                x[torch.randn(*shape) > 0.5] = complex("-inf")
62        elif dtype == torch.bool:
63            x = torch.zeros(shape, dtype=dtype, device=device)
64            x[torch.randn(*shape) > 0.5] = True
65        else:
66            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
67
68    return x
69
70
71class TestShapeOps(TestCase):
72    # TODO: update to work on CUDA, too
73    @onlyCPU
74    def test_unbind(self, device):
75        x = torch.rand(2, 3, 4, 5)
76        for dim in range(4):
77            res = torch.unbind(x, dim)
78            res2 = x.unbind(dim)
79            self.assertEqual(x.size(dim), len(res))
80            self.assertEqual(x.size(dim), len(res2))
81            for i in range(dim):
82                self.assertEqual(x.select(dim, i), res[i])
83                self.assertEqual(x.select(dim, i), res2[i])
84
85    # TODO: update to work on CUDA, too?
86    @skipIfTorchDynamo("TorchDynamo fails with an unknown error")
87    @onlyCPU
88    def test_tolist(self, device):
89        list0D = []
90        tensor0D = torch.tensor(list0D)
91        self.assertEqual(tensor0D.tolist(), list0D)
92
93        table1D = [1.0, 2.0, 3.0]
94        tensor1D = torch.tensor(table1D)
95        storage = torch.Storage(table1D)
96        self.assertEqual(tensor1D.tolist(), table1D)
97        self.assertEqual(storage.tolist(), table1D)
98        self.assertEqual(tensor1D.tolist(), table1D)
99        self.assertEqual(storage.tolist(), table1D)
100
101        table2D = [[1, 2], [3, 4]]
102        tensor2D = torch.tensor(table2D)
103        self.assertEqual(tensor2D.tolist(), table2D)
104
105        tensor3D = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
106        tensorNonContig = tensor3D.select(1, 1)
107        self.assertFalse(tensorNonContig.is_contiguous())
108        self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]])
109
110    @dtypes(torch.int64, torch.float, torch.complex128)
111    def test_movedim_invalid(self, device, dtype):
112        shape = self._rand_shape(4, min_size=5, max_size=10)
113        x = _generate_input(shape, dtype, device, False)
114
115        for fn in [torch.movedim, torch.moveaxis]:
116            # Invalid `source` and `destination` dimension
117            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
118                fn(x, 5, 0)
119
120            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
121                fn(x, 0, 5)
122
123            # Mismatch in size of `source` and `destination`
124            with self.assertRaisesRegex(
125                RuntimeError, "movedim: Invalid source or destination dims:"
126            ):
127                fn(x, (1, 0), (0,))
128
129            with self.assertRaisesRegex(
130                RuntimeError, "movedim: repeated dim in `source`"
131            ):
132                fn(x, (0, 0), (0, 1))
133
134            with self.assertRaisesRegex(
135                RuntimeError, "movedim: repeated dim in `source`"
136            ):
137                fn(x, (0, 1, 0), (0, 1, 2))
138
139            with self.assertRaisesRegex(
140                RuntimeError, "movedim: repeated dim in `destination`"
141            ):
142                fn(x, (0, 1), (1, 1))
143
144            with self.assertRaisesRegex(
145                RuntimeError, "movedim: repeated dim in `destination`"
146            ):
147                fn(x, (0, 1, 2), (1, 0, 1))
148
149    @dtypes(torch.int64, torch.float, torch.complex128)
150    def test_movedim(self, device, dtype):
151        for fn in [torch.moveaxis, torch.movedim]:
152            for nd in range(5):
153                shape = self._rand_shape(nd, min_size=5, max_size=10)
154                x = _generate_input(shape, dtype, device, with_extremal=False)
155                for random_negative in [True, False]:
156                    for src_dim, dst_dim in permutations(range(nd), r=2):
157                        random_prob = random.random()
158
159                        if random_negative and random_prob > 0.66:
160                            src_dim = src_dim - nd
161                        elif random_negative and random_prob > 0.33:
162                            dst_dim = dst_dim - nd
163                        elif random_negative:
164                            src_dim = src_dim - nd
165                            dst_dim = dst_dim - nd
166
167                        # Integer `source` and `destination`
168                        torch_fn = partial(fn, source=src_dim, destination=dst_dim)
169                        np_fn = partial(
170                            np.moveaxis, source=src_dim, destination=dst_dim
171                        )
172                        self.compare_with_numpy(
173                            torch_fn, np_fn, x, device=None, dtype=None
174                        )
175
176                    if nd == 0:
177                        continue
178
179                    def make_index_negative(sequence, idx):
180                        sequence = list(sequence)
181                        sequence[random_idx] = sequence[random_idx] - nd
182                        return tuple(src_sequence)
183
184                    for src_sequence in permutations(
185                        range(nd), r=random.randint(1, nd)
186                    ):
187                        # Sequence `source` and `destination`
188                        dst_sequence = tuple(
189                            random.sample(range(nd), len(src_sequence))
190                        )
191
192                        # Randomly change a dim to a negative dim representation of itself.
193                        random_prob = random.random()
194                        if random_negative and random_prob > 0.66:
195                            random_idx = random.randint(0, len(src_sequence) - 1)
196                            src_sequence = make_index_negative(src_sequence, random_idx)
197                        elif random_negative and random_prob > 0.33:
198                            random_idx = random.randint(0, len(src_sequence) - 1)
199                            dst_sequence = make_index_negative(dst_sequence, random_idx)
200                        elif random_negative:
201                            random_idx = random.randint(0, len(src_sequence) - 1)
202                            dst_sequence = make_index_negative(dst_sequence, random_idx)
203                            random_idx = random.randint(0, len(src_sequence) - 1)
204                            src_sequence = make_index_negative(src_sequence, random_idx)
205
206                        torch_fn = partial(
207                            fn, source=src_sequence, destination=dst_sequence
208                        )
209                        np_fn = partial(
210                            np.moveaxis, source=src_sequence, destination=dst_sequence
211                        )
212                        self.compare_with_numpy(
213                            torch_fn, np_fn, x, device=None, dtype=None
214                        )
215
216            # Move dim to same position
217            x = torch.randn(2, 3, 5, 7, 11)
218            torch_fn = partial(fn, source=(0, 1), destination=(0, 1))
219            np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
220            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
221
222            torch_fn = partial(fn, source=1, destination=1)
223            np_fn = partial(np.moveaxis, source=1, destination=1)
224            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
225
226            # Empty Sequence
227            torch_fn = partial(fn, source=(), destination=())
228            np_fn = partial(np.moveaxis, source=(), destination=())
229            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
230
231    @dtypes(torch.float, torch.bool)
232    def test_diag(self, device, dtype):
233        if dtype is torch.bool:
234            x = torch.rand(100, 100, device=device) >= 0.5
235        else:
236            x = torch.rand(100, 100, dtype=dtype, device=device)
237
238        res1 = torch.diag(x)
239        res2 = torch.tensor((), dtype=dtype, device=device)
240        torch.diag(x, out=res2)
241        self.assertEqual(res1, res2)
242
243    def test_diagonal(self, device):
244        x = torch.randn((100, 100), device=device)
245        result = torch.diagonal(x)
246        expected = torch.diag(x)
247        self.assertEqual(result, expected)
248
249        x = torch.randn((100, 100), device=device)
250        result = torch.diagonal(x, 17)
251        expected = torch.diag(x, 17)
252        self.assertEqual(result, expected)
253
254    @onlyCPU
255    @dtypes(torch.float)
256    def test_diagonal_multidim(self, device, dtype):
257        x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device)
258        xn = x.numpy()
259        for args in [(2, 2, 3), (2,), (-2, 1, 2), (0, -2, -1)]:
260            result = torch.diagonal(x, *args)
261            expected = xn.diagonal(*args)
262            self.assertEqual(expected.shape, result.shape)
263            self.assertEqual(expected, result)
264        # test non-continguous
265        xp = x.permute(1, 2, 3, 0)
266        result = torch.diagonal(xp, 0, -2, -1)
267        expected = xp.numpy().diagonal(0, -2, -1)
268        self.assertEqual(expected.shape, result.shape)
269        self.assertEqual(expected, result)
270
271    @onlyNativeDeviceTypes
272    @dtypes(*all_types())
273    @dtypesIfCUDA(*all_types_and(torch.half))
274    def test_trace(self, device, dtype):
275        def test(shape):
276            tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
277            expected_dtype = tensor.sum().dtype
278            expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]
279
280            result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
281            expected = torch.tensor(result, device=device)
282            self.assertEqual(tensor.trace(), expected)
283
284        shapes = (
285            [10, 1],
286            [1, 10],
287            [100, 100],
288            [20, 100],
289            [100, 20],
290        )
291        for shape in shapes:
292            test(shape)
293
294    def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans):
295        """
296        Creates a random tensor for a given device and dtype, and computes the expected clamped
297        values given the min_vals and/or max_vals.
298        If with_nans is provided, then some values are randomly set to nan.
299        """
300        X = torch.rand(100, device=device).mul(50).add(-25)  # uniform in [-25, 25]
301        X = X.to(dtype)
302        if with_nans:
303            mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device)
304            X[mask] = nan
305
306        if isinstance(min_vals, torch.Tensor):
307            min_vals = min_vals.cpu().numpy()
308
309        if isinstance(max_vals, torch.Tensor):
310            max_vals = max_vals.cpu().numpy()
311
312        # Use NumPy implementation as reference
313        X_clamped = torch.tensor(
314            np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device
315        )
316        return X, X_clamped
317
318    # Tests clamp and its alias, clip
319    @dtypes(torch.int64, torch.float32)
320    def test_clamp(self, device, dtype):
321        op_list = (
322            torch.clamp,
323            torch.Tensor.clamp,
324            torch.Tensor.clamp_,
325            torch.clip,
326            torch.Tensor.clip,
327            torch.Tensor.clip_,
328        )
329
330        # min/max argument product
331        args = product((-10, None), (10, None))
332
333        for op in op_list:
334            for min_val, max_val in args:
335                if min_val is None and max_val is None:
336                    continue
337
338                X, Y_expected = self.generate_clamp_baseline(
339                    device, dtype, min_vals=min_val, max_vals=max_val, with_nans=False
340                )
341
342                # Test op
343                X1 = X.clone()  # So that the in-place ops do not change X
344                Y_actual = op(X1, min_val, max_val)
345                self.assertEqual(Y_expected, Y_actual)
346
347                # Test op-out behavior (out does not exist for method versions)
348                if op in (torch.clamp, torch.clip):
349                    Y_out = torch.empty_like(X)
350                    op(X, min=min_val, max=max_val, out=Y_out)
351                    self.assertEqual(Y_expected, Y_out)
352
353    def test_clamp_propagates_nans(self, device):
354        op_list = (
355            torch.clamp,
356            torch.Tensor.clamp,
357            torch.Tensor.clamp_,
358            torch.clip,
359            torch.Tensor.clip,
360            torch.Tensor.clip_,
361        )
362
363        # min/max argument product
364        args = product((-10, None), (10, None))
365
366        for op in op_list:
367            for min_val, max_val in args:
368                if min_val is None and max_val is None:
369                    continue
370
371                X, Y_expected = self.generate_clamp_baseline(
372                    device,
373                    torch.float,
374                    min_vals=min_val,
375                    max_vals=max_val,
376                    with_nans=True,
377                )
378                Y_expected = torch.isnan(Y_expected)
379
380                # Test op
381                X1 = X.clone()  # So that the in-place ops do not change X
382                Y_actual = op(X1, min_val, max_val)
383                self.assertEqual(Y_expected, torch.isnan(Y_actual))
384
385                # Test op-out behavior (out does not exist for method versions)
386                if op in (torch.clamp, torch.clip):
387                    Y_out = torch.empty_like(X)
388                    op(X, min_val, max_val, out=Y_out)
389                    self.assertEqual(Y_expected, torch.isnan(Y_out))
390
391    def test_clamp_raises_arg_errors(self, device):
392        X = torch.randn(100, dtype=torch.float, device=device)
393        error_msg = "At least one of 'min' or 'max' must not be None"
394        with self.assertRaisesRegex(RuntimeError, error_msg):
395            X.clamp()
396        with self.assertRaisesRegex(RuntimeError, error_msg):
397            X.clamp_()
398        with self.assertRaisesRegex(RuntimeError, error_msg):
399            torch.clamp(X)
400
401    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
402    def test_flip(self, device, dtype):
403        make_from_data = partial(torch.tensor, device=device, dtype=dtype)
404        make_from_size = partial(make_tensor, device=device, dtype=dtype)
405
406        def test_flip_impl(input_t, dims, output_t):
407            def all_t():
408                yield input_t, output_t
409                if dtype is torch.float:
410                    # We generate quantized versions as well
411                    for qdtype in (torch.quint8, torch.qint8, torch.qint32):
412                        qinput_t = torch.quantize_per_tensor(input_t, 0.1, 5, qdtype)
413                        qoutput_t = torch.quantize_per_tensor(output_t, 0.1, 5, qdtype)
414                        yield qinput_t, qoutput_t
415
416            for in_t, out_t in all_t():
417                self.assertEqual(in_t.flip(dims), out_t)
418                n = in_t.ndim
419                if not isinstance(dims, tuple):
420                    # Wrap dim
421                    self.assertEqual(in_t.flip(-n + dims), out_t)
422                else:
423                    # Permute dimensions
424                    for p_dims in permutations(dims):
425                        self.assertEqual(in_t.flip(p_dims), out_t)
426                        if len(p_dims) > 0:
427                            # Wrap 1st dim
428                            self.assertEqual(
429                                in_t.flip((-n + p_dims[0],) + p_dims[1:]), out_t
430                            )
431
432        def gen_data():
433            # Basic tests
434            data = make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2)
435            nonctg = make_from_size((2, 2, 2), noncontiguous=True).copy_(data)
436
437            dims_result = (
438                (0, make_from_data([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2)),
439                (1, make_from_data([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2)),
440                (2, make_from_data([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2)),
441                ((0, 1), make_from_data([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2)),
442                ((0, 1, 2), make_from_data([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2)),
443            )
444            for in_tensor, (dims, out_tensor) in product((data, nonctg), dims_result):
445                yield in_tensor, dims, out_tensor
446
447            # Expanded
448            in_t = make_from_data([1, 2, 3]).view(3, 1).expand(3, 2)
449            dims = 0
450            out_t = make_from_data([3, 3, 2, 2, 1, 1]).view(3, 2)
451            yield in_t, dims, out_t
452            # Noop on expanded dimension
453            yield in_t, 1, in_t
454
455            # Transposed
456            in_t = (
457                make_from_data([1, 2, 3, 4, 5, 6, 7, 8]).view(2, 2, 2).transpose(0, 1)
458            )
459            dims = (0, 1, 2)
460            out_t = make_from_data([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2)
461            yield in_t, dims, out_t
462
463            # Rectangular case
464            in_t = make_from_data([1, 2, 3, 4, 5, 6]).view(2, 3)
465            dims = 0
466            out_t = make_from_data([[4, 5, 6], [1, 2, 3]])
467            yield in_t, dims, out_t
468            dims = 1
469            out_t = make_from_data([[3, 2, 1], [6, 5, 4]])
470            yield in_t, dims, out_t
471
472            # vectorized NCHW cases (images)
473            if device == "cpu" and dtype != torch.bfloat16:
474                for mf in [torch.contiguous_format, torch.channels_last]:
475                    for c in [2, 3, 8, 16]:
476                        in_t = make_from_size((2, c, 32, 32)).contiguous(
477                            memory_format=mf
478                        )
479                        np_in_t = in_t.numpy()
480
481                        np_out_t = np_in_t[:, :, :, ::-1].copy()
482                        out_t = torch.from_numpy(np_out_t)
483                        yield in_t, 3, out_t
484
485                        np_out_t = np_in_t[:, :, ::-1, :].copy()
486                        out_t = torch.from_numpy(np_out_t)
487                        yield in_t, 2, out_t
488
489                        # non-contig cases
490                        in_tt = in_t[..., ::2, :]
491                        np_in_t = in_tt.numpy()
492                        np_out_t = np_in_t[:, :, :, ::-1].copy()
493                        out_t = torch.from_numpy(np_out_t)
494                        yield in_tt, 3, out_t
495
496                        in_tt = in_t[..., ::2]
497                        np_in_t = in_tt.numpy()
498                        np_out_t = np_in_t[:, :, :, ::-1].copy()
499                        out_t = torch.from_numpy(np_out_t)
500                        yield in_tt, 3, out_t
501
502            # Noops (edge cases)
503
504            # Size 0
505            in_t = make_from_data(())
506            yield in_t, 0, in_t
507            yield in_t, (), in_t
508
509            # dims = ()
510            in_t = make_from_size((3, 2, 1))
511            yield in_t, (), in_t
512
513            # Zero elements, non-zero size
514            in_t = make_from_size((3, 0, 2))
515            for i in range(in_t.ndim):
516                yield in_t, i, in_t
517
518            # Size 1
519            in_t = make_from_size(())
520            yield in_t, 0, in_t
521            in_t = make_from_size((1,))
522            yield in_t, 0, in_t
523
524        for in_tensor, dims, out_tensor in gen_data():
525            test_flip_impl(in_tensor, dims, out_tensor)
526
527        # test for shape
528        size = [2, 3, 4]
529        data = make_from_size(size)
530        possible_dims = range(len(size))
531        test_dims = chain(
532            combinations(possible_dims, 1), combinations(possible_dims, 2)
533        )
534
535        for dims in test_dims:
536            self.assertEqual(size, list(data.flip(dims).size()))
537
538    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
539    def test_flip_errors(self, device, dtype):
540        make_arg = partial(make_tensor, dtype=dtype, device=device)
541        data = make_arg((2, 2, 2))
542
543        # not allow flip on the same dim more than once
544        self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
545        # not allow empty list as input
546        self.assertRaises(TypeError, lambda: data.flip())
547
548        # not allow dim > max dim
549        self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3))
550        self.assertRaises(IndexError, lambda: data.flip(3))
551
552    def _rand_shape(self, dim, min_size, max_size):
553        return tuple(torch.randint(min_size, max_size + 1, (dim,)))
554
555    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
556    def test_flip_numpy(self, device, dtype):
557        make_arg = partial(make_tensor, dtype=dtype, device=device)
558
559        for ndim in [3, 4]:
560            shape = self._rand_shape(ndim, 5, 10)
561            data = make_arg(shape)
562
563            # Axis to sample for given shape.
564            for i in range(1, ndim + 1):
565                # Check all combinations of `i` axis.
566                for flip_dim in combinations(range(ndim), i):
567                    torch_fn = partial(torch.flip, dims=flip_dim)
568                    np_fn = partial(np.flip, axis=flip_dim)
569                    self.compare_with_numpy(torch_fn, np_fn, data)
570
571    @onlyCUDA  # CPU is too slow
572    @largeTensorTest("17GB")  # 4 tensors of 4GB (in, out) x (torch, numpy) + 1GB
573    @largeTensorTest(
574        "81GB", "cpu"
575    )  # even for CUDA test, sufficient system memory is required
576    @unittest.skipIf(IS_JETSON, "Too large for Jetson")
577    def test_flip_large_tensor(self, device):
578        t_in = torch.empty(2**32 + 1, dtype=torch.uint8).random_()
579        torch_fn = partial(torch.flip, dims=(0,))
580        np_fn = partial(np.flip, axis=0)
581        self.compare_with_numpy(torch_fn, np_fn, t_in)
582        del t_in
583
584    def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype):
585        for dim in range(min_dim, max_dim + 1):
586            shape = self._rand_shape(dim, 5, 10)
587            # Randomly scale the input
588            if dtype.is_floating_point or dtype.is_complex:
589                data = torch.randn(*shape, device=device, dtype=dtype)
590            else:
591                data = torch.randint(0, 10, shape, device=device, dtype=dtype)
592            self.compare_with_numpy(torch_fn, np_fn, data)
593
594    @dtypes(torch.int64, torch.double, torch.cdouble)
595    def test_fliplr(self, device, dtype):
596        self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype)
597
598    @dtypes(torch.int64, torch.double, torch.cdouble)
599    def test_fliplr_invalid(self, device, dtype):
600        x = torch.randn(42).to(dtype)
601        with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
602            torch.fliplr(x)
603        with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
604            torch.fliplr(torch.tensor(42, device=device, dtype=dtype))
605
606    @dtypes(torch.int64, torch.double, torch.cdouble)
607    def test_flipud(self, device, dtype):
608        self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype)
609
610    @dtypes(torch.int64, torch.double, torch.cdouble)
611    def test_flipud_invalid(self, device, dtype):
612        with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."):
613            torch.flipud(torch.tensor(42, device=device, dtype=dtype))
614
615    def test_rot90(self, device):
616        data = torch.arange(1, 5, device=device).view(2, 2)
617        self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
618        self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1]))
619        self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1]))
620        self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1]))
621
622        # test for default args k=1, dims=[0, 1]
623        self.assertEqual(data.rot90(), data.rot90(1, [0, 1]))
624
625        # test for reversed order of dims
626        self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0]))
627
628        # test for modulo of k
629        self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1]))
630        self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1]))
631        self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1]))
632
633        # test for dims out-of-range error
634        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3]))
635        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2]))
636
637        # test tensor with more than 2D
638        data = torch.arange(1, 9, device=device).view(2, 2, 2)
639        self.assertEqual(
640            torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])
641        )
642        self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2]))
643
644        # test for errors
645        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3]))
646        self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1]))
647        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2]))
648        self.assertRaises(RuntimeError, lambda: data.rot90(1, [0]))
649
650    @skipIfTorchDynamo("TorchDynamo fails with an unknown error")
651    @dtypes(torch.cfloat, torch.cdouble)
652    def test_complex_rot90(self, device, dtype):
653        shape = self._rand_shape(random.randint(2, 4), 5, 10)
654        for rot_times in range(4):
655            data = torch.randn(*shape, device=device, dtype=dtype)
656            torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1])
657            np_fn = partial(np.rot90, k=rot_times, axes=[0, 1])
658            self.compare_with_numpy(torch_fn, np_fn, data)
659
660    # TODO: update once warning flag is available to always trigger ONCE warnings
661    # Ensures nonzero does not throw a warning, even when the as_tuple argument
662    #   is not provided
663    def test_nonzero_no_warning(self, device):
664        t = torch.randn((2, 2), device=device)
665        with warnings.catch_warnings(record=True) as w:
666            warnings.simplefilter("always")
667            torch.nonzero(t)
668            t.nonzero()
669            self.assertEqual(len(w), 0)
670
671    @dtypes(*all_types_and(torch.half, torch.bool, torch.bfloat16))
672    def test_nonzero(self, device, dtype):
673        shapes = [
674            torch.Size((12,)),
675            torch.Size((12, 1)),
676            torch.Size((1, 12)),
677            torch.Size((6, 2)),
678            torch.Size((3, 2, 2)),
679            torch.Size((5, 5, 5)),
680        ]
681
682        def gen_nontrivial_input(shape, dtype, device):
683            if dtype != torch.bfloat16:
684                return torch.randint(2, shape, device=device, dtype=dtype)
685            else:
686                # windows does not work for bfloat16 randing
687                return torch.randint(2, shape, device=device, dtype=torch.float).to(
688                    dtype
689                )
690
691        for shape in shapes:
692            tensor = gen_nontrivial_input(shape, dtype, device)
693            dst1 = torch.nonzero(tensor, as_tuple=False)
694            dst2 = tensor.nonzero(as_tuple=False)
695            dst3 = torch.empty([], dtype=torch.long, device=device)
696            torch.nonzero(tensor, out=dst3)
697            if self.device_type != "xla":
698                # xla does not raise runtime error
699                self.assertRaisesRegex(
700                    RuntimeError,
701                    "scalar type Long",
702                    lambda: torch.nonzero(
703                        tensor, out=torch.empty([], dtype=torch.float, device=device)
704                    ),
705                )
706            if (
707                self.device_type == "cuda"
708                or self.device_type == TEST_PRIVATEUSE1_DEVICE_TYPE
709            ):
710                self.assertRaisesRegex(
711                    RuntimeError,
712                    "on the same device",
713                    lambda: torch.nonzero(
714                        tensor, out=torch.empty([], dtype=torch.long)
715                    ),
716                )
717            np_array = (
718                tensor.cpu().numpy()
719                if dtype != torch.bfloat16
720                else tensor.float().cpu().numpy()
721            )
722            np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
723            self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
724            self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
725            self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
726            tup1 = torch.nonzero(tensor, as_tuple=True)
727            tup2 = tensor.nonzero(as_tuple=True)
728            tup1 = torch.stack(tup1).t().cpu()
729            tup2 = torch.stack(tup2).t().cpu()
730            self.assertEqual(tup1, np_result, atol=0, rtol=0)
731            self.assertEqual(tup2, np_result, atol=0, rtol=0)
732
733    def test_nonzero_astuple_out(self, device):
734        t = torch.randn((3, 3, 3), device=device)
735        out = torch.empty_like(t, dtype=torch.long)
736
737        with self.assertRaises(RuntimeError):
738            torch.nonzero(t, as_tuple=True, out=out)
739
740        self.assertEqual(
741            torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out)
742        )
743
744        # Verifies that JIT script cannot handle the as_tuple kwarg
745        # See Issue https://github.com/pytorch/pytorch/issues/45499.
746        def _foo(t):
747            tuple_result = torch.nonzero(t, as_tuple=True)
748            nontuple_result = torch.nonzero(t, as_tuple=False)
749            out = torch.empty_like(nontuple_result)
750            torch.nonzero(t, as_tuple=False, out=out)
751            return tuple_result, nontuple_result, out
752
753        with self.assertRaises(RuntimeError):
754            scripted_foo = torch.jit.script(_foo)
755
756        # Verifies that JIT tracing works fine
757        traced_foo = torch.jit.trace(_foo, t)
758        traced_tuple, traced_nontuple, traced_out = traced_foo(t)
759        expected_tuple = torch.nonzero(t, as_tuple=True)
760        expected_nontuple = torch.nonzero(t)
761
762        self.assertEqual(traced_tuple, expected_tuple)
763        self.assertEqual(traced_nontuple, expected_nontuple)
764        self.assertEqual(traced_out, expected_nontuple)
765
766    @onlyNativeDeviceTypes
767    def test_nonzero_discontiguous(self, device):
768        shape = (4, 4)
769        tensor = torch.randint(2, shape, device=device)
770        tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(
771            tensor
772        )
773        dst1 = tensor.nonzero(as_tuple=False)
774        dst2 = tensor_nc.nonzero(as_tuple=False)
775        self.assertEqual(dst1, dst2, atol=0, rtol=0)
776        dst3 = torch.empty_like(dst1)
777        data_ptr = dst3.data_ptr()
778        # expect dst3 storage to be reused
779        torch.nonzero(tensor, out=dst3)
780        self.assertEqual(data_ptr, dst3.data_ptr())
781        self.assertEqual(dst1, dst3, atol=0, rtol=0)
782        # discontiguous out
783        dst4 = torch.empty(
784            dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device
785        )[:, ::2]
786        data_ptr = dst4.data_ptr()
787        strides = dst4.stride()
788        torch.nonzero(tensor, out=dst4)
789        self.assertEqual(data_ptr, dst4.data_ptr())
790        self.assertEqual(dst1, dst4, atol=0, rtol=0)
791        self.assertEqual(strides, dst4.stride())
792
793    def test_nonzero_non_diff(self, device):
794        x = torch.randn(10, requires_grad=True)
795        nz = x.nonzero()
796        self.assertFalse(nz.requires_grad)
797
798    @dtypes(torch.int64, torch.float, torch.complex128)
799    def test_sparse_dense_dim(self, device, dtype):
800        for shape in [(), (2,), (2, 3)]:
801            if dtype.is_complex or dtype.is_floating_point:
802                x = torch.rand(shape, device=device, dtype=dtype)
803            else:
804                x = torch.randint(-9, 9, shape, device=device, dtype=dtype)
805            self.assertEqual(x.sparse_dim(), 0)
806            self.assertEqual(x.dense_dim(), len(shape))
807
808
809instantiate_device_type_tests(TestShapeOps, globals())
810
811if __name__ == "__main__":
812    run_tests()
813