xref: /aosp_15_r20/external/pytorch/test/test_reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: tests"]
2
3import contextlib
4import torch
5import numpy as np
6
7import math
8from typing import Dict, List, Sequence
9import random
10from functools import partial
11from itertools import product, combinations, permutations
12import warnings
13
14from torch import inf, nan
15from torch.testing import make_tensor
16from torch.testing._internal.common_dtype import (
17    all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and,
18    integral_types_and, floating_and_complex_types_and, all_types_and, all_types,
19)
20from torch.testing._internal.common_utils import (
21    TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
22    parametrize,
23    IS_WINDOWS)
24from torch.testing._internal.common_device_type import (
25    OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
26    onlyNativeDeviceTypes, onlyCUDA, largeTensorTest, ops, precisionOverride)
27from torch.testing._internal.common_methods_invocations import (
28    ReductionOpInfo, ReductionPythonRefInfo, reduction_ops, reference_masked_ops)
29
30# TODO: replace with make_tensor
31def _generate_input(shape, dtype, device, with_extremal):
32    if shape == ():
33        x = torch.tensor((), dtype=dtype, device=device)
34    else:
35        if dtype.is_floating_point or dtype.is_complex:
36            # work around torch.randn not being implemented for bfloat16
37            if dtype == torch.bfloat16:
38                x = torch.randn(*shape, device=device) * random.randint(30, 100)
39                x = x.to(torch.bfloat16)
40            else:
41                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
42            x[torch.randn(*shape) > 0.5] = 0
43            if with_extremal and dtype.is_floating_point:
44                # Use extremal values
45                x[torch.randn(*shape) > 0.5] = float('nan')
46                x[torch.randn(*shape) > 0.5] = float('inf')
47                x[torch.randn(*shape) > 0.5] = float('-inf')
48            elif with_extremal and dtype.is_complex:
49                x[torch.randn(*shape) > 0.5] = complex('nan')
50                x[torch.randn(*shape) > 0.5] = complex('inf')
51                x[torch.randn(*shape) > 0.5] = complex('-inf')
52        elif dtype == torch.bool:
53            x = torch.zeros(shape, dtype=dtype, device=device)
54            x[torch.randn(*shape) > 0.5] = True
55        else:
56            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
57
58    return x
59
60# TODO: replace with make_tensor
61def _rand_shape(dim, min_size, max_size):
62    shape = []
63    for i in range(dim):
64        shape.append(random.randint(min_size, max_size))
65    return tuple(shape)
66
67def _reduced_shape(shape, dim=None, keepdim=False):
68    """Computes the expected reduced shape given dim and keepdim
69
70    Args:
71        shape: The shape to reduce
72        dim : The dimensions to reduce
73        keepdim: If true, reduced dimensions have size 1 in the reduced shape,
74            otherwise they are removed from the reduced shape.
75
76    Returns:
77        The reduced shape
78    """
79    if dim is None:
80        return [1] * len(shape) if keepdim else []
81
82    # Wrap negative dims
83    dim = dim if isinstance(dim, Sequence) else [dim]
84    dim = {i if i >= 0 else len(shape) + i for i in dim}
85
86    result = []
87    for i, size in enumerate(shape):
88        if i not in dim:
89            result.append(size)
90        elif keepdim:
91            result.append(1)
92
93    return result
94
95class TestReductions(TestCase):
96
97    ###########################################################################
98    # ReductionOpInfo unit tests
99    ###########################################################################
100
101    def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim):
102        """Tests output shape for input with ndim and dim and keepdim kwargs"""
103        shape = torch.randint(2, 5, (ndim,)).tolist()
104        t = make_tensor(shape, dtype=torch.float, device=device)
105        args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim))
106        result = op(t, *args, **dim_keepdim, **kwargs)
107        expected_shape = _reduced_shape(shape, **dim_keepdim)
108        self.assertEqual(result.shape, expected_shape, f"""
109        expected output shape to be {expected_shape} but got {list(result.shape)}
110        for input shape {shape} and {dim_keepdim}
111        """)
112
113    # TODO(@heitorschueroff) combine cases with and without keepdim once
114    # there's support for a @parametrize decorator.
115
116    @ops(reduction_ops, dtypes=OpDTypes.none)
117    def test_dim_default(self, device, op: ReductionOpInfo):
118        """Tests that the default dim reduces all dimensions."""
119        for ndim in range(3):
120            self._test_dim_keepdim(op, device, ndim=ndim)
121
122    @ops(reduction_ops, dtypes=OpDTypes.none)
123    def test_dim_default_keepdim(self, device, op: ReductionOpInfo):
124        """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1."""
125        for ndim in range(3):
126            self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True)
127
128    @ops(reduction_ops, dtypes=OpDTypes.none)
129    def test_dim_none(self, device, op: ReductionOpInfo):
130        """Tests that dim=None reduces all dimensions."""
131        for ndim in range(3):
132            self._test_dim_keepdim(op, device, ndim=ndim, dim=None)
133
134    @ops(reduction_ops, dtypes=OpDTypes.none)
135    def test_dim_none_keepdim(self, device, op: ReductionOpInfo):
136        """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1."""
137        for ndim in range(3):
138            self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True)
139
140    @ops(reduction_ops, dtypes=OpDTypes.none)
141    def test_dim_single(self, device, op: ReductionOpInfo):
142        """Tests that dim=i reduces dimension i."""
143        self._test_dim_keepdim(op, device, ndim=0, dim=0)
144        self._test_dim_keepdim(op, device, ndim=1, dim=0)
145        self._test_dim_keepdim(op, device, ndim=2, dim=-1)
146        self._test_dim_keepdim(op, device, ndim=3, dim=1)
147
148    @ops(reduction_ops, dtypes=OpDTypes.none)
149    def test_dim_single_keepdim(self, device, op: ReductionOpInfo):
150        """Tests that dim=i, when keepdim=True, reduces dimension i to size 1."""
151        self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True)
152        self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True)
153        self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True)
154        self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True)
155
156    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
157    def test_dim_empty(self, device, op: ReductionOpInfo):
158        """Tests that dim=[] is a no-op"""
159        self._test_dim_keepdim(op, device, ndim=0, dim=[])
160        self._test_dim_keepdim(op, device, ndim=2, dim=[])
161
162    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
163    def test_dim_empty_keepdim(self, device, op: ReductionOpInfo):
164        """Tests that dim=[], when keepdim=True, is a no-op"""
165        self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True)
166        self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True)
167
168    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
169    def test_dim_multi(self, device, op: ReductionOpInfo):
170        """Tests that dim=[i, j, ...] reduces dimensions i, j, ...."""
171        self._test_dim_keepdim(op, device, ndim=1, dim=[0])
172        self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
173
174    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
175    def test_dim_multi_keepdim(self, device, op: ReductionOpInfo):
176        """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1."""
177        self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True)
178        self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True)
179
180    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
181    def test_dim_multi_unsorted(self, device, op: ReductionOpInfo):
182        """Tests that operator correctly handles unsorted dim list."""
183        self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2])
184
185    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
186    def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo):
187        """Tests that operator correctly handles unsorted dim list when keepdim=True."""
188        self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True)
189
190    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
191    def test_dim_multi_duplicate(self, device, op: ReductionOpInfo):
192        """Tests that an error is raised if dim has duplicate entries."""
193        with self.assertRaises(RuntimeError):
194            self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2])
195
196    @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
197    def test_dim_multi_unsupported(self, device, op: ReductionOpInfo):
198        """Tests that ops claiming to not support multi dim actually don't."""
199        with self.assertRaises(TypeError):
200            self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
201
202    @ops(reduction_ops, dtypes=OpDTypes.none)
203    def test_dim_offbounds(self, device, op: ReductionOpInfo):
204        """Tests that passing an off-bounds dim throws"""
205        with self.assertRaises(IndexError):
206            self._test_dim_keepdim(op, device, ndim=2, dim=2)
207
208    @ops(reduction_ops, dtypes=OpDTypes.none)
209    def test_dim_ndim_limit(self, device, op: ReductionOpInfo):
210        """Tests that an exception is raised when reducing a tensor with more
211        than 64 dims along some specific dimensions. dim=None is ok"""
212        t = make_tensor([1] * 65, dtype=torch.float, device=device)
213        with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
214            op(t, dim=0)
215
216    @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported)
217    def test_identity(self, device, dtype, op: ReductionOpInfo):
218        """Tests that the identity value is an identity for the operator"""
219        t = make_tensor((10,), dtype=dtype, device=device)
220        t[1::2] = op.identity
221        args, kwargs = next(op.generate_args_kwargs(t))
222        result = op(t[::2], *args, **kwargs)
223        result_with_identity = op(t, *args, **kwargs)
224        self.assertEqual(result, result_with_identity, """
225        Adding identity value to the input tensor should not change the result.
226        """)
227
228    # TODO(@heitorschueroff) Update these to use the nan_policy kwarg once
229    # it is added to reduction operators.
230
231    @ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported,
232         allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
233    def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo):
234        """Tests that nan is propagated to the output by default"""
235        t = make_tensor((5,), dtype=dtype, device=device)
236        t[2] = torch.nan
237        args, kwargs = next(op.generate_args_kwargs(t))
238        result = op(t, *args, **kwargs)
239        self.assertTrue(result.isnan())
240
241    @ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported,
242         allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
243    def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo):
244        """Tests that NaN values do not affect the result."""
245        t = make_tensor((10,), dtype=dtype, device=device)
246        t[1::2] = torch.nan
247        args, kwargs = next(op.generate_args_kwargs(t))
248        result = op(t[::2], *args, **kwargs)
249        result_with_nan = op(t, *args, **kwargs)
250        self.assertEqual(result, result_with_nan)
251
252    @ops(reduction_ops, dtypes=OpDTypes.supported)
253    def test_result_dtype(self, device, dtype, op: ReductionOpInfo):
254        """Tests that the result has the correct dtype"""
255        t = make_tensor((5,), dtype=dtype, device=device)
256        args, kwargs = next(op.generate_args_kwargs(t))
257        result: torch.Tensor = op(t, *args, **kwargs)
258        is_integral = dtype in integral_types_and(torch.bool)
259        if op.promotes_int_to_float and is_integral:
260            self.assertTrue(torch.is_floating_point(result))
261        elif op.promotes_int_to_int64 and is_integral:
262            self.assertEqual(result.dtype, torch.int64)
263        elif op.result_dtype is not None:
264            self.assertEqual(result.dtype, op.result_dtype)
265        elif op.complex_to_real:
266            _complex_to_real_dtype_map = {
267                torch.complex128: torch.float64,
268                torch.complex64: torch.float32,
269                torch.complex32: torch.float16,
270            }
271            self.assertEqual(result.dtype, _complex_to_real_dtype_map.get(dtype, dtype))
272        else:
273            self.assertEqual(result.dtype, dtype)
274
275    @ops(reduction_ops, dtypes=OpDTypes.none)
276    def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo):
277        """Tests for consistent behavior when reducing over an empty slice.
278
279        The rules for reducing over an empty slice are as follows:
280            - Return the identity value if the operator has one
281            - Otherwise, return NaN if the operator promotes integral dtype to
282              floating point dtypes.
283            - Otherwise, raise an error
284
285        See discussion here https://github.com/pytorch/pytorch/issues/61901
286        """
287        t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
288        for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []:
289            args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
290            if op.identity is not None:
291                # Reducing along empty slice should return identity
292                result = op(t, *args, dim=dim, **kwargs)
293                self.assertEqual(result, torch.full_like(result, op.identity))
294            elif op.promotes_int_to_float:
295                # Reducing along empty slice should return NaN
296                result = op(t, *args, dim=dim, **kwargs)
297                self.assertEqual(result, torch.full_like(result, torch.nan))
298            else:
299                # Reducing along empty slice should raise an error
300                if isinstance(op, ReductionPythonRefInfo):
301                    # ref reductions throw RuntimeError for this
302                    with self.assertRaises(RuntimeError):
303                        op(t, *args, dim=dim, **kwargs)
304                else:
305                    with self.assertRaises(IndexError):
306                        op(t, *args, dim=dim, **kwargs)
307
308    @ops(reduction_ops, dtypes=OpDTypes.none)
309    def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo):
310        """Tests that reducing a nonempty slice of an empty tensor returns an
311        empty tensor with the dimensions reduced."""
312        t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
313        for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []:
314            args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
315            result = op(t, *args, dim=dim, **kwargs)
316            self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
317
318    def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
319        """Helper method to test noncontiguous input tensors."""
320        assert not t.is_contiguous()
321
322        t_contig = t.contiguous()
323        for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs):
324            kwargs.update(reduction_kwargs)
325            result = op(t, *args, **kwargs)
326            expected = op(t_contig, *args, **kwargs)
327            self.assertEqual(result, expected)
328
329    @ops(reduction_ops)
330    def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo):
331        """Tests reducing along noncontiguous innermost dimension."""
332        t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
333        self._test_noncontiguous(op, t[:, ::2], dim=1)
334
335    @ops(reduction_ops)
336    def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo):
337        """Tests reducing along noncontiguous outermost dimension."""
338        t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
339        self._test_noncontiguous(op, t[::2, :], dim=0)
340
341    @ops(reduction_ops)
342    def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo):
343        """Tests reducing all dimensions of a noncontiguous tensor."""
344        t = make_tensor((5, 5, 5), dtype=dtype, device=device, low=-1, high=1)
345        self._test_noncontiguous(op, t[::2, ::3, 1:-1:2])
346
347    @ops(reduction_ops)
348    def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo):
349        """Tests reducing a transposed tensor."""
350        t = make_tensor((5, 5), dtype=dtype, device=device, low=-1, high=1)
351        self._test_noncontiguous(op, t.T)
352
353    @ops(reduction_ops)
354    def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo):
355        """Tests reducing a tensor with expanded singleton dimensions."""
356        t = make_tensor((2, 3), dtype=dtype, device=device, low=-1, high=1)
357        self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1))
358
359    # NumPy does not support BFloat16 so we don't test that against reference
360    # implementations. We also don't compare dtypes or test for different
361    # keepdim because we already have other tests covering those.
362    # The test_reference_testing in test_ops.py only uses the samples from
363    # sample_inputs_func which do not test as exhaustively as these tests.
364
365    def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
366        """Compares op against op.ref for the given input and reduction kwargs"""
367        for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs):
368            kwargs.update(reduction_kwargs)
369            result = op(t, *args, **kwargs)
370            expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs)
371            self.assertEqual(result, expected, exact_dtype=False)
372
373    @ops(filter(lambda op: op.ref is not None, reduction_ops),
374         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
375    def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo):
376        """Compares op against reference for scalar input tensors"""
377        self._test_ref(op, make_tensor([], dtype=dtype, device=device))
378
379    @ops(filter(lambda op: op.ref is not None, reduction_ops),
380         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
381    def test_ref_small_input(self, device, dtype, op: ReductionOpInfo):
382        """Compares op against reference for small input tensors"""
383        t = make_tensor((5, 3, 4, 2), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
384        self._test_ref(op, t)
385        for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
386            self._test_ref(op, t, dim=dim)
387
388    @ops(filter(lambda op: op.ref is not None, reduction_ops),
389         allowed_dtypes=[torch.float64])
390    def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo):
391        """Compares op against reference for a large 1D input tensor to check stability"""
392        self._test_ref(op, make_tensor((2 ** 20,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
393
394    @ops(filter(lambda op: op.ref is not None, reduction_ops),
395         allowed_dtypes=[torch.float64])
396    def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo):
397        """Compares op against reference for a large 2D input tensor to test parallelism"""
398        t = make_tensor((32, 2 ** 16), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)
399        self._test_ref(op, t, dim=1)
400
401    @largeTensorTest("8gb")
402    @ops(filter(lambda op: op.ref is not None, reduction_ops),
403         allowed_dtypes=[torch.float64])
404    def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo):
405        """Compares op against reference for a very large input tensor that requires 64 bit indexing"""
406        self._test_ref(op, make_tensor((275000000,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
407
408    @ops(filter(lambda op: op.ref is not None, reduction_ops),
409         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
410    def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo):
411        """Compares op against reference for input tensors with duplicate values"""
412        t = make_tensor((4, 4), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
413        t[::2, ::2] = t[1::2, 1::2]
414        self._test_ref(op, t)
415        self._test_ref(op, t, dim=0)
416        self._test_ref(op, t, dim=1)
417
418    @ops(filter(lambda op: op.ref is not None, reduction_ops),
419         allowed_dtypes=[torch.float32, torch.complex64])
420    def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo):
421        """Compares op against reference for input tensors with extremal values"""
422        t = make_tensor((5,), dtype=dtype, device=device, exclude_zero=True)
423        extremals = [0, 1, nan, inf, -inf]
424        for extremal in extremals:
425            t[2] = extremal
426            self._test_ref(op, t)
427
428    ###########################################################################
429    # TODO: Legacy tests - port to ReductionOpInfo
430    ###########################################################################
431
432    def test_var_unbiased(self, device):
433        tensor = torch.randn(100, device=device)
434        self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
435        self.assertEqual(tensor.var(), tensor.var(unbiased=True))
436        self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False))
437
438        tensor = torch.tensor([1.0, 2.0], device=device)
439        self.assertEqual(tensor.var(unbiased=True), 0.5)
440        self.assertEqual(tensor.var(unbiased=False), 0.25)
441
442        tensor = torch.tensor([1.0, 2.0, 3.0], device=device)
443        self.assertEqual(tensor.var(unbiased=True), 1.0)
444        self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0)
445
446        tensor = torch.randn(100, device=device)
447        self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True))
448        self.assertEqual(tensor.std(), tensor.std(unbiased=True))
449        self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False))
450
451    def test_var_stability(self, device):
452        tensor = torch.tensor([2281.5, 2281.25], device=device)
453        self.assertEqual(tensor.var(dim=0), 0.03125)
454        self.assertEqual(tensor.var(), 0.03125)
455
456    def test_sum_dim_reduction_uint8_overflow(self, device):
457        example = [[-1, 2, 1], [5, 3, 6]]
458        x = torch.tensor(example, dtype=torch.uint8, device=device)
459        self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
460        self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8, device=device))
461        self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8, device=device))
462        y = torch.tensor(example, dtype=torch.uint8, device=device)
463        torch.sum(x, 0, out=y)
464        self.assertEqual(x.sum(0, dtype=torch.uint8), y)
465
466    def test_dim_reduction_less_than_64(self, device):
467        sizes = [1] * 65
468        x = torch.randn(sizes, device=device)
469        ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
470               torch.norm]
471        for op in ops:
472            with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
473                op(x, dim=64)
474            with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
475                op(x, dim=-1)
476
477    @onlyCPU
478    @dtypes(torch.float, torch.bfloat16)
479    def test_dim_reduction_lastdim(self, device, dtype):
480        x = torch.randn(3, 5, 40, device=device, dtype=dtype)
481        x = x[:, :, 0:40:2]
482        x2 = x.contiguous()
483        ops = [torch.norm, torch.argmax, torch.argmin]
484        for op in ops:
485            y = op(x, dim=-1)
486            y2 = op(x2, dim=-1)
487            self.assertEqual(y, y2)
488
489    @skipIfNoSciPy
490    @dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
491    def test_logsumexp(self, device, dtype):
492        from scipy.special import logsumexp
493        a = torch.randn(5, 4, device=device, dtype=dtype)
494        # torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp,
495        # numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740
496        if torch.device(device) != torch.device('cpu'):
497            a[0, 0] = inf
498        a[1, :] = -inf
499        actual = a.logsumexp(1)
500        expected = logsumexp(a.cpu().numpy(), 1)
501        self.assertEqual(expected.shape, actual.shape)
502        self.assertEqual(expected, actual)
503
504        # check that out is actually inplace
505        b = torch.zeros(5, 2, device=device, dtype=dtype)
506        c = b[:, 0]
507        torch.logsumexp(a, 1, out=c)
508        self.assertEqual(expected, b[:, 0])
509
510    @skipIfNoSciPy
511    def test_logsumexp_integral_promotion(self, device):
512        from scipy.special import logsumexp
513        # check integral inputs is promoted to floating point
514        e = torch.randint(-100, 100, [5, 4], device=device)
515        actual = e.logsumexp(1).to(torch.float64)
516        expected = logsumexp(e.cpu().numpy(), 1)
517        self.assertEqual(expected.shape, actual.shape)
518        self.assertEqual(expected, actual)
519
520    @skipIfNoSciPy
521    @dtypes(torch.complex64, torch.complex128)
522    def test_logcumsumexp_complex(self, device, dtype):
523        # logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))``
524        # and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]``
525        # the for-loop above should produce similar precision as logcumsumexp (it's just slower),
526        # so it can be used as the expected values to check our computation
527
528        # using logsumexp from scipy because by the time of writing this test code,
529        # torch.logsumexp has not been implemented for complex numbers
530        from scipy.special import logsumexp
531
532        def zero_out_neg_inf(t):
533            t = t.clone()
534            idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0)
535            t[idx] = torch.real(t[idx]).to(t.dtype)
536            return t
537
538        def standardize_phase(t):
539            t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi))
540            return t
541
542        def logcumsumexp_slow(a, dim):
543            res_lst = []
544            for i in range(a.size(dim)):
545                index = [slice(None, None, None) for _ in range(a.ndim)]
546                index[dim] = slice(None, i + 1, None)
547                a_inp = a[tuple(index)]
548                res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
549            res = np.concatenate(res_lst, axis=dim)
550            return torch.as_tensor(res)
551
552        def compare_logcumsumexp(a, expected=None):
553            for i in range(a.ndim):
554                actual = torch.logcumsumexp(a, dim=i)
555                # if the expected is not given, then revert to scipy's logsumexp
556                if expected is None:
557                    expected2 = logcumsumexp_slow(a, dim=i)
558                else:
559                    expected2 = expected
560
561                # move the imaginary values to (0, 2 * pi)
562                actual = standardize_phase(actual)
563                expected2 = standardize_phase(expected2)
564
565                # zeroing the imaginary part of the element if the real part is -inf
566                # as the imaginary part cannot be determined exactly and it does not
567                # really matter if we take the exp of the output
568                actual = zero_out_neg_inf(actual)
569                expected2 = zero_out_neg_inf(expected2)
570                self.assertEqual(expected2.shape, actual.shape)
571                self.assertEqual(expected2, actual)
572
573        # randomly specified values
574        # in this case, scipy.logsumexp should be enough
575        a1 = torch.randn((5, 10), dtype=dtype, device=device)
576        compare_logcumsumexp(a1)
577
578        # test with some non-normal values
579        a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device)
580        compare_logcumsumexp(a2)
581
582        # handle special case involving infinites and nans
583        # here we don't use scipy.logsumexp as it gives confusing answer on
584        # some inf cases
585        # see here:
586        inf = float('inf')
587        nan = float('nan')
588        a3_input = torch.tensor([
589            -inf + 4j,
590            -inf + 1j,
591            1.2 + 2.1j,
592            1e10 + 1e20j,
593            inf + 0j,
594            inf + 1j,
595            inf + 3j,
596            nan + 2j,
597        ])
598        a3_expected = torch.tensor([
599            -inf + 0j,
600            -inf + 0j,
601            1.2 + 2.1j,
602            1e10 + 1e20j,
603            inf + 0j,  # scipy's logsumexp gives (inf + 0.7853982j) here, unclear why
604            inf + (np.pi / 4) * 1j,  # the imaginary part thanks to some weird behaviour of log(inf + infj)
605            complex(inf, nan),
606            complex(nan, nan),
607        ])
608        # windows give strange results on the second-to-last results where it gives inf + pi/4 j
609        # instead of inf + nan j
610        if not IS_WINDOWS:
611            compare_logcumsumexp(a3_input, a3_expected)
612
613        a4_input = torch.tensor([
614            complex(-inf, inf),
615            complex(-inf, inf),
616            -inf + 1j,
617            1.2 + 2.1j,
618            complex(2.4, inf),
619        ])
620        a4_expected = torch.tensor([
621            -inf + 0j,
622            -inf + 0j,
623            -inf + 0j,
624            1.2 + 2.1j,
625            complex(nan, nan),
626        ])
627        if not IS_WINDOWS:
628            compare_logcumsumexp(a4_input, a4_expected)
629
630    @onlyCPU
631    def test_sum_parallel(self, device):
632        # To use parallel branches we'll need to compare on tensors
633        # that are relatively large. Even if this is run on a single
634        # core machine these tests will still give you signal on
635        # the correctness
636
637        def _run_test(size):
638            for dim in range(len(size) + 1):
639                nv = np.round(np.random.rand(*size))  # 0s and 1s
640                tv = torch.from_numpy(nv)
641                # Parallelisim is only used if numel is
642                # larger than grainsize defined in Parallel.h
643                self.assertTrue(tv.numel() > 32768)
644                if dim == len(size):
645                    nvs = nv.sum()
646                    tvs = tv.sum()
647                else:
648                    nvs = nv.sum(dim)
649                    tvs = tv.sum(dim)
650                diff = np.abs(nvs - tvs.numpy()).sum()
651                self.assertEqual(diff, 0)
652
653        _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
654        _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
655        _run_test([1, 32 * 8 * 32 * 8])
656        _run_test([1, 32770])
657
658    # TODO: kill map2_ (and similar) uses and update to compare with NumPy
659    # only works on CPU since this uses map2_, which is only supported on CPU
660    def _testCSelection(self, torchfn, mathfn):
661        # Two tensors
662        size = (100, 100)
663        a = torch.rand(*size)
664        b = torch.rand(*size)
665        c = torchfn(a, b)
666        expected_c = torch.zeros(*size)
667        expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b))
668        self.assertEqual(expected_c, c, atol=0, rtol=0)
669
670    @onlyCPU
671    def test_max_elementwise(self, device):
672        self._testCSelection(torch.max, max)
673
674    @onlyCPU
675    def test_min_elementwise(self, device):
676        self._testCSelection(torch.min, min)
677
678    def test_all_any(self, device):
679        def test(size):
680            x = torch.ones(*size, device=device).byte()
681            self.assertTrue(x.all())
682            self.assertTrue(x.any())
683
684            x[3] = 0
685            self.assertFalse(x.all())
686            self.assertTrue(x.any())
687
688            x.zero_()
689            self.assertFalse(x.all())
690            self.assertFalse(x.any())
691
692            x.fill_(2)
693            self.assertTrue(x.all())
694            self.assertTrue(x.any())
695
696            x = torch.ones(*size, device=device).bool()
697            self.assertTrue(x.all())
698            self.assertTrue(x.any())
699
700            x[3] = False
701            self.assertFalse(x.all())
702            self.assertTrue(x.any())
703
704        test((10,))
705        test((5, 5))
706
707    def test_all_any_with_dim(self, device):
708        def test(x):
709            r1 = x.prod(dim=0, keepdim=False).byte()
710            r2 = x.all(dim=0, keepdim=False)
711            self.assertEqual(r1.shape, r2.shape)
712            self.assertTrue((r1 == r2).all())
713
714            r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte()
715            r4 = x.any(dim=1, keepdim=True)
716            self.assertEqual(r3.shape, r4.shape)
717            self.assertTrue((r3 == r4).all())
718
719        test(torch.tensor([[0, 0, 0],
720                           [0, 0, 1],
721                           [0, 1, 1],
722                           [1, 1, 1]], device=device, dtype=torch.uint8))
723
724    def test_numpy_named_args(self, device):
725        x1 = torch.randn(10, device=device)
726        x2 = torch.randn(10, device=device)
727        res1 = torch.add(input=x1, other=x2)
728        res2 = torch.add(x1=x1, x2=x2)
729        self.assertEqual(res1, res2)
730
731        x1 = torch.randn(10, 10, 10, device=device)
732        res1 = x1.sum(dim=(0, 2), keepdim=True)
733        res2 = x1.sum(axis=(0, 2), keepdims=True)
734        self.assertEqual(res1, res2)
735
736    # TODO: kill this ane replace with common creation ops
737    def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
738                      use_complex=False) -> Dict[str, List[torch.Tensor]]:
739        float_types = [torch.double,
740                       torch.float]
741        int_types = [torch.int64,
742                     torch.int32,
743                     torch.int16]
744
745        complex_types = [torch.complex64,
746                         torch.complex128]
747
748        def make_contiguous(shape, dtype) -> torch.Tensor:
749            if dtype in float_types:
750                val = torch.randn(shape, dtype=dtype)
751                val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
752                val = val + ((val_range[1] - val_range[0]) / 2.0)
753                val = torch.clamp(val, min=val_range[0], max=val_range[1])
754                return val
755            result = torch.zeros(shape, dtype=dtype)
756            result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
757            return result
758
759        def make_non_contiguous(shape, dtype) -> torch.Tensor:
760            contig = make_contiguous(shape, dtype)
761            non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
762            non_contig = non_contig.select(-1, -1)
763            non_contig.copy_(contig)
764            self.assertFalse(non_contig.is_contiguous())
765            return non_contig
766
767        def make_contiguous_slice(size, dtype) -> torch.Tensor:
768            contig = make_contiguous((1, size), dtype)
769            non_contig = contig[:1, 1:size - 1]
770            self.assertTrue(non_contig.is_contiguous())
771            return contig
772
773        types = []
774        if use_floating:
775            types += float_types
776        if use_integral:
777            types += int_types
778        if use_complex:
779            types += complex_types
780        tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
781        for dtype in types:
782            tensors["cont"].append(make_contiguous(shape, dtype))
783            tensors["noncont"].append(make_non_contiguous(shape, dtype))
784            tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
785
786        return tensors
787
788    # TODO: refactor this to use comparators from common_utils
789    def _assert_matches_numpy(self, t, n):
790        self.assertEqual(n.shape, t.shape)
791        if t.dtype == torch.float:
792            self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True)
793        else:
794            self.assertEqual(n, t, equal_nan=True)
795
796    # TODO: update this and tests that use it to use the device argument properly
797    def _test_dim_ops(self, pytorch_op, numpy_op,
798                      use_floating=True, use_integral=True, use_complex=False):
799        def do_one(tensors_dict, dim):
800            for category, tensors in tensors_dict.items():
801                if category == "slice":
802                    dim = 0
803                for tensor in tensors:
804                    # we have no control over NumPy warnings...
805                    with warnings.catch_warnings():
806                        warnings.simplefilter("ignore")
807                        expected = numpy_op(tensor.cpu().numpy(), dim)
808                    actual = pytorch_op(tensor, dim)
809                    self._assert_matches_numpy(actual, expected)
810                    if torch.cuda.is_available():
811                        self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected)
812        do_one(self._make_tensors((5, 400000), use_floating=use_floating,
813                                  use_integral=use_integral, use_complex=use_complex), 1)
814        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
815                                  use_integral=use_integral, use_complex=use_complex), 0)
816        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
817                                  use_integral=use_integral, use_complex=use_complex), 1)
818        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
819                                  use_integral=use_integral, use_complex=use_complex), 2)
820        do_one(self._make_tensors((100000, ), use_floating=use_floating,
821                                  use_integral=use_integral, use_complex=use_complex), -1)
822        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
823                                  use_integral=use_integral, use_complex=use_complex), 0)
824        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
825                                  use_integral=use_integral, use_complex=use_complex), 1)
826        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
827                                  use_integral=use_integral, use_complex=use_complex), 2)
828        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
829                                  use_integral=use_integral, use_complex=use_complex), (1, 2))
830        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
831                                  use_integral=use_integral, use_complex=use_complex), (1, -1))
832        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
833                                  use_integral=use_integral, use_complex=use_complex), (0, 2))
834        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
835                                  use_integral=use_integral, use_complex=use_complex), (0, 2, 1))
836
837    @slowTest
838    @onlyCPU
839    def test_sum_dim(self, device):
840        self._test_dim_ops(
841            lambda t, d: t.sum(d),
842            lambda n, d: n.sum(d),
843            use_floating=True, use_integral=True, use_complex=True)
844
845    @onlyCPU
846    def test_mean_dim(self, device):
847        self._test_dim_ops(
848            lambda t, d: t.mean(d),
849            lambda n, d: n.mean(d),
850            use_integral=False,
851            use_complex=True)
852
853    @onlyCPU
854    def test_std_dim(self, device):
855        for unbiased in [False, True]:
856            self._test_dim_ops(
857                lambda t, d: t.std(d, unbiased=unbiased),
858                lambda n, d: n.std(d, ddof=1 if unbiased else 0),
859                use_integral=False)
860
861    @onlyCPU
862    def test_var_dim(self, device):
863        for unbiased in [False, True]:
864            self._test_dim_ops(
865                lambda t, d: t.var(d, unbiased=unbiased),
866                lambda n, d: n.var(d, ddof=1 if unbiased else 0),
867                use_integral=False)
868
869    @onlyCPU
870    @skipIfNoSciPy
871    def test_logsumexp_dim(self, device):
872        from scipy.special import logsumexp
873        self._test_dim_ops(
874            lambda t, d: t.logsumexp(d),
875            lambda n, d: logsumexp(n, d),
876            use_integral=False)
877
878    @onlyCPU
879    def test_mean_int_with_optdtype(self, device):
880        a = make_tensor((3, 4, 5), dtype=torch.int64, device=device)
881
882        # If the optional desired output type is given, the input
883        # is internally cast.
884        a_float = a.to(torch.float32)
885        self.assertEqual(a_float.mean(), a.mean(dtype=torch.float32))
886
887    # TODO: update this and tests that use it to handle device properly
888    def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True):
889        shape = (3, 4, 5)
890        reduced_shape = fn(torch.ones(shape)).shape
891
892        def _test_out(dtype, other_dtype):
893            out = torch.ones(reduced_shape, dtype=dtype)
894            result = fn(x, out=out)
895            self.assertIs(out.dtype, result.dtype)
896            self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
897            result = fn(x, out=out, dtype=dtype)
898            self.assertIs(out.dtype, result.dtype)
899            self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
900            # 'out' is favored over dtype, check error
901            self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype))
902
903        for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]:
904            x = torch.ones(shape, dtype=dtype)
905            expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64
906            self.assertIs(expected_dtype, fn(x).dtype)
907            self.assertEqual(fn(x.to(expected_dtype)), fn(x))
908
909            if dtype.is_floating_point:
910                other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
911            elif dtype.is_complex:
912                other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128
913            else:
914                other_dtype = torch.int32 if dtype != torch.int32 else torch.int16
915            self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
916            self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False)
917
918            # test mixed int/float/complex
919            if dtype.is_floating_point:
920                mixed_dtypes = [torch.int32, torch.complex64]
921            elif dtype.is_complex:
922                mixed_dtypes = [torch.int32, torch.float32]
923            else:
924                mixed_dtypes = [torch.float32, torch.complex64]
925
926            for mixed_dtype in mixed_dtypes:
927                self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
928                self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False)
929
930                if has_out:
931                    _test_out(dtype, other_dtype)
932                    _test_out(dtype, mixed_dtype)
933
934    @onlyCPU
935    def test_sum_integer_upcast(self, device):
936        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False)
937        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs))
938
939    @onlyCPU
940    def test_prod_integer_upcast(self, device):
941        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False)
942        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs))
943
944    @onlyCPU
945    def test_cumsum_integer_upcast(self, device):
946        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
947
948    @onlyCPU
949    def test_cumprod_integer_upcast(self, device):
950        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
951
952    @dtypes(*all_types())
953    def test_mode(self, device, dtype):
954        SIZE = 10
955        x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE)
956        x[:2] = 1
957        x[:, :2] = 1
958        x0 = x.clone()
959
960        # Pre-calculated results.
961        res1val = torch.ones(SIZE, device=device, dtype=dtype)
962        # The indices are the position of the last appearance of the mode element.
963        res1ind = torch.ones(SIZE, device=device, dtype=torch.long)
964        res1ind[0] = SIZE - 1
965        res1ind[1] = SIZE - 1
966
967        res2val, res2ind = torch.mode(x, keepdim=False)
968        self.assertEqual(res1val, res2val, atol=0, rtol=0)
969        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
970
971        # Test use of result tensor
972        res2val = torch.tensor((), device=device, dtype=dtype)
973        res2ind = torch.tensor((), device=device, dtype=torch.long)
974        torch.mode(x, keepdim=False, out=(res2val, res2ind))
975        self.assertEqual(res1val, res2val, atol=0, rtol=0)
976        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
977
978        # Test non-default dim
979        res2val, res2ind = torch.mode(x, 0, False)
980        self.assertEqual(res1val, res2val, atol=0, rtol=0)
981        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
982
983        # input unchanged
984        self.assertEqual(x, x0, atol=0, rtol=0)
985
986    def _test_mode_intervals(self, shape, intervals, device, dtype, v=1):
987        x = torch.arange(0, shape[1], device=device, dtype=dtype).expand(shape)
988        x = x.contiguous()
989        x[:, v] = intervals[0][0]
990
991        # Set the value of each interval to the mode "v"
992        for (beg, end) in intervals:
993            x[:, beg:end] = v
994
995        values, indices = torch.mode(x, -1, False)
996
997        # Check whether the returned indices correspond to the returned values
998        self.assertTrue((x.gather(1, indices.unsqueeze(1)).t() == values).all())
999        # Check whether the returned values are the mode
1000        self.assertTrue((values == v).all().item())
1001
1002    @onlyCUDA
1003    @dtypes(*all_types_and(torch.half, torch.bfloat16))
1004    def test_mode_large(self, device, dtype):
1005        # i should be less than (d - 2) / 2
1006        def testset_for_shape(shape, i):
1007            d = shape[-1]
1008            # Mode only in the middle.
1009            self._test_mode_intervals(shape, [(i, d - i)], device, dtype)
1010            # Mode in discontiguous parts of the input.
1011            self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device, dtype)
1012
1013        # More than one line of (65535) thread blocks
1014        testset_for_shape((65536, 10), 3)
1015
1016        # Max slice size (2048)
1017        testset_for_shape((10, 2048), 10)
1018
1019        # Naive kernel for big slice sizes (> 2048)
1020        testset_for_shape((10, 4096), 10)
1021
1022    def test_mode_boolean(self, device):
1023        shapes = [
1024            (10, 10),
1025            (4, 2048),
1026            (1, 4096),
1027        ]
1028
1029        for shape in shapes:
1030            a = torch.zeros(shape, device=device, dtype=torch.bool)
1031
1032            a[:, (shape[1] - 1) // 2:] = True
1033            values, indices = a.mode(-1)
1034            self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
1035            print(indices)
1036            indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1037            self.assertEqual(values, indexed)
1038
1039            a.fill_(False)
1040            a[:, shape[1] // 2 + 1:] = True
1041            values, indices = a.mode(-1)
1042            print(indices)
1043            self.assertEqual(values, torch.zeros(shape[0], dtype=torch.bool))
1044            indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1045            self.assertEqual(values, indexed)
1046
1047
1048    @expectedFailureMeta  # mode only supports CPU and CUDA device type
1049    @onlyNativeDeviceTypes
1050    def test_mode_wrong_dtype(self, device):
1051        def test_for_dtypes(x_ty, v_ty, i_ty, message):
1052            x = torch.ones(10, device=device, dtype=x_ty)
1053            v = torch.ones(10, device=device, dtype=v_ty)
1054            i = torch.ones(10, device=device, dtype=i_ty)
1055
1056            with self.assertRaisesRegex(RuntimeError, message):
1057                torch.mode(x, -1, True, out=(v, i))
1058
1059        err_msg = "expected scalar type .* but got .* for "
1060        values_err = err_msg + "values"
1061        indices_err = err_msg + "indices"
1062
1063        test_for_dtypes(torch.uint8, torch.int8, torch.long, values_err)
1064        test_for_dtypes(torch.int8, torch.int16, torch.long, values_err)
1065        test_for_dtypes(torch.int32, torch.float32, torch.long, values_err)
1066        test_for_dtypes(torch.float32, torch.float64, torch.long, values_err)
1067
1068        test_for_dtypes(torch.uint8, torch.uint8, torch.int8, indices_err)
1069        test_for_dtypes(torch.int8, torch.int8, torch.int16, indices_err)
1070        test_for_dtypes(torch.int32, torch.int32, torch.float32, indices_err)
1071        test_for_dtypes(torch.float32, torch.float32, torch.float64, indices_err)
1072
1073    @onlyCUDA
1074    def test_mode_wrong_device(self, device):
1075        # CPU Input Tensor
1076        x = torch.ones(2)
1077
1078        with self.assertRaisesRegex(RuntimeError,
1079                                    "expected device .* but got .* for values"):
1080            values = torch.tensor([], device=device)
1081            torch.mode(x, -1, True, out=(values, torch.tensor([], dtype=torch.long)))
1082
1083        with self.assertRaisesRegex(RuntimeError,
1084                                    "expected device .* but got .* for indices"):
1085            indices = torch.tensor([], device=device)
1086            torch.mode(x, -1, True, out=(torch.tensor([]), indices))
1087
1088    # TODO: make work on CUDA, too
1089    @onlyCPU
1090    def test_accreal_type(self, device) -> None:
1091        x = torch.ones(2, 3, 4)
1092        self.assertIsInstance(x.double().sum().item(), float)
1093        self.assertIsInstance(x.float().sum().item(), float)
1094        self.assertIsInstance(x.long().sum().item(), int)
1095        self.assertIsInstance(x.int().sum().item(), int)
1096        self.assertIsInstance(x.short().sum().item(), int)
1097        self.assertIsInstance(x.char().sum().item(), int)
1098        self.assertIsInstance(x.byte().sum().item(), int)
1099
1100    def test_var_mean_some_dims(self, device):
1101        sizes = (4, 6, 7, 5, 3)
1102        dims = len(sizes)
1103
1104        x = torch.rand(sizes, device=device)
1105        for num_of_dims in range(2, dims):
1106            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
1107            for dim in dim_list:
1108                for unbiased in [False, True]:
1109                    for keepdim in [False, True]:
1110                        var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
1111                        var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
1112                        mean2 = x.mean(dim=dim, keepdim=keepdim)
1113                        self.assertEqual(var1, var2)
1114                        self.assertEqual(mean1, mean2)
1115
1116    # TODO: this should be a generic opinfo test
1117    def test_all_any_empty(self, device):
1118        x = torch.ByteTensor().to(device)
1119        self.assertTrue(x.all())
1120        self.assertFalse(x.any())
1121
1122        x = torch.BoolTensor().to(device)
1123        self.assertTrue(x.all())
1124        self.assertFalse(x.any())
1125
1126    def test_all_issue117215(self, device):
1127        info = torch.iinfo(torch.uint8)
1128        a = torch.randint(info.min, info.max, (73, 11, 3, 17), dtype=torch.uint8)
1129        b = torch.all(a, dim=0)
1130        c = a.to(torch.bool).all(dim=0)
1131        self.assertEqual(torch.ne(b, c).sum(), 0)
1132
1133    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1134    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1135    def test_max_with_inf(self, device, dtype):
1136        a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1137        self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item())
1138        self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item())
1139        self.assertTrue(torch.max(a).item() == inf)
1140        self.assertTrue(torch.amax(a).item() == inf)
1141
1142    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1143    @dtypes(torch.half, torch.float, torch.bfloat16, torch.double)
1144    def test_min_with_inf(self, device, dtype):
1145        a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1146        self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item())
1147        self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item())
1148        self.assertTrue(torch.min(a).item() == -inf)
1149        self.assertTrue(torch.amin(a).item() == -inf)
1150
1151    def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False):
1152        def create_input(shape, device, dtype):
1153            if dtype.is_floating_point:
1154                return torch.randn(*shape, device=device, dtype=dtype)
1155            else:
1156                low = 0 if dtype == torch.bool else -1000
1157                high = 2 if dtype == torch.bool else 1000
1158                return torch.randint(low, high, shape, device=device, dtype=dtype)
1159        x = create_input((100, 100), device, dtype)
1160        self.compare_with_numpy(torchfn, reffn, x)
1161        # non contiguous
1162        x = create_input((10, 10, 10), device, dtype)
1163        x = x[:, 4]
1164        self.compare_with_numpy(torchfn, reffn, x)
1165
1166        def get_values(x):
1167            if isinstance(x, tuple):
1168                return x[0]
1169            return x
1170
1171        # indices
1172        if not skip_indices:
1173            size = 5
1174            x = create_input((size, size), device, dtype)
1175            inputs = (x, x.t())
1176            dims = (0, 1)
1177            for xinp, d in product(inputs, dims):
1178                self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp)
1179                result = torchfn(xinp, d, False)
1180                if isinstance(result, tuple):
1181                    v, i = result
1182                    if d == 1:
1183                        self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0)
1184                    else:
1185                        self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0)
1186        # nan
1187        if dtype.is_floating_point:
1188            for index in (0, 4, 99):
1189                x = create_input((100,), device, dtype)
1190                x[index] = nan
1191                if not skip_indices:
1192                    result = torchfn(x, 0)
1193                    v = get_values(result)
1194                    self.assertEqual(v, nan)
1195                    if isinstance(result, tuple):
1196                        i = result[1]
1197                        self.assertEqual(i, index)
1198                self.assertEqual(torchfn(x), nan)
1199
1200    @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1201    @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1202    @dtypes(torch.half, torch.float, torch.double)
1203    def test_max(self, device, dtype):
1204        self._test_minmax_helper(torch.max, np.amax, device, dtype)
1205
1206    @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1207    @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1208    @dtypes(torch.half, torch.float, torch.double)
1209    def test_min(self, device, dtype):
1210        self._test_minmax_helper(torch.min, np.amin, device, dtype)
1211
1212    @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1213    @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1214    @dtypes(torch.half, torch.float, torch.double)
1215    def test_amin(self, device, dtype):
1216        self._test_minmax_helper(torch.amin, np.amin, device, dtype)
1217
1218    @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1219    @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1220    @dtypes(torch.float, torch.double)
1221    def test_amax(self, device, dtype):
1222        self._test_minmax_helper(torch.amax, np.amax, device, dtype)
1223
1224    @onlyNativeDeviceTypes
1225    @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
1226    @dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
1227    def test_aminmax(self, device, dtype):
1228
1229        def _amin_wrapper(x, dim=None, keepdims=False):
1230            return torch.aminmax(x, dim=dim, keepdim=keepdims)[0]
1231
1232        def _amax_wrapper(x, dim=None, keepdims=False):
1233            return torch.aminmax(x, dim=dim, keepdim=keepdims)[1]
1234
1235        self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype)
1236        self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype)
1237
1238    @onlyNativeDeviceTypes
1239    @dtypes(*complex_types())
1240    def test_invalid_0dim_aminmax(self, device, dtype):
1241        with self.assertRaisesRegex(RuntimeError, 'not implemented'):
1242            torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0)
1243
1244    # TODO: bincount isn't a classic reduction -- maybe this test suite is
1245    #   reductions and summary ops?
1246    def test_bincount(self, device):
1247        # negative input throws
1248        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1249            torch.bincount(torch.tensor([1, -1], device=device))
1250        # n-d input, with n > 1 throws
1251        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1252            torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
1253        # floating input type throws
1254        with self.assertRaisesRegex(RuntimeError, 'not implemented'):
1255            torch.bincount(torch.tensor([1., 0.3], device=device))
1256        # minlength < 0 throws
1257        with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
1258            torch.bincount(torch.tensor([1, 3], device=device),
1259                           torch.tensor([.2, .2], device=device),
1260                           minlength=-1)
1261        # n-d weights, with n > 1 throws
1262        with self.assertRaisesRegex(RuntimeError, '1-d'):
1263            torch.bincount(torch.tensor([1, 0], device=device),
1264                           torch.tensor([[1., 0.3], [1., 0.3]], device=device))
1265        # input and weights dim mismatch
1266        with self.assertRaisesRegex(RuntimeError, 'same length'):
1267            torch.bincount(torch.tensor([1, 0], device=device),
1268                           torch.tensor([1., 0.3, 0.5], device=device))
1269        # 1-d input with no elements and default minlength
1270        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
1271                         torch.zeros(0, dtype=torch.long, device=device))
1272        # 1-d input with no elements and specified minlength
1273        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
1274                         torch.zeros(10, dtype=torch.long, device=device))
1275
1276        # test tensor method without weights
1277        long_counts = torch.tensor(
1278            [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
1279        self.assertEqual(
1280            torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
1281            long_counts)
1282        # test avoiding overflow for uint8 (#76979)
1283        count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
1284        count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
1285        self.assertEqual(count_uint8, count_int16)
1286        # test minlength functionality
1287        int_counts = torch.bincount(
1288            torch.tensor([1, 1, 1, 1], device=device), minlength=5)
1289        self.assertEqual(
1290            torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
1291            int_counts)
1292        # test weights
1293        byte_counts = torch.bincount(
1294            torch.tensor([0, 1, 1, 1, 4], device=device),
1295            torch.tensor([.1, .2, .3, .4, .5], device=device))
1296        self.assertEqual(
1297            torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
1298        byte_counts = torch.bincount(
1299            torch.tensor([0, 1, 1, 1, 4], device=device),
1300            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
1301        self.assertEqual(
1302            torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts)
1303        # test non-contiguous inputs and weights
1304        inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
1305        weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
1306        for i in [0, 1]:
1307            assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
1308            assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
1309        # inputs are non-contiguous but weights are contiguous
1310        self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
1311        # inputs and weights are non-contiguous
1312        self.assertEqual(
1313            inputs[:, 1].bincount(weights[:, 1]),
1314            torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1315        # weights are non-contiguous but inputs are contiguous
1316        self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
1317                         torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1318
1319        # test bincount on non-contiguous slices
1320        all0s = torch.zeros((32, 2), dtype=torch.int64, device=device)
1321        self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
1322
1323        all1s = torch.ones((32, 2), dtype=torch.int64, device=device)
1324        self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
1325
1326        # test large number of bins - global memory use
1327        big_exp = torch.zeros(10000000, device=device)
1328        big_exp[-1] = 50.0
1329        big_w = torch.tensor([.5] * 100, device=device)
1330        big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w)
1331        self.assertEqual(big_exp, big_out)
1332        # test large input size
1333        big_exp = torch.zeros(2, device=device, dtype=torch.int64)
1334        big_exp[1] = 1000000
1335        big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount()
1336        self.assertEqual(big_exp, big_out)
1337
1338    # TODO: how many var stability tests are there?
1339    def test_var_stability2(self, device):
1340        tensor = torch.FloatTensor([2281.5, 2281.25]).to(device)
1341
1342        # Stability for inner dim
1343        self.assertEqual(tensor.var(0), 0.03125)
1344
1345        # General stability
1346        self.assertEqual(tensor.var(), 0.03125)
1347
1348        # Stability for outer dimensions
1349        tensor = tensor.unsqueeze(1)
1350        self.assertEqual(tensor.var(0), 0.03125)
1351
1352    @onlyCPU
1353    @dtypes(torch.bfloat16, torch.float16)
1354    def test_sum_noncontig_lowp(self, device, dtype) -> None:
1355        dim_sequences = {
1356            2: [0, 1],
1357            3: [0, 1, 2],
1358            4: [0, 1, 2, 3],
1359            5: [0, 1, 2, 3, 4],
1360        }
1361
1362        def create_noncontig_inputs(x, ndim):
1363            if ndim == 2:
1364                return x[::2, ::2]
1365            elif ndim == 3:
1366                return x[::2, ::2, ::2]
1367            elif ndim == 4:
1368                return x[::2, ::2, ::2, ::2]
1369            elif ndim == 5:
1370                return x[::2, ::2, ::2, ::2, ::2]
1371
1372        def helper(self, shape, reduce_dims, device, dtype):
1373            for permute_list in list(permutations(dim_sequences[len(shape)], len(shape))):
1374                x = torch.ones(shape, device=device, dtype=dtype)
1375                x = create_noncontig_inputs(x, len(shape))
1376                x_trans = x.permute(permute_list)
1377                x_sum = torch.sum(x_trans, reduce_dims)
1378                x_trans_ref = x_trans.float()
1379                x_sum_ref = torch.sum(x_trans_ref, reduce_dims)
1380                self.assertEqual(x_sum, x_sum_ref.to(dtype=dtype))
1381
1382        shapes = [
1383            (50, 50),
1384            (50, 50, 50),
1385            (10, 50, 30, 30),
1386            (10, 5, 10, 50, 7),
1387        ]
1388
1389        for shape in shapes:
1390            for i in range(1, len(shape) + 1):
1391                reduce_dims = list(combinations(dim_sequences[len(shape)], i))
1392                for reduce_dim in reduce_dims:
1393                    helper(self, shape, reduce_dim, device, dtype)
1394
1395
1396    @onlyCPU
1397    @dtypes(torch.bool, torch.double)
1398    def test_sum_all(self, device, dtype) -> None:
1399        def check_sum_all(tensor: torch.Tensor) -> None:
1400            pylist = tensor.reshape(-1).tolist()
1401            self.assertEqual(tensor.sum(), sum(pylist))
1402
1403        if dtype != torch.bool:
1404            check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device))
1405            check_sum_all(torch.randn(200000, dtype=dtype, device=device))
1406            check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0])
1407        else:
1408            check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device))
1409
1410    def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
1411                                            memory_format, compare_data=True, default_is_preserve=False):
1412
1413        assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
1414
1415        # xc is a channels last tensor
1416        xc = input_generator_fn(device)
1417        # xc is not memory dense, but looks like channels last
1418        if memory_format == torch.channels_last:
1419            xc = xc[..., ::2, ::2]
1420        else:
1421            xc = xc[..., ::2, ::2, ::2]
1422
1423        clone = transformation_fn(xc, memory_format=torch.preserve_format)
1424        self.assertFalse(clone.is_contiguous())
1425        self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1426        self.assertFalse(xc.is_contiguous())
1427        self.assertFalse(xc.is_contiguous(memory_format=memory_format))
1428        if compare_data:
1429            self.assertEqual(xc, clone.to(xc))
1430
1431        xc = input_generator_fn(device)
1432        clone = transformation_fn(xc, memory_format=torch.contiguous_format)
1433        self.assertTrue(clone.is_contiguous())
1434        self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1435        if compare_data:
1436            self.assertEqual(xc, clone.to(xc))
1437
1438        xc = input_generator_fn(device)
1439        clone = transformation_fn(xc)
1440
1441        if default_is_preserve:
1442            self.assertFalse(clone.is_contiguous())
1443            self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1444        else:
1445            self.assertTrue(clone.is_contiguous())
1446            self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1447        if compare_data:
1448            self.assertEqual(xc, clone.to(xc))
1449
1450        x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
1451        for _ in range(10):
1452            permutation = list(range(len(x.shape)))
1453            random.shuffle(permutation)
1454            x = x.permute(permutation)
1455            self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
1456
1457    @onlyCPU
1458    @dtypes(torch.double)
1459    def test_sum_out(self, device, dtype: torch.dtype) -> None:
1460        x = torch.rand(100, 100, dtype=dtype, device=device)
1461        res1 = torch.sum(x, 1)
1462        res2 = torch.tensor((), dtype=dtype, device=device)
1463        torch.sum(x, 1, out=res2)
1464        self.assertEqual(res1, res2)
1465        x = torch.rand(100, 100, 100, dtype=dtype, device=device)
1466        res1 = x.sum(2).sum(1)
1467        res2 = torch.tensor((), dtype=dtype, device=device)
1468        torch.sum(x, (2, 1), out=res2)
1469        self.assertEqual(res1, res2)
1470
1471    @onlyCUDA
1472    @dtypes(torch.float16, torch.float32)
1473    def test_prod_gpu(self, device, dtype):
1474        x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device)
1475
1476        # Check all combinations: fp16 input - fp16 output, fp16 input - fp32
1477        # output, fp32 input - fp16 output, fp32 input - fp32 output
1478        for dtype_output in [torch.float16, torch.float32]:
1479            result_expected = torch.tensor(2592, dtype=dtype_output, device=device)
1480            output = torch.prod(x, dtype=dtype_output)
1481            self.assertEqual(output, result_expected)
1482
1483            output = x.prod(dtype=dtype_output)
1484            self.assertEqual(output, result_expected)
1485
1486    @onlyCPU
1487    @dtypes(torch.float)
1488    def test_prod(self, device, dtype):
1489        x = torch.rand(100, 100, dtype=dtype, device=device)
1490        res1 = torch.prod(x, 1)
1491        res2 = torch.tensor((), dtype=dtype, device=device)
1492        torch.prod(x, 1, out=res2)
1493        self.assertEqual(res1, res2)
1494
1495    @onlyCPU
1496    @dtypes(torch.float16, torch.bfloat16)
1497    def test_prod_lowp(self, device, dtype):
1498        x = torch.rand(100, 100, dtype=dtype, device=device)
1499        x_ref = x.float()
1500        res1 = torch.prod(x, 1)
1501        res2 = torch.prod(x_ref, 1)
1502        self.assertEqual(res1, res2.to(dtype=dtype))
1503        res1 = torch.prod(x, 0)
1504        res2 = torch.prod(x_ref, 0)
1505        self.assertEqual(res1, res2.to(dtype=dtype))
1506
1507    def test_prod_bool(self, device):
1508        vals = [
1509            [True, True],
1510            [True, False],
1511            [False, False],
1512            [],
1513            [False] * 256,  # https://github.com/pytorch/pytorch/issues/127866
1514        ]
1515        for val in vals:
1516            result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
1517            expect = np.prod(np.array(val), dtype=bool)
1518            self.assertEqual(result, expect)
1519
1520            result = torch.prod(torch.tensor(val, device=device)).item()
1521            expect = np.prod(np.array(val))
1522            self.assertEqual(result, expect)
1523
1524    @onlyCPU
1525    def test_max_mixed_devices(self, device):
1526        a = torch.randn(10, device=device)
1527        if torch.cuda.is_available():
1528            values = torch.randn(10).cuda()
1529            indices = torch.cuda.LongTensor()
1530            self.assertRaises(RuntimeError,
1531                              lambda: torch.max(a, 0, out=(values, indices)))
1532            self.assertRaises(RuntimeError,
1533                              lambda: torch.amax(a, 0, out=values))
1534
1535    @onlyCPU
1536    def test_min_mixed_devices(self, device):
1537        a = torch.randn(10, device=device)
1538        if torch.cuda.is_available():
1539            values = torch.randn(10).cuda()
1540            indices = torch.cuda.LongTensor()
1541            self.assertRaises(RuntimeError,
1542                              lambda: torch.min(a, 0, out=(values, indices)))
1543            self.assertRaises(RuntimeError,
1544                              lambda: torch.amin(a, 0, out=values))
1545
1546    # TODO: consider refactoring with bincount test
1547    def test_bucketization(self, device):
1548        values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device)
1549        values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1550
1551        # simple 1d boundary and 3d input value
1552        boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device)
1553        expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device)
1554        output = torch.empty(2, 2, 3, device=device, dtype=torch.int64)
1555        self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result)
1556        self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result)
1557        expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1558        self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result)
1559        self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result)
1560
1561        # simple float 1d boundary and 1d input with output int32 type
1562        for dtype in [torch.float32, torch.float16]:
1563            values_1d_float = values_1d.to(dtype)
1564            boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=dtype)
1565            expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1566            self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result)
1567            self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1568
1569        # multiple dimension input with 0 elements
1570        boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64)
1571        values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64)
1572        expected_result = values_0_el.to(torch.int64)
1573        self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result)
1574        self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result)
1575
1576        # nan input
1577        values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64)
1578        boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64)
1579        expected_result = torch.tensor([1, 4, 2, 4], device=device)
1580        self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result)
1581        expected_result = torch.tensor([2, 4, 3, 4], device=device)
1582        self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result)
1583        self.assertEqual(torch.searchsorted(boundaries, values_nan, side='right'), expected_result)
1584
1585        # type promotion and non contiguous tensors
1586        values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32)
1587        boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64)
1588        expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device)
1589        if self.device_type != 'xla':
1590            self.assertWarnsRegex(
1591                UserWarning, "tensor is non-contiguous",
1592                lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result))
1593        else:
1594            # All tensors in XLA is contiguous even doing permute, no warning msg will be generate in XLA
1595            self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)
1596
1597        # scalar type
1598        boundaries = torch.tensor([1.5, 2.5, 3.5], device=device)
1599        expected_result = torch.tensor(1, device=device)
1600        self.assertEqual(torch.searchsorted(boundaries, 2), expected_result)
1601        self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result)
1602        expected_result = torch.tensor(3, device=device)
1603        scalar_tensor_nan = torch.tensor(float('nan'), device=device)
1604        self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result)
1605        self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result)
1606
1607        # invalid input dimensions
1608        boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device)
1609        with self.assertRaisesRegex(
1610                RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"):
1611            torch.searchsorted(boundaries, values_3d)
1612        with self.assertRaisesRegex(
1613                RuntimeError, "boundaries tensor must be 1 dimension"):
1614            torch.bucketize(values_3d, boundaries)
1615        with self.assertRaisesRegex(
1616                RuntimeError, "only when boundaries tensor dimension is 1"):
1617            torch.searchsorted(boundaries, 1)
1618
1619        # incompatiable output tensor's dtype
1620        def test_output_dtype(dtype, is_int32):
1621            output = values_1d.to(dtype)
1622            with self.assertRaisesRegex(
1623                    RuntimeError, "output tensor's dtype is wrong"):
1624                torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32)
1625
1626        test_output_dtype(torch.float32, False)
1627        test_output_dtype(torch.int32, False)
1628        test_output_dtype(torch.int64, True)
1629
1630        # invalid side argument
1631        with self.assertRaisesRegex(RuntimeError, "side can only be 'left' or 'right'"):
1632            torch.searchsorted(values_1d, values_1d, side='bad')
1633
1634        # invalid sorter argument, wrong size
1635        with self.assertRaisesRegex(RuntimeError, "boundary and sorter must have the same size"):
1636            sequence = torch.rand_like(values_1d, dtype=torch.float)
1637            _, sorted_idx = torch.sort(sequence)
1638            torch.searchsorted(sequence, values_1d, sorter=sorted_idx[:-1])
1639
1640        # invalid sorter argument, is not dtype long
1641        with self.assertRaisesRegex(RuntimeError, "sorter must be a tensor of long dtype"):
1642            sequence = torch.rand_like(values_1d, dtype=torch.float)
1643            _, sorted_idx = torch.sort(sequence)
1644            torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32))
1645
1646        # invalid sorter value, out of bound (>= innermost size)
1647        with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1648            torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3]))
1649
1650        # invalid sorter value, out of bound (< 0)
1651        with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1652            torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2]))
1653
1654        # scalar type bfloat16
1655        if self.device_type == 'cpu':
1656            def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False):
1657                values_1d_float = values_1d.to(torch.float32)
1658                boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32)
1659                if values_bf16:
1660                    values_1d_float = values_1d_float.to(torch.bfloat16)
1661                if boundaries_bf16:
1662                    boundaries = boundaries.to(torch.bfloat16)
1663                expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1664                self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1665
1666            test_dtype_bfloat16(True, False)
1667            test_dtype_bfloat16(False, True)
1668            test_dtype_bfloat16(True, True)
1669
1670    @dtypes(*all_types_and(torch.half, torch.bfloat16))
1671    def test_nansum(self, device, dtype):
1672        args = product(
1673            (True, False),  # noncontiguous
1674            (0, 1, None),   # dim
1675        )
1676        zero = torch.zeros((), device=device, dtype=dtype)
1677
1678        for noncontiguous, dim in args:
1679            # Randomly scale the values
1680            scale = random.randint(10, 100)
1681            x = make_tensor((17, 17), device=device, dtype=dtype,
1682                            low=-scale, high=scale, noncontiguous=noncontiguous)
1683
1684            if dtype.is_floating_point:
1685                nan_mask = x < 0.2 * scale
1686                x_nonan = torch.where(nan_mask, zero, x)
1687                x[nan_mask] = np.nan
1688            else:
1689                x_nonan = x
1690
1691            dim_kwargs = {} if dim is None else {"dim": dim}
1692            expect = torch.sum(x_nonan, **dim_kwargs)
1693            actual = torch.nansum(x, **dim_kwargs)
1694            self.assertEqual(expect, actual)
1695
1696    def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype,
1697                                            with_extremal=False, atol=None, rtol=None,
1698                                            exact_dtype=True, with_keepdim=False):
1699        # Test 0-d to 3-d tensors.
1700        for ndims in range(0, 4):
1701            shape = _rand_shape(ndims, min_size=5, max_size=10)
1702            for n in range(ndims + 1):
1703                for c in combinations(list(range(ndims)), n):
1704                    for count_dim in permutations(c):
1705                        # Generate Input.
1706                        x = _generate_input(shape, dtype, device, with_extremal)
1707
1708                        if count_dim == ():
1709                            # Default `dims=None` case
1710                            self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None,
1711                                                    atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1712                        else:
1713                            # With `dims: tuple of ints` case
1714                            if with_keepdim:
1715                                torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim)
1716                                np_func_partial = partial(np_func, keepdims=True, axis=count_dim)
1717                            else:
1718                                torch_func_partial = partial(torch_func, dim=count_dim)
1719                                np_func_partial = partial(np_func, axis=count_dim)
1720                            self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None,
1721                                                    atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1722
1723    @dtypes(*all_types_and_complex_and(torch.half))
1724    def test_count_nonzero(self, device, dtype):
1725        self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype)
1726        self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True)
1727
1728    # TODO: Investigate why the output is not close to numpy.
1729    def _get_relaxed_tolerances_for(self, dtype):
1730        if dtype == torch.float16:
1731            atol = 0.4
1732            rtol = 1e-2
1733        elif dtype == torch.float32:
1734            atol = 7e-05
1735            rtol = 3e-06
1736        else:
1737            # Default values
1738            atol = None
1739            rtol = None
1740        return atol, rtol
1741
1742    def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False):
1743        def is_integral(dtype):
1744            return dtype in integral_types()
1745
1746        exact_dtype = True
1747        # On Windows CI, the current version of `numpy` promotes all lower integers
1748        # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking
1749        # the exact dtype.
1750        # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580
1751        # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370
1752        if IS_WINDOWS and is_integral(dtype):
1753            exact_dtype = False
1754        # For uint8, numpy promotes to uint64 while torch promotes to int64.
1755        # So we must skip this as well.
1756        if dtype == torch.uint8:
1757            exact_dtype = False
1758
1759        # TODO: Investigate why the output is not close to numpy.
1760        atol, rtol = self._get_relaxed_tolerances_for(dtype)
1761        self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
1762                                                 atol=atol, rtol=rtol, exact_dtype=exact_dtype,
1763                                                 with_keepdim=with_keepdim, with_extremal=with_extremal)
1764
1765    @onlyNativeDeviceTypes
1766    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1767    def test_sum_vs_numpy(self, device, dtype):
1768        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype)
1769        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True)
1770        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True)
1771
1772    @onlyNativeDeviceTypes
1773    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1774    def test_nansum_vs_numpy(self, device, dtype):
1775        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype)
1776        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True)
1777        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True)
1778
1779    @onlyCPU
1780    @dtypes(*complex_types())
1781    def test_nansum_complex(self, device, dtype):
1782        x = torch.randn((3, 3, 3), device=device, dtype=dtype)
1783        with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"):
1784            torch.nansum(x)
1785
1786    @dtypes(*all_types_and(torch.half))
1787    def test_nansum_out_dtype(self, device, dtype):
1788        out_dtype = dtype
1789        inp_dtypes = all_types_and(torch.half) if out_dtype.is_floating_point else integral_types()
1790        for inp_dtype in inp_dtypes:
1791            # TODO: Investigate why the output is not close to numpy.
1792            atol, rtol = self._get_relaxed_tolerances_for(dtype)
1793            shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10)
1794            x = _generate_input(shape, inp_dtype, device, with_extremal=False)
1795            torch_fn = partial(torch.nansum, dtype=out_dtype)
1796            np_out_dtype = torch_to_numpy_dtype_dict[out_dtype]
1797            np_fn = partial(np.nansum, dtype=np_out_dtype)
1798            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None, atol=atol, rtol=rtol)
1799
1800    @dtypes(*all_types_and(torch.half))
1801    def test_argminmax_multiple(self, device, dtype):
1802        # Case: All Ones
1803        t = torch.ones(3, 3, device=device, dtype=dtype)
1804        self.compare_with_numpy(torch.argmax, np.argmax, t)
1805        self.compare_with_numpy(torch.argmin, np.argmin, t)
1806
1807        # Case: With single `nan` present.
1808        if dtype in floating_types_and(torch.half, torch.bfloat16):
1809            t[2, 2] = float('nan')
1810            self.compare_with_numpy(torch.argmax, np.argmax, t)
1811            self.compare_with_numpy(torch.argmin, np.argmin, t)
1812
1813        # Case: Randomly Generated Tensors
1814        for ndims in range(1, 5):
1815            shape = _rand_shape(ndims, min_size=5, max_size=10)
1816            for with_extremal in [False, True]:
1817                for contiguous in [False, True]:
1818                    # Generate Input.
1819                    x = _generate_input(shape, dtype, device, with_extremal)
1820
1821                    if dtype == torch.half:
1822                        max_val = torch.max(x.to(torch.float))
1823                        min_val = torch.min(x.to(torch.float))
1824                    else:
1825                        max_val = torch.max(x)
1826                        min_val = torch.min(x)
1827
1828                    mask = torch.randn(x.shape) > 0.5
1829                    x[mask] = torch.tensor(max_val + 1, dtype=dtype)
1830
1831                    mask = torch.randn(x.shape) > 0.5
1832                    x[mask] = torch.tensor(min_val - 1, dtype=dtype)
1833
1834                    if not contiguous:
1835                        x = x.T
1836
1837                    self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None)
1838                    self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None)
1839
1840                    # Verify indices returned by max and min.
1841                    if dtype != torch.half:
1842                        rand_dim = random.randint(0, ndims - 1)
1843                        self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1],
1844                                                lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None)
1845                        self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1],
1846                                                lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None)
1847
1848        def verify_against_numpy(t):
1849            # Argmax
1850            torch_fn = partial(torch.argmax, dim=1)
1851            np_fn = partial(np.argmax, axis=1)
1852            self.compare_with_numpy(torch_fn, np_fn, t)
1853            # Non-contiguous input
1854            self.compare_with_numpy(torch_fn, np_fn, t.T)
1855
1856            # Verify indices returned by max.
1857            if dtype != torch.half:
1858                self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1859                self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1860
1861            # Argmin
1862            torch_fn = partial(torch.argmin, dim=1)
1863            np_fn = partial(np.argmin, axis=1)
1864            self.compare_with_numpy(torch_fn, np_fn, t)
1865            # Non-contiguous input
1866            self.compare_with_numpy(torch_fn, np_fn, t.T)
1867
1868            # Verify indices returned by min.
1869            if dtype != torch.half:
1870                self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1871                self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1872
1873        # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998
1874        t = torch.tensor([[1, 5],
1875                          [2, 10],
1876                          [3, 3]], device=device, dtype=dtype)
1877        verify_against_numpy(t)
1878
1879        # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998
1880        t = torch.tensor([[1, 5],
1881                          [2, 10],
1882                          [0, 0]], device=device, dtype=dtype)
1883        verify_against_numpy(t)
1884
1885    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1886    def test_all_any_vs_numpy(self, device, dtype):
1887        # Note [all, any uint8 compatibility]: However for compatibility reason,
1888        # for `uint8`, they return Tensor of same dtype `uint8`.
1889        # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
1890        exact_dtype = True if dtype != torch.uint8 else False
1891
1892        def _test_all_any(x):
1893            self.compare_with_numpy(torch.all, np.all, x)
1894            self.compare_with_numpy(torch.any, np.any, x)
1895
1896        def _test_all_any_with_dim(x, dim):
1897            torch_fn = partial(torch.all, dim=dim)
1898            np_fn = partial(np.all, axis=dim)
1899            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1900
1901            torch_fn = partial(torch.any, dim=dim)
1902            np_fn = partial(np.any, axis=dim)
1903            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1904
1905        def _test_out_variant(x, dim):
1906            out = torch.empty_like(x)
1907            if dtype == torch.bool or dtype == torch.uint8:
1908                expected = torch.all(x, dim)
1909                torch.all(x, dim, out=out)
1910                self.assertEqual(expected, out)
1911
1912                expected = torch.any(x, dim)
1913                torch.any(x, dim, out=out)
1914                self.assertEqual(expected, out)
1915            else:
1916                with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"):
1917                    torch.all(x, dim, out=out)
1918
1919                with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"):
1920                    torch.any(x, dim, out=out)
1921
1922        def _test_all_any_with_dim_keepdim(x, dim, keepdim):
1923            torch_fn = partial(torch.all, dim=dim, keepdim=keepdim)
1924            np_fn = partial(np.all, axis=dim, keepdims=keepdim)
1925            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1926
1927            torch_fn = partial(torch.any, dim=dim, keepdim=keepdim)
1928            np_fn = partial(np.any, axis=dim, keepdims=keepdim)
1929            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1930
1931        def _test_output_dtype(x):
1932            # This test will fail once the functions return bool output
1933            # for uint8 input.
1934            expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool
1935            self.assertEqual(torch.all(x).dtype, expected_dtype)
1936            self.assertEqual(torch.any(x).dtype, expected_dtype)
1937
1938            self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype)
1939            self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype)
1940
1941        for ndim in range(5):
1942            shape = _rand_shape(ndim, 1, 5)
1943            x = _generate_input(shape, dtype, device, with_extremal=False)
1944            _test_all_any(x)
1945            _test_all_any(x.T)
1946            _test_all_any(x[..., ::2])
1947
1948            x = _generate_input(shape, dtype, device, with_extremal=True)
1949            _test_all_any(x)
1950            _test_all_any(x.T)
1951            _test_all_any(x[..., ::2])
1952
1953            x = torch.zeros_like(x)
1954            _test_all_any(x)
1955            _test_all_any(x.T)
1956            _test_all_any(x[..., ::2])
1957
1958            x = torch.ones_like(x)
1959            _test_all_any(x)
1960            _test_all_any(x.T)
1961            _test_all_any(x[..., ::2])
1962            _test_output_dtype(x)
1963            for dim in range(ndim):
1964                x = _generate_input(shape, dtype, device, with_extremal=False)
1965                _test_all_any_with_dim(x, dim)
1966                _test_all_any_with_dim(x.T, dim)
1967                _test_all_any_with_dim(x[..., ::2], dim)
1968                _test_out_variant(x, dim)
1969                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1970                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1971
1972                x = _generate_input(shape, dtype, device, with_extremal=True)
1973                _test_all_any_with_dim(x, dim)
1974                _test_all_any_with_dim(x.T, dim)
1975                _test_all_any_with_dim(x[..., ::2], dim)
1976                _test_out_variant(x, dim)
1977                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1978                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1979
1980                x = torch.zeros_like(x)
1981                _test_all_any_with_dim(x, dim)
1982                _test_all_any_with_dim(x.T, dim)
1983                _test_all_any_with_dim(x[..., ::2], dim)
1984                _test_out_variant(x, dim)
1985                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1986                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1987
1988                x = torch.ones_like(x)
1989                _test_all_any_with_dim(x, dim)
1990                _test_all_any_with_dim(x.T, dim)
1991                _test_all_any_with_dim(x[..., ::2], dim)
1992                _test_out_variant(x, dim)
1993                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1994                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1995
1996    # TODO: part of this test covers torch.norm, with should be covered by test_linalg
1997    @onlyNativeDeviceTypes
1998    def test_repeated_dim(self, device):
1999        ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
2000               torch.norm]
2001        x = torch.randn(3, 3, 3, 3, device=device)
2002
2003        error_msg = r'appears multiple times in the list of dims'
2004        for op in ops:
2005            for dim in [(0, 0), (0, -4)]:
2006                with self.assertRaisesRegex(RuntimeError, error_msg):
2007                    op(x, dim=dim)
2008
2009    # TODO: update this test to comapre against NumPy
2010    @onlyCUDA
2011    def test_var(self, device):
2012        cpu_tensor = torch.randn(2, 3, 3)
2013        device_tensor = cpu_tensor.to(device)
2014        self.assertEqual(device_tensor.var(), cpu_tensor.var())
2015        self.assertEqual(device_tensor.var(1), cpu_tensor.var(1))
2016        self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
2017        self.assertEqual(device_tensor.std(), cpu_tensor.std())
2018        self.assertEqual(device_tensor.std(1), cpu_tensor.std(1))
2019        self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
2020
2021        cpu_tensor = torch.randn(100)
2022        device_tensor = cpu_tensor.to(device)
2023        self.assertEqual(device_tensor.var(), cpu_tensor.var())
2024
2025    # TODO: update this test to compare against NumPy
2026    @onlyCUDA
2027    def test_var_large_input(self, device):
2028        # Large, not-nice input
2029        cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67)
2030        device_tensor = cpu_tensor.to(device)
2031
2032        self.assertEqual(cpu_tensor.var(2), device_tensor.var(2))
2033
2034    # TODO: update this to compare against NumPy instead of CPU
2035    @onlyCUDA
2036    @dtypes(torch.double)
2037    def test_sum_noncontig(self, device, dtype):
2038        x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2)
2039        y = x.cpu()
2040        self.assertEqual(x.sum().cpu(), y.sum())
2041        self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2)))
2042        self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3)))
2043
2044    # TODO: update this to compare against NumPy instead of CPU
2045    @onlyCUDA
2046    def test_min_max_nan(self, device):
2047        tests = [(lambda x: x.min(), 'min'),
2048                 (lambda x: x.max(), 'max'),
2049                 (lambda x: x.amin(), 'amin'),
2050                 (lambda x: x.amax(), 'amax'),
2051                 (lambda x: x.min(0).values, 'min_dim'),
2052                 (lambda x: x.max(0).values, 'max_dim'),
2053                 (lambda x: x.amin(0), 'amin_dim'),
2054                 (lambda x: x.amax(0), 'amax_dim')]
2055        for f, name in tests:
2056            a = torch.arange(25.0).view(5, 5)
2057            a[2, 2] = nan
2058            actual = f(a.to(device)).cpu()
2059            expected = f(a).cpu()
2060            self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg=f'nans for {name}')
2061            self.assertEqual(actual[~torch.isnan(actual)],
2062                             expected[~torch.isnan(expected)], msg=f'nans for {name}')
2063
2064    # TODO: make this test generic using OpInfos
2065    @onlyCUDA
2066    def test_sum_cpu_device_mismatch(self, device):
2067        x = torch.randn(20, dtype=torch.float32, device=device)
2068        y = torch.randn(1, dtype=torch.float32)
2069
2070        err_string = f"Expected out tensor to have device {device}, but got cpu instead"
2071
2072        with self.assertRaisesRegex(RuntimeError, err_string):
2073            torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2074
2075        # tests half to float promotion
2076        if self.device_type == 'cuda':
2077            x = x.half()
2078            with self.assertRaisesRegex(RuntimeError, err_string):
2079                torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2080
2081    # Assert for illegal dtype would not be raised on XLA
2082    @onlyNativeDeviceTypes
2083    def test_minmax_illegal_dtype(self, device):
2084        x = torch.randn(5, 5, dtype=torch.float32, device=device)
2085        valid_values = torch.empty(5, dtype=torch.float32, device=device)
2086        valid_indices = torch.empty(5, dtype=torch.long, device=device)
2087        illegal_values = torch.empty(5, dtype=torch.int, device=device)
2088        illegal_indices = torch.empty(5, dtype=torch.double, device=device)
2089        torch.max(x, dim=0, out=(valid_values, valid_indices))
2090        torch.min(x, dim=0, out=(valid_values, valid_indices))
2091        torch.amax(x, dim=0, out=valid_values)
2092        torch.amin(x, dim=0, out=valid_values)
2093        rmsg = r'scalar type|dtype'
2094        with self.assertRaisesRegex(RuntimeError, rmsg):
2095            torch.max(x, dim=0, out=(illegal_values, valid_indices))
2096        with self.assertRaisesRegex(RuntimeError, rmsg):
2097            torch.min(x, dim=0, out=(illegal_values, valid_indices))
2098        with self.assertRaisesRegex(RuntimeError, rmsg):
2099            torch.max(x, dim=0, out=(valid_values, illegal_indices))
2100        with self.assertRaisesRegex(RuntimeError, rmsg):
2101            torch.min(x, dim=0, out=(valid_values, illegal_indices))
2102        with self.assertRaisesRegex(RuntimeError, rmsg):
2103            torch.max(x, dim=0, out=(illegal_values, illegal_indices))
2104        with self.assertRaisesRegex(RuntimeError, rmsg):
2105            torch.min(x, dim=0, out=(illegal_values, illegal_indices))
2106
2107    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2108    def test_dim_arg_reduction_scalar(self, device, dtype):
2109        example = 4.0
2110
2111        x = torch.tensor(example, device=device, dtype=dtype)
2112        self.assertEqual(x.argmax().item(), 0)
2113        self.assertEqual(x.argmax(dim=None).item(), 0)
2114        self.assertEqual(x.argmax(dim=0).item(), 0)
2115        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2116
2117        x = torch.tensor(example, device=device, dtype=dtype)
2118        self.assertEqual(x.argmin().item(), 0)
2119        self.assertEqual(x.argmin(dim=None).item(), 0)
2120        self.assertEqual(x.argmin(dim=0).item(), 0)
2121        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2122
2123
2124    @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
2125    @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
2126    def test_dim_reduction(self, device, dtype):
2127        example = [[-1, 2, 1], [5, 3, 6]]
2128
2129        sum_dtype = {
2130            torch.bfloat16: torch.bfloat16,
2131            torch.double: torch.double,
2132            torch.float: torch.float,
2133            torch.half: torch.half,
2134            torch.int64: torch.int64,
2135            torch.int32: torch.int64,
2136            torch.int16: torch.int64,
2137            torch.int8: torch.int64
2138        }
2139
2140        # This won't test for 256bit instructions, since we usually
2141        # only work on 1 cacheline (512bit) at a time and these
2142        # examples aren't big enough to trigger that.
2143        x = torch.tensor(example, device=device, dtype=dtype)
2144        self.assertEqual(x.sum().item(), 16)
2145        self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype]))
2146        self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype]))
2147        y = torch.tensor(example, device=device, dtype=sum_dtype[dtype])
2148        torch.sum(x, 0, out=y)
2149        self.assertEqual(x.sum(0), y)
2150
2151        # Mean not supported for Int types
2152        if dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
2153            x = torch.tensor(example, device=device, dtype=dtype)
2154            self.assertEqual(x.mean().item(), 16.0 / 6)
2155            self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype))
2156            self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype))
2157            self.assertEqual(x.mean(), x.mean((0, 1)))
2158
2159        prod_dtype = {
2160            torch.bfloat16: torch.bfloat16,
2161            torch.double: torch.double,
2162            torch.float: torch.float,
2163            torch.float16: torch.float16,
2164            torch.int64: torch.int64,
2165            torch.int32: torch.int64,
2166            torch.int16: torch.int64,
2167            torch.int8: torch.int64,
2168        }
2169
2170        # prod is not supported for float16 & bfloat16 on CPU
2171        if not (self.device_type == 'cpu' and dtype in [torch.float16, torch.bfloat16]):
2172            x = torch.tensor(example, device=device, dtype=dtype)
2173            self.assertEqual(x.prod().item(), -180)
2174            self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype]))
2175            self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype]))
2176
2177        x = torch.tensor(example, device=device, dtype=dtype)
2178
2179        self.assertEqual(x.min().item(), -1)
2180        self.assertEqual(x.argmin().item(), 0)
2181
2182        # TODO: torch.min does not support the same operation as argmin
2183        # for the same case, should we enable it?
2184        self.assertEqual(x.argmin(dim=None).item(), 0)
2185
2186        self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype),
2187                                    torch.tensor([0, 0, 0], dtype=torch.int64)))
2188        self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype))
2189        self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64))
2190
2191        self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype),
2192                         torch.tensor([[0, 0, 0]], dtype=torch.int64)))
2193        self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype))
2194        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64))
2195
2196        self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype),
2197                         torch.tensor([0, 1], dtype=torch.int64)))
2198        self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype))
2199        self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64))
2200
2201        self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype),
2202                         torch.tensor([[0], [1]], dtype=torch.int64)))
2203        self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype))
2204        self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64))
2205
2206        # test that non-contiguous tensors work
2207        self.assertEqual(x[:, :2].min().item(), -1)
2208        self.assertEqual(x[:, :2].amin().item(), -1)
2209        self.assertEqual(x[:, :2].argmin().item(), 0)
2210
2211        x = torch.tensor(example, device=device, dtype=dtype)
2212
2213        self.assertEqual(x.max().item(), 6)
2214        self.assertEqual(x.amax().item(), 6)
2215        self.assertEqual(x.argmax().item(), 5)
2216
2217        self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype),
2218                                    torch.tensor([1, 1, 1], dtype=torch.int64)))
2219        self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype))
2220        self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64))
2221
2222        self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype),
2223                                                      torch.tensor([[1, 1, 1]], dtype=torch.int64)))
2224        self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype))
2225        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64))
2226
2227        self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype),
2228                                    torch.tensor([1, 2], dtype=torch.int64)))
2229        self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype))
2230        self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64))
2231
2232        self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype),
2233                                                  torch.tensor([[1], [2]], dtype=torch.int64)))
2234        self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype))
2235        self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64))
2236
2237        # test that non-contiguous tensors work
2238        self.assertEqual(x[:, :2].max().item(), 5)
2239        self.assertEqual(x[:, :2].amax().item(), 5)
2240        self.assertEqual(x[:, :2].argmax().item(), 2)
2241
2242
2243    @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
2244    @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
2245    @parametrize("fn_name", [
2246        "mean", "median", "nanmedian", "mode", "norm", "prod",
2247        "std", "sum", "var", "max", "min", "amax", "amin"])
2248    def test_dim_reduction_fns(self, device, dtype, fn_name):
2249        def normfn_attr(t, dim, keepdim=False, out=None):
2250            attr = torch.norm
2251            return attr(t, 2, dim, keepdim, out=out)
2252
2253        fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
2254
2255        def fn(x, dim, keepdim=False, out=None):
2256            ans = fn_attr(x, dim, keepdim=keepdim, out=out)
2257            return ans if not isinstance(ans, tuple) else ans[0]
2258
2259        def fn_tuple(x, dim, keepdim=False, out=None):
2260            return fn_attr(x, dim, keepdim=keepdim, out=out)
2261
2262        def test_multidim(x, dim):
2263            self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
2264            self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
2265            self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())
2266
2267        # general case
2268        x = torch.randn(3, 4, 5, device=device)
2269        dim = random.randint(0, 2)
2270        test_multidim(x, dim)
2271
2272        # check 1-d behavior
2273        x = torch.randn(1, device=device)
2274        dim = 0
2275        self.assertEqual(fn(x, dim).shape, ())
2276        self.assertEqual(fn(x, dim, keepdim=True).shape, (1,))
2277
2278        # check reducing of a singleton dimension
2279        dims = [3, 4, 5]
2280        singleton_dim = random.randint(0, 2)
2281        dims[singleton_dim] = 1
2282        x = torch.randn(dims, device=device)
2283        test_multidim(x, singleton_dim)
2284
2285        # check reducing with output kwargs
2286        if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']:
2287            y = torch.randn(5, 3, device=device)
2288            values = torch.randn(5, 3, device=device)
2289            indices = torch.zeros(5, 3, device=device).long() - 1
2290            fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
2291            values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
2292            self.assertEqual(values[:, 1], values_expected,
2293                             msg=f'{fn_name} values with out= kwarg')
2294            self.assertEqual(indices[:, 1], indices_expected,
2295                             msg=f'{fn_name} indices with out= kwarg')
2296            return
2297
2298        x = torch.randn(5, 3, device=device)
2299        y = torch.randn(5, 3, device=device)
2300        fn(y, 1, keepdim=False, out=x[:, 1])
2301        expected = fn(y, 1, keepdim=False)
2302        self.assertEqual(x[:, 1], expected, msg=f'{fn_name} with out= kwarg')
2303
2304    @onlyCUDA
2305    @largeTensorTest('10GB')
2306    def test_reduction_split(self, device):
2307        # Test reduction when there is a 32bit-indexing split
2308        # https://github.com/pytorch/pytorch/issues/37583
2309        input_ = torch.randn(5, 14400, 14400, device=device)
2310        result = input_.sum(dim=0)
2311        expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4]
2312        self.assertEqual(result, expect)
2313
2314    @onlyCUDA
2315    @dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2316    def test_reduction_vectorize_along_input_corner(self, device, dtype):
2317        # 1D case: sum
2318        size = 1024 * 1024 * 64 + 3
2319        shift = 1
2320        x = torch.zeros(size, dtype=dtype, device=device)
2321        y = x[shift:]
2322        for i in range(100):
2323            x.zero_()
2324            x[i] = 1
2325            self.assertEqual(x.sum(), 1.0)
2326            if i < shift:
2327                self.assertEqual(y.sum(), 0.0)
2328            else:
2329                self.assertEqual(y.sum(), 1.0)
2330        for i in range(1, 100):
2331            x.zero_()
2332            x[-i] = 1
2333            self.assertEqual(x.sum(), 1.0)
2334            self.assertEqual(y.sum(), 1.0)
2335        # 1D case: argmax
2336        size = 1024 * 1024 * 64 + 3
2337        shift = 1
2338        ysize = size - shift
2339        x = torch.zeros(size, dtype=dtype, device=device)
2340        y = x[shift:]
2341        for i in range(100):
2342            x.zero_()
2343            x[i] = 1
2344            self.assertEqual(x.argmax().item(), i)
2345            if i >= shift:
2346                self.assertEqual(y.argmax().item(), i - shift)
2347        for i in range(1, 100):
2348            x.zero_()
2349            x[-i] = 1
2350            self.assertEqual(x.argmax().item(), size - i)
2351            self.assertEqual(y.argmax().item(), ysize - i)
2352        # 2D case: sum
2353        size = (7, 1024 * 1024 + 3)
2354        x = torch.zeros(size, dtype=dtype, device=device)
2355        for i in range(100):
2356            x.zero_()
2357            for j in range(7):
2358                x[j][i] = j
2359            xs = x.sum(dim=-1)
2360            for j in range(7):
2361                self.assertEqual(xs[j].item(), float(j))
2362        for i in range(100):
2363            x.zero_()
2364            for j in range(7):
2365                x[j][-i] = j
2366            xs = x.sum(dim=-1)
2367            for j in range(7):
2368                self.assertEqual(xs[j].item(), float(j))
2369        # 2D case: max/argmax
2370        size = (7, 1024 * 1024 + 3)
2371        x = torch.zeros(size, dtype=dtype, device=device)
2372        for i in range(100):
2373            x.zero_()
2374            for j in range(7):
2375                x[j][i] = j + 1
2376            xs1 = x.argmax(dim=-1)
2377            xs2 = x.max(dim=-1).indices
2378            for j in range(7):
2379                self.assertEqual(xs1[j].item(), i)
2380                self.assertEqual(xs2[j].item(), i)
2381        for i in range(1, 100):
2382            x.zero_()
2383            for j in range(7):
2384                x[j][-i] = j + 1
2385            xs1 = x.argmax(dim=-1)
2386            xs2 = x.max(dim=-1).indices
2387            for j in range(7):
2388                self.assertEqual(xs1[j].item(), size[1] - i)
2389                self.assertEqual(xs2[j].item(), size[1] - i)
2390        # 2D case: min/argmin
2391        size = (7, 1024 * 1024 + 3)
2392        x = torch.zeros(size, dtype=dtype, device=device)
2393        for i in range(100):
2394            x.zero_()
2395            for j in range(7):
2396                x[j][i] = -(j + 1)
2397            xs1 = x.argmin(dim=-1)
2398            xs2 = x.min(dim=-1).indices
2399            for j in range(7):
2400                self.assertEqual(xs1[j].item(), i)
2401                self.assertEqual(xs2[j].item(), i)
2402        for i in range(1, 100):
2403            x.zero_()
2404            for j in range(7):
2405                x[j][-i] = -(j + 1)
2406            xs1 = x.argmin(dim=-1)
2407            xs2 = x.min(dim=-1).indices
2408            for j in range(7):
2409                self.assertEqual(xs1[j].item(), size[1] - i)
2410                self.assertEqual(xs2[j].item(), size[1] - i)
2411
2412    @onlyCUDA
2413    @dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2414    def test_reduction_vectorize_along_output(self, device, dtype):
2415        def run_test(input_):
2416            M, N = input_.shape
2417            input_.zero_()
2418            for i in range(min(M, N)):
2419                input_[i][i] = 1
2420            output1 = input_.argmax(dim=0)
2421            output2 = input_.sum(dim=0)
2422            for i in range(min(M, N)):
2423                self.assertEqual(output1[i], i)
2424                self.assertEqual(output2[i], 1)
2425        # vec 4
2426        run_test(torch.zeros(64, 64, dtype=dtype, device=device))
2427        # vec 2
2428        run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64))
2429        run_test(torch.zeros(64, 62, dtype=dtype, device=device))
2430        run_test(torch.zeros(64, 2, dtype=dtype, device=device))
2431        # vec 1
2432        run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64))
2433        run_test(torch.zeros(64, 61, dtype=dtype, device=device))
2434        run_test(torch.zeros(64, 1, dtype=dtype, device=device))
2435
2436    @onlyCUDA
2437    def test_argminmax_large_axis(self, device):
2438        # Regression test for gh-32863
2439        x = torch.zeros(2**31, device=device, dtype=torch.int8)
2440        x[-1] = 1
2441        self.assertEqual(x.argmax(0), x.shape[0] - 1)
2442        self.assertEqual(x.max(0).indices, x.shape[0] - 1)
2443        x[-1] = -1
2444        self.assertEqual(x.argmin(0), x.shape[0] - 1)
2445        self.assertEqual(x.min(0).indices, x.shape[0] - 1)
2446
2447    def test_argminmax_axis_with_dim_one(self, device):
2448        # See: https://github.com/pytorch/pytorch/issues/38922
2449        n = 32768
2450        x = torch.zeros(1, n)
2451        self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64))
2452        self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64))
2453
2454        self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64))
2455        self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64))
2456
2457        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2458        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2459
2460        self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2461        self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2462
2463    @dtypes(torch.int, torch.long, torch.float, torch.double)
2464    @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double)
2465    def test_median_real_values(self, device, dtype):
2466        # Generate random 0-3D sizes
2467        sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2468        for size in sizes:
2469            # Create random input tensor
2470            t = torch.randn(size, device=device).type(dtype)
2471            t_numpy = t.cpu().numpy()
2472            res = t.median()
2473            self.assertEqual(res, t.nanmedian())
2474            k = int((t.numel() - 1) / 2)
2475            self.assertEqual(res, t.view(-1).sort()[0][k])
2476            if t.numel() % 2 == 1:
2477                # We can only test agains numpy for odd reductions because numpy
2478                # returns the mean of the two medians and torch returns the lower
2479                self.assertEqual(res.cpu().numpy(), np.median(t_numpy))
2480            for dim in range(t.ndim):
2481                res = t.median(dim, True)
2482                self.assertEqual(res, t.nanmedian(dim, True))
2483                size = t.size(dim) if t.ndim > 0 else 1
2484                k = int((size - 1) / 2)
2485                self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim))
2486                self.assertEqual(res[0], t.gather(dim, res[1]))
2487                if size % 2 == 1:
2488                    # We can only test agains numpy for odd reductions because numpy
2489                    # returns the mean of the two medians and torch returns the lower
2490                    self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False)
2491
2492    @dtypes(torch.float, torch.double)
2493    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2494    def test_median_nan_values(self, device, dtype):
2495        # Generate random 0-3D sizes
2496        sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2497        for size in sizes:
2498            # Create random input tensor with nan values
2499            t = torch.rand(size, device=device, dtype=dtype)
2500            t.masked_fill_(t < 0.1, float('nan'))
2501            t_numpy = t.cpu().numpy()
2502            for op in [torch.median, torch.nanmedian]:
2503                numpy_op = np.median if op == torch.median else np.nanmedian
2504                res = op(t)
2505                num_nan = t.isnan().sum()
2506                if op == torch.median and num_nan > 0:
2507                    k = t.numel() - 1
2508                else:
2509                    k = int((t.numel() - num_nan - 1) / 2)
2510                self.assertEqual(res, t.view(-1).sort()[0][k])
2511                if (t.numel() - num_nan) % 2 == 1:
2512                    # We can only test agains numpy for odd reductions because numpy
2513                    # returns the mean of the two medians and torch returns the lower
2514                    self.assertEqual(res.item(), numpy_op(t.cpu().numpy()))
2515                for dim in range(t.ndim):
2516                    res = op(t, dim, True)
2517                    size = t.size(dim) if t.ndim > 0 else 1
2518                    num_nan = t.isnan().sum(dim, True)
2519                    if op == torch.median:
2520                        k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2))
2521                    else:
2522                        k = ((size - num_nan - 1) / 2).type(torch.long)
2523                    self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k))
2524                    self.assertEqual(res[0], t.gather(dim, res[1]))
2525                    # We can only test agains numpy for odd reductions because numpy
2526                    # returns the mean of the two medians and torch returns the lower
2527                    mask = (size - num_nan) % 2 == 1
2528                    res = res[0].masked_select(mask).cpu()
2529                    ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()]
2530                    self.assertEqual(res, torch.from_numpy(ref))
2531
2532    def test_median_corner_cases(self, device):
2533        def check(op, a, args, key):
2534            t = torch.tensor(a, device=device)
2535            res = op(t, *args)
2536            if not args:
2537                key = torch.tensor(key, device=device)
2538            else:
2539                if len(key) == 1:
2540                    key = torch.tensor(key[0], device=device)
2541                    res = res[0]
2542                else:
2543                    key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device))
2544            self.assertEqual(res, key)
2545
2546        nan = float('nan')
2547        check(torch.median, nan, [], nan)
2548        check(torch.median, [], [], nan)
2549        check(torch.nanmedian, nan, [], nan)
2550        check(torch.median, nan, [0], [nan, 0])
2551        check(torch.nanmedian, nan, [0], [nan, 0])
2552        check(torch.median, [nan], [0, True], [[nan], [0]])
2553        check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2554        check(torch.median, [nan], [0, True], [[nan], [0]])
2555        check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2556
2557        # Indices are not deterministic here so can only check values
2558        check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]])
2559        check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]])
2560        check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
2561        check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
2562
2563        # Discontiguous and strided tensors
2564        a = torch.arange(12, device=device)
2565        self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
2566        self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
2567
2568        a.resize_(3, 4)
2569        self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2570        self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2571        self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
2572        self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
2573
2574        a.resize_(2, 3, 2)
2575        self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2576        self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2577        self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2578        self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2579
2580
2581    @onlyNativeDeviceTypes
2582    @dtypes(torch.float, torch.double)
2583    def test_quantile(self, device, dtype):
2584        # Generate some random test cases
2585        ops = ['quantile', 'nanquantile']
2586        inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)]
2587        quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)]
2588        keepdims = [True, False]
2589
2590        # Add corner cases
2591        inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)])
2592        inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]])
2593        inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]])
2594        quantiles.extend([0.5, [0., 1.], np.random.rand(10)])
2595
2596        # Enumerate all input combinations
2597        for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims):
2598            if type(x) is tuple:
2599                a = torch.randn(x, dtype=dtype, device=device)
2600                # Make some random elements NaN
2601                a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan'))
2602            else:
2603                a = torch.tensor(x, dtype=dtype, device=device)
2604
2605            q = torch.tensor(q, dtype=dtype, device=device)
2606
2607            torch_op = getattr(torch, op)
2608            numpy_op = getattr(np, op)
2609
2610            # Compute quantile along every dimension and flattened tensor
2611            interpolations = ('linear', 'lower', 'higher', 'midpoint', 'nearest')
2612            for interpolation, dim in product(interpolations,
2613                                              [None] + list(range(a.ndim))):
2614                result = torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation)
2615                expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim,
2616                                    interpolation=interpolation, keepdims=keepdim)
2617                self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type()))
2618
2619                # Test out variation
2620                out = torch.empty_like(result)
2621                torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out)
2622                self.assertEqual(out.cpu(), result.cpu())
2623
2624    def test_quantile_backward(self, device):
2625        def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)):
2626            for op in ops:
2627                t = torch.tensor(a, device=device, requires_grad=True)
2628                op(t, torch.tensor(q, device=device), dim).sum().backward()
2629                self.assertEqual(t.grad, expected_grad)
2630
2631        check([1., 2, 3], 0.5, 0, [0, 1, 0])
2632        check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0])
2633        check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5])
2634        check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25])
2635        check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]])
2636        check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]])
2637        check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile])
2638        check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile])
2639
2640    def test_quantile_error(self, device):
2641        def check(a, q, args, kwargs, message):
2642            with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message):
2643                at = torch.tensor(a, device=device)
2644                qt = torch.tensor(q, device=device) if isinstance(q, list) else q
2645                torch.quantile(at, qt, *args, **kwargs)
2646
2647        check([], 0.5, [], {}, r'input tensor must be non-empty')
2648        check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor')
2649        check([1], 0.5, [], {}, r'input tensor must be either float or double dtype')
2650        check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor')
2651        check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1')
2652        check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1')
2653        check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.int32, device=device)},
2654              r'out tensor must be same dtype as the input tensor')
2655        check([1.], [1.], [None, False], {'interpolation': 'random_mode'},
2656              r"interpolation must be one of linear, lower, higher, midpoint or nearest, but got random_mode")
2657
2658        if self.device_type == "cpu":
2659            check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]')
2660
2661        if self.device_type == "cuda":
2662            with self.assertRaisesRegex(
2663                    RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'):
2664                torch.randn(1, device=device).quantile(torch.tensor(0.5))
2665            with self.assertRaisesRegex(
2666                    RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'):
2667                torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1))
2668
2669    def test_std_mean(self, device):
2670        x = torch.rand(100, 50, 20, device=device)
2671        for dim in range(x.dim()):
2672            for unbiased in [False, True]:
2673                for keepdim in [False, True]:
2674                    std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2675                    std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2676                    mean2 = x.mean(dim=dim, keepdim=keepdim)
2677                    self.assertEqual(std1, std2)
2678                    self.assertEqual(mean1, mean2)
2679
2680    def test_std_mean_all_dims(self, device):
2681        x = torch.rand(100, 50, 20, device=device)
2682        for unbiased in [False, True]:
2683            std1, mean1 = torch.std_mean(x, unbiased=unbiased)
2684            std2 = x.std(unbiased=unbiased)
2685            mean2 = x.mean()
2686            self.assertEqual(std1, std2)
2687            self.assertEqual(mean1, mean2)
2688
2689    def test_var_mean(self, device):
2690        x = torch.rand(100, 300, 50, device=device)
2691        for dim in range(x.dim()):
2692            for unbiased in [False, True]:
2693                for keepdim in [False, True]:
2694                    var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2695                    var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
2696                    mean2 = x.mean(dim=dim, keepdim=keepdim)
2697                    self.assertEqual(var1, var2)
2698                    self.assertEqual(mean1, mean2)
2699
2700    def test_var_mean_all_dims(self, device):
2701        x = torch.rand(100, 50, 20, device=device)
2702        for unbiased in [False, True]:
2703            var1, mean1 = torch.var_mean(x, unbiased=unbiased)
2704            var2 = x.var(unbiased=unbiased)
2705            mean2 = x.mean()
2706            self.assertEqual(var1, var2)
2707            self.assertEqual(mean1, mean2)
2708
2709    def test_std_mean_some_dims(self, device):
2710        sizes = (4, 6, 7, 5, 3)
2711        dims = len(sizes)
2712        x = torch.rand(sizes, device=device)
2713        for num_of_dims in range(2, dims):
2714            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2715            for dim in dim_list:
2716                for unbiased in [False, True]:
2717                    for keepdim in [False, True]:
2718                        std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2719                        std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2720                        mean2 = x.mean(dim=dim, keepdim=keepdim)
2721                        self.assertEqual(std1, std2)
2722                        self.assertEqual(mean1, mean2)
2723
2724    def _compare_std_var_with_numpy(self, op, device, dtype, input, dim,
2725                                    keepdim, unbiased, use_out):
2726        a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy()
2727        numpy_kwargs = {
2728            'axis' : dim,
2729            'keepdims' : keepdim,
2730            'ddof' : 1 if unbiased else 0,
2731        }
2732
2733        if dim is None:
2734            del numpy_kwargs['axis']
2735            del numpy_kwargs['keepdims']
2736
2737        if op == 'var':
2738            torch_op = torch.var
2739            numpy_op = np.var
2740        elif op == 'std':
2741            torch_op = torch.std
2742            numpy_op = np.std
2743        else:
2744            self.fail("Unknown op!")
2745
2746        numpy_result = numpy_op(a, **numpy_kwargs)
2747
2748        if dim is None and use_out is False:
2749            torch_result = torch_op(input, unbiased)
2750        elif dim is not None and use_out is False:
2751            torch_result = torch_op(input, dim, unbiased, keepdim)
2752        elif dim is not None and use_out is True:
2753            out = torch.empty(0, device=device, dtype=dtype)
2754            torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2755        else:
2756            out = torch.empty(0, device=device, dtype=dtype)
2757            torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2758
2759        exact_dtype = input.dtype not in (torch.bfloat16, torch.complex32, torch.complex64, torch.complex128)
2760        self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype)
2761
2762    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2763    def test_var_vs_numpy(self, device, dtype):
2764        _size = (20, 20)
2765
2766        for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2767                                 (None, 0, 1),
2768                                 (False, True),
2769                                 (False, True),
2770                                 (False, True),):
2771            self._compare_std_var_with_numpy('var', device, dtype, *test_case)
2772
2773    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2774    def test_std_vs_numpy(self, device, dtype):
2775        _size = (20, 20)
2776
2777        for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2778                                 (None, 0, 1),
2779                                 (False, True),
2780                                 (False, True),
2781                                 (False, True),):
2782            self._compare_std_var_with_numpy('std', device, dtype, *test_case)
2783
2784    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2785    def test_var_correction_vs_numpy(self, device, dtype):
2786        _size = (20, 20)
2787        test_args = [
2788            *product(
2789                # dim
2790                (None, 0, 1),
2791                # correction
2792                (None, 0, 10, 30),
2793                # keepdim
2794                (False, True),
2795            ),
2796            [None, -100, True],  # Negative correction
2797        ]
2798
2799        tensor = make_tensor(_size, device=device, dtype=dtype)
2800        array = tensor.cpu().numpy()
2801
2802        for dim, correction, keepdim in test_args:
2803            numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2804            if correction is None:
2805                # NumPy default is not compatible with torch.std (gh-50010)
2806                numpy_kwargs['ddof'] = 1
2807
2808            numpy_res = np.asarray(np.var(array, **numpy_kwargs))
2809            torch_res = torch.var(tensor, dim=dim, correction=correction, keepdim=keepdim)
2810
2811            # inf vs. nan results are sensitive to machine precision,
2812            # just treat them as equivalent
2813            numpy_res[np.isinf(numpy_res)] = np.nan
2814            torch_res[torch_res.isinf()] = np.nan
2815
2816            self.assertEqual(torch_res, numpy_res)
2817
2818    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2819    def test_std_correction_vs_numpy(self, device, dtype):
2820        _size = (20, 20)
2821        test_args = [
2822            *product(
2823                # dim
2824                (None, 0, 1),
2825                # correction
2826                (None, 0, 10, 30),
2827                # keepdim
2828                (False, True),
2829            ),
2830            [None, -100, True],  # Negative correction
2831        ]
2832
2833        tensor = make_tensor(_size, device=device, dtype=dtype)
2834        array = tensor.cpu().numpy()
2835
2836        for dim, correction, keepdim in test_args:
2837            numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2838            if correction is None:
2839                # NumPy default is incompatible with torch.std (gh-50010)
2840                numpy_kwargs['ddof'] = 1
2841
2842            numpy_res = np.asarray(np.std(array, **numpy_kwargs))
2843            torch_res = torch.std(tensor, dim=dim, correction=correction, keepdim=keepdim)
2844
2845            # inf vs. nan results are sensitive to machine precision,
2846            # just treat them as equivalent
2847            numpy_res[np.isinf(numpy_res)] = np.nan
2848            torch_res[torch_res.isinf()] = np.nan
2849
2850            self.assertEqual(torch_res, numpy_res)
2851
2852    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2853    def test_std_mean_correction(self, device, dtype):
2854        _size = (20, 20)
2855        test_args = [
2856            *product(
2857                # dim
2858                (None, 0, 1),
2859                # correction
2860                (None, 0, 10, 30),
2861                # keepdim
2862                (False, True),
2863            ),
2864            [None, -100, True],  # Negative correction
2865        ]
2866
2867        tensor = make_tensor(_size, device=device, dtype=dtype)
2868
2869        for dim, correction, keepdim in test_args:
2870            kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2871            std1 = torch.std(tensor, **kwargs)
2872            if dim is not None:
2873                mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2874            else:
2875                mean1 = torch.mean(tensor)
2876                if keepdim:
2877                    mean1 = mean1.reshape((1,) * tensor.ndim)
2878            std2, mean2 = torch.std_mean(tensor, **kwargs)
2879
2880            self.assertEqual(std1, std2)
2881            self.assertEqual(mean1, mean2)
2882
2883    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2884    def test_var_mean_correction(self, device, dtype):
2885        _size = (20, 20)
2886        test_args = [
2887            *product(
2888                # dim
2889                (None, 0, 1),
2890                # correction
2891                (None, 0, 10, 30),
2892                # keepdim
2893                (False, True),
2894            ),
2895            [None, -100, True],  # Negative correction
2896        ]
2897
2898        tensor = make_tensor(_size, device=device, dtype=dtype)
2899
2900        for dim, correction, keepdim in test_args:
2901            kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2902            var1 = torch.var(tensor, **kwargs)
2903            if dim is not None:
2904                mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2905            else:
2906                mean1 = torch.mean(tensor)
2907                if keepdim:
2908                    mean1 = mean1.reshape((1,) * tensor.ndim)
2909            var2, mean2 = torch.var_mean(tensor, **kwargs)
2910
2911            self.assertEqual(var1, var2)
2912            self.assertEqual(mean1, mean2)
2913
2914    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2915    def test_warn_invalid_degrees_of_freedom(self, device, dtype):
2916        def _assert_warning(_func, _tensor, _correction):
2917            with warnings.catch_warnings(record=True) as w:
2918                _func(_tensor, dim=-1, correction=_correction)
2919            self.assertIn('degrees of freedom is <= 0', str(w[0].message))
2920
2921        correction = 20
2922        size = (10, correction)
2923        tensor = make_tensor(size, dtype=dtype, device=device)
2924        for f in [torch.std, torch.var, torch.var_mean, torch.std_mean]:
2925            _assert_warning(f, tensor, correction)
2926
2927    def test_amin_amax_some_dims(self, device):
2928        sizes = (4, 6, 7, 5, 3)
2929        dims = len(sizes)
2930        x = torch.rand(sizes, device=device)
2931        for num_of_dims in range(2, dims):
2932            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2933            for dim in dim_list:
2934                for keepdim in [False, True]:
2935                    amin1 = torch.amin(x, dim=dim, keepdim=keepdim)
2936                    amax1 = torch.amax(x, dim=dim, keepdim=keepdim)
2937                    amin2 = x
2938                    amax2 = x
2939                    for i, d in enumerate(dim):
2940                        if not keepdim:
2941                            d -= i
2942                        amin2 = torch.amin(amin2, dim=d, keepdim=keepdim)
2943                        amax2 = torch.amax(amax2, dim=d, keepdim=keepdim)
2944                    self.assertEqual(amin1, amin2)
2945                    self.assertEqual(amax1, amax2)
2946
2947    def test_histc(self, device):
2948        # negative nbins throws
2949        with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
2950            torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2951        # empty tensor
2952        actual = torch.histc(torch.tensor([], device=device), min=0, max=3)
2953        expected = torch.zeros(100, dtype=torch.float, device=device)
2954        self.assertEqual(expected, actual)
2955
2956        # without nbins
2957        actual = torch.histc(
2958            torch.tensor([2, 5], dtype=torch.float, device=device))
2959        expected = torch.zeros(100, dtype=torch.float, device=device)
2960        expected[0] = 1
2961        expected[99] = 1
2962        self.assertEqual(expected, actual)
2963        # tensor with the same element
2964        actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2965        self.assertEqual(
2966            torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2967            actual)
2968        # no element falls between [min, max]
2969        actual = torch.histc(
2970            torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2971        self.assertEqual(
2972            torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2973            actual)
2974        # element falls below min + integral bin size and
2975        actual = torch.histc(
2976            torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2977            bins=5, min=1, max=5)
2978        self.assertEqual(
2979            torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2980            actual)
2981        # non-integral bin size
2982        actual = torch.histc(
2983            torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2984            bins=4, min=0, max=3)
2985        self.assertEqual(
2986            torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2987            actual)
2988        # double input
2989        actual = torch.histc(
2990            torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3)
2991        self.assertEqual(
2992            torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2993            actual)
2994        self.assertEqual(actual.dtype, torch.double)
2995        # mixed input
2996        actual = torch.histc(
2997            torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2998            bins=4, min=0, max=3)
2999        self.assertEqual(
3000            torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
3001            actual)
3002        self.assertEqual(actual.dtype, torch.float)
3003        # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar.
3004        actual = torch.histc(
3005            torch.tensor(0, dtype=torch.float, device=device),
3006            bins=1, min=0, max=3)
3007        self.assertEqual(
3008            torch.tensor([1], dtype=torch.float, device=device),
3009            actual)
3010        # tensors with inf; min, max not provided -- should throw a RuntimeError
3011        with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'):
3012            torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device))
3013        with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'):
3014            torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device))
3015        # tensors with inf; min, max provided
3016        self.assertEqual(
3017            torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device),
3018                        bins=1, min=0, max=3),
3019            torch.tensor([0], dtype=torch.float, device=device))
3020        self.assertEqual(
3021            torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device),
3022                        bins=4, max=3),
3023            torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
3024        # tensor with nan; min, max not provided -- should throw a RuntimeError
3025        with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'):
3026            torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device))
3027        # tensor with nan; min, max provided -- nan is ignored
3028        self.assertEqual(
3029            torch.histc(torch.tensor([1., 2., float("nan")], dtype=torch.float, device=device),
3030                        bins=4, max=3),
3031            torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
3032        # tensors with min > max -- should throw a RuntimeError
3033        with self.assertRaisesRegex(RuntimeError, "max must be larger than min"):
3034            torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device),
3035                        bins=4, min=5, max=1)
3036
3037        # test against numpy.histogram()
3038        def test_against_np(tensor, bins=100, min=0, max=0):
3039            if min == 0 and max == 0:
3040                min = tensor.min().item()
3041                max = tensor.max().item()
3042            nparr = tensor.cpu().numpy()
3043            actual = torch.histc(tensor, bins=bins, min=min, max=max)
3044            expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
3045            actual_cpu = actual.cpu()
3046            # NB: Numpy returns a int64 tensor, like normal people...
3047            self.assertEqual(actual, expected.to(actual_cpu))
3048
3049        test_against_np(torch.tensor([1., 2, 1], device=device))
3050        test_against_np(torch.randn(5000, device=device))
3051
3052        # Test bins arg
3053        test_against_np(torch.randn(301, device=device), bins=10)
3054
3055        # Test truncated range
3056        test_against_np(torch.randn(201, device=device), min=0.1, max=1)
3057
3058        noncontig = torch.randn(100, 3, device=device)[:, 2]
3059        test_against_np(noncontig)
3060
3061        multidim = torch.randn(3, 5, 7, 2, device=device)
3062        test_against_np(multidim)
3063
3064        expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
3065        test_against_np(expanded)
3066
3067        linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device)
3068        test_against_np(linear, bins=20, min=0, max=0.99)
3069
3070    @onlyCPU
3071    @dtypes(torch.bfloat16, torch.half)
3072    def test_histc_lowp(self, device, dtype):
3073        actual = torch.histc(
3074            torch.tensor([1, 2, 1], dtype=dtype, device=device), bins=4, min=0, max=3)
3075        self.assertEqual(
3076            torch.tensor([0, 2, 1, 0], dtype=dtype, device=device),
3077            actual)
3078        self.assertEqual(actual.dtype, dtype)
3079
3080    """
3081    Runs torch.histogram and numpy.histogram on the specified input parameters
3082    and asserts that their output is equal.
3083    """
3084    def _test_histogram_numpy(self, t, bins, bin_range, weights, density):
3085        def to_np(t):
3086            if not torch.is_tensor(t):
3087                return t
3088            else:
3089                return t.cpu().numpy()
3090
3091        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
3092        def reference_histogram(self, t, bins, bin_range, weights, density, dtype):
3093            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3094            (np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density)
3095            return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype))
3096
3097        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
3098        if bin_range:
3099            (actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density)
3100        else:
3101            (actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density)
3102
3103        (expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype)
3104
3105        """
3106        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3107        When bin edges are not explicitly defined, histogram uses the linspace operator internally
3108        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3109        from numpy.linspace output.
3110        Issue: https://github.com/pytorch/pytorch/issues/58758
3111        """
3112        if not torch.is_tensor(bins):
3113            self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5)
3114            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
3115            (expected_hist, expected_bin_edges) = reference_histogram(
3116                self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3117
3118        self.assertEqual(actual_hist, expected_hist)
3119        self.assertEqual(actual_bin_edges, expected_bin_edges)
3120
3121        # Test passing non-contiguous output tensors
3122        hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype,
3123                               noncontiguous=True)
3124        bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype,
3125                                    noncontiguous=True)
3126
3127        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
3128        if bin_range:
3129            torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out))
3130        else:
3131            torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out))
3132
3133        self.assertEqual(hist_out, expected_hist)
3134        self.assertEqual(bin_edges_out, expected_bin_edges)
3135
3136    @onlyCPU
3137    @dtypes(torch.float32)
3138    def test_histogram(self, device, dtype):
3139        shapes = (
3140            (),
3141            (0,),
3142            (1,),
3143            (1, 5),
3144            (3, 5),
3145            (1, 5, 1),
3146            (2, 3, 5))
3147
3148        for contig, bins_contig, bin_ct, weighted, density, shape in \
3149                product([True, False], [True, False], range(1, 10), [True, False], [True, False], shapes):
3150            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3151            weights = make_tensor(shape, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) if weighted else None
3152
3153            # Tests passing just the bin_ct
3154            self._test_histogram_numpy(values, bin_ct, None, weights, density)
3155
3156            # Tests with caller-specified histogram range
3157            bin_range = sorted((random.uniform(-9, 9), random.uniform(-9, 9)))
3158            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3159
3160            # Tests with range min=max
3161            bin_range[1] = bin_range[0]
3162            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3163
3164            # Tests with caller-specified bin edges
3165            bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort()
3166            if not bins_contig:
3167                # Necessary because msort always produces contiguous output
3168                bin_edges_noncontig = make_tensor(bin_ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3169                bin_edges_noncontig.copy_(bin_edges)
3170                bin_edges = bin_edges_noncontig
3171            self.assertEqual(bin_edges.is_contiguous(), bins_contig)
3172            self._test_histogram_numpy(values, bin_edges, None, weights, density)
3173
3174            # Tests with input tensor in which all elements are equal
3175            elt = random.uniform(-9, 9)
3176            values = make_tensor(shape, dtype=dtype, device=device, low=elt, high=elt, noncontiguous=not contig)
3177            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3178            self._test_histogram_numpy(values, bin_edges, None, weights, density)
3179
3180            # Tests with input equal to bin_edges
3181            weights = (
3182                make_tensor(bin_ct + 1, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3183                if weighted
3184                else None
3185            )
3186            self._test_histogram_numpy(bin_edges, bin_edges, None, weights, density)
3187
3188        # Tests values of default args
3189        for bin_ct, shape in product(range(1, 10), shapes):
3190            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
3191            (actual_hist, actual_bin_edges) = torch.histogram(values, bin_ct)
3192            (expected_hist, expected_bin_edges) = torch.histogram(
3193                values, bin_ct, range=None, weight=None, density=False)
3194            self.assertEqual(actual_hist, expected_hist)
3195            self.assertEqual(actual_bin_edges, expected_bin_edges)
3196
3197    """
3198    Runs torch.histogramdd and numpy.histogramdd on the specified input parameters
3199    and asserts that their output is equal.
3200    """
3201    def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
3202        def to_np(t):
3203            if type(t) == list:
3204                return list(map(to_np, t))
3205            if not torch.is_tensor(t):
3206                return t
3207            return t.cpu().numpy()
3208
3209        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
3210        def reference_histogramdd(t, bins, bin_range, weights, density, dtype):
3211            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3212
3213            # numpy.histogramdd accepts only (N, D) shapes
3214            D = np_t.shape[-1]
3215            N = np.prod(np_t.shape[:-1])
3216            reshaped_t = np.reshape(np_t, (N, D))
3217            reshaped_wt = np.reshape(np_weights, (N,)) if np_weights is not None else None
3218
3219            # numpy.histogramdd throws an error for D=0
3220            if D == 0:
3221                return (torch.tensor(float('nan') if density else 0.), [])
3222
3223            # numpy.histogramdd expects range to be specified as a sequence of D (lower, upper) tuples
3224            reshaped_range = None if not bin_range else [(bin_range[2 * i], bin_range[2 * i + 1]) for i in range(D)]
3225
3226            (np_hist, np_bin_edges) = np.histogramdd(reshaped_t, np_bins,
3227                                                     range=reshaped_range, weights=reshaped_wt, density=density)
3228
3229            return (torch.from_numpy(np_hist).to(dtype), [torch.from_numpy(t).to(dtype) for t in np_bin_edges])
3230
3231        (actual_hist, actual_bin_edges) = torch.histogramdd(t, bins, range=bin_range, weight=weights, density=density)
3232        (expected_hist, expected_bin_edges) = reference_histogramdd(t, bins, bin_range, weights, density, actual_hist.dtype)
3233
3234        D = len(actual_bin_edges)
3235        self.assertEqual(D, len(expected_bin_edges))
3236
3237        """
3238        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3239        When bin edges are not explicitly defined, histogram uses the linspace operator internally
3240        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3241        from numpy.linspace output.
3242        Issue: https://github.com/pytorch/pytorch/issues/58758
3243        """
3244        if not torch.is_tensor(bins):
3245            for dim in range(D):
3246                self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5)
3247            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
3248            (expected_hist, expected_bin_edges) = reference_histogramdd(
3249                t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3250            self.assertEqual(D, len(expected_bin_edges))
3251
3252        self.assertEqual(actual_hist, expected_hist)
3253        for dim in range(D):
3254            self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim])
3255
3256    @onlyCPU
3257    @dtypes(torch.float32)
3258    def test_histogramdd(self, device, dtype):
3259        shapes = (
3260            (1, 5),
3261            (3, 5),
3262            (1, 5, 1),
3263            (2, 3, 5),
3264            (7, 7, 7, 7),
3265            (16, 8, 4, 2),
3266            (10, 10, 10),
3267            (7, 0, 3),
3268            (5, 0),)
3269
3270        for contig, bins_contig, weighted, density, shape in \
3271                product([True, False], [True, False], [True, False], [True, False], shapes):
3272            D = shape[-1]
3273
3274            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3275            weights = (
3276                make_tensor(shape[:-1], dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3277                if weighted
3278                else None
3279            )
3280
3281            # Tests passing a single bin count
3282            bin_ct = random.randint(1, 5)
3283            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3284
3285            # Tests passing a bin count for each dimension
3286            bin_ct = [random.randint(1, 5) for dim in range(D)]
3287            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3288
3289            # Tests with caller-specified histogram range
3290            bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)]
3291            bin_range = [elt for t in bin_range_tuples for elt in t]
3292            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3293
3294            # Tests with range min=max
3295            for dim in range(D):
3296                bin_range[2 * dim + 1] = bin_range[2 * dim]
3297            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3298
3299            # Tests with caller-specified bin edges
3300            bin_edges = [make_tensor(ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() for ct in bin_ct]
3301            if not bins_contig:
3302                # Necessary because msort always produces contiguous output
3303                bin_edges_noncontig = [
3304                    make_tensor(ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3305                    for ct in bin_ct
3306                ]
3307                for dim in range(D):
3308                    bin_edges_noncontig[dim].copy_(bin_edges[dim])
3309                bin_edges = bin_edges_noncontig
3310            for dim in range(D):
3311                self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig)
3312            self._test_histogramdd_numpy(values, bin_edges, None, weights, density)
3313
3314    @onlyCPU
3315    @dtypes(torch.float32)
3316    def test_histogram_error_handling(self, device, dtype):
3317        with self.assertRaisesRegex(RuntimeError, 'not implemented for'):
3318            values = make_tensor((), dtype=torch.int32, device=device)
3319            torch.histogram(values, 1)
3320
3321        inconsistent_dtype = torch.float32 if dtype != torch.float32 else torch.float64
3322
3323        with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensors should have the same dtype'):
3324            values = make_tensor((), dtype=dtype, device=device)
3325            bins = make_tensor((), dtype=inconsistent_dtype, device=device)
3326            torch.histogram(values, bins)
3327
3328        with self.assertRaisesRegex(RuntimeError, 'input tensor and weight tensor should have the same dtype'):
3329            values = make_tensor((), dtype=dtype, device=device)
3330            weight = make_tensor((), dtype=inconsistent_dtype, device=device)
3331            torch.histogram(values, 1, weight=weight)
3332
3333        with self.assertRaisesRegex(RuntimeError, 'input tensor and hist tensor should have the same dtype'):
3334            values = make_tensor((), dtype=dtype, device=device)
3335            hist = make_tensor((), dtype=inconsistent_dtype, device=device)
3336            bin_edges = make_tensor((), dtype=dtype, device=device)
3337            torch.histogram(values, 1, out=(hist, bin_edges))
3338
3339        with self.assertRaisesRegex(RuntimeError, 'input tensor and bin_edges tensor should have the same dtype'):
3340            values = make_tensor((), dtype=dtype, device=device)
3341            hist = make_tensor((), dtype=dtype, device=device)
3342            bin_edges = make_tensor((), dtype=inconsistent_dtype, device=device)
3343            torch.histogram(values, 1, out=(hist, bin_edges))
3344
3345        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have one dimension'):
3346            t = make_tensor((2, 2), dtype=dtype, device=device)
3347            torch.histogram(t, t)
3348
3349        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have at least 1 element'):
3350            t = make_tensor((0), dtype=dtype, device=device)
3351            torch.histogram(t, t)
3352
3353        with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
3354            values = make_tensor((), dtype=dtype, device=device)
3355            torch.histogram(values, -1)
3356
3357        with self.assertRaisesRegex(RuntimeError, 'if weight tensor is provided it should have the same shape \
3358as the input tensor excluding its innermost dimension'):
3359            values = make_tensor((2, 2), dtype=dtype, device=device)
3360            weight = make_tensor((1), dtype=dtype, device=device)
3361            torch.histogram(values, 1, weight=weight)
3362
3363        with self.assertRaisesRegex(TypeError, 'received an invalid combination of arguments'):
3364            values = make_tensor((), dtype=dtype, device=device)
3365            bin_edges = make_tensor((), dtype=dtype, device=device)
3366            torch.histogram(values, bin_edges, range=(0, 1))
3367
3368        with self.assertRaisesRegex(RuntimeError, 'min should not exceed max'):
3369            values = make_tensor((), dtype=dtype, device=device)
3370            torch.histogram(values, 2, range=(1, 0))
3371
3372        with self.assertRaisesRegex(RuntimeError, r'range \[nan, nan\] is not finite'):
3373            values = torch.tensor([float("nan")], device=device, dtype=dtype)
3374            torch.histogram(values, 2)
3375
3376    # Tests to ensure that reduction functions employing comparison operators are usable when there
3377    # exists a zero dimension (i.e. when the tensors are empty) in the tensor. These tests specifically
3378    # cater to functions where specifying the `dim` parameter is necessary.
3379    def test_tensor_compare_ops_empty(self, device):
3380        shape = (2, 0, 4)
3381        master_input = torch.randn(shape, device=device)
3382        np_input = np.empty(shape)
3383        test_functions = [
3384            ('amax', torch.amax, np.amax),
3385            ('amin', torch.amin, np.amin),
3386            ('max', lambda *args, **kwargs: torch.max(*args, **kwargs).values, np.max),
3387            ('min', lambda *args, **kwargs: torch.min(*args, **kwargs).values, np.min),
3388            ('median', lambda *args, **kwargs: torch.median(*args, **kwargs).values, np.median),
3389        ]
3390
3391        for name, fn, np_function in test_functions:
3392            # Check if reduction happens along the specified dim with and without keepdim. Check with
3393            # numpy to maintain compatibility with numpy functions.
3394            error_msg = f"test function: {name}"
3395            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3396            self.assertEqual(np_function(np_input, axis=2),
3397                             fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False)
3398
3399            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3400            self.assertEqual(np_function(np_input, axis=-1),
3401                             fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False)
3402
3403            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3404                             msg=error_msg)
3405            self.assertEqual(np_function(np_input, axis=2, keepdims=True),
3406                             fn(master_input, dim=2, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3407
3408            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3409                             msg=error_msg)
3410            self.assertEqual(np_function(np_input, axis=-1, keepdims=True),
3411                             fn(master_input, dim=-1, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3412
3413            # Check if function raises error on specified zero'd dimension as reduction dim.
3414            self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3415
3416    # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using comparison operators
3417    # raises an error if no `dim` parameter is specified. This exists separately from tests in
3418    # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does
3419    # not throw errors. Also, checking the return type of argmax requires supplying a different dtype
3420    # argument than that for the input tensor. There is also variantion in numpy testing.
3421    def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device):
3422        shape = (2, 0, 4)
3423        master_input = torch.randn(shape, device=device)
3424        np_input = np.empty(shape)
3425        test_functions = [
3426            ('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax),
3427            ('argmin', torch.argmin, {'dtype': torch.int64}, np.argmin),
3428            ('kthvalue', lambda *args, k=1, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values,
3429             {}, lambda *args, k=1, axis=None, **kwargs: np.partition(*args, k, **kwargs).take(k - 1, axis=axis))
3430        ]
3431
3432        for name, fn, dtype, np_function in test_functions:
3433            error_msg = f"test function: {name}"
3434            self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_msg)
3435            self.assertEqual(
3436                np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False
3437            )
3438
3439            self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_msg)
3440            self.assertEqual(
3441                np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False
3442            )
3443
3444            # keepdim variant does not exist for numpy
3445            self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=True),
3446                             msg=error_msg)
3447            self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=-1, keepdim=True),
3448                             msg=error_msg)
3449
3450            # Check if function raises error on specified zero'd dimension as reduction dim.
3451            self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3452            if name != 'kthvalue':
3453                self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input))
3454
3455    # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using math operators works when a
3456    # non-zero dim is specified for the reduction and throws an error when the dim specified is 0. Although
3457    # there is some repetition with test_tensor_compare_ops_optional_dim_empty and test_tensor_compare_ops_empty,
3458    # these tests are kept separate since tests for math operators also require checking for correctness of the
3459    # returned data using allclose() or isinf() which does not exists in the former tests.
3460    @skipIfNoSciPy
3461    def test_tensor_reduce_ops_empty(self, device):
3462        from scipy.special import logsumexp
3463        shape = (2, 0, 4)
3464        master_input = torch.randn(shape, device=device)
3465        np_input = np.empty(shape)
3466        test_functions = [
3467            ('prod', torch.prod, 1., np.prod),
3468            ('sum', torch.sum, 0., np.sum),
3469            ('norm', torch.norm, 0., np.linalg.norm),
3470            ('mean', torch.mean, nan, np.mean),
3471            ('var', torch.var, nan, np.var),
3472            ('std', torch.std, nan, np.std),
3473            ('logsumexp', torch.logsumexp, -inf, logsumexp),
3474        ]
3475
3476        for name, fn, return_value, np_function in test_functions:
3477            # Check if reduction happens along the specified dimension.
3478            error_msg = f"test function: {name}"
3479            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3480            self.assertEqual(np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg,
3481                             exact_dtype=False)
3482
3483            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3484            self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg,
3485                             exact_dtype=False)
3486
3487            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3488                             msg=error_msg)
3489            self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True),
3490                             msg=error_msg, exact_dtype=False)
3491
3492            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3493                             msg=error_msg)
3494            self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True),
3495                             msg=error_msg, exact_dtype=False)
3496
3497            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg)
3498            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg)
3499            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True),
3500                             msg=error_msg)
3501            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True),
3502                             msg=error_msg)
3503
3504            if name != 'logsumexp':
3505                # The scipy function does not work for reduction the zero dimension
3506                self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(),
3507                                 msg=error_msg)
3508                self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(),
3509                                 msg=error_msg)
3510                self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)),
3511                                 fn(master_input, dim=1, keepdim=True).cpu().numpy(),
3512                                 msg=error_msg)
3513                self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)),
3514                                 fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
3515                                 msg=error_msg)
3516
3517                # logsumexp throws a type error when not specifying dim so test separately.
3518                self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg)
3519            else:
3520                self.assertRaises(TypeError, lambda: fn(master_input))
3521
3522    # Tests to ensure that any() and all() functions work with zero-dim tensors. Kept separate from
3523    # other tests for checking reduction with zero-dim tensors because these tests have significantly
3524    # different testing behaviour than that used for the former tests.
3525    def test_reduction_empty_any_all(self, device):
3526        shape = (2, 0, 4)
3527        x = torch.randn(shape, device=device)
3528
3529        for dtype in all_types_and_complex_and(torch.half, torch.bool):
3530            # Refer: [all, any uint8 compatibility]
3531            if dtype == torch.uint8:
3532                out_dtype = torch.uint8
3533            else:
3534                out_dtype = torch.bool  # output of all/any is bool irrespective of input dtype
3535
3536            xb = x.to(dtype)
3537            yb = x.to(dtype)
3538            # any
3539            self.assertEqual((2, 0), xb.any(2).shape)
3540            self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape)
3541            self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1))
3542            self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True))
3543            self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any())
3544
3545            # all
3546            self.assertEqual((2, 0), xb.all(2).shape)
3547            self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape)
3548            self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1))
3549            self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True))
3550            self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all())
3551
3552    # TODO: can these be merged with their respective OpInfos?
3553    def test_reduce_dtype(self, device):
3554        def test_reduction(op, has_no_dim, takes_dtype=True):
3555            x = torch.randn(3, 3, dtype=torch.float, requires_grad=True, device=device)
3556
3557            if has_no_dim:
3558                grad1, = torch.autograd.grad([op(x)], [x])
3559                grad2, = torch.autograd.grad([op(x, dtype=torch.double)], [x])
3560                self.assertEqual(grad1, grad2)
3561                self.assertEqual(grad2.dtype, torch.float)
3562
3563            gi = torch.randn(op(x, dim=0).shape, dtype=torch.float, device=device)
3564            grad1, = torch.autograd.grad([op(x, dim=0)], [x], gi)
3565            if takes_dtype:
3566                grad2, = torch.autograd.grad([op(x, dim=0, dtype=torch.double)], [x], gi.double())
3567            else:
3568                grad2, = torch.autograd.grad([op(x.double(), dim=0)], [x], gi.double())
3569            self.assertEqual(grad1, grad2)
3570            self.assertEqual(grad2.dtype, torch.float)
3571
3572        test_reduction(torch.sum, True)
3573        test_reduction(torch.prod, True)
3574        test_reduction(torch.cumsum, False)
3575        test_reduction(torch.cumprod, False)
3576        test_reduction(torch.logcumsumexp, False, takes_dtype=False)
3577
3578    @ops(reference_masked_ops)
3579    def test_reference_masked(self, device, dtype, op):
3580        """Test masked reduction operations on strided-only tensors using
3581        numpy reductions as reference.
3582        """
3583
3584        def to_numpy(input):
3585            if input.dtype is torch.bfloat16:
3586                return input.cpu().to(torch.float32).numpy()
3587            else:
3588                return input.cpu().numpy()
3589
3590        samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
3591        for sample_input in samples:
3592            t = sample_input.input
3593            actual = op(t, *sample_input.args, **sample_input.kwargs)
3594            exact_dtype = not (t.dtype is torch.bfloat16
3595                               or (op.promotes_int_to_float and not torch.is_floating_point(t)))
3596            expected = op.ref(to_numpy(t), *sample_input.args,
3597                              **dict(
3598                                  # `identity` is mapped to numpy reduction `initial` argument
3599                                  identity=torch.masked._reduction_identity(op.name, t),
3600                                  **sample_input.kwargs))
3601
3602            # Workaround https://github.com/pytorch/pytorch/issues/66556
3603            expected = np.asarray(expected)  # transform numpy scalars to numpy.ndarray instances
3604
3605            # Numpy differs, producing uint32 on Windows
3606            if expected.dtype in [np.uint64, np.uint32]:
3607                exact_dtype = False
3608
3609            msg = ("Failed to produce expected results! Input tensor was"
3610                   f" {t}, torch result is {actual}, and reference result is"
3611                   f" {expected}.") if t.numel() < 10 else None
3612
3613            self.assertEqual(actual, expected, msg, exact_dtype=exact_dtype)
3614
3615    @onlyCUDA
3616    @largeTensorTest("8GB")
3617    @dtypes(torch.half, torch.chalf, torch.bfloat16)
3618    def test_reductions_large_half_tensors(self, device, dtype):
3619        t = torch.ones(2**31, device=device, dtype=dtype)
3620        t[2**30:] = -1
3621        expected = torch.tensor(0, device=device, dtype=dtype)
3622        self.assertEqual(torch.sum(t), expected)
3623
3624        # mean_cuda is not implemented for ComplexHalf
3625        err_msg = "not implemented for 'ComplexHalf'"
3626        ctx = self.assertRaisesRegex(
3627            RuntimeError, err_msg) if dtype is torch.chalf else contextlib.nullcontext()
3628        with ctx:
3629            self.assertEqual(torch.mean(t), expected)
3630
3631instantiate_device_type_tests(TestReductions, globals())
3632
3633if __name__ == '__main__':
3634    run_tests()
3635