1# Owner(s): ["module: scatter & gather ops"] 2 3import random 4 5import torch 6 7from torch.testing import make_tensor 8from torch.testing._internal.common_utils import \ 9 (parametrize, run_tests, TestCase, DeterministicGuard) 10from torch.testing._internal.common_device_type import \ 11 (instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, 12 toleranceOverride, tol,) 13from torch.testing._internal.common_dtype import \ 14 (get_all_dtypes,) 15 16# Protects against includes accidentally setting the default dtype 17assert torch.get_default_dtype() is torch.float32 18 19 20# Note: test_scatter_gather_ops.py 21# This test file tests scatter and gather operations, 22# like torch.scatter and torch.gather. 23 24class TestScatterGather(TestCase): 25 # Fills an index tensor with valid indices 26 def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o, unique_indices=True): 27 for i in range(1 if dim == 0 else m): 28 for j in range(1 if dim == 1 else n): 29 for k in range(1 if dim == 2 else o): 30 ii = [i, j, k] 31 ii[dim] = slice(0, idx.size(dim) + 1) 32 if unique_indices: 33 idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] 34 else: 35 idx[tuple(ii)] = torch.randint(dim_size, (elems_per_row,)) 36 37 @dtypes(torch.float32, torch.complex64) 38 def test_gather(self, device, dtype): 39 m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) 40 elems_per_row = random.randint(1, 10) 41 dim = random.randrange(3) 42 43 src = make_tensor((m, n, o), device=device, dtype=dtype) 44 idx_size = [m, n, o] 45 idx_size[dim] = elems_per_row 46 idx = make_tensor(idx_size, device=device, dtype=torch.long) 47 self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o) 48 49 actual = torch.gather(src, dim, idx) 50 expected = torch.zeros(idx_size, device=device, dtype=dtype) 51 for i in range(idx_size[0]): 52 for j in range(idx_size[1]): 53 for k in range(idx_size[2]): 54 ii = [i, j, k] 55 ii[dim] = idx[i, j, k] 56 expected[i, j, k] = src[tuple(ii)] 57 self.assertEqual(actual, expected, atol=0, rtol=0) 58 59 # Guarded because torch.max isn't defined for complex types 60 if not dtype.is_complex: 61 src = make_tensor((3, 4, 5), device=device, dtype=dtype) 62 expected, idx = src.max(2, True) 63 actual = torch.gather(src, 2, idx) 64 self.assertEqual(actual, expected, atol=0, rtol=0) 65 66 @dtypes(torch.bool) 67 def test_gather_bool(self, device, dtype): 68 src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype) 69 idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) 70 actual = torch.gather(src, 1, idx) 71 expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype) 72 self.assertEqual(actual, expected, atol=0, rtol=0) 73 74 @parametrize("sparse_grad", [False, True]) 75 @dtypes(torch.float32, torch.float64) 76 def test_gather_backward_with_empty_index_tensor(self, device, dtype, sparse_grad): 77 dim = -1 78 input = torch.rand([10, 5], dtype=dtype, device=device, requires_grad=True) 79 index = torch.randint(0, 2, [3, 0], dtype=torch.int64, device=device) 80 res = torch.gather(input, dim, index, sparse_grad=sparse_grad) 81 res.sum().backward() 82 grad = input.grad.to_dense() if sparse_grad else input.grad 83 expected_grad = torch.zeros_like(input, requires_grad=False) 84 self.assertEqual(grad, expected_grad, atol=0, rtol=0) 85 86 def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction, 87 unique_indices=True, include_self=True): 88 m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) 89 elems_per_row = random.randint(1, 10) 90 dim = random.randrange(3) 91 92 idx_size = [m, n, o] 93 idx_size[dim] = elems_per_row 94 idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long) 95 self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o, unique_indices) 96 97 if is_scalar: 98 src = random.random() 99 else: 100 src_size = [random.randint(1, 5) + s for s in idx_size] 101 src = make_tensor(tuple(src_size), device=device, dtype=dtype) 102 103 base = make_tensor((m, n, o), device=device, dtype=dtype) 104 if reduction is not None: 105 if fn is torch.Tensor.scatter_reduce_: 106 actual = fn(base.clone(), dim, idx, src, reduce=reduction, include_self=include_self) 107 else: 108 actual = fn(base.clone(), dim, idx, src, reduce=reduction) 109 else: 110 actual = fn(base.clone(), dim, idx, src) 111 112 expected = base.clone() 113 counts = torch.zeros(base.shape, dtype=torch.long, device=device) + include_self 114 for i in range(idx_size[0]): 115 for j in range(idx_size[1]): 116 for k in range(idx_size[2]): 117 ii = [i, j, k] 118 ii[dim] = idx[i, j, k] 119 if fn is torch.Tensor.scatter_add_: 120 expected[tuple(ii)] += src[i, j, k] 121 else: 122 # method may be 'scatter_', 'scatter', 'scatter_reduce' 123 # or 'scatter_reduce_', the former two might have a reduction argument 124 # while the latter two always do 125 value = src if is_scalar else src[i, j, k] 126 127 if ((not include_self) and counts[tuple(ii)] == 0): 128 expected[tuple(ii)] = value 129 else: 130 if reduction == "add" or reduction == "sum": 131 expected[tuple(ii)] += value 132 elif reduction == "multiply" or reduction == "prod": 133 expected[tuple(ii)] *= value 134 elif reduction == "amax": 135 expected[tuple(ii)] = max(expected[tuple(ii)], value) 136 elif reduction == "amin": 137 expected[tuple(ii)] = min(expected[tuple(ii)], value) 138 elif reduction == "mean": 139 expected[tuple(ii)] += value 140 else: 141 expected[tuple(ii)] = value 142 143 counts[tuple(ii)] += 1 144 145 if (reduction == "mean"): 146 counts.masked_fill_(counts == 0, 1) 147 if (dtype.is_floating_point or dtype.is_complex): 148 expected /= counts 149 else: 150 expected.div_(counts, rounding_mode="floor") 151 152 if dtype == torch.float16 or dtype == torch.bfloat16: 153 # Some CUDA kernels (e.g. indexing_backward_kernel_stride_1) that are called during 154 # the test use fp32 for internal accumulation for improved accuracy. When using 16 bit 155 # precision types can be small differences 156 self.assertEqual(actual, expected, atol=0.04, rtol=0.05) 157 else: 158 self.assertEqual(actual, expected, atol=0, rtol=0) 159 160 # Tests empty index 161 dst = make_tensor((2, 2), device=device, dtype=dtype) 162 idx = torch.tensor((), device=device, dtype=torch.long) 163 src = make_tensor((2, 2), device=device, dtype=dtype) 164 if reduction is not None: 165 actual = fn(dst, 0, idx, src, reduce=reduction) 166 else: 167 actual = fn(dst, 0, idx, src) 168 self.assertEqual(actual, dst, atol=0, rtol=0) 169 170 @dtypes(torch.float16, torch.float32, torch.complex64) 171 def test_scatter_(self, device, dtype): 172 for deterministic in [False, True]: 173 with DeterministicGuard(deterministic): 174 self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, 175 is_scalar=False, reduction=None) 176 177 @dtypes(torch.float16, torch.float32, torch.complex64) 178 def test_scatter__scalar(self, device, dtype): 179 self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, 180 is_scalar=True, reduction=None) 181 182 # FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat' 183 @toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)}) 184 @dtypesIfCUDA(torch.float16, torch.float32) 185 @dtypes(torch.float16, torch.float32, torch.complex64) 186 def test_scatter__reductions(self, device, dtype): 187 for reduction in ("add", "multiply"): 188 self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, 189 is_scalar=False, reduction=reduction) 190 self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype, 191 is_scalar=True, reduction=reduction) 192 193 @dtypes(torch.float16, torch.float32, torch.complex64) 194 def test_scatter_add_(self, device, dtype): 195 for deterministic in [False, True]: 196 with DeterministicGuard(deterministic): 197 self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype, 198 is_scalar=False, reduction=None) 199 200 @dtypes(torch.float32) 201 def test_scatter_add_mult_index_base(self, device, dtype): 202 for deterministic in [False, True]: 203 with DeterministicGuard(deterministic): 204 m, n = 30, 40 205 idx = torch.zeros(m, n, device=device, dtype=torch.long) 206 src = torch.ones(m, n, device=device, dtype=dtype) 207 res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src) 208 res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src) 209 210 self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0) 211 self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0) 212 213 # FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA) 214 @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False)) 215 def test_scatter_reduce_sum(self, device, dtype): 216 for include_self in (True, False): 217 for deterministic in [False, True]: 218 with DeterministicGuard(deterministic): 219 self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, 220 is_scalar=False, reduction='sum', unique_indices=False, 221 include_self=include_self) 222 223 @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True)) 224 @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) 225 def test_scatter_reduce_prod(self, device, dtype): 226 for include_self in (True, False): 227 self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, 228 is_scalar=False, reduction='prod', unique_indices=False, 229 include_self=include_self) 230 231 @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False)) 232 @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) 233 def test_scatter_reduce_mean(self, device, dtype): 234 for include_self in (True, False): 235 for deterministic in [False, True]: 236 with DeterministicGuard(deterministic): 237 self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, 238 is_scalar=False, reduction='mean', unique_indices=False, 239 include_self=include_self) 240 241 @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) 242 @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) 243 def test_scatter_reduce_amax(self, device, dtype): 244 for include_self in (True, False): 245 self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, 246 is_scalar=False, reduction='amax', unique_indices=False, 247 include_self=include_self) 248 # simple test for nan/inf propagation 249 if (dtype.is_floating_point): 250 input = torch.zeros(3, device=device, dtype=dtype) 251 src = torch.tensor([1, float('nan'), -float('inf'), -float('inf'), 2, float('inf')], device=device, dtype=dtype) 252 idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device) 253 input.scatter_reduce_(0, idx, src, 'amax', include_self=include_self) 254 expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype) 255 if (include_self): 256 expected_result[1] = 0 257 self.assertEqual(input, expected_result) 258 259 260 @dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False)) 261 @dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False)) 262 def test_scatter_reduce_amin(self, device, dtype): 263 for include_self in (True, False): 264 self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype, 265 is_scalar=False, reduction='amin', unique_indices=False, 266 include_self=include_self) 267 # simple test for nan/inf propagation 268 if (dtype.is_floating_point): 269 input = torch.zeros(3, device=device, dtype=dtype) 270 src = torch.tensor([1, float('nan'), -2, -float('inf'), float('inf'), float('inf')], device=device, dtype=dtype) 271 idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device) 272 input.scatter_reduce_(0, idx, src, 'amin', include_self=include_self) 273 expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype) 274 if (include_self): 275 expected_result[2] = 0 276 self.assertEqual(input, expected_result) 277 278 @onlyCPU 279 @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16) 280 def test_scatter_expanded_index(self, device, dtype): 281 def helper(input_size, idx_size): 282 input = torch.randn(input_size, device=device).to(dtype=dtype) 283 input2 = input.clone() 284 285 shape = [1] * len(input_size) 286 shape[0] = idx_size 287 dim_size = input_size[0] 288 idx = torch.randint(0, dim_size, shape) 289 290 # The fast path on scatter when index is expanded 291 # will depend on sorted index where the collected src indice 292 # for each row in input will be mapped to rowptrs in a CSR format. 293 # Create some empty rows by masking: 294 mask = (idx > 1) * (idx < 4) 295 idx[mask] = 0 296 297 expanded_shape = input_size 298 expanded_shape[0] = idx_size 299 idx = idx.expand(expanded_shape) 300 idx2 = idx.contiguous() 301 src = torch.randn(expanded_shape, device=device).to(dtype=dtype) 302 303 out = input.scatter_add(0, idx, src) 304 out2 = input2.scatter_add(0, idx2, src) 305 self.assertEqual(out, out2) 306 307 for reduce in ["sum", "prod", "mean", "amax", "amin"]: 308 for include_self in [True, False]: 309 out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self) 310 out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self) 311 self.assertEqual(out, out2) 312 313 helper([50, 17], 100) 314 helper([50, 1], 100) 315 helper([50, 8, 7], 100) 316 helper([50, 3, 4, 5], 100) 317 318 @onlyCPU 319 @dtypes(torch.float32, torch.float64, torch.bfloat16) 320 def test_gather_expanded_index(self, device, dtype): 321 # Test when index is [N, 1], which would have stride [1, 0] 322 # should be excluded from the fast path when index ix expanded 323 input = torch.arange(25).view(5, 5) 324 input2 = input.to(dtype=dtype) 325 326 idx = torch.arange(5).view(5, 1) 327 out = torch.gather(input, 0, idx) 328 out2 = torch.gather(input2, 0, idx) 329 330 self.assertEqual(out.to(dtype=dtype), out2) 331 332 def helper(input_size, idx_size): 333 input = torch.randn(input_size, device=device).to(dtype=dtype) 334 input2 = input.clone() 335 336 shape = [1] * len(input_size) 337 shape[0] = idx_size 338 dim_size = input_size[0] 339 idx = torch.randint(0, dim_size, shape) 340 341 # Test the fast path on gather when index is expanded 342 expanded_shape = input_size 343 expanded_shape[0] = idx_size 344 idx = idx.expand(expanded_shape) 345 idx2 = idx.contiguous() 346 347 out = torch.gather(input, 0, idx) 348 out2 = torch.gather(input2, 0, idx2) 349 350 self.assertEqual(out, out2) 351 352 helper([50, 17], 100) 353 helper([50, 1], 100) 354 helper([50, 8, 7], 100) 355 helper([50, 3, 4, 5], 100) 356 357# Generic Device Test Framework instantation, see 358# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests 359# for details. 360instantiate_device_type_tests(TestScatterGather, globals()) 361 362if __name__ == '__main__': 363 run_tests() 364