xref: /aosp_15_r20/external/pytorch/test/test_sort_and_select.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import random
4from itertools import permutations, product
5
6import numpy as np
7
8import torch
9from torch import nan
10from torch.testing import make_tensor
11from torch.testing._internal.common_device_type import (
12    dtypes,
13    dtypesIfCPU,
14    dtypesIfCUDA,
15    instantiate_device_type_tests,
16    largeTensorTest,
17    onlyCPU,
18    onlyCUDA,
19    onlyNativeDeviceTypes,
20)
21from torch.testing._internal.common_dtype import (
22    all_types,
23    all_types_and,
24    floating_types_and,
25    integral_types,
26)
27from torch.testing._internal.common_utils import (
28    run_tests,
29    skipIfTorchDynamo,
30    slowTest,
31    TestCase,
32)
33
34
35class TestSortAndSelect(TestCase):
36    def assertIsOrdered(self, order, x, mxx, ixx, task):
37        SIZE = x.size(1)
38        if order == "descending":
39
40            def check_order(a, b):
41                # `a != a` because we put NaNs
42                # at the end of ascending sorted lists,
43                # and the beginning of descending ones.
44                return ((a != a) | (a >= b)).all().item()
45
46        elif order == "ascending":
47
48            def check_order(a, b):
49                # see above
50                return ((b != b) | (a <= b)).all().item()
51
52        else:
53            error(  # noqa: F821
54                f'unknown order "{order}", must be "ascending" or "descending"'
55            )
56
57        are_ordered = True
58        for k in range(1, SIZE):
59            self.assertTrue(
60                check_order(mxx[:, k - 1], mxx[:, k]),
61                f"torch.sort ({order}) values unordered for {task}",
62            )
63
64        seen = set()
65        indicesCorrect = True
66        size0 = x.size(0)
67        size = x.size(x.dim() - 1)
68        x = x.tolist()
69        mxx = mxx.tolist()
70        ixx = ixx.tolist()
71        for k in range(size0):
72            seen.clear()
73            for j in range(size):
74                self.assertEqual(
75                    x[k][ixx[k][j]],
76                    mxx[k][j],
77                    msg=f"torch.sort ({order}) indices wrong for {task}",
78                )
79                seen.add(ixx[k][j])
80            self.assertEqual(len(seen), size)
81
82    def test_sort(self, device):
83        # on CUDA 2048 vs >2048 have different code path for the dim being sorted
84        for SIZE in (4, 2049):
85            x = torch.rand(4, SIZE, device=device)
86            res1val, res1ind = torch.sort(x)
87
88            # Test inplace
89            y = x.clone()
90            y_inds = torch.tensor((), dtype=torch.int64, device=device)
91            torch.sort(y, out=(y, y_inds))
92            x_vals, x_inds = torch.sort(x)
93            self.assertEqual(x_vals, y)
94            self.assertEqual(x_inds, y_inds)
95
96            # Test use of result tensor
97            res2val = torch.tensor((), device=device)
98            res2ind = torch.tensor((), device=device, dtype=torch.long)
99            torch.sort(x, out=(res2val, res2ind))
100            self.assertEqual(res1val, res2val, atol=0, rtol=0)
101            self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
102            self.assertEqual(torch.argsort(x), res1ind)
103            self.assertEqual(x.argsort(), res1ind)
104
105            # Test sorting of random numbers
106            self.assertIsOrdered("ascending", x, res2val, res2ind, "random")
107
108            # Test simple sort
109            self.assertEqual(
110                torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
111                torch.tensor((10, 20, 30, 40, 50), device=device),
112                atol=0,
113                rtol=0,
114            )
115
116            # Test that we still have proper sorting with duplicate keys
117            x = torch.floor(torch.rand(4, SIZE, device=device) * 10)
118            torch.sort(x, out=(res2val, res2ind))
119            self.assertIsOrdered(
120                "ascending", x, res2val, res2ind, "random with duplicate keys"
121            )
122
123            # DESCENDING SORT
124            x = torch.rand(4, SIZE, device=device)
125            res1val, res1ind = torch.sort(x, x.dim() - 1, True)
126
127            # Test use of result tensor
128            res2val = torch.tensor((), device=device)
129            res2ind = torch.tensor((), device=device, dtype=torch.long)
130            torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind))
131            self.assertEqual(res1val, res2val, atol=0, rtol=0)
132            self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
133            self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind)
134            self.assertEqual(x.argsort(x.dim() - 1, True), res1ind)
135
136            # Test sorting of random numbers
137            self.assertIsOrdered("descending", x, res2val, res2ind, "random")
138
139            # Test simple sort task
140            self.assertEqual(
141                torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[
142                    0
143                ],
144                torch.tensor((50, 40, 30, 20, 10), device=device),
145                atol=0,
146                rtol=0,
147            )
148
149            # Test that we still have proper sorting with duplicate keys
150            self.assertIsOrdered(
151                "descending", x, res2val, res2ind, "random with duplicate keys"
152            )
153
154            # Test argument sorting with and without stable
155            x = torch.tensor([1, 10, 2, 2, 3, 7, 7, 8, 9, 9] * 3)
156            self.assertEqual(
157                torch.argsort(x, stable=True), torch.sort(x, stable=True).indices
158            )
159            self.assertEqual(
160                torch.argsort(x, stable=False), torch.sort(x, stable=False).indices
161            )
162            self.assertEqual(torch.argsort(x), torch.sort(x).indices)
163
164            # Test sorting with NaNs
165            x = torch.rand(4, SIZE, device=device)
166            x[1][2] = float("NaN")
167            x[3][0] = float("NaN")
168            torch.sort(x, out=(res2val, res2ind))
169            self.assertIsOrdered("ascending", x, res2val, res2ind, "random with NaNs")
170            torch.sort(x, out=(res2val, res2ind), descending=True)
171            self.assertIsOrdered("descending", x, res2val, res2ind, "random with NaNs")
172
173    def test_sort_stable_none(self):
174        # Called sort with stable=None used to trigger an assertion
175        # See https://github.com/pytorch/pytorch/issues/117255
176        x = torch.ones(10)
177        y = x.sort(stable=None).values
178        self.assertTrue(torch.all(y == torch.ones(10)).item())
179
180    @onlyCUDA
181    def test_sort_large_slice(self, device):
182        # tests direct cub path
183        x = torch.randn(4, 1024000, device=device)
184        res1val, res1ind = torch.sort(x, stable=True)
185        torch.cuda.synchronize()
186        # assertIsOrdered is too slow, so just compare to cpu
187        res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), stable=True)
188        self.assertEqual(res1val, res1val_cpu.cuda())
189        self.assertEqual(res1ind, res1ind_cpu.cuda())
190        res1val, res1ind = torch.sort(x, descending=True, stable=True)
191        torch.cuda.synchronize()
192        res1val_cpu, res1ind_cpu = torch.sort(x.cpu(), descending=True, stable=True)
193        self.assertEqual(res1val, res1val_cpu.cuda())
194        self.assertEqual(res1ind, res1ind_cpu.cuda())
195
196    # FIXME: remove torch.bool from unsupported types once support is added for cub sort
197    @dtypes(*all_types_and(torch.half, torch.bfloat16))
198    def test_stable_sort(self, device, dtype):
199        sizes = (100, 1000, 10000)
200        for ncopies in sizes:
201            x = torch.tensor([0, 1] * ncopies, dtype=dtype, device=device)
202            _, idx = x.sort(stable=True)
203            self.assertEqual(
204                idx[:ncopies],
205                torch.arange(start=0, end=2 * ncopies, step=2, device=device),
206            )
207            self.assertEqual(
208                idx[ncopies:],
209                torch.arange(start=1, end=2 * ncopies, step=2, device=device),
210            )
211
212    @onlyCUDA
213    @dtypes(torch.uint8)
214    @largeTensorTest("200GB")  # Unfortunately 80GB A100 is not large enough
215    def test_sort_large(self, device, dtype):
216        t0 = torch.randperm(8192, device=device).to(dtype)
217        t = t0.view(1, 8192).expand(2**18 + 1, -1).contiguous()
218        v, i = t.sort()
219        del t
220        iv, im = i.var_mean(dim=0)
221        del i
222        vv, vm = v.var_mean(dim=0)
223        del v
224        self.assertEqual(vv, torch.zeros_like(vv))
225        self.assertEqual(iv, torch.zeros_like(iv))
226        self.assertEqual(vm, torch.arange(255, dtype=dtype, device=device))
227        self.assertEqual(im, t0.sort().indices)
228
229    @dtypes(torch.float32)
230    def test_sort_restride(self, device, dtype):
231        # Input: non-contiguous (stride: 5) 3-element array
232        tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
233        # Outputs: 0-dim tensors
234        # They will need to be resized, which means they will also be
235        # restrided with the input tensor's strides as base.
236        values = torch.tensor(0, dtype=dtype, device=device)
237        indices = torch.tensor(0, dtype=torch.long, device=device)
238        torch.sort(tensor, out=(values, indices))
239        # Check: outputs were restrided to dense strides
240        self.assertEqual(values.stride(), (1,))
241        self.assertEqual(indices.stride(), (1,))
242        # Check: 'tensor'  indexed by 'indices' is equal to 'values'
243        self.assertEqual(tensor[indices], values)
244
245    def _test_sort_discontiguous(self, device, dtype):
246        # on CUDA 2048 vs >2048 have different code path for the dim being sorted
247        sizes = (5, 7, 2049)
248        for shape in permutations(sizes):
249            for perm in permutations((0, 1, 2)):
250                for dim in range(3):
251                    t = torch.randn(shape, device=device, dtype=dtype).permute(perm)
252                    r1 = t.sort(dim=dim)
253                    r2 = t.contiguous().sort(dim=dim)
254                    self.assertEqual(r1, r2)
255                    n = t.size(dim)
256
257                    # assert ordered
258                    self.assertTrue(
259                        (
260                            r1.values.narrow(dim, 1, n - 1)
261                            >= r1.values.narrow(dim, 0, n - 1)
262                        ).all()
263                    )
264
265                    # assert that different segments does not mix, which can easily happen
266                    # if the stride is not handled correctly
267                    self.assertTrue(
268                        (t.unsqueeze(-1).transpose(dim, -1) == r1.values.unsqueeze(-1))
269                        .any(dim=dim)
270                        .any(dim=-1)
271                        .all()
272                    )
273
274                    # assert stride is preserved
275                    if self.device_type == "cuda":
276                        # FIXME: this behavior should be true for all cases, not
277                        # just the one specified in if condition
278                        self.assertEqual(r1.values.stride(), t.stride())
279                        self.assertEqual(r1.indices.stride(), t.stride())
280
281    @onlyCUDA
282    @dtypes(torch.float32)
283    def test_sort_discontiguous(self, device, dtype):
284        self._test_sort_discontiguous(device, dtype)
285
286    @slowTest  # this test is slow on CPU, but not on CUDA
287    @onlyCPU
288    @dtypes(torch.float32)
289    def test_sort_discontiguous_slow(self, device, dtype):
290        self._test_sort_discontiguous(device, dtype)
291
292    @dtypes(torch.float32)
293    def test_sort_1d_output_discontiguous(self, device, dtype):
294        tensor = torch.randn(12, device=device, dtype=dtype)[:6]
295        values = torch.empty_like(tensor)[::2]
296        indices = torch.empty(18, device=device, dtype=torch.long)[::3]
297        torch.sort(tensor, out=(values, indices))
298        values_cont, indices_cont = tensor.sort()
299        self.assertEqual(indices, indices_cont)
300        self.assertEqual(values, values_cont)
301
302    @slowTest
303    @onlyCPU
304    @dtypes(*integral_types())
305    def test_sort_1d_parallel(self, device, dtype):
306        low = 0 if dtype == torch.uint8 else -128
307        tensor = torch.randint(
308            low=low, high=127, size=(100000,), device=device, dtype=dtype
309        )
310        vals, _ = torch.sort(tensor, stable=True)
311        self.assertEqual(True, torch.all(vals[:-1] <= vals[1:]))
312
313    @dtypes(torch.float32)
314    def test_topk_1d_output_discontiguous(self, device, dtype):
315        tensor = torch.randn(12, device=device, dtype=dtype)
316        values = torch.empty_like(tensor)[::2]
317        indices = torch.empty(18, device=device, dtype=torch.long)[::3]
318        for sorted in (True, False):
319            # outputs of `sorted=False` test are not guaranteed to be the same,
320            # but with current implementation they are
321            torch.topk(tensor, 6, sorted=sorted, out=(values, indices))
322            values_cont, indices_cont = tensor.topk(6, sorted=sorted)
323            self.assertEqual(indices, indices_cont)
324            self.assertEqual(values, values_cont)
325
326    # FIXME: remove torch.bool from unsupported types once support is added for cub sort
327    @dtypes(*all_types_and(torch.half, torch.bfloat16))
328    def test_stable_sort_against_numpy(self, device, dtype):
329        if dtype in floating_types_and(torch.float16, torch.bfloat16):
330            inf = float("inf")
331            neg_inf = -float("inf")
332            nan = float("nan")
333        else:
334            if dtype != torch.bool:
335                # no torch.iinfo support for torch.bool
336                inf = torch.iinfo(dtype).max
337                neg_inf = torch.iinfo(dtype).min
338            else:
339                inf = True
340                neg_inf = ~inf
341            # no nan for integral types, we use inf instead for simplicity
342            nan = inf
343
344        def generate_samples():
345            from itertools import chain, combinations
346
347            for sizes in [(1025,), (10000,)]:
348                size = sizes[0]
349                # binary strings
350                yield (torch.tensor([0, 1] * size, dtype=dtype, device=device), 0)
351
352            if self.device_type == "cuda":
353                return
354
355            yield (torch.tensor([0, 1] * 100, dtype=dtype, device=device), 0)
356
357            def repeated_index_fill(t, dim, idxs, vals):
358                res = t
359                for idx, val in zip(idxs, vals):
360                    res = res.index_fill(dim, idx, val)
361                return res
362
363            for sizes in [(1, 10), (10, 1), (10, 10), (10, 10, 10)]:
364                size = min(*sizes)
365                x = (torch.randn(*sizes, device=device) * size).to(dtype)
366                yield (x, 0)
367
368                # Generate tensors which are being filled at random locations
369                # with values from the non-empty subsets of the set (inf, neg_inf, nan)
370                # for each dimension.
371                n_fill_vals = 3  # cardinality of (inf, neg_inf, nan)
372                for dim in range(len(sizes)):
373                    idxs = (
374                        torch.randint(high=size, size=(size // 10,))
375                        for i in range(n_fill_vals)
376                    )
377                    vals = (inf, neg_inf, nan)
378                    subsets = chain.from_iterable(
379                        combinations(list(zip(idxs, vals)), r)
380                        for r in range(1, n_fill_vals + 1)
381                    )
382                    for subset in subsets:
383                        idxs_subset, vals_subset = zip(*subset)
384                        yield (
385                            repeated_index_fill(x, dim, idxs_subset, vals_subset),
386                            dim,
387                        )
388
389        for sample, dim in generate_samples():
390            _, idx_torch = sample.sort(dim=dim, stable=True)
391            if dtype is torch.bfloat16:
392                sample_numpy = sample.float().cpu().numpy()
393            else:
394                sample_numpy = sample.cpu().numpy()
395            idx_numpy = np.argsort(sample_numpy, axis=dim, kind="stable")
396            self.assertEqual(idx_torch, idx_numpy)
397
398    @dtypes(*all_types_and(torch.half, torch.bfloat16))
399    def test_msort(self, device, dtype):
400        def test(shape):
401            tensor = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
402            if tensor.size() != torch.Size([]):
403                if dtype is torch.bfloat16:
404                    expected = torch.from_numpy(
405                        np.msort(tensor.float().cpu().numpy())
406                    ).bfloat16()
407                else:
408                    expected = torch.from_numpy(np.msort(tensor.cpu().numpy()))
409            else:
410                expected = tensor  # numpy.msort() does not support empty shapes tensor
411
412            result = torch.msort(tensor)
413            self.assertEqual(result, expected)
414
415            out = torch.empty_like(result)
416            torch.msort(tensor, out=out)
417            self.assertEqual(out, expected)
418
419        shapes = (
420            [],
421            [0],
422            [20],
423            [1, 20],
424            [30, 30],
425            [10, 20, 30],
426        )
427        for shape in shapes:
428            test(shape)
429
430    @skipIfTorchDynamo("Fails on python 3.11")
431    @dtypes(torch.float)
432    def test_sort_expanded_tensor(self, device, dtype):
433        # https://github.com/pytorch/pytorch/issues/91420
434        data = torch.scalar_tensor(True, device=device, dtype=dtype)
435        data = data.expand([1, 1, 1])
436        ref = torch.Tensor([[[True]]])
437        out = torch.sort(data, stable=True, dim=1, descending=True)
438        expected = torch.sort(ref, stable=True, dim=1, descending=True)
439        self.assertEqual(out, expected)
440
441        data = torch.randn(4, 1, 10, device=device, dtype=dtype)
442        data = data.expand([4, 8, 10])
443        ref = data.contiguous()
444        out = torch.sort(data, stable=True, dim=1, descending=True)
445        expected = torch.sort(ref, stable=True, dim=1, descending=True)
446        self.assertEqual(out, expected)
447
448    def test_topk(self, device):
449        def topKViaSort(t, k, dim, dir):
450            sorted, indices = t.sort(dim, dir)
451            return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
452
453        def compareTensors(t, res1, ind1, res2, ind2, dim):
454            # Values should be exactly equivalent
455            self.assertEqual(res1, res2, atol=0, rtol=0)
456
457            # Indices might differ based on the implementation, since there is
458            # no guarantee of the relative order of selection
459            if not ind1.eq(ind2).all():
460                # To verify that the indices represent equivalent elements,
461                # gather from the input using the topk indices and compare against
462                # the sort indices
463                vals = t.gather(dim, ind2)
464                self.assertEqual(res1, vals, atol=0, rtol=0)
465
466        def compare(t, k, dim, dir):
467            topKVal, topKInd = t.topk(k, dim, dir, True)
468            sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
469            compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
470
471        SIZE = 100
472        t = torch.rand(
473            random.randint(1, SIZE),
474            random.randint(1, SIZE),
475            random.randint(1, SIZE),
476            device=device,
477        )
478
479        for _kTries in range(3):
480            for _dimTries in range(3):
481                for transpose in (True, False):
482                    for dir in (True, False):
483                        testTensor = t
484                        if transpose:
485                            dim1 = random.randrange(t.ndimension())
486                            dim2 = dim1
487                            while dim1 == dim2:
488                                dim2 = random.randrange(t.ndimension())
489
490                            testTensor = t.transpose(dim1, dim2)
491
492                        dim = random.randrange(testTensor.ndimension())
493                        k = random.randint(1, testTensor.size(dim))
494                        compare(testTensor, k, dim, dir)
495
496        # This tests the code path where on CUDA, topk is implemented with sort.
497        t = torch.randn((2, 100000), device=device)
498        compare(t, 2000, 1, True)
499        compare(t, 2000, 1, False)
500
501        # This tests the code path where on CUDA, topk is implemented with multiblock
502        t = torch.randn((2, 10000), device=device)
503        compare(t, 2000, 1, True)
504        compare(t, 2000, 1, False)
505
506    def test_topk_quantized_scalar_input(self):
507        # Calling topk on a quantized scalar input used to segfault,
508        # see https://github.com/pytorch/pytorch/issues/116324
509        x = torch.quantize_per_tensor(torch.randn(()), 0.1, 10, torch.qint8)
510        x.topk(1)
511
512    def test_topk_arguments(self, device):
513        q = torch.randn(10, 2, 10, device=device)
514        # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
515        self.assertRaises(TypeError, lambda: q.topk(4, True))
516
517    def test_unique_dim(self, device):
518        self.assertFalse(hasattr(torch, "unique_dim"))
519
520        def run_test(device, dtype):
521            x = torch.tensor(
522                [
523                    [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]],
524                    [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]],
525                ],
526                dtype=dtype,
527                device=device,
528            )
529            x_empty = torch.empty(5, 0, dtype=dtype, device=device)
530            x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device)
531            x_ill_formed_empty_another = torch.empty(
532                5, 0, 5, dtype=dtype, device=device
533            )
534            if dtype in floating_types_and(torch.float16, torch.bfloat16):
535                x_nan = torch.tensor(
536                    [float("nan"), 0, 0, float("nan"), float("nan"), 1],
537                    dtype=dtype,
538                    device=device,
539                )
540            expected_unique_dim0 = torch.tensor(
541                [[[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]]],
542                dtype=dtype,
543                device=device,
544            )
545            expected_inverse_dim0 = torch.tensor([0, 0])
546            expected_counts_dim0 = torch.tensor([2])
547            expected_unique_dim1 = torch.tensor(
548                [
549                    [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]],
550                    [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]],
551                ],
552                dtype=dtype,
553                device=device,
554            )
555            expected_unique_dim1_bool = torch.tensor(
556                [[[False, True], [True, True]], [[False, True], [True, True]]],
557                dtype=torch.bool,
558                device=device,
559            )
560            expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
561            expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0])
562            expected_counts_dim1 = torch.tensor([2, 1, 1])
563            expected_counts_dim1_bool = torch.tensor([2, 2])
564            expected_unique_dim2 = torch.tensor(
565                [
566                    [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]],
567                    [[1.0, 1.0], [0.0, 1.0], [2.0, 1.0], [0.0, 1.0]],
568                ],
569                dtype=dtype,
570                device=device,
571            )
572            expected_inverse_dim2 = torch.tensor([0, 1])
573            expected_counts_dim2 = torch.tensor([1, 1])
574            expected_unique_empty = torch.empty(5, 0, dtype=dtype, device=device)
575            expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device)
576            expected_counts_empty = torch.tensor([], dtype=torch.long, device=device)
577            if dtype in floating_types_and(torch.float16, torch.bfloat16):
578                expected_unique_nan = torch.tensor(
579                    [float("nan"), 0, float("nan"), float("nan"), 1],
580                    dtype=dtype,
581                    device=device,
582                )
583                expected_inverse_nan = torch.tensor(
584                    [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device
585                )
586                expected_counts_nan = torch.tensor(
587                    [1, 2, 1, 1, 1], dtype=torch.long, device=device
588                )
589            # dim0
590            x_unique = torch.unique(x, dim=0)
591            self.assertEqual(expected_unique_dim0, x_unique)
592
593            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
594            self.assertEqual(expected_unique_dim0, x_unique)
595            self.assertEqual(expected_inverse_dim0, x_inverse)
596
597            x_unique, x_counts = torch.unique(
598                x, return_inverse=False, return_counts=True, dim=0
599            )
600            self.assertEqual(expected_unique_dim0, x_unique)
601            self.assertEqual(expected_counts_dim0, x_counts)
602
603            x_unique, x_inverse, x_counts = torch.unique(
604                x, return_inverse=True, return_counts=True, dim=0
605            )
606            self.assertEqual(expected_unique_dim0, x_unique)
607            self.assertEqual(expected_inverse_dim0, x_inverse)
608            self.assertEqual(expected_counts_dim0, x_counts)
609
610            # dim1
611            x_unique = torch.unique(x, dim=1)
612            if x.dtype == torch.bool:
613                self.assertEqual(expected_unique_dim1_bool, x_unique)
614            else:
615                self.assertEqual(expected_unique_dim1, x_unique)
616
617            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
618            if x.dtype == torch.bool:
619                self.assertEqual(expected_unique_dim1_bool, x_unique)
620                self.assertEqual(expected_inverse_dim1_bool, x_inverse)
621            else:
622                self.assertEqual(expected_unique_dim1, x_unique)
623                self.assertEqual(expected_inverse_dim1, x_inverse)
624
625            x_unique, x_counts = torch.unique(
626                x, return_inverse=False, return_counts=True, dim=1
627            )
628            if x.dtype == torch.bool:
629                self.assertEqual(expected_unique_dim1_bool, x_unique)
630                self.assertEqual(expected_counts_dim1_bool, x_counts)
631            else:
632                self.assertEqual(expected_unique_dim1, x_unique)
633                self.assertEqual(expected_counts_dim1, x_counts)
634
635            x_unique, x_inverse, x_counts = torch.unique(
636                x, return_inverse=True, return_counts=True, dim=1
637            )
638            if x.dtype == torch.bool:
639                self.assertEqual(expected_unique_dim1_bool, x_unique)
640                self.assertEqual(expected_inverse_dim1_bool, x_inverse)
641                self.assertEqual(expected_counts_dim1_bool, x_counts)
642            else:
643                self.assertEqual(expected_unique_dim1, x_unique)
644                self.assertEqual(expected_inverse_dim1, x_inverse)
645                self.assertEqual(expected_counts_dim1, x_counts)
646
647            # dim2
648            x_unique = torch.unique(x, dim=2)
649            self.assertEqual(expected_unique_dim2, x_unique)
650
651            x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
652            self.assertEqual(expected_unique_dim2, x_unique)
653            self.assertEqual(expected_inverse_dim2, x_inverse)
654
655            x_unique, x_counts = torch.unique(
656                x, return_inverse=False, return_counts=True, dim=2
657            )
658            self.assertEqual(expected_unique_dim2, x_unique)
659            self.assertEqual(expected_counts_dim2, x_counts)
660
661            x_unique, x_inverse, x_counts = torch.unique(
662                x, return_inverse=True, return_counts=True, dim=2
663            )
664            self.assertEqual(expected_unique_dim2, x_unique)
665            self.assertEqual(expected_inverse_dim2, x_inverse)
666            self.assertEqual(expected_counts_dim2, x_counts)
667
668            # test empty tensor
669            x_unique, x_inverse, x_counts = torch.unique(
670                x_empty, return_inverse=True, return_counts=True, dim=1
671            )
672            self.assertEqual(expected_unique_empty, x_unique)
673            self.assertEqual(expected_inverse_empty, x_inverse)
674            self.assertEqual(expected_counts_empty, x_counts)
675
676            # test tensor with nan
677            if dtype in floating_types_and(torch.float16, torch.bfloat16):
678                x_unique, x_inverse, x_counts = torch.unique(
679                    x_nan, return_inverse=True, return_counts=True, dim=0
680                )
681                self.assertEqual(expected_unique_nan, x_unique)
682                self.assertEqual(expected_inverse_nan, x_inverse)
683                self.assertEqual(expected_counts_nan, x_counts)
684
685            # test not a well formed tensor
686            # Checking for runtime error, as this is the expected behaviour
687            with self.assertRaises(RuntimeError):
688                torch.unique(
689                    x_ill_formed_empty, return_inverse=True, return_counts=True, dim=1
690                )
691
692            # test along dim2
693            with self.assertRaises(RuntimeError):
694                torch.unique(
695                    x_ill_formed_empty_another,
696                    return_inverse=True,
697                    return_counts=True,
698                    dim=2,
699                )
700
701            # test consecutive version
702            y = torch.tensor(
703                [
704                    [0, 1],
705                    [0, 1],
706                    [0, 1],
707                    [1, 2],
708                    [1, 2],
709                    [3, 4],
710                    [0, 1],
711                    [0, 1],
712                    [3, 4],
713                    [1, 2],
714                ],
715                dtype=dtype,
716                device=device,
717            )
718            # test tensor with nan
719            if dtype in floating_types_and(torch.float16, torch.bfloat16):
720                y_nan = torch.tensor(
721                    [float("nan"), 0, 0, float("nan"), float("nan"), 1],
722                    dtype=dtype,
723                    device=device,
724                )
725            expected_y_unique = torch.tensor(
726                [[0, 1], [1, 2], [3, 4], [0, 1], [3, 4], [1, 2]],
727                dtype=dtype,
728                device=device,
729            )
730            expected_y_inverse = torch.tensor(
731                [0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device
732            )
733            expected_y_counts = torch.tensor(
734                [3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device
735            )
736            expected_y_inverse_bool = torch.tensor(
737                [0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device
738            )
739            expected_y_counts_bool = torch.tensor(
740                [3, 3, 2, 2], dtype=torch.int64, device=device
741            )
742            if dtype in floating_types_and(torch.float16, torch.bfloat16):
743                expected_y_unique_nan = torch.tensor(
744                    [float("nan"), 0, float("nan"), float("nan"), 1],
745                    dtype=dtype,
746                    device=device,
747                )
748                expected_y_inverse_nan = torch.tensor(
749                    [0, 1, 1, 2, 3, 4], dtype=torch.long, device=device
750                )
751                expected_y_counts_nan = torch.tensor(
752                    [1, 2, 1, 1, 1], dtype=torch.long, device=device
753                )
754
755            y_unique, y_inverse, y_counts = torch.unique_consecutive(
756                y, return_inverse=True, return_counts=True, dim=0
757            )
758            if x.dtype == torch.bool:
759                self.assertEqual(expected_y_inverse_bool, y_inverse)
760                self.assertEqual(expected_y_counts_bool, y_counts)
761            else:
762                self.assertEqual(expected_y_inverse, y_inverse)
763                self.assertEqual(expected_y_counts, y_counts)
764
765            # test tensor with nan
766            if dtype in floating_types_and(torch.float16, torch.bfloat16):
767                y_unique, y_inverse, y_counts = torch.unique_consecutive(
768                    y_nan, return_inverse=True, return_counts=True, dim=0
769                )
770                self.assertEqual(expected_y_unique_nan, y_unique)
771                self.assertEqual(expected_y_inverse_nan, y_inverse)
772                self.assertEqual(expected_y_counts_nan, y_counts)
773
774            # Test dim is sorted same as NumPy with dims >= 3
775            x = torch.tensor(
776                [
777                    [
778                        [[1, 0, 1, 0, 1, 1], [0, 1, 1, 0, 1, 1]],
779                        [[0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0]],
780                    ],
781                    [
782                        [[0, 1, 0, 1, 1, 1], [0, 1, 1, 0, 1, 1]],
783                        [[0, 0, 1, 1, 0, 1], [1, 1, 0, 0, 0, 0]],
784                    ],
785                ],
786                dtype=dtype,
787                device=device,
788            )
789            xn = x.cpu().numpy()
790            for d in range(x.dim()):
791                t = torch.unique(x, dim=d)
792                n = np.unique(xn, axis=d)
793                self.assertEqual(t.cpu().numpy(), n)
794
795        run_test(device, torch.float)
796        run_test(device, torch.double)
797        run_test(device, torch.long)
798        run_test(device, torch.uint8)
799        run_test(device, torch.bool)
800
801    @onlyCUDA
802    def test_topk_noncontiguous_gpu(self, device):
803        # test different topk paths on cuda
804        single_block_t = torch.randn(20, device=device)[::2]
805        multi_block_t = torch.randn(20000, device=device)[::2]
806        sort_t = torch.randn(200000, device=device)[::2]
807        for t in (single_block_t, multi_block_t, sort_t):
808            for k in (5, 2000, 10000):
809                if k >= t.shape[0]:
810                    continue
811                top1, idx1 = t.topk(k)
812                top2, idx2 = t.contiguous().topk(k)
813                self.assertEqual(top1, top2)
814                self.assertEqual(idx1, idx2)
815
816    def _test_topk_dtype(self, device, dtype, integral, size):
817        if integral:
818            a = torch.randint(
819                torch.iinfo(dtype).min,
820                torch.iinfo(dtype).max,
821                size=(size,),
822                dtype=dtype,
823                device=device,
824            )
825        else:
826            a = torch.randn(size=(size,), dtype=dtype, device=device)
827
828        sort_topk = a.sort()[0][-(size // 2) :].flip(0)
829        topk = a.topk(size // 2)
830        self.assertEqual(sort_topk, topk[0])  # check values
831        self.assertEqual(sort_topk, a[topk[1]])  # check indices
832
833    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)
834    def test_topk_integral(self, device, dtype):
835        small = 10
836        large = 4096
837        verylarge = 8192  # multi_block topk on cuda
838        for curr_size in (small, large, verylarge):
839            self._test_topk_dtype(device, dtype, True, curr_size)
840
841    @dtypes(torch.bfloat16, torch.half)
842    def test_topk_lower_precision(self, device, dtype):
843        small = 10
844        large = 4096
845        verylarge = 8192  # multi_block topk on cuda
846        for curr_size in (small, large, verylarge):
847            self._test_topk_dtype(device, dtype, False, curr_size)
848
849    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
850    @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
851    def test_topk_nonfinite(self, device, dtype):
852        x = torch.tensor(
853            [float("nan"), float("inf"), 1e4, 0, -1e4, -float("inf")],
854            device=device,
855            dtype=dtype,
856        )
857        val, idx = x.topk(4)
858        expect = torch.tensor(
859            [float("nan"), float("inf"), 1e4, 0], device=device, dtype=dtype
860        )
861        self.assertEqual(val, expect)
862        self.assertEqual(idx, [0, 1, 2, 3])
863
864        val, idx = x.topk(4, largest=False)
865        expect = torch.tensor([-float("inf"), -1e4, 0, 1e4], device=device, dtype=dtype)
866        self.assertEqual(val, expect)
867        self.assertEqual(idx, [5, 4, 3, 2])
868
869    def test_topk_4d(self, device):
870        small = 128
871        large = 8192
872        for size in (small, large):
873            x = torch.ones(2, size, 2, 2, device=device)
874            x[:, 1, :, :] *= 2.0
875            x[:, 10, :, :] *= 1.5
876            val, ind = torch.topk(x, k=2, dim=1)
877            expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device)
878            expected_ind[:, 1, :, :] = 10
879            expected_val = torch.ones(2, 2, 2, 2, device=device)
880            expected_val[:, 0, :, :] *= 2.0
881            expected_val[:, 1, :, :] *= 1.5
882            self.assertEqual(val, expected_val, atol=0, rtol=0)
883            self.assertEqual(ind, expected_ind, atol=0, rtol=0)
884
885    @onlyNativeDeviceTypes
886    @dtypesIfCUDA(*all_types_and(torch.bfloat16))
887    @dtypes(*all_types_and(torch.bfloat16, torch.half))
888    def test_topk_zero(self, device, dtype):
889        # https://github.com/pytorch/pytorch/issues/49205
890        t = torch.rand(2, 2, device=device).to(dtype=dtype)
891        val, idx = torch.topk(t, k=0, largest=False)
892        self.assertEqual(val.size(), torch.Size([2, 0]))
893        self.assertEqual(idx.size(), torch.Size([2, 0]))
894
895    def _test_unique_scalar_empty(self, dtype, device, f):
896        # test scalar
897        x = torch.tensor(0, dtype=dtype, device=device)
898        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
899        expected_unique = torch.tensor([0], dtype=dtype, device=device)
900        expected_inverse = torch.tensor(0, device=device)
901        expected_counts = torch.tensor([1], device=device)
902        self.assertEqual(unique, expected_unique)
903        self.assertEqual(inverse, expected_inverse)
904        self.assertEqual(counts, expected_counts)
905
906        # test zero sized tensor
907        x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
908        unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
909        expected_unique = torch.tensor([], dtype=dtype, device=device)
910        expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
911        expected_counts = torch.tensor([], dtype=torch.long, device=device)
912        self.assertEqual(unique, expected_unique)
913        self.assertEqual(inverse, expected_inverse)
914        self.assertEqual(counts, expected_counts)
915
916    def _test_unique_with_expects(
917        self,
918        device,
919        dtype,
920        f,
921        x,
922        expected_unique,
923        expected_inverse,
924        expected_counts,
925        additional_shape,
926    ):
927        def ensure_tuple(x):
928            if isinstance(x, torch.Tensor):
929                return (x,)
930            return x
931
932        for return_inverse in [True, False]:
933            for return_counts in [True, False]:
934                # test with expected
935                ret = ensure_tuple(
936                    f(x, return_inverse=return_inverse, return_counts=return_counts)
937                )
938                self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
939                self.assertEqual(expected_unique, ret[0])
940                if return_inverse:
941                    self.assertEqual(expected_inverse, ret[1])
942                if return_counts:
943                    count_index = 1 + int(return_inverse)
944                    self.assertEqual(expected_counts, ret[count_index])
945
946                # tests per-element unique on a higher rank tensor.
947                y = x.view(additional_shape)
948                y_unique, y_inverse, y_counts = f(
949                    y, return_inverse=True, return_counts=True
950                )
951                self.assertEqual(expected_unique, y_unique)
952                self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
953                self.assertEqual(expected_counts, y_counts)
954
955    @dtypesIfCPU(*all_types_and(torch.bool, torch.float16, torch.bfloat16))
956    @dtypes(*all_types_and(torch.half, torch.bool))
957    def test_unique(self, device, dtype):
958        def ensure_tuple(x):
959            if isinstance(x, torch.Tensor):
960                return (x,)
961            return x
962
963        if dtype is torch.bool:
964            x = torch.tensor(
965                [True, False, False, False, True, False, True, False],
966                dtype=torch.bool,
967                device=device,
968            )
969            expected_unique = torch.tensor(
970                [False, True], dtype=torch.bool, device=device
971            )
972            expected_inverse = torch.tensor(
973                [1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device
974            )
975            expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
976        else:
977            x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
978            expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
979            expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
980            expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
981
982        # test sorted unique
983        fs = (
984            lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
985            lambda x, **kwargs: x.unique(sorted=True, **kwargs),
986        )
987        x_sliced = torch.empty(x.size(0) * 2, dtype=dtype, device=device)[::2].copy_(x)
988        xs = (x, x_sliced)
989        for f, x in product(fs, xs):
990            self._test_unique_with_expects(
991                device,
992                dtype,
993                f,
994                x,
995                expected_unique,
996                expected_inverse,
997                expected_counts,
998                (2, 2, 2),
999            )
1000            self._test_unique_scalar_empty(dtype, device, f)
1001
1002        # test unsorted unique
1003        fs = (
1004            lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
1005            lambda x, **kwargs: x.unique(sorted=False, **kwargs),
1006        )
1007        for f, x in product(fs, xs):
1008            self._test_unique_scalar_empty(dtype, device, f)
1009            for return_inverse, return_counts in product((True, False), repeat=2):
1010                ret = ensure_tuple(
1011                    f(x, return_inverse=return_inverse, return_counts=return_counts)
1012                )
1013                self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
1014                x_list = x.tolist()
1015                x_unique_list = ret[0].tolist()
1016                self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
1017                if return_inverse:
1018                    x_inverse_list = ret[1].tolist()
1019                    for i, j in enumerate(x_inverse_list):
1020                        self.assertEqual(x_list[i], x_unique_list[j])
1021                if return_counts:
1022                    count_index = 1 + int(return_inverse)
1023                    x_counts_list = ret[count_index].tolist()
1024                    for i, j in zip(x_unique_list, x_counts_list):
1025                        count = 0
1026                        for k in x_list:
1027                            if k == i:
1028                                count += 1
1029                        self.assertEqual(j, count)
1030
1031    @dtypesIfCPU(*all_types_and(torch.bool, torch.float16, torch.bfloat16))
1032    @dtypes(*all_types_and(torch.half, torch.bool))
1033    def test_unique_consecutive(self, device, dtype):
1034        if dtype is torch.bool:
1035            x = torch.tensor(
1036                [True, False, False, False, True, True, False, False, False],
1037                dtype=torch.bool,
1038                device=device,
1039            )
1040            expected_unique = torch.tensor(
1041                [True, False, True, False], dtype=torch.bool, device=device
1042            )
1043            expected_inverse = torch.tensor(
1044                [0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device
1045            )
1046            expected_counts = torch.tensor(
1047                [1, 3, 2, 3], dtype=torch.long, device=device
1048            )
1049        else:
1050            x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device)
1051            expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device)
1052            expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device)
1053            expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device)
1054
1055        for f in [
1056            torch.unique_consecutive,
1057            lambda x, **kwargs: x.unique_consecutive(**kwargs),
1058        ]:
1059            self._test_unique_with_expects(
1060                device,
1061                dtype,
1062                f,
1063                x,
1064                expected_unique,
1065                expected_inverse,
1066                expected_counts,
1067                (3, 3),
1068            )
1069            self._test_unique_scalar_empty(dtype, device, f)
1070
1071    @dtypes(torch.double)
1072    def test_kthvalue(self, device, dtype):
1073        SIZE = 50
1074        x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device)
1075        x0 = x.clone()
1076
1077        k = random.randint(1, SIZE)
1078        res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
1079        res2val, res2ind = torch.sort(x)
1080
1081        self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
1082        self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)
1083        # test use of result tensors
1084        k = random.randint(1, SIZE)
1085        res1val = torch.tensor([], dtype=dtype, device=device)
1086        res1ind = torch.tensor([], dtype=torch.long, device=device)
1087        torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind))
1088        res2val, res2ind = torch.sort(x)
1089        self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
1090        self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)
1091
1092        # test non-default dim
1093        k = random.randint(1, SIZE)
1094        res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False)
1095        res2val, res2ind = torch.sort(x, 0)
1096        self.assertEqual(res1val, res2val[k - 1], atol=0, rtol=0)
1097        self.assertEqual(res1ind, res2ind[k - 1], atol=0, rtol=0)
1098
1099        # non-contiguous
1100        y = x.narrow(1, 0, 1)
1101        y0 = y.contiguous()
1102        k = random.randint(1, SIZE)
1103        res1val, res1ind = torch.kthvalue(y, k)
1104        res2val, res2ind = torch.kthvalue(y0, k)
1105        self.assertEqual(res1val, res2val, atol=0, rtol=0)
1106        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
1107
1108        # non-contiguous [Reference: https://github.com/pytorch/pytorch/issues/45721]
1109        non_contig_t = torch.tensor([0, -1, 1, -2, 2], dtype=dtype, device=device)[::2]
1110        expected_val, expected_ind = non_contig_t.contiguous().kthvalue(2)
1111        non_contig_cpu_t = non_contig_t.cpu()
1112        expected_val_cpu, expected_ind_cpu = non_contig_cpu_t.kthvalue(2)
1113
1114        out_val, out_ind = non_contig_t.kthvalue(2)
1115        self.assertEqual(expected_val, out_val, atol=0, rtol=0)
1116        self.assertEqual(expected_ind, out_ind, atol=0, rtol=0)
1117        self.assertEqual(expected_val_cpu, out_val, atol=0, rtol=0)
1118        self.assertEqual(expected_ind_cpu, out_ind, atol=0, rtol=0)
1119
1120        # check that the input wasn't modified
1121        self.assertEqual(x, x0, atol=0, rtol=0)
1122
1123        # simple test case (with repetitions)
1124        y = torch.tensor((3.0, 5, 4, 1, 1, 5), dtype=dtype, device=device)
1125        self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0)
1126        self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0)
1127
1128        # simple test case (with NaN)
1129        SIZE = 50
1130        x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device)
1131        x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
1132        ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
1133        res2val, res2ind = torch.sort(x)
1134        for k in ks:
1135            res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
1136            self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
1137            self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)
1138
1139    @dtypes(torch.float)
1140    @onlyNativeDeviceTypes  # Fails on XLA
1141    def test_kthvalue_scalar(self, device, dtype):
1142        # Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818)
1143        # Tests that passing a scalar tensor or 1D tensor with 1 element work either way
1144        res = torch.tensor(2, device=device, dtype=dtype).kthvalue(1)
1145        ref = torch.tensor([2], device=device, dtype=dtype).kthvalue(1)
1146        self.assertEqual(res[0], ref[0].squeeze())
1147        self.assertEqual(res[1], ref[1].squeeze())
1148
1149    @dtypes(*all_types())
1150    @dtypesIfCUDA(*all_types_and(torch.half))
1151    def test_isin(self, device, dtype):
1152        def assert_isin_equal(a, b):
1153            # Compare to the numpy reference implementation.
1154            x = torch.isin(a, b)
1155            a = a.cpu().numpy() if torch.is_tensor(a) else np.array(a)
1156            b = b.cpu().numpy() if torch.is_tensor(b) else np.array(b)
1157            y = np.isin(a, b)
1158            self.assertEqual(x, y)
1159
1160        # multi-dim tensor, multi-dim tensor
1161        a = torch.arange(24, device=device, dtype=dtype).reshape([2, 3, 4])
1162        b = torch.tensor(
1163            [[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype
1164        )
1165        assert_isin_equal(a, b)
1166
1167        # zero-dim tensor
1168        zero_d = torch.tensor(3, device=device, dtype=dtype)
1169        assert_isin_equal(zero_d, b)
1170        assert_isin_equal(a, zero_d)
1171        assert_isin_equal(zero_d, zero_d)
1172
1173        # empty tensor
1174        empty = torch.tensor([], device=device, dtype=dtype)
1175        assert_isin_equal(empty, b)
1176        assert_isin_equal(a, empty)
1177        assert_isin_equal(empty, empty)
1178
1179        # scalar
1180        assert_isin_equal(a, 6)
1181        assert_isin_equal(5, b)
1182
1183        def define_expected(lst, invert=False):
1184            expected = torch.tensor(lst, device=device)
1185            if invert:
1186                expected = expected.logical_not()
1187            return expected
1188
1189        # Adapted from numpy's in1d tests
1190        for mult in [1, 10]:
1191            for invert in [False, True]:
1192                a = torch.tensor([5, 7, 1, 2], device=device, dtype=dtype)
1193                b = torch.tensor([2, 4, 3, 1, 5] * mult, device=device, dtype=dtype)
1194                ec = define_expected([True, False, True, True], invert=invert)
1195                c = torch.isin(a, b, assume_unique=True, invert=invert)
1196                self.assertEqual(c, ec)
1197
1198                a[0] = 8
1199                ec = define_expected([False, False, True, True], invert=invert)
1200                c = torch.isin(a, b, assume_unique=True, invert=invert)
1201                self.assertEqual(c, ec)
1202
1203                a[0], a[3] = 4, 8
1204                ec = define_expected([True, False, True, False], invert=invert)
1205                c = torch.isin(a, b, assume_unique=True, invert=invert)
1206                self.assertEqual(c, ec)
1207
1208                a = torch.tensor(
1209                    [5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5],
1210                    device=device,
1211                    dtype=dtype,
1212                )
1213                b = torch.tensor([2, 3, 4] * mult, device=device, dtype=dtype)
1214                ec = define_expected(
1215                    [
1216                        False,
1217                        True,
1218                        False,
1219                        True,
1220                        True,
1221                        True,
1222                        True,
1223                        True,
1224                        True,
1225                        False,
1226                        True,
1227                        False,
1228                        False,
1229                        False,
1230                    ],
1231                    invert=invert,
1232                )
1233                c = torch.isin(a, b, invert=invert)
1234                self.assertEqual(c, ec)
1235
1236                b = torch.tensor(
1237                    [2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype
1238                )
1239                ec = define_expected(
1240                    [
1241                        True,
1242                        True,
1243                        True,
1244                        True,
1245                        True,
1246                        True,
1247                        True,
1248                        True,
1249                        True,
1250                        True,
1251                        True,
1252                        False,
1253                        True,
1254                        True,
1255                    ],
1256                    invert=invert,
1257                )
1258                c = torch.isin(a, b, invert=invert)
1259                self.assertEqual(c, ec)
1260
1261                a = torch.tensor([5, 7, 1, 2], device=device, dtype=dtype)
1262                b = torch.tensor([2, 4, 3, 1, 5] * mult, device=device, dtype=dtype)
1263                ec = define_expected([True, False, True, True], invert=invert)
1264                c = torch.isin(a, b, invert=invert)
1265                self.assertEqual(c, ec)
1266
1267                a = torch.tensor([5, 7, 1, 1, 2], device=device, dtype=dtype)
1268                b = torch.tensor([2, 4, 3, 3, 1, 5] * mult, device=device, dtype=dtype)
1269                ec = define_expected([True, False, True, True, True], invert=invert)
1270                c = torch.isin(a, b, invert=invert)
1271                self.assertEqual(c, ec)
1272
1273                a = torch.tensor([5, 5], device=device, dtype=dtype)
1274                b = torch.tensor([2, 2] * mult, device=device, dtype=dtype)
1275                ec = define_expected([False, False], invert=invert)
1276                c = torch.isin(a, b, invert=invert)
1277                self.assertEqual(c, ec)
1278
1279                # multi-dimensional input case using sort-based algo
1280                for assume_unique in [False, True]:
1281                    a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3])
1282                    b = torch.arange(3, 30, device=device, dtype=dtype)
1283                    ec = define_expected(
1284                        [[False, False, False], [True, True, True]], invert=invert
1285                    )
1286                    c = torch.isin(a, b, invert=invert, assume_unique=assume_unique)
1287                    self.assertEqual(c, ec)
1288
1289    def test_isin_different_dtypes(self, device):
1290        supported_types = all_types() if device == "cpu" else all_types_and(torch.half)
1291        for mult in [1, 10]:
1292            for assume_unique in [False, True]:
1293                for dtype1, dtype2 in product(supported_types, supported_types):
1294                    a = torch.tensor([1, 2, 3], device=device, dtype=dtype1)
1295                    b = torch.tensor([3, 4, 5] * mult, device=device, dtype=dtype2)
1296                    ec = torch.tensor([False, False, True], device=device)
1297                    c = torch.isin(a, b, assume_unique=assume_unique)
1298                    self.assertEqual(c, ec)
1299
1300    @onlyCUDA
1301    @dtypes(*all_types())
1302    def test_isin_different_devices(self, device, dtype):
1303        a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3])
1304        b = torch.arange(3, 30, device="cpu", dtype=dtype)
1305        with self.assertRaises(RuntimeError):
1306            torch.isin(a, b)
1307
1308        c = torch.arange(6, device="cpu", dtype=dtype).reshape([2, 3])
1309        d = torch.arange(3, 30, device=device, dtype=dtype)
1310        with self.assertRaises(RuntimeError):
1311            torch.isin(c, d)
1312
1313    @dtypes(*integral_types())
1314    def test_sort_overflow(self, device, dtype):
1315        "Regression test for https://github.com/pytorch/pytorch/issues/111189"
1316        prev_num_threads = torch.get_num_threads()
1317        try:
1318            low = 0 if dtype == torch.uint8 else -1
1319            x = torch.full((32768,), low, dtype=dtype, device=device)
1320            x[:100] = torch.iinfo(x.dtype).max
1321            torch.set_num_threads(1)
1322            uv = x.sort().values.unique()
1323            self.assertEqual(uv.size(0), 2)
1324        finally:
1325            torch.set_num_threads(prev_num_threads)
1326
1327
1328instantiate_device_type_tests(TestSortAndSelect, globals())
1329
1330if __name__ == "__main__":
1331    run_tests()
1332