xref: /aosp_15_r20/external/pytorch/test/test_scatter_gather_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: scatter & gather ops"]
2
3import random
4
5import torch
6
7from torch.testing import make_tensor
8from torch.testing._internal.common_utils import \
9    (parametrize, run_tests, TestCase, DeterministicGuard)
10from torch.testing._internal.common_device_type import \
11    (instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA,
12     toleranceOverride, tol,)
13from torch.testing._internal.common_dtype import \
14    (get_all_dtypes,)
15
16# Protects against includes accidentally setting the default dtype
17assert torch.get_default_dtype() is torch.float32
18
19
20# Note: test_scatter_gather_ops.py
21# This test file tests scatter and gather operations,
22#   like torch.scatter and torch.gather.
23
24class TestScatterGather(TestCase):
25    # Fills an index tensor with valid indices
26    def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o, unique_indices=True):
27        for i in range(1 if dim == 0 else m):
28            for j in range(1 if dim == 1 else n):
29                for k in range(1 if dim == 2 else o):
30                    ii = [i, j, k]
31                    ii[dim] = slice(0, idx.size(dim) + 1)
32                    if unique_indices:
33                        idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
34                    else:
35                        idx[tuple(ii)] = torch.randint(dim_size, (elems_per_row,))
36
37    @dtypes(torch.float32, torch.complex64)
38    def test_gather(self, device, dtype):
39        m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
40        elems_per_row = random.randint(1, 10)
41        dim = random.randrange(3)
42
43        src = make_tensor((m, n, o), device=device, dtype=dtype)
44        idx_size = [m, n, o]
45        idx_size[dim] = elems_per_row
46        idx = make_tensor(idx_size, device=device, dtype=torch.long)
47        self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
48
49        actual = torch.gather(src, dim, idx)
50        expected = torch.zeros(idx_size, device=device, dtype=dtype)
51        for i in range(idx_size[0]):
52            for j in range(idx_size[1]):
53                for k in range(idx_size[2]):
54                    ii = [i, j, k]
55                    ii[dim] = idx[i, j, k]
56                    expected[i, j, k] = src[tuple(ii)]
57        self.assertEqual(actual, expected, atol=0, rtol=0)
58
59        # Guarded because torch.max isn't defined for complex types
60        if not dtype.is_complex:
61            src = make_tensor((3, 4, 5), device=device, dtype=dtype)
62            expected, idx = src.max(2, True)
63            actual = torch.gather(src, 2, idx)
64            self.assertEqual(actual, expected, atol=0, rtol=0)
65
66    @dtypes(torch.bool)
67    def test_gather_bool(self, device, dtype):
68        src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)
69        idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
70        actual = torch.gather(src, 1, idx)
71        expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype)
72        self.assertEqual(actual, expected, atol=0, rtol=0)
73
74    @parametrize("sparse_grad", [False, True])
75    @dtypes(torch.float32, torch.float64)
76    def test_gather_backward_with_empty_index_tensor(self, device, dtype, sparse_grad):
77        dim = -1
78        input = torch.rand([10, 5], dtype=dtype, device=device, requires_grad=True)
79        index = torch.randint(0, 2, [3, 0], dtype=torch.int64, device=device)
80        res = torch.gather(input, dim, index, sparse_grad=sparse_grad)
81        res.sum().backward()
82        grad = input.grad.to_dense() if sparse_grad else input.grad
83        expected_grad = torch.zeros_like(input, requires_grad=False)
84        self.assertEqual(grad, expected_grad, atol=0, rtol=0)
85
86    def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction,
87                           unique_indices=True, include_self=True):
88        m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
89        elems_per_row = random.randint(1, 10)
90        dim = random.randrange(3)
91
92        idx_size = [m, n, o]
93        idx_size[dim] = elems_per_row
94        idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
95        self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o, unique_indices)
96
97        if is_scalar:
98            src = random.random()
99        else:
100            src_size = [random.randint(1, 5) + s for s in idx_size]
101            src = make_tensor(tuple(src_size), device=device, dtype=dtype)
102
103        base = make_tensor((m, n, o), device=device, dtype=dtype)
104        if reduction is not None:
105            if fn is torch.Tensor.scatter_reduce_:
106                actual = fn(base.clone(), dim, idx, src, reduce=reduction, include_self=include_self)
107            else:
108                actual = fn(base.clone(), dim, idx, src, reduce=reduction)
109        else:
110            actual = fn(base.clone(), dim, idx, src)
111
112        expected = base.clone()
113        counts = torch.zeros(base.shape, dtype=torch.long, device=device) + include_self
114        for i in range(idx_size[0]):
115            for j in range(idx_size[1]):
116                for k in range(idx_size[2]):
117                    ii = [i, j, k]
118                    ii[dim] = idx[i, j, k]
119                    if fn is torch.Tensor.scatter_add_:
120                        expected[tuple(ii)] += src[i, j, k]
121                    else:
122                        # method may be 'scatter_', 'scatter', 'scatter_reduce'
123                        # or 'scatter_reduce_', the former two might have a reduction argument
124                        # while the latter two always do
125                        value = src if is_scalar else src[i, j, k]
126
127                        if ((not include_self) and counts[tuple(ii)] == 0):
128                            expected[tuple(ii)] = value
129                        else:
130                            if reduction == "add" or reduction == "sum":
131                                expected[tuple(ii)] += value
132                            elif reduction == "multiply" or reduction == "prod":
133                                expected[tuple(ii)] *= value
134                            elif reduction == "amax":
135                                expected[tuple(ii)] = max(expected[tuple(ii)], value)
136                            elif reduction == "amin":
137                                expected[tuple(ii)] = min(expected[tuple(ii)], value)
138                            elif reduction == "mean":
139                                expected[tuple(ii)] += value
140                            else:
141                                expected[tuple(ii)] = value
142
143                        counts[tuple(ii)] += 1
144
145        if (reduction == "mean"):
146            counts.masked_fill_(counts == 0, 1)
147            if (dtype.is_floating_point or dtype.is_complex):
148                expected /= counts
149            else:
150                expected.div_(counts, rounding_mode="floor")
151
152        if dtype == torch.float16 or dtype == torch.bfloat16:
153            # Some CUDA kernels (e.g. indexing_backward_kernel_stride_1) that are called during
154            # the test use fp32 for internal accumulation for improved accuracy. When using 16 bit
155            # precision types can be small differences
156            self.assertEqual(actual, expected, atol=0.04, rtol=0.05)
157        else:
158            self.assertEqual(actual, expected, atol=0, rtol=0)
159
160        # Tests empty index
161        dst = make_tensor((2, 2), device=device, dtype=dtype)
162        idx = torch.tensor((), device=device, dtype=torch.long)
163        src = make_tensor((2, 2), device=device, dtype=dtype)
164        if reduction is not None:
165            actual = fn(dst, 0, idx, src, reduce=reduction)
166        else:
167            actual = fn(dst, 0, idx, src)
168        self.assertEqual(actual, dst, atol=0, rtol=0)
169
170    @dtypes(torch.float16, torch.float32, torch.complex64)
171    def test_scatter_(self, device, dtype):
172        for deterministic in [False, True]:
173            with DeterministicGuard(deterministic):
174                self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
175                                        is_scalar=False, reduction=None)
176
177    @dtypes(torch.float16, torch.float32, torch.complex64)
178    def test_scatter__scalar(self, device, dtype):
179        self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
180                                is_scalar=True, reduction=None)
181
182    # FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
183    @toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
184    @dtypesIfCUDA(torch.float16, torch.float32)
185    @dtypes(torch.float16, torch.float32, torch.complex64)
186    def test_scatter__reductions(self, device, dtype):
187        for reduction in ("add", "multiply"):
188            self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
189                                    is_scalar=False, reduction=reduction)
190            self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
191                                    is_scalar=True, reduction=reduction)
192
193    @dtypes(torch.float16, torch.float32, torch.complex64)
194    def test_scatter_add_(self, device, dtype):
195        for deterministic in [False, True]:
196            with DeterministicGuard(deterministic):
197                self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
198                                        is_scalar=False, reduction=None)
199
200    @dtypes(torch.float32)
201    def test_scatter_add_mult_index_base(self, device, dtype):
202        for deterministic in [False, True]:
203            with DeterministicGuard(deterministic):
204                m, n = 30, 40
205                idx = torch.zeros(m, n, device=device, dtype=torch.long)
206                src = torch.ones(m, n, device=device, dtype=dtype)
207                res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
208                res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
209
210                self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
211                self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
212
213    # FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA)
214    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
215    def test_scatter_reduce_sum(self, device, dtype):
216        for include_self in (True, False):
217            for deterministic in [False, True]:
218                with DeterministicGuard(deterministic):
219                    self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
220                                            is_scalar=False, reduction='sum', unique_indices=False,
221                                            include_self=include_self)
222
223    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
224    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
225    def test_scatter_reduce_prod(self, device, dtype):
226        for include_self in (True, False):
227            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
228                                    is_scalar=False, reduction='prod', unique_indices=False,
229                                    include_self=include_self)
230
231    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
232    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
233    def test_scatter_reduce_mean(self, device, dtype):
234        for include_self in (True, False):
235            for deterministic in [False, True]:
236                with DeterministicGuard(deterministic):
237                    self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
238                                            is_scalar=False, reduction='mean', unique_indices=False,
239                                            include_self=include_self)
240
241    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
242    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
243    def test_scatter_reduce_amax(self, device, dtype):
244        for include_self in (True, False):
245            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
246                                    is_scalar=False, reduction='amax', unique_indices=False,
247                                    include_self=include_self)
248            # simple test for nan/inf propagation
249            if (dtype.is_floating_point):
250                input = torch.zeros(3, device=device, dtype=dtype)
251                src = torch.tensor([1, float('nan'), -float('inf'), -float('inf'), 2, float('inf')], device=device, dtype=dtype)
252                idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
253                input.scatter_reduce_(0, idx, src, 'amax', include_self=include_self)
254                expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
255                if (include_self):
256                    expected_result[1] = 0
257                self.assertEqual(input, expected_result)
258
259
260    @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
261    @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
262    def test_scatter_reduce_amin(self, device, dtype):
263        for include_self in (True, False):
264            self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
265                                    is_scalar=False, reduction='amin', unique_indices=False,
266                                    include_self=include_self)
267            # simple test for nan/inf propagation
268            if (dtype.is_floating_point):
269                input = torch.zeros(3, device=device, dtype=dtype)
270                src = torch.tensor([1, float('nan'), -2, -float('inf'), float('inf'), float('inf')], device=device, dtype=dtype)
271                idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
272                input.scatter_reduce_(0, idx, src, 'amin', include_self=include_self)
273                expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
274                if (include_self):
275                    expected_result[2] = 0
276                self.assertEqual(input, expected_result)
277
278    @onlyCPU
279    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
280    def test_scatter_expanded_index(self, device, dtype):
281        def helper(input_size, idx_size):
282            input = torch.randn(input_size, device=device).to(dtype=dtype)
283            input2 = input.clone()
284
285            shape = [1] * len(input_size)
286            shape[0] = idx_size
287            dim_size = input_size[0]
288            idx = torch.randint(0, dim_size, shape)
289
290            # The fast path on scatter when index is expanded
291            # will depend on sorted index where the collected src indice
292            # for each row in input will be mapped to rowptrs in a CSR format.
293            # Create some empty rows by masking:
294            mask = (idx > 1) * (idx < 4)
295            idx[mask] = 0
296
297            expanded_shape = input_size
298            expanded_shape[0] = idx_size
299            idx = idx.expand(expanded_shape)
300            idx2 = idx.contiguous()
301            src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
302
303            out = input.scatter_add(0, idx, src)
304            out2 = input2.scatter_add(0, idx2, src)
305            self.assertEqual(out, out2)
306
307            for reduce in ["sum", "prod", "mean", "amax", "amin"]:
308                for include_self in [True, False]:
309                    out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
310                    out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
311                    self.assertEqual(out, out2)
312
313        helper([50, 17], 100)
314        helper([50, 1], 100)
315        helper([50, 8, 7], 100)
316        helper([50, 3, 4, 5], 100)
317
318    @onlyCPU
319    @dtypes(torch.float32, torch.float64, torch.bfloat16)
320    def test_gather_expanded_index(self, device, dtype):
321        # Test when index is [N, 1], which would have stride [1, 0]
322        # should be excluded from the fast path when index ix expanded
323        input = torch.arange(25).view(5, 5)
324        input2 = input.to(dtype=dtype)
325
326        idx = torch.arange(5).view(5, 1)
327        out = torch.gather(input, 0, idx)
328        out2 = torch.gather(input2, 0, idx)
329
330        self.assertEqual(out.to(dtype=dtype), out2)
331
332        def helper(input_size, idx_size):
333            input = torch.randn(input_size, device=device).to(dtype=dtype)
334            input2 = input.clone()
335
336            shape = [1] * len(input_size)
337            shape[0] = idx_size
338            dim_size = input_size[0]
339            idx = torch.randint(0, dim_size, shape)
340
341            # Test the fast path on gather when index is expanded
342            expanded_shape = input_size
343            expanded_shape[0] = idx_size
344            idx = idx.expand(expanded_shape)
345            idx2 = idx.contiguous()
346
347            out = torch.gather(input, 0, idx)
348            out2 = torch.gather(input2, 0, idx2)
349
350            self.assertEqual(out, out2)
351
352        helper([50, 17], 100)
353        helper([50, 1], 100)
354        helper([50, 8, 7], 100)
355        helper([50, 3, 4, 5], 100)
356
357# Generic Device Test Framework instantation, see
358#   https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
359#   for details.
360instantiate_device_type_tests(TestScatterGather, globals())
361
362if __name__ == '__main__':
363    run_tests()
364