xref: /aosp_15_r20/external/pytorch/test/test_reductions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerimport numpy as np
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport math
8*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Sequence
9*da0073e9SAndroid Build Coastguard Workerimport random
10*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
11*da0073e9SAndroid Build Coastguard Workerfrom itertools import product, combinations, permutations
12*da0073e9SAndroid Build Coastguard Workerimport warnings
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import (
17*da0073e9SAndroid Build Coastguard Worker    all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and,
18*da0073e9SAndroid Build Coastguard Worker    integral_types_and, floating_and_complex_types_and, all_types_and, all_types,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
21*da0073e9SAndroid Build Coastguard Worker    TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
22*da0073e9SAndroid Build Coastguard Worker    parametrize,
23*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS)
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
25*da0073e9SAndroid Build Coastguard Worker    OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
26*da0073e9SAndroid Build Coastguard Worker    onlyNativeDeviceTypes, onlyCUDA, largeTensorTest, ops, precisionOverride)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import (
28*da0073e9SAndroid Build Coastguard Worker    ReductionOpInfo, ReductionPythonRefInfo, reduction_ops, reference_masked_ops)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker# TODO: replace with make_tensor
31*da0073e9SAndroid Build Coastguard Workerdef _generate_input(shape, dtype, device, with_extremal):
32*da0073e9SAndroid Build Coastguard Worker    if shape == ():
33*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor((), dtype=dtype, device=device)
34*da0073e9SAndroid Build Coastguard Worker    else:
35*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point or dtype.is_complex:
36*da0073e9SAndroid Build Coastguard Worker            # work around torch.randn not being implemented for bfloat16
37*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bfloat16:
38*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*shape, device=device) * random.randint(30, 100)
39*da0073e9SAndroid Build Coastguard Worker                x = x.to(torch.bfloat16)
40*da0073e9SAndroid Build Coastguard Worker            else:
41*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
42*da0073e9SAndroid Build Coastguard Worker            x[torch.randn(*shape) > 0.5] = 0
43*da0073e9SAndroid Build Coastguard Worker            if with_extremal and dtype.is_floating_point:
44*da0073e9SAndroid Build Coastguard Worker                # Use extremal values
45*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float('nan')
46*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float('inf')
47*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = float('-inf')
48*da0073e9SAndroid Build Coastguard Worker            elif with_extremal and dtype.is_complex:
49*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex('nan')
50*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex('inf')
51*da0073e9SAndroid Build Coastguard Worker                x[torch.randn(*shape) > 0.5] = complex('-inf')
52*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.bool:
53*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(shape, dtype=dtype, device=device)
54*da0073e9SAndroid Build Coastguard Worker            x[torch.randn(*shape) > 0.5] = True
55*da0073e9SAndroid Build Coastguard Worker        else:
56*da0073e9SAndroid Build Coastguard Worker            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    return x
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker# TODO: replace with make_tensor
61*da0073e9SAndroid Build Coastguard Workerdef _rand_shape(dim, min_size, max_size):
62*da0073e9SAndroid Build Coastguard Worker    shape = []
63*da0073e9SAndroid Build Coastguard Worker    for i in range(dim):
64*da0073e9SAndroid Build Coastguard Worker        shape.append(random.randint(min_size, max_size))
65*da0073e9SAndroid Build Coastguard Worker    return tuple(shape)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerdef _reduced_shape(shape, dim=None, keepdim=False):
68*da0073e9SAndroid Build Coastguard Worker    """Computes the expected reduced shape given dim and keepdim
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    Args:
71*da0073e9SAndroid Build Coastguard Worker        shape: The shape to reduce
72*da0073e9SAndroid Build Coastguard Worker        dim : The dimensions to reduce
73*da0073e9SAndroid Build Coastguard Worker        keepdim: If true, reduced dimensions have size 1 in the reduced shape,
74*da0073e9SAndroid Build Coastguard Worker            otherwise they are removed from the reduced shape.
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    Returns:
77*da0073e9SAndroid Build Coastguard Worker        The reduced shape
78*da0073e9SAndroid Build Coastguard Worker    """
79*da0073e9SAndroid Build Coastguard Worker    if dim is None:
80*da0073e9SAndroid Build Coastguard Worker        return [1] * len(shape) if keepdim else []
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    # Wrap negative dims
83*da0073e9SAndroid Build Coastguard Worker    dim = dim if isinstance(dim, Sequence) else [dim]
84*da0073e9SAndroid Build Coastguard Worker    dim = {i if i >= 0 else len(shape) + i for i in dim}
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    result = []
87*da0073e9SAndroid Build Coastguard Worker    for i, size in enumerate(shape):
88*da0073e9SAndroid Build Coastguard Worker        if i not in dim:
89*da0073e9SAndroid Build Coastguard Worker            result.append(size)
90*da0073e9SAndroid Build Coastguard Worker        elif keepdim:
91*da0073e9SAndroid Build Coastguard Worker            result.append(1)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    return result
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Workerclass TestReductions(TestCase):
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    ###########################################################################
98*da0073e9SAndroid Build Coastguard Worker    # ReductionOpInfo unit tests
99*da0073e9SAndroid Build Coastguard Worker    ###########################################################################
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim):
102*da0073e9SAndroid Build Coastguard Worker        """Tests output shape for input with ndim and dim and keepdim kwargs"""
103*da0073e9SAndroid Build Coastguard Worker        shape = torch.randint(2, 5, (ndim,)).tolist()
104*da0073e9SAndroid Build Coastguard Worker        t = make_tensor(shape, dtype=torch.float, device=device)
105*da0073e9SAndroid Build Coastguard Worker        args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim))
106*da0073e9SAndroid Build Coastguard Worker        result = op(t, *args, **dim_keepdim, **kwargs)
107*da0073e9SAndroid Build Coastguard Worker        expected_shape = _reduced_shape(shape, **dim_keepdim)
108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, expected_shape, f"""
109*da0073e9SAndroid Build Coastguard Worker        expected output shape to be {expected_shape} but got {list(result.shape)}
110*da0073e9SAndroid Build Coastguard Worker        for input shape {shape} and {dim_keepdim}
111*da0073e9SAndroid Build Coastguard Worker        """)
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    # TODO(@heitorschueroff) combine cases with and without keepdim once
114*da0073e9SAndroid Build Coastguard Worker    # there's support for a @parametrize decorator.
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
117*da0073e9SAndroid Build Coastguard Worker    def test_dim_default(self, device, op: ReductionOpInfo):
118*da0073e9SAndroid Build Coastguard Worker        """Tests that the default dim reduces all dimensions."""
119*da0073e9SAndroid Build Coastguard Worker        for ndim in range(3):
120*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=ndim)
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
123*da0073e9SAndroid Build Coastguard Worker    def test_dim_default_keepdim(self, device, op: ReductionOpInfo):
124*da0073e9SAndroid Build Coastguard Worker        """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1."""
125*da0073e9SAndroid Build Coastguard Worker        for ndim in range(3):
126*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
129*da0073e9SAndroid Build Coastguard Worker    def test_dim_none(self, device, op: ReductionOpInfo):
130*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=None reduces all dimensions."""
131*da0073e9SAndroid Build Coastguard Worker        for ndim in range(3):
132*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=ndim, dim=None)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
135*da0073e9SAndroid Build Coastguard Worker    def test_dim_none_keepdim(self, device, op: ReductionOpInfo):
136*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1."""
137*da0073e9SAndroid Build Coastguard Worker        for ndim in range(3):
138*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
141*da0073e9SAndroid Build Coastguard Worker    def test_dim_single(self, device, op: ReductionOpInfo):
142*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=i reduces dimension i."""
143*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=0, dim=0)
144*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=1, dim=0)
145*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=2, dim=-1)
146*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=3, dim=1)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
149*da0073e9SAndroid Build Coastguard Worker    def test_dim_single_keepdim(self, device, op: ReductionOpInfo):
150*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=i, when keepdim=True, reduces dimension i to size 1."""
151*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True)
152*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True)
153*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True)
154*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
157*da0073e9SAndroid Build Coastguard Worker    def test_dim_empty(self, device, op: ReductionOpInfo):
158*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=[] is a no-op"""
159*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=0, dim=[])
160*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=2, dim=[])
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
163*da0073e9SAndroid Build Coastguard Worker    def test_dim_empty_keepdim(self, device, op: ReductionOpInfo):
164*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=[], when keepdim=True, is a no-op"""
165*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True)
166*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
169*da0073e9SAndroid Build Coastguard Worker    def test_dim_multi(self, device, op: ReductionOpInfo):
170*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=[i, j, ...] reduces dimensions i, j, ...."""
171*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=1, dim=[0])
172*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
175*da0073e9SAndroid Build Coastguard Worker    def test_dim_multi_keepdim(self, device, op: ReductionOpInfo):
176*da0073e9SAndroid Build Coastguard Worker        """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1."""
177*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True)
178*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
181*da0073e9SAndroid Build Coastguard Worker    def test_dim_multi_unsorted(self, device, op: ReductionOpInfo):
182*da0073e9SAndroid Build Coastguard Worker        """Tests that operator correctly handles unsorted dim list."""
183*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2])
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
186*da0073e9SAndroid Build Coastguard Worker    def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo):
187*da0073e9SAndroid Build Coastguard Worker        """Tests that operator correctly handles unsorted dim list when keepdim=True."""
188*da0073e9SAndroid Build Coastguard Worker        self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True)
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
191*da0073e9SAndroid Build Coastguard Worker    def test_dim_multi_duplicate(self, device, op: ReductionOpInfo):
192*da0073e9SAndroid Build Coastguard Worker        """Tests that an error is raised if dim has duplicate entries."""
193*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
194*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2])
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
197*da0073e9SAndroid Build Coastguard Worker    def test_dim_multi_unsupported(self, device, op: ReductionOpInfo):
198*da0073e9SAndroid Build Coastguard Worker        """Tests that ops claiming to not support multi dim actually don't."""
199*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
200*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
203*da0073e9SAndroid Build Coastguard Worker    def test_dim_offbounds(self, device, op: ReductionOpInfo):
204*da0073e9SAndroid Build Coastguard Worker        """Tests that passing an off-bounds dim throws"""
205*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(IndexError):
206*da0073e9SAndroid Build Coastguard Worker            self._test_dim_keepdim(op, device, ndim=2, dim=2)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
209*da0073e9SAndroid Build Coastguard Worker    def test_dim_ndim_limit(self, device, op: ReductionOpInfo):
210*da0073e9SAndroid Build Coastguard Worker        """Tests that an exception is raised when reducing a tensor with more
211*da0073e9SAndroid Build Coastguard Worker        than 64 dims along some specific dimensions. dim=None is ok"""
212*da0073e9SAndroid Build Coastguard Worker        t = make_tensor([1] * 65, dtype=torch.float, device=device)
213*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
214*da0073e9SAndroid Build Coastguard Worker            op(t, dim=0)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported)
217*da0073e9SAndroid Build Coastguard Worker    def test_identity(self, device, dtype, op: ReductionOpInfo):
218*da0073e9SAndroid Build Coastguard Worker        """Tests that the identity value is an identity for the operator"""
219*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((10,), dtype=dtype, device=device)
220*da0073e9SAndroid Build Coastguard Worker        t[1::2] = op.identity
221*da0073e9SAndroid Build Coastguard Worker        args, kwargs = next(op.generate_args_kwargs(t))
222*da0073e9SAndroid Build Coastguard Worker        result = op(t[::2], *args, **kwargs)
223*da0073e9SAndroid Build Coastguard Worker        result_with_identity = op(t, *args, **kwargs)
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, result_with_identity, """
225*da0073e9SAndroid Build Coastguard Worker        Adding identity value to the input tensor should not change the result.
226*da0073e9SAndroid Build Coastguard Worker        """)
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    # TODO(@heitorschueroff) Update these to use the nan_policy kwarg once
229*da0073e9SAndroid Build Coastguard Worker    # it is added to reduction operators.
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported,
232*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
233*da0073e9SAndroid Build Coastguard Worker    def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo):
234*da0073e9SAndroid Build Coastguard Worker        """Tests that nan is propagated to the output by default"""
235*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5,), dtype=dtype, device=device)
236*da0073e9SAndroid Build Coastguard Worker        t[2] = torch.nan
237*da0073e9SAndroid Build Coastguard Worker        args, kwargs = next(op.generate_args_kwargs(t))
238*da0073e9SAndroid Build Coastguard Worker        result = op(t, *args, **kwargs)
239*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result.isnan())
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported,
242*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
243*da0073e9SAndroid Build Coastguard Worker    def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo):
244*da0073e9SAndroid Build Coastguard Worker        """Tests that NaN values do not affect the result."""
245*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((10,), dtype=dtype, device=device)
246*da0073e9SAndroid Build Coastguard Worker        t[1::2] = torch.nan
247*da0073e9SAndroid Build Coastguard Worker        args, kwargs = next(op.generate_args_kwargs(t))
248*da0073e9SAndroid Build Coastguard Worker        result = op(t[::2], *args, **kwargs)
249*da0073e9SAndroid Build Coastguard Worker        result_with_nan = op(t, *args, **kwargs)
250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, result_with_nan)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.supported)
253*da0073e9SAndroid Build Coastguard Worker    def test_result_dtype(self, device, dtype, op: ReductionOpInfo):
254*da0073e9SAndroid Build Coastguard Worker        """Tests that the result has the correct dtype"""
255*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5,), dtype=dtype, device=device)
256*da0073e9SAndroid Build Coastguard Worker        args, kwargs = next(op.generate_args_kwargs(t))
257*da0073e9SAndroid Build Coastguard Worker        result: torch.Tensor = op(t, *args, **kwargs)
258*da0073e9SAndroid Build Coastguard Worker        is_integral = dtype in integral_types_and(torch.bool)
259*da0073e9SAndroid Build Coastguard Worker        if op.promotes_int_to_float and is_integral:
260*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_floating_point(result))
261*da0073e9SAndroid Build Coastguard Worker        elif op.promotes_int_to_int64 and is_integral:
262*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, torch.int64)
263*da0073e9SAndroid Build Coastguard Worker        elif op.result_dtype is not None:
264*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, op.result_dtype)
265*da0073e9SAndroid Build Coastguard Worker        elif op.complex_to_real:
266*da0073e9SAndroid Build Coastguard Worker            _complex_to_real_dtype_map = {
267*da0073e9SAndroid Build Coastguard Worker                torch.complex128: torch.float64,
268*da0073e9SAndroid Build Coastguard Worker                torch.complex64: torch.float32,
269*da0073e9SAndroid Build Coastguard Worker                torch.complex32: torch.float16,
270*da0073e9SAndroid Build Coastguard Worker            }
271*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, _complex_to_real_dtype_map.get(dtype, dtype))
272*da0073e9SAndroid Build Coastguard Worker        else:
273*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, dtype)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
276*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo):
277*da0073e9SAndroid Build Coastguard Worker        """Tests for consistent behavior when reducing over an empty slice.
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        The rules for reducing over an empty slice are as follows:
280*da0073e9SAndroid Build Coastguard Worker            - Return the identity value if the operator has one
281*da0073e9SAndroid Build Coastguard Worker            - Otherwise, return NaN if the operator promotes integral dtype to
282*da0073e9SAndroid Build Coastguard Worker              floating point dtypes.
283*da0073e9SAndroid Build Coastguard Worker            - Otherwise, raise an error
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        See discussion here https://github.com/pytorch/pytorch/issues/61901
286*da0073e9SAndroid Build Coastguard Worker        """
287*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
288*da0073e9SAndroid Build Coastguard Worker        for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []:
289*da0073e9SAndroid Build Coastguard Worker            args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
290*da0073e9SAndroid Build Coastguard Worker            if op.identity is not None:
291*da0073e9SAndroid Build Coastguard Worker                # Reducing along empty slice should return identity
292*da0073e9SAndroid Build Coastguard Worker                result = op(t, *args, dim=dim, **kwargs)
293*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, torch.full_like(result, op.identity))
294*da0073e9SAndroid Build Coastguard Worker            elif op.promotes_int_to_float:
295*da0073e9SAndroid Build Coastguard Worker                # Reducing along empty slice should return NaN
296*da0073e9SAndroid Build Coastguard Worker                result = op(t, *args, dim=dim, **kwargs)
297*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result, torch.full_like(result, torch.nan))
298*da0073e9SAndroid Build Coastguard Worker            else:
299*da0073e9SAndroid Build Coastguard Worker                # Reducing along empty slice should raise an error
300*da0073e9SAndroid Build Coastguard Worker                if isinstance(op, ReductionPythonRefInfo):
301*da0073e9SAndroid Build Coastguard Worker                    # ref reductions throw RuntimeError for this
302*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(RuntimeError):
303*da0073e9SAndroid Build Coastguard Worker                        op(t, *args, dim=dim, **kwargs)
304*da0073e9SAndroid Build Coastguard Worker                else:
305*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaises(IndexError):
306*da0073e9SAndroid Build Coastguard Worker                        op(t, *args, dim=dim, **kwargs)
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops, dtypes=OpDTypes.none)
309*da0073e9SAndroid Build Coastguard Worker    def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo):
310*da0073e9SAndroid Build Coastguard Worker        """Tests that reducing a nonempty slice of an empty tensor returns an
311*da0073e9SAndroid Build Coastguard Worker        empty tensor with the dimensions reduced."""
312*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
313*da0073e9SAndroid Build Coastguard Worker        for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []:
314*da0073e9SAndroid Build Coastguard Worker            args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
315*da0073e9SAndroid Build Coastguard Worker            result = op(t, *args, dim=dim, **kwargs)
316*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
319*da0073e9SAndroid Build Coastguard Worker        """Helper method to test noncontiguous input tensors."""
320*da0073e9SAndroid Build Coastguard Worker        assert not t.is_contiguous()
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        t_contig = t.contiguous()
323*da0073e9SAndroid Build Coastguard Worker        for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs):
324*da0073e9SAndroid Build Coastguard Worker            kwargs.update(reduction_kwargs)
325*da0073e9SAndroid Build Coastguard Worker            result = op(t, *args, **kwargs)
326*da0073e9SAndroid Build Coastguard Worker            expected = op(t_contig, *args, **kwargs)
327*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops)
330*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo):
331*da0073e9SAndroid Build Coastguard Worker        """Tests reducing along noncontiguous innermost dimension."""
332*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
333*da0073e9SAndroid Build Coastguard Worker        self._test_noncontiguous(op, t[:, ::2], dim=1)
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops)
336*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo):
337*da0073e9SAndroid Build Coastguard Worker        """Tests reducing along noncontiguous outermost dimension."""
338*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
339*da0073e9SAndroid Build Coastguard Worker        self._test_noncontiguous(op, t[::2, :], dim=0)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops)
342*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo):
343*da0073e9SAndroid Build Coastguard Worker        """Tests reducing all dimensions of a noncontiguous tensor."""
344*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5, 5, 5), dtype=dtype, device=device, low=-1, high=1)
345*da0073e9SAndroid Build Coastguard Worker        self._test_noncontiguous(op, t[::2, ::3, 1:-1:2])
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops)
348*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo):
349*da0073e9SAndroid Build Coastguard Worker        """Tests reducing a transposed tensor."""
350*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5, 5), dtype=dtype, device=device, low=-1, high=1)
351*da0073e9SAndroid Build Coastguard Worker        self._test_noncontiguous(op, t.T)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    @ops(reduction_ops)
354*da0073e9SAndroid Build Coastguard Worker    def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo):
355*da0073e9SAndroid Build Coastguard Worker        """Tests reducing a tensor with expanded singleton dimensions."""
356*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((2, 3), dtype=dtype, device=device, low=-1, high=1)
357*da0073e9SAndroid Build Coastguard Worker        self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1))
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    # NumPy does not support BFloat16 so we don't test that against reference
360*da0073e9SAndroid Build Coastguard Worker    # implementations. We also don't compare dtypes or test for different
361*da0073e9SAndroid Build Coastguard Worker    # keepdim because we already have other tests covering those.
362*da0073e9SAndroid Build Coastguard Worker    # The test_reference_testing in test_ops.py only uses the samples from
363*da0073e9SAndroid Build Coastguard Worker    # sample_inputs_func which do not test as exhaustively as these tests.
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker    def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
366*da0073e9SAndroid Build Coastguard Worker        """Compares op against op.ref for the given input and reduction kwargs"""
367*da0073e9SAndroid Build Coastguard Worker        for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs):
368*da0073e9SAndroid Build Coastguard Worker            kwargs.update(reduction_kwargs)
369*da0073e9SAndroid Build Coastguard Worker            result = op(t, *args, **kwargs)
370*da0073e9SAndroid Build Coastguard Worker            expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs)
371*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected, exact_dtype=False)
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
374*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
375*da0073e9SAndroid Build Coastguard Worker    def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo):
376*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for scalar input tensors"""
377*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, make_tensor([], dtype=dtype, device=device))
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
380*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
381*da0073e9SAndroid Build Coastguard Worker    def test_ref_small_input(self, device, dtype, op: ReductionOpInfo):
382*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for small input tensors"""
383*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5, 3, 4, 2), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
384*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, t)
385*da0073e9SAndroid Build Coastguard Worker        for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
386*da0073e9SAndroid Build Coastguard Worker            self._test_ref(op, t, dim=dim)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
389*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.float64])
390*da0073e9SAndroid Build Coastguard Worker    def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo):
391*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for a large 1D input tensor to check stability"""
392*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, make_tensor((2 ** 20,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
395*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.float64])
396*da0073e9SAndroid Build Coastguard Worker    def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo):
397*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for a large 2D input tensor to test parallelism"""
398*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((32, 2 ** 16), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)
399*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, t, dim=1)
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("8gb")
402*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
403*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.float64])
404*da0073e9SAndroid Build Coastguard Worker    def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo):
405*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for a very large input tensor that requires 64 bit indexing"""
406*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, make_tensor((275000000,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
409*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
410*da0073e9SAndroid Build Coastguard Worker    def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo):
411*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for input tensors with duplicate values"""
412*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((4, 4), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
413*da0073e9SAndroid Build Coastguard Worker        t[::2, ::2] = t[1::2, 1::2]
414*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, t)
415*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, t, dim=0)
416*da0073e9SAndroid Build Coastguard Worker        self._test_ref(op, t, dim=1)
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    @ops(filter(lambda op: op.ref is not None, reduction_ops),
419*da0073e9SAndroid Build Coastguard Worker         allowed_dtypes=[torch.float32, torch.complex64])
420*da0073e9SAndroid Build Coastguard Worker    def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo):
421*da0073e9SAndroid Build Coastguard Worker        """Compares op against reference for input tensors with extremal values"""
422*da0073e9SAndroid Build Coastguard Worker        t = make_tensor((5,), dtype=dtype, device=device, exclude_zero=True)
423*da0073e9SAndroid Build Coastguard Worker        extremals = [0, 1, nan, inf, -inf]
424*da0073e9SAndroid Build Coastguard Worker        for extremal in extremals:
425*da0073e9SAndroid Build Coastguard Worker            t[2] = extremal
426*da0073e9SAndroid Build Coastguard Worker            self._test_ref(op, t)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    ###########################################################################
429*da0073e9SAndroid Build Coastguard Worker    # TODO: Legacy tests - port to ReductionOpInfo
430*da0073e9SAndroid Build Coastguard Worker    ###########################################################################
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def test_var_unbiased(self, device):
433*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(100, device=device)
434*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(), tensor.var(unbiased=True))
436*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False))
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([1.0, 2.0], device=device)
439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(unbiased=True), 0.5)
440*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(unbiased=False), 0.25)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([1.0, 2.0, 3.0], device=device)
443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(unbiased=True), 1.0)
444*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0)
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(100, device=device)
447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True))
448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.std(), tensor.std(unbiased=True))
449*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False))
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    def test_var_stability(self, device):
452*da0073e9SAndroid Build Coastguard Worker        tensor = torch.tensor([2281.5, 2281.25], device=device)
453*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(dim=0), 0.03125)
454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(), 0.03125)
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker    def test_sum_dim_reduction_uint8_overflow(self, device):
457*da0073e9SAndroid Build Coastguard Worker        example = [[-1, 2, 1], [5, 3, 6]]
458*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(example, dtype=torch.uint8, device=device)
459*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
460*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8, device=device))
461*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8, device=device))
462*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(example, dtype=torch.uint8, device=device)
463*da0073e9SAndroid Build Coastguard Worker        torch.sum(x, 0, out=y)
464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(0, dtype=torch.uint8), y)
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    def test_dim_reduction_less_than_64(self, device):
467*da0073e9SAndroid Build Coastguard Worker        sizes = [1] * 65
468*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(sizes, device=device)
469*da0073e9SAndroid Build Coastguard Worker        ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
470*da0073e9SAndroid Build Coastguard Worker               torch.norm]
471*da0073e9SAndroid Build Coastguard Worker        for op in ops:
472*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
473*da0073e9SAndroid Build Coastguard Worker                op(x, dim=64)
474*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
475*da0073e9SAndroid Build Coastguard Worker                op(x, dim=-1)
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
478*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.bfloat16)
479*da0073e9SAndroid Build Coastguard Worker    def test_dim_reduction_lastdim(self, device, dtype):
480*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 5, 40, device=device, dtype=dtype)
481*da0073e9SAndroid Build Coastguard Worker        x = x[:, :, 0:40:2]
482*da0073e9SAndroid Build Coastguard Worker        x2 = x.contiguous()
483*da0073e9SAndroid Build Coastguard Worker        ops = [torch.norm, torch.argmax, torch.argmin]
484*da0073e9SAndroid Build Coastguard Worker        for op in ops:
485*da0073e9SAndroid Build Coastguard Worker            y = op(x, dim=-1)
486*da0073e9SAndroid Build Coastguard Worker            y2 = op(x2, dim=-1)
487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y2)
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
490*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
491*da0073e9SAndroid Build Coastguard Worker    def test_logsumexp(self, device, dtype):
492*da0073e9SAndroid Build Coastguard Worker        from scipy.special import logsumexp
493*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 4, device=device, dtype=dtype)
494*da0073e9SAndroid Build Coastguard Worker        # torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp,
495*da0073e9SAndroid Build Coastguard Worker        # numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740
496*da0073e9SAndroid Build Coastguard Worker        if torch.device(device) != torch.device('cpu'):
497*da0073e9SAndroid Build Coastguard Worker            a[0, 0] = inf
498*da0073e9SAndroid Build Coastguard Worker        a[1, :] = -inf
499*da0073e9SAndroid Build Coastguard Worker        actual = a.logsumexp(1)
500*da0073e9SAndroid Build Coastguard Worker        expected = logsumexp(a.cpu().numpy(), 1)
501*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.shape, actual.shape)
502*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        # check that out is actually inplace
505*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(5, 2, device=device, dtype=dtype)
506*da0073e9SAndroid Build Coastguard Worker        c = b[:, 0]
507*da0073e9SAndroid Build Coastguard Worker        torch.logsumexp(a, 1, out=c)
508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, b[:, 0])
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
511*da0073e9SAndroid Build Coastguard Worker    def test_logsumexp_integral_promotion(self, device):
512*da0073e9SAndroid Build Coastguard Worker        from scipy.special import logsumexp
513*da0073e9SAndroid Build Coastguard Worker        # check integral inputs is promoted to floating point
514*da0073e9SAndroid Build Coastguard Worker        e = torch.randint(-100, 100, [5, 4], device=device)
515*da0073e9SAndroid Build Coastguard Worker        actual = e.logsumexp(1).to(torch.float64)
516*da0073e9SAndroid Build Coastguard Worker        expected = logsumexp(e.cpu().numpy(), 1)
517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected.shape, actual.shape)
518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
521*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.complex64, torch.complex128)
522*da0073e9SAndroid Build Coastguard Worker    def test_logcumsumexp_complex(self, device, dtype):
523*da0073e9SAndroid Build Coastguard Worker        # logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))``
524*da0073e9SAndroid Build Coastguard Worker        # and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]``
525*da0073e9SAndroid Build Coastguard Worker        # the for-loop above should produce similar precision as logcumsumexp (it's just slower),
526*da0073e9SAndroid Build Coastguard Worker        # so it can be used as the expected values to check our computation
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker        # using logsumexp from scipy because by the time of writing this test code,
529*da0073e9SAndroid Build Coastguard Worker        # torch.logsumexp has not been implemented for complex numbers
530*da0073e9SAndroid Build Coastguard Worker        from scipy.special import logsumexp
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker        def zero_out_neg_inf(t):
533*da0073e9SAndroid Build Coastguard Worker            t = t.clone()
534*da0073e9SAndroid Build Coastguard Worker            idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0)
535*da0073e9SAndroid Build Coastguard Worker            t[idx] = torch.real(t[idx]).to(t.dtype)
536*da0073e9SAndroid Build Coastguard Worker            return t
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        def standardize_phase(t):
539*da0073e9SAndroid Build Coastguard Worker            t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi))
540*da0073e9SAndroid Build Coastguard Worker            return t
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        def logcumsumexp_slow(a, dim):
543*da0073e9SAndroid Build Coastguard Worker            res_lst = []
544*da0073e9SAndroid Build Coastguard Worker            for i in range(a.size(dim)):
545*da0073e9SAndroid Build Coastguard Worker                index = [slice(None, None, None) for _ in range(a.ndim)]
546*da0073e9SAndroid Build Coastguard Worker                index[dim] = slice(None, i + 1, None)
547*da0073e9SAndroid Build Coastguard Worker                a_inp = a[tuple(index)]
548*da0073e9SAndroid Build Coastguard Worker                res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
549*da0073e9SAndroid Build Coastguard Worker            res = np.concatenate(res_lst, axis=dim)
550*da0073e9SAndroid Build Coastguard Worker            return torch.as_tensor(res)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        def compare_logcumsumexp(a, expected=None):
553*da0073e9SAndroid Build Coastguard Worker            for i in range(a.ndim):
554*da0073e9SAndroid Build Coastguard Worker                actual = torch.logcumsumexp(a, dim=i)
555*da0073e9SAndroid Build Coastguard Worker                # if the expected is not given, then revert to scipy's logsumexp
556*da0073e9SAndroid Build Coastguard Worker                if expected is None:
557*da0073e9SAndroid Build Coastguard Worker                    expected2 = logcumsumexp_slow(a, dim=i)
558*da0073e9SAndroid Build Coastguard Worker                else:
559*da0073e9SAndroid Build Coastguard Worker                    expected2 = expected
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker                # move the imaginary values to (0, 2 * pi)
562*da0073e9SAndroid Build Coastguard Worker                actual = standardize_phase(actual)
563*da0073e9SAndroid Build Coastguard Worker                expected2 = standardize_phase(expected2)
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker                # zeroing the imaginary part of the element if the real part is -inf
566*da0073e9SAndroid Build Coastguard Worker                # as the imaginary part cannot be determined exactly and it does not
567*da0073e9SAndroid Build Coastguard Worker                # really matter if we take the exp of the output
568*da0073e9SAndroid Build Coastguard Worker                actual = zero_out_neg_inf(actual)
569*da0073e9SAndroid Build Coastguard Worker                expected2 = zero_out_neg_inf(expected2)
570*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected2.shape, actual.shape)
571*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected2, actual)
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker        # randomly specified values
574*da0073e9SAndroid Build Coastguard Worker        # in this case, scipy.logsumexp should be enough
575*da0073e9SAndroid Build Coastguard Worker        a1 = torch.randn((5, 10), dtype=dtype, device=device)
576*da0073e9SAndroid Build Coastguard Worker        compare_logcumsumexp(a1)
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        # test with some non-normal values
579*da0073e9SAndroid Build Coastguard Worker        a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device)
580*da0073e9SAndroid Build Coastguard Worker        compare_logcumsumexp(a2)
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        # handle special case involving infinites and nans
583*da0073e9SAndroid Build Coastguard Worker        # here we don't use scipy.logsumexp as it gives confusing answer on
584*da0073e9SAndroid Build Coastguard Worker        # some inf cases
585*da0073e9SAndroid Build Coastguard Worker        # see here:
586*da0073e9SAndroid Build Coastguard Worker        inf = float('inf')
587*da0073e9SAndroid Build Coastguard Worker        nan = float('nan')
588*da0073e9SAndroid Build Coastguard Worker        a3_input = torch.tensor([
589*da0073e9SAndroid Build Coastguard Worker            -inf + 4j,
590*da0073e9SAndroid Build Coastguard Worker            -inf + 1j,
591*da0073e9SAndroid Build Coastguard Worker            1.2 + 2.1j,
592*da0073e9SAndroid Build Coastguard Worker            1e10 + 1e20j,
593*da0073e9SAndroid Build Coastguard Worker            inf + 0j,
594*da0073e9SAndroid Build Coastguard Worker            inf + 1j,
595*da0073e9SAndroid Build Coastguard Worker            inf + 3j,
596*da0073e9SAndroid Build Coastguard Worker            nan + 2j,
597*da0073e9SAndroid Build Coastguard Worker        ])
598*da0073e9SAndroid Build Coastguard Worker        a3_expected = torch.tensor([
599*da0073e9SAndroid Build Coastguard Worker            -inf + 0j,
600*da0073e9SAndroid Build Coastguard Worker            -inf + 0j,
601*da0073e9SAndroid Build Coastguard Worker            1.2 + 2.1j,
602*da0073e9SAndroid Build Coastguard Worker            1e10 + 1e20j,
603*da0073e9SAndroid Build Coastguard Worker            inf + 0j,  # scipy's logsumexp gives (inf + 0.7853982j) here, unclear why
604*da0073e9SAndroid Build Coastguard Worker            inf + (np.pi / 4) * 1j,  # the imaginary part thanks to some weird behaviour of log(inf + infj)
605*da0073e9SAndroid Build Coastguard Worker            complex(inf, nan),
606*da0073e9SAndroid Build Coastguard Worker            complex(nan, nan),
607*da0073e9SAndroid Build Coastguard Worker        ])
608*da0073e9SAndroid Build Coastguard Worker        # windows give strange results on the second-to-last results where it gives inf + pi/4 j
609*da0073e9SAndroid Build Coastguard Worker        # instead of inf + nan j
610*da0073e9SAndroid Build Coastguard Worker        if not IS_WINDOWS:
611*da0073e9SAndroid Build Coastguard Worker            compare_logcumsumexp(a3_input, a3_expected)
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        a4_input = torch.tensor([
614*da0073e9SAndroid Build Coastguard Worker            complex(-inf, inf),
615*da0073e9SAndroid Build Coastguard Worker            complex(-inf, inf),
616*da0073e9SAndroid Build Coastguard Worker            -inf + 1j,
617*da0073e9SAndroid Build Coastguard Worker            1.2 + 2.1j,
618*da0073e9SAndroid Build Coastguard Worker            complex(2.4, inf),
619*da0073e9SAndroid Build Coastguard Worker        ])
620*da0073e9SAndroid Build Coastguard Worker        a4_expected = torch.tensor([
621*da0073e9SAndroid Build Coastguard Worker            -inf + 0j,
622*da0073e9SAndroid Build Coastguard Worker            -inf + 0j,
623*da0073e9SAndroid Build Coastguard Worker            -inf + 0j,
624*da0073e9SAndroid Build Coastguard Worker            1.2 + 2.1j,
625*da0073e9SAndroid Build Coastguard Worker            complex(nan, nan),
626*da0073e9SAndroid Build Coastguard Worker        ])
627*da0073e9SAndroid Build Coastguard Worker        if not IS_WINDOWS:
628*da0073e9SAndroid Build Coastguard Worker            compare_logcumsumexp(a4_input, a4_expected)
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
631*da0073e9SAndroid Build Coastguard Worker    def test_sum_parallel(self, device):
632*da0073e9SAndroid Build Coastguard Worker        # To use parallel branches we'll need to compare on tensors
633*da0073e9SAndroid Build Coastguard Worker        # that are relatively large. Even if this is run on a single
634*da0073e9SAndroid Build Coastguard Worker        # core machine these tests will still give you signal on
635*da0073e9SAndroid Build Coastguard Worker        # the correctness
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker        def _run_test(size):
638*da0073e9SAndroid Build Coastguard Worker            for dim in range(len(size) + 1):
639*da0073e9SAndroid Build Coastguard Worker                nv = np.round(np.random.rand(*size))  # 0s and 1s
640*da0073e9SAndroid Build Coastguard Worker                tv = torch.from_numpy(nv)
641*da0073e9SAndroid Build Coastguard Worker                # Parallelisim is only used if numel is
642*da0073e9SAndroid Build Coastguard Worker                # larger than grainsize defined in Parallel.h
643*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(tv.numel() > 32768)
644*da0073e9SAndroid Build Coastguard Worker                if dim == len(size):
645*da0073e9SAndroid Build Coastguard Worker                    nvs = nv.sum()
646*da0073e9SAndroid Build Coastguard Worker                    tvs = tv.sum()
647*da0073e9SAndroid Build Coastguard Worker                else:
648*da0073e9SAndroid Build Coastguard Worker                    nvs = nv.sum(dim)
649*da0073e9SAndroid Build Coastguard Worker                    tvs = tv.sum(dim)
650*da0073e9SAndroid Build Coastguard Worker                diff = np.abs(nvs - tvs.numpy()).sum()
651*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(diff, 0)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
654*da0073e9SAndroid Build Coastguard Worker        _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
655*da0073e9SAndroid Build Coastguard Worker        _run_test([1, 32 * 8 * 32 * 8])
656*da0073e9SAndroid Build Coastguard Worker        _run_test([1, 32770])
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker    # TODO: kill map2_ (and similar) uses and update to compare with NumPy
659*da0073e9SAndroid Build Coastguard Worker    # only works on CPU since this uses map2_, which is only supported on CPU
660*da0073e9SAndroid Build Coastguard Worker    def _testCSelection(self, torchfn, mathfn):
661*da0073e9SAndroid Build Coastguard Worker        # Two tensors
662*da0073e9SAndroid Build Coastguard Worker        size = (100, 100)
663*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(*size)
664*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(*size)
665*da0073e9SAndroid Build Coastguard Worker        c = torchfn(a, b)
666*da0073e9SAndroid Build Coastguard Worker        expected_c = torch.zeros(*size)
667*da0073e9SAndroid Build Coastguard Worker        expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b))
668*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_c, c, atol=0, rtol=0)
669*da0073e9SAndroid Build Coastguard Worker
670*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
671*da0073e9SAndroid Build Coastguard Worker    def test_max_elementwise(self, device):
672*da0073e9SAndroid Build Coastguard Worker        self._testCSelection(torch.max, max)
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
675*da0073e9SAndroid Build Coastguard Worker    def test_min_elementwise(self, device):
676*da0073e9SAndroid Build Coastguard Worker        self._testCSelection(torch.min, min)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker    def test_all_any(self, device):
679*da0073e9SAndroid Build Coastguard Worker        def test(size):
680*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(*size, device=device).byte()
681*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.all())
682*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.any())
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker            x[3] = 0
685*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.all())
686*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.any())
687*da0073e9SAndroid Build Coastguard Worker
688*da0073e9SAndroid Build Coastguard Worker            x.zero_()
689*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.all())
690*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.any())
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker            x.fill_(2)
693*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.all())
694*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.any())
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(*size, device=device).bool()
697*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.all())
698*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.any())
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker            x[3] = False
701*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.all())
702*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.any())
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker        test((10,))
705*da0073e9SAndroid Build Coastguard Worker        test((5, 5))
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker    def test_all_any_with_dim(self, device):
708*da0073e9SAndroid Build Coastguard Worker        def test(x):
709*da0073e9SAndroid Build Coastguard Worker            r1 = x.prod(dim=0, keepdim=False).byte()
710*da0073e9SAndroid Build Coastguard Worker            r2 = x.all(dim=0, keepdim=False)
711*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r1.shape, r2.shape)
712*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((r1 == r2).all())
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker            r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte()
715*da0073e9SAndroid Build Coastguard Worker            r4 = x.any(dim=1, keepdim=True)
716*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(r3.shape, r4.shape)
717*da0073e9SAndroid Build Coastguard Worker            self.assertTrue((r3 == r4).all())
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        test(torch.tensor([[0, 0, 0],
720*da0073e9SAndroid Build Coastguard Worker                           [0, 0, 1],
721*da0073e9SAndroid Build Coastguard Worker                           [0, 1, 1],
722*da0073e9SAndroid Build Coastguard Worker                           [1, 1, 1]], device=device, dtype=torch.uint8))
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker    def test_numpy_named_args(self, device):
725*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10, device=device)
726*da0073e9SAndroid Build Coastguard Worker        x2 = torch.randn(10, device=device)
727*da0073e9SAndroid Build Coastguard Worker        res1 = torch.add(input=x1, other=x2)
728*da0073e9SAndroid Build Coastguard Worker        res2 = torch.add(x1=x1, x2=x2)
729*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker        x1 = torch.randn(10, 10, 10, device=device)
732*da0073e9SAndroid Build Coastguard Worker        res1 = x1.sum(dim=(0, 2), keepdim=True)
733*da0073e9SAndroid Build Coastguard Worker        res2 = x1.sum(axis=(0, 2), keepdims=True)
734*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    # TODO: kill this ane replace with common creation ops
737*da0073e9SAndroid Build Coastguard Worker    def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
738*da0073e9SAndroid Build Coastguard Worker                      use_complex=False) -> Dict[str, List[torch.Tensor]]:
739*da0073e9SAndroid Build Coastguard Worker        float_types = [torch.double,
740*da0073e9SAndroid Build Coastguard Worker                       torch.float]
741*da0073e9SAndroid Build Coastguard Worker        int_types = [torch.int64,
742*da0073e9SAndroid Build Coastguard Worker                     torch.int32,
743*da0073e9SAndroid Build Coastguard Worker                     torch.int16]
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker        complex_types = [torch.complex64,
746*da0073e9SAndroid Build Coastguard Worker                         torch.complex128]
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        def make_contiguous(shape, dtype) -> torch.Tensor:
749*da0073e9SAndroid Build Coastguard Worker            if dtype in float_types:
750*da0073e9SAndroid Build Coastguard Worker                val = torch.randn(shape, dtype=dtype)
751*da0073e9SAndroid Build Coastguard Worker                val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
752*da0073e9SAndroid Build Coastguard Worker                val = val + ((val_range[1] - val_range[0]) / 2.0)
753*da0073e9SAndroid Build Coastguard Worker                val = torch.clamp(val, min=val_range[0], max=val_range[1])
754*da0073e9SAndroid Build Coastguard Worker                return val
755*da0073e9SAndroid Build Coastguard Worker            result = torch.zeros(shape, dtype=dtype)
756*da0073e9SAndroid Build Coastguard Worker            result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
757*da0073e9SAndroid Build Coastguard Worker            return result
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker        def make_non_contiguous(shape, dtype) -> torch.Tensor:
760*da0073e9SAndroid Build Coastguard Worker            contig = make_contiguous(shape, dtype)
761*da0073e9SAndroid Build Coastguard Worker            non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
762*da0073e9SAndroid Build Coastguard Worker            non_contig = non_contig.select(-1, -1)
763*da0073e9SAndroid Build Coastguard Worker            non_contig.copy_(contig)
764*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(non_contig.is_contiguous())
765*da0073e9SAndroid Build Coastguard Worker            return non_contig
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        def make_contiguous_slice(size, dtype) -> torch.Tensor:
768*da0073e9SAndroid Build Coastguard Worker            contig = make_contiguous((1, size), dtype)
769*da0073e9SAndroid Build Coastguard Worker            non_contig = contig[:1, 1:size - 1]
770*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(non_contig.is_contiguous())
771*da0073e9SAndroid Build Coastguard Worker            return contig
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        types = []
774*da0073e9SAndroid Build Coastguard Worker        if use_floating:
775*da0073e9SAndroid Build Coastguard Worker            types += float_types
776*da0073e9SAndroid Build Coastguard Worker        if use_integral:
777*da0073e9SAndroid Build Coastguard Worker            types += int_types
778*da0073e9SAndroid Build Coastguard Worker        if use_complex:
779*da0073e9SAndroid Build Coastguard Worker            types += complex_types
780*da0073e9SAndroid Build Coastguard Worker        tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
781*da0073e9SAndroid Build Coastguard Worker        for dtype in types:
782*da0073e9SAndroid Build Coastguard Worker            tensors["cont"].append(make_contiguous(shape, dtype))
783*da0073e9SAndroid Build Coastguard Worker            tensors["noncont"].append(make_non_contiguous(shape, dtype))
784*da0073e9SAndroid Build Coastguard Worker            tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker        return tensors
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker    # TODO: refactor this to use comparators from common_utils
789*da0073e9SAndroid Build Coastguard Worker    def _assert_matches_numpy(self, t, n):
790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(n.shape, t.shape)
791*da0073e9SAndroid Build Coastguard Worker        if t.dtype == torch.float:
792*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True)
793*da0073e9SAndroid Build Coastguard Worker        else:
794*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(n, t, equal_nan=True)
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker    # TODO: update this and tests that use it to use the device argument properly
797*da0073e9SAndroid Build Coastguard Worker    def _test_dim_ops(self, pytorch_op, numpy_op,
798*da0073e9SAndroid Build Coastguard Worker                      use_floating=True, use_integral=True, use_complex=False):
799*da0073e9SAndroid Build Coastguard Worker        def do_one(tensors_dict, dim):
800*da0073e9SAndroid Build Coastguard Worker            for category, tensors in tensors_dict.items():
801*da0073e9SAndroid Build Coastguard Worker                if category == "slice":
802*da0073e9SAndroid Build Coastguard Worker                    dim = 0
803*da0073e9SAndroid Build Coastguard Worker                for tensor in tensors:
804*da0073e9SAndroid Build Coastguard Worker                    # we have no control over NumPy warnings...
805*da0073e9SAndroid Build Coastguard Worker                    with warnings.catch_warnings():
806*da0073e9SAndroid Build Coastguard Worker                        warnings.simplefilter("ignore")
807*da0073e9SAndroid Build Coastguard Worker                        expected = numpy_op(tensor.cpu().numpy(), dim)
808*da0073e9SAndroid Build Coastguard Worker                    actual = pytorch_op(tensor, dim)
809*da0073e9SAndroid Build Coastguard Worker                    self._assert_matches_numpy(actual, expected)
810*da0073e9SAndroid Build Coastguard Worker                    if torch.cuda.is_available():
811*da0073e9SAndroid Build Coastguard Worker                        self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected)
812*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((5, 400000), use_floating=use_floating,
813*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 1)
814*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
815*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 0)
816*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
817*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 1)
818*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
819*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 2)
820*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((100000, ), use_floating=use_floating,
821*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), -1)
822*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
823*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 0)
824*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
825*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 1)
826*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
827*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), 2)
828*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
829*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), (1, 2))
830*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
831*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), (1, -1))
832*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
833*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), (0, 2))
834*da0073e9SAndroid Build Coastguard Worker        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
835*da0073e9SAndroid Build Coastguard Worker                                  use_integral=use_integral, use_complex=use_complex), (0, 2, 1))
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker    @slowTest
838*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
839*da0073e9SAndroid Build Coastguard Worker    def test_sum_dim(self, device):
840*da0073e9SAndroid Build Coastguard Worker        self._test_dim_ops(
841*da0073e9SAndroid Build Coastguard Worker            lambda t, d: t.sum(d),
842*da0073e9SAndroid Build Coastguard Worker            lambda n, d: n.sum(d),
843*da0073e9SAndroid Build Coastguard Worker            use_floating=True, use_integral=True, use_complex=True)
844*da0073e9SAndroid Build Coastguard Worker
845*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
846*da0073e9SAndroid Build Coastguard Worker    def test_mean_dim(self, device):
847*da0073e9SAndroid Build Coastguard Worker        self._test_dim_ops(
848*da0073e9SAndroid Build Coastguard Worker            lambda t, d: t.mean(d),
849*da0073e9SAndroid Build Coastguard Worker            lambda n, d: n.mean(d),
850*da0073e9SAndroid Build Coastguard Worker            use_integral=False,
851*da0073e9SAndroid Build Coastguard Worker            use_complex=True)
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
854*da0073e9SAndroid Build Coastguard Worker    def test_std_dim(self, device):
855*da0073e9SAndroid Build Coastguard Worker        for unbiased in [False, True]:
856*da0073e9SAndroid Build Coastguard Worker            self._test_dim_ops(
857*da0073e9SAndroid Build Coastguard Worker                lambda t, d: t.std(d, unbiased=unbiased),
858*da0073e9SAndroid Build Coastguard Worker                lambda n, d: n.std(d, ddof=1 if unbiased else 0),
859*da0073e9SAndroid Build Coastguard Worker                use_integral=False)
860*da0073e9SAndroid Build Coastguard Worker
861*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
862*da0073e9SAndroid Build Coastguard Worker    def test_var_dim(self, device):
863*da0073e9SAndroid Build Coastguard Worker        for unbiased in [False, True]:
864*da0073e9SAndroid Build Coastguard Worker            self._test_dim_ops(
865*da0073e9SAndroid Build Coastguard Worker                lambda t, d: t.var(d, unbiased=unbiased),
866*da0073e9SAndroid Build Coastguard Worker                lambda n, d: n.var(d, ddof=1 if unbiased else 0),
867*da0073e9SAndroid Build Coastguard Worker                use_integral=False)
868*da0073e9SAndroid Build Coastguard Worker
869*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
870*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
871*da0073e9SAndroid Build Coastguard Worker    def test_logsumexp_dim(self, device):
872*da0073e9SAndroid Build Coastguard Worker        from scipy.special import logsumexp
873*da0073e9SAndroid Build Coastguard Worker        self._test_dim_ops(
874*da0073e9SAndroid Build Coastguard Worker            lambda t, d: t.logsumexp(d),
875*da0073e9SAndroid Build Coastguard Worker            lambda n, d: logsumexp(n, d),
876*da0073e9SAndroid Build Coastguard Worker            use_integral=False)
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
879*da0073e9SAndroid Build Coastguard Worker    def test_mean_int_with_optdtype(self, device):
880*da0073e9SAndroid Build Coastguard Worker        a = make_tensor((3, 4, 5), dtype=torch.int64, device=device)
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker        # If the optional desired output type is given, the input
883*da0073e9SAndroid Build Coastguard Worker        # is internally cast.
884*da0073e9SAndroid Build Coastguard Worker        a_float = a.to(torch.float32)
885*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a_float.mean(), a.mean(dtype=torch.float32))
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker    # TODO: update this and tests that use it to handle device properly
888*da0073e9SAndroid Build Coastguard Worker    def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True):
889*da0073e9SAndroid Build Coastguard Worker        shape = (3, 4, 5)
890*da0073e9SAndroid Build Coastguard Worker        reduced_shape = fn(torch.ones(shape)).shape
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Worker        def _test_out(dtype, other_dtype):
893*da0073e9SAndroid Build Coastguard Worker            out = torch.ones(reduced_shape, dtype=dtype)
894*da0073e9SAndroid Build Coastguard Worker            result = fn(x, out=out)
895*da0073e9SAndroid Build Coastguard Worker            self.assertIs(out.dtype, result.dtype)
896*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
897*da0073e9SAndroid Build Coastguard Worker            result = fn(x, out=out, dtype=dtype)
898*da0073e9SAndroid Build Coastguard Worker            self.assertIs(out.dtype, result.dtype)
899*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
900*da0073e9SAndroid Build Coastguard Worker            # 'out' is favored over dtype, check error
901*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype))
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]:
904*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(shape, dtype=dtype)
905*da0073e9SAndroid Build Coastguard Worker            expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64
906*da0073e9SAndroid Build Coastguard Worker            self.assertIs(expected_dtype, fn(x).dtype)
907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(x.to(expected_dtype)), fn(x))
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
910*da0073e9SAndroid Build Coastguard Worker                other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
911*da0073e9SAndroid Build Coastguard Worker            elif dtype.is_complex:
912*da0073e9SAndroid Build Coastguard Worker                other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128
913*da0073e9SAndroid Build Coastguard Worker            else:
914*da0073e9SAndroid Build Coastguard Worker                other_dtype = torch.int32 if dtype != torch.int32 else torch.int16
915*da0073e9SAndroid Build Coastguard Worker            self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
916*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker            # test mixed int/float/complex
919*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
920*da0073e9SAndroid Build Coastguard Worker                mixed_dtypes = [torch.int32, torch.complex64]
921*da0073e9SAndroid Build Coastguard Worker            elif dtype.is_complex:
922*da0073e9SAndroid Build Coastguard Worker                mixed_dtypes = [torch.int32, torch.float32]
923*da0073e9SAndroid Build Coastguard Worker            else:
924*da0073e9SAndroid Build Coastguard Worker                mixed_dtypes = [torch.float32, torch.complex64]
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker            for mixed_dtype in mixed_dtypes:
927*da0073e9SAndroid Build Coastguard Worker                self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
928*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False)
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker                if has_out:
931*da0073e9SAndroid Build Coastguard Worker                    _test_out(dtype, other_dtype)
932*da0073e9SAndroid Build Coastguard Worker                    _test_out(dtype, mixed_dtype)
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
935*da0073e9SAndroid Build Coastguard Worker    def test_sum_integer_upcast(self, device):
936*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False)
937*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs))
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
940*da0073e9SAndroid Build Coastguard Worker    def test_prod_integer_upcast(self, device):
941*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False)
942*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs))
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
945*da0073e9SAndroid Build Coastguard Worker    def test_cumsum_integer_upcast(self, device):
946*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
949*da0073e9SAndroid Build Coastguard Worker    def test_cumprod_integer_upcast(self, device):
950*da0073e9SAndroid Build Coastguard Worker        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types())
953*da0073e9SAndroid Build Coastguard Worker    def test_mode(self, device, dtype):
954*da0073e9SAndroid Build Coastguard Worker        SIZE = 10
955*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE)
956*da0073e9SAndroid Build Coastguard Worker        x[:2] = 1
957*da0073e9SAndroid Build Coastguard Worker        x[:, :2] = 1
958*da0073e9SAndroid Build Coastguard Worker        x0 = x.clone()
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker        # Pre-calculated results.
961*da0073e9SAndroid Build Coastguard Worker        res1val = torch.ones(SIZE, device=device, dtype=dtype)
962*da0073e9SAndroid Build Coastguard Worker        # The indices are the position of the last appearance of the mode element.
963*da0073e9SAndroid Build Coastguard Worker        res1ind = torch.ones(SIZE, device=device, dtype=torch.long)
964*da0073e9SAndroid Build Coastguard Worker        res1ind[0] = SIZE - 1
965*da0073e9SAndroid Build Coastguard Worker        res1ind[1] = SIZE - 1
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker        res2val, res2ind = torch.mode(x, keepdim=False)
968*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1val, res2val, atol=0, rtol=0)
969*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker        # Test use of result tensor
972*da0073e9SAndroid Build Coastguard Worker        res2val = torch.tensor((), device=device, dtype=dtype)
973*da0073e9SAndroid Build Coastguard Worker        res2ind = torch.tensor((), device=device, dtype=torch.long)
974*da0073e9SAndroid Build Coastguard Worker        torch.mode(x, keepdim=False, out=(res2val, res2ind))
975*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1val, res2val, atol=0, rtol=0)
976*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker        # Test non-default dim
979*da0073e9SAndroid Build Coastguard Worker        res2val, res2ind = torch.mode(x, 0, False)
980*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1val, res2val, atol=0, rtol=0)
981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker        # input unchanged
984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x0, atol=0, rtol=0)
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker    def _test_mode_intervals(self, shape, intervals, device, dtype, v=1):
987*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(0, shape[1], device=device, dtype=dtype).expand(shape)
988*da0073e9SAndroid Build Coastguard Worker        x = x.contiguous()
989*da0073e9SAndroid Build Coastguard Worker        x[:, v] = intervals[0][0]
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker        # Set the value of each interval to the mode "v"
992*da0073e9SAndroid Build Coastguard Worker        for (beg, end) in intervals:
993*da0073e9SAndroid Build Coastguard Worker            x[:, beg:end] = v
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker        values, indices = torch.mode(x, -1, False)
996*da0073e9SAndroid Build Coastguard Worker
997*da0073e9SAndroid Build Coastguard Worker        # Check whether the returned indices correspond to the returned values
998*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((x.gather(1, indices.unsqueeze(1)).t() == values).all())
999*da0073e9SAndroid Build Coastguard Worker        # Check whether the returned values are the mode
1000*da0073e9SAndroid Build Coastguard Worker        self.assertTrue((values == v).all().item())
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1003*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
1004*da0073e9SAndroid Build Coastguard Worker    def test_mode_large(self, device, dtype):
1005*da0073e9SAndroid Build Coastguard Worker        # i should be less than (d - 2) / 2
1006*da0073e9SAndroid Build Coastguard Worker        def testset_for_shape(shape, i):
1007*da0073e9SAndroid Build Coastguard Worker            d = shape[-1]
1008*da0073e9SAndroid Build Coastguard Worker            # Mode only in the middle.
1009*da0073e9SAndroid Build Coastguard Worker            self._test_mode_intervals(shape, [(i, d - i)], device, dtype)
1010*da0073e9SAndroid Build Coastguard Worker            # Mode in discontiguous parts of the input.
1011*da0073e9SAndroid Build Coastguard Worker            self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device, dtype)
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker        # More than one line of (65535) thread blocks
1014*da0073e9SAndroid Build Coastguard Worker        testset_for_shape((65536, 10), 3)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker        # Max slice size (2048)
1017*da0073e9SAndroid Build Coastguard Worker        testset_for_shape((10, 2048), 10)
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker        # Naive kernel for big slice sizes (> 2048)
1020*da0073e9SAndroid Build Coastguard Worker        testset_for_shape((10, 4096), 10)
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker    def test_mode_boolean(self, device):
1023*da0073e9SAndroid Build Coastguard Worker        shapes = [
1024*da0073e9SAndroid Build Coastguard Worker            (10, 10),
1025*da0073e9SAndroid Build Coastguard Worker            (4, 2048),
1026*da0073e9SAndroid Build Coastguard Worker            (1, 4096),
1027*da0073e9SAndroid Build Coastguard Worker        ]
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
1030*da0073e9SAndroid Build Coastguard Worker            a = torch.zeros(shape, device=device, dtype=torch.bool)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker            a[:, (shape[1] - 1) // 2:] = True
1033*da0073e9SAndroid Build Coastguard Worker            values, indices = a.mode(-1)
1034*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
1035*da0073e9SAndroid Build Coastguard Worker            print(indices)
1036*da0073e9SAndroid Build Coastguard Worker            indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1037*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, indexed)
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker            a.fill_(False)
1040*da0073e9SAndroid Build Coastguard Worker            a[:, shape[1] // 2 + 1:] = True
1041*da0073e9SAndroid Build Coastguard Worker            values, indices = a.mode(-1)
1042*da0073e9SAndroid Build Coastguard Worker            print(indices)
1043*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, torch.zeros(shape[0], dtype=torch.bool))
1044*da0073e9SAndroid Build Coastguard Worker            indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1045*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, indexed)
1046*da0073e9SAndroid Build Coastguard Worker
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker    @expectedFailureMeta  # mode only supports CPU and CUDA device type
1049*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1050*da0073e9SAndroid Build Coastguard Worker    def test_mode_wrong_dtype(self, device):
1051*da0073e9SAndroid Build Coastguard Worker        def test_for_dtypes(x_ty, v_ty, i_ty, message):
1052*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(10, device=device, dtype=x_ty)
1053*da0073e9SAndroid Build Coastguard Worker            v = torch.ones(10, device=device, dtype=v_ty)
1054*da0073e9SAndroid Build Coastguard Worker            i = torch.ones(10, device=device, dtype=i_ty)
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, message):
1057*da0073e9SAndroid Build Coastguard Worker                torch.mode(x, -1, True, out=(v, i))
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker        err_msg = "expected scalar type .* but got .* for "
1060*da0073e9SAndroid Build Coastguard Worker        values_err = err_msg + "values"
1061*da0073e9SAndroid Build Coastguard Worker        indices_err = err_msg + "indices"
1062*da0073e9SAndroid Build Coastguard Worker
1063*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.uint8, torch.int8, torch.long, values_err)
1064*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.int8, torch.int16, torch.long, values_err)
1065*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.int32, torch.float32, torch.long, values_err)
1066*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.float32, torch.float64, torch.long, values_err)
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.uint8, torch.uint8, torch.int8, indices_err)
1069*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.int8, torch.int8, torch.int16, indices_err)
1070*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.int32, torch.int32, torch.float32, indices_err)
1071*da0073e9SAndroid Build Coastguard Worker        test_for_dtypes(torch.float32, torch.float32, torch.float64, indices_err)
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1074*da0073e9SAndroid Build Coastguard Worker    def test_mode_wrong_device(self, device):
1075*da0073e9SAndroid Build Coastguard Worker        # CPU Input Tensor
1076*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2)
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1079*da0073e9SAndroid Build Coastguard Worker                                    "expected device .* but got .* for values"):
1080*da0073e9SAndroid Build Coastguard Worker            values = torch.tensor([], device=device)
1081*da0073e9SAndroid Build Coastguard Worker            torch.mode(x, -1, True, out=(values, torch.tensor([], dtype=torch.long)))
1082*da0073e9SAndroid Build Coastguard Worker
1083*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError,
1084*da0073e9SAndroid Build Coastguard Worker                                    "expected device .* but got .* for indices"):
1085*da0073e9SAndroid Build Coastguard Worker            indices = torch.tensor([], device=device)
1086*da0073e9SAndroid Build Coastguard Worker            torch.mode(x, -1, True, out=(torch.tensor([]), indices))
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker    # TODO: make work on CUDA, too
1089*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1090*da0073e9SAndroid Build Coastguard Worker    def test_accreal_type(self, device) -> None:
1091*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 3, 4)
1092*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.double().sum().item(), float)
1093*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.float().sum().item(), float)
1094*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.long().sum().item(), int)
1095*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.int().sum().item(), int)
1096*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.short().sum().item(), int)
1097*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.char().sum().item(), int)
1098*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.byte().sum().item(), int)
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker    def test_var_mean_some_dims(self, device):
1101*da0073e9SAndroid Build Coastguard Worker        sizes = (4, 6, 7, 5, 3)
1102*da0073e9SAndroid Build Coastguard Worker        dims = len(sizes)
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(sizes, device=device)
1105*da0073e9SAndroid Build Coastguard Worker        for num_of_dims in range(2, dims):
1106*da0073e9SAndroid Build Coastguard Worker            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
1107*da0073e9SAndroid Build Coastguard Worker            for dim in dim_list:
1108*da0073e9SAndroid Build Coastguard Worker                for unbiased in [False, True]:
1109*da0073e9SAndroid Build Coastguard Worker                    for keepdim in [False, True]:
1110*da0073e9SAndroid Build Coastguard Worker                        var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
1111*da0073e9SAndroid Build Coastguard Worker                        var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
1112*da0073e9SAndroid Build Coastguard Worker                        mean2 = x.mean(dim=dim, keepdim=keepdim)
1113*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(var1, var2)
1114*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(mean1, mean2)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker    # TODO: this should be a generic opinfo test
1117*da0073e9SAndroid Build Coastguard Worker    def test_all_any_empty(self, device):
1118*da0073e9SAndroid Build Coastguard Worker        x = torch.ByteTensor().to(device)
1119*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.all())
1120*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.any())
1121*da0073e9SAndroid Build Coastguard Worker
1122*da0073e9SAndroid Build Coastguard Worker        x = torch.BoolTensor().to(device)
1123*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.all())
1124*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.any())
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker    def test_all_issue117215(self, device):
1127*da0073e9SAndroid Build Coastguard Worker        info = torch.iinfo(torch.uint8)
1128*da0073e9SAndroid Build Coastguard Worker        a = torch.randint(info.min, info.max, (73, 11, 3, 17), dtype=torch.uint8)
1129*da0073e9SAndroid Build Coastguard Worker        b = torch.all(a, dim=0)
1130*da0073e9SAndroid Build Coastguard Worker        c = a.to(torch.bool).all(dim=0)
1131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.ne(b, c).sum(), 0)
1132*da0073e9SAndroid Build Coastguard Worker
1133*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1134*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1135*da0073e9SAndroid Build Coastguard Worker    def test_max_with_inf(self, device, dtype):
1136*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1137*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item())
1138*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item())
1139*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.max(a).item() == inf)
1140*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.amax(a).item() == inf)
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1143*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.bfloat16, torch.double)
1144*da0073e9SAndroid Build Coastguard Worker    def test_min_with_inf(self, device, dtype):
1145*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1146*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item())
1147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item())
1148*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.min(a).item() == -inf)
1149*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.amin(a).item() == -inf)
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker    def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False):
1152*da0073e9SAndroid Build Coastguard Worker        def create_input(shape, device, dtype):
1153*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
1154*da0073e9SAndroid Build Coastguard Worker                return torch.randn(*shape, device=device, dtype=dtype)
1155*da0073e9SAndroid Build Coastguard Worker            else:
1156*da0073e9SAndroid Build Coastguard Worker                low = 0 if dtype == torch.bool else -1000
1157*da0073e9SAndroid Build Coastguard Worker                high = 2 if dtype == torch.bool else 1000
1158*da0073e9SAndroid Build Coastguard Worker                return torch.randint(low, high, shape, device=device, dtype=dtype)
1159*da0073e9SAndroid Build Coastguard Worker        x = create_input((100, 100), device, dtype)
1160*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torchfn, reffn, x)
1161*da0073e9SAndroid Build Coastguard Worker        # non contiguous
1162*da0073e9SAndroid Build Coastguard Worker        x = create_input((10, 10, 10), device, dtype)
1163*da0073e9SAndroid Build Coastguard Worker        x = x[:, 4]
1164*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torchfn, reffn, x)
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker        def get_values(x):
1167*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, tuple):
1168*da0073e9SAndroid Build Coastguard Worker                return x[0]
1169*da0073e9SAndroid Build Coastguard Worker            return x
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker        # indices
1172*da0073e9SAndroid Build Coastguard Worker        if not skip_indices:
1173*da0073e9SAndroid Build Coastguard Worker            size = 5
1174*da0073e9SAndroid Build Coastguard Worker            x = create_input((size, size), device, dtype)
1175*da0073e9SAndroid Build Coastguard Worker            inputs = (x, x.t())
1176*da0073e9SAndroid Build Coastguard Worker            dims = (0, 1)
1177*da0073e9SAndroid Build Coastguard Worker            for xinp, d in product(inputs, dims):
1178*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp)
1179*da0073e9SAndroid Build Coastguard Worker                result = torchfn(xinp, d, False)
1180*da0073e9SAndroid Build Coastguard Worker                if isinstance(result, tuple):
1181*da0073e9SAndroid Build Coastguard Worker                    v, i = result
1182*da0073e9SAndroid Build Coastguard Worker                    if d == 1:
1183*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0)
1184*da0073e9SAndroid Build Coastguard Worker                    else:
1185*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0)
1186*da0073e9SAndroid Build Coastguard Worker        # nan
1187*da0073e9SAndroid Build Coastguard Worker        if dtype.is_floating_point:
1188*da0073e9SAndroid Build Coastguard Worker            for index in (0, 4, 99):
1189*da0073e9SAndroid Build Coastguard Worker                x = create_input((100,), device, dtype)
1190*da0073e9SAndroid Build Coastguard Worker                x[index] = nan
1191*da0073e9SAndroid Build Coastguard Worker                if not skip_indices:
1192*da0073e9SAndroid Build Coastguard Worker                    result = torchfn(x, 0)
1193*da0073e9SAndroid Build Coastguard Worker                    v = get_values(result)
1194*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(v, nan)
1195*da0073e9SAndroid Build Coastguard Worker                    if isinstance(result, tuple):
1196*da0073e9SAndroid Build Coastguard Worker                        i = result[1]
1197*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(i, index)
1198*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torchfn(x), nan)
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1201*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1202*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double)
1203*da0073e9SAndroid Build Coastguard Worker    def test_max(self, device, dtype):
1204*da0073e9SAndroid Build Coastguard Worker        self._test_minmax_helper(torch.max, np.amax, device, dtype)
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1207*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1208*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double)
1209*da0073e9SAndroid Build Coastguard Worker    def test_min(self, device, dtype):
1210*da0073e9SAndroid Build Coastguard Worker        self._test_minmax_helper(torch.min, np.amin, device, dtype)
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1213*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1214*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double)
1215*da0073e9SAndroid Build Coastguard Worker    def test_amin(self, device, dtype):
1216*da0073e9SAndroid Build Coastguard Worker        self._test_minmax_helper(torch.amin, np.amin, device, dtype)
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1219*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1220*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
1221*da0073e9SAndroid Build Coastguard Worker    def test_amax(self, device, dtype):
1222*da0073e9SAndroid Build Coastguard Worker        self._test_minmax_helper(torch.amax, np.amax, device, dtype)
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1225*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
1226*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
1227*da0073e9SAndroid Build Coastguard Worker    def test_aminmax(self, device, dtype):
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker        def _amin_wrapper(x, dim=None, keepdims=False):
1230*da0073e9SAndroid Build Coastguard Worker            return torch.aminmax(x, dim=dim, keepdim=keepdims)[0]
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker        def _amax_wrapper(x, dim=None, keepdims=False):
1233*da0073e9SAndroid Build Coastguard Worker            return torch.aminmax(x, dim=dim, keepdim=keepdims)[1]
1234*da0073e9SAndroid Build Coastguard Worker
1235*da0073e9SAndroid Build Coastguard Worker        self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype)
1236*da0073e9SAndroid Build Coastguard Worker        self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype)
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1239*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
1240*da0073e9SAndroid Build Coastguard Worker    def test_invalid_0dim_aminmax(self, device, dtype):
1241*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'not implemented'):
1242*da0073e9SAndroid Build Coastguard Worker            torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0)
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker    # TODO: bincount isn't a classic reduction -- maybe this test suite is
1245*da0073e9SAndroid Build Coastguard Worker    #   reductions and summary ops?
1246*da0073e9SAndroid Build Coastguard Worker    def test_bincount(self, device):
1247*da0073e9SAndroid Build Coastguard Worker        # negative input throws
1248*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1249*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, -1], device=device))
1250*da0073e9SAndroid Build Coastguard Worker        # n-d input, with n > 1 throws
1251*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1252*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
1253*da0073e9SAndroid Build Coastguard Worker        # floating input type throws
1254*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'not implemented'):
1255*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1., 0.3], device=device))
1256*da0073e9SAndroid Build Coastguard Worker        # minlength < 0 throws
1257*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
1258*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, 3], device=device),
1259*da0073e9SAndroid Build Coastguard Worker                           torch.tensor([.2, .2], device=device),
1260*da0073e9SAndroid Build Coastguard Worker                           minlength=-1)
1261*da0073e9SAndroid Build Coastguard Worker        # n-d weights, with n > 1 throws
1262*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, '1-d'):
1263*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, 0], device=device),
1264*da0073e9SAndroid Build Coastguard Worker                           torch.tensor([[1., 0.3], [1., 0.3]], device=device))
1265*da0073e9SAndroid Build Coastguard Worker        # input and weights dim mismatch
1266*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'same length'):
1267*da0073e9SAndroid Build Coastguard Worker            torch.bincount(torch.tensor([1, 0], device=device),
1268*da0073e9SAndroid Build Coastguard Worker                           torch.tensor([1., 0.3, 0.5], device=device))
1269*da0073e9SAndroid Build Coastguard Worker        # 1-d input with no elements and default minlength
1270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
1271*da0073e9SAndroid Build Coastguard Worker                         torch.zeros(0, dtype=torch.long, device=device))
1272*da0073e9SAndroid Build Coastguard Worker        # 1-d input with no elements and specified minlength
1273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
1274*da0073e9SAndroid Build Coastguard Worker                         torch.zeros(10, dtype=torch.long, device=device))
1275*da0073e9SAndroid Build Coastguard Worker
1276*da0073e9SAndroid Build Coastguard Worker        # test tensor method without weights
1277*da0073e9SAndroid Build Coastguard Worker        long_counts = torch.tensor(
1278*da0073e9SAndroid Build Coastguard Worker            [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
1279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1280*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
1281*da0073e9SAndroid Build Coastguard Worker            long_counts)
1282*da0073e9SAndroid Build Coastguard Worker        # test avoiding overflow for uint8 (#76979)
1283*da0073e9SAndroid Build Coastguard Worker        count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
1284*da0073e9SAndroid Build Coastguard Worker        count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
1285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count_uint8, count_int16)
1286*da0073e9SAndroid Build Coastguard Worker        # test minlength functionality
1287*da0073e9SAndroid Build Coastguard Worker        int_counts = torch.bincount(
1288*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 1, 1, 1], device=device), minlength=5)
1289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1290*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
1291*da0073e9SAndroid Build Coastguard Worker            int_counts)
1292*da0073e9SAndroid Build Coastguard Worker        # test weights
1293*da0073e9SAndroid Build Coastguard Worker        byte_counts = torch.bincount(
1294*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 1, 1, 4], device=device),
1295*da0073e9SAndroid Build Coastguard Worker            torch.tensor([.1, .2, .3, .4, .5], device=device))
1296*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1297*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
1298*da0073e9SAndroid Build Coastguard Worker        byte_counts = torch.bincount(
1299*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 1, 1, 4], device=device),
1300*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
1301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1302*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts)
1303*da0073e9SAndroid Build Coastguard Worker        # test non-contiguous inputs and weights
1304*da0073e9SAndroid Build Coastguard Worker        inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
1305*da0073e9SAndroid Build Coastguard Worker        weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
1306*da0073e9SAndroid Build Coastguard Worker        for i in [0, 1]:
1307*da0073e9SAndroid Build Coastguard Worker            assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
1308*da0073e9SAndroid Build Coastguard Worker            assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
1309*da0073e9SAndroid Build Coastguard Worker        # inputs are non-contiguous but weights are contiguous
1310*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
1311*da0073e9SAndroid Build Coastguard Worker        # inputs and weights are non-contiguous
1312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1313*da0073e9SAndroid Build Coastguard Worker            inputs[:, 1].bincount(weights[:, 1]),
1314*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1315*da0073e9SAndroid Build Coastguard Worker        # weights are non-contiguous but inputs are contiguous
1316*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
1317*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker        # test bincount on non-contiguous slices
1320*da0073e9SAndroid Build Coastguard Worker        all0s = torch.zeros((32, 2), dtype=torch.int64, device=device)
1321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        all1s = torch.ones((32, 2), dtype=torch.int64, device=device)
1324*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker        # test large number of bins - global memory use
1327*da0073e9SAndroid Build Coastguard Worker        big_exp = torch.zeros(10000000, device=device)
1328*da0073e9SAndroid Build Coastguard Worker        big_exp[-1] = 50.0
1329*da0073e9SAndroid Build Coastguard Worker        big_w = torch.tensor([.5] * 100, device=device)
1330*da0073e9SAndroid Build Coastguard Worker        big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w)
1331*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(big_exp, big_out)
1332*da0073e9SAndroid Build Coastguard Worker        # test large input size
1333*da0073e9SAndroid Build Coastguard Worker        big_exp = torch.zeros(2, device=device, dtype=torch.int64)
1334*da0073e9SAndroid Build Coastguard Worker        big_exp[1] = 1000000
1335*da0073e9SAndroid Build Coastguard Worker        big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount()
1336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(big_exp, big_out)
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    # TODO: how many var stability tests are there?
1339*da0073e9SAndroid Build Coastguard Worker    def test_var_stability2(self, device):
1340*da0073e9SAndroid Build Coastguard Worker        tensor = torch.FloatTensor([2281.5, 2281.25]).to(device)
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker        # Stability for inner dim
1343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(0), 0.03125)
1344*da0073e9SAndroid Build Coastguard Worker
1345*da0073e9SAndroid Build Coastguard Worker        # General stability
1346*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(), 0.03125)
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker        # Stability for outer dimensions
1349*da0073e9SAndroid Build Coastguard Worker        tensor = tensor.unsqueeze(1)
1350*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.var(0), 0.03125)
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1353*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.float16)
1354*da0073e9SAndroid Build Coastguard Worker    def test_sum_noncontig_lowp(self, device, dtype) -> None:
1355*da0073e9SAndroid Build Coastguard Worker        dim_sequences = {
1356*da0073e9SAndroid Build Coastguard Worker            2: [0, 1],
1357*da0073e9SAndroid Build Coastguard Worker            3: [0, 1, 2],
1358*da0073e9SAndroid Build Coastguard Worker            4: [0, 1, 2, 3],
1359*da0073e9SAndroid Build Coastguard Worker            5: [0, 1, 2, 3, 4],
1360*da0073e9SAndroid Build Coastguard Worker        }
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        def create_noncontig_inputs(x, ndim):
1363*da0073e9SAndroid Build Coastguard Worker            if ndim == 2:
1364*da0073e9SAndroid Build Coastguard Worker                return x[::2, ::2]
1365*da0073e9SAndroid Build Coastguard Worker            elif ndim == 3:
1366*da0073e9SAndroid Build Coastguard Worker                return x[::2, ::2, ::2]
1367*da0073e9SAndroid Build Coastguard Worker            elif ndim == 4:
1368*da0073e9SAndroid Build Coastguard Worker                return x[::2, ::2, ::2, ::2]
1369*da0073e9SAndroid Build Coastguard Worker            elif ndim == 5:
1370*da0073e9SAndroid Build Coastguard Worker                return x[::2, ::2, ::2, ::2, ::2]
1371*da0073e9SAndroid Build Coastguard Worker
1372*da0073e9SAndroid Build Coastguard Worker        def helper(self, shape, reduce_dims, device, dtype):
1373*da0073e9SAndroid Build Coastguard Worker            for permute_list in list(permutations(dim_sequences[len(shape)], len(shape))):
1374*da0073e9SAndroid Build Coastguard Worker                x = torch.ones(shape, device=device, dtype=dtype)
1375*da0073e9SAndroid Build Coastguard Worker                x = create_noncontig_inputs(x, len(shape))
1376*da0073e9SAndroid Build Coastguard Worker                x_trans = x.permute(permute_list)
1377*da0073e9SAndroid Build Coastguard Worker                x_sum = torch.sum(x_trans, reduce_dims)
1378*da0073e9SAndroid Build Coastguard Worker                x_trans_ref = x_trans.float()
1379*da0073e9SAndroid Build Coastguard Worker                x_sum_ref = torch.sum(x_trans_ref, reduce_dims)
1380*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_sum, x_sum_ref.to(dtype=dtype))
1381*da0073e9SAndroid Build Coastguard Worker
1382*da0073e9SAndroid Build Coastguard Worker        shapes = [
1383*da0073e9SAndroid Build Coastguard Worker            (50, 50),
1384*da0073e9SAndroid Build Coastguard Worker            (50, 50, 50),
1385*da0073e9SAndroid Build Coastguard Worker            (10, 50, 30, 30),
1386*da0073e9SAndroid Build Coastguard Worker            (10, 5, 10, 50, 7),
1387*da0073e9SAndroid Build Coastguard Worker        ]
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker        for shape in shapes:
1390*da0073e9SAndroid Build Coastguard Worker            for i in range(1, len(shape) + 1):
1391*da0073e9SAndroid Build Coastguard Worker                reduce_dims = list(combinations(dim_sequences[len(shape)], i))
1392*da0073e9SAndroid Build Coastguard Worker                for reduce_dim in reduce_dims:
1393*da0073e9SAndroid Build Coastguard Worker                    helper(self, shape, reduce_dim, device, dtype)
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker
1396*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1397*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bool, torch.double)
1398*da0073e9SAndroid Build Coastguard Worker    def test_sum_all(self, device, dtype) -> None:
1399*da0073e9SAndroid Build Coastguard Worker        def check_sum_all(tensor: torch.Tensor) -> None:
1400*da0073e9SAndroid Build Coastguard Worker            pylist = tensor.reshape(-1).tolist()
1401*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tensor.sum(), sum(pylist))
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        if dtype != torch.bool:
1404*da0073e9SAndroid Build Coastguard Worker            check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device))
1405*da0073e9SAndroid Build Coastguard Worker            check_sum_all(torch.randn(200000, dtype=dtype, device=device))
1406*da0073e9SAndroid Build Coastguard Worker            check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0])
1407*da0073e9SAndroid Build Coastguard Worker        else:
1408*da0073e9SAndroid Build Coastguard Worker            check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device))
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker    def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
1411*da0073e9SAndroid Build Coastguard Worker                                            memory_format, compare_data=True, default_is_preserve=False):
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker        # xc is a channels last tensor
1416*da0073e9SAndroid Build Coastguard Worker        xc = input_generator_fn(device)
1417*da0073e9SAndroid Build Coastguard Worker        # xc is not memory dense, but looks like channels last
1418*da0073e9SAndroid Build Coastguard Worker        if memory_format == torch.channels_last:
1419*da0073e9SAndroid Build Coastguard Worker            xc = xc[..., ::2, ::2]
1420*da0073e9SAndroid Build Coastguard Worker        else:
1421*da0073e9SAndroid Build Coastguard Worker            xc = xc[..., ::2, ::2, ::2]
1422*da0073e9SAndroid Build Coastguard Worker
1423*da0073e9SAndroid Build Coastguard Worker        clone = transformation_fn(xc, memory_format=torch.preserve_format)
1424*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(clone.is_contiguous())
1425*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1426*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(xc.is_contiguous())
1427*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(xc.is_contiguous(memory_format=memory_format))
1428*da0073e9SAndroid Build Coastguard Worker        if compare_data:
1429*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xc, clone.to(xc))
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker        xc = input_generator_fn(device)
1432*da0073e9SAndroid Build Coastguard Worker        clone = transformation_fn(xc, memory_format=torch.contiguous_format)
1433*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(clone.is_contiguous())
1434*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1435*da0073e9SAndroid Build Coastguard Worker        if compare_data:
1436*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xc, clone.to(xc))
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker        xc = input_generator_fn(device)
1439*da0073e9SAndroid Build Coastguard Worker        clone = transformation_fn(xc)
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker        if default_is_preserve:
1442*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(clone.is_contiguous())
1443*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1444*da0073e9SAndroid Build Coastguard Worker        else:
1445*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(clone.is_contiguous())
1446*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1447*da0073e9SAndroid Build Coastguard Worker        if compare_data:
1448*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(xc, clone.to(xc))
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
1451*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
1452*da0073e9SAndroid Build Coastguard Worker            permutation = list(range(len(x.shape)))
1453*da0073e9SAndroid Build Coastguard Worker            random.shuffle(permutation)
1454*da0073e9SAndroid Build Coastguard Worker            x = x.permute(permutation)
1455*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1458*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
1459*da0073e9SAndroid Build Coastguard Worker    def test_sum_out(self, device, dtype: torch.dtype) -> None:
1460*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 100, dtype=dtype, device=device)
1461*da0073e9SAndroid Build Coastguard Worker        res1 = torch.sum(x, 1)
1462*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
1463*da0073e9SAndroid Build Coastguard Worker        torch.sum(x, 1, out=res2)
1464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
1465*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 100, 100, dtype=dtype, device=device)
1466*da0073e9SAndroid Build Coastguard Worker        res1 = x.sum(2).sum(1)
1467*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
1468*da0073e9SAndroid Build Coastguard Worker        torch.sum(x, (2, 1), out=res2)
1469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
1472*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.float32)
1473*da0073e9SAndroid Build Coastguard Worker    def test_prod_gpu(self, device, dtype):
1474*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device)
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker        # Check all combinations: fp16 input - fp16 output, fp16 input - fp32
1477*da0073e9SAndroid Build Coastguard Worker        # output, fp32 input - fp16 output, fp32 input - fp32 output
1478*da0073e9SAndroid Build Coastguard Worker        for dtype_output in [torch.float16, torch.float32]:
1479*da0073e9SAndroid Build Coastguard Worker            result_expected = torch.tensor(2592, dtype=dtype_output, device=device)
1480*da0073e9SAndroid Build Coastguard Worker            output = torch.prod(x, dtype=dtype_output)
1481*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, result_expected)
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker            output = x.prod(dtype=dtype_output)
1484*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(output, result_expected)
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1487*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float)
1488*da0073e9SAndroid Build Coastguard Worker    def test_prod(self, device, dtype):
1489*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 100, dtype=dtype, device=device)
1490*da0073e9SAndroid Build Coastguard Worker        res1 = torch.prod(x, 1)
1491*da0073e9SAndroid Build Coastguard Worker        res2 = torch.tensor((), dtype=dtype, device=device)
1492*da0073e9SAndroid Build Coastguard Worker        torch.prod(x, 1, out=res2)
1493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2)
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1496*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
1497*da0073e9SAndroid Build Coastguard Worker    def test_prod_lowp(self, device, dtype):
1498*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 100, dtype=dtype, device=device)
1499*da0073e9SAndroid Build Coastguard Worker        x_ref = x.float()
1500*da0073e9SAndroid Build Coastguard Worker        res1 = torch.prod(x, 1)
1501*da0073e9SAndroid Build Coastguard Worker        res2 = torch.prod(x_ref, 1)
1502*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2.to(dtype=dtype))
1503*da0073e9SAndroid Build Coastguard Worker        res1 = torch.prod(x, 0)
1504*da0073e9SAndroid Build Coastguard Worker        res2 = torch.prod(x_ref, 0)
1505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res1, res2.to(dtype=dtype))
1506*da0073e9SAndroid Build Coastguard Worker
1507*da0073e9SAndroid Build Coastguard Worker    def test_prod_bool(self, device):
1508*da0073e9SAndroid Build Coastguard Worker        vals = [
1509*da0073e9SAndroid Build Coastguard Worker            [True, True],
1510*da0073e9SAndroid Build Coastguard Worker            [True, False],
1511*da0073e9SAndroid Build Coastguard Worker            [False, False],
1512*da0073e9SAndroid Build Coastguard Worker            [],
1513*da0073e9SAndroid Build Coastguard Worker            [False] * 256,  # https://github.com/pytorch/pytorch/issues/127866
1514*da0073e9SAndroid Build Coastguard Worker        ]
1515*da0073e9SAndroid Build Coastguard Worker        for val in vals:
1516*da0073e9SAndroid Build Coastguard Worker            result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
1517*da0073e9SAndroid Build Coastguard Worker            expect = np.prod(np.array(val), dtype=bool)
1518*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expect)
1519*da0073e9SAndroid Build Coastguard Worker
1520*da0073e9SAndroid Build Coastguard Worker            result = torch.prod(torch.tensor(val, device=device)).item()
1521*da0073e9SAndroid Build Coastguard Worker            expect = np.prod(np.array(val))
1522*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expect)
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1525*da0073e9SAndroid Build Coastguard Worker    def test_max_mixed_devices(self, device):
1526*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, device=device)
1527*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1528*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(10).cuda()
1529*da0073e9SAndroid Build Coastguard Worker            indices = torch.cuda.LongTensor()
1530*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError,
1531*da0073e9SAndroid Build Coastguard Worker                              lambda: torch.max(a, 0, out=(values, indices)))
1532*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError,
1533*da0073e9SAndroid Build Coastguard Worker                              lambda: torch.amax(a, 0, out=values))
1534*da0073e9SAndroid Build Coastguard Worker
1535*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1536*da0073e9SAndroid Build Coastguard Worker    def test_min_mixed_devices(self, device):
1537*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, device=device)
1538*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
1539*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(10).cuda()
1540*da0073e9SAndroid Build Coastguard Worker            indices = torch.cuda.LongTensor()
1541*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError,
1542*da0073e9SAndroid Build Coastguard Worker                              lambda: torch.min(a, 0, out=(values, indices)))
1543*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError,
1544*da0073e9SAndroid Build Coastguard Worker                              lambda: torch.amin(a, 0, out=values))
1545*da0073e9SAndroid Build Coastguard Worker
1546*da0073e9SAndroid Build Coastguard Worker    # TODO: consider refactoring with bincount test
1547*da0073e9SAndroid Build Coastguard Worker    def test_bucketization(self, device):
1548*da0073e9SAndroid Build Coastguard Worker        values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device)
1549*da0073e9SAndroid Build Coastguard Worker        values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        # simple 1d boundary and 3d input value
1552*da0073e9SAndroid Build Coastguard Worker        boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device)
1553*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device)
1554*da0073e9SAndroid Build Coastguard Worker        output = torch.empty(2, 2, 3, device=device, dtype=torch.int64)
1555*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result)
1556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result)
1557*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1558*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result)
1559*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result)
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        # simple float 1d boundary and 1d input with output int32 type
1562*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float32, torch.float16]:
1563*da0073e9SAndroid Build Coastguard Worker            values_1d_float = values_1d.to(dtype)
1564*da0073e9SAndroid Build Coastguard Worker            boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=dtype)
1565*da0073e9SAndroid Build Coastguard Worker            expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1566*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result)
1567*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker        # multiple dimension input with 0 elements
1570*da0073e9SAndroid Build Coastguard Worker        boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64)
1571*da0073e9SAndroid Build Coastguard Worker        values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64)
1572*da0073e9SAndroid Build Coastguard Worker        expected_result = values_0_el.to(torch.int64)
1573*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result)
1574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result)
1575*da0073e9SAndroid Build Coastguard Worker
1576*da0073e9SAndroid Build Coastguard Worker        # nan input
1577*da0073e9SAndroid Build Coastguard Worker        values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64)
1578*da0073e9SAndroid Build Coastguard Worker        boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64)
1579*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor([1, 4, 2, 4], device=device)
1580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result)
1581*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor([2, 4, 3, 4], device=device)
1582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result)
1583*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.searchsorted(boundaries, values_nan, side='right'), expected_result)
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker        # type promotion and non contiguous tensors
1586*da0073e9SAndroid Build Coastguard Worker        values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32)
1587*da0073e9SAndroid Build Coastguard Worker        boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64)
1588*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device)
1589*da0073e9SAndroid Build Coastguard Worker        if self.device_type != 'xla':
1590*da0073e9SAndroid Build Coastguard Worker            self.assertWarnsRegex(
1591*da0073e9SAndroid Build Coastguard Worker                UserWarning, "tensor is non-contiguous",
1592*da0073e9SAndroid Build Coastguard Worker                lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result))
1593*da0073e9SAndroid Build Coastguard Worker        else:
1594*da0073e9SAndroid Build Coastguard Worker            # All tensors in XLA is contiguous even doing permute, no warning msg will be generate in XLA
1595*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker        # scalar type
1598*da0073e9SAndroid Build Coastguard Worker        boundaries = torch.tensor([1.5, 2.5, 3.5], device=device)
1599*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor(1, device=device)
1600*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.searchsorted(boundaries, 2), expected_result)
1601*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result)
1602*da0073e9SAndroid Build Coastguard Worker        expected_result = torch.tensor(3, device=device)
1603*da0073e9SAndroid Build Coastguard Worker        scalar_tensor_nan = torch.tensor(float('nan'), device=device)
1604*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result)
1605*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result)
1606*da0073e9SAndroid Build Coastguard Worker
1607*da0073e9SAndroid Build Coastguard Worker        # invalid input dimensions
1608*da0073e9SAndroid Build Coastguard Worker        boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device)
1609*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1610*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"):
1611*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(boundaries, values_3d)
1612*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1613*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "boundaries tensor must be 1 dimension"):
1614*da0073e9SAndroid Build Coastguard Worker            torch.bucketize(values_3d, boundaries)
1615*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1616*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "only when boundaries tensor dimension is 1"):
1617*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(boundaries, 1)
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        # incompatiable output tensor's dtype
1620*da0073e9SAndroid Build Coastguard Worker        def test_output_dtype(dtype, is_int32):
1621*da0073e9SAndroid Build Coastguard Worker            output = values_1d.to(dtype)
1622*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1623*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "output tensor's dtype is wrong"):
1624*da0073e9SAndroid Build Coastguard Worker                torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32)
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker        test_output_dtype(torch.float32, False)
1627*da0073e9SAndroid Build Coastguard Worker        test_output_dtype(torch.int32, False)
1628*da0073e9SAndroid Build Coastguard Worker        test_output_dtype(torch.int64, True)
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker        # invalid side argument
1631*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "side can only be 'left' or 'right'"):
1632*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(values_1d, values_1d, side='bad')
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker        # invalid sorter argument, wrong size
1635*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "boundary and sorter must have the same size"):
1636*da0073e9SAndroid Build Coastguard Worker            sequence = torch.rand_like(values_1d, dtype=torch.float)
1637*da0073e9SAndroid Build Coastguard Worker            _, sorted_idx = torch.sort(sequence)
1638*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(sequence, values_1d, sorter=sorted_idx[:-1])
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker        # invalid sorter argument, is not dtype long
1641*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "sorter must be a tensor of long dtype"):
1642*da0073e9SAndroid Build Coastguard Worker            sequence = torch.rand_like(values_1d, dtype=torch.float)
1643*da0073e9SAndroid Build Coastguard Worker            _, sorted_idx = torch.sort(sequence)
1644*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32))
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        # invalid sorter value, out of bound (>= innermost size)
1647*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1648*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3]))
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        # invalid sorter value, out of bound (< 0)
1651*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1652*da0073e9SAndroid Build Coastguard Worker            torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2]))
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker        # scalar type bfloat16
1655*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cpu':
1656*da0073e9SAndroid Build Coastguard Worker            def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False):
1657*da0073e9SAndroid Build Coastguard Worker                values_1d_float = values_1d.to(torch.float32)
1658*da0073e9SAndroid Build Coastguard Worker                boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32)
1659*da0073e9SAndroid Build Coastguard Worker                if values_bf16:
1660*da0073e9SAndroid Build Coastguard Worker                    values_1d_float = values_1d_float.to(torch.bfloat16)
1661*da0073e9SAndroid Build Coastguard Worker                if boundaries_bf16:
1662*da0073e9SAndroid Build Coastguard Worker                    boundaries = boundaries.to(torch.bfloat16)
1663*da0073e9SAndroid Build Coastguard Worker                expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1664*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker            test_dtype_bfloat16(True, False)
1667*da0073e9SAndroid Build Coastguard Worker            test_dtype_bfloat16(False, True)
1668*da0073e9SAndroid Build Coastguard Worker            test_dtype_bfloat16(True, True)
1669*da0073e9SAndroid Build Coastguard Worker
1670*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
1671*da0073e9SAndroid Build Coastguard Worker    def test_nansum(self, device, dtype):
1672*da0073e9SAndroid Build Coastguard Worker        args = product(
1673*da0073e9SAndroid Build Coastguard Worker            (True, False),  # noncontiguous
1674*da0073e9SAndroid Build Coastguard Worker            (0, 1, None),   # dim
1675*da0073e9SAndroid Build Coastguard Worker        )
1676*da0073e9SAndroid Build Coastguard Worker        zero = torch.zeros((), device=device, dtype=dtype)
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker        for noncontiguous, dim in args:
1679*da0073e9SAndroid Build Coastguard Worker            # Randomly scale the values
1680*da0073e9SAndroid Build Coastguard Worker            scale = random.randint(10, 100)
1681*da0073e9SAndroid Build Coastguard Worker            x = make_tensor((17, 17), device=device, dtype=dtype,
1682*da0073e9SAndroid Build Coastguard Worker                            low=-scale, high=scale, noncontiguous=noncontiguous)
1683*da0073e9SAndroid Build Coastguard Worker
1684*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
1685*da0073e9SAndroid Build Coastguard Worker                nan_mask = x < 0.2 * scale
1686*da0073e9SAndroid Build Coastguard Worker                x_nonan = torch.where(nan_mask, zero, x)
1687*da0073e9SAndroid Build Coastguard Worker                x[nan_mask] = np.nan
1688*da0073e9SAndroid Build Coastguard Worker            else:
1689*da0073e9SAndroid Build Coastguard Worker                x_nonan = x
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker            dim_kwargs = {} if dim is None else {"dim": dim}
1692*da0073e9SAndroid Build Coastguard Worker            expect = torch.sum(x_nonan, **dim_kwargs)
1693*da0073e9SAndroid Build Coastguard Worker            actual = torch.nansum(x, **dim_kwargs)
1694*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expect, actual)
1695*da0073e9SAndroid Build Coastguard Worker
1696*da0073e9SAndroid Build Coastguard Worker    def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype,
1697*da0073e9SAndroid Build Coastguard Worker                                            with_extremal=False, atol=None, rtol=None,
1698*da0073e9SAndroid Build Coastguard Worker                                            exact_dtype=True, with_keepdim=False):
1699*da0073e9SAndroid Build Coastguard Worker        # Test 0-d to 3-d tensors.
1700*da0073e9SAndroid Build Coastguard Worker        for ndims in range(0, 4):
1701*da0073e9SAndroid Build Coastguard Worker            shape = _rand_shape(ndims, min_size=5, max_size=10)
1702*da0073e9SAndroid Build Coastguard Worker            for n in range(ndims + 1):
1703*da0073e9SAndroid Build Coastguard Worker                for c in combinations(list(range(ndims)), n):
1704*da0073e9SAndroid Build Coastguard Worker                    for count_dim in permutations(c):
1705*da0073e9SAndroid Build Coastguard Worker                        # Generate Input.
1706*da0073e9SAndroid Build Coastguard Worker                        x = _generate_input(shape, dtype, device, with_extremal)
1707*da0073e9SAndroid Build Coastguard Worker
1708*da0073e9SAndroid Build Coastguard Worker                        if count_dim == ():
1709*da0073e9SAndroid Build Coastguard Worker                            # Default `dims=None` case
1710*da0073e9SAndroid Build Coastguard Worker                            self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None,
1711*da0073e9SAndroid Build Coastguard Worker                                                    atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1712*da0073e9SAndroid Build Coastguard Worker                        else:
1713*da0073e9SAndroid Build Coastguard Worker                            # With `dims: tuple of ints` case
1714*da0073e9SAndroid Build Coastguard Worker                            if with_keepdim:
1715*da0073e9SAndroid Build Coastguard Worker                                torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim)
1716*da0073e9SAndroid Build Coastguard Worker                                np_func_partial = partial(np_func, keepdims=True, axis=count_dim)
1717*da0073e9SAndroid Build Coastguard Worker                            else:
1718*da0073e9SAndroid Build Coastguard Worker                                torch_func_partial = partial(torch_func, dim=count_dim)
1719*da0073e9SAndroid Build Coastguard Worker                                np_func_partial = partial(np_func, axis=count_dim)
1720*da0073e9SAndroid Build Coastguard Worker                            self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None,
1721*da0073e9SAndroid Build Coastguard Worker                                                    atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1722*da0073e9SAndroid Build Coastguard Worker
1723*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half))
1724*da0073e9SAndroid Build Coastguard Worker    def test_count_nonzero(self, device, dtype):
1725*da0073e9SAndroid Build Coastguard Worker        self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype)
1726*da0073e9SAndroid Build Coastguard Worker        self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True)
1727*da0073e9SAndroid Build Coastguard Worker
1728*da0073e9SAndroid Build Coastguard Worker    # TODO: Investigate why the output is not close to numpy.
1729*da0073e9SAndroid Build Coastguard Worker    def _get_relaxed_tolerances_for(self, dtype):
1730*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16:
1731*da0073e9SAndroid Build Coastguard Worker            atol = 0.4
1732*da0073e9SAndroid Build Coastguard Worker            rtol = 1e-2
1733*da0073e9SAndroid Build Coastguard Worker        elif dtype == torch.float32:
1734*da0073e9SAndroid Build Coastguard Worker            atol = 7e-05
1735*da0073e9SAndroid Build Coastguard Worker            rtol = 3e-06
1736*da0073e9SAndroid Build Coastguard Worker        else:
1737*da0073e9SAndroid Build Coastguard Worker            # Default values
1738*da0073e9SAndroid Build Coastguard Worker            atol = None
1739*da0073e9SAndroid Build Coastguard Worker            rtol = None
1740*da0073e9SAndroid Build Coastguard Worker        return atol, rtol
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker    def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False):
1743*da0073e9SAndroid Build Coastguard Worker        def is_integral(dtype):
1744*da0073e9SAndroid Build Coastguard Worker            return dtype in integral_types()
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker        exact_dtype = True
1747*da0073e9SAndroid Build Coastguard Worker        # On Windows CI, the current version of `numpy` promotes all lower integers
1748*da0073e9SAndroid Build Coastguard Worker        # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking
1749*da0073e9SAndroid Build Coastguard Worker        # the exact dtype.
1750*da0073e9SAndroid Build Coastguard Worker        # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580
1751*da0073e9SAndroid Build Coastguard Worker        # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370
1752*da0073e9SAndroid Build Coastguard Worker        if IS_WINDOWS and is_integral(dtype):
1753*da0073e9SAndroid Build Coastguard Worker            exact_dtype = False
1754*da0073e9SAndroid Build Coastguard Worker        # For uint8, numpy promotes to uint64 while torch promotes to int64.
1755*da0073e9SAndroid Build Coastguard Worker        # So we must skip this as well.
1756*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.uint8:
1757*da0073e9SAndroid Build Coastguard Worker            exact_dtype = False
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker        # TODO: Investigate why the output is not close to numpy.
1760*da0073e9SAndroid Build Coastguard Worker        atol, rtol = self._get_relaxed_tolerances_for(dtype)
1761*da0073e9SAndroid Build Coastguard Worker        self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
1762*da0073e9SAndroid Build Coastguard Worker                                                 atol=atol, rtol=rtol, exact_dtype=exact_dtype,
1763*da0073e9SAndroid Build Coastguard Worker                                                 with_keepdim=with_keepdim, with_extremal=with_extremal)
1764*da0073e9SAndroid Build Coastguard Worker
1765*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1766*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1767*da0073e9SAndroid Build Coastguard Worker    def test_sum_vs_numpy(self, device, dtype):
1768*da0073e9SAndroid Build Coastguard Worker        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype)
1769*da0073e9SAndroid Build Coastguard Worker        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True)
1770*da0073e9SAndroid Build Coastguard Worker        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True)
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1773*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1774*da0073e9SAndroid Build Coastguard Worker    def test_nansum_vs_numpy(self, device, dtype):
1775*da0073e9SAndroid Build Coastguard Worker        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype)
1776*da0073e9SAndroid Build Coastguard Worker        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True)
1777*da0073e9SAndroid Build Coastguard Worker        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True)
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1780*da0073e9SAndroid Build Coastguard Worker    @dtypes(*complex_types())
1781*da0073e9SAndroid Build Coastguard Worker    def test_nansum_complex(self, device, dtype):
1782*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((3, 3, 3), device=device, dtype=dtype)
1783*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"):
1784*da0073e9SAndroid Build Coastguard Worker            torch.nansum(x)
1785*da0073e9SAndroid Build Coastguard Worker
1786*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half))
1787*da0073e9SAndroid Build Coastguard Worker    def test_nansum_out_dtype(self, device, dtype):
1788*da0073e9SAndroid Build Coastguard Worker        out_dtype = dtype
1789*da0073e9SAndroid Build Coastguard Worker        inp_dtypes = all_types_and(torch.half) if out_dtype.is_floating_point else integral_types()
1790*da0073e9SAndroid Build Coastguard Worker        for inp_dtype in inp_dtypes:
1791*da0073e9SAndroid Build Coastguard Worker            # TODO: Investigate why the output is not close to numpy.
1792*da0073e9SAndroid Build Coastguard Worker            atol, rtol = self._get_relaxed_tolerances_for(dtype)
1793*da0073e9SAndroid Build Coastguard Worker            shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10)
1794*da0073e9SAndroid Build Coastguard Worker            x = _generate_input(shape, inp_dtype, device, with_extremal=False)
1795*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.nansum, dtype=out_dtype)
1796*da0073e9SAndroid Build Coastguard Worker            np_out_dtype = torch_to_numpy_dtype_dict[out_dtype]
1797*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.nansum, dtype=np_out_dtype)
1798*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None, atol=atol, rtol=rtol)
1799*da0073e9SAndroid Build Coastguard Worker
1800*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half))
1801*da0073e9SAndroid Build Coastguard Worker    def test_argminmax_multiple(self, device, dtype):
1802*da0073e9SAndroid Build Coastguard Worker        # Case: All Ones
1803*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(3, 3, device=device, dtype=dtype)
1804*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch.argmax, np.argmax, t)
1805*da0073e9SAndroid Build Coastguard Worker        self.compare_with_numpy(torch.argmin, np.argmin, t)
1806*da0073e9SAndroid Build Coastguard Worker
1807*da0073e9SAndroid Build Coastguard Worker        # Case: With single `nan` present.
1808*da0073e9SAndroid Build Coastguard Worker        if dtype in floating_types_and(torch.half, torch.bfloat16):
1809*da0073e9SAndroid Build Coastguard Worker            t[2, 2] = float('nan')
1810*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.argmax, np.argmax, t)
1811*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.argmin, np.argmin, t)
1812*da0073e9SAndroid Build Coastguard Worker
1813*da0073e9SAndroid Build Coastguard Worker        # Case: Randomly Generated Tensors
1814*da0073e9SAndroid Build Coastguard Worker        for ndims in range(1, 5):
1815*da0073e9SAndroid Build Coastguard Worker            shape = _rand_shape(ndims, min_size=5, max_size=10)
1816*da0073e9SAndroid Build Coastguard Worker            for with_extremal in [False, True]:
1817*da0073e9SAndroid Build Coastguard Worker                for contiguous in [False, True]:
1818*da0073e9SAndroid Build Coastguard Worker                    # Generate Input.
1819*da0073e9SAndroid Build Coastguard Worker                    x = _generate_input(shape, dtype, device, with_extremal)
1820*da0073e9SAndroid Build Coastguard Worker
1821*da0073e9SAndroid Build Coastguard Worker                    if dtype == torch.half:
1822*da0073e9SAndroid Build Coastguard Worker                        max_val = torch.max(x.to(torch.float))
1823*da0073e9SAndroid Build Coastguard Worker                        min_val = torch.min(x.to(torch.float))
1824*da0073e9SAndroid Build Coastguard Worker                    else:
1825*da0073e9SAndroid Build Coastguard Worker                        max_val = torch.max(x)
1826*da0073e9SAndroid Build Coastguard Worker                        min_val = torch.min(x)
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker                    mask = torch.randn(x.shape) > 0.5
1829*da0073e9SAndroid Build Coastguard Worker                    x[mask] = torch.tensor(max_val + 1, dtype=dtype)
1830*da0073e9SAndroid Build Coastguard Worker
1831*da0073e9SAndroid Build Coastguard Worker                    mask = torch.randn(x.shape) > 0.5
1832*da0073e9SAndroid Build Coastguard Worker                    x[mask] = torch.tensor(min_val - 1, dtype=dtype)
1833*da0073e9SAndroid Build Coastguard Worker
1834*da0073e9SAndroid Build Coastguard Worker                    if not contiguous:
1835*da0073e9SAndroid Build Coastguard Worker                        x = x.T
1836*da0073e9SAndroid Build Coastguard Worker
1837*da0073e9SAndroid Build Coastguard Worker                    self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None)
1838*da0073e9SAndroid Build Coastguard Worker                    self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None)
1839*da0073e9SAndroid Build Coastguard Worker
1840*da0073e9SAndroid Build Coastguard Worker                    # Verify indices returned by max and min.
1841*da0073e9SAndroid Build Coastguard Worker                    if dtype != torch.half:
1842*da0073e9SAndroid Build Coastguard Worker                        rand_dim = random.randint(0, ndims - 1)
1843*da0073e9SAndroid Build Coastguard Worker                        self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1],
1844*da0073e9SAndroid Build Coastguard Worker                                                lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None)
1845*da0073e9SAndroid Build Coastguard Worker                        self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1],
1846*da0073e9SAndroid Build Coastguard Worker                                                lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None)
1847*da0073e9SAndroid Build Coastguard Worker
1848*da0073e9SAndroid Build Coastguard Worker        def verify_against_numpy(t):
1849*da0073e9SAndroid Build Coastguard Worker            # Argmax
1850*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.argmax, dim=1)
1851*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.argmax, axis=1)
1852*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, t)
1853*da0073e9SAndroid Build Coastguard Worker            # Non-contiguous input
1854*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, t.T)
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker            # Verify indices returned by max.
1857*da0073e9SAndroid Build Coastguard Worker            if dtype != torch.half:
1858*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1859*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1860*da0073e9SAndroid Build Coastguard Worker
1861*da0073e9SAndroid Build Coastguard Worker            # Argmin
1862*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.argmin, dim=1)
1863*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.argmin, axis=1)
1864*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, t)
1865*da0073e9SAndroid Build Coastguard Worker            # Non-contiguous input
1866*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, t.T)
1867*da0073e9SAndroid Build Coastguard Worker
1868*da0073e9SAndroid Build Coastguard Worker            # Verify indices returned by min.
1869*da0073e9SAndroid Build Coastguard Worker            if dtype != torch.half:
1870*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1871*da0073e9SAndroid Build Coastguard Worker                self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1872*da0073e9SAndroid Build Coastguard Worker
1873*da0073e9SAndroid Build Coastguard Worker        # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998
1874*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 5],
1875*da0073e9SAndroid Build Coastguard Worker                          [2, 10],
1876*da0073e9SAndroid Build Coastguard Worker                          [3, 3]], device=device, dtype=dtype)
1877*da0073e9SAndroid Build Coastguard Worker        verify_against_numpy(t)
1878*da0073e9SAndroid Build Coastguard Worker
1879*da0073e9SAndroid Build Coastguard Worker        # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998
1880*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([[1, 5],
1881*da0073e9SAndroid Build Coastguard Worker                          [2, 10],
1882*da0073e9SAndroid Build Coastguard Worker                          [0, 0]], device=device, dtype=dtype)
1883*da0073e9SAndroid Build Coastguard Worker        verify_against_numpy(t)
1884*da0073e9SAndroid Build Coastguard Worker
1885*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1886*da0073e9SAndroid Build Coastguard Worker    def test_all_any_vs_numpy(self, device, dtype):
1887*da0073e9SAndroid Build Coastguard Worker        # Note [all, any uint8 compatibility]: However for compatibility reason,
1888*da0073e9SAndroid Build Coastguard Worker        # for `uint8`, they return Tensor of same dtype `uint8`.
1889*da0073e9SAndroid Build Coastguard Worker        # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
1890*da0073e9SAndroid Build Coastguard Worker        exact_dtype = True if dtype != torch.uint8 else False
1891*da0073e9SAndroid Build Coastguard Worker
1892*da0073e9SAndroid Build Coastguard Worker        def _test_all_any(x):
1893*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.all, np.all, x)
1894*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch.any, np.any, x)
1895*da0073e9SAndroid Build Coastguard Worker
1896*da0073e9SAndroid Build Coastguard Worker        def _test_all_any_with_dim(x, dim):
1897*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.all, dim=dim)
1898*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.all, axis=dim)
1899*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1900*da0073e9SAndroid Build Coastguard Worker
1901*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.any, dim=dim)
1902*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.any, axis=dim)
1903*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1904*da0073e9SAndroid Build Coastguard Worker
1905*da0073e9SAndroid Build Coastguard Worker        def _test_out_variant(x, dim):
1906*da0073e9SAndroid Build Coastguard Worker            out = torch.empty_like(x)
1907*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.bool or dtype == torch.uint8:
1908*da0073e9SAndroid Build Coastguard Worker                expected = torch.all(x, dim)
1909*da0073e9SAndroid Build Coastguard Worker                torch.all(x, dim, out=out)
1910*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, out)
1911*da0073e9SAndroid Build Coastguard Worker
1912*da0073e9SAndroid Build Coastguard Worker                expected = torch.any(x, dim)
1913*da0073e9SAndroid Build Coastguard Worker                torch.any(x, dim, out=out)
1914*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, out)
1915*da0073e9SAndroid Build Coastguard Worker            else:
1916*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"):
1917*da0073e9SAndroid Build Coastguard Worker                    torch.all(x, dim, out=out)
1918*da0073e9SAndroid Build Coastguard Worker
1919*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"):
1920*da0073e9SAndroid Build Coastguard Worker                    torch.any(x, dim, out=out)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        def _test_all_any_with_dim_keepdim(x, dim, keepdim):
1923*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.all, dim=dim, keepdim=keepdim)
1924*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.all, axis=dim, keepdims=keepdim)
1925*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1926*da0073e9SAndroid Build Coastguard Worker
1927*da0073e9SAndroid Build Coastguard Worker            torch_fn = partial(torch.any, dim=dim, keepdim=keepdim)
1928*da0073e9SAndroid Build Coastguard Worker            np_fn = partial(np.any, axis=dim, keepdims=keepdim)
1929*da0073e9SAndroid Build Coastguard Worker            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1930*da0073e9SAndroid Build Coastguard Worker
1931*da0073e9SAndroid Build Coastguard Worker        def _test_output_dtype(x):
1932*da0073e9SAndroid Build Coastguard Worker            # This test will fail once the functions return bool output
1933*da0073e9SAndroid Build Coastguard Worker            # for uint8 input.
1934*da0073e9SAndroid Build Coastguard Worker            expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool
1935*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.all(x).dtype, expected_dtype)
1936*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.any(x).dtype, expected_dtype)
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype)
1939*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype)
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker        for ndim in range(5):
1942*da0073e9SAndroid Build Coastguard Worker            shape = _rand_shape(ndim, 1, 5)
1943*da0073e9SAndroid Build Coastguard Worker            x = _generate_input(shape, dtype, device, with_extremal=False)
1944*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x)
1945*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x.T)
1946*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x[..., ::2])
1947*da0073e9SAndroid Build Coastguard Worker
1948*da0073e9SAndroid Build Coastguard Worker            x = _generate_input(shape, dtype, device, with_extremal=True)
1949*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x)
1950*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x.T)
1951*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x[..., ::2])
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros_like(x)
1954*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x)
1955*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x.T)
1956*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x[..., ::2])
1957*da0073e9SAndroid Build Coastguard Worker
1958*da0073e9SAndroid Build Coastguard Worker            x = torch.ones_like(x)
1959*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x)
1960*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x.T)
1961*da0073e9SAndroid Build Coastguard Worker            _test_all_any(x[..., ::2])
1962*da0073e9SAndroid Build Coastguard Worker            _test_output_dtype(x)
1963*da0073e9SAndroid Build Coastguard Worker            for dim in range(ndim):
1964*da0073e9SAndroid Build Coastguard Worker                x = _generate_input(shape, dtype, device, with_extremal=False)
1965*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x, dim)
1966*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x.T, dim)
1967*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x[..., ::2], dim)
1968*da0073e9SAndroid Build Coastguard Worker                _test_out_variant(x, dim)
1969*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1970*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1971*da0073e9SAndroid Build Coastguard Worker
1972*da0073e9SAndroid Build Coastguard Worker                x = _generate_input(shape, dtype, device, with_extremal=True)
1973*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x, dim)
1974*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x.T, dim)
1975*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x[..., ::2], dim)
1976*da0073e9SAndroid Build Coastguard Worker                _test_out_variant(x, dim)
1977*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1978*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1979*da0073e9SAndroid Build Coastguard Worker
1980*da0073e9SAndroid Build Coastguard Worker                x = torch.zeros_like(x)
1981*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x, dim)
1982*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x.T, dim)
1983*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x[..., ::2], dim)
1984*da0073e9SAndroid Build Coastguard Worker                _test_out_variant(x, dim)
1985*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1986*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1987*da0073e9SAndroid Build Coastguard Worker
1988*da0073e9SAndroid Build Coastguard Worker                x = torch.ones_like(x)
1989*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x, dim)
1990*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x.T, dim)
1991*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim(x[..., ::2], dim)
1992*da0073e9SAndroid Build Coastguard Worker                _test_out_variant(x, dim)
1993*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1994*da0073e9SAndroid Build Coastguard Worker                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker    # TODO: part of this test covers torch.norm, with should be covered by test_linalg
1997*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
1998*da0073e9SAndroid Build Coastguard Worker    def test_repeated_dim(self, device):
1999*da0073e9SAndroid Build Coastguard Worker        ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
2000*da0073e9SAndroid Build Coastguard Worker               torch.norm]
2001*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, 3, 3, device=device)
2002*da0073e9SAndroid Build Coastguard Worker
2003*da0073e9SAndroid Build Coastguard Worker        error_msg = r'appears multiple times in the list of dims'
2004*da0073e9SAndroid Build Coastguard Worker        for op in ops:
2005*da0073e9SAndroid Build Coastguard Worker            for dim in [(0, 0), (0, -4)]:
2006*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, error_msg):
2007*da0073e9SAndroid Build Coastguard Worker                    op(x, dim=dim)
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker    # TODO: update this test to comapre against NumPy
2010*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2011*da0073e9SAndroid Build Coastguard Worker    def test_var(self, device):
2012*da0073e9SAndroid Build Coastguard Worker        cpu_tensor = torch.randn(2, 3, 3)
2013*da0073e9SAndroid Build Coastguard Worker        device_tensor = cpu_tensor.to(device)
2014*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.var(), cpu_tensor.var())
2015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.var(1), cpu_tensor.var(1))
2016*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
2017*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.std(), cpu_tensor.std())
2018*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.std(1), cpu_tensor.std(1))
2019*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
2020*da0073e9SAndroid Build Coastguard Worker
2021*da0073e9SAndroid Build Coastguard Worker        cpu_tensor = torch.randn(100)
2022*da0073e9SAndroid Build Coastguard Worker        device_tensor = cpu_tensor.to(device)
2023*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(device_tensor.var(), cpu_tensor.var())
2024*da0073e9SAndroid Build Coastguard Worker
2025*da0073e9SAndroid Build Coastguard Worker    # TODO: update this test to compare against NumPy
2026*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2027*da0073e9SAndroid Build Coastguard Worker    def test_var_large_input(self, device):
2028*da0073e9SAndroid Build Coastguard Worker        # Large, not-nice input
2029*da0073e9SAndroid Build Coastguard Worker        cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67)
2030*da0073e9SAndroid Build Coastguard Worker        device_tensor = cpu_tensor.to(device)
2031*da0073e9SAndroid Build Coastguard Worker
2032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cpu_tensor.var(2), device_tensor.var(2))
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker    # TODO: update this to compare against NumPy instead of CPU
2035*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2036*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double)
2037*da0073e9SAndroid Build Coastguard Worker    def test_sum_noncontig(self, device, dtype):
2038*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2)
2039*da0073e9SAndroid Build Coastguard Worker        y = x.cpu()
2040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum().cpu(), y.sum())
2041*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2)))
2042*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3)))
2043*da0073e9SAndroid Build Coastguard Worker
2044*da0073e9SAndroid Build Coastguard Worker    # TODO: update this to compare against NumPy instead of CPU
2045*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2046*da0073e9SAndroid Build Coastguard Worker    def test_min_max_nan(self, device):
2047*da0073e9SAndroid Build Coastguard Worker        tests = [(lambda x: x.min(), 'min'),
2048*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.max(), 'max'),
2049*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.amin(), 'amin'),
2050*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.amax(), 'amax'),
2051*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.min(0).values, 'min_dim'),
2052*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.max(0).values, 'max_dim'),
2053*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.amin(0), 'amin_dim'),
2054*da0073e9SAndroid Build Coastguard Worker                 (lambda x: x.amax(0), 'amax_dim')]
2055*da0073e9SAndroid Build Coastguard Worker        for f, name in tests:
2056*da0073e9SAndroid Build Coastguard Worker            a = torch.arange(25.0).view(5, 5)
2057*da0073e9SAndroid Build Coastguard Worker            a[2, 2] = nan
2058*da0073e9SAndroid Build Coastguard Worker            actual = f(a.to(device)).cpu()
2059*da0073e9SAndroid Build Coastguard Worker            expected = f(a).cpu()
2060*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg=f'nans for {name}')
2061*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual[~torch.isnan(actual)],
2062*da0073e9SAndroid Build Coastguard Worker                             expected[~torch.isnan(expected)], msg=f'nans for {name}')
2063*da0073e9SAndroid Build Coastguard Worker
2064*da0073e9SAndroid Build Coastguard Worker    # TODO: make this test generic using OpInfos
2065*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2066*da0073e9SAndroid Build Coastguard Worker    def test_sum_cpu_device_mismatch(self, device):
2067*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(20, dtype=torch.float32, device=device)
2068*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1, dtype=torch.float32)
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker        err_string = f"Expected out tensor to have device {device}, but got cpu instead"
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_string):
2073*da0073e9SAndroid Build Coastguard Worker            torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2074*da0073e9SAndroid Build Coastguard Worker
2075*da0073e9SAndroid Build Coastguard Worker        # tests half to float promotion
2076*da0073e9SAndroid Build Coastguard Worker        if self.device_type == 'cuda':
2077*da0073e9SAndroid Build Coastguard Worker            x = x.half()
2078*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err_string):
2079*da0073e9SAndroid Build Coastguard Worker                torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2080*da0073e9SAndroid Build Coastguard Worker
2081*da0073e9SAndroid Build Coastguard Worker    # Assert for illegal dtype would not be raised on XLA
2082*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2083*da0073e9SAndroid Build Coastguard Worker    def test_minmax_illegal_dtype(self, device):
2084*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.float32, device=device)
2085*da0073e9SAndroid Build Coastguard Worker        valid_values = torch.empty(5, dtype=torch.float32, device=device)
2086*da0073e9SAndroid Build Coastguard Worker        valid_indices = torch.empty(5, dtype=torch.long, device=device)
2087*da0073e9SAndroid Build Coastguard Worker        illegal_values = torch.empty(5, dtype=torch.int, device=device)
2088*da0073e9SAndroid Build Coastguard Worker        illegal_indices = torch.empty(5, dtype=torch.double, device=device)
2089*da0073e9SAndroid Build Coastguard Worker        torch.max(x, dim=0, out=(valid_values, valid_indices))
2090*da0073e9SAndroid Build Coastguard Worker        torch.min(x, dim=0, out=(valid_values, valid_indices))
2091*da0073e9SAndroid Build Coastguard Worker        torch.amax(x, dim=0, out=valid_values)
2092*da0073e9SAndroid Build Coastguard Worker        torch.amin(x, dim=0, out=valid_values)
2093*da0073e9SAndroid Build Coastguard Worker        rmsg = r'scalar type|dtype'
2094*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, rmsg):
2095*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=0, out=(illegal_values, valid_indices))
2096*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, rmsg):
2097*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=0, out=(illegal_values, valid_indices))
2098*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, rmsg):
2099*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=0, out=(valid_values, illegal_indices))
2100*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, rmsg):
2101*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=0, out=(valid_values, illegal_indices))
2102*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, rmsg):
2103*da0073e9SAndroid Build Coastguard Worker            torch.max(x, dim=0, out=(illegal_values, illegal_indices))
2104*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, rmsg):
2105*da0073e9SAndroid Build Coastguard Worker            torch.min(x, dim=0, out=(illegal_values, illegal_indices))
2106*da0073e9SAndroid Build Coastguard Worker
2107*da0073e9SAndroid Build Coastguard Worker    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2108*da0073e9SAndroid Build Coastguard Worker    def test_dim_arg_reduction_scalar(self, device, dtype):
2109*da0073e9SAndroid Build Coastguard Worker        example = 4.0
2110*da0073e9SAndroid Build Coastguard Worker
2111*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(example, device=device, dtype=dtype)
2112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax().item(), 0)
2113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=None).item(), 0)
2114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=0).item(), 0)
2115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2116*da0073e9SAndroid Build Coastguard Worker
2117*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(example, device=device, dtype=dtype)
2118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin().item(), 0)
2119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=None).item(), 0)
2120*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=0).item(), 0)
2121*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
2125*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
2126*da0073e9SAndroid Build Coastguard Worker    def test_dim_reduction(self, device, dtype):
2127*da0073e9SAndroid Build Coastguard Worker        example = [[-1, 2, 1], [5, 3, 6]]
2128*da0073e9SAndroid Build Coastguard Worker
2129*da0073e9SAndroid Build Coastguard Worker        sum_dtype = {
2130*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: torch.bfloat16,
2131*da0073e9SAndroid Build Coastguard Worker            torch.double: torch.double,
2132*da0073e9SAndroid Build Coastguard Worker            torch.float: torch.float,
2133*da0073e9SAndroid Build Coastguard Worker            torch.half: torch.half,
2134*da0073e9SAndroid Build Coastguard Worker            torch.int64: torch.int64,
2135*da0073e9SAndroid Build Coastguard Worker            torch.int32: torch.int64,
2136*da0073e9SAndroid Build Coastguard Worker            torch.int16: torch.int64,
2137*da0073e9SAndroid Build Coastguard Worker            torch.int8: torch.int64
2138*da0073e9SAndroid Build Coastguard Worker        }
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker        # This won't test for 256bit instructions, since we usually
2141*da0073e9SAndroid Build Coastguard Worker        # only work on 1 cacheline (512bit) at a time and these
2142*da0073e9SAndroid Build Coastguard Worker        # examples aren't big enough to trigger that.
2143*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(example, device=device, dtype=dtype)
2144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum().item(), 16)
2145*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype]))
2146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype]))
2147*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(example, device=device, dtype=sum_dtype[dtype])
2148*da0073e9SAndroid Build Coastguard Worker        torch.sum(x, 0, out=y)
2149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.sum(0), y)
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker        # Mean not supported for Int types
2152*da0073e9SAndroid Build Coastguard Worker        if dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
2153*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(example, device=device, dtype=dtype)
2154*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.mean().item(), 16.0 / 6)
2155*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype))
2156*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype))
2157*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.mean(), x.mean((0, 1)))
2158*da0073e9SAndroid Build Coastguard Worker
2159*da0073e9SAndroid Build Coastguard Worker        prod_dtype = {
2160*da0073e9SAndroid Build Coastguard Worker            torch.bfloat16: torch.bfloat16,
2161*da0073e9SAndroid Build Coastguard Worker            torch.double: torch.double,
2162*da0073e9SAndroid Build Coastguard Worker            torch.float: torch.float,
2163*da0073e9SAndroid Build Coastguard Worker            torch.float16: torch.float16,
2164*da0073e9SAndroid Build Coastguard Worker            torch.int64: torch.int64,
2165*da0073e9SAndroid Build Coastguard Worker            torch.int32: torch.int64,
2166*da0073e9SAndroid Build Coastguard Worker            torch.int16: torch.int64,
2167*da0073e9SAndroid Build Coastguard Worker            torch.int8: torch.int64,
2168*da0073e9SAndroid Build Coastguard Worker        }
2169*da0073e9SAndroid Build Coastguard Worker
2170*da0073e9SAndroid Build Coastguard Worker        # prod is not supported for float16 & bfloat16 on CPU
2171*da0073e9SAndroid Build Coastguard Worker        if not (self.device_type == 'cpu' and dtype in [torch.float16, torch.bfloat16]):
2172*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor(example, device=device, dtype=dtype)
2173*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.prod().item(), -180)
2174*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype]))
2175*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype]))
2176*da0073e9SAndroid Build Coastguard Worker
2177*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(example, device=device, dtype=dtype)
2178*da0073e9SAndroid Build Coastguard Worker
2179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.min().item(), -1)
2180*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin().item(), 0)
2181*da0073e9SAndroid Build Coastguard Worker
2182*da0073e9SAndroid Build Coastguard Worker        # TODO: torch.min does not support the same operation as argmin
2183*da0073e9SAndroid Build Coastguard Worker        # for the same case, should we enable it?
2184*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=None).item(), 0)
2185*da0073e9SAndroid Build Coastguard Worker
2186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype),
2187*da0073e9SAndroid Build Coastguard Worker                                    torch.tensor([0, 0, 0], dtype=torch.int64)))
2188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype))
2189*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64))
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype),
2192*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([[0, 0, 0]], dtype=torch.int64)))
2193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype))
2194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64))
2195*da0073e9SAndroid Build Coastguard Worker
2196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype),
2197*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([0, 1], dtype=torch.int64)))
2198*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype))
2199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64))
2200*da0073e9SAndroid Build Coastguard Worker
2201*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype),
2202*da0073e9SAndroid Build Coastguard Worker                         torch.tensor([[0], [1]], dtype=torch.int64)))
2203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype))
2204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64))
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker        # test that non-contiguous tensors work
2207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, :2].min().item(), -1)
2208*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, :2].amin().item(), -1)
2209*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, :2].argmin().item(), 0)
2210*da0073e9SAndroid Build Coastguard Worker
2211*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(example, device=device, dtype=dtype)
2212*da0073e9SAndroid Build Coastguard Worker
2213*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max().item(), 6)
2214*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amax().item(), 6)
2215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax().item(), 5)
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype),
2218*da0073e9SAndroid Build Coastguard Worker                                    torch.tensor([1, 1, 1], dtype=torch.int64)))
2219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype))
2220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64))
2221*da0073e9SAndroid Build Coastguard Worker
2222*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype),
2223*da0073e9SAndroid Build Coastguard Worker                                                      torch.tensor([[1, 1, 1]], dtype=torch.int64)))
2224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype))
2225*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64))
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype),
2228*da0073e9SAndroid Build Coastguard Worker                                    torch.tensor([1, 2], dtype=torch.int64)))
2229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype))
2230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64))
2231*da0073e9SAndroid Build Coastguard Worker
2232*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype),
2233*da0073e9SAndroid Build Coastguard Worker                                                  torch.tensor([[1], [2]], dtype=torch.int64)))
2234*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype))
2235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64))
2236*da0073e9SAndroid Build Coastguard Worker
2237*da0073e9SAndroid Build Coastguard Worker        # test that non-contiguous tensors work
2238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, :2].max().item(), 5)
2239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, :2].amax().item(), 5)
2240*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, :2].argmax().item(), 2)
2241*da0073e9SAndroid Build Coastguard Worker
2242*da0073e9SAndroid Build Coastguard Worker
2243*da0073e9SAndroid Build Coastguard Worker    @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
2244*da0073e9SAndroid Build Coastguard Worker    @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
2245*da0073e9SAndroid Build Coastguard Worker    @parametrize("fn_name", [
2246*da0073e9SAndroid Build Coastguard Worker        "mean", "median", "nanmedian", "mode", "norm", "prod",
2247*da0073e9SAndroid Build Coastguard Worker        "std", "sum", "var", "max", "min", "amax", "amin"])
2248*da0073e9SAndroid Build Coastguard Worker    def test_dim_reduction_fns(self, device, dtype, fn_name):
2249*da0073e9SAndroid Build Coastguard Worker        def normfn_attr(t, dim, keepdim=False, out=None):
2250*da0073e9SAndroid Build Coastguard Worker            attr = torch.norm
2251*da0073e9SAndroid Build Coastguard Worker            return attr(t, 2, dim, keepdim, out=out)
2252*da0073e9SAndroid Build Coastguard Worker
2253*da0073e9SAndroid Build Coastguard Worker        fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
2254*da0073e9SAndroid Build Coastguard Worker
2255*da0073e9SAndroid Build Coastguard Worker        def fn(x, dim, keepdim=False, out=None):
2256*da0073e9SAndroid Build Coastguard Worker            ans = fn_attr(x, dim, keepdim=keepdim, out=out)
2257*da0073e9SAndroid Build Coastguard Worker            return ans if not isinstance(ans, tuple) else ans[0]
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker        def fn_tuple(x, dim, keepdim=False, out=None):
2260*da0073e9SAndroid Build Coastguard Worker            return fn_attr(x, dim, keepdim=keepdim, out=out)
2261*da0073e9SAndroid Build Coastguard Worker
2262*da0073e9SAndroid Build Coastguard Worker        def test_multidim(x, dim):
2263*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
2264*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
2265*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())
2266*da0073e9SAndroid Build Coastguard Worker
2267*da0073e9SAndroid Build Coastguard Worker        # general case
2268*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, device=device)
2269*da0073e9SAndroid Build Coastguard Worker        dim = random.randint(0, 2)
2270*da0073e9SAndroid Build Coastguard Worker        test_multidim(x, dim)
2271*da0073e9SAndroid Build Coastguard Worker
2272*da0073e9SAndroid Build Coastguard Worker        # check 1-d behavior
2273*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, device=device)
2274*da0073e9SAndroid Build Coastguard Worker        dim = 0
2275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, dim).shape, ())
2276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x, dim, keepdim=True).shape, (1,))
2277*da0073e9SAndroid Build Coastguard Worker
2278*da0073e9SAndroid Build Coastguard Worker        # check reducing of a singleton dimension
2279*da0073e9SAndroid Build Coastguard Worker        dims = [3, 4, 5]
2280*da0073e9SAndroid Build Coastguard Worker        singleton_dim = random.randint(0, 2)
2281*da0073e9SAndroid Build Coastguard Worker        dims[singleton_dim] = 1
2282*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(dims, device=device)
2283*da0073e9SAndroid Build Coastguard Worker        test_multidim(x, singleton_dim)
2284*da0073e9SAndroid Build Coastguard Worker
2285*da0073e9SAndroid Build Coastguard Worker        # check reducing with output kwargs
2286*da0073e9SAndroid Build Coastguard Worker        if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']:
2287*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(5, 3, device=device)
2288*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(5, 3, device=device)
2289*da0073e9SAndroid Build Coastguard Worker            indices = torch.zeros(5, 3, device=device).long() - 1
2290*da0073e9SAndroid Build Coastguard Worker            fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
2291*da0073e9SAndroid Build Coastguard Worker            values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
2292*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values[:, 1], values_expected,
2293*da0073e9SAndroid Build Coastguard Worker                             msg=f'{fn_name} values with out= kwarg')
2294*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indices[:, 1], indices_expected,
2295*da0073e9SAndroid Build Coastguard Worker                             msg=f'{fn_name} indices with out= kwarg')
2296*da0073e9SAndroid Build Coastguard Worker            return
2297*da0073e9SAndroid Build Coastguard Worker
2298*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3, device=device)
2299*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 3, device=device)
2300*da0073e9SAndroid Build Coastguard Worker        fn(y, 1, keepdim=False, out=x[:, 1])
2301*da0073e9SAndroid Build Coastguard Worker        expected = fn(y, 1, keepdim=False)
2302*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x[:, 1], expected, msg=f'{fn_name} with out= kwarg')
2303*da0073e9SAndroid Build Coastguard Worker
2304*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2305*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest('10GB')
2306*da0073e9SAndroid Build Coastguard Worker    def test_reduction_split(self, device):
2307*da0073e9SAndroid Build Coastguard Worker        # Test reduction when there is a 32bit-indexing split
2308*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/37583
2309*da0073e9SAndroid Build Coastguard Worker        input_ = torch.randn(5, 14400, 14400, device=device)
2310*da0073e9SAndroid Build Coastguard Worker        result = input_.sum(dim=0)
2311*da0073e9SAndroid Build Coastguard Worker        expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4]
2312*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expect)
2313*da0073e9SAndroid Build Coastguard Worker
2314*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2315*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2316*da0073e9SAndroid Build Coastguard Worker    def test_reduction_vectorize_along_input_corner(self, device, dtype):
2317*da0073e9SAndroid Build Coastguard Worker        # 1D case: sum
2318*da0073e9SAndroid Build Coastguard Worker        size = 1024 * 1024 * 64 + 3
2319*da0073e9SAndroid Build Coastguard Worker        shift = 1
2320*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(size, dtype=dtype, device=device)
2321*da0073e9SAndroid Build Coastguard Worker        y = x[shift:]
2322*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2323*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2324*da0073e9SAndroid Build Coastguard Worker            x[i] = 1
2325*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.sum(), 1.0)
2326*da0073e9SAndroid Build Coastguard Worker            if i < shift:
2327*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.sum(), 0.0)
2328*da0073e9SAndroid Build Coastguard Worker            else:
2329*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.sum(), 1.0)
2330*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 100):
2331*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2332*da0073e9SAndroid Build Coastguard Worker            x[-i] = 1
2333*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.sum(), 1.0)
2334*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.sum(), 1.0)
2335*da0073e9SAndroid Build Coastguard Worker        # 1D case: argmax
2336*da0073e9SAndroid Build Coastguard Worker        size = 1024 * 1024 * 64 + 3
2337*da0073e9SAndroid Build Coastguard Worker        shift = 1
2338*da0073e9SAndroid Build Coastguard Worker        ysize = size - shift
2339*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(size, dtype=dtype, device=device)
2340*da0073e9SAndroid Build Coastguard Worker        y = x[shift:]
2341*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2342*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2343*da0073e9SAndroid Build Coastguard Worker            x[i] = 1
2344*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.argmax().item(), i)
2345*da0073e9SAndroid Build Coastguard Worker            if i >= shift:
2346*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.argmax().item(), i - shift)
2347*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 100):
2348*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2349*da0073e9SAndroid Build Coastguard Worker            x[-i] = 1
2350*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.argmax().item(), size - i)
2351*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.argmax().item(), ysize - i)
2352*da0073e9SAndroid Build Coastguard Worker        # 2D case: sum
2353*da0073e9SAndroid Build Coastguard Worker        size = (7, 1024 * 1024 + 3)
2354*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(size, dtype=dtype, device=device)
2355*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2356*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2357*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2358*da0073e9SAndroid Build Coastguard Worker                x[j][i] = j
2359*da0073e9SAndroid Build Coastguard Worker            xs = x.sum(dim=-1)
2360*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2361*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs[j].item(), float(j))
2362*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2363*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2364*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2365*da0073e9SAndroid Build Coastguard Worker                x[j][-i] = j
2366*da0073e9SAndroid Build Coastguard Worker            xs = x.sum(dim=-1)
2367*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2368*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs[j].item(), float(j))
2369*da0073e9SAndroid Build Coastguard Worker        # 2D case: max/argmax
2370*da0073e9SAndroid Build Coastguard Worker        size = (7, 1024 * 1024 + 3)
2371*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(size, dtype=dtype, device=device)
2372*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2373*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2374*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2375*da0073e9SAndroid Build Coastguard Worker                x[j][i] = j + 1
2376*da0073e9SAndroid Build Coastguard Worker            xs1 = x.argmax(dim=-1)
2377*da0073e9SAndroid Build Coastguard Worker            xs2 = x.max(dim=-1).indices
2378*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2379*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs1[j].item(), i)
2380*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs2[j].item(), i)
2381*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 100):
2382*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2383*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2384*da0073e9SAndroid Build Coastguard Worker                x[j][-i] = j + 1
2385*da0073e9SAndroid Build Coastguard Worker            xs1 = x.argmax(dim=-1)
2386*da0073e9SAndroid Build Coastguard Worker            xs2 = x.max(dim=-1).indices
2387*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2388*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs1[j].item(), size[1] - i)
2389*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs2[j].item(), size[1] - i)
2390*da0073e9SAndroid Build Coastguard Worker        # 2D case: min/argmin
2391*da0073e9SAndroid Build Coastguard Worker        size = (7, 1024 * 1024 + 3)
2392*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(size, dtype=dtype, device=device)
2393*da0073e9SAndroid Build Coastguard Worker        for i in range(100):
2394*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2395*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2396*da0073e9SAndroid Build Coastguard Worker                x[j][i] = -(j + 1)
2397*da0073e9SAndroid Build Coastguard Worker            xs1 = x.argmin(dim=-1)
2398*da0073e9SAndroid Build Coastguard Worker            xs2 = x.min(dim=-1).indices
2399*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2400*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs1[j].item(), i)
2401*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs2[j].item(), i)
2402*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 100):
2403*da0073e9SAndroid Build Coastguard Worker            x.zero_()
2404*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2405*da0073e9SAndroid Build Coastguard Worker                x[j][-i] = -(j + 1)
2406*da0073e9SAndroid Build Coastguard Worker            xs1 = x.argmin(dim=-1)
2407*da0073e9SAndroid Build Coastguard Worker            xs2 = x.min(dim=-1).indices
2408*da0073e9SAndroid Build Coastguard Worker            for j in range(7):
2409*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs1[j].item(), size[1] - i)
2410*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(xs2[j].item(), size[1] - i)
2411*da0073e9SAndroid Build Coastguard Worker
2412*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2413*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2414*da0073e9SAndroid Build Coastguard Worker    def test_reduction_vectorize_along_output(self, device, dtype):
2415*da0073e9SAndroid Build Coastguard Worker        def run_test(input_):
2416*da0073e9SAndroid Build Coastguard Worker            M, N = input_.shape
2417*da0073e9SAndroid Build Coastguard Worker            input_.zero_()
2418*da0073e9SAndroid Build Coastguard Worker            for i in range(min(M, N)):
2419*da0073e9SAndroid Build Coastguard Worker                input_[i][i] = 1
2420*da0073e9SAndroid Build Coastguard Worker            output1 = input_.argmax(dim=0)
2421*da0073e9SAndroid Build Coastguard Worker            output2 = input_.sum(dim=0)
2422*da0073e9SAndroid Build Coastguard Worker            for i in range(min(M, N)):
2423*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output1[i], i)
2424*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(output2[i], 1)
2425*da0073e9SAndroid Build Coastguard Worker        # vec 4
2426*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64, 64, dtype=dtype, device=device))
2427*da0073e9SAndroid Build Coastguard Worker        # vec 2
2428*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64))
2429*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64, 62, dtype=dtype, device=device))
2430*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64, 2, dtype=dtype, device=device))
2431*da0073e9SAndroid Build Coastguard Worker        # vec 1
2432*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64))
2433*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64, 61, dtype=dtype, device=device))
2434*da0073e9SAndroid Build Coastguard Worker        run_test(torch.zeros(64, 1, dtype=dtype, device=device))
2435*da0073e9SAndroid Build Coastguard Worker
2436*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
2437*da0073e9SAndroid Build Coastguard Worker    def test_argminmax_large_axis(self, device):
2438*da0073e9SAndroid Build Coastguard Worker        # Regression test for gh-32863
2439*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2**31, device=device, dtype=torch.int8)
2440*da0073e9SAndroid Build Coastguard Worker        x[-1] = 1
2441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(0), x.shape[0] - 1)
2442*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.max(0).indices, x.shape[0] - 1)
2443*da0073e9SAndroid Build Coastguard Worker        x[-1] = -1
2444*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(0), x.shape[0] - 1)
2445*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.min(0).indices, x.shape[0] - 1)
2446*da0073e9SAndroid Build Coastguard Worker
2447*da0073e9SAndroid Build Coastguard Worker    def test_argminmax_axis_with_dim_one(self, device):
2448*da0073e9SAndroid Build Coastguard Worker        # See: https://github.com/pytorch/pytorch/issues/38922
2449*da0073e9SAndroid Build Coastguard Worker        n = 32768
2450*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(1, n)
2451*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64))
2452*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64))
2453*da0073e9SAndroid Build Coastguard Worker
2454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64))
2455*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64))
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2461*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2462*da0073e9SAndroid Build Coastguard Worker
2463*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.int, torch.long, torch.float, torch.double)
2464*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double)
2465*da0073e9SAndroid Build Coastguard Worker    def test_median_real_values(self, device, dtype):
2466*da0073e9SAndroid Build Coastguard Worker        # Generate random 0-3D sizes
2467*da0073e9SAndroid Build Coastguard Worker        sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2468*da0073e9SAndroid Build Coastguard Worker        for size in sizes:
2469*da0073e9SAndroid Build Coastguard Worker            # Create random input tensor
2470*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(size, device=device).type(dtype)
2471*da0073e9SAndroid Build Coastguard Worker            t_numpy = t.cpu().numpy()
2472*da0073e9SAndroid Build Coastguard Worker            res = t.median()
2473*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, t.nanmedian())
2474*da0073e9SAndroid Build Coastguard Worker            k = int((t.numel() - 1) / 2)
2475*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, t.view(-1).sort()[0][k])
2476*da0073e9SAndroid Build Coastguard Worker            if t.numel() % 2 == 1:
2477*da0073e9SAndroid Build Coastguard Worker                # We can only test agains numpy for odd reductions because numpy
2478*da0073e9SAndroid Build Coastguard Worker                # returns the mean of the two medians and torch returns the lower
2479*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res.cpu().numpy(), np.median(t_numpy))
2480*da0073e9SAndroid Build Coastguard Worker            for dim in range(t.ndim):
2481*da0073e9SAndroid Build Coastguard Worker                res = t.median(dim, True)
2482*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, t.nanmedian(dim, True))
2483*da0073e9SAndroid Build Coastguard Worker                size = t.size(dim) if t.ndim > 0 else 1
2484*da0073e9SAndroid Build Coastguard Worker                k = int((size - 1) / 2)
2485*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim))
2486*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res[0], t.gather(dim, res[1]))
2487*da0073e9SAndroid Build Coastguard Worker                if size % 2 == 1:
2488*da0073e9SAndroid Build Coastguard Worker                    # We can only test agains numpy for odd reductions because numpy
2489*da0073e9SAndroid Build Coastguard Worker                    # returns the mean of the two medians and torch returns the lower
2490*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False)
2491*da0073e9SAndroid Build Coastguard Worker
2492*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2493*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2494*da0073e9SAndroid Build Coastguard Worker    def test_median_nan_values(self, device, dtype):
2495*da0073e9SAndroid Build Coastguard Worker        # Generate random 0-3D sizes
2496*da0073e9SAndroid Build Coastguard Worker        sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2497*da0073e9SAndroid Build Coastguard Worker        for size in sizes:
2498*da0073e9SAndroid Build Coastguard Worker            # Create random input tensor with nan values
2499*da0073e9SAndroid Build Coastguard Worker            t = torch.rand(size, device=device, dtype=dtype)
2500*da0073e9SAndroid Build Coastguard Worker            t.masked_fill_(t < 0.1, float('nan'))
2501*da0073e9SAndroid Build Coastguard Worker            t_numpy = t.cpu().numpy()
2502*da0073e9SAndroid Build Coastguard Worker            for op in [torch.median, torch.nanmedian]:
2503*da0073e9SAndroid Build Coastguard Worker                numpy_op = np.median if op == torch.median else np.nanmedian
2504*da0073e9SAndroid Build Coastguard Worker                res = op(t)
2505*da0073e9SAndroid Build Coastguard Worker                num_nan = t.isnan().sum()
2506*da0073e9SAndroid Build Coastguard Worker                if op == torch.median and num_nan > 0:
2507*da0073e9SAndroid Build Coastguard Worker                    k = t.numel() - 1
2508*da0073e9SAndroid Build Coastguard Worker                else:
2509*da0073e9SAndroid Build Coastguard Worker                    k = int((t.numel() - num_nan - 1) / 2)
2510*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(res, t.view(-1).sort()[0][k])
2511*da0073e9SAndroid Build Coastguard Worker                if (t.numel() - num_nan) % 2 == 1:
2512*da0073e9SAndroid Build Coastguard Worker                    # We can only test agains numpy for odd reductions because numpy
2513*da0073e9SAndroid Build Coastguard Worker                    # returns the mean of the two medians and torch returns the lower
2514*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res.item(), numpy_op(t.cpu().numpy()))
2515*da0073e9SAndroid Build Coastguard Worker                for dim in range(t.ndim):
2516*da0073e9SAndroid Build Coastguard Worker                    res = op(t, dim, True)
2517*da0073e9SAndroid Build Coastguard Worker                    size = t.size(dim) if t.ndim > 0 else 1
2518*da0073e9SAndroid Build Coastguard Worker                    num_nan = t.isnan().sum(dim, True)
2519*da0073e9SAndroid Build Coastguard Worker                    if op == torch.median:
2520*da0073e9SAndroid Build Coastguard Worker                        k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2))
2521*da0073e9SAndroid Build Coastguard Worker                    else:
2522*da0073e9SAndroid Build Coastguard Worker                        k = ((size - num_nan - 1) / 2).type(torch.long)
2523*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k))
2524*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res[0], t.gather(dim, res[1]))
2525*da0073e9SAndroid Build Coastguard Worker                    # We can only test agains numpy for odd reductions because numpy
2526*da0073e9SAndroid Build Coastguard Worker                    # returns the mean of the two medians and torch returns the lower
2527*da0073e9SAndroid Build Coastguard Worker                    mask = (size - num_nan) % 2 == 1
2528*da0073e9SAndroid Build Coastguard Worker                    res = res[0].masked_select(mask).cpu()
2529*da0073e9SAndroid Build Coastguard Worker                    ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()]
2530*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res, torch.from_numpy(ref))
2531*da0073e9SAndroid Build Coastguard Worker
2532*da0073e9SAndroid Build Coastguard Worker    def test_median_corner_cases(self, device):
2533*da0073e9SAndroid Build Coastguard Worker        def check(op, a, args, key):
2534*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor(a, device=device)
2535*da0073e9SAndroid Build Coastguard Worker            res = op(t, *args)
2536*da0073e9SAndroid Build Coastguard Worker            if not args:
2537*da0073e9SAndroid Build Coastguard Worker                key = torch.tensor(key, device=device)
2538*da0073e9SAndroid Build Coastguard Worker            else:
2539*da0073e9SAndroid Build Coastguard Worker                if len(key) == 1:
2540*da0073e9SAndroid Build Coastguard Worker                    key = torch.tensor(key[0], device=device)
2541*da0073e9SAndroid Build Coastguard Worker                    res = res[0]
2542*da0073e9SAndroid Build Coastguard Worker                else:
2543*da0073e9SAndroid Build Coastguard Worker                    key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device))
2544*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, key)
2545*da0073e9SAndroid Build Coastguard Worker
2546*da0073e9SAndroid Build Coastguard Worker        nan = float('nan')
2547*da0073e9SAndroid Build Coastguard Worker        check(torch.median, nan, [], nan)
2548*da0073e9SAndroid Build Coastguard Worker        check(torch.median, [], [], nan)
2549*da0073e9SAndroid Build Coastguard Worker        check(torch.nanmedian, nan, [], nan)
2550*da0073e9SAndroid Build Coastguard Worker        check(torch.median, nan, [0], [nan, 0])
2551*da0073e9SAndroid Build Coastguard Worker        check(torch.nanmedian, nan, [0], [nan, 0])
2552*da0073e9SAndroid Build Coastguard Worker        check(torch.median, [nan], [0, True], [[nan], [0]])
2553*da0073e9SAndroid Build Coastguard Worker        check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2554*da0073e9SAndroid Build Coastguard Worker        check(torch.median, [nan], [0, True], [[nan], [0]])
2555*da0073e9SAndroid Build Coastguard Worker        check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2556*da0073e9SAndroid Build Coastguard Worker
2557*da0073e9SAndroid Build Coastguard Worker        # Indices are not deterministic here so can only check values
2558*da0073e9SAndroid Build Coastguard Worker        check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]])
2559*da0073e9SAndroid Build Coastguard Worker        check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]])
2560*da0073e9SAndroid Build Coastguard Worker        check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
2561*da0073e9SAndroid Build Coastguard Worker        check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
2562*da0073e9SAndroid Build Coastguard Worker
2563*da0073e9SAndroid Build Coastguard Worker        # Discontiguous and strided tensors
2564*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(12, device=device)
2565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
2566*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
2567*da0073e9SAndroid Build Coastguard Worker
2568*da0073e9SAndroid Build Coastguard Worker        a.resize_(3, 4)
2569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
2572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
2573*da0073e9SAndroid Build Coastguard Worker
2574*da0073e9SAndroid Build Coastguard Worker        a.resize_(2, 3, 2)
2575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2579*da0073e9SAndroid Build Coastguard Worker
2580*da0073e9SAndroid Build Coastguard Worker
2581*da0073e9SAndroid Build Coastguard Worker    @onlyNativeDeviceTypes
2582*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
2583*da0073e9SAndroid Build Coastguard Worker    def test_quantile(self, device, dtype):
2584*da0073e9SAndroid Build Coastguard Worker        # Generate some random test cases
2585*da0073e9SAndroid Build Coastguard Worker        ops = ['quantile', 'nanquantile']
2586*da0073e9SAndroid Build Coastguard Worker        inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)]
2587*da0073e9SAndroid Build Coastguard Worker        quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)]
2588*da0073e9SAndroid Build Coastguard Worker        keepdims = [True, False]
2589*da0073e9SAndroid Build Coastguard Worker
2590*da0073e9SAndroid Build Coastguard Worker        # Add corner cases
2591*da0073e9SAndroid Build Coastguard Worker        inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)])
2592*da0073e9SAndroid Build Coastguard Worker        inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]])
2593*da0073e9SAndroid Build Coastguard Worker        inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]])
2594*da0073e9SAndroid Build Coastguard Worker        quantiles.extend([0.5, [0., 1.], np.random.rand(10)])
2595*da0073e9SAndroid Build Coastguard Worker
2596*da0073e9SAndroid Build Coastguard Worker        # Enumerate all input combinations
2597*da0073e9SAndroid Build Coastguard Worker        for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims):
2598*da0073e9SAndroid Build Coastguard Worker            if type(x) is tuple:
2599*da0073e9SAndroid Build Coastguard Worker                a = torch.randn(x, dtype=dtype, device=device)
2600*da0073e9SAndroid Build Coastguard Worker                # Make some random elements NaN
2601*da0073e9SAndroid Build Coastguard Worker                a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan'))
2602*da0073e9SAndroid Build Coastguard Worker            else:
2603*da0073e9SAndroid Build Coastguard Worker                a = torch.tensor(x, dtype=dtype, device=device)
2604*da0073e9SAndroid Build Coastguard Worker
2605*da0073e9SAndroid Build Coastguard Worker            q = torch.tensor(q, dtype=dtype, device=device)
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker            torch_op = getattr(torch, op)
2608*da0073e9SAndroid Build Coastguard Worker            numpy_op = getattr(np, op)
2609*da0073e9SAndroid Build Coastguard Worker
2610*da0073e9SAndroid Build Coastguard Worker            # Compute quantile along every dimension and flattened tensor
2611*da0073e9SAndroid Build Coastguard Worker            interpolations = ('linear', 'lower', 'higher', 'midpoint', 'nearest')
2612*da0073e9SAndroid Build Coastguard Worker            for interpolation, dim in product(interpolations,
2613*da0073e9SAndroid Build Coastguard Worker                                              [None] + list(range(a.ndim))):
2614*da0073e9SAndroid Build Coastguard Worker                result = torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation)
2615*da0073e9SAndroid Build Coastguard Worker                expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim,
2616*da0073e9SAndroid Build Coastguard Worker                                    interpolation=interpolation, keepdims=keepdim)
2617*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type()))
2618*da0073e9SAndroid Build Coastguard Worker
2619*da0073e9SAndroid Build Coastguard Worker                # Test out variation
2620*da0073e9SAndroid Build Coastguard Worker                out = torch.empty_like(result)
2621*da0073e9SAndroid Build Coastguard Worker                torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out)
2622*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.cpu(), result.cpu())
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker    def test_quantile_backward(self, device):
2625*da0073e9SAndroid Build Coastguard Worker        def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)):
2626*da0073e9SAndroid Build Coastguard Worker            for op in ops:
2627*da0073e9SAndroid Build Coastguard Worker                t = torch.tensor(a, device=device, requires_grad=True)
2628*da0073e9SAndroid Build Coastguard Worker                op(t, torch.tensor(q, device=device), dim).sum().backward()
2629*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(t.grad, expected_grad)
2630*da0073e9SAndroid Build Coastguard Worker
2631*da0073e9SAndroid Build Coastguard Worker        check([1., 2, 3], 0.5, 0, [0, 1, 0])
2632*da0073e9SAndroid Build Coastguard Worker        check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0])
2633*da0073e9SAndroid Build Coastguard Worker        check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5])
2634*da0073e9SAndroid Build Coastguard Worker        check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25])
2635*da0073e9SAndroid Build Coastguard Worker        check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]])
2636*da0073e9SAndroid Build Coastguard Worker        check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]])
2637*da0073e9SAndroid Build Coastguard Worker        check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile])
2638*da0073e9SAndroid Build Coastguard Worker        check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile])
2639*da0073e9SAndroid Build Coastguard Worker
2640*da0073e9SAndroid Build Coastguard Worker    def test_quantile_error(self, device):
2641*da0073e9SAndroid Build Coastguard Worker        def check(a, q, args, kwargs, message):
2642*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message):
2643*da0073e9SAndroid Build Coastguard Worker                at = torch.tensor(a, device=device)
2644*da0073e9SAndroid Build Coastguard Worker                qt = torch.tensor(q, device=device) if isinstance(q, list) else q
2645*da0073e9SAndroid Build Coastguard Worker                torch.quantile(at, qt, *args, **kwargs)
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker        check([], 0.5, [], {}, r'input tensor must be non-empty')
2648*da0073e9SAndroid Build Coastguard Worker        check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor')
2649*da0073e9SAndroid Build Coastguard Worker        check([1], 0.5, [], {}, r'input tensor must be either float or double dtype')
2650*da0073e9SAndroid Build Coastguard Worker        check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor')
2651*da0073e9SAndroid Build Coastguard Worker        check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1')
2652*da0073e9SAndroid Build Coastguard Worker        check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1')
2653*da0073e9SAndroid Build Coastguard Worker        check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.int32, device=device)},
2654*da0073e9SAndroid Build Coastguard Worker              r'out tensor must be same dtype as the input tensor')
2655*da0073e9SAndroid Build Coastguard Worker        check([1.], [1.], [None, False], {'interpolation': 'random_mode'},
2656*da0073e9SAndroid Build Coastguard Worker              r"interpolation must be one of linear, lower, higher, midpoint or nearest, but got random_mode")
2657*da0073e9SAndroid Build Coastguard Worker
2658*da0073e9SAndroid Build Coastguard Worker        if self.device_type == "cpu":
2659*da0073e9SAndroid Build Coastguard Worker            check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]')
2660*da0073e9SAndroid Build Coastguard Worker
2661*da0073e9SAndroid Build Coastguard Worker        if self.device_type == "cuda":
2662*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2663*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'):
2664*da0073e9SAndroid Build Coastguard Worker                torch.randn(1, device=device).quantile(torch.tensor(0.5))
2665*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2666*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'):
2667*da0073e9SAndroid Build Coastguard Worker                torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1))
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker    def test_std_mean(self, device):
2670*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 50, 20, device=device)
2671*da0073e9SAndroid Build Coastguard Worker        for dim in range(x.dim()):
2672*da0073e9SAndroid Build Coastguard Worker            for unbiased in [False, True]:
2673*da0073e9SAndroid Build Coastguard Worker                for keepdim in [False, True]:
2674*da0073e9SAndroid Build Coastguard Worker                    std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2675*da0073e9SAndroid Build Coastguard Worker                    std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2676*da0073e9SAndroid Build Coastguard Worker                    mean2 = x.mean(dim=dim, keepdim=keepdim)
2677*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(std1, std2)
2678*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mean1, mean2)
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker    def test_std_mean_all_dims(self, device):
2681*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 50, 20, device=device)
2682*da0073e9SAndroid Build Coastguard Worker        for unbiased in [False, True]:
2683*da0073e9SAndroid Build Coastguard Worker            std1, mean1 = torch.std_mean(x, unbiased=unbiased)
2684*da0073e9SAndroid Build Coastguard Worker            std2 = x.std(unbiased=unbiased)
2685*da0073e9SAndroid Build Coastguard Worker            mean2 = x.mean()
2686*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(std1, std2)
2687*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mean1, mean2)
2688*da0073e9SAndroid Build Coastguard Worker
2689*da0073e9SAndroid Build Coastguard Worker    def test_var_mean(self, device):
2690*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 300, 50, device=device)
2691*da0073e9SAndroid Build Coastguard Worker        for dim in range(x.dim()):
2692*da0073e9SAndroid Build Coastguard Worker            for unbiased in [False, True]:
2693*da0073e9SAndroid Build Coastguard Worker                for keepdim in [False, True]:
2694*da0073e9SAndroid Build Coastguard Worker                    var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2695*da0073e9SAndroid Build Coastguard Worker                    var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
2696*da0073e9SAndroid Build Coastguard Worker                    mean2 = x.mean(dim=dim, keepdim=keepdim)
2697*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(var1, var2)
2698*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(mean1, mean2)
2699*da0073e9SAndroid Build Coastguard Worker
2700*da0073e9SAndroid Build Coastguard Worker    def test_var_mean_all_dims(self, device):
2701*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(100, 50, 20, device=device)
2702*da0073e9SAndroid Build Coastguard Worker        for unbiased in [False, True]:
2703*da0073e9SAndroid Build Coastguard Worker            var1, mean1 = torch.var_mean(x, unbiased=unbiased)
2704*da0073e9SAndroid Build Coastguard Worker            var2 = x.var(unbiased=unbiased)
2705*da0073e9SAndroid Build Coastguard Worker            mean2 = x.mean()
2706*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(var1, var2)
2707*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mean1, mean2)
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker    def test_std_mean_some_dims(self, device):
2710*da0073e9SAndroid Build Coastguard Worker        sizes = (4, 6, 7, 5, 3)
2711*da0073e9SAndroid Build Coastguard Worker        dims = len(sizes)
2712*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(sizes, device=device)
2713*da0073e9SAndroid Build Coastguard Worker        for num_of_dims in range(2, dims):
2714*da0073e9SAndroid Build Coastguard Worker            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2715*da0073e9SAndroid Build Coastguard Worker            for dim in dim_list:
2716*da0073e9SAndroid Build Coastguard Worker                for unbiased in [False, True]:
2717*da0073e9SAndroid Build Coastguard Worker                    for keepdim in [False, True]:
2718*da0073e9SAndroid Build Coastguard Worker                        std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2719*da0073e9SAndroid Build Coastguard Worker                        std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2720*da0073e9SAndroid Build Coastguard Worker                        mean2 = x.mean(dim=dim, keepdim=keepdim)
2721*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(std1, std2)
2722*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(mean1, mean2)
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker    def _compare_std_var_with_numpy(self, op, device, dtype, input, dim,
2725*da0073e9SAndroid Build Coastguard Worker                                    keepdim, unbiased, use_out):
2726*da0073e9SAndroid Build Coastguard Worker        a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy()
2727*da0073e9SAndroid Build Coastguard Worker        numpy_kwargs = {
2728*da0073e9SAndroid Build Coastguard Worker            'axis' : dim,
2729*da0073e9SAndroid Build Coastguard Worker            'keepdims' : keepdim,
2730*da0073e9SAndroid Build Coastguard Worker            'ddof' : 1 if unbiased else 0,
2731*da0073e9SAndroid Build Coastguard Worker        }
2732*da0073e9SAndroid Build Coastguard Worker
2733*da0073e9SAndroid Build Coastguard Worker        if dim is None:
2734*da0073e9SAndroid Build Coastguard Worker            del numpy_kwargs['axis']
2735*da0073e9SAndroid Build Coastguard Worker            del numpy_kwargs['keepdims']
2736*da0073e9SAndroid Build Coastguard Worker
2737*da0073e9SAndroid Build Coastguard Worker        if op == 'var':
2738*da0073e9SAndroid Build Coastguard Worker            torch_op = torch.var
2739*da0073e9SAndroid Build Coastguard Worker            numpy_op = np.var
2740*da0073e9SAndroid Build Coastguard Worker        elif op == 'std':
2741*da0073e9SAndroid Build Coastguard Worker            torch_op = torch.std
2742*da0073e9SAndroid Build Coastguard Worker            numpy_op = np.std
2743*da0073e9SAndroid Build Coastguard Worker        else:
2744*da0073e9SAndroid Build Coastguard Worker            self.fail("Unknown op!")
2745*da0073e9SAndroid Build Coastguard Worker
2746*da0073e9SAndroid Build Coastguard Worker        numpy_result = numpy_op(a, **numpy_kwargs)
2747*da0073e9SAndroid Build Coastguard Worker
2748*da0073e9SAndroid Build Coastguard Worker        if dim is None and use_out is False:
2749*da0073e9SAndroid Build Coastguard Worker            torch_result = torch_op(input, unbiased)
2750*da0073e9SAndroid Build Coastguard Worker        elif dim is not None and use_out is False:
2751*da0073e9SAndroid Build Coastguard Worker            torch_result = torch_op(input, dim, unbiased, keepdim)
2752*da0073e9SAndroid Build Coastguard Worker        elif dim is not None and use_out is True:
2753*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=device, dtype=dtype)
2754*da0073e9SAndroid Build Coastguard Worker            torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2755*da0073e9SAndroid Build Coastguard Worker        else:
2756*da0073e9SAndroid Build Coastguard Worker            out = torch.empty(0, device=device, dtype=dtype)
2757*da0073e9SAndroid Build Coastguard Worker            torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2758*da0073e9SAndroid Build Coastguard Worker
2759*da0073e9SAndroid Build Coastguard Worker        exact_dtype = input.dtype not in (torch.bfloat16, torch.complex32, torch.complex64, torch.complex128)
2760*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype)
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2763*da0073e9SAndroid Build Coastguard Worker    def test_var_vs_numpy(self, device, dtype):
2764*da0073e9SAndroid Build Coastguard Worker        _size = (20, 20)
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker        for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2767*da0073e9SAndroid Build Coastguard Worker                                 (None, 0, 1),
2768*da0073e9SAndroid Build Coastguard Worker                                 (False, True),
2769*da0073e9SAndroid Build Coastguard Worker                                 (False, True),
2770*da0073e9SAndroid Build Coastguard Worker                                 (False, True),):
2771*da0073e9SAndroid Build Coastguard Worker            self._compare_std_var_with_numpy('var', device, dtype, *test_case)
2772*da0073e9SAndroid Build Coastguard Worker
2773*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2774*da0073e9SAndroid Build Coastguard Worker    def test_std_vs_numpy(self, device, dtype):
2775*da0073e9SAndroid Build Coastguard Worker        _size = (20, 20)
2776*da0073e9SAndroid Build Coastguard Worker
2777*da0073e9SAndroid Build Coastguard Worker        for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2778*da0073e9SAndroid Build Coastguard Worker                                 (None, 0, 1),
2779*da0073e9SAndroid Build Coastguard Worker                                 (False, True),
2780*da0073e9SAndroid Build Coastguard Worker                                 (False, True),
2781*da0073e9SAndroid Build Coastguard Worker                                 (False, True),):
2782*da0073e9SAndroid Build Coastguard Worker            self._compare_std_var_with_numpy('std', device, dtype, *test_case)
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2785*da0073e9SAndroid Build Coastguard Worker    def test_var_correction_vs_numpy(self, device, dtype):
2786*da0073e9SAndroid Build Coastguard Worker        _size = (20, 20)
2787*da0073e9SAndroid Build Coastguard Worker        test_args = [
2788*da0073e9SAndroid Build Coastguard Worker            *product(
2789*da0073e9SAndroid Build Coastguard Worker                # dim
2790*da0073e9SAndroid Build Coastguard Worker                (None, 0, 1),
2791*da0073e9SAndroid Build Coastguard Worker                # correction
2792*da0073e9SAndroid Build Coastguard Worker                (None, 0, 10, 30),
2793*da0073e9SAndroid Build Coastguard Worker                # keepdim
2794*da0073e9SAndroid Build Coastguard Worker                (False, True),
2795*da0073e9SAndroid Build Coastguard Worker            ),
2796*da0073e9SAndroid Build Coastguard Worker            [None, -100, True],  # Negative correction
2797*da0073e9SAndroid Build Coastguard Worker        ]
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker        tensor = make_tensor(_size, device=device, dtype=dtype)
2800*da0073e9SAndroid Build Coastguard Worker        array = tensor.cpu().numpy()
2801*da0073e9SAndroid Build Coastguard Worker
2802*da0073e9SAndroid Build Coastguard Worker        for dim, correction, keepdim in test_args:
2803*da0073e9SAndroid Build Coastguard Worker            numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2804*da0073e9SAndroid Build Coastguard Worker            if correction is None:
2805*da0073e9SAndroid Build Coastguard Worker                # NumPy default is not compatible with torch.std (gh-50010)
2806*da0073e9SAndroid Build Coastguard Worker                numpy_kwargs['ddof'] = 1
2807*da0073e9SAndroid Build Coastguard Worker
2808*da0073e9SAndroid Build Coastguard Worker            numpy_res = np.asarray(np.var(array, **numpy_kwargs))
2809*da0073e9SAndroid Build Coastguard Worker            torch_res = torch.var(tensor, dim=dim, correction=correction, keepdim=keepdim)
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker            # inf vs. nan results are sensitive to machine precision,
2812*da0073e9SAndroid Build Coastguard Worker            # just treat them as equivalent
2813*da0073e9SAndroid Build Coastguard Worker            numpy_res[np.isinf(numpy_res)] = np.nan
2814*da0073e9SAndroid Build Coastguard Worker            torch_res[torch_res.isinf()] = np.nan
2815*da0073e9SAndroid Build Coastguard Worker
2816*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch_res, numpy_res)
2817*da0073e9SAndroid Build Coastguard Worker
2818*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2819*da0073e9SAndroid Build Coastguard Worker    def test_std_correction_vs_numpy(self, device, dtype):
2820*da0073e9SAndroid Build Coastguard Worker        _size = (20, 20)
2821*da0073e9SAndroid Build Coastguard Worker        test_args = [
2822*da0073e9SAndroid Build Coastguard Worker            *product(
2823*da0073e9SAndroid Build Coastguard Worker                # dim
2824*da0073e9SAndroid Build Coastguard Worker                (None, 0, 1),
2825*da0073e9SAndroid Build Coastguard Worker                # correction
2826*da0073e9SAndroid Build Coastguard Worker                (None, 0, 10, 30),
2827*da0073e9SAndroid Build Coastguard Worker                # keepdim
2828*da0073e9SAndroid Build Coastguard Worker                (False, True),
2829*da0073e9SAndroid Build Coastguard Worker            ),
2830*da0073e9SAndroid Build Coastguard Worker            [None, -100, True],  # Negative correction
2831*da0073e9SAndroid Build Coastguard Worker        ]
2832*da0073e9SAndroid Build Coastguard Worker
2833*da0073e9SAndroid Build Coastguard Worker        tensor = make_tensor(_size, device=device, dtype=dtype)
2834*da0073e9SAndroid Build Coastguard Worker        array = tensor.cpu().numpy()
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Worker        for dim, correction, keepdim in test_args:
2837*da0073e9SAndroid Build Coastguard Worker            numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2838*da0073e9SAndroid Build Coastguard Worker            if correction is None:
2839*da0073e9SAndroid Build Coastguard Worker                # NumPy default is incompatible with torch.std (gh-50010)
2840*da0073e9SAndroid Build Coastguard Worker                numpy_kwargs['ddof'] = 1
2841*da0073e9SAndroid Build Coastguard Worker
2842*da0073e9SAndroid Build Coastguard Worker            numpy_res = np.asarray(np.std(array, **numpy_kwargs))
2843*da0073e9SAndroid Build Coastguard Worker            torch_res = torch.std(tensor, dim=dim, correction=correction, keepdim=keepdim)
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker            # inf vs. nan results are sensitive to machine precision,
2846*da0073e9SAndroid Build Coastguard Worker            # just treat them as equivalent
2847*da0073e9SAndroid Build Coastguard Worker            numpy_res[np.isinf(numpy_res)] = np.nan
2848*da0073e9SAndroid Build Coastguard Worker            torch_res[torch_res.isinf()] = np.nan
2849*da0073e9SAndroid Build Coastguard Worker
2850*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch_res, numpy_res)
2851*da0073e9SAndroid Build Coastguard Worker
2852*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2853*da0073e9SAndroid Build Coastguard Worker    def test_std_mean_correction(self, device, dtype):
2854*da0073e9SAndroid Build Coastguard Worker        _size = (20, 20)
2855*da0073e9SAndroid Build Coastguard Worker        test_args = [
2856*da0073e9SAndroid Build Coastguard Worker            *product(
2857*da0073e9SAndroid Build Coastguard Worker                # dim
2858*da0073e9SAndroid Build Coastguard Worker                (None, 0, 1),
2859*da0073e9SAndroid Build Coastguard Worker                # correction
2860*da0073e9SAndroid Build Coastguard Worker                (None, 0, 10, 30),
2861*da0073e9SAndroid Build Coastguard Worker                # keepdim
2862*da0073e9SAndroid Build Coastguard Worker                (False, True),
2863*da0073e9SAndroid Build Coastguard Worker            ),
2864*da0073e9SAndroid Build Coastguard Worker            [None, -100, True],  # Negative correction
2865*da0073e9SAndroid Build Coastguard Worker        ]
2866*da0073e9SAndroid Build Coastguard Worker
2867*da0073e9SAndroid Build Coastguard Worker        tensor = make_tensor(_size, device=device, dtype=dtype)
2868*da0073e9SAndroid Build Coastguard Worker
2869*da0073e9SAndroid Build Coastguard Worker        for dim, correction, keepdim in test_args:
2870*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2871*da0073e9SAndroid Build Coastguard Worker            std1 = torch.std(tensor, **kwargs)
2872*da0073e9SAndroid Build Coastguard Worker            if dim is not None:
2873*da0073e9SAndroid Build Coastguard Worker                mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2874*da0073e9SAndroid Build Coastguard Worker            else:
2875*da0073e9SAndroid Build Coastguard Worker                mean1 = torch.mean(tensor)
2876*da0073e9SAndroid Build Coastguard Worker                if keepdim:
2877*da0073e9SAndroid Build Coastguard Worker                    mean1 = mean1.reshape((1,) * tensor.ndim)
2878*da0073e9SAndroid Build Coastguard Worker            std2, mean2 = torch.std_mean(tensor, **kwargs)
2879*da0073e9SAndroid Build Coastguard Worker
2880*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(std1, std2)
2881*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mean1, mean2)
2882*da0073e9SAndroid Build Coastguard Worker
2883*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2884*da0073e9SAndroid Build Coastguard Worker    def test_var_mean_correction(self, device, dtype):
2885*da0073e9SAndroid Build Coastguard Worker        _size = (20, 20)
2886*da0073e9SAndroid Build Coastguard Worker        test_args = [
2887*da0073e9SAndroid Build Coastguard Worker            *product(
2888*da0073e9SAndroid Build Coastguard Worker                # dim
2889*da0073e9SAndroid Build Coastguard Worker                (None, 0, 1),
2890*da0073e9SAndroid Build Coastguard Worker                # correction
2891*da0073e9SAndroid Build Coastguard Worker                (None, 0, 10, 30),
2892*da0073e9SAndroid Build Coastguard Worker                # keepdim
2893*da0073e9SAndroid Build Coastguard Worker                (False, True),
2894*da0073e9SAndroid Build Coastguard Worker            ),
2895*da0073e9SAndroid Build Coastguard Worker            [None, -100, True],  # Negative correction
2896*da0073e9SAndroid Build Coastguard Worker        ]
2897*da0073e9SAndroid Build Coastguard Worker
2898*da0073e9SAndroid Build Coastguard Worker        tensor = make_tensor(_size, device=device, dtype=dtype)
2899*da0073e9SAndroid Build Coastguard Worker
2900*da0073e9SAndroid Build Coastguard Worker        for dim, correction, keepdim in test_args:
2901*da0073e9SAndroid Build Coastguard Worker            kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2902*da0073e9SAndroid Build Coastguard Worker            var1 = torch.var(tensor, **kwargs)
2903*da0073e9SAndroid Build Coastguard Worker            if dim is not None:
2904*da0073e9SAndroid Build Coastguard Worker                mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2905*da0073e9SAndroid Build Coastguard Worker            else:
2906*da0073e9SAndroid Build Coastguard Worker                mean1 = torch.mean(tensor)
2907*da0073e9SAndroid Build Coastguard Worker                if keepdim:
2908*da0073e9SAndroid Build Coastguard Worker                    mean1 = mean1.reshape((1,) * tensor.ndim)
2909*da0073e9SAndroid Build Coastguard Worker            var2, mean2 = torch.var_mean(tensor, **kwargs)
2910*da0073e9SAndroid Build Coastguard Worker
2911*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(var1, var2)
2912*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(mean1, mean2)
2913*da0073e9SAndroid Build Coastguard Worker
2914*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2915*da0073e9SAndroid Build Coastguard Worker    def test_warn_invalid_degrees_of_freedom(self, device, dtype):
2916*da0073e9SAndroid Build Coastguard Worker        def _assert_warning(_func, _tensor, _correction):
2917*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
2918*da0073e9SAndroid Build Coastguard Worker                _func(_tensor, dim=-1, correction=_correction)
2919*da0073e9SAndroid Build Coastguard Worker            self.assertIn('degrees of freedom is <= 0', str(w[0].message))
2920*da0073e9SAndroid Build Coastguard Worker
2921*da0073e9SAndroid Build Coastguard Worker        correction = 20
2922*da0073e9SAndroid Build Coastguard Worker        size = (10, correction)
2923*da0073e9SAndroid Build Coastguard Worker        tensor = make_tensor(size, dtype=dtype, device=device)
2924*da0073e9SAndroid Build Coastguard Worker        for f in [torch.std, torch.var, torch.var_mean, torch.std_mean]:
2925*da0073e9SAndroid Build Coastguard Worker            _assert_warning(f, tensor, correction)
2926*da0073e9SAndroid Build Coastguard Worker
2927*da0073e9SAndroid Build Coastguard Worker    def test_amin_amax_some_dims(self, device):
2928*da0073e9SAndroid Build Coastguard Worker        sizes = (4, 6, 7, 5, 3)
2929*da0073e9SAndroid Build Coastguard Worker        dims = len(sizes)
2930*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(sizes, device=device)
2931*da0073e9SAndroid Build Coastguard Worker        for num_of_dims in range(2, dims):
2932*da0073e9SAndroid Build Coastguard Worker            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2933*da0073e9SAndroid Build Coastguard Worker            for dim in dim_list:
2934*da0073e9SAndroid Build Coastguard Worker                for keepdim in [False, True]:
2935*da0073e9SAndroid Build Coastguard Worker                    amin1 = torch.amin(x, dim=dim, keepdim=keepdim)
2936*da0073e9SAndroid Build Coastguard Worker                    amax1 = torch.amax(x, dim=dim, keepdim=keepdim)
2937*da0073e9SAndroid Build Coastguard Worker                    amin2 = x
2938*da0073e9SAndroid Build Coastguard Worker                    amax2 = x
2939*da0073e9SAndroid Build Coastguard Worker                    for i, d in enumerate(dim):
2940*da0073e9SAndroid Build Coastguard Worker                        if not keepdim:
2941*da0073e9SAndroid Build Coastguard Worker                            d -= i
2942*da0073e9SAndroid Build Coastguard Worker                        amin2 = torch.amin(amin2, dim=d, keepdim=keepdim)
2943*da0073e9SAndroid Build Coastguard Worker                        amax2 = torch.amax(amax2, dim=d, keepdim=keepdim)
2944*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(amin1, amin2)
2945*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(amax1, amax2)
2946*da0073e9SAndroid Build Coastguard Worker
2947*da0073e9SAndroid Build Coastguard Worker    def test_histc(self, device):
2948*da0073e9SAndroid Build Coastguard Worker        # negative nbins throws
2949*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
2950*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2951*da0073e9SAndroid Build Coastguard Worker        # empty tensor
2952*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(torch.tensor([], device=device), min=0, max=3)
2953*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(100, dtype=torch.float, device=device)
2954*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
2955*da0073e9SAndroid Build Coastguard Worker
2956*da0073e9SAndroid Build Coastguard Worker        # without nbins
2957*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
2958*da0073e9SAndroid Build Coastguard Worker            torch.tensor([2, 5], dtype=torch.float, device=device))
2959*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(100, dtype=torch.float, device=device)
2960*da0073e9SAndroid Build Coastguard Worker        expected[0] = 1
2961*da0073e9SAndroid Build Coastguard Worker        expected[99] = 1
2962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected, actual)
2963*da0073e9SAndroid Build Coastguard Worker        # tensor with the same element
2964*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2965*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2966*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2967*da0073e9SAndroid Build Coastguard Worker            actual)
2968*da0073e9SAndroid Build Coastguard Worker        # no element falls between [min, max]
2969*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
2970*da0073e9SAndroid Build Coastguard Worker            torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2972*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2973*da0073e9SAndroid Build Coastguard Worker            actual)
2974*da0073e9SAndroid Build Coastguard Worker        # element falls below min + integral bin size and
2975*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
2976*da0073e9SAndroid Build Coastguard Worker            torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2977*da0073e9SAndroid Build Coastguard Worker            bins=5, min=1, max=5)
2978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2979*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2980*da0073e9SAndroid Build Coastguard Worker            actual)
2981*da0073e9SAndroid Build Coastguard Worker        # non-integral bin size
2982*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
2983*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2984*da0073e9SAndroid Build Coastguard Worker            bins=4, min=0, max=3)
2985*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2986*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2987*da0073e9SAndroid Build Coastguard Worker            actual)
2988*da0073e9SAndroid Build Coastguard Worker        # double input
2989*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
2990*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3)
2991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2992*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2993*da0073e9SAndroid Build Coastguard Worker            actual)
2994*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.dtype, torch.double)
2995*da0073e9SAndroid Build Coastguard Worker        # mixed input
2996*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
2997*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2998*da0073e9SAndroid Build Coastguard Worker            bins=4, min=0, max=3)
2999*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3000*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
3001*da0073e9SAndroid Build Coastguard Worker            actual)
3002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.dtype, torch.float)
3003*da0073e9SAndroid Build Coastguard Worker        # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar.
3004*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
3005*da0073e9SAndroid Build Coastguard Worker            torch.tensor(0, dtype=torch.float, device=device),
3006*da0073e9SAndroid Build Coastguard Worker            bins=1, min=0, max=3)
3007*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3008*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1], dtype=torch.float, device=device),
3009*da0073e9SAndroid Build Coastguard Worker            actual)
3010*da0073e9SAndroid Build Coastguard Worker        # tensors with inf; min, max not provided -- should throw a RuntimeError
3011*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'):
3012*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device))
3013*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'):
3014*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device))
3015*da0073e9SAndroid Build Coastguard Worker        # tensors with inf; min, max provided
3016*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3017*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device),
3018*da0073e9SAndroid Build Coastguard Worker                        bins=1, min=0, max=3),
3019*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0], dtype=torch.float, device=device))
3020*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3021*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device),
3022*da0073e9SAndroid Build Coastguard Worker                        bins=4, max=3),
3023*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
3024*da0073e9SAndroid Build Coastguard Worker        # tensor with nan; min, max not provided -- should throw a RuntimeError
3025*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'):
3026*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device))
3027*da0073e9SAndroid Build Coastguard Worker        # tensor with nan; min, max provided -- nan is ignored
3028*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3029*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([1., 2., float("nan")], dtype=torch.float, device=device),
3030*da0073e9SAndroid Build Coastguard Worker                        bins=4, max=3),
3031*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
3032*da0073e9SAndroid Build Coastguard Worker        # tensors with min > max -- should throw a RuntimeError
3033*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "max must be larger than min"):
3034*da0073e9SAndroid Build Coastguard Worker            torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device),
3035*da0073e9SAndroid Build Coastguard Worker                        bins=4, min=5, max=1)
3036*da0073e9SAndroid Build Coastguard Worker
3037*da0073e9SAndroid Build Coastguard Worker        # test against numpy.histogram()
3038*da0073e9SAndroid Build Coastguard Worker        def test_against_np(tensor, bins=100, min=0, max=0):
3039*da0073e9SAndroid Build Coastguard Worker            if min == 0 and max == 0:
3040*da0073e9SAndroid Build Coastguard Worker                min = tensor.min().item()
3041*da0073e9SAndroid Build Coastguard Worker                max = tensor.max().item()
3042*da0073e9SAndroid Build Coastguard Worker            nparr = tensor.cpu().numpy()
3043*da0073e9SAndroid Build Coastguard Worker            actual = torch.histc(tensor, bins=bins, min=min, max=max)
3044*da0073e9SAndroid Build Coastguard Worker            expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
3045*da0073e9SAndroid Build Coastguard Worker            actual_cpu = actual.cpu()
3046*da0073e9SAndroid Build Coastguard Worker            # NB: Numpy returns a int64 tensor, like normal people...
3047*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected.to(actual_cpu))
3048*da0073e9SAndroid Build Coastguard Worker
3049*da0073e9SAndroid Build Coastguard Worker        test_against_np(torch.tensor([1., 2, 1], device=device))
3050*da0073e9SAndroid Build Coastguard Worker        test_against_np(torch.randn(5000, device=device))
3051*da0073e9SAndroid Build Coastguard Worker
3052*da0073e9SAndroid Build Coastguard Worker        # Test bins arg
3053*da0073e9SAndroid Build Coastguard Worker        test_against_np(torch.randn(301, device=device), bins=10)
3054*da0073e9SAndroid Build Coastguard Worker
3055*da0073e9SAndroid Build Coastguard Worker        # Test truncated range
3056*da0073e9SAndroid Build Coastguard Worker        test_against_np(torch.randn(201, device=device), min=0.1, max=1)
3057*da0073e9SAndroid Build Coastguard Worker
3058*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(100, 3, device=device)[:, 2]
3059*da0073e9SAndroid Build Coastguard Worker        test_against_np(noncontig)
3060*da0073e9SAndroid Build Coastguard Worker
3061*da0073e9SAndroid Build Coastguard Worker        multidim = torch.randn(3, 5, 7, 2, device=device)
3062*da0073e9SAndroid Build Coastguard Worker        test_against_np(multidim)
3063*da0073e9SAndroid Build Coastguard Worker
3064*da0073e9SAndroid Build Coastguard Worker        expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
3065*da0073e9SAndroid Build Coastguard Worker        test_against_np(expanded)
3066*da0073e9SAndroid Build Coastguard Worker
3067*da0073e9SAndroid Build Coastguard Worker        linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device)
3068*da0073e9SAndroid Build Coastguard Worker        test_against_np(linear, bins=20, min=0, max=0.99)
3069*da0073e9SAndroid Build Coastguard Worker
3070*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3071*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.bfloat16, torch.half)
3072*da0073e9SAndroid Build Coastguard Worker    def test_histc_lowp(self, device, dtype):
3073*da0073e9SAndroid Build Coastguard Worker        actual = torch.histc(
3074*da0073e9SAndroid Build Coastguard Worker            torch.tensor([1, 2, 1], dtype=dtype, device=device), bins=4, min=0, max=3)
3075*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
3076*da0073e9SAndroid Build Coastguard Worker            torch.tensor([0, 2, 1, 0], dtype=dtype, device=device),
3077*da0073e9SAndroid Build Coastguard Worker            actual)
3078*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual.dtype, dtype)
3079*da0073e9SAndroid Build Coastguard Worker
3080*da0073e9SAndroid Build Coastguard Worker    """
3081*da0073e9SAndroid Build Coastguard Worker    Runs torch.histogram and numpy.histogram on the specified input parameters
3082*da0073e9SAndroid Build Coastguard Worker    and asserts that their output is equal.
3083*da0073e9SAndroid Build Coastguard Worker    """
3084*da0073e9SAndroid Build Coastguard Worker    def _test_histogram_numpy(self, t, bins, bin_range, weights, density):
3085*da0073e9SAndroid Build Coastguard Worker        def to_np(t):
3086*da0073e9SAndroid Build Coastguard Worker            if not torch.is_tensor(t):
3087*da0073e9SAndroid Build Coastguard Worker                return t
3088*da0073e9SAndroid Build Coastguard Worker            else:
3089*da0073e9SAndroid Build Coastguard Worker                return t.cpu().numpy()
3090*da0073e9SAndroid Build Coastguard Worker
3091*da0073e9SAndroid Build Coastguard Worker        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
3092*da0073e9SAndroid Build Coastguard Worker        def reference_histogram(self, t, bins, bin_range, weights, density, dtype):
3093*da0073e9SAndroid Build Coastguard Worker            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3094*da0073e9SAndroid Build Coastguard Worker            (np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density)
3095*da0073e9SAndroid Build Coastguard Worker            return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype))
3096*da0073e9SAndroid Build Coastguard Worker
3097*da0073e9SAndroid Build Coastguard Worker        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
3098*da0073e9SAndroid Build Coastguard Worker        if bin_range:
3099*da0073e9SAndroid Build Coastguard Worker            (actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density)
3100*da0073e9SAndroid Build Coastguard Worker        else:
3101*da0073e9SAndroid Build Coastguard Worker            (actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density)
3102*da0073e9SAndroid Build Coastguard Worker
3103*da0073e9SAndroid Build Coastguard Worker        (expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype)
3104*da0073e9SAndroid Build Coastguard Worker
3105*da0073e9SAndroid Build Coastguard Worker        """
3106*da0073e9SAndroid Build Coastguard Worker        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3107*da0073e9SAndroid Build Coastguard Worker        When bin edges are not explicitly defined, histogram uses the linspace operator internally
3108*da0073e9SAndroid Build Coastguard Worker        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3109*da0073e9SAndroid Build Coastguard Worker        from numpy.linspace output.
3110*da0073e9SAndroid Build Coastguard Worker        Issue: https://github.com/pytorch/pytorch/issues/58758
3111*da0073e9SAndroid Build Coastguard Worker        """
3112*da0073e9SAndroid Build Coastguard Worker        if not torch.is_tensor(bins):
3113*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5)
3114*da0073e9SAndroid Build Coastguard Worker            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
3115*da0073e9SAndroid Build Coastguard Worker            (expected_hist, expected_bin_edges) = reference_histogram(
3116*da0073e9SAndroid Build Coastguard Worker                self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3117*da0073e9SAndroid Build Coastguard Worker
3118*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_hist, expected_hist)
3119*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_bin_edges, expected_bin_edges)
3120*da0073e9SAndroid Build Coastguard Worker
3121*da0073e9SAndroid Build Coastguard Worker        # Test passing non-contiguous output tensors
3122*da0073e9SAndroid Build Coastguard Worker        hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype,
3123*da0073e9SAndroid Build Coastguard Worker                               noncontiguous=True)
3124*da0073e9SAndroid Build Coastguard Worker        bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype,
3125*da0073e9SAndroid Build Coastguard Worker                                    noncontiguous=True)
3126*da0073e9SAndroid Build Coastguard Worker
3127*da0073e9SAndroid Build Coastguard Worker        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
3128*da0073e9SAndroid Build Coastguard Worker        if bin_range:
3129*da0073e9SAndroid Build Coastguard Worker            torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out))
3130*da0073e9SAndroid Build Coastguard Worker        else:
3131*da0073e9SAndroid Build Coastguard Worker            torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out))
3132*da0073e9SAndroid Build Coastguard Worker
3133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hist_out, expected_hist)
3134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bin_edges_out, expected_bin_edges)
3135*da0073e9SAndroid Build Coastguard Worker
3136*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3137*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
3138*da0073e9SAndroid Build Coastguard Worker    def test_histogram(self, device, dtype):
3139*da0073e9SAndroid Build Coastguard Worker        shapes = (
3140*da0073e9SAndroid Build Coastguard Worker            (),
3141*da0073e9SAndroid Build Coastguard Worker            (0,),
3142*da0073e9SAndroid Build Coastguard Worker            (1,),
3143*da0073e9SAndroid Build Coastguard Worker            (1, 5),
3144*da0073e9SAndroid Build Coastguard Worker            (3, 5),
3145*da0073e9SAndroid Build Coastguard Worker            (1, 5, 1),
3146*da0073e9SAndroid Build Coastguard Worker            (2, 3, 5))
3147*da0073e9SAndroid Build Coastguard Worker
3148*da0073e9SAndroid Build Coastguard Worker        for contig, bins_contig, bin_ct, weighted, density, shape in \
3149*da0073e9SAndroid Build Coastguard Worker                product([True, False], [True, False], range(1, 10), [True, False], [True, False], shapes):
3150*da0073e9SAndroid Build Coastguard Worker            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3151*da0073e9SAndroid Build Coastguard Worker            weights = make_tensor(shape, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) if weighted else None
3152*da0073e9SAndroid Build Coastguard Worker
3153*da0073e9SAndroid Build Coastguard Worker            # Tests passing just the bin_ct
3154*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(values, bin_ct, None, weights, density)
3155*da0073e9SAndroid Build Coastguard Worker
3156*da0073e9SAndroid Build Coastguard Worker            # Tests with caller-specified histogram range
3157*da0073e9SAndroid Build Coastguard Worker            bin_range = sorted((random.uniform(-9, 9), random.uniform(-9, 9)))
3158*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3159*da0073e9SAndroid Build Coastguard Worker
3160*da0073e9SAndroid Build Coastguard Worker            # Tests with range min=max
3161*da0073e9SAndroid Build Coastguard Worker            bin_range[1] = bin_range[0]
3162*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3163*da0073e9SAndroid Build Coastguard Worker
3164*da0073e9SAndroid Build Coastguard Worker            # Tests with caller-specified bin edges
3165*da0073e9SAndroid Build Coastguard Worker            bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort()
3166*da0073e9SAndroid Build Coastguard Worker            if not bins_contig:
3167*da0073e9SAndroid Build Coastguard Worker                # Necessary because msort always produces contiguous output
3168*da0073e9SAndroid Build Coastguard Worker                bin_edges_noncontig = make_tensor(bin_ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3169*da0073e9SAndroid Build Coastguard Worker                bin_edges_noncontig.copy_(bin_edges)
3170*da0073e9SAndroid Build Coastguard Worker                bin_edges = bin_edges_noncontig
3171*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bin_edges.is_contiguous(), bins_contig)
3172*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(values, bin_edges, None, weights, density)
3173*da0073e9SAndroid Build Coastguard Worker
3174*da0073e9SAndroid Build Coastguard Worker            # Tests with input tensor in which all elements are equal
3175*da0073e9SAndroid Build Coastguard Worker            elt = random.uniform(-9, 9)
3176*da0073e9SAndroid Build Coastguard Worker            values = make_tensor(shape, dtype=dtype, device=device, low=elt, high=elt, noncontiguous=not contig)
3177*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3178*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(values, bin_edges, None, weights, density)
3179*da0073e9SAndroid Build Coastguard Worker
3180*da0073e9SAndroid Build Coastguard Worker            # Tests with input equal to bin_edges
3181*da0073e9SAndroid Build Coastguard Worker            weights = (
3182*da0073e9SAndroid Build Coastguard Worker                make_tensor(bin_ct + 1, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3183*da0073e9SAndroid Build Coastguard Worker                if weighted
3184*da0073e9SAndroid Build Coastguard Worker                else None
3185*da0073e9SAndroid Build Coastguard Worker            )
3186*da0073e9SAndroid Build Coastguard Worker            self._test_histogram_numpy(bin_edges, bin_edges, None, weights, density)
3187*da0073e9SAndroid Build Coastguard Worker
3188*da0073e9SAndroid Build Coastguard Worker        # Tests values of default args
3189*da0073e9SAndroid Build Coastguard Worker        for bin_ct, shape in product(range(1, 10), shapes):
3190*da0073e9SAndroid Build Coastguard Worker            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
3191*da0073e9SAndroid Build Coastguard Worker            (actual_hist, actual_bin_edges) = torch.histogram(values, bin_ct)
3192*da0073e9SAndroid Build Coastguard Worker            (expected_hist, expected_bin_edges) = torch.histogram(
3193*da0073e9SAndroid Build Coastguard Worker                values, bin_ct, range=None, weight=None, density=False)
3194*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_hist, expected_hist)
3195*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_bin_edges, expected_bin_edges)
3196*da0073e9SAndroid Build Coastguard Worker
3197*da0073e9SAndroid Build Coastguard Worker    """
3198*da0073e9SAndroid Build Coastguard Worker    Runs torch.histogramdd and numpy.histogramdd on the specified input parameters
3199*da0073e9SAndroid Build Coastguard Worker    and asserts that their output is equal.
3200*da0073e9SAndroid Build Coastguard Worker    """
3201*da0073e9SAndroid Build Coastguard Worker    def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
3202*da0073e9SAndroid Build Coastguard Worker        def to_np(t):
3203*da0073e9SAndroid Build Coastguard Worker            if type(t) == list:
3204*da0073e9SAndroid Build Coastguard Worker                return list(map(to_np, t))
3205*da0073e9SAndroid Build Coastguard Worker            if not torch.is_tensor(t):
3206*da0073e9SAndroid Build Coastguard Worker                return t
3207*da0073e9SAndroid Build Coastguard Worker            return t.cpu().numpy()
3208*da0073e9SAndroid Build Coastguard Worker
3209*da0073e9SAndroid Build Coastguard Worker        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
3210*da0073e9SAndroid Build Coastguard Worker        def reference_histogramdd(t, bins, bin_range, weights, density, dtype):
3211*da0073e9SAndroid Build Coastguard Worker            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3212*da0073e9SAndroid Build Coastguard Worker
3213*da0073e9SAndroid Build Coastguard Worker            # numpy.histogramdd accepts only (N, D) shapes
3214*da0073e9SAndroid Build Coastguard Worker            D = np_t.shape[-1]
3215*da0073e9SAndroid Build Coastguard Worker            N = np.prod(np_t.shape[:-1])
3216*da0073e9SAndroid Build Coastguard Worker            reshaped_t = np.reshape(np_t, (N, D))
3217*da0073e9SAndroid Build Coastguard Worker            reshaped_wt = np.reshape(np_weights, (N,)) if np_weights is not None else None
3218*da0073e9SAndroid Build Coastguard Worker
3219*da0073e9SAndroid Build Coastguard Worker            # numpy.histogramdd throws an error for D=0
3220*da0073e9SAndroid Build Coastguard Worker            if D == 0:
3221*da0073e9SAndroid Build Coastguard Worker                return (torch.tensor(float('nan') if density else 0.), [])
3222*da0073e9SAndroid Build Coastguard Worker
3223*da0073e9SAndroid Build Coastguard Worker            # numpy.histogramdd expects range to be specified as a sequence of D (lower, upper) tuples
3224*da0073e9SAndroid Build Coastguard Worker            reshaped_range = None if not bin_range else [(bin_range[2 * i], bin_range[2 * i + 1]) for i in range(D)]
3225*da0073e9SAndroid Build Coastguard Worker
3226*da0073e9SAndroid Build Coastguard Worker            (np_hist, np_bin_edges) = np.histogramdd(reshaped_t, np_bins,
3227*da0073e9SAndroid Build Coastguard Worker                                                     range=reshaped_range, weights=reshaped_wt, density=density)
3228*da0073e9SAndroid Build Coastguard Worker
3229*da0073e9SAndroid Build Coastguard Worker            return (torch.from_numpy(np_hist).to(dtype), [torch.from_numpy(t).to(dtype) for t in np_bin_edges])
3230*da0073e9SAndroid Build Coastguard Worker
3231*da0073e9SAndroid Build Coastguard Worker        (actual_hist, actual_bin_edges) = torch.histogramdd(t, bins, range=bin_range, weight=weights, density=density)
3232*da0073e9SAndroid Build Coastguard Worker        (expected_hist, expected_bin_edges) = reference_histogramdd(t, bins, bin_range, weights, density, actual_hist.dtype)
3233*da0073e9SAndroid Build Coastguard Worker
3234*da0073e9SAndroid Build Coastguard Worker        D = len(actual_bin_edges)
3235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(D, len(expected_bin_edges))
3236*da0073e9SAndroid Build Coastguard Worker
3237*da0073e9SAndroid Build Coastguard Worker        """
3238*da0073e9SAndroid Build Coastguard Worker        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3239*da0073e9SAndroid Build Coastguard Worker        When bin edges are not explicitly defined, histogram uses the linspace operator internally
3240*da0073e9SAndroid Build Coastguard Worker        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3241*da0073e9SAndroid Build Coastguard Worker        from numpy.linspace output.
3242*da0073e9SAndroid Build Coastguard Worker        Issue: https://github.com/pytorch/pytorch/issues/58758
3243*da0073e9SAndroid Build Coastguard Worker        """
3244*da0073e9SAndroid Build Coastguard Worker        if not torch.is_tensor(bins):
3245*da0073e9SAndroid Build Coastguard Worker            for dim in range(D):
3246*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5)
3247*da0073e9SAndroid Build Coastguard Worker            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
3248*da0073e9SAndroid Build Coastguard Worker            (expected_hist, expected_bin_edges) = reference_histogramdd(
3249*da0073e9SAndroid Build Coastguard Worker                t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3250*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(D, len(expected_bin_edges))
3251*da0073e9SAndroid Build Coastguard Worker
3252*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual_hist, expected_hist)
3253*da0073e9SAndroid Build Coastguard Worker        for dim in range(D):
3254*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim])
3255*da0073e9SAndroid Build Coastguard Worker
3256*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3257*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
3258*da0073e9SAndroid Build Coastguard Worker    def test_histogramdd(self, device, dtype):
3259*da0073e9SAndroid Build Coastguard Worker        shapes = (
3260*da0073e9SAndroid Build Coastguard Worker            (1, 5),
3261*da0073e9SAndroid Build Coastguard Worker            (3, 5),
3262*da0073e9SAndroid Build Coastguard Worker            (1, 5, 1),
3263*da0073e9SAndroid Build Coastguard Worker            (2, 3, 5),
3264*da0073e9SAndroid Build Coastguard Worker            (7, 7, 7, 7),
3265*da0073e9SAndroid Build Coastguard Worker            (16, 8, 4, 2),
3266*da0073e9SAndroid Build Coastguard Worker            (10, 10, 10),
3267*da0073e9SAndroid Build Coastguard Worker            (7, 0, 3),
3268*da0073e9SAndroid Build Coastguard Worker            (5, 0),)
3269*da0073e9SAndroid Build Coastguard Worker
3270*da0073e9SAndroid Build Coastguard Worker        for contig, bins_contig, weighted, density, shape in \
3271*da0073e9SAndroid Build Coastguard Worker                product([True, False], [True, False], [True, False], [True, False], shapes):
3272*da0073e9SAndroid Build Coastguard Worker            D = shape[-1]
3273*da0073e9SAndroid Build Coastguard Worker
3274*da0073e9SAndroid Build Coastguard Worker            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3275*da0073e9SAndroid Build Coastguard Worker            weights = (
3276*da0073e9SAndroid Build Coastguard Worker                make_tensor(shape[:-1], dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3277*da0073e9SAndroid Build Coastguard Worker                if weighted
3278*da0073e9SAndroid Build Coastguard Worker                else None
3279*da0073e9SAndroid Build Coastguard Worker            )
3280*da0073e9SAndroid Build Coastguard Worker
3281*da0073e9SAndroid Build Coastguard Worker            # Tests passing a single bin count
3282*da0073e9SAndroid Build Coastguard Worker            bin_ct = random.randint(1, 5)
3283*da0073e9SAndroid Build Coastguard Worker            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3284*da0073e9SAndroid Build Coastguard Worker
3285*da0073e9SAndroid Build Coastguard Worker            # Tests passing a bin count for each dimension
3286*da0073e9SAndroid Build Coastguard Worker            bin_ct = [random.randint(1, 5) for dim in range(D)]
3287*da0073e9SAndroid Build Coastguard Worker            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3288*da0073e9SAndroid Build Coastguard Worker
3289*da0073e9SAndroid Build Coastguard Worker            # Tests with caller-specified histogram range
3290*da0073e9SAndroid Build Coastguard Worker            bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)]
3291*da0073e9SAndroid Build Coastguard Worker            bin_range = [elt for t in bin_range_tuples for elt in t]
3292*da0073e9SAndroid Build Coastguard Worker            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3293*da0073e9SAndroid Build Coastguard Worker
3294*da0073e9SAndroid Build Coastguard Worker            # Tests with range min=max
3295*da0073e9SAndroid Build Coastguard Worker            for dim in range(D):
3296*da0073e9SAndroid Build Coastguard Worker                bin_range[2 * dim + 1] = bin_range[2 * dim]
3297*da0073e9SAndroid Build Coastguard Worker            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3298*da0073e9SAndroid Build Coastguard Worker
3299*da0073e9SAndroid Build Coastguard Worker            # Tests with caller-specified bin edges
3300*da0073e9SAndroid Build Coastguard Worker            bin_edges = [make_tensor(ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() for ct in bin_ct]
3301*da0073e9SAndroid Build Coastguard Worker            if not bins_contig:
3302*da0073e9SAndroid Build Coastguard Worker                # Necessary because msort always produces contiguous output
3303*da0073e9SAndroid Build Coastguard Worker                bin_edges_noncontig = [
3304*da0073e9SAndroid Build Coastguard Worker                    make_tensor(ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3305*da0073e9SAndroid Build Coastguard Worker                    for ct in bin_ct
3306*da0073e9SAndroid Build Coastguard Worker                ]
3307*da0073e9SAndroid Build Coastguard Worker                for dim in range(D):
3308*da0073e9SAndroid Build Coastguard Worker                    bin_edges_noncontig[dim].copy_(bin_edges[dim])
3309*da0073e9SAndroid Build Coastguard Worker                bin_edges = bin_edges_noncontig
3310*da0073e9SAndroid Build Coastguard Worker            for dim in range(D):
3311*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig)
3312*da0073e9SAndroid Build Coastguard Worker            self._test_histogramdd_numpy(values, bin_edges, None, weights, density)
3313*da0073e9SAndroid Build Coastguard Worker
3314*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
3315*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float32)
3316*da0073e9SAndroid Build Coastguard Worker    def test_histogram_error_handling(self, device, dtype):
3317*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'not implemented for'):
3318*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=torch.int32, device=device)
3319*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 1)
3320*da0073e9SAndroid Build Coastguard Worker
3321*da0073e9SAndroid Build Coastguard Worker        inconsistent_dtype = torch.float32 if dtype != torch.float32 else torch.float64
3322*da0073e9SAndroid Build Coastguard Worker
3323*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensors should have the same dtype'):
3324*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3325*da0073e9SAndroid Build Coastguard Worker            bins = make_tensor((), dtype=inconsistent_dtype, device=device)
3326*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, bins)
3327*da0073e9SAndroid Build Coastguard Worker
3328*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'input tensor and weight tensor should have the same dtype'):
3329*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3330*da0073e9SAndroid Build Coastguard Worker            weight = make_tensor((), dtype=inconsistent_dtype, device=device)
3331*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 1, weight=weight)
3332*da0073e9SAndroid Build Coastguard Worker
3333*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'input tensor and hist tensor should have the same dtype'):
3334*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3335*da0073e9SAndroid Build Coastguard Worker            hist = make_tensor((), dtype=inconsistent_dtype, device=device)
3336*da0073e9SAndroid Build Coastguard Worker            bin_edges = make_tensor((), dtype=dtype, device=device)
3337*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 1, out=(hist, bin_edges))
3338*da0073e9SAndroid Build Coastguard Worker
3339*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'input tensor and bin_edges tensor should have the same dtype'):
3340*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3341*da0073e9SAndroid Build Coastguard Worker            hist = make_tensor((), dtype=dtype, device=device)
3342*da0073e9SAndroid Build Coastguard Worker            bin_edges = make_tensor((), dtype=inconsistent_dtype, device=device)
3343*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 1, out=(hist, bin_edges))
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have one dimension'):
3346*da0073e9SAndroid Build Coastguard Worker            t = make_tensor((2, 2), dtype=dtype, device=device)
3347*da0073e9SAndroid Build Coastguard Worker            torch.histogram(t, t)
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have at least 1 element'):
3350*da0073e9SAndroid Build Coastguard Worker            t = make_tensor((0), dtype=dtype, device=device)
3351*da0073e9SAndroid Build Coastguard Worker            torch.histogram(t, t)
3352*da0073e9SAndroid Build Coastguard Worker
3353*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
3354*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3355*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, -1)
3356*da0073e9SAndroid Build Coastguard Worker
3357*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'if weight tensor is provided it should have the same shape \
3358*da0073e9SAndroid Build Coastguard Workeras the input tensor excluding its innermost dimension'):
3359*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((2, 2), dtype=dtype, device=device)
3360*da0073e9SAndroid Build Coastguard Worker            weight = make_tensor((1), dtype=dtype, device=device)
3361*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 1, weight=weight)
3362*da0073e9SAndroid Build Coastguard Worker
3363*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, 'received an invalid combination of arguments'):
3364*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3365*da0073e9SAndroid Build Coastguard Worker            bin_edges = make_tensor((), dtype=dtype, device=device)
3366*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, bin_edges, range=(0, 1))
3367*da0073e9SAndroid Build Coastguard Worker
3368*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, 'min should not exceed max'):
3369*da0073e9SAndroid Build Coastguard Worker            values = make_tensor((), dtype=dtype, device=device)
3370*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 2, range=(1, 0))
3371*da0073e9SAndroid Build Coastguard Worker
3372*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, r'range \[nan, nan\] is not finite'):
3373*da0073e9SAndroid Build Coastguard Worker            values = torch.tensor([float("nan")], device=device, dtype=dtype)
3374*da0073e9SAndroid Build Coastguard Worker            torch.histogram(values, 2)
3375*da0073e9SAndroid Build Coastguard Worker
3376*da0073e9SAndroid Build Coastguard Worker    # Tests to ensure that reduction functions employing comparison operators are usable when there
3377*da0073e9SAndroid Build Coastguard Worker    # exists a zero dimension (i.e. when the tensors are empty) in the tensor. These tests specifically
3378*da0073e9SAndroid Build Coastguard Worker    # cater to functions where specifying the `dim` parameter is necessary.
3379*da0073e9SAndroid Build Coastguard Worker    def test_tensor_compare_ops_empty(self, device):
3380*da0073e9SAndroid Build Coastguard Worker        shape = (2, 0, 4)
3381*da0073e9SAndroid Build Coastguard Worker        master_input = torch.randn(shape, device=device)
3382*da0073e9SAndroid Build Coastguard Worker        np_input = np.empty(shape)
3383*da0073e9SAndroid Build Coastguard Worker        test_functions = [
3384*da0073e9SAndroid Build Coastguard Worker            ('amax', torch.amax, np.amax),
3385*da0073e9SAndroid Build Coastguard Worker            ('amin', torch.amin, np.amin),
3386*da0073e9SAndroid Build Coastguard Worker            ('max', lambda *args, **kwargs: torch.max(*args, **kwargs).values, np.max),
3387*da0073e9SAndroid Build Coastguard Worker            ('min', lambda *args, **kwargs: torch.min(*args, **kwargs).values, np.min),
3388*da0073e9SAndroid Build Coastguard Worker            ('median', lambda *args, **kwargs: torch.median(*args, **kwargs).values, np.median),
3389*da0073e9SAndroid Build Coastguard Worker        ]
3390*da0073e9SAndroid Build Coastguard Worker
3391*da0073e9SAndroid Build Coastguard Worker        for name, fn, np_function in test_functions:
3392*da0073e9SAndroid Build Coastguard Worker            # Check if reduction happens along the specified dim with and without keepdim. Check with
3393*da0073e9SAndroid Build Coastguard Worker            # numpy to maintain compatibility with numpy functions.
3394*da0073e9SAndroid Build Coastguard Worker            error_msg = f"test function: {name}"
3395*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3396*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=2),
3397*da0073e9SAndroid Build Coastguard Worker                             fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False)
3398*da0073e9SAndroid Build Coastguard Worker
3399*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3400*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=-1),
3401*da0073e9SAndroid Build Coastguard Worker                             fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False)
3402*da0073e9SAndroid Build Coastguard Worker
3403*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3404*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3405*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=2, keepdims=True),
3406*da0073e9SAndroid Build Coastguard Worker                             fn(master_input, dim=2, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3407*da0073e9SAndroid Build Coastguard Worker
3408*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3409*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3410*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=-1, keepdims=True),
3411*da0073e9SAndroid Build Coastguard Worker                             fn(master_input, dim=-1, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3412*da0073e9SAndroid Build Coastguard Worker
3413*da0073e9SAndroid Build Coastguard Worker            # Check if function raises error on specified zero'd dimension as reduction dim.
3414*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3415*da0073e9SAndroid Build Coastguard Worker
3416*da0073e9SAndroid Build Coastguard Worker    # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using comparison operators
3417*da0073e9SAndroid Build Coastguard Worker    # raises an error if no `dim` parameter is specified. This exists separately from tests in
3418*da0073e9SAndroid Build Coastguard Worker    # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does
3419*da0073e9SAndroid Build Coastguard Worker    # not throw errors. Also, checking the return type of argmax requires supplying a different dtype
3420*da0073e9SAndroid Build Coastguard Worker    # argument than that for the input tensor. There is also variantion in numpy testing.
3421*da0073e9SAndroid Build Coastguard Worker    def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device):
3422*da0073e9SAndroid Build Coastguard Worker        shape = (2, 0, 4)
3423*da0073e9SAndroid Build Coastguard Worker        master_input = torch.randn(shape, device=device)
3424*da0073e9SAndroid Build Coastguard Worker        np_input = np.empty(shape)
3425*da0073e9SAndroid Build Coastguard Worker        test_functions = [
3426*da0073e9SAndroid Build Coastguard Worker            ('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax),
3427*da0073e9SAndroid Build Coastguard Worker            ('argmin', torch.argmin, {'dtype': torch.int64}, np.argmin),
3428*da0073e9SAndroid Build Coastguard Worker            ('kthvalue', lambda *args, k=1, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values,
3429*da0073e9SAndroid Build Coastguard Worker             {}, lambda *args, k=1, axis=None, **kwargs: np.partition(*args, k, **kwargs).take(k - 1, axis=axis))
3430*da0073e9SAndroid Build Coastguard Worker        ]
3431*da0073e9SAndroid Build Coastguard Worker
3432*da0073e9SAndroid Build Coastguard Worker        for name, fn, dtype, np_function in test_functions:
3433*da0073e9SAndroid Build Coastguard Worker            error_msg = f"test function: {name}"
3434*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_msg)
3435*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3436*da0073e9SAndroid Build Coastguard Worker                np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False
3437*da0073e9SAndroid Build Coastguard Worker            )
3438*da0073e9SAndroid Build Coastguard Worker
3439*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_msg)
3440*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
3441*da0073e9SAndroid Build Coastguard Worker                np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False
3442*da0073e9SAndroid Build Coastguard Worker            )
3443*da0073e9SAndroid Build Coastguard Worker
3444*da0073e9SAndroid Build Coastguard Worker            # keepdim variant does not exist for numpy
3445*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=True),
3446*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3447*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=-1, keepdim=True),
3448*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3449*da0073e9SAndroid Build Coastguard Worker
3450*da0073e9SAndroid Build Coastguard Worker            # Check if function raises error on specified zero'd dimension as reduction dim.
3451*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3452*da0073e9SAndroid Build Coastguard Worker            if name != 'kthvalue':
3453*da0073e9SAndroid Build Coastguard Worker                self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input))
3454*da0073e9SAndroid Build Coastguard Worker
3455*da0073e9SAndroid Build Coastguard Worker    # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using math operators works when a
3456*da0073e9SAndroid Build Coastguard Worker    # non-zero dim is specified for the reduction and throws an error when the dim specified is 0. Although
3457*da0073e9SAndroid Build Coastguard Worker    # there is some repetition with test_tensor_compare_ops_optional_dim_empty and test_tensor_compare_ops_empty,
3458*da0073e9SAndroid Build Coastguard Worker    # these tests are kept separate since tests for math operators also require checking for correctness of the
3459*da0073e9SAndroid Build Coastguard Worker    # returned data using allclose() or isinf() which does not exists in the former tests.
3460*da0073e9SAndroid Build Coastguard Worker    @skipIfNoSciPy
3461*da0073e9SAndroid Build Coastguard Worker    def test_tensor_reduce_ops_empty(self, device):
3462*da0073e9SAndroid Build Coastguard Worker        from scipy.special import logsumexp
3463*da0073e9SAndroid Build Coastguard Worker        shape = (2, 0, 4)
3464*da0073e9SAndroid Build Coastguard Worker        master_input = torch.randn(shape, device=device)
3465*da0073e9SAndroid Build Coastguard Worker        np_input = np.empty(shape)
3466*da0073e9SAndroid Build Coastguard Worker        test_functions = [
3467*da0073e9SAndroid Build Coastguard Worker            ('prod', torch.prod, 1., np.prod),
3468*da0073e9SAndroid Build Coastguard Worker            ('sum', torch.sum, 0., np.sum),
3469*da0073e9SAndroid Build Coastguard Worker            ('norm', torch.norm, 0., np.linalg.norm),
3470*da0073e9SAndroid Build Coastguard Worker            ('mean', torch.mean, nan, np.mean),
3471*da0073e9SAndroid Build Coastguard Worker            ('var', torch.var, nan, np.var),
3472*da0073e9SAndroid Build Coastguard Worker            ('std', torch.std, nan, np.std),
3473*da0073e9SAndroid Build Coastguard Worker            ('logsumexp', torch.logsumexp, -inf, logsumexp),
3474*da0073e9SAndroid Build Coastguard Worker        ]
3475*da0073e9SAndroid Build Coastguard Worker
3476*da0073e9SAndroid Build Coastguard Worker        for name, fn, return_value, np_function in test_functions:
3477*da0073e9SAndroid Build Coastguard Worker            # Check if reduction happens along the specified dimension.
3478*da0073e9SAndroid Build Coastguard Worker            error_msg = f"test function: {name}"
3479*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3480*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg,
3481*da0073e9SAndroid Build Coastguard Worker                             exact_dtype=False)
3482*da0073e9SAndroid Build Coastguard Worker
3483*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3484*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg,
3485*da0073e9SAndroid Build Coastguard Worker                             exact_dtype=False)
3486*da0073e9SAndroid Build Coastguard Worker
3487*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3488*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3489*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True),
3490*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg, exact_dtype=False)
3491*da0073e9SAndroid Build Coastguard Worker
3492*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3493*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3494*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True),
3495*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg, exact_dtype=False)
3496*da0073e9SAndroid Build Coastguard Worker
3497*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg)
3498*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg)
3499*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True),
3500*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3501*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True),
3502*da0073e9SAndroid Build Coastguard Worker                             msg=error_msg)
3503*da0073e9SAndroid Build Coastguard Worker
3504*da0073e9SAndroid Build Coastguard Worker            if name != 'logsumexp':
3505*da0073e9SAndroid Build Coastguard Worker                # The scipy function does not work for reduction the zero dimension
3506*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(),
3507*da0073e9SAndroid Build Coastguard Worker                                 msg=error_msg)
3508*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(),
3509*da0073e9SAndroid Build Coastguard Worker                                 msg=error_msg)
3510*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)),
3511*da0073e9SAndroid Build Coastguard Worker                                 fn(master_input, dim=1, keepdim=True).cpu().numpy(),
3512*da0073e9SAndroid Build Coastguard Worker                                 msg=error_msg)
3513*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)),
3514*da0073e9SAndroid Build Coastguard Worker                                 fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
3515*da0073e9SAndroid Build Coastguard Worker                                 msg=error_msg)
3516*da0073e9SAndroid Build Coastguard Worker
3517*da0073e9SAndroid Build Coastguard Worker                # logsumexp throws a type error when not specifying dim so test separately.
3518*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg)
3519*da0073e9SAndroid Build Coastguard Worker            else:
3520*da0073e9SAndroid Build Coastguard Worker                self.assertRaises(TypeError, lambda: fn(master_input))
3521*da0073e9SAndroid Build Coastguard Worker
3522*da0073e9SAndroid Build Coastguard Worker    # Tests to ensure that any() and all() functions work with zero-dim tensors. Kept separate from
3523*da0073e9SAndroid Build Coastguard Worker    # other tests for checking reduction with zero-dim tensors because these tests have significantly
3524*da0073e9SAndroid Build Coastguard Worker    # different testing behaviour than that used for the former tests.
3525*da0073e9SAndroid Build Coastguard Worker    def test_reduction_empty_any_all(self, device):
3526*da0073e9SAndroid Build Coastguard Worker        shape = (2, 0, 4)
3527*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, device=device)
3528*da0073e9SAndroid Build Coastguard Worker
3529*da0073e9SAndroid Build Coastguard Worker        for dtype in all_types_and_complex_and(torch.half, torch.bool):
3530*da0073e9SAndroid Build Coastguard Worker            # Refer: [all, any uint8 compatibility]
3531*da0073e9SAndroid Build Coastguard Worker            if dtype == torch.uint8:
3532*da0073e9SAndroid Build Coastguard Worker                out_dtype = torch.uint8
3533*da0073e9SAndroid Build Coastguard Worker            else:
3534*da0073e9SAndroid Build Coastguard Worker                out_dtype = torch.bool  # output of all/any is bool irrespective of input dtype
3535*da0073e9SAndroid Build Coastguard Worker
3536*da0073e9SAndroid Build Coastguard Worker            xb = x.to(dtype)
3537*da0073e9SAndroid Build Coastguard Worker            yb = x.to(dtype)
3538*da0073e9SAndroid Build Coastguard Worker            # any
3539*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((2, 0), xb.any(2).shape)
3540*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape)
3541*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1))
3542*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True))
3543*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any())
3544*da0073e9SAndroid Build Coastguard Worker
3545*da0073e9SAndroid Build Coastguard Worker            # all
3546*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((2, 0), xb.all(2).shape)
3547*da0073e9SAndroid Build Coastguard Worker            self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape)
3548*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1))
3549*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True))
3550*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all())
3551*da0073e9SAndroid Build Coastguard Worker
3552*da0073e9SAndroid Build Coastguard Worker    # TODO: can these be merged with their respective OpInfos?
3553*da0073e9SAndroid Build Coastguard Worker    def test_reduce_dtype(self, device):
3554*da0073e9SAndroid Build Coastguard Worker        def test_reduction(op, has_no_dim, takes_dtype=True):
3555*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 3, dtype=torch.float, requires_grad=True, device=device)
3556*da0073e9SAndroid Build Coastguard Worker
3557*da0073e9SAndroid Build Coastguard Worker            if has_no_dim:
3558*da0073e9SAndroid Build Coastguard Worker                grad1, = torch.autograd.grad([op(x)], [x])
3559*da0073e9SAndroid Build Coastguard Worker                grad2, = torch.autograd.grad([op(x, dtype=torch.double)], [x])
3560*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad1, grad2)
3561*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad2.dtype, torch.float)
3562*da0073e9SAndroid Build Coastguard Worker
3563*da0073e9SAndroid Build Coastguard Worker            gi = torch.randn(op(x, dim=0).shape, dtype=torch.float, device=device)
3564*da0073e9SAndroid Build Coastguard Worker            grad1, = torch.autograd.grad([op(x, dim=0)], [x], gi)
3565*da0073e9SAndroid Build Coastguard Worker            if takes_dtype:
3566*da0073e9SAndroid Build Coastguard Worker                grad2, = torch.autograd.grad([op(x, dim=0, dtype=torch.double)], [x], gi.double())
3567*da0073e9SAndroid Build Coastguard Worker            else:
3568*da0073e9SAndroid Build Coastguard Worker                grad2, = torch.autograd.grad([op(x.double(), dim=0)], [x], gi.double())
3569*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad1, grad2)
3570*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad2.dtype, torch.float)
3571*da0073e9SAndroid Build Coastguard Worker
3572*da0073e9SAndroid Build Coastguard Worker        test_reduction(torch.sum, True)
3573*da0073e9SAndroid Build Coastguard Worker        test_reduction(torch.prod, True)
3574*da0073e9SAndroid Build Coastguard Worker        test_reduction(torch.cumsum, False)
3575*da0073e9SAndroid Build Coastguard Worker        test_reduction(torch.cumprod, False)
3576*da0073e9SAndroid Build Coastguard Worker        test_reduction(torch.logcumsumexp, False, takes_dtype=False)
3577*da0073e9SAndroid Build Coastguard Worker
3578*da0073e9SAndroid Build Coastguard Worker    @ops(reference_masked_ops)
3579*da0073e9SAndroid Build Coastguard Worker    def test_reference_masked(self, device, dtype, op):
3580*da0073e9SAndroid Build Coastguard Worker        """Test masked reduction operations on strided-only tensors using
3581*da0073e9SAndroid Build Coastguard Worker        numpy reductions as reference.
3582*da0073e9SAndroid Build Coastguard Worker        """
3583*da0073e9SAndroid Build Coastguard Worker
3584*da0073e9SAndroid Build Coastguard Worker        def to_numpy(input):
3585*da0073e9SAndroid Build Coastguard Worker            if input.dtype is torch.bfloat16:
3586*da0073e9SAndroid Build Coastguard Worker                return input.cpu().to(torch.float32).numpy()
3587*da0073e9SAndroid Build Coastguard Worker            else:
3588*da0073e9SAndroid Build Coastguard Worker                return input.cpu().numpy()
3589*da0073e9SAndroid Build Coastguard Worker
3590*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
3591*da0073e9SAndroid Build Coastguard Worker        for sample_input in samples:
3592*da0073e9SAndroid Build Coastguard Worker            t = sample_input.input
3593*da0073e9SAndroid Build Coastguard Worker            actual = op(t, *sample_input.args, **sample_input.kwargs)
3594*da0073e9SAndroid Build Coastguard Worker            exact_dtype = not (t.dtype is torch.bfloat16
3595*da0073e9SAndroid Build Coastguard Worker                               or (op.promotes_int_to_float and not torch.is_floating_point(t)))
3596*da0073e9SAndroid Build Coastguard Worker            expected = op.ref(to_numpy(t), *sample_input.args,
3597*da0073e9SAndroid Build Coastguard Worker                              **dict(
3598*da0073e9SAndroid Build Coastguard Worker                                  # `identity` is mapped to numpy reduction `initial` argument
3599*da0073e9SAndroid Build Coastguard Worker                                  identity=torch.masked._reduction_identity(op.name, t),
3600*da0073e9SAndroid Build Coastguard Worker                                  **sample_input.kwargs))
3601*da0073e9SAndroid Build Coastguard Worker
3602*da0073e9SAndroid Build Coastguard Worker            # Workaround https://github.com/pytorch/pytorch/issues/66556
3603*da0073e9SAndroid Build Coastguard Worker            expected = np.asarray(expected)  # transform numpy scalars to numpy.ndarray instances
3604*da0073e9SAndroid Build Coastguard Worker
3605*da0073e9SAndroid Build Coastguard Worker            # Numpy differs, producing uint32 on Windows
3606*da0073e9SAndroid Build Coastguard Worker            if expected.dtype in [np.uint64, np.uint32]:
3607*da0073e9SAndroid Build Coastguard Worker                exact_dtype = False
3608*da0073e9SAndroid Build Coastguard Worker
3609*da0073e9SAndroid Build Coastguard Worker            msg = ("Failed to produce expected results! Input tensor was"
3610*da0073e9SAndroid Build Coastguard Worker                   f" {t}, torch result is {actual}, and reference result is"
3611*da0073e9SAndroid Build Coastguard Worker                   f" {expected}.") if t.numel() < 10 else None
3612*da0073e9SAndroid Build Coastguard Worker
3613*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, expected, msg, exact_dtype=exact_dtype)
3614*da0073e9SAndroid Build Coastguard Worker
3615*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
3616*da0073e9SAndroid Build Coastguard Worker    @largeTensorTest("8GB")
3617*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.half, torch.chalf, torch.bfloat16)
3618*da0073e9SAndroid Build Coastguard Worker    def test_reductions_large_half_tensors(self, device, dtype):
3619*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(2**31, device=device, dtype=dtype)
3620*da0073e9SAndroid Build Coastguard Worker        t[2**30:] = -1
3621*da0073e9SAndroid Build Coastguard Worker        expected = torch.tensor(0, device=device, dtype=dtype)
3622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sum(t), expected)
3623*da0073e9SAndroid Build Coastguard Worker
3624*da0073e9SAndroid Build Coastguard Worker        # mean_cuda is not implemented for ComplexHalf
3625*da0073e9SAndroid Build Coastguard Worker        err_msg = "not implemented for 'ComplexHalf'"
3626*da0073e9SAndroid Build Coastguard Worker        ctx = self.assertRaisesRegex(
3627*da0073e9SAndroid Build Coastguard Worker            RuntimeError, err_msg) if dtype is torch.chalf else contextlib.nullcontext()
3628*da0073e9SAndroid Build Coastguard Worker        with ctx:
3629*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.mean(t), expected)
3630*da0073e9SAndroid Build Coastguard Worker
3631*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestReductions, globals())
3632*da0073e9SAndroid Build Coastguard Worker
3633*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
3634*da0073e9SAndroid Build Coastguard Worker    run_tests()
3635