1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerimport numpy as np 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport math 8*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Sequence 9*da0073e9SAndroid Build Coastguard Workerimport random 10*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 11*da0073e9SAndroid Build Coastguard Workerfrom itertools import product, combinations, permutations 12*da0073e9SAndroid Build Coastguard Workerimport warnings 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 17*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and, 18*da0073e9SAndroid Build Coastguard Worker integral_types_and, floating_and_complex_types_and, all_types_and, all_types, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 21*da0073e9SAndroid Build Coastguard Worker TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, 22*da0073e9SAndroid Build Coastguard Worker parametrize, 23*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS) 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 25*da0073e9SAndroid Build Coastguard Worker OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, 26*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, onlyCUDA, largeTensorTest, ops, precisionOverride) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 28*da0073e9SAndroid Build Coastguard Worker ReductionOpInfo, ReductionPythonRefInfo, reduction_ops, reference_masked_ops) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker# TODO: replace with make_tensor 31*da0073e9SAndroid Build Coastguard Workerdef _generate_input(shape, dtype, device, with_extremal): 32*da0073e9SAndroid Build Coastguard Worker if shape == (): 33*da0073e9SAndroid Build Coastguard Worker x = torch.tensor((), dtype=dtype, device=device) 34*da0073e9SAndroid Build Coastguard Worker else: 35*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 36*da0073e9SAndroid Build Coastguard Worker # work around torch.randn not being implemented for bfloat16 37*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 38*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, device=device) * random.randint(30, 100) 39*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.bfloat16) 40*da0073e9SAndroid Build Coastguard Worker else: 41*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) 42*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = 0 43*da0073e9SAndroid Build Coastguard Worker if with_extremal and dtype.is_floating_point: 44*da0073e9SAndroid Build Coastguard Worker # Use extremal values 45*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float('nan') 46*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float('inf') 47*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float('-inf') 48*da0073e9SAndroid Build Coastguard Worker elif with_extremal and dtype.is_complex: 49*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex('nan') 50*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex('inf') 51*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex('-inf') 52*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bool: 53*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(shape, dtype=dtype, device=device) 54*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = True 55*da0073e9SAndroid Build Coastguard Worker else: 56*da0073e9SAndroid Build Coastguard Worker x = torch.randint(15, 100, shape, dtype=dtype, device=device) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker return x 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker# TODO: replace with make_tensor 61*da0073e9SAndroid Build Coastguard Workerdef _rand_shape(dim, min_size, max_size): 62*da0073e9SAndroid Build Coastguard Worker shape = [] 63*da0073e9SAndroid Build Coastguard Worker for i in range(dim): 64*da0073e9SAndroid Build Coastguard Worker shape.append(random.randint(min_size, max_size)) 65*da0073e9SAndroid Build Coastguard Worker return tuple(shape) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerdef _reduced_shape(shape, dim=None, keepdim=False): 68*da0073e9SAndroid Build Coastguard Worker """Computes the expected reduced shape given dim and keepdim 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker Args: 71*da0073e9SAndroid Build Coastguard Worker shape: The shape to reduce 72*da0073e9SAndroid Build Coastguard Worker dim : The dimensions to reduce 73*da0073e9SAndroid Build Coastguard Worker keepdim: If true, reduced dimensions have size 1 in the reduced shape, 74*da0073e9SAndroid Build Coastguard Worker otherwise they are removed from the reduced shape. 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker Returns: 77*da0073e9SAndroid Build Coastguard Worker The reduced shape 78*da0073e9SAndroid Build Coastguard Worker """ 79*da0073e9SAndroid Build Coastguard Worker if dim is None: 80*da0073e9SAndroid Build Coastguard Worker return [1] * len(shape) if keepdim else [] 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker # Wrap negative dims 83*da0073e9SAndroid Build Coastguard Worker dim = dim if isinstance(dim, Sequence) else [dim] 84*da0073e9SAndroid Build Coastguard Worker dim = {i if i >= 0 else len(shape) + i for i in dim} 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker result = [] 87*da0073e9SAndroid Build Coastguard Worker for i, size in enumerate(shape): 88*da0073e9SAndroid Build Coastguard Worker if i not in dim: 89*da0073e9SAndroid Build Coastguard Worker result.append(size) 90*da0073e9SAndroid Build Coastguard Worker elif keepdim: 91*da0073e9SAndroid Build Coastguard Worker result.append(1) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker return result 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Workerclass TestReductions(TestCase): 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker ########################################################################### 98*da0073e9SAndroid Build Coastguard Worker # ReductionOpInfo unit tests 99*da0073e9SAndroid Build Coastguard Worker ########################################################################### 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim): 102*da0073e9SAndroid Build Coastguard Worker """Tests output shape for input with ndim and dim and keepdim kwargs""" 103*da0073e9SAndroid Build Coastguard Worker shape = torch.randint(2, 5, (ndim,)).tolist() 104*da0073e9SAndroid Build Coastguard Worker t = make_tensor(shape, dtype=torch.float, device=device) 105*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim)) 106*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, **dim_keepdim, **kwargs) 107*da0073e9SAndroid Build Coastguard Worker expected_shape = _reduced_shape(shape, **dim_keepdim) 108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, expected_shape, f""" 109*da0073e9SAndroid Build Coastguard Worker expected output shape to be {expected_shape} but got {list(result.shape)} 110*da0073e9SAndroid Build Coastguard Worker for input shape {shape} and {dim_keepdim} 111*da0073e9SAndroid Build Coastguard Worker """) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker # TODO(@heitorschueroff) combine cases with and without keepdim once 114*da0073e9SAndroid Build Coastguard Worker # there's support for a @parametrize decorator. 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 117*da0073e9SAndroid Build Coastguard Worker def test_dim_default(self, device, op: ReductionOpInfo): 118*da0073e9SAndroid Build Coastguard Worker """Tests that the default dim reduces all dimensions.""" 119*da0073e9SAndroid Build Coastguard Worker for ndim in range(3): 120*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=ndim) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 123*da0073e9SAndroid Build Coastguard Worker def test_dim_default_keepdim(self, device, op: ReductionOpInfo): 124*da0073e9SAndroid Build Coastguard Worker """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1.""" 125*da0073e9SAndroid Build Coastguard Worker for ndim in range(3): 126*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 129*da0073e9SAndroid Build Coastguard Worker def test_dim_none(self, device, op: ReductionOpInfo): 130*da0073e9SAndroid Build Coastguard Worker """Tests that dim=None reduces all dimensions.""" 131*da0073e9SAndroid Build Coastguard Worker for ndim in range(3): 132*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=ndim, dim=None) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 135*da0073e9SAndroid Build Coastguard Worker def test_dim_none_keepdim(self, device, op: ReductionOpInfo): 136*da0073e9SAndroid Build Coastguard Worker """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1.""" 137*da0073e9SAndroid Build Coastguard Worker for ndim in range(3): 138*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 141*da0073e9SAndroid Build Coastguard Worker def test_dim_single(self, device, op: ReductionOpInfo): 142*da0073e9SAndroid Build Coastguard Worker """Tests that dim=i reduces dimension i.""" 143*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=0, dim=0) 144*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=1, dim=0) 145*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=2, dim=-1) 146*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=3, dim=1) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 149*da0073e9SAndroid Build Coastguard Worker def test_dim_single_keepdim(self, device, op: ReductionOpInfo): 150*da0073e9SAndroid Build Coastguard Worker """Tests that dim=i, when keepdim=True, reduces dimension i to size 1.""" 151*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True) 152*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True) 153*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True) 154*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 157*da0073e9SAndroid Build Coastguard Worker def test_dim_empty(self, device, op: ReductionOpInfo): 158*da0073e9SAndroid Build Coastguard Worker """Tests that dim=[] is a no-op""" 159*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=0, dim=[]) 160*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=2, dim=[]) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 163*da0073e9SAndroid Build Coastguard Worker def test_dim_empty_keepdim(self, device, op: ReductionOpInfo): 164*da0073e9SAndroid Build Coastguard Worker """Tests that dim=[], when keepdim=True, is a no-op""" 165*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True) 166*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 169*da0073e9SAndroid Build Coastguard Worker def test_dim_multi(self, device, op: ReductionOpInfo): 170*da0073e9SAndroid Build Coastguard Worker """Tests that dim=[i, j, ...] reduces dimensions i, j, ....""" 171*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=1, dim=[0]) 172*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 175*da0073e9SAndroid Build Coastguard Worker def test_dim_multi_keepdim(self, device, op: ReductionOpInfo): 176*da0073e9SAndroid Build Coastguard Worker """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1.""" 177*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True) 178*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 181*da0073e9SAndroid Build Coastguard Worker def test_dim_multi_unsorted(self, device, op: ReductionOpInfo): 182*da0073e9SAndroid Build Coastguard Worker """Tests that operator correctly handles unsorted dim list.""" 183*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2]) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 186*da0073e9SAndroid Build Coastguard Worker def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo): 187*da0073e9SAndroid Build Coastguard Worker """Tests that operator correctly handles unsorted dim list when keepdim=True.""" 188*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 191*da0073e9SAndroid Build Coastguard Worker def test_dim_multi_duplicate(self, device, op: ReductionOpInfo): 192*da0073e9SAndroid Build Coastguard Worker """Tests that an error is raised if dim has duplicate entries.""" 193*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 194*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2]) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) 197*da0073e9SAndroid Build Coastguard Worker def test_dim_multi_unsupported(self, device, op: ReductionOpInfo): 198*da0073e9SAndroid Build Coastguard Worker """Tests that ops claiming to not support multi dim actually don't.""" 199*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 200*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 203*da0073e9SAndroid Build Coastguard Worker def test_dim_offbounds(self, device, op: ReductionOpInfo): 204*da0073e9SAndroid Build Coastguard Worker """Tests that passing an off-bounds dim throws""" 205*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 206*da0073e9SAndroid Build Coastguard Worker self._test_dim_keepdim(op, device, ndim=2, dim=2) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 209*da0073e9SAndroid Build Coastguard Worker def test_dim_ndim_limit(self, device, op: ReductionOpInfo): 210*da0073e9SAndroid Build Coastguard Worker """Tests that an exception is raised when reducing a tensor with more 211*da0073e9SAndroid Build Coastguard Worker than 64 dims along some specific dimensions. dim=None is ok""" 212*da0073e9SAndroid Build Coastguard Worker t = make_tensor([1] * 65, dtype=torch.float, device=device) 213*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): 214*da0073e9SAndroid Build Coastguard Worker op(t, dim=0) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported) 217*da0073e9SAndroid Build Coastguard Worker def test_identity(self, device, dtype, op: ReductionOpInfo): 218*da0073e9SAndroid Build Coastguard Worker """Tests that the identity value is an identity for the operator""" 219*da0073e9SAndroid Build Coastguard Worker t = make_tensor((10,), dtype=dtype, device=device) 220*da0073e9SAndroid Build Coastguard Worker t[1::2] = op.identity 221*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t)) 222*da0073e9SAndroid Build Coastguard Worker result = op(t[::2], *args, **kwargs) 223*da0073e9SAndroid Build Coastguard Worker result_with_identity = op(t, *args, **kwargs) 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_with_identity, """ 225*da0073e9SAndroid Build Coastguard Worker Adding identity value to the input tensor should not change the result. 226*da0073e9SAndroid Build Coastguard Worker """) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker # TODO(@heitorschueroff) Update these to use the nan_policy kwarg once 229*da0073e9SAndroid Build Coastguard Worker # it is added to reduction operators. 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported, 232*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16)) 233*da0073e9SAndroid Build Coastguard Worker def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo): 234*da0073e9SAndroid Build Coastguard Worker """Tests that nan is propagated to the output by default""" 235*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5,), dtype=dtype, device=device) 236*da0073e9SAndroid Build Coastguard Worker t[2] = torch.nan 237*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t)) 238*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, **kwargs) 239*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.isnan()) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported, 242*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16)) 243*da0073e9SAndroid Build Coastguard Worker def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo): 244*da0073e9SAndroid Build Coastguard Worker """Tests that NaN values do not affect the result.""" 245*da0073e9SAndroid Build Coastguard Worker t = make_tensor((10,), dtype=dtype, device=device) 246*da0073e9SAndroid Build Coastguard Worker t[1::2] = torch.nan 247*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t)) 248*da0073e9SAndroid Build Coastguard Worker result = op(t[::2], *args, **kwargs) 249*da0073e9SAndroid Build Coastguard Worker result_with_nan = op(t, *args, **kwargs) 250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_with_nan) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.supported) 253*da0073e9SAndroid Build Coastguard Worker def test_result_dtype(self, device, dtype, op: ReductionOpInfo): 254*da0073e9SAndroid Build Coastguard Worker """Tests that the result has the correct dtype""" 255*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5,), dtype=dtype, device=device) 256*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t)) 257*da0073e9SAndroid Build Coastguard Worker result: torch.Tensor = op(t, *args, **kwargs) 258*da0073e9SAndroid Build Coastguard Worker is_integral = dtype in integral_types_and(torch.bool) 259*da0073e9SAndroid Build Coastguard Worker if op.promotes_int_to_float and is_integral: 260*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_floating_point(result)) 261*da0073e9SAndroid Build Coastguard Worker elif op.promotes_int_to_int64 and is_integral: 262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.int64) 263*da0073e9SAndroid Build Coastguard Worker elif op.result_dtype is not None: 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, op.result_dtype) 265*da0073e9SAndroid Build Coastguard Worker elif op.complex_to_real: 266*da0073e9SAndroid Build Coastguard Worker _complex_to_real_dtype_map = { 267*da0073e9SAndroid Build Coastguard Worker torch.complex128: torch.float64, 268*da0073e9SAndroid Build Coastguard Worker torch.complex64: torch.float32, 269*da0073e9SAndroid Build Coastguard Worker torch.complex32: torch.float16, 270*da0073e9SAndroid Build Coastguard Worker } 271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, _complex_to_real_dtype_map.get(dtype, dtype)) 272*da0073e9SAndroid Build Coastguard Worker else: 273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, dtype) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 276*da0073e9SAndroid Build Coastguard Worker def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo): 277*da0073e9SAndroid Build Coastguard Worker """Tests for consistent behavior when reducing over an empty slice. 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker The rules for reducing over an empty slice are as follows: 280*da0073e9SAndroid Build Coastguard Worker - Return the identity value if the operator has one 281*da0073e9SAndroid Build Coastguard Worker - Otherwise, return NaN if the operator promotes integral dtype to 282*da0073e9SAndroid Build Coastguard Worker floating point dtypes. 283*da0073e9SAndroid Build Coastguard Worker - Otherwise, raise an error 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker See discussion here https://github.com/pytorch/pytorch/issues/61901 286*da0073e9SAndroid Build Coastguard Worker """ 287*da0073e9SAndroid Build Coastguard Worker t = make_tensor((0, 2, 3), dtype=torch.float, device=device) 288*da0073e9SAndroid Build Coastguard Worker for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []: 289*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t, dim=dim)) 290*da0073e9SAndroid Build Coastguard Worker if op.identity is not None: 291*da0073e9SAndroid Build Coastguard Worker # Reducing along empty slice should return identity 292*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, dim=dim, **kwargs) 293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.full_like(result, op.identity)) 294*da0073e9SAndroid Build Coastguard Worker elif op.promotes_int_to_float: 295*da0073e9SAndroid Build Coastguard Worker # Reducing along empty slice should return NaN 296*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, dim=dim, **kwargs) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.full_like(result, torch.nan)) 298*da0073e9SAndroid Build Coastguard Worker else: 299*da0073e9SAndroid Build Coastguard Worker # Reducing along empty slice should raise an error 300*da0073e9SAndroid Build Coastguard Worker if isinstance(op, ReductionPythonRefInfo): 301*da0073e9SAndroid Build Coastguard Worker # ref reductions throw RuntimeError for this 302*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 303*da0073e9SAndroid Build Coastguard Worker op(t, *args, dim=dim, **kwargs) 304*da0073e9SAndroid Build Coastguard Worker else: 305*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 306*da0073e9SAndroid Build Coastguard Worker op(t, *args, dim=dim, **kwargs) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops, dtypes=OpDTypes.none) 309*da0073e9SAndroid Build Coastguard Worker def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo): 310*da0073e9SAndroid Build Coastguard Worker """Tests that reducing a nonempty slice of an empty tensor returns an 311*da0073e9SAndroid Build Coastguard Worker empty tensor with the dimensions reduced.""" 312*da0073e9SAndroid Build Coastguard Worker t = make_tensor((0, 2, 3), dtype=torch.float, device=device) 313*da0073e9SAndroid Build Coastguard Worker for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []: 314*da0073e9SAndroid Build Coastguard Worker args, kwargs = next(op.generate_args_kwargs(t, dim=dim)) 315*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, dim=dim, **kwargs) 316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, _reduced_shape(t.shape, dim)) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs): 319*da0073e9SAndroid Build Coastguard Worker """Helper method to test noncontiguous input tensors.""" 320*da0073e9SAndroid Build Coastguard Worker assert not t.is_contiguous() 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker t_contig = t.contiguous() 323*da0073e9SAndroid Build Coastguard Worker for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs): 324*da0073e9SAndroid Build Coastguard Worker kwargs.update(reduction_kwargs) 325*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, **kwargs) 326*da0073e9SAndroid Build Coastguard Worker expected = op(t_contig, *args, **kwargs) 327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops) 330*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo): 331*da0073e9SAndroid Build Coastguard Worker """Tests reducing along noncontiguous innermost dimension.""" 332*da0073e9SAndroid Build Coastguard Worker t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1) 333*da0073e9SAndroid Build Coastguard Worker self._test_noncontiguous(op, t[:, ::2], dim=1) 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops) 336*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo): 337*da0073e9SAndroid Build Coastguard Worker """Tests reducing along noncontiguous outermost dimension.""" 338*da0073e9SAndroid Build Coastguard Worker t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1) 339*da0073e9SAndroid Build Coastguard Worker self._test_noncontiguous(op, t[::2, :], dim=0) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops) 342*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo): 343*da0073e9SAndroid Build Coastguard Worker """Tests reducing all dimensions of a noncontiguous tensor.""" 344*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5, 5, 5), dtype=dtype, device=device, low=-1, high=1) 345*da0073e9SAndroid Build Coastguard Worker self._test_noncontiguous(op, t[::2, ::3, 1:-1:2]) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops) 348*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo): 349*da0073e9SAndroid Build Coastguard Worker """Tests reducing a transposed tensor.""" 350*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5, 5), dtype=dtype, device=device, low=-1, high=1) 351*da0073e9SAndroid Build Coastguard Worker self._test_noncontiguous(op, t.T) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker @ops(reduction_ops) 354*da0073e9SAndroid Build Coastguard Worker def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo): 355*da0073e9SAndroid Build Coastguard Worker """Tests reducing a tensor with expanded singleton dimensions.""" 356*da0073e9SAndroid Build Coastguard Worker t = make_tensor((2, 3), dtype=dtype, device=device, low=-1, high=1) 357*da0073e9SAndroid Build Coastguard Worker self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1)) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker # NumPy does not support BFloat16 so we don't test that against reference 360*da0073e9SAndroid Build Coastguard Worker # implementations. We also don't compare dtypes or test for different 361*da0073e9SAndroid Build Coastguard Worker # keepdim because we already have other tests covering those. 362*da0073e9SAndroid Build Coastguard Worker # The test_reference_testing in test_ops.py only uses the samples from 363*da0073e9SAndroid Build Coastguard Worker # sample_inputs_func which do not test as exhaustively as these tests. 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs): 366*da0073e9SAndroid Build Coastguard Worker """Compares op against op.ref for the given input and reduction kwargs""" 367*da0073e9SAndroid Build Coastguard Worker for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs): 368*da0073e9SAndroid Build Coastguard Worker kwargs.update(reduction_kwargs) 369*da0073e9SAndroid Build Coastguard Worker result = op(t, *args, **kwargs) 370*da0073e9SAndroid Build Coastguard Worker expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs) 371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected, exact_dtype=False) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 374*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool)) 375*da0073e9SAndroid Build Coastguard Worker def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo): 376*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for scalar input tensors""" 377*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, make_tensor([], dtype=dtype, device=device)) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 380*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool)) 381*da0073e9SAndroid Build Coastguard Worker def test_ref_small_input(self, device, dtype, op: ReductionOpInfo): 382*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for small input tensors""" 383*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5, 3, 4, 2), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True) 384*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t) 385*da0073e9SAndroid Build Coastguard Worker for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []): 386*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t, dim=dim) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 389*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.float64]) 390*da0073e9SAndroid Build Coastguard Worker def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo): 391*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for a large 1D input tensor to check stability""" 392*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, make_tensor((2 ** 20,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 395*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.float64]) 396*da0073e9SAndroid Build Coastguard Worker def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo): 397*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for a large 2D input tensor to test parallelism""" 398*da0073e9SAndroid Build Coastguard Worker t = make_tensor((32, 2 ** 16), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True) 399*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t, dim=1) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("8gb") 402*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 403*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.float64]) 404*da0073e9SAndroid Build Coastguard Worker def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo): 405*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for a very large input tensor that requires 64 bit indexing""" 406*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, make_tensor((275000000,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 409*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool)) 410*da0073e9SAndroid Build Coastguard Worker def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo): 411*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for input tensors with duplicate values""" 412*da0073e9SAndroid Build Coastguard Worker t = make_tensor((4, 4), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True) 413*da0073e9SAndroid Build Coastguard Worker t[::2, ::2] = t[1::2, 1::2] 414*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t) 415*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t, dim=0) 416*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t, dim=1) 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker @ops(filter(lambda op: op.ref is not None, reduction_ops), 419*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=[torch.float32, torch.complex64]) 420*da0073e9SAndroid Build Coastguard Worker def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo): 421*da0073e9SAndroid Build Coastguard Worker """Compares op against reference for input tensors with extremal values""" 422*da0073e9SAndroid Build Coastguard Worker t = make_tensor((5,), dtype=dtype, device=device, exclude_zero=True) 423*da0073e9SAndroid Build Coastguard Worker extremals = [0, 1, nan, inf, -inf] 424*da0073e9SAndroid Build Coastguard Worker for extremal in extremals: 425*da0073e9SAndroid Build Coastguard Worker t[2] = extremal 426*da0073e9SAndroid Build Coastguard Worker self._test_ref(op, t) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker ########################################################################### 429*da0073e9SAndroid Build Coastguard Worker # TODO: Legacy tests - port to ReductionOpInfo 430*da0073e9SAndroid Build Coastguard Worker ########################################################################### 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def test_var_unbiased(self, device): 433*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(100, device=device) 434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) 435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(), tensor.var(unbiased=True)) 436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False)) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([1.0, 2.0], device=device) 439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(unbiased=True), 0.5) 440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(unbiased=False), 0.25) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([1.0, 2.0, 3.0], device=device) 443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(unbiased=True), 1.0) 444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(100, device=device) 447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True)) 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.std(), tensor.std(unbiased=True)) 449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker def test_var_stability(self, device): 452*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([2281.5, 2281.25], device=device) 453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(dim=0), 0.03125) 454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(), 0.03125) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker def test_sum_dim_reduction_uint8_overflow(self, device): 457*da0073e9SAndroid Build Coastguard Worker example = [[-1, 2, 1], [5, 3, 6]] 458*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, dtype=torch.uint8, device=device) 459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(dtype=torch.uint8).item(), 16) 460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8, device=device)) 461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8, device=device)) 462*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(example, dtype=torch.uint8, device=device) 463*da0073e9SAndroid Build Coastguard Worker torch.sum(x, 0, out=y) 464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(0, dtype=torch.uint8), y) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker def test_dim_reduction_less_than_64(self, device): 467*da0073e9SAndroid Build Coastguard Worker sizes = [1] * 65 468*da0073e9SAndroid Build Coastguard Worker x = torch.randn(sizes, device=device) 469*da0073e9SAndroid Build Coastguard Worker ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, 470*da0073e9SAndroid Build Coastguard Worker torch.norm] 471*da0073e9SAndroid Build Coastguard Worker for op in ops: 472*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): 473*da0073e9SAndroid Build Coastguard Worker op(x, dim=64) 474*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): 475*da0073e9SAndroid Build Coastguard Worker op(x, dim=-1) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker @onlyCPU 478*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.bfloat16) 479*da0073e9SAndroid Build Coastguard Worker def test_dim_reduction_lastdim(self, device, dtype): 480*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 5, 40, device=device, dtype=dtype) 481*da0073e9SAndroid Build Coastguard Worker x = x[:, :, 0:40:2] 482*da0073e9SAndroid Build Coastguard Worker x2 = x.contiguous() 483*da0073e9SAndroid Build Coastguard Worker ops = [torch.norm, torch.argmax, torch.argmin] 484*da0073e9SAndroid Build Coastguard Worker for op in ops: 485*da0073e9SAndroid Build Coastguard Worker y = op(x, dim=-1) 486*da0073e9SAndroid Build Coastguard Worker y2 = op(x2, dim=-1) 487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y2) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 490*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.double, torch.complex64, torch.complex128) 491*da0073e9SAndroid Build Coastguard Worker def test_logsumexp(self, device, dtype): 492*da0073e9SAndroid Build Coastguard Worker from scipy.special import logsumexp 493*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 4, device=device, dtype=dtype) 494*da0073e9SAndroid Build Coastguard Worker # torch.exp(complex(inf, 0)) yields inf+nan*j instead of inf+0*j on CPU which disagrees with CUDA, C++ std::exp, 495*da0073e9SAndroid Build Coastguard Worker # numpy and scipy. Skip inf testing on CPU. Related to https://github.com/pytorch/pytorch/issues/95740 496*da0073e9SAndroid Build Coastguard Worker if torch.device(device) != torch.device('cpu'): 497*da0073e9SAndroid Build Coastguard Worker a[0, 0] = inf 498*da0073e9SAndroid Build Coastguard Worker a[1, :] = -inf 499*da0073e9SAndroid Build Coastguard Worker actual = a.logsumexp(1) 500*da0073e9SAndroid Build Coastguard Worker expected = logsumexp(a.cpu().numpy(), 1) 501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker # check that out is actually inplace 505*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(5, 2, device=device, dtype=dtype) 506*da0073e9SAndroid Build Coastguard Worker c = b[:, 0] 507*da0073e9SAndroid Build Coastguard Worker torch.logsumexp(a, 1, out=c) 508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, b[:, 0]) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 511*da0073e9SAndroid Build Coastguard Worker def test_logsumexp_integral_promotion(self, device): 512*da0073e9SAndroid Build Coastguard Worker from scipy.special import logsumexp 513*da0073e9SAndroid Build Coastguard Worker # check integral inputs is promoted to floating point 514*da0073e9SAndroid Build Coastguard Worker e = torch.randint(-100, 100, [5, 4], device=device) 515*da0073e9SAndroid Build Coastguard Worker actual = e.logsumexp(1).to(torch.float64) 516*da0073e9SAndroid Build Coastguard Worker expected = logsumexp(e.cpu().numpy(), 1) 517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 521*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 522*da0073e9SAndroid Build Coastguard Worker def test_logcumsumexp_complex(self, device, dtype): 523*da0073e9SAndroid Build Coastguard Worker # logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))`` 524*da0073e9SAndroid Build Coastguard Worker # and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]`` 525*da0073e9SAndroid Build Coastguard Worker # the for-loop above should produce similar precision as logcumsumexp (it's just slower), 526*da0073e9SAndroid Build Coastguard Worker # so it can be used as the expected values to check our computation 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker # using logsumexp from scipy because by the time of writing this test code, 529*da0073e9SAndroid Build Coastguard Worker # torch.logsumexp has not been implemented for complex numbers 530*da0073e9SAndroid Build Coastguard Worker from scipy.special import logsumexp 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker def zero_out_neg_inf(t): 533*da0073e9SAndroid Build Coastguard Worker t = t.clone() 534*da0073e9SAndroid Build Coastguard Worker idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0) 535*da0073e9SAndroid Build Coastguard Worker t[idx] = torch.real(t[idx]).to(t.dtype) 536*da0073e9SAndroid Build Coastguard Worker return t 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker def standardize_phase(t): 539*da0073e9SAndroid Build Coastguard Worker t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi)) 540*da0073e9SAndroid Build Coastguard Worker return t 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker def logcumsumexp_slow(a, dim): 543*da0073e9SAndroid Build Coastguard Worker res_lst = [] 544*da0073e9SAndroid Build Coastguard Worker for i in range(a.size(dim)): 545*da0073e9SAndroid Build Coastguard Worker index = [slice(None, None, None) for _ in range(a.ndim)] 546*da0073e9SAndroid Build Coastguard Worker index[dim] = slice(None, i + 1, None) 547*da0073e9SAndroid Build Coastguard Worker a_inp = a[tuple(index)] 548*da0073e9SAndroid Build Coastguard Worker res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True)) 549*da0073e9SAndroid Build Coastguard Worker res = np.concatenate(res_lst, axis=dim) 550*da0073e9SAndroid Build Coastguard Worker return torch.as_tensor(res) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker def compare_logcumsumexp(a, expected=None): 553*da0073e9SAndroid Build Coastguard Worker for i in range(a.ndim): 554*da0073e9SAndroid Build Coastguard Worker actual = torch.logcumsumexp(a, dim=i) 555*da0073e9SAndroid Build Coastguard Worker # if the expected is not given, then revert to scipy's logsumexp 556*da0073e9SAndroid Build Coastguard Worker if expected is None: 557*da0073e9SAndroid Build Coastguard Worker expected2 = logcumsumexp_slow(a, dim=i) 558*da0073e9SAndroid Build Coastguard Worker else: 559*da0073e9SAndroid Build Coastguard Worker expected2 = expected 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker # move the imaginary values to (0, 2 * pi) 562*da0073e9SAndroid Build Coastguard Worker actual = standardize_phase(actual) 563*da0073e9SAndroid Build Coastguard Worker expected2 = standardize_phase(expected2) 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker # zeroing the imaginary part of the element if the real part is -inf 566*da0073e9SAndroid Build Coastguard Worker # as the imaginary part cannot be determined exactly and it does not 567*da0073e9SAndroid Build Coastguard Worker # really matter if we take the exp of the output 568*da0073e9SAndroid Build Coastguard Worker actual = zero_out_neg_inf(actual) 569*da0073e9SAndroid Build Coastguard Worker expected2 = zero_out_neg_inf(expected2) 570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected2.shape, actual.shape) 571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected2, actual) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker # randomly specified values 574*da0073e9SAndroid Build Coastguard Worker # in this case, scipy.logsumexp should be enough 575*da0073e9SAndroid Build Coastguard Worker a1 = torch.randn((5, 10), dtype=dtype, device=device) 576*da0073e9SAndroid Build Coastguard Worker compare_logcumsumexp(a1) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker # test with some non-normal values 579*da0073e9SAndroid Build Coastguard Worker a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device) 580*da0073e9SAndroid Build Coastguard Worker compare_logcumsumexp(a2) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker # handle special case involving infinites and nans 583*da0073e9SAndroid Build Coastguard Worker # here we don't use scipy.logsumexp as it gives confusing answer on 584*da0073e9SAndroid Build Coastguard Worker # some inf cases 585*da0073e9SAndroid Build Coastguard Worker # see here: 586*da0073e9SAndroid Build Coastguard Worker inf = float('inf') 587*da0073e9SAndroid Build Coastguard Worker nan = float('nan') 588*da0073e9SAndroid Build Coastguard Worker a3_input = torch.tensor([ 589*da0073e9SAndroid Build Coastguard Worker -inf + 4j, 590*da0073e9SAndroid Build Coastguard Worker -inf + 1j, 591*da0073e9SAndroid Build Coastguard Worker 1.2 + 2.1j, 592*da0073e9SAndroid Build Coastguard Worker 1e10 + 1e20j, 593*da0073e9SAndroid Build Coastguard Worker inf + 0j, 594*da0073e9SAndroid Build Coastguard Worker inf + 1j, 595*da0073e9SAndroid Build Coastguard Worker inf + 3j, 596*da0073e9SAndroid Build Coastguard Worker nan + 2j, 597*da0073e9SAndroid Build Coastguard Worker ]) 598*da0073e9SAndroid Build Coastguard Worker a3_expected = torch.tensor([ 599*da0073e9SAndroid Build Coastguard Worker -inf + 0j, 600*da0073e9SAndroid Build Coastguard Worker -inf + 0j, 601*da0073e9SAndroid Build Coastguard Worker 1.2 + 2.1j, 602*da0073e9SAndroid Build Coastguard Worker 1e10 + 1e20j, 603*da0073e9SAndroid Build Coastguard Worker inf + 0j, # scipy's logsumexp gives (inf + 0.7853982j) here, unclear why 604*da0073e9SAndroid Build Coastguard Worker inf + (np.pi / 4) * 1j, # the imaginary part thanks to some weird behaviour of log(inf + infj) 605*da0073e9SAndroid Build Coastguard Worker complex(inf, nan), 606*da0073e9SAndroid Build Coastguard Worker complex(nan, nan), 607*da0073e9SAndroid Build Coastguard Worker ]) 608*da0073e9SAndroid Build Coastguard Worker # windows give strange results on the second-to-last results where it gives inf + pi/4 j 609*da0073e9SAndroid Build Coastguard Worker # instead of inf + nan j 610*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 611*da0073e9SAndroid Build Coastguard Worker compare_logcumsumexp(a3_input, a3_expected) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker a4_input = torch.tensor([ 614*da0073e9SAndroid Build Coastguard Worker complex(-inf, inf), 615*da0073e9SAndroid Build Coastguard Worker complex(-inf, inf), 616*da0073e9SAndroid Build Coastguard Worker -inf + 1j, 617*da0073e9SAndroid Build Coastguard Worker 1.2 + 2.1j, 618*da0073e9SAndroid Build Coastguard Worker complex(2.4, inf), 619*da0073e9SAndroid Build Coastguard Worker ]) 620*da0073e9SAndroid Build Coastguard Worker a4_expected = torch.tensor([ 621*da0073e9SAndroid Build Coastguard Worker -inf + 0j, 622*da0073e9SAndroid Build Coastguard Worker -inf + 0j, 623*da0073e9SAndroid Build Coastguard Worker -inf + 0j, 624*da0073e9SAndroid Build Coastguard Worker 1.2 + 2.1j, 625*da0073e9SAndroid Build Coastguard Worker complex(nan, nan), 626*da0073e9SAndroid Build Coastguard Worker ]) 627*da0073e9SAndroid Build Coastguard Worker if not IS_WINDOWS: 628*da0073e9SAndroid Build Coastguard Worker compare_logcumsumexp(a4_input, a4_expected) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker @onlyCPU 631*da0073e9SAndroid Build Coastguard Worker def test_sum_parallel(self, device): 632*da0073e9SAndroid Build Coastguard Worker # To use parallel branches we'll need to compare on tensors 633*da0073e9SAndroid Build Coastguard Worker # that are relatively large. Even if this is run on a single 634*da0073e9SAndroid Build Coastguard Worker # core machine these tests will still give you signal on 635*da0073e9SAndroid Build Coastguard Worker # the correctness 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def _run_test(size): 638*da0073e9SAndroid Build Coastguard Worker for dim in range(len(size) + 1): 639*da0073e9SAndroid Build Coastguard Worker nv = np.round(np.random.rand(*size)) # 0s and 1s 640*da0073e9SAndroid Build Coastguard Worker tv = torch.from_numpy(nv) 641*da0073e9SAndroid Build Coastguard Worker # Parallelisim is only used if numel is 642*da0073e9SAndroid Build Coastguard Worker # larger than grainsize defined in Parallel.h 643*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tv.numel() > 32768) 644*da0073e9SAndroid Build Coastguard Worker if dim == len(size): 645*da0073e9SAndroid Build Coastguard Worker nvs = nv.sum() 646*da0073e9SAndroid Build Coastguard Worker tvs = tv.sum() 647*da0073e9SAndroid Build Coastguard Worker else: 648*da0073e9SAndroid Build Coastguard Worker nvs = nv.sum(dim) 649*da0073e9SAndroid Build Coastguard Worker tvs = tv.sum(dim) 650*da0073e9SAndroid Build Coastguard Worker diff = np.abs(nvs - tvs.numpy()).sum() 651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(diff, 0) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3]) 654*da0073e9SAndroid Build Coastguard Worker _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) 655*da0073e9SAndroid Build Coastguard Worker _run_test([1, 32 * 8 * 32 * 8]) 656*da0073e9SAndroid Build Coastguard Worker _run_test([1, 32770]) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker # TODO: kill map2_ (and similar) uses and update to compare with NumPy 659*da0073e9SAndroid Build Coastguard Worker # only works on CPU since this uses map2_, which is only supported on CPU 660*da0073e9SAndroid Build Coastguard Worker def _testCSelection(self, torchfn, mathfn): 661*da0073e9SAndroid Build Coastguard Worker # Two tensors 662*da0073e9SAndroid Build Coastguard Worker size = (100, 100) 663*da0073e9SAndroid Build Coastguard Worker a = torch.rand(*size) 664*da0073e9SAndroid Build Coastguard Worker b = torch.rand(*size) 665*da0073e9SAndroid Build Coastguard Worker c = torchfn(a, b) 666*da0073e9SAndroid Build Coastguard Worker expected_c = torch.zeros(*size) 667*da0073e9SAndroid Build Coastguard Worker expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b)) 668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_c, c, atol=0, rtol=0) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker @onlyCPU 671*da0073e9SAndroid Build Coastguard Worker def test_max_elementwise(self, device): 672*da0073e9SAndroid Build Coastguard Worker self._testCSelection(torch.max, max) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker @onlyCPU 675*da0073e9SAndroid Build Coastguard Worker def test_min_elementwise(self, device): 676*da0073e9SAndroid Build Coastguard Worker self._testCSelection(torch.min, min) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker def test_all_any(self, device): 679*da0073e9SAndroid Build Coastguard Worker def test(size): 680*da0073e9SAndroid Build Coastguard Worker x = torch.ones(*size, device=device).byte() 681*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.all()) 682*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.any()) 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker x[3] = 0 685*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.all()) 686*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.any()) 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker x.zero_() 689*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.all()) 690*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.any()) 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker x.fill_(2) 693*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.all()) 694*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.any()) 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker x = torch.ones(*size, device=device).bool() 697*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.all()) 698*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.any()) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker x[3] = False 701*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.all()) 702*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.any()) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker test((10,)) 705*da0073e9SAndroid Build Coastguard Worker test((5, 5)) 706*da0073e9SAndroid Build Coastguard Worker 707*da0073e9SAndroid Build Coastguard Worker def test_all_any_with_dim(self, device): 708*da0073e9SAndroid Build Coastguard Worker def test(x): 709*da0073e9SAndroid Build Coastguard Worker r1 = x.prod(dim=0, keepdim=False).byte() 710*da0073e9SAndroid Build Coastguard Worker r2 = x.all(dim=0, keepdim=False) 711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1.shape, r2.shape) 712*da0073e9SAndroid Build Coastguard Worker self.assertTrue((r1 == r2).all()) 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte() 715*da0073e9SAndroid Build Coastguard Worker r4 = x.any(dim=1, keepdim=True) 716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r3.shape, r4.shape) 717*da0073e9SAndroid Build Coastguard Worker self.assertTrue((r3 == r4).all()) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker test(torch.tensor([[0, 0, 0], 720*da0073e9SAndroid Build Coastguard Worker [0, 0, 1], 721*da0073e9SAndroid Build Coastguard Worker [0, 1, 1], 722*da0073e9SAndroid Build Coastguard Worker [1, 1, 1]], device=device, dtype=torch.uint8)) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker def test_numpy_named_args(self, device): 725*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10, device=device) 726*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(10, device=device) 727*da0073e9SAndroid Build Coastguard Worker res1 = torch.add(input=x1, other=x2) 728*da0073e9SAndroid Build Coastguard Worker res2 = torch.add(x1=x1, x2=x2) 729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(10, 10, 10, device=device) 732*da0073e9SAndroid Build Coastguard Worker res1 = x1.sum(dim=(0, 2), keepdim=True) 733*da0073e9SAndroid Build Coastguard Worker res2 = x1.sum(axis=(0, 2), keepdims=True) 734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker # TODO: kill this ane replace with common creation ops 737*da0073e9SAndroid Build Coastguard Worker def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, 738*da0073e9SAndroid Build Coastguard Worker use_complex=False) -> Dict[str, List[torch.Tensor]]: 739*da0073e9SAndroid Build Coastguard Worker float_types = [torch.double, 740*da0073e9SAndroid Build Coastguard Worker torch.float] 741*da0073e9SAndroid Build Coastguard Worker int_types = [torch.int64, 742*da0073e9SAndroid Build Coastguard Worker torch.int32, 743*da0073e9SAndroid Build Coastguard Worker torch.int16] 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker complex_types = [torch.complex64, 746*da0073e9SAndroid Build Coastguard Worker torch.complex128] 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker def make_contiguous(shape, dtype) -> torch.Tensor: 749*da0073e9SAndroid Build Coastguard Worker if dtype in float_types: 750*da0073e9SAndroid Build Coastguard Worker val = torch.randn(shape, dtype=dtype) 751*da0073e9SAndroid Build Coastguard Worker val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0)) 752*da0073e9SAndroid Build Coastguard Worker val = val + ((val_range[1] - val_range[0]) / 2.0) 753*da0073e9SAndroid Build Coastguard Worker val = torch.clamp(val, min=val_range[0], max=val_range[1]) 754*da0073e9SAndroid Build Coastguard Worker return val 755*da0073e9SAndroid Build Coastguard Worker result = torch.zeros(shape, dtype=dtype) 756*da0073e9SAndroid Build Coastguard Worker result.apply_(lambda x: random.randint(val_range[0], val_range[1])) 757*da0073e9SAndroid Build Coastguard Worker return result 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker def make_non_contiguous(shape, dtype) -> torch.Tensor: 760*da0073e9SAndroid Build Coastguard Worker contig = make_contiguous(shape, dtype) 761*da0073e9SAndroid Build Coastguard Worker non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0] 762*da0073e9SAndroid Build Coastguard Worker non_contig = non_contig.select(-1, -1) 763*da0073e9SAndroid Build Coastguard Worker non_contig.copy_(contig) 764*da0073e9SAndroid Build Coastguard Worker self.assertFalse(non_contig.is_contiguous()) 765*da0073e9SAndroid Build Coastguard Worker return non_contig 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker def make_contiguous_slice(size, dtype) -> torch.Tensor: 768*da0073e9SAndroid Build Coastguard Worker contig = make_contiguous((1, size), dtype) 769*da0073e9SAndroid Build Coastguard Worker non_contig = contig[:1, 1:size - 1] 770*da0073e9SAndroid Build Coastguard Worker self.assertTrue(non_contig.is_contiguous()) 771*da0073e9SAndroid Build Coastguard Worker return contig 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker types = [] 774*da0073e9SAndroid Build Coastguard Worker if use_floating: 775*da0073e9SAndroid Build Coastguard Worker types += float_types 776*da0073e9SAndroid Build Coastguard Worker if use_integral: 777*da0073e9SAndroid Build Coastguard Worker types += int_types 778*da0073e9SAndroid Build Coastguard Worker if use_complex: 779*da0073e9SAndroid Build Coastguard Worker types += complex_types 780*da0073e9SAndroid Build Coastguard Worker tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []} 781*da0073e9SAndroid Build Coastguard Worker for dtype in types: 782*da0073e9SAndroid Build Coastguard Worker tensors["cont"].append(make_contiguous(shape, dtype)) 783*da0073e9SAndroid Build Coastguard Worker tensors["noncont"].append(make_non_contiguous(shape, dtype)) 784*da0073e9SAndroid Build Coastguard Worker tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype)) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker return tensors 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker # TODO: refactor this to use comparators from common_utils 789*da0073e9SAndroid Build Coastguard Worker def _assert_matches_numpy(self, t, n): 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n.shape, t.shape) 791*da0073e9SAndroid Build Coastguard Worker if t.dtype == torch.float: 792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True) 793*da0073e9SAndroid Build Coastguard Worker else: 794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(n, t, equal_nan=True) 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker # TODO: update this and tests that use it to use the device argument properly 797*da0073e9SAndroid Build Coastguard Worker def _test_dim_ops(self, pytorch_op, numpy_op, 798*da0073e9SAndroid Build Coastguard Worker use_floating=True, use_integral=True, use_complex=False): 799*da0073e9SAndroid Build Coastguard Worker def do_one(tensors_dict, dim): 800*da0073e9SAndroid Build Coastguard Worker for category, tensors in tensors_dict.items(): 801*da0073e9SAndroid Build Coastguard Worker if category == "slice": 802*da0073e9SAndroid Build Coastguard Worker dim = 0 803*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 804*da0073e9SAndroid Build Coastguard Worker # we have no control over NumPy warnings... 805*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 806*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("ignore") 807*da0073e9SAndroid Build Coastguard Worker expected = numpy_op(tensor.cpu().numpy(), dim) 808*da0073e9SAndroid Build Coastguard Worker actual = pytorch_op(tensor, dim) 809*da0073e9SAndroid Build Coastguard Worker self._assert_matches_numpy(actual, expected) 810*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 811*da0073e9SAndroid Build Coastguard Worker self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected) 812*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((5, 400000), use_floating=use_floating, 813*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 1) 814*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, 815*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 0) 816*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, 817*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 1) 818*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, 819*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 2) 820*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((100000, ), use_floating=use_floating, 821*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), -1) 822*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 823*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 0) 824*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 825*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 1) 826*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 827*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), 2) 828*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 829*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), (1, 2)) 830*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 831*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), (1, -1)) 832*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 833*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), (0, 2)) 834*da0073e9SAndroid Build Coastguard Worker do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, 835*da0073e9SAndroid Build Coastguard Worker use_integral=use_integral, use_complex=use_complex), (0, 2, 1)) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker @slowTest 838*da0073e9SAndroid Build Coastguard Worker @onlyCPU 839*da0073e9SAndroid Build Coastguard Worker def test_sum_dim(self, device): 840*da0073e9SAndroid Build Coastguard Worker self._test_dim_ops( 841*da0073e9SAndroid Build Coastguard Worker lambda t, d: t.sum(d), 842*da0073e9SAndroid Build Coastguard Worker lambda n, d: n.sum(d), 843*da0073e9SAndroid Build Coastguard Worker use_floating=True, use_integral=True, use_complex=True) 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker @onlyCPU 846*da0073e9SAndroid Build Coastguard Worker def test_mean_dim(self, device): 847*da0073e9SAndroid Build Coastguard Worker self._test_dim_ops( 848*da0073e9SAndroid Build Coastguard Worker lambda t, d: t.mean(d), 849*da0073e9SAndroid Build Coastguard Worker lambda n, d: n.mean(d), 850*da0073e9SAndroid Build Coastguard Worker use_integral=False, 851*da0073e9SAndroid Build Coastguard Worker use_complex=True) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker @onlyCPU 854*da0073e9SAndroid Build Coastguard Worker def test_std_dim(self, device): 855*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 856*da0073e9SAndroid Build Coastguard Worker self._test_dim_ops( 857*da0073e9SAndroid Build Coastguard Worker lambda t, d: t.std(d, unbiased=unbiased), 858*da0073e9SAndroid Build Coastguard Worker lambda n, d: n.std(d, ddof=1 if unbiased else 0), 859*da0073e9SAndroid Build Coastguard Worker use_integral=False) 860*da0073e9SAndroid Build Coastguard Worker 861*da0073e9SAndroid Build Coastguard Worker @onlyCPU 862*da0073e9SAndroid Build Coastguard Worker def test_var_dim(self, device): 863*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 864*da0073e9SAndroid Build Coastguard Worker self._test_dim_ops( 865*da0073e9SAndroid Build Coastguard Worker lambda t, d: t.var(d, unbiased=unbiased), 866*da0073e9SAndroid Build Coastguard Worker lambda n, d: n.var(d, ddof=1 if unbiased else 0), 867*da0073e9SAndroid Build Coastguard Worker use_integral=False) 868*da0073e9SAndroid Build Coastguard Worker 869*da0073e9SAndroid Build Coastguard Worker @onlyCPU 870*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 871*da0073e9SAndroid Build Coastguard Worker def test_logsumexp_dim(self, device): 872*da0073e9SAndroid Build Coastguard Worker from scipy.special import logsumexp 873*da0073e9SAndroid Build Coastguard Worker self._test_dim_ops( 874*da0073e9SAndroid Build Coastguard Worker lambda t, d: t.logsumexp(d), 875*da0073e9SAndroid Build Coastguard Worker lambda n, d: logsumexp(n, d), 876*da0073e9SAndroid Build Coastguard Worker use_integral=False) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker @onlyCPU 879*da0073e9SAndroid Build Coastguard Worker def test_mean_int_with_optdtype(self, device): 880*da0073e9SAndroid Build Coastguard Worker a = make_tensor((3, 4, 5), dtype=torch.int64, device=device) 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker # If the optional desired output type is given, the input 883*da0073e9SAndroid Build Coastguard Worker # is internally cast. 884*da0073e9SAndroid Build Coastguard Worker a_float = a.to(torch.float32) 885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_float.mean(), a.mean(dtype=torch.float32)) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker # TODO: update this and tests that use it to handle device properly 888*da0073e9SAndroid Build Coastguard Worker def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True): 889*da0073e9SAndroid Build Coastguard Worker shape = (3, 4, 5) 890*da0073e9SAndroid Build Coastguard Worker reduced_shape = fn(torch.ones(shape)).shape 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker def _test_out(dtype, other_dtype): 893*da0073e9SAndroid Build Coastguard Worker out = torch.ones(reduced_shape, dtype=dtype) 894*da0073e9SAndroid Build Coastguard Worker result = fn(x, out=out) 895*da0073e9SAndroid Build Coastguard Worker self.assertIs(out.dtype, result.dtype) 896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False) 897*da0073e9SAndroid Build Coastguard Worker result = fn(x, out=out, dtype=dtype) 898*da0073e9SAndroid Build Coastguard Worker self.assertIs(out.dtype, result.dtype) 899*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False) 900*da0073e9SAndroid Build Coastguard Worker # 'out' is favored over dtype, check error 901*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]: 904*da0073e9SAndroid Build Coastguard Worker x = torch.ones(shape, dtype=dtype) 905*da0073e9SAndroid Build Coastguard Worker expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64 906*da0073e9SAndroid Build Coastguard Worker self.assertIs(expected_dtype, fn(x).dtype) 907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x.to(expected_dtype)), fn(x)) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 910*da0073e9SAndroid Build Coastguard Worker other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 911*da0073e9SAndroid Build Coastguard Worker elif dtype.is_complex: 912*da0073e9SAndroid Build Coastguard Worker other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128 913*da0073e9SAndroid Build Coastguard Worker else: 914*da0073e9SAndroid Build Coastguard Worker other_dtype = torch.int32 if dtype != torch.int32 else torch.int16 915*da0073e9SAndroid Build Coastguard Worker self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype) 916*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker # test mixed int/float/complex 919*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 920*da0073e9SAndroid Build Coastguard Worker mixed_dtypes = [torch.int32, torch.complex64] 921*da0073e9SAndroid Build Coastguard Worker elif dtype.is_complex: 922*da0073e9SAndroid Build Coastguard Worker mixed_dtypes = [torch.int32, torch.float32] 923*da0073e9SAndroid Build Coastguard Worker else: 924*da0073e9SAndroid Build Coastguard Worker mixed_dtypes = [torch.float32, torch.complex64] 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker for mixed_dtype in mixed_dtypes: 927*da0073e9SAndroid Build Coastguard Worker self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype) 928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False) 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker if has_out: 931*da0073e9SAndroid Build Coastguard Worker _test_out(dtype, other_dtype) 932*da0073e9SAndroid Build Coastguard Worker _test_out(dtype, mixed_dtype) 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker @onlyCPU 935*da0073e9SAndroid Build Coastguard Worker def test_sum_integer_upcast(self, device): 936*da0073e9SAndroid Build Coastguard Worker self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False) 937*da0073e9SAndroid Build Coastguard Worker self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs)) 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker @onlyCPU 940*da0073e9SAndroid Build Coastguard Worker def test_prod_integer_upcast(self, device): 941*da0073e9SAndroid Build Coastguard Worker self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False) 942*da0073e9SAndroid Build Coastguard Worker self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs)) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker @onlyCPU 945*da0073e9SAndroid Build Coastguard Worker def test_cumsum_integer_upcast(self, device): 946*da0073e9SAndroid Build Coastguard Worker self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs)) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker @onlyCPU 949*da0073e9SAndroid Build Coastguard Worker def test_cumprod_integer_upcast(self, device): 950*da0073e9SAndroid Build Coastguard Worker self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types()) 953*da0073e9SAndroid Build Coastguard Worker def test_mode(self, device, dtype): 954*da0073e9SAndroid Build Coastguard Worker SIZE = 10 955*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE) 956*da0073e9SAndroid Build Coastguard Worker x[:2] = 1 957*da0073e9SAndroid Build Coastguard Worker x[:, :2] = 1 958*da0073e9SAndroid Build Coastguard Worker x0 = x.clone() 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker # Pre-calculated results. 961*da0073e9SAndroid Build Coastguard Worker res1val = torch.ones(SIZE, device=device, dtype=dtype) 962*da0073e9SAndroid Build Coastguard Worker # The indices are the position of the last appearance of the mode element. 963*da0073e9SAndroid Build Coastguard Worker res1ind = torch.ones(SIZE, device=device, dtype=torch.long) 964*da0073e9SAndroid Build Coastguard Worker res1ind[0] = SIZE - 1 965*da0073e9SAndroid Build Coastguard Worker res1ind[1] = SIZE - 1 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker res2val, res2ind = torch.mode(x, keepdim=False) 968*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1val, res2val, atol=0, rtol=0) 969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1ind, res2ind, atol=0, rtol=0) 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker # Test use of result tensor 972*da0073e9SAndroid Build Coastguard Worker res2val = torch.tensor((), device=device, dtype=dtype) 973*da0073e9SAndroid Build Coastguard Worker res2ind = torch.tensor((), device=device, dtype=torch.long) 974*da0073e9SAndroid Build Coastguard Worker torch.mode(x, keepdim=False, out=(res2val, res2ind)) 975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1val, res2val, atol=0, rtol=0) 976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1ind, res2ind, atol=0, rtol=0) 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker # Test non-default dim 979*da0073e9SAndroid Build Coastguard Worker res2val, res2ind = torch.mode(x, 0, False) 980*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1val, res2val, atol=0, rtol=0) 981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1ind, res2ind, atol=0, rtol=0) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker # input unchanged 984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x0, atol=0, rtol=0) 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker def _test_mode_intervals(self, shape, intervals, device, dtype, v=1): 987*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, shape[1], device=device, dtype=dtype).expand(shape) 988*da0073e9SAndroid Build Coastguard Worker x = x.contiguous() 989*da0073e9SAndroid Build Coastguard Worker x[:, v] = intervals[0][0] 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker # Set the value of each interval to the mode "v" 992*da0073e9SAndroid Build Coastguard Worker for (beg, end) in intervals: 993*da0073e9SAndroid Build Coastguard Worker x[:, beg:end] = v 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker values, indices = torch.mode(x, -1, False) 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker # Check whether the returned indices correspond to the returned values 998*da0073e9SAndroid Build Coastguard Worker self.assertTrue((x.gather(1, indices.unsqueeze(1)).t() == values).all()) 999*da0073e9SAndroid Build Coastguard Worker # Check whether the returned values are the mode 1000*da0073e9SAndroid Build Coastguard Worker self.assertTrue((values == v).all().item()) 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1003*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 1004*da0073e9SAndroid Build Coastguard Worker def test_mode_large(self, device, dtype): 1005*da0073e9SAndroid Build Coastguard Worker # i should be less than (d - 2) / 2 1006*da0073e9SAndroid Build Coastguard Worker def testset_for_shape(shape, i): 1007*da0073e9SAndroid Build Coastguard Worker d = shape[-1] 1008*da0073e9SAndroid Build Coastguard Worker # Mode only in the middle. 1009*da0073e9SAndroid Build Coastguard Worker self._test_mode_intervals(shape, [(i, d - i)], device, dtype) 1010*da0073e9SAndroid Build Coastguard Worker # Mode in discontiguous parts of the input. 1011*da0073e9SAndroid Build Coastguard Worker self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device, dtype) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker # More than one line of (65535) thread blocks 1014*da0073e9SAndroid Build Coastguard Worker testset_for_shape((65536, 10), 3) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker # Max slice size (2048) 1017*da0073e9SAndroid Build Coastguard Worker testset_for_shape((10, 2048), 10) 1018*da0073e9SAndroid Build Coastguard Worker 1019*da0073e9SAndroid Build Coastguard Worker # Naive kernel for big slice sizes (> 2048) 1020*da0073e9SAndroid Build Coastguard Worker testset_for_shape((10, 4096), 10) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker def test_mode_boolean(self, device): 1023*da0073e9SAndroid Build Coastguard Worker shapes = [ 1024*da0073e9SAndroid Build Coastguard Worker (10, 10), 1025*da0073e9SAndroid Build Coastguard Worker (4, 2048), 1026*da0073e9SAndroid Build Coastguard Worker (1, 4096), 1027*da0073e9SAndroid Build Coastguard Worker ] 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1030*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(shape, device=device, dtype=torch.bool) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker a[:, (shape[1] - 1) // 2:] = True 1033*da0073e9SAndroid Build Coastguard Worker values, indices = a.mode(-1) 1034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool)) 1035*da0073e9SAndroid Build Coastguard Worker print(indices) 1036*da0073e9SAndroid Build Coastguard Worker indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1) 1037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, indexed) 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker a.fill_(False) 1040*da0073e9SAndroid Build Coastguard Worker a[:, shape[1] // 2 + 1:] = True 1041*da0073e9SAndroid Build Coastguard Worker values, indices = a.mode(-1) 1042*da0073e9SAndroid Build Coastguard Worker print(indices) 1043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, torch.zeros(shape[0], dtype=torch.bool)) 1044*da0073e9SAndroid Build Coastguard Worker indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1) 1045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, indexed) 1046*da0073e9SAndroid Build Coastguard Worker 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # mode only supports CPU and CUDA device type 1049*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1050*da0073e9SAndroid Build Coastguard Worker def test_mode_wrong_dtype(self, device): 1051*da0073e9SAndroid Build Coastguard Worker def test_for_dtypes(x_ty, v_ty, i_ty, message): 1052*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10, device=device, dtype=x_ty) 1053*da0073e9SAndroid Build Coastguard Worker v = torch.ones(10, device=device, dtype=v_ty) 1054*da0073e9SAndroid Build Coastguard Worker i = torch.ones(10, device=device, dtype=i_ty) 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, message): 1057*da0073e9SAndroid Build Coastguard Worker torch.mode(x, -1, True, out=(v, i)) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker err_msg = "expected scalar type .* but got .* for " 1060*da0073e9SAndroid Build Coastguard Worker values_err = err_msg + "values" 1061*da0073e9SAndroid Build Coastguard Worker indices_err = err_msg + "indices" 1062*da0073e9SAndroid Build Coastguard Worker 1063*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.uint8, torch.int8, torch.long, values_err) 1064*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.int8, torch.int16, torch.long, values_err) 1065*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.int32, torch.float32, torch.long, values_err) 1066*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.float32, torch.float64, torch.long, values_err) 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.uint8, torch.uint8, torch.int8, indices_err) 1069*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.int8, torch.int8, torch.int16, indices_err) 1070*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.int32, torch.int32, torch.float32, indices_err) 1071*da0073e9SAndroid Build Coastguard Worker test_for_dtypes(torch.float32, torch.float32, torch.float64, indices_err) 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1074*da0073e9SAndroid Build Coastguard Worker def test_mode_wrong_device(self, device): 1075*da0073e9SAndroid Build Coastguard Worker # CPU Input Tensor 1076*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1079*da0073e9SAndroid Build Coastguard Worker "expected device .* but got .* for values"): 1080*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([], device=device) 1081*da0073e9SAndroid Build Coastguard Worker torch.mode(x, -1, True, out=(values, torch.tensor([], dtype=torch.long))) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 1084*da0073e9SAndroid Build Coastguard Worker "expected device .* but got .* for indices"): 1085*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([], device=device) 1086*da0073e9SAndroid Build Coastguard Worker torch.mode(x, -1, True, out=(torch.tensor([]), indices)) 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker # TODO: make work on CUDA, too 1089*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1090*da0073e9SAndroid Build Coastguard Worker def test_accreal_type(self, device) -> None: 1091*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 3, 4) 1092*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.double().sum().item(), float) 1093*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.float().sum().item(), float) 1094*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.long().sum().item(), int) 1095*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.int().sum().item(), int) 1096*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.short().sum().item(), int) 1097*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.char().sum().item(), int) 1098*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.byte().sum().item(), int) 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker def test_var_mean_some_dims(self, device): 1101*da0073e9SAndroid Build Coastguard Worker sizes = (4, 6, 7, 5, 3) 1102*da0073e9SAndroid Build Coastguard Worker dims = len(sizes) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker x = torch.rand(sizes, device=device) 1105*da0073e9SAndroid Build Coastguard Worker for num_of_dims in range(2, dims): 1106*da0073e9SAndroid Build Coastguard Worker dim_list = list(combinations(list(range(dims)), r=num_of_dims)) 1107*da0073e9SAndroid Build Coastguard Worker for dim in dim_list: 1108*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 1109*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 1110*da0073e9SAndroid Build Coastguard Worker var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) 1111*da0073e9SAndroid Build Coastguard Worker var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) 1112*da0073e9SAndroid Build Coastguard Worker mean2 = x.mean(dim=dim, keepdim=keepdim) 1113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var1, var2) 1114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker # TODO: this should be a generic opinfo test 1117*da0073e9SAndroid Build Coastguard Worker def test_all_any_empty(self, device): 1118*da0073e9SAndroid Build Coastguard Worker x = torch.ByteTensor().to(device) 1119*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.all()) 1120*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.any()) 1121*da0073e9SAndroid Build Coastguard Worker 1122*da0073e9SAndroid Build Coastguard Worker x = torch.BoolTensor().to(device) 1123*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.all()) 1124*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.any()) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker def test_all_issue117215(self, device): 1127*da0073e9SAndroid Build Coastguard Worker info = torch.iinfo(torch.uint8) 1128*da0073e9SAndroid Build Coastguard Worker a = torch.randint(info.min, info.max, (73, 11, 3, 17), dtype=torch.uint8) 1129*da0073e9SAndroid Build Coastguard Worker b = torch.all(a, dim=0) 1130*da0073e9SAndroid Build Coastguard Worker c = a.to(torch.bool).all(dim=0) 1131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ne(b, c).sum(), 0) 1132*da0073e9SAndroid Build Coastguard Worker 1133*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) 1134*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) 1135*da0073e9SAndroid Build Coastguard Worker def test_max_with_inf(self, device, dtype): 1136*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) 1137*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item()) 1138*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item()) 1139*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.max(a).item() == inf) 1140*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.amax(a).item() == inf) 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) 1143*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.bfloat16, torch.double) 1144*da0073e9SAndroid Build Coastguard Worker def test_min_with_inf(self, device, dtype): 1145*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) 1146*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item()) 1147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item()) 1148*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.min(a).item() == -inf) 1149*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.amin(a).item() == -inf) 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False): 1152*da0073e9SAndroid Build Coastguard Worker def create_input(shape, device, dtype): 1153*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 1154*da0073e9SAndroid Build Coastguard Worker return torch.randn(*shape, device=device, dtype=dtype) 1155*da0073e9SAndroid Build Coastguard Worker else: 1156*da0073e9SAndroid Build Coastguard Worker low = 0 if dtype == torch.bool else -1000 1157*da0073e9SAndroid Build Coastguard Worker high = 2 if dtype == torch.bool else 1000 1158*da0073e9SAndroid Build Coastguard Worker return torch.randint(low, high, shape, device=device, dtype=dtype) 1159*da0073e9SAndroid Build Coastguard Worker x = create_input((100, 100), device, dtype) 1160*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torchfn, reffn, x) 1161*da0073e9SAndroid Build Coastguard Worker # non contiguous 1162*da0073e9SAndroid Build Coastguard Worker x = create_input((10, 10, 10), device, dtype) 1163*da0073e9SAndroid Build Coastguard Worker x = x[:, 4] 1164*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torchfn, reffn, x) 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker def get_values(x): 1167*da0073e9SAndroid Build Coastguard Worker if isinstance(x, tuple): 1168*da0073e9SAndroid Build Coastguard Worker return x[0] 1169*da0073e9SAndroid Build Coastguard Worker return x 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker # indices 1172*da0073e9SAndroid Build Coastguard Worker if not skip_indices: 1173*da0073e9SAndroid Build Coastguard Worker size = 5 1174*da0073e9SAndroid Build Coastguard Worker x = create_input((size, size), device, dtype) 1175*da0073e9SAndroid Build Coastguard Worker inputs = (x, x.t()) 1176*da0073e9SAndroid Build Coastguard Worker dims = (0, 1) 1177*da0073e9SAndroid Build Coastguard Worker for xinp, d in product(inputs, dims): 1178*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp) 1179*da0073e9SAndroid Build Coastguard Worker result = torchfn(xinp, d, False) 1180*da0073e9SAndroid Build Coastguard Worker if isinstance(result, tuple): 1181*da0073e9SAndroid Build Coastguard Worker v, i = result 1182*da0073e9SAndroid Build Coastguard Worker if d == 1: 1183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0) 1184*da0073e9SAndroid Build Coastguard Worker else: 1185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0) 1186*da0073e9SAndroid Build Coastguard Worker # nan 1187*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 1188*da0073e9SAndroid Build Coastguard Worker for index in (0, 4, 99): 1189*da0073e9SAndroid Build Coastguard Worker x = create_input((100,), device, dtype) 1190*da0073e9SAndroid Build Coastguard Worker x[index] = nan 1191*da0073e9SAndroid Build Coastguard Worker if not skip_indices: 1192*da0073e9SAndroid Build Coastguard Worker result = torchfn(x, 0) 1193*da0073e9SAndroid Build Coastguard Worker v = get_values(result) 1194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v, nan) 1195*da0073e9SAndroid Build Coastguard Worker if isinstance(result, tuple): 1196*da0073e9SAndroid Build Coastguard Worker i = result[1] 1197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, index) 1198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torchfn(x), nan) 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half) 1201*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) 1202*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double) 1203*da0073e9SAndroid Build Coastguard Worker def test_max(self, device, dtype): 1204*da0073e9SAndroid Build Coastguard Worker self._test_minmax_helper(torch.max, np.amax, device, dtype) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half) 1207*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) 1208*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double) 1209*da0073e9SAndroid Build Coastguard Worker def test_min(self, device, dtype): 1210*da0073e9SAndroid Build Coastguard Worker self._test_minmax_helper(torch.min, np.amin, device, dtype) 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool) 1213*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) 1214*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double) 1215*da0073e9SAndroid Build Coastguard Worker def test_amin(self, device, dtype): 1216*da0073e9SAndroid Build Coastguard Worker self._test_minmax_helper(torch.amin, np.amin, device, dtype) 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool) 1219*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) 1220*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 1221*da0073e9SAndroid Build Coastguard Worker def test_amax(self, device, dtype): 1222*da0073e9SAndroid Build Coastguard Worker self._test_minmax_helper(torch.amax, np.amax, device, dtype) 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1225*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) 1226*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.bfloat16) 1227*da0073e9SAndroid Build Coastguard Worker def test_aminmax(self, device, dtype): 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker def _amin_wrapper(x, dim=None, keepdims=False): 1230*da0073e9SAndroid Build Coastguard Worker return torch.aminmax(x, dim=dim, keepdim=keepdims)[0] 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker def _amax_wrapper(x, dim=None, keepdims=False): 1233*da0073e9SAndroid Build Coastguard Worker return torch.aminmax(x, dim=dim, keepdim=keepdims)[1] 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype) 1236*da0073e9SAndroid Build Coastguard Worker self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype) 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1239*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 1240*da0073e9SAndroid Build Coastguard Worker def test_invalid_0dim_aminmax(self, device, dtype): 1241*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'not implemented'): 1242*da0073e9SAndroid Build Coastguard Worker torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker # TODO: bincount isn't a classic reduction -- maybe this test suite is 1245*da0073e9SAndroid Build Coastguard Worker # reductions and summary ops? 1246*da0073e9SAndroid Build Coastguard Worker def test_bincount(self, device): 1247*da0073e9SAndroid Build Coastguard Worker # negative input throws 1248*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): 1249*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, -1], device=device)) 1250*da0073e9SAndroid Build Coastguard Worker # n-d input, with n > 1 throws 1251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): 1252*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) 1253*da0073e9SAndroid Build Coastguard Worker # floating input type throws 1254*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'not implemented'): 1255*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1., 0.3], device=device)) 1256*da0073e9SAndroid Build Coastguard Worker # minlength < 0 throws 1257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): 1258*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, 3], device=device), 1259*da0073e9SAndroid Build Coastguard Worker torch.tensor([.2, .2], device=device), 1260*da0073e9SAndroid Build Coastguard Worker minlength=-1) 1261*da0073e9SAndroid Build Coastguard Worker # n-d weights, with n > 1 throws 1262*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, '1-d'): 1263*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, 0], device=device), 1264*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1., 0.3], [1., 0.3]], device=device)) 1265*da0073e9SAndroid Build Coastguard Worker # input and weights dim mismatch 1266*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'same length'): 1267*da0073e9SAndroid Build Coastguard Worker torch.bincount(torch.tensor([1, 0], device=device), 1268*da0073e9SAndroid Build Coastguard Worker torch.tensor([1., 0.3, 0.5], device=device)) 1269*da0073e9SAndroid Build Coastguard Worker # 1-d input with no elements and default minlength 1270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), 1271*da0073e9SAndroid Build Coastguard Worker torch.zeros(0, dtype=torch.long, device=device)) 1272*da0073e9SAndroid Build Coastguard Worker # 1-d input with no elements and specified minlength 1273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), 1274*da0073e9SAndroid Build Coastguard Worker torch.zeros(10, dtype=torch.long, device=device)) 1275*da0073e9SAndroid Build Coastguard Worker 1276*da0073e9SAndroid Build Coastguard Worker # test tensor method without weights 1277*da0073e9SAndroid Build Coastguard Worker long_counts = torch.tensor( 1278*da0073e9SAndroid Build Coastguard Worker [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() 1279*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1280*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), 1281*da0073e9SAndroid Build Coastguard Worker long_counts) 1282*da0073e9SAndroid Build Coastguard Worker # test avoiding overflow for uint8 (#76979) 1283*da0073e9SAndroid Build Coastguard Worker count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount() 1284*da0073e9SAndroid Build Coastguard Worker count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount() 1285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_uint8, count_int16) 1286*da0073e9SAndroid Build Coastguard Worker # test minlength functionality 1287*da0073e9SAndroid Build Coastguard Worker int_counts = torch.bincount( 1288*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 1, 1, 1], device=device), minlength=5) 1289*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1290*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), 1291*da0073e9SAndroid Build Coastguard Worker int_counts) 1292*da0073e9SAndroid Build Coastguard Worker # test weights 1293*da0073e9SAndroid Build Coastguard Worker byte_counts = torch.bincount( 1294*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 1, 1, 4], device=device), 1295*da0073e9SAndroid Build Coastguard Worker torch.tensor([.1, .2, .3, .4, .5], device=device)) 1296*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1297*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) 1298*da0073e9SAndroid Build Coastguard Worker byte_counts = torch.bincount( 1299*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 1, 1, 4], device=device), 1300*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) 1301*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1302*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts) 1303*da0073e9SAndroid Build Coastguard Worker # test non-contiguous inputs and weights 1304*da0073e9SAndroid Build Coastguard Worker inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device) 1305*da0073e9SAndroid Build Coastguard Worker weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) 1306*da0073e9SAndroid Build Coastguard Worker for i in [0, 1]: 1307*da0073e9SAndroid Build Coastguard Worker assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" 1308*da0073e9SAndroid Build Coastguard Worker assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" 1309*da0073e9SAndroid Build Coastguard Worker # inputs are non-contiguous but weights are contiguous 1310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) 1311*da0073e9SAndroid Build Coastguard Worker # inputs and weights are non-contiguous 1312*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1313*da0073e9SAndroid Build Coastguard Worker inputs[:, 1].bincount(weights[:, 1]), 1314*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) 1315*da0073e9SAndroid Build Coastguard Worker # weights are non-contiguous but inputs are contiguous 1316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), 1317*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32)) 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker # test bincount on non-contiguous slices 1320*da0073e9SAndroid Build Coastguard Worker all0s = torch.zeros((32, 2), dtype=torch.int64, device=device) 1321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker all1s = torch.ones((32, 2), dtype=torch.int64, device=device) 1324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker # test large number of bins - global memory use 1327*da0073e9SAndroid Build Coastguard Worker big_exp = torch.zeros(10000000, device=device) 1328*da0073e9SAndroid Build Coastguard Worker big_exp[-1] = 50.0 1329*da0073e9SAndroid Build Coastguard Worker big_w = torch.tensor([.5] * 100, device=device) 1330*da0073e9SAndroid Build Coastguard Worker big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w) 1331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(big_exp, big_out) 1332*da0073e9SAndroid Build Coastguard Worker # test large input size 1333*da0073e9SAndroid Build Coastguard Worker big_exp = torch.zeros(2, device=device, dtype=torch.int64) 1334*da0073e9SAndroid Build Coastguard Worker big_exp[1] = 1000000 1335*da0073e9SAndroid Build Coastguard Worker big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount() 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(big_exp, big_out) 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker # TODO: how many var stability tests are there? 1339*da0073e9SAndroid Build Coastguard Worker def test_var_stability2(self, device): 1340*da0073e9SAndroid Build Coastguard Worker tensor = torch.FloatTensor([2281.5, 2281.25]).to(device) 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker # Stability for inner dim 1343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(0), 0.03125) 1344*da0073e9SAndroid Build Coastguard Worker 1345*da0073e9SAndroid Build Coastguard Worker # General stability 1346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(), 0.03125) 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker # Stability for outer dimensions 1349*da0073e9SAndroid Build Coastguard Worker tensor = tensor.unsqueeze(1) 1350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.var(0), 0.03125) 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1353*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float16) 1354*da0073e9SAndroid Build Coastguard Worker def test_sum_noncontig_lowp(self, device, dtype) -> None: 1355*da0073e9SAndroid Build Coastguard Worker dim_sequences = { 1356*da0073e9SAndroid Build Coastguard Worker 2: [0, 1], 1357*da0073e9SAndroid Build Coastguard Worker 3: [0, 1, 2], 1358*da0073e9SAndroid Build Coastguard Worker 4: [0, 1, 2, 3], 1359*da0073e9SAndroid Build Coastguard Worker 5: [0, 1, 2, 3, 4], 1360*da0073e9SAndroid Build Coastguard Worker } 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker def create_noncontig_inputs(x, ndim): 1363*da0073e9SAndroid Build Coastguard Worker if ndim == 2: 1364*da0073e9SAndroid Build Coastguard Worker return x[::2, ::2] 1365*da0073e9SAndroid Build Coastguard Worker elif ndim == 3: 1366*da0073e9SAndroid Build Coastguard Worker return x[::2, ::2, ::2] 1367*da0073e9SAndroid Build Coastguard Worker elif ndim == 4: 1368*da0073e9SAndroid Build Coastguard Worker return x[::2, ::2, ::2, ::2] 1369*da0073e9SAndroid Build Coastguard Worker elif ndim == 5: 1370*da0073e9SAndroid Build Coastguard Worker return x[::2, ::2, ::2, ::2, ::2] 1371*da0073e9SAndroid Build Coastguard Worker 1372*da0073e9SAndroid Build Coastguard Worker def helper(self, shape, reduce_dims, device, dtype): 1373*da0073e9SAndroid Build Coastguard Worker for permute_list in list(permutations(dim_sequences[len(shape)], len(shape))): 1374*da0073e9SAndroid Build Coastguard Worker x = torch.ones(shape, device=device, dtype=dtype) 1375*da0073e9SAndroid Build Coastguard Worker x = create_noncontig_inputs(x, len(shape)) 1376*da0073e9SAndroid Build Coastguard Worker x_trans = x.permute(permute_list) 1377*da0073e9SAndroid Build Coastguard Worker x_sum = torch.sum(x_trans, reduce_dims) 1378*da0073e9SAndroid Build Coastguard Worker x_trans_ref = x_trans.float() 1379*da0073e9SAndroid Build Coastguard Worker x_sum_ref = torch.sum(x_trans_ref, reduce_dims) 1380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_sum, x_sum_ref.to(dtype=dtype)) 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker shapes = [ 1383*da0073e9SAndroid Build Coastguard Worker (50, 50), 1384*da0073e9SAndroid Build Coastguard Worker (50, 50, 50), 1385*da0073e9SAndroid Build Coastguard Worker (10, 50, 30, 30), 1386*da0073e9SAndroid Build Coastguard Worker (10, 5, 10, 50, 7), 1387*da0073e9SAndroid Build Coastguard Worker ] 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 1390*da0073e9SAndroid Build Coastguard Worker for i in range(1, len(shape) + 1): 1391*da0073e9SAndroid Build Coastguard Worker reduce_dims = list(combinations(dim_sequences[len(shape)], i)) 1392*da0073e9SAndroid Build Coastguard Worker for reduce_dim in reduce_dims: 1393*da0073e9SAndroid Build Coastguard Worker helper(self, shape, reduce_dim, device, dtype) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1397*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bool, torch.double) 1398*da0073e9SAndroid Build Coastguard Worker def test_sum_all(self, device, dtype) -> None: 1399*da0073e9SAndroid Build Coastguard Worker def check_sum_all(tensor: torch.Tensor) -> None: 1400*da0073e9SAndroid Build Coastguard Worker pylist = tensor.reshape(-1).tolist() 1401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.sum(), sum(pylist)) 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker if dtype != torch.bool: 1404*da0073e9SAndroid Build Coastguard Worker check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device)) 1405*da0073e9SAndroid Build Coastguard Worker check_sum_all(torch.randn(200000, dtype=dtype, device=device)) 1406*da0073e9SAndroid Build Coastguard Worker check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0]) 1407*da0073e9SAndroid Build Coastguard Worker else: 1408*da0073e9SAndroid Build Coastguard Worker check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device)) 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, 1411*da0073e9SAndroid Build Coastguard Worker memory_format, compare_data=True, default_is_preserve=False): 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker # xc is a channels last tensor 1416*da0073e9SAndroid Build Coastguard Worker xc = input_generator_fn(device) 1417*da0073e9SAndroid Build Coastguard Worker # xc is not memory dense, but looks like channels last 1418*da0073e9SAndroid Build Coastguard Worker if memory_format == torch.channels_last: 1419*da0073e9SAndroid Build Coastguard Worker xc = xc[..., ::2, ::2] 1420*da0073e9SAndroid Build Coastguard Worker else: 1421*da0073e9SAndroid Build Coastguard Worker xc = xc[..., ::2, ::2, ::2] 1422*da0073e9SAndroid Build Coastguard Worker 1423*da0073e9SAndroid Build Coastguard Worker clone = transformation_fn(xc, memory_format=torch.preserve_format) 1424*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous()) 1425*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous(memory_format=memory_format)) 1426*da0073e9SAndroid Build Coastguard Worker self.assertFalse(xc.is_contiguous()) 1427*da0073e9SAndroid Build Coastguard Worker self.assertFalse(xc.is_contiguous(memory_format=memory_format)) 1428*da0073e9SAndroid Build Coastguard Worker if compare_data: 1429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xc, clone.to(xc)) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker xc = input_generator_fn(device) 1432*da0073e9SAndroid Build Coastguard Worker clone = transformation_fn(xc, memory_format=torch.contiguous_format) 1433*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous()) 1434*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous(memory_format=memory_format)) 1435*da0073e9SAndroid Build Coastguard Worker if compare_data: 1436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xc, clone.to(xc)) 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker xc = input_generator_fn(device) 1439*da0073e9SAndroid Build Coastguard Worker clone = transformation_fn(xc) 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker if default_is_preserve: 1442*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous()) 1443*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous(memory_format=memory_format)) 1444*da0073e9SAndroid Build Coastguard Worker else: 1445*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous()) 1446*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous(memory_format=memory_format)) 1447*da0073e9SAndroid Build Coastguard Worker if compare_data: 1448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xc, clone.to(xc)) 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Worker x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device) 1451*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1452*da0073e9SAndroid Build Coastguard Worker permutation = list(range(len(x.shape))) 1453*da0073e9SAndroid Build Coastguard Worker random.shuffle(permutation) 1454*da0073e9SAndroid Build Coastguard Worker x = x.permute(permutation) 1455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride()) 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1458*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1459*da0073e9SAndroid Build Coastguard Worker def test_sum_out(self, device, dtype: torch.dtype) -> None: 1460*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, dtype=dtype, device=device) 1461*da0073e9SAndroid Build Coastguard Worker res1 = torch.sum(x, 1) 1462*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 1463*da0073e9SAndroid Build Coastguard Worker torch.sum(x, 1, out=res2) 1464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1465*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, 100, dtype=dtype, device=device) 1466*da0073e9SAndroid Build Coastguard Worker res1 = x.sum(2).sum(1) 1467*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 1468*da0073e9SAndroid Build Coastguard Worker torch.sum(x, (2, 1), out=res2) 1469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1470*da0073e9SAndroid Build Coastguard Worker 1471*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1472*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.float32) 1473*da0073e9SAndroid Build Coastguard Worker def test_prod_gpu(self, device, dtype): 1474*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device) 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker # Check all combinations: fp16 input - fp16 output, fp16 input - fp32 1477*da0073e9SAndroid Build Coastguard Worker # output, fp32 input - fp16 output, fp32 input - fp32 output 1478*da0073e9SAndroid Build Coastguard Worker for dtype_output in [torch.float16, torch.float32]: 1479*da0073e9SAndroid Build Coastguard Worker result_expected = torch.tensor(2592, dtype=dtype_output, device=device) 1480*da0073e9SAndroid Build Coastguard Worker output = torch.prod(x, dtype=dtype_output) 1481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, result_expected) 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker output = x.prod(dtype=dtype_output) 1484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, result_expected) 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1487*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1488*da0073e9SAndroid Build Coastguard Worker def test_prod(self, device, dtype): 1489*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, dtype=dtype, device=device) 1490*da0073e9SAndroid Build Coastguard Worker res1 = torch.prod(x, 1) 1491*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor((), dtype=dtype, device=device) 1492*da0073e9SAndroid Build Coastguard Worker torch.prod(x, 1, out=res2) 1493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1496*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 1497*da0073e9SAndroid Build Coastguard Worker def test_prod_lowp(self, device, dtype): 1498*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, dtype=dtype, device=device) 1499*da0073e9SAndroid Build Coastguard Worker x_ref = x.float() 1500*da0073e9SAndroid Build Coastguard Worker res1 = torch.prod(x, 1) 1501*da0073e9SAndroid Build Coastguard Worker res2 = torch.prod(x_ref, 1) 1502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2.to(dtype=dtype)) 1503*da0073e9SAndroid Build Coastguard Worker res1 = torch.prod(x, 0) 1504*da0073e9SAndroid Build Coastguard Worker res2 = torch.prod(x_ref, 0) 1505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2.to(dtype=dtype)) 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker def test_prod_bool(self, device): 1508*da0073e9SAndroid Build Coastguard Worker vals = [ 1509*da0073e9SAndroid Build Coastguard Worker [True, True], 1510*da0073e9SAndroid Build Coastguard Worker [True, False], 1511*da0073e9SAndroid Build Coastguard Worker [False, False], 1512*da0073e9SAndroid Build Coastguard Worker [], 1513*da0073e9SAndroid Build Coastguard Worker [False] * 256, # https://github.com/pytorch/pytorch/issues/127866 1514*da0073e9SAndroid Build Coastguard Worker ] 1515*da0073e9SAndroid Build Coastguard Worker for val in vals: 1516*da0073e9SAndroid Build Coastguard Worker result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item() 1517*da0073e9SAndroid Build Coastguard Worker expect = np.prod(np.array(val), dtype=bool) 1518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expect) 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker result = torch.prod(torch.tensor(val, device=device)).item() 1521*da0073e9SAndroid Build Coastguard Worker expect = np.prod(np.array(val)) 1522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expect) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1525*da0073e9SAndroid Build Coastguard Worker def test_max_mixed_devices(self, device): 1526*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, device=device) 1527*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1528*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10).cuda() 1529*da0073e9SAndroid Build Coastguard Worker indices = torch.cuda.LongTensor() 1530*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 1531*da0073e9SAndroid Build Coastguard Worker lambda: torch.max(a, 0, out=(values, indices))) 1532*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 1533*da0073e9SAndroid Build Coastguard Worker lambda: torch.amax(a, 0, out=values)) 1534*da0073e9SAndroid Build Coastguard Worker 1535*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1536*da0073e9SAndroid Build Coastguard Worker def test_min_mixed_devices(self, device): 1537*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, device=device) 1538*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1539*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10).cuda() 1540*da0073e9SAndroid Build Coastguard Worker indices = torch.cuda.LongTensor() 1541*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 1542*da0073e9SAndroid Build Coastguard Worker lambda: torch.min(a, 0, out=(values, indices))) 1543*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 1544*da0073e9SAndroid Build Coastguard Worker lambda: torch.amin(a, 0, out=values)) 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker # TODO: consider refactoring with bincount test 1547*da0073e9SAndroid Build Coastguard Worker def test_bucketization(self, device): 1548*da0073e9SAndroid Build Coastguard Worker values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device) 1549*da0073e9SAndroid Build Coastguard Worker values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker # simple 1d boundary and 3d input value 1552*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device) 1553*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device) 1554*da0073e9SAndroid Build Coastguard Worker output = torch.empty(2, 2, 3, device=device, dtype=torch.int64) 1555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result) 1556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result) 1557*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) 1558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result) 1559*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker # simple float 1d boundary and 1d input with output int32 type 1562*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float32, torch.float16]: 1563*da0073e9SAndroid Build Coastguard Worker values_1d_float = values_1d.to(dtype) 1564*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=dtype) 1565*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32) 1566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result) 1567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result) 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker # multiple dimension input with 0 elements 1570*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64) 1571*da0073e9SAndroid Build Coastguard Worker values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64) 1572*da0073e9SAndroid Build Coastguard Worker expected_result = values_0_el.to(torch.int64) 1573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result) 1574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result) 1575*da0073e9SAndroid Build Coastguard Worker 1576*da0073e9SAndroid Build Coastguard Worker # nan input 1577*da0073e9SAndroid Build Coastguard Worker values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64) 1578*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64) 1579*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([1, 4, 2, 4], device=device) 1580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result) 1581*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([2, 4, 3, 4], device=device) 1582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result) 1583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, values_nan, side='right'), expected_result) 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker # type promotion and non contiguous tensors 1586*da0073e9SAndroid Build Coastguard Worker values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32) 1587*da0073e9SAndroid Build Coastguard Worker boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64) 1588*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device) 1589*da0073e9SAndroid Build Coastguard Worker if self.device_type != 'xla': 1590*da0073e9SAndroid Build Coastguard Worker self.assertWarnsRegex( 1591*da0073e9SAndroid Build Coastguard Worker UserWarning, "tensor is non-contiguous", 1592*da0073e9SAndroid Build Coastguard Worker lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)) 1593*da0073e9SAndroid Build Coastguard Worker else: 1594*da0073e9SAndroid Build Coastguard Worker # All tensors in XLA is contiguous even doing permute, no warning msg will be generate in XLA 1595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result) 1596*da0073e9SAndroid Build Coastguard Worker 1597*da0073e9SAndroid Build Coastguard Worker # scalar type 1598*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([1.5, 2.5, 3.5], device=device) 1599*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor(1, device=device) 1600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, 2), expected_result) 1601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result) 1602*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor(3, device=device) 1603*da0073e9SAndroid Build Coastguard Worker scalar_tensor_nan = torch.tensor(float('nan'), device=device) 1604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result) 1605*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result) 1606*da0073e9SAndroid Build Coastguard Worker 1607*da0073e9SAndroid Build Coastguard Worker # invalid input dimensions 1608*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device) 1609*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1610*da0073e9SAndroid Build Coastguard Worker RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"): 1611*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(boundaries, values_3d) 1612*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1613*da0073e9SAndroid Build Coastguard Worker RuntimeError, "boundaries tensor must be 1 dimension"): 1614*da0073e9SAndroid Build Coastguard Worker torch.bucketize(values_3d, boundaries) 1615*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1616*da0073e9SAndroid Build Coastguard Worker RuntimeError, "only when boundaries tensor dimension is 1"): 1617*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(boundaries, 1) 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker # incompatiable output tensor's dtype 1620*da0073e9SAndroid Build Coastguard Worker def test_output_dtype(dtype, is_int32): 1621*da0073e9SAndroid Build Coastguard Worker output = values_1d.to(dtype) 1622*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1623*da0073e9SAndroid Build Coastguard Worker RuntimeError, "output tensor's dtype is wrong"): 1624*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32) 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker test_output_dtype(torch.float32, False) 1627*da0073e9SAndroid Build Coastguard Worker test_output_dtype(torch.int32, False) 1628*da0073e9SAndroid Build Coastguard Worker test_output_dtype(torch.int64, True) 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker # invalid side argument 1631*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "side can only be 'left' or 'right'"): 1632*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(values_1d, values_1d, side='bad') 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker # invalid sorter argument, wrong size 1635*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "boundary and sorter must have the same size"): 1636*da0073e9SAndroid Build Coastguard Worker sequence = torch.rand_like(values_1d, dtype=torch.float) 1637*da0073e9SAndroid Build Coastguard Worker _, sorted_idx = torch.sort(sequence) 1638*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(sequence, values_1d, sorter=sorted_idx[:-1]) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker # invalid sorter argument, is not dtype long 1641*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "sorter must be a tensor of long dtype"): 1642*da0073e9SAndroid Build Coastguard Worker sequence = torch.rand_like(values_1d, dtype=torch.float) 1643*da0073e9SAndroid Build Coastguard Worker _, sorted_idx = torch.sort(sequence) 1644*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32)) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker # invalid sorter value, out of bound (>= innermost size) 1647*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "sorter index out of range"): 1648*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3])) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker # invalid sorter value, out of bound (< 0) 1651*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "sorter index out of range"): 1652*da0073e9SAndroid Build Coastguard Worker torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2])) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker # scalar type bfloat16 1655*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cpu': 1656*da0073e9SAndroid Build Coastguard Worker def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False): 1657*da0073e9SAndroid Build Coastguard Worker values_1d_float = values_1d.to(torch.float32) 1658*da0073e9SAndroid Build Coastguard Worker boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32) 1659*da0073e9SAndroid Build Coastguard Worker if values_bf16: 1660*da0073e9SAndroid Build Coastguard Worker values_1d_float = values_1d_float.to(torch.bfloat16) 1661*da0073e9SAndroid Build Coastguard Worker if boundaries_bf16: 1662*da0073e9SAndroid Build Coastguard Worker boundaries = boundaries.to(torch.bfloat16) 1663*da0073e9SAndroid Build Coastguard Worker expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32) 1664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result) 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker test_dtype_bfloat16(True, False) 1667*da0073e9SAndroid Build Coastguard Worker test_dtype_bfloat16(False, True) 1668*da0073e9SAndroid Build Coastguard Worker test_dtype_bfloat16(True, True) 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 1671*da0073e9SAndroid Build Coastguard Worker def test_nansum(self, device, dtype): 1672*da0073e9SAndroid Build Coastguard Worker args = product( 1673*da0073e9SAndroid Build Coastguard Worker (True, False), # noncontiguous 1674*da0073e9SAndroid Build Coastguard Worker (0, 1, None), # dim 1675*da0073e9SAndroid Build Coastguard Worker ) 1676*da0073e9SAndroid Build Coastguard Worker zero = torch.zeros((), device=device, dtype=dtype) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker for noncontiguous, dim in args: 1679*da0073e9SAndroid Build Coastguard Worker # Randomly scale the values 1680*da0073e9SAndroid Build Coastguard Worker scale = random.randint(10, 100) 1681*da0073e9SAndroid Build Coastguard Worker x = make_tensor((17, 17), device=device, dtype=dtype, 1682*da0073e9SAndroid Build Coastguard Worker low=-scale, high=scale, noncontiguous=noncontiguous) 1683*da0073e9SAndroid Build Coastguard Worker 1684*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 1685*da0073e9SAndroid Build Coastguard Worker nan_mask = x < 0.2 * scale 1686*da0073e9SAndroid Build Coastguard Worker x_nonan = torch.where(nan_mask, zero, x) 1687*da0073e9SAndroid Build Coastguard Worker x[nan_mask] = np.nan 1688*da0073e9SAndroid Build Coastguard Worker else: 1689*da0073e9SAndroid Build Coastguard Worker x_nonan = x 1690*da0073e9SAndroid Build Coastguard Worker 1691*da0073e9SAndroid Build Coastguard Worker dim_kwargs = {} if dim is None else {"dim": dim} 1692*da0073e9SAndroid Build Coastguard Worker expect = torch.sum(x_nonan, **dim_kwargs) 1693*da0073e9SAndroid Build Coastguard Worker actual = torch.nansum(x, **dim_kwargs) 1694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype, 1697*da0073e9SAndroid Build Coastguard Worker with_extremal=False, atol=None, rtol=None, 1698*da0073e9SAndroid Build Coastguard Worker exact_dtype=True, with_keepdim=False): 1699*da0073e9SAndroid Build Coastguard Worker # Test 0-d to 3-d tensors. 1700*da0073e9SAndroid Build Coastguard Worker for ndims in range(0, 4): 1701*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(ndims, min_size=5, max_size=10) 1702*da0073e9SAndroid Build Coastguard Worker for n in range(ndims + 1): 1703*da0073e9SAndroid Build Coastguard Worker for c in combinations(list(range(ndims)), n): 1704*da0073e9SAndroid Build Coastguard Worker for count_dim in permutations(c): 1705*da0073e9SAndroid Build Coastguard Worker # Generate Input. 1706*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal) 1707*da0073e9SAndroid Build Coastguard Worker 1708*da0073e9SAndroid Build Coastguard Worker if count_dim == (): 1709*da0073e9SAndroid Build Coastguard Worker # Default `dims=None` case 1710*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None, 1711*da0073e9SAndroid Build Coastguard Worker atol=atol, rtol=rtol, exact_dtype=exact_dtype) 1712*da0073e9SAndroid Build Coastguard Worker else: 1713*da0073e9SAndroid Build Coastguard Worker # With `dims: tuple of ints` case 1714*da0073e9SAndroid Build Coastguard Worker if with_keepdim: 1715*da0073e9SAndroid Build Coastguard Worker torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim) 1716*da0073e9SAndroid Build Coastguard Worker np_func_partial = partial(np_func, keepdims=True, axis=count_dim) 1717*da0073e9SAndroid Build Coastguard Worker else: 1718*da0073e9SAndroid Build Coastguard Worker torch_func_partial = partial(torch_func, dim=count_dim) 1719*da0073e9SAndroid Build Coastguard Worker np_func_partial = partial(np_func, axis=count_dim) 1720*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None, 1721*da0073e9SAndroid Build Coastguard Worker atol=atol, rtol=rtol, exact_dtype=exact_dtype) 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half)) 1724*da0073e9SAndroid Build Coastguard Worker def test_count_nonzero(self, device, dtype): 1725*da0073e9SAndroid Build Coastguard Worker self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype) 1726*da0073e9SAndroid Build Coastguard Worker self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True) 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why the output is not close to numpy. 1729*da0073e9SAndroid Build Coastguard Worker def _get_relaxed_tolerances_for(self, dtype): 1730*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1731*da0073e9SAndroid Build Coastguard Worker atol = 0.4 1732*da0073e9SAndroid Build Coastguard Worker rtol = 1e-2 1733*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.float32: 1734*da0073e9SAndroid Build Coastguard Worker atol = 7e-05 1735*da0073e9SAndroid Build Coastguard Worker rtol = 3e-06 1736*da0073e9SAndroid Build Coastguard Worker else: 1737*da0073e9SAndroid Build Coastguard Worker # Default values 1738*da0073e9SAndroid Build Coastguard Worker atol = None 1739*da0073e9SAndroid Build Coastguard Worker rtol = None 1740*da0073e9SAndroid Build Coastguard Worker return atol, rtol 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False): 1743*da0073e9SAndroid Build Coastguard Worker def is_integral(dtype): 1744*da0073e9SAndroid Build Coastguard Worker return dtype in integral_types() 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 1747*da0073e9SAndroid Build Coastguard Worker # On Windows CI, the current version of `numpy` promotes all lower integers 1748*da0073e9SAndroid Build Coastguard Worker # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking 1749*da0073e9SAndroid Build Coastguard Worker # the exact dtype. 1750*da0073e9SAndroid Build Coastguard Worker # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580 1751*da0073e9SAndroid Build Coastguard Worker # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370 1752*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS and is_integral(dtype): 1753*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 1754*da0073e9SAndroid Build Coastguard Worker # For uint8, numpy promotes to uint64 while torch promotes to int64. 1755*da0073e9SAndroid Build Coastguard Worker # So we must skip this as well. 1756*da0073e9SAndroid Build Coastguard Worker if dtype == torch.uint8: 1757*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why the output is not close to numpy. 1760*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._get_relaxed_tolerances_for(dtype) 1761*da0073e9SAndroid Build Coastguard Worker self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, 1762*da0073e9SAndroid Build Coastguard Worker atol=atol, rtol=rtol, exact_dtype=exact_dtype, 1763*da0073e9SAndroid Build Coastguard Worker with_keepdim=with_keepdim, with_extremal=with_extremal) 1764*da0073e9SAndroid Build Coastguard Worker 1765*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1766*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(all_types_and(torch.half)) - {torch.uint8}) 1767*da0073e9SAndroid Build Coastguard Worker def test_sum_vs_numpy(self, device, dtype): 1768*da0073e9SAndroid Build Coastguard Worker self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype) 1769*da0073e9SAndroid Build Coastguard Worker self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True) 1770*da0073e9SAndroid Build Coastguard Worker self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True) 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1773*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(all_types_and(torch.half)) - {torch.uint8}) 1774*da0073e9SAndroid Build Coastguard Worker def test_nansum_vs_numpy(self, device, dtype): 1775*da0073e9SAndroid Build Coastguard Worker self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype) 1776*da0073e9SAndroid Build Coastguard Worker self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True) 1777*da0073e9SAndroid Build Coastguard Worker self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True) 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1780*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 1781*da0073e9SAndroid Build Coastguard Worker def test_nansum_complex(self, device, dtype): 1782*da0073e9SAndroid Build Coastguard Worker x = torch.randn((3, 3, 3), device=device, dtype=dtype) 1783*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"): 1784*da0073e9SAndroid Build Coastguard Worker torch.nansum(x) 1785*da0073e9SAndroid Build Coastguard Worker 1786*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half)) 1787*da0073e9SAndroid Build Coastguard Worker def test_nansum_out_dtype(self, device, dtype): 1788*da0073e9SAndroid Build Coastguard Worker out_dtype = dtype 1789*da0073e9SAndroid Build Coastguard Worker inp_dtypes = all_types_and(torch.half) if out_dtype.is_floating_point else integral_types() 1790*da0073e9SAndroid Build Coastguard Worker for inp_dtype in inp_dtypes: 1791*da0073e9SAndroid Build Coastguard Worker # TODO: Investigate why the output is not close to numpy. 1792*da0073e9SAndroid Build Coastguard Worker atol, rtol = self._get_relaxed_tolerances_for(dtype) 1793*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10) 1794*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, inp_dtype, device, with_extremal=False) 1795*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.nansum, dtype=out_dtype) 1796*da0073e9SAndroid Build Coastguard Worker np_out_dtype = torch_to_numpy_dtype_dict[out_dtype] 1797*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.nansum, dtype=np_out_dtype) 1798*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None, atol=atol, rtol=rtol) 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half)) 1801*da0073e9SAndroid Build Coastguard Worker def test_argminmax_multiple(self, device, dtype): 1802*da0073e9SAndroid Build Coastguard Worker # Case: All Ones 1803*da0073e9SAndroid Build Coastguard Worker t = torch.ones(3, 3, device=device, dtype=dtype) 1804*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.argmax, np.argmax, t) 1805*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.argmin, np.argmin, t) 1806*da0073e9SAndroid Build Coastguard Worker 1807*da0073e9SAndroid Build Coastguard Worker # Case: With single `nan` present. 1808*da0073e9SAndroid Build Coastguard Worker if dtype in floating_types_and(torch.half, torch.bfloat16): 1809*da0073e9SAndroid Build Coastguard Worker t[2, 2] = float('nan') 1810*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.argmax, np.argmax, t) 1811*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.argmin, np.argmin, t) 1812*da0073e9SAndroid Build Coastguard Worker 1813*da0073e9SAndroid Build Coastguard Worker # Case: Randomly Generated Tensors 1814*da0073e9SAndroid Build Coastguard Worker for ndims in range(1, 5): 1815*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(ndims, min_size=5, max_size=10) 1816*da0073e9SAndroid Build Coastguard Worker for with_extremal in [False, True]: 1817*da0073e9SAndroid Build Coastguard Worker for contiguous in [False, True]: 1818*da0073e9SAndroid Build Coastguard Worker # Generate Input. 1819*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal) 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 1822*da0073e9SAndroid Build Coastguard Worker max_val = torch.max(x.to(torch.float)) 1823*da0073e9SAndroid Build Coastguard Worker min_val = torch.min(x.to(torch.float)) 1824*da0073e9SAndroid Build Coastguard Worker else: 1825*da0073e9SAndroid Build Coastguard Worker max_val = torch.max(x) 1826*da0073e9SAndroid Build Coastguard Worker min_val = torch.min(x) 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker mask = torch.randn(x.shape) > 0.5 1829*da0073e9SAndroid Build Coastguard Worker x[mask] = torch.tensor(max_val + 1, dtype=dtype) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker mask = torch.randn(x.shape) > 0.5 1832*da0073e9SAndroid Build Coastguard Worker x[mask] = torch.tensor(min_val - 1, dtype=dtype) 1833*da0073e9SAndroid Build Coastguard Worker 1834*da0073e9SAndroid Build Coastguard Worker if not contiguous: 1835*da0073e9SAndroid Build Coastguard Worker x = x.T 1836*da0073e9SAndroid Build Coastguard Worker 1837*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None) 1838*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None) 1839*da0073e9SAndroid Build Coastguard Worker 1840*da0073e9SAndroid Build Coastguard Worker # Verify indices returned by max and min. 1841*da0073e9SAndroid Build Coastguard Worker if dtype != torch.half: 1842*da0073e9SAndroid Build Coastguard Worker rand_dim = random.randint(0, ndims - 1) 1843*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1], 1844*da0073e9SAndroid Build Coastguard Worker lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None) 1845*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1], 1846*da0073e9SAndroid Build Coastguard Worker lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None) 1847*da0073e9SAndroid Build Coastguard Worker 1848*da0073e9SAndroid Build Coastguard Worker def verify_against_numpy(t): 1849*da0073e9SAndroid Build Coastguard Worker # Argmax 1850*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.argmax, dim=1) 1851*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.argmax, axis=1) 1852*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, t) 1853*da0073e9SAndroid Build Coastguard Worker # Non-contiguous input 1854*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, t.T) 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker # Verify indices returned by max. 1857*da0073e9SAndroid Build Coastguard Worker if dtype != torch.half: 1858*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None) 1859*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) 1860*da0073e9SAndroid Build Coastguard Worker 1861*da0073e9SAndroid Build Coastguard Worker # Argmin 1862*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.argmin, dim=1) 1863*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.argmin, axis=1) 1864*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, t) 1865*da0073e9SAndroid Build Coastguard Worker # Non-contiguous input 1866*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, t.T) 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker # Verify indices returned by min. 1869*da0073e9SAndroid Build Coastguard Worker if dtype != torch.half: 1870*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None) 1871*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 1874*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 5], 1875*da0073e9SAndroid Build Coastguard Worker [2, 10], 1876*da0073e9SAndroid Build Coastguard Worker [3, 3]], device=device, dtype=dtype) 1877*da0073e9SAndroid Build Coastguard Worker verify_against_numpy(t) 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 1880*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 5], 1881*da0073e9SAndroid Build Coastguard Worker [2, 10], 1882*da0073e9SAndroid Build Coastguard Worker [0, 0]], device=device, dtype=dtype) 1883*da0073e9SAndroid Build Coastguard Worker verify_against_numpy(t) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 1886*da0073e9SAndroid Build Coastguard Worker def test_all_any_vs_numpy(self, device, dtype): 1887*da0073e9SAndroid Build Coastguard Worker # Note [all, any uint8 compatibility]: However for compatibility reason, 1888*da0073e9SAndroid Build Coastguard Worker # for `uint8`, they return Tensor of same dtype `uint8`. 1889*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561 1890*da0073e9SAndroid Build Coastguard Worker exact_dtype = True if dtype != torch.uint8 else False 1891*da0073e9SAndroid Build Coastguard Worker 1892*da0073e9SAndroid Build Coastguard Worker def _test_all_any(x): 1893*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.all, np.all, x) 1894*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch.any, np.any, x) 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker def _test_all_any_with_dim(x, dim): 1897*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.all, dim=dim) 1898*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.all, axis=dim) 1899*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) 1900*da0073e9SAndroid Build Coastguard Worker 1901*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.any, dim=dim) 1902*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.any, axis=dim) 1903*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker def _test_out_variant(x, dim): 1906*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(x) 1907*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool or dtype == torch.uint8: 1908*da0073e9SAndroid Build Coastguard Worker expected = torch.all(x, dim) 1909*da0073e9SAndroid Build Coastguard Worker torch.all(x, dim, out=out) 1910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 1911*da0073e9SAndroid Build Coastguard Worker 1912*da0073e9SAndroid Build Coastguard Worker expected = torch.any(x, dim) 1913*da0073e9SAndroid Build Coastguard Worker torch.any(x, dim, out=out) 1914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 1915*da0073e9SAndroid Build Coastguard Worker else: 1916*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"): 1917*da0073e9SAndroid Build Coastguard Worker torch.all(x, dim, out=out) 1918*da0073e9SAndroid Build Coastguard Worker 1919*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"): 1920*da0073e9SAndroid Build Coastguard Worker torch.any(x, dim, out=out) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker def _test_all_any_with_dim_keepdim(x, dim, keepdim): 1923*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.all, dim=dim, keepdim=keepdim) 1924*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.all, axis=dim, keepdims=keepdim) 1925*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) 1926*da0073e9SAndroid Build Coastguard Worker 1927*da0073e9SAndroid Build Coastguard Worker torch_fn = partial(torch.any, dim=dim, keepdim=keepdim) 1928*da0073e9SAndroid Build Coastguard Worker np_fn = partial(np.any, axis=dim, keepdims=keepdim) 1929*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype) 1930*da0073e9SAndroid Build Coastguard Worker 1931*da0073e9SAndroid Build Coastguard Worker def _test_output_dtype(x): 1932*da0073e9SAndroid Build Coastguard Worker # This test will fail once the functions return bool output 1933*da0073e9SAndroid Build Coastguard Worker # for uint8 input. 1934*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool 1935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.all(x).dtype, expected_dtype) 1936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.any(x).dtype, expected_dtype) 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype) 1939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype) 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker for ndim in range(5): 1942*da0073e9SAndroid Build Coastguard Worker shape = _rand_shape(ndim, 1, 5) 1943*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal=False) 1944*da0073e9SAndroid Build Coastguard Worker _test_all_any(x) 1945*da0073e9SAndroid Build Coastguard Worker _test_all_any(x.T) 1946*da0073e9SAndroid Build Coastguard Worker _test_all_any(x[..., ::2]) 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal=True) 1949*da0073e9SAndroid Build Coastguard Worker _test_all_any(x) 1950*da0073e9SAndroid Build Coastguard Worker _test_all_any(x.T) 1951*da0073e9SAndroid Build Coastguard Worker _test_all_any(x[..., ::2]) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker x = torch.zeros_like(x) 1954*da0073e9SAndroid Build Coastguard Worker _test_all_any(x) 1955*da0073e9SAndroid Build Coastguard Worker _test_all_any(x.T) 1956*da0073e9SAndroid Build Coastguard Worker _test_all_any(x[..., ::2]) 1957*da0073e9SAndroid Build Coastguard Worker 1958*da0073e9SAndroid Build Coastguard Worker x = torch.ones_like(x) 1959*da0073e9SAndroid Build Coastguard Worker _test_all_any(x) 1960*da0073e9SAndroid Build Coastguard Worker _test_all_any(x.T) 1961*da0073e9SAndroid Build Coastguard Worker _test_all_any(x[..., ::2]) 1962*da0073e9SAndroid Build Coastguard Worker _test_output_dtype(x) 1963*da0073e9SAndroid Build Coastguard Worker for dim in range(ndim): 1964*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal=False) 1965*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x, dim) 1966*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x.T, dim) 1967*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x[..., ::2], dim) 1968*da0073e9SAndroid Build Coastguard Worker _test_out_variant(x, dim) 1969*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=True) 1970*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=False) 1971*da0073e9SAndroid Build Coastguard Worker 1972*da0073e9SAndroid Build Coastguard Worker x = _generate_input(shape, dtype, device, with_extremal=True) 1973*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x, dim) 1974*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x.T, dim) 1975*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x[..., ::2], dim) 1976*da0073e9SAndroid Build Coastguard Worker _test_out_variant(x, dim) 1977*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=True) 1978*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=False) 1979*da0073e9SAndroid Build Coastguard Worker 1980*da0073e9SAndroid Build Coastguard Worker x = torch.zeros_like(x) 1981*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x, dim) 1982*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x.T, dim) 1983*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x[..., ::2], dim) 1984*da0073e9SAndroid Build Coastguard Worker _test_out_variant(x, dim) 1985*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=True) 1986*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=False) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker x = torch.ones_like(x) 1989*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x, dim) 1990*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x.T, dim) 1991*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim(x[..., ::2], dim) 1992*da0073e9SAndroid Build Coastguard Worker _test_out_variant(x, dim) 1993*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=True) 1994*da0073e9SAndroid Build Coastguard Worker _test_all_any_with_dim_keepdim(x, dim, keepdim=False) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker # TODO: part of this test covers torch.norm, with should be covered by test_linalg 1997*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1998*da0073e9SAndroid Build Coastguard Worker def test_repeated_dim(self, device): 1999*da0073e9SAndroid Build Coastguard Worker ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, 2000*da0073e9SAndroid Build Coastguard Worker torch.norm] 2001*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, 3, 3, device=device) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker error_msg = r'appears multiple times in the list of dims' 2004*da0073e9SAndroid Build Coastguard Worker for op in ops: 2005*da0073e9SAndroid Build Coastguard Worker for dim in [(0, 0), (0, -4)]: 2006*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 2007*da0073e9SAndroid Build Coastguard Worker op(x, dim=dim) 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker # TODO: update this test to comapre against NumPy 2010*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2011*da0073e9SAndroid Build Coastguard Worker def test_var(self, device): 2012*da0073e9SAndroid Build Coastguard Worker cpu_tensor = torch.randn(2, 3, 3) 2013*da0073e9SAndroid Build Coastguard Worker device_tensor = cpu_tensor.to(device) 2014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.var(), cpu_tensor.var()) 2015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.var(1), cpu_tensor.var(1)) 2016*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.var(2), cpu_tensor.var(2)) 2017*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.std(), cpu_tensor.std()) 2018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.std(1), cpu_tensor.std(1)) 2019*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.var(2), cpu_tensor.var(2)) 2020*da0073e9SAndroid Build Coastguard Worker 2021*da0073e9SAndroid Build Coastguard Worker cpu_tensor = torch.randn(100) 2022*da0073e9SAndroid Build Coastguard Worker device_tensor = cpu_tensor.to(device) 2023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(device_tensor.var(), cpu_tensor.var()) 2024*da0073e9SAndroid Build Coastguard Worker 2025*da0073e9SAndroid Build Coastguard Worker # TODO: update this test to compare against NumPy 2026*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2027*da0073e9SAndroid Build Coastguard Worker def test_var_large_input(self, device): 2028*da0073e9SAndroid Build Coastguard Worker # Large, not-nice input 2029*da0073e9SAndroid Build Coastguard Worker cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67) 2030*da0073e9SAndroid Build Coastguard Worker device_tensor = cpu_tensor.to(device) 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_tensor.var(2), device_tensor.var(2)) 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker # TODO: update this to compare against NumPy instead of CPU 2035*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2036*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 2037*da0073e9SAndroid Build Coastguard Worker def test_sum_noncontig(self, device, dtype): 2038*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2) 2039*da0073e9SAndroid Build Coastguard Worker y = x.cpu() 2040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum().cpu(), y.sum()) 2041*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2))) 2042*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3))) 2043*da0073e9SAndroid Build Coastguard Worker 2044*da0073e9SAndroid Build Coastguard Worker # TODO: update this to compare against NumPy instead of CPU 2045*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2046*da0073e9SAndroid Build Coastguard Worker def test_min_max_nan(self, device): 2047*da0073e9SAndroid Build Coastguard Worker tests = [(lambda x: x.min(), 'min'), 2048*da0073e9SAndroid Build Coastguard Worker (lambda x: x.max(), 'max'), 2049*da0073e9SAndroid Build Coastguard Worker (lambda x: x.amin(), 'amin'), 2050*da0073e9SAndroid Build Coastguard Worker (lambda x: x.amax(), 'amax'), 2051*da0073e9SAndroid Build Coastguard Worker (lambda x: x.min(0).values, 'min_dim'), 2052*da0073e9SAndroid Build Coastguard Worker (lambda x: x.max(0).values, 'max_dim'), 2053*da0073e9SAndroid Build Coastguard Worker (lambda x: x.amin(0), 'amin_dim'), 2054*da0073e9SAndroid Build Coastguard Worker (lambda x: x.amax(0), 'amax_dim')] 2055*da0073e9SAndroid Build Coastguard Worker for f, name in tests: 2056*da0073e9SAndroid Build Coastguard Worker a = torch.arange(25.0).view(5, 5) 2057*da0073e9SAndroid Build Coastguard Worker a[2, 2] = nan 2058*da0073e9SAndroid Build Coastguard Worker actual = f(a.to(device)).cpu() 2059*da0073e9SAndroid Build Coastguard Worker expected = f(a).cpu() 2060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg=f'nans for {name}') 2061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual[~torch.isnan(actual)], 2062*da0073e9SAndroid Build Coastguard Worker expected[~torch.isnan(expected)], msg=f'nans for {name}') 2063*da0073e9SAndroid Build Coastguard Worker 2064*da0073e9SAndroid Build Coastguard Worker # TODO: make this test generic using OpInfos 2065*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2066*da0073e9SAndroid Build Coastguard Worker def test_sum_cpu_device_mismatch(self, device): 2067*da0073e9SAndroid Build Coastguard Worker x = torch.randn(20, dtype=torch.float32, device=device) 2068*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, dtype=torch.float32) 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker err_string = f"Expected out tensor to have device {device}, but got cpu instead" 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_string): 2073*da0073e9SAndroid Build Coastguard Worker torch.sum(x, dim=[0], dtype=torch.float32, out=y) 2074*da0073e9SAndroid Build Coastguard Worker 2075*da0073e9SAndroid Build Coastguard Worker # tests half to float promotion 2076*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda': 2077*da0073e9SAndroid Build Coastguard Worker x = x.half() 2078*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_string): 2079*da0073e9SAndroid Build Coastguard Worker torch.sum(x, dim=[0], dtype=torch.float32, out=y) 2080*da0073e9SAndroid Build Coastguard Worker 2081*da0073e9SAndroid Build Coastguard Worker # Assert for illegal dtype would not be raised on XLA 2082*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2083*da0073e9SAndroid Build Coastguard Worker def test_minmax_illegal_dtype(self, device): 2084*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.float32, device=device) 2085*da0073e9SAndroid Build Coastguard Worker valid_values = torch.empty(5, dtype=torch.float32, device=device) 2086*da0073e9SAndroid Build Coastguard Worker valid_indices = torch.empty(5, dtype=torch.long, device=device) 2087*da0073e9SAndroid Build Coastguard Worker illegal_values = torch.empty(5, dtype=torch.int, device=device) 2088*da0073e9SAndroid Build Coastguard Worker illegal_indices = torch.empty(5, dtype=torch.double, device=device) 2089*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=0, out=(valid_values, valid_indices)) 2090*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=0, out=(valid_values, valid_indices)) 2091*da0073e9SAndroid Build Coastguard Worker torch.amax(x, dim=0, out=valid_values) 2092*da0073e9SAndroid Build Coastguard Worker torch.amin(x, dim=0, out=valid_values) 2093*da0073e9SAndroid Build Coastguard Worker rmsg = r'scalar type|dtype' 2094*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, rmsg): 2095*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=0, out=(illegal_values, valid_indices)) 2096*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, rmsg): 2097*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=0, out=(illegal_values, valid_indices)) 2098*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, rmsg): 2099*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=0, out=(valid_values, illegal_indices)) 2100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, rmsg): 2101*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=0, out=(valid_values, illegal_indices)) 2102*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, rmsg): 2103*da0073e9SAndroid Build Coastguard Worker torch.max(x, dim=0, out=(illegal_values, illegal_indices)) 2104*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, rmsg): 2105*da0073e9SAndroid Build Coastguard Worker torch.min(x, dim=0, out=(illegal_values, illegal_indices)) 2106*da0073e9SAndroid Build Coastguard Worker 2107*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 2108*da0073e9SAndroid Build Coastguard Worker def test_dim_arg_reduction_scalar(self, device, dtype): 2109*da0073e9SAndroid Build Coastguard Worker example = 4.0 2110*da0073e9SAndroid Build Coastguard Worker 2111*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax().item(), 0) 2113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=None).item(), 0) 2114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=0).item(), 0) 2115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64)) 2116*da0073e9SAndroid Build Coastguard Worker 2117*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin().item(), 0) 2119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=None).item(), 0) 2120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=0).item(), 0) 2121*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64)) 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}) 2125*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8}) 2126*da0073e9SAndroid Build Coastguard Worker def test_dim_reduction(self, device, dtype): 2127*da0073e9SAndroid Build Coastguard Worker example = [[-1, 2, 1], [5, 3, 6]] 2128*da0073e9SAndroid Build Coastguard Worker 2129*da0073e9SAndroid Build Coastguard Worker sum_dtype = { 2130*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.bfloat16, 2131*da0073e9SAndroid Build Coastguard Worker torch.double: torch.double, 2132*da0073e9SAndroid Build Coastguard Worker torch.float: torch.float, 2133*da0073e9SAndroid Build Coastguard Worker torch.half: torch.half, 2134*da0073e9SAndroid Build Coastguard Worker torch.int64: torch.int64, 2135*da0073e9SAndroid Build Coastguard Worker torch.int32: torch.int64, 2136*da0073e9SAndroid Build Coastguard Worker torch.int16: torch.int64, 2137*da0073e9SAndroid Build Coastguard Worker torch.int8: torch.int64 2138*da0073e9SAndroid Build Coastguard Worker } 2139*da0073e9SAndroid Build Coastguard Worker 2140*da0073e9SAndroid Build Coastguard Worker # This won't test for 256bit instructions, since we usually 2141*da0073e9SAndroid Build Coastguard Worker # only work on 1 cacheline (512bit) at a time and these 2142*da0073e9SAndroid Build Coastguard Worker # examples aren't big enough to trigger that. 2143*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum().item(), 16) 2145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype])) 2146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype])) 2147*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(example, device=device, dtype=sum_dtype[dtype]) 2148*da0073e9SAndroid Build Coastguard Worker torch.sum(x, 0, out=y) 2149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(0), y) 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker # Mean not supported for Int types 2152*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]: 2153*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.mean().item(), 16.0 / 6) 2155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype)) 2156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype)) 2157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.mean(), x.mean((0, 1))) 2158*da0073e9SAndroid Build Coastguard Worker 2159*da0073e9SAndroid Build Coastguard Worker prod_dtype = { 2160*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: torch.bfloat16, 2161*da0073e9SAndroid Build Coastguard Worker torch.double: torch.double, 2162*da0073e9SAndroid Build Coastguard Worker torch.float: torch.float, 2163*da0073e9SAndroid Build Coastguard Worker torch.float16: torch.float16, 2164*da0073e9SAndroid Build Coastguard Worker torch.int64: torch.int64, 2165*da0073e9SAndroid Build Coastguard Worker torch.int32: torch.int64, 2166*da0073e9SAndroid Build Coastguard Worker torch.int16: torch.int64, 2167*da0073e9SAndroid Build Coastguard Worker torch.int8: torch.int64, 2168*da0073e9SAndroid Build Coastguard Worker } 2169*da0073e9SAndroid Build Coastguard Worker 2170*da0073e9SAndroid Build Coastguard Worker # prod is not supported for float16 & bfloat16 on CPU 2171*da0073e9SAndroid Build Coastguard Worker if not (self.device_type == 'cpu' and dtype in [torch.float16, torch.bfloat16]): 2172*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.prod().item(), -180) 2174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype])) 2175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype])) 2176*da0073e9SAndroid Build Coastguard Worker 2177*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2178*da0073e9SAndroid Build Coastguard Worker 2179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.min().item(), -1) 2180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin().item(), 0) 2181*da0073e9SAndroid Build Coastguard Worker 2182*da0073e9SAndroid Build Coastguard Worker # TODO: torch.min does not support the same operation as argmin 2183*da0073e9SAndroid Build Coastguard Worker # for the same case, should we enable it? 2184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=None).item(), 0) 2185*da0073e9SAndroid Build Coastguard Worker 2186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype), 2187*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 0, 0], dtype=torch.int64))) 2188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype)) 2189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64)) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype), 2192*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0, 0]], dtype=torch.int64))) 2193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype)) 2194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64)) 2195*da0073e9SAndroid Build Coastguard Worker 2196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype), 2197*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1], dtype=torch.int64))) 2198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype)) 2199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64)) 2200*da0073e9SAndroid Build Coastguard Worker 2201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype), 2202*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0], [1]], dtype=torch.int64))) 2203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype)) 2204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64)) 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker # test that non-contiguous tensors work 2207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, :2].min().item(), -1) 2208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, :2].amin().item(), -1) 2209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, :2].argmin().item(), 0) 2210*da0073e9SAndroid Build Coastguard Worker 2211*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(example, device=device, dtype=dtype) 2212*da0073e9SAndroid Build Coastguard Worker 2213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max().item(), 6) 2214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amax().item(), 6) 2215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax().item(), 5) 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype), 2218*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 1, 1], dtype=torch.int64))) 2219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype)) 2220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64)) 2221*da0073e9SAndroid Build Coastguard Worker 2222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype), 2223*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1, 1]], dtype=torch.int64))) 2224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype)) 2225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64)) 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype), 2228*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2], dtype=torch.int64))) 2229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype)) 2230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64)) 2231*da0073e9SAndroid Build Coastguard Worker 2232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype), 2233*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1], [2]], dtype=torch.int64))) 2234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype)) 2235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64)) 2236*da0073e9SAndroid Build Coastguard Worker 2237*da0073e9SAndroid Build Coastguard Worker # test that non-contiguous tensors work 2238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, :2].max().item(), 5) 2239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, :2].amax().item(), 5) 2240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, :2].argmax().item(), 2) 2241*da0073e9SAndroid Build Coastguard Worker 2242*da0073e9SAndroid Build Coastguard Worker 2243*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}) 2244*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8}) 2245*da0073e9SAndroid Build Coastguard Worker @parametrize("fn_name", [ 2246*da0073e9SAndroid Build Coastguard Worker "mean", "median", "nanmedian", "mode", "norm", "prod", 2247*da0073e9SAndroid Build Coastguard Worker "std", "sum", "var", "max", "min", "amax", "amin"]) 2248*da0073e9SAndroid Build Coastguard Worker def test_dim_reduction_fns(self, device, dtype, fn_name): 2249*da0073e9SAndroid Build Coastguard Worker def normfn_attr(t, dim, keepdim=False, out=None): 2250*da0073e9SAndroid Build Coastguard Worker attr = torch.norm 2251*da0073e9SAndroid Build Coastguard Worker return attr(t, 2, dim, keepdim, out=out) 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr 2254*da0073e9SAndroid Build Coastguard Worker 2255*da0073e9SAndroid Build Coastguard Worker def fn(x, dim, keepdim=False, out=None): 2256*da0073e9SAndroid Build Coastguard Worker ans = fn_attr(x, dim, keepdim=keepdim, out=out) 2257*da0073e9SAndroid Build Coastguard Worker return ans if not isinstance(ans, tuple) else ans[0] 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker def fn_tuple(x, dim, keepdim=False, out=None): 2260*da0073e9SAndroid Build Coastguard Worker return fn_attr(x, dim, keepdim=keepdim, out=out) 2261*da0073e9SAndroid Build Coastguard Worker 2262*da0073e9SAndroid Build Coastguard Worker def test_multidim(x, dim): 2263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True)) 2264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension()) 2265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension()) 2266*da0073e9SAndroid Build Coastguard Worker 2267*da0073e9SAndroid Build Coastguard Worker # general case 2268*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, device=device) 2269*da0073e9SAndroid Build Coastguard Worker dim = random.randint(0, 2) 2270*da0073e9SAndroid Build Coastguard Worker test_multidim(x, dim) 2271*da0073e9SAndroid Build Coastguard Worker 2272*da0073e9SAndroid Build Coastguard Worker # check 1-d behavior 2273*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, device=device) 2274*da0073e9SAndroid Build Coastguard Worker dim = 0 2275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, dim).shape, ()) 2276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, dim, keepdim=True).shape, (1,)) 2277*da0073e9SAndroid Build Coastguard Worker 2278*da0073e9SAndroid Build Coastguard Worker # check reducing of a singleton dimension 2279*da0073e9SAndroid Build Coastguard Worker dims = [3, 4, 5] 2280*da0073e9SAndroid Build Coastguard Worker singleton_dim = random.randint(0, 2) 2281*da0073e9SAndroid Build Coastguard Worker dims[singleton_dim] = 1 2282*da0073e9SAndroid Build Coastguard Worker x = torch.randn(dims, device=device) 2283*da0073e9SAndroid Build Coastguard Worker test_multidim(x, singleton_dim) 2284*da0073e9SAndroid Build Coastguard Worker 2285*da0073e9SAndroid Build Coastguard Worker # check reducing with output kwargs 2286*da0073e9SAndroid Build Coastguard Worker if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']: 2287*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, device=device) 2288*da0073e9SAndroid Build Coastguard Worker values = torch.randn(5, 3, device=device) 2289*da0073e9SAndroid Build Coastguard Worker indices = torch.zeros(5, 3, device=device).long() - 1 2290*da0073e9SAndroid Build Coastguard Worker fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1])) 2291*da0073e9SAndroid Build Coastguard Worker values_expected, indices_expected = fn_tuple(y, 1, keepdim=False) 2292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values[:, 1], values_expected, 2293*da0073e9SAndroid Build Coastguard Worker msg=f'{fn_name} values with out= kwarg') 2294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices[:, 1], indices_expected, 2295*da0073e9SAndroid Build Coastguard Worker msg=f'{fn_name} indices with out= kwarg') 2296*da0073e9SAndroid Build Coastguard Worker return 2297*da0073e9SAndroid Build Coastguard Worker 2298*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3, device=device) 2299*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, device=device) 2300*da0073e9SAndroid Build Coastguard Worker fn(y, 1, keepdim=False, out=x[:, 1]) 2301*da0073e9SAndroid Build Coastguard Worker expected = fn(y, 1, keepdim=False) 2302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, 1], expected, msg=f'{fn_name} with out= kwarg') 2303*da0073e9SAndroid Build Coastguard Worker 2304*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2305*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('10GB') 2306*da0073e9SAndroid Build Coastguard Worker def test_reduction_split(self, device): 2307*da0073e9SAndroid Build Coastguard Worker # Test reduction when there is a 32bit-indexing split 2308*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/37583 2309*da0073e9SAndroid Build Coastguard Worker input_ = torch.randn(5, 14400, 14400, device=device) 2310*da0073e9SAndroid Build Coastguard Worker result = input_.sum(dim=0) 2311*da0073e9SAndroid Build Coastguard Worker expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4] 2312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expect) 2313*da0073e9SAndroid Build Coastguard Worker 2314*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2315*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double, torch.bfloat16) 2316*da0073e9SAndroid Build Coastguard Worker def test_reduction_vectorize_along_input_corner(self, device, dtype): 2317*da0073e9SAndroid Build Coastguard Worker # 1D case: sum 2318*da0073e9SAndroid Build Coastguard Worker size = 1024 * 1024 * 64 + 3 2319*da0073e9SAndroid Build Coastguard Worker shift = 1 2320*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(size, dtype=dtype, device=device) 2321*da0073e9SAndroid Build Coastguard Worker y = x[shift:] 2322*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2323*da0073e9SAndroid Build Coastguard Worker x.zero_() 2324*da0073e9SAndroid Build Coastguard Worker x[i] = 1 2325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 1.0) 2326*da0073e9SAndroid Build Coastguard Worker if i < shift: 2327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.sum(), 0.0) 2328*da0073e9SAndroid Build Coastguard Worker else: 2329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.sum(), 1.0) 2330*da0073e9SAndroid Build Coastguard Worker for i in range(1, 100): 2331*da0073e9SAndroid Build Coastguard Worker x.zero_() 2332*da0073e9SAndroid Build Coastguard Worker x[-i] = 1 2333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.sum(), 1.0) 2334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.sum(), 1.0) 2335*da0073e9SAndroid Build Coastguard Worker # 1D case: argmax 2336*da0073e9SAndroid Build Coastguard Worker size = 1024 * 1024 * 64 + 3 2337*da0073e9SAndroid Build Coastguard Worker shift = 1 2338*da0073e9SAndroid Build Coastguard Worker ysize = size - shift 2339*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(size, dtype=dtype, device=device) 2340*da0073e9SAndroid Build Coastguard Worker y = x[shift:] 2341*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2342*da0073e9SAndroid Build Coastguard Worker x.zero_() 2343*da0073e9SAndroid Build Coastguard Worker x[i] = 1 2344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax().item(), i) 2345*da0073e9SAndroid Build Coastguard Worker if i >= shift: 2346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.argmax().item(), i - shift) 2347*da0073e9SAndroid Build Coastguard Worker for i in range(1, 100): 2348*da0073e9SAndroid Build Coastguard Worker x.zero_() 2349*da0073e9SAndroid Build Coastguard Worker x[-i] = 1 2350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax().item(), size - i) 2351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.argmax().item(), ysize - i) 2352*da0073e9SAndroid Build Coastguard Worker # 2D case: sum 2353*da0073e9SAndroid Build Coastguard Worker size = (7, 1024 * 1024 + 3) 2354*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(size, dtype=dtype, device=device) 2355*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2356*da0073e9SAndroid Build Coastguard Worker x.zero_() 2357*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2358*da0073e9SAndroid Build Coastguard Worker x[j][i] = j 2359*da0073e9SAndroid Build Coastguard Worker xs = x.sum(dim=-1) 2360*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2361*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs[j].item(), float(j)) 2362*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2363*da0073e9SAndroid Build Coastguard Worker x.zero_() 2364*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2365*da0073e9SAndroid Build Coastguard Worker x[j][-i] = j 2366*da0073e9SAndroid Build Coastguard Worker xs = x.sum(dim=-1) 2367*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs[j].item(), float(j)) 2369*da0073e9SAndroid Build Coastguard Worker # 2D case: max/argmax 2370*da0073e9SAndroid Build Coastguard Worker size = (7, 1024 * 1024 + 3) 2371*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(size, dtype=dtype, device=device) 2372*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2373*da0073e9SAndroid Build Coastguard Worker x.zero_() 2374*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2375*da0073e9SAndroid Build Coastguard Worker x[j][i] = j + 1 2376*da0073e9SAndroid Build Coastguard Worker xs1 = x.argmax(dim=-1) 2377*da0073e9SAndroid Build Coastguard Worker xs2 = x.max(dim=-1).indices 2378*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs1[j].item(), i) 2380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs2[j].item(), i) 2381*da0073e9SAndroid Build Coastguard Worker for i in range(1, 100): 2382*da0073e9SAndroid Build Coastguard Worker x.zero_() 2383*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2384*da0073e9SAndroid Build Coastguard Worker x[j][-i] = j + 1 2385*da0073e9SAndroid Build Coastguard Worker xs1 = x.argmax(dim=-1) 2386*da0073e9SAndroid Build Coastguard Worker xs2 = x.max(dim=-1).indices 2387*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs1[j].item(), size[1] - i) 2389*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs2[j].item(), size[1] - i) 2390*da0073e9SAndroid Build Coastguard Worker # 2D case: min/argmin 2391*da0073e9SAndroid Build Coastguard Worker size = (7, 1024 * 1024 + 3) 2392*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(size, dtype=dtype, device=device) 2393*da0073e9SAndroid Build Coastguard Worker for i in range(100): 2394*da0073e9SAndroid Build Coastguard Worker x.zero_() 2395*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2396*da0073e9SAndroid Build Coastguard Worker x[j][i] = -(j + 1) 2397*da0073e9SAndroid Build Coastguard Worker xs1 = x.argmin(dim=-1) 2398*da0073e9SAndroid Build Coastguard Worker xs2 = x.min(dim=-1).indices 2399*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs1[j].item(), i) 2401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs2[j].item(), i) 2402*da0073e9SAndroid Build Coastguard Worker for i in range(1, 100): 2403*da0073e9SAndroid Build Coastguard Worker x.zero_() 2404*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2405*da0073e9SAndroid Build Coastguard Worker x[j][-i] = -(j + 1) 2406*da0073e9SAndroid Build Coastguard Worker xs1 = x.argmin(dim=-1) 2407*da0073e9SAndroid Build Coastguard Worker xs2 = x.min(dim=-1).indices 2408*da0073e9SAndroid Build Coastguard Worker for j in range(7): 2409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs1[j].item(), size[1] - i) 2410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xs2[j].item(), size[1] - i) 2411*da0073e9SAndroid Build Coastguard Worker 2412*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2413*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float, torch.double, torch.bfloat16) 2414*da0073e9SAndroid Build Coastguard Worker def test_reduction_vectorize_along_output(self, device, dtype): 2415*da0073e9SAndroid Build Coastguard Worker def run_test(input_): 2416*da0073e9SAndroid Build Coastguard Worker M, N = input_.shape 2417*da0073e9SAndroid Build Coastguard Worker input_.zero_() 2418*da0073e9SAndroid Build Coastguard Worker for i in range(min(M, N)): 2419*da0073e9SAndroid Build Coastguard Worker input_[i][i] = 1 2420*da0073e9SAndroid Build Coastguard Worker output1 = input_.argmax(dim=0) 2421*da0073e9SAndroid Build Coastguard Worker output2 = input_.sum(dim=0) 2422*da0073e9SAndroid Build Coastguard Worker for i in range(min(M, N)): 2423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output1[i], i) 2424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output2[i], 1) 2425*da0073e9SAndroid Build Coastguard Worker # vec 4 2426*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64, 64, dtype=dtype, device=device)) 2427*da0073e9SAndroid Build Coastguard Worker # vec 2 2428*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64)) 2429*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64, 62, dtype=dtype, device=device)) 2430*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64, 2, dtype=dtype, device=device)) 2431*da0073e9SAndroid Build Coastguard Worker # vec 1 2432*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64)) 2433*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64, 61, dtype=dtype, device=device)) 2434*da0073e9SAndroid Build Coastguard Worker run_test(torch.zeros(64, 1, dtype=dtype, device=device)) 2435*da0073e9SAndroid Build Coastguard Worker 2436*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2437*da0073e9SAndroid Build Coastguard Worker def test_argminmax_large_axis(self, device): 2438*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-32863 2439*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2**31, device=device, dtype=torch.int8) 2440*da0073e9SAndroid Build Coastguard Worker x[-1] = 1 2441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(0), x.shape[0] - 1) 2442*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.max(0).indices, x.shape[0] - 1) 2443*da0073e9SAndroid Build Coastguard Worker x[-1] = -1 2444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(0), x.shape[0] - 1) 2445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.min(0).indices, x.shape[0] - 1) 2446*da0073e9SAndroid Build Coastguard Worker 2447*da0073e9SAndroid Build Coastguard Worker def test_argminmax_axis_with_dim_one(self, device): 2448*da0073e9SAndroid Build Coastguard Worker # See: https://github.com/pytorch/pytorch/issues/38922 2449*da0073e9SAndroid Build Coastguard Worker n = 32768 2450*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, n) 2451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64)) 2452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64)) 2453*da0073e9SAndroid Build Coastguard Worker 2454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64)) 2455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64)) 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) 2458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) 2461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64)) 2462*da0073e9SAndroid Build Coastguard Worker 2463*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int, torch.long, torch.float, torch.double) 2464*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double) 2465*da0073e9SAndroid Build Coastguard Worker def test_median_real_values(self, device, dtype): 2466*da0073e9SAndroid Build Coastguard Worker # Generate random 0-3D sizes 2467*da0073e9SAndroid Build Coastguard Worker sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] 2468*da0073e9SAndroid Build Coastguard Worker for size in sizes: 2469*da0073e9SAndroid Build Coastguard Worker # Create random input tensor 2470*da0073e9SAndroid Build Coastguard Worker t = torch.randn(size, device=device).type(dtype) 2471*da0073e9SAndroid Build Coastguard Worker t_numpy = t.cpu().numpy() 2472*da0073e9SAndroid Build Coastguard Worker res = t.median() 2473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, t.nanmedian()) 2474*da0073e9SAndroid Build Coastguard Worker k = int((t.numel() - 1) / 2) 2475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, t.view(-1).sort()[0][k]) 2476*da0073e9SAndroid Build Coastguard Worker if t.numel() % 2 == 1: 2477*da0073e9SAndroid Build Coastguard Worker # We can only test agains numpy for odd reductions because numpy 2478*da0073e9SAndroid Build Coastguard Worker # returns the mean of the two medians and torch returns the lower 2479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.cpu().numpy(), np.median(t_numpy)) 2480*da0073e9SAndroid Build Coastguard Worker for dim in range(t.ndim): 2481*da0073e9SAndroid Build Coastguard Worker res = t.median(dim, True) 2482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, t.nanmedian(dim, True)) 2483*da0073e9SAndroid Build Coastguard Worker size = t.size(dim) if t.ndim > 0 else 1 2484*da0073e9SAndroid Build Coastguard Worker k = int((size - 1) / 2) 2485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim)) 2486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[0], t.gather(dim, res[1])) 2487*da0073e9SAndroid Build Coastguard Worker if size % 2 == 1: 2488*da0073e9SAndroid Build Coastguard Worker # We can only test agains numpy for odd reductions because numpy 2489*da0073e9SAndroid Build Coastguard Worker # returns the mean of the two medians and torch returns the lower 2490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False) 2491*da0073e9SAndroid Build Coastguard Worker 2492*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2493*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 2494*da0073e9SAndroid Build Coastguard Worker def test_median_nan_values(self, device, dtype): 2495*da0073e9SAndroid Build Coastguard Worker # Generate random 0-3D sizes 2496*da0073e9SAndroid Build Coastguard Worker sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] 2497*da0073e9SAndroid Build Coastguard Worker for size in sizes: 2498*da0073e9SAndroid Build Coastguard Worker # Create random input tensor with nan values 2499*da0073e9SAndroid Build Coastguard Worker t = torch.rand(size, device=device, dtype=dtype) 2500*da0073e9SAndroid Build Coastguard Worker t.masked_fill_(t < 0.1, float('nan')) 2501*da0073e9SAndroid Build Coastguard Worker t_numpy = t.cpu().numpy() 2502*da0073e9SAndroid Build Coastguard Worker for op in [torch.median, torch.nanmedian]: 2503*da0073e9SAndroid Build Coastguard Worker numpy_op = np.median if op == torch.median else np.nanmedian 2504*da0073e9SAndroid Build Coastguard Worker res = op(t) 2505*da0073e9SAndroid Build Coastguard Worker num_nan = t.isnan().sum() 2506*da0073e9SAndroid Build Coastguard Worker if op == torch.median and num_nan > 0: 2507*da0073e9SAndroid Build Coastguard Worker k = t.numel() - 1 2508*da0073e9SAndroid Build Coastguard Worker else: 2509*da0073e9SAndroid Build Coastguard Worker k = int((t.numel() - num_nan - 1) / 2) 2510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, t.view(-1).sort()[0][k]) 2511*da0073e9SAndroid Build Coastguard Worker if (t.numel() - num_nan) % 2 == 1: 2512*da0073e9SAndroid Build Coastguard Worker # We can only test agains numpy for odd reductions because numpy 2513*da0073e9SAndroid Build Coastguard Worker # returns the mean of the two medians and torch returns the lower 2514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.item(), numpy_op(t.cpu().numpy())) 2515*da0073e9SAndroid Build Coastguard Worker for dim in range(t.ndim): 2516*da0073e9SAndroid Build Coastguard Worker res = op(t, dim, True) 2517*da0073e9SAndroid Build Coastguard Worker size = t.size(dim) if t.ndim > 0 else 1 2518*da0073e9SAndroid Build Coastguard Worker num_nan = t.isnan().sum(dim, True) 2519*da0073e9SAndroid Build Coastguard Worker if op == torch.median: 2520*da0073e9SAndroid Build Coastguard Worker k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2)) 2521*da0073e9SAndroid Build Coastguard Worker else: 2522*da0073e9SAndroid Build Coastguard Worker k = ((size - num_nan - 1) / 2).type(torch.long) 2523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k)) 2524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res[0], t.gather(dim, res[1])) 2525*da0073e9SAndroid Build Coastguard Worker # We can only test agains numpy for odd reductions because numpy 2526*da0073e9SAndroid Build Coastguard Worker # returns the mean of the two medians and torch returns the lower 2527*da0073e9SAndroid Build Coastguard Worker mask = (size - num_nan) % 2 == 1 2528*da0073e9SAndroid Build Coastguard Worker res = res[0].masked_select(mask).cpu() 2529*da0073e9SAndroid Build Coastguard Worker ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()] 2530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.from_numpy(ref)) 2531*da0073e9SAndroid Build Coastguard Worker 2532*da0073e9SAndroid Build Coastguard Worker def test_median_corner_cases(self, device): 2533*da0073e9SAndroid Build Coastguard Worker def check(op, a, args, key): 2534*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(a, device=device) 2535*da0073e9SAndroid Build Coastguard Worker res = op(t, *args) 2536*da0073e9SAndroid Build Coastguard Worker if not args: 2537*da0073e9SAndroid Build Coastguard Worker key = torch.tensor(key, device=device) 2538*da0073e9SAndroid Build Coastguard Worker else: 2539*da0073e9SAndroid Build Coastguard Worker if len(key) == 1: 2540*da0073e9SAndroid Build Coastguard Worker key = torch.tensor(key[0], device=device) 2541*da0073e9SAndroid Build Coastguard Worker res = res[0] 2542*da0073e9SAndroid Build Coastguard Worker else: 2543*da0073e9SAndroid Build Coastguard Worker key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device)) 2544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, key) 2545*da0073e9SAndroid Build Coastguard Worker 2546*da0073e9SAndroid Build Coastguard Worker nan = float('nan') 2547*da0073e9SAndroid Build Coastguard Worker check(torch.median, nan, [], nan) 2548*da0073e9SAndroid Build Coastguard Worker check(torch.median, [], [], nan) 2549*da0073e9SAndroid Build Coastguard Worker check(torch.nanmedian, nan, [], nan) 2550*da0073e9SAndroid Build Coastguard Worker check(torch.median, nan, [0], [nan, 0]) 2551*da0073e9SAndroid Build Coastguard Worker check(torch.nanmedian, nan, [0], [nan, 0]) 2552*da0073e9SAndroid Build Coastguard Worker check(torch.median, [nan], [0, True], [[nan], [0]]) 2553*da0073e9SAndroid Build Coastguard Worker check(torch.nanmedian, [nan], [0, True], [[nan], [0]]) 2554*da0073e9SAndroid Build Coastguard Worker check(torch.median, [nan], [0, True], [[nan], [0]]) 2555*da0073e9SAndroid Build Coastguard Worker check(torch.nanmedian, [nan], [0, True], [[nan], [0]]) 2556*da0073e9SAndroid Build Coastguard Worker 2557*da0073e9SAndroid Build Coastguard Worker # Indices are not deterministic here so can only check values 2558*da0073e9SAndroid Build Coastguard Worker check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]]) 2559*da0073e9SAndroid Build Coastguard Worker check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]]) 2560*da0073e9SAndroid Build Coastguard Worker check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]]) 2561*da0073e9SAndroid Build Coastguard Worker check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]]) 2562*da0073e9SAndroid Build Coastguard Worker 2563*da0073e9SAndroid Build Coastguard Worker # Discontiguous and strided tensors 2564*da0073e9SAndroid Build Coastguard Worker a = torch.arange(12, device=device) 2565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[::2].median(), torch.tensor(4, device=device)) 2566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device)) 2567*da0073e9SAndroid Build Coastguard Worker 2568*da0073e9SAndroid Build Coastguard Worker a.resize_(3, 4) 2569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.T.median(), torch.tensor(5, device=device)) 2570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) 2571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device)) 2572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device)) 2573*da0073e9SAndroid Build Coastguard Worker 2574*da0073e9SAndroid Build Coastguard Worker a.resize_(2, 3, 2) 2575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.T.median(), torch.tensor(5, device=device)) 2576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) 2577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) 2578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker 2581*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2582*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2583*da0073e9SAndroid Build Coastguard Worker def test_quantile(self, device, dtype): 2584*da0073e9SAndroid Build Coastguard Worker # Generate some random test cases 2585*da0073e9SAndroid Build Coastguard Worker ops = ['quantile', 'nanquantile'] 2586*da0073e9SAndroid Build Coastguard Worker inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)] 2587*da0073e9SAndroid Build Coastguard Worker quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)] 2588*da0073e9SAndroid Build Coastguard Worker keepdims = [True, False] 2589*da0073e9SAndroid Build Coastguard Worker 2590*da0073e9SAndroid Build Coastguard Worker # Add corner cases 2591*da0073e9SAndroid Build Coastguard Worker inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)]) 2592*da0073e9SAndroid Build Coastguard Worker inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]]) 2593*da0073e9SAndroid Build Coastguard Worker inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]]) 2594*da0073e9SAndroid Build Coastguard Worker quantiles.extend([0.5, [0., 1.], np.random.rand(10)]) 2595*da0073e9SAndroid Build Coastguard Worker 2596*da0073e9SAndroid Build Coastguard Worker # Enumerate all input combinations 2597*da0073e9SAndroid Build Coastguard Worker for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims): 2598*da0073e9SAndroid Build Coastguard Worker if type(x) is tuple: 2599*da0073e9SAndroid Build Coastguard Worker a = torch.randn(x, dtype=dtype, device=device) 2600*da0073e9SAndroid Build Coastguard Worker # Make some random elements NaN 2601*da0073e9SAndroid Build Coastguard Worker a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan')) 2602*da0073e9SAndroid Build Coastguard Worker else: 2603*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(x, dtype=dtype, device=device) 2604*da0073e9SAndroid Build Coastguard Worker 2605*da0073e9SAndroid Build Coastguard Worker q = torch.tensor(q, dtype=dtype, device=device) 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker torch_op = getattr(torch, op) 2608*da0073e9SAndroid Build Coastguard Worker numpy_op = getattr(np, op) 2609*da0073e9SAndroid Build Coastguard Worker 2610*da0073e9SAndroid Build Coastguard Worker # Compute quantile along every dimension and flattened tensor 2611*da0073e9SAndroid Build Coastguard Worker interpolations = ('linear', 'lower', 'higher', 'midpoint', 'nearest') 2612*da0073e9SAndroid Build Coastguard Worker for interpolation, dim in product(interpolations, 2613*da0073e9SAndroid Build Coastguard Worker [None] + list(range(a.ndim))): 2614*da0073e9SAndroid Build Coastguard Worker result = torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation) 2615*da0073e9SAndroid Build Coastguard Worker expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim, 2616*da0073e9SAndroid Build Coastguard Worker interpolation=interpolation, keepdims=keepdim) 2617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type())) 2618*da0073e9SAndroid Build Coastguard Worker 2619*da0073e9SAndroid Build Coastguard Worker # Test out variation 2620*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(result) 2621*da0073e9SAndroid Build Coastguard Worker torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out) 2622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.cpu(), result.cpu()) 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker def test_quantile_backward(self, device): 2625*da0073e9SAndroid Build Coastguard Worker def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)): 2626*da0073e9SAndroid Build Coastguard Worker for op in ops: 2627*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(a, device=device, requires_grad=True) 2628*da0073e9SAndroid Build Coastguard Worker op(t, torch.tensor(q, device=device), dim).sum().backward() 2629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, expected_grad) 2630*da0073e9SAndroid Build Coastguard Worker 2631*da0073e9SAndroid Build Coastguard Worker check([1., 2, 3], 0.5, 0, [0, 1, 0]) 2632*da0073e9SAndroid Build Coastguard Worker check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0]) 2633*da0073e9SAndroid Build Coastguard Worker check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5]) 2634*da0073e9SAndroid Build Coastguard Worker check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25]) 2635*da0073e9SAndroid Build Coastguard Worker check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]]) 2636*da0073e9SAndroid Build Coastguard Worker check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]]) 2637*da0073e9SAndroid Build Coastguard Worker check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile]) 2638*da0073e9SAndroid Build Coastguard Worker check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile]) 2639*da0073e9SAndroid Build Coastguard Worker 2640*da0073e9SAndroid Build Coastguard Worker def test_quantile_error(self, device): 2641*da0073e9SAndroid Build Coastguard Worker def check(a, q, args, kwargs, message): 2642*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message): 2643*da0073e9SAndroid Build Coastguard Worker at = torch.tensor(a, device=device) 2644*da0073e9SAndroid Build Coastguard Worker qt = torch.tensor(q, device=device) if isinstance(q, list) else q 2645*da0073e9SAndroid Build Coastguard Worker torch.quantile(at, qt, *args, **kwargs) 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker check([], 0.5, [], {}, r'input tensor must be non-empty') 2648*da0073e9SAndroid Build Coastguard Worker check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor') 2649*da0073e9SAndroid Build Coastguard Worker check([1], 0.5, [], {}, r'input tensor must be either float or double dtype') 2650*da0073e9SAndroid Build Coastguard Worker check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor') 2651*da0073e9SAndroid Build Coastguard Worker check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1') 2652*da0073e9SAndroid Build Coastguard Worker check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1') 2653*da0073e9SAndroid Build Coastguard Worker check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.int32, device=device)}, 2654*da0073e9SAndroid Build Coastguard Worker r'out tensor must be same dtype as the input tensor') 2655*da0073e9SAndroid Build Coastguard Worker check([1.], [1.], [None, False], {'interpolation': 'random_mode'}, 2656*da0073e9SAndroid Build Coastguard Worker r"interpolation must be one of linear, lower, higher, midpoint or nearest, but got random_mode") 2657*da0073e9SAndroid Build Coastguard Worker 2658*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 2659*da0073e9SAndroid Build Coastguard Worker check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]') 2660*da0073e9SAndroid Build Coastguard Worker 2661*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cuda": 2662*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2663*da0073e9SAndroid Build Coastguard Worker RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'): 2664*da0073e9SAndroid Build Coastguard Worker torch.randn(1, device=device).quantile(torch.tensor(0.5)) 2665*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2666*da0073e9SAndroid Build Coastguard Worker RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'): 2667*da0073e9SAndroid Build Coastguard Worker torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1)) 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker def test_std_mean(self, device): 2670*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 50, 20, device=device) 2671*da0073e9SAndroid Build Coastguard Worker for dim in range(x.dim()): 2672*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 2673*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 2674*da0073e9SAndroid Build Coastguard Worker std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) 2675*da0073e9SAndroid Build Coastguard Worker std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) 2676*da0073e9SAndroid Build Coastguard Worker mean2 = x.mean(dim=dim, keepdim=keepdim) 2677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(std1, std2) 2678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker def test_std_mean_all_dims(self, device): 2681*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 50, 20, device=device) 2682*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 2683*da0073e9SAndroid Build Coastguard Worker std1, mean1 = torch.std_mean(x, unbiased=unbiased) 2684*da0073e9SAndroid Build Coastguard Worker std2 = x.std(unbiased=unbiased) 2685*da0073e9SAndroid Build Coastguard Worker mean2 = x.mean() 2686*da0073e9SAndroid Build Coastguard Worker self.assertEqual(std1, std2) 2687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2688*da0073e9SAndroid Build Coastguard Worker 2689*da0073e9SAndroid Build Coastguard Worker def test_var_mean(self, device): 2690*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 300, 50, device=device) 2691*da0073e9SAndroid Build Coastguard Worker for dim in range(x.dim()): 2692*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 2693*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 2694*da0073e9SAndroid Build Coastguard Worker var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) 2695*da0073e9SAndroid Build Coastguard Worker var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) 2696*da0073e9SAndroid Build Coastguard Worker mean2 = x.mean(dim=dim, keepdim=keepdim) 2697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var1, var2) 2698*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2699*da0073e9SAndroid Build Coastguard Worker 2700*da0073e9SAndroid Build Coastguard Worker def test_var_mean_all_dims(self, device): 2701*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 50, 20, device=device) 2702*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 2703*da0073e9SAndroid Build Coastguard Worker var1, mean1 = torch.var_mean(x, unbiased=unbiased) 2704*da0073e9SAndroid Build Coastguard Worker var2 = x.var(unbiased=unbiased) 2705*da0073e9SAndroid Build Coastguard Worker mean2 = x.mean() 2706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var1, var2) 2707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker def test_std_mean_some_dims(self, device): 2710*da0073e9SAndroid Build Coastguard Worker sizes = (4, 6, 7, 5, 3) 2711*da0073e9SAndroid Build Coastguard Worker dims = len(sizes) 2712*da0073e9SAndroid Build Coastguard Worker x = torch.rand(sizes, device=device) 2713*da0073e9SAndroid Build Coastguard Worker for num_of_dims in range(2, dims): 2714*da0073e9SAndroid Build Coastguard Worker dim_list = list(combinations(list(range(dims)), r=num_of_dims)) 2715*da0073e9SAndroid Build Coastguard Worker for dim in dim_list: 2716*da0073e9SAndroid Build Coastguard Worker for unbiased in [False, True]: 2717*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 2718*da0073e9SAndroid Build Coastguard Worker std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) 2719*da0073e9SAndroid Build Coastguard Worker std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) 2720*da0073e9SAndroid Build Coastguard Worker mean2 = x.mean(dim=dim, keepdim=keepdim) 2721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(std1, std2) 2722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker def _compare_std_var_with_numpy(self, op, device, dtype, input, dim, 2725*da0073e9SAndroid Build Coastguard Worker keepdim, unbiased, use_out): 2726*da0073e9SAndroid Build Coastguard Worker a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy() 2727*da0073e9SAndroid Build Coastguard Worker numpy_kwargs = { 2728*da0073e9SAndroid Build Coastguard Worker 'axis' : dim, 2729*da0073e9SAndroid Build Coastguard Worker 'keepdims' : keepdim, 2730*da0073e9SAndroid Build Coastguard Worker 'ddof' : 1 if unbiased else 0, 2731*da0073e9SAndroid Build Coastguard Worker } 2732*da0073e9SAndroid Build Coastguard Worker 2733*da0073e9SAndroid Build Coastguard Worker if dim is None: 2734*da0073e9SAndroid Build Coastguard Worker del numpy_kwargs['axis'] 2735*da0073e9SAndroid Build Coastguard Worker del numpy_kwargs['keepdims'] 2736*da0073e9SAndroid Build Coastguard Worker 2737*da0073e9SAndroid Build Coastguard Worker if op == 'var': 2738*da0073e9SAndroid Build Coastguard Worker torch_op = torch.var 2739*da0073e9SAndroid Build Coastguard Worker numpy_op = np.var 2740*da0073e9SAndroid Build Coastguard Worker elif op == 'std': 2741*da0073e9SAndroid Build Coastguard Worker torch_op = torch.std 2742*da0073e9SAndroid Build Coastguard Worker numpy_op = np.std 2743*da0073e9SAndroid Build Coastguard Worker else: 2744*da0073e9SAndroid Build Coastguard Worker self.fail("Unknown op!") 2745*da0073e9SAndroid Build Coastguard Worker 2746*da0073e9SAndroid Build Coastguard Worker numpy_result = numpy_op(a, **numpy_kwargs) 2747*da0073e9SAndroid Build Coastguard Worker 2748*da0073e9SAndroid Build Coastguard Worker if dim is None and use_out is False: 2749*da0073e9SAndroid Build Coastguard Worker torch_result = torch_op(input, unbiased) 2750*da0073e9SAndroid Build Coastguard Worker elif dim is not None and use_out is False: 2751*da0073e9SAndroid Build Coastguard Worker torch_result = torch_op(input, dim, unbiased, keepdim) 2752*da0073e9SAndroid Build Coastguard Worker elif dim is not None and use_out is True: 2753*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=device, dtype=dtype) 2754*da0073e9SAndroid Build Coastguard Worker torch_result = torch_op(input, dim, unbiased, keepdim, out=out) 2755*da0073e9SAndroid Build Coastguard Worker else: 2756*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=device, dtype=dtype) 2757*da0073e9SAndroid Build Coastguard Worker torch_result = torch_op(input, dim, unbiased, keepdim, out=out) 2758*da0073e9SAndroid Build Coastguard Worker 2759*da0073e9SAndroid Build Coastguard Worker exact_dtype = input.dtype not in (torch.bfloat16, torch.complex32, torch.complex64, torch.complex128) 2760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype) 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2763*da0073e9SAndroid Build Coastguard Worker def test_var_vs_numpy(self, device, dtype): 2764*da0073e9SAndroid Build Coastguard Worker _size = (20, 20) 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker for test_case in product((torch.randn(_size, device=device, dtype=dtype),), 2767*da0073e9SAndroid Build Coastguard Worker (None, 0, 1), 2768*da0073e9SAndroid Build Coastguard Worker (False, True), 2769*da0073e9SAndroid Build Coastguard Worker (False, True), 2770*da0073e9SAndroid Build Coastguard Worker (False, True),): 2771*da0073e9SAndroid Build Coastguard Worker self._compare_std_var_with_numpy('var', device, dtype, *test_case) 2772*da0073e9SAndroid Build Coastguard Worker 2773*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2774*da0073e9SAndroid Build Coastguard Worker def test_std_vs_numpy(self, device, dtype): 2775*da0073e9SAndroid Build Coastguard Worker _size = (20, 20) 2776*da0073e9SAndroid Build Coastguard Worker 2777*da0073e9SAndroid Build Coastguard Worker for test_case in product((torch.randn(_size, device=device, dtype=dtype),), 2778*da0073e9SAndroid Build Coastguard Worker (None, 0, 1), 2779*da0073e9SAndroid Build Coastguard Worker (False, True), 2780*da0073e9SAndroid Build Coastguard Worker (False, True), 2781*da0073e9SAndroid Build Coastguard Worker (False, True),): 2782*da0073e9SAndroid Build Coastguard Worker self._compare_std_var_with_numpy('std', device, dtype, *test_case) 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2785*da0073e9SAndroid Build Coastguard Worker def test_var_correction_vs_numpy(self, device, dtype): 2786*da0073e9SAndroid Build Coastguard Worker _size = (20, 20) 2787*da0073e9SAndroid Build Coastguard Worker test_args = [ 2788*da0073e9SAndroid Build Coastguard Worker *product( 2789*da0073e9SAndroid Build Coastguard Worker # dim 2790*da0073e9SAndroid Build Coastguard Worker (None, 0, 1), 2791*da0073e9SAndroid Build Coastguard Worker # correction 2792*da0073e9SAndroid Build Coastguard Worker (None, 0, 10, 30), 2793*da0073e9SAndroid Build Coastguard Worker # keepdim 2794*da0073e9SAndroid Build Coastguard Worker (False, True), 2795*da0073e9SAndroid Build Coastguard Worker ), 2796*da0073e9SAndroid Build Coastguard Worker [None, -100, True], # Negative correction 2797*da0073e9SAndroid Build Coastguard Worker ] 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor(_size, device=device, dtype=dtype) 2800*da0073e9SAndroid Build Coastguard Worker array = tensor.cpu().numpy() 2801*da0073e9SAndroid Build Coastguard Worker 2802*da0073e9SAndroid Build Coastguard Worker for dim, correction, keepdim in test_args: 2803*da0073e9SAndroid Build Coastguard Worker numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim) 2804*da0073e9SAndroid Build Coastguard Worker if correction is None: 2805*da0073e9SAndroid Build Coastguard Worker # NumPy default is not compatible with torch.std (gh-50010) 2806*da0073e9SAndroid Build Coastguard Worker numpy_kwargs['ddof'] = 1 2807*da0073e9SAndroid Build Coastguard Worker 2808*da0073e9SAndroid Build Coastguard Worker numpy_res = np.asarray(np.var(array, **numpy_kwargs)) 2809*da0073e9SAndroid Build Coastguard Worker torch_res = torch.var(tensor, dim=dim, correction=correction, keepdim=keepdim) 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker # inf vs. nan results are sensitive to machine precision, 2812*da0073e9SAndroid Build Coastguard Worker # just treat them as equivalent 2813*da0073e9SAndroid Build Coastguard Worker numpy_res[np.isinf(numpy_res)] = np.nan 2814*da0073e9SAndroid Build Coastguard Worker torch_res[torch_res.isinf()] = np.nan 2815*da0073e9SAndroid Build Coastguard Worker 2816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch_res, numpy_res) 2817*da0073e9SAndroid Build Coastguard Worker 2818*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2819*da0073e9SAndroid Build Coastguard Worker def test_std_correction_vs_numpy(self, device, dtype): 2820*da0073e9SAndroid Build Coastguard Worker _size = (20, 20) 2821*da0073e9SAndroid Build Coastguard Worker test_args = [ 2822*da0073e9SAndroid Build Coastguard Worker *product( 2823*da0073e9SAndroid Build Coastguard Worker # dim 2824*da0073e9SAndroid Build Coastguard Worker (None, 0, 1), 2825*da0073e9SAndroid Build Coastguard Worker # correction 2826*da0073e9SAndroid Build Coastguard Worker (None, 0, 10, 30), 2827*da0073e9SAndroid Build Coastguard Worker # keepdim 2828*da0073e9SAndroid Build Coastguard Worker (False, True), 2829*da0073e9SAndroid Build Coastguard Worker ), 2830*da0073e9SAndroid Build Coastguard Worker [None, -100, True], # Negative correction 2831*da0073e9SAndroid Build Coastguard Worker ] 2832*da0073e9SAndroid Build Coastguard Worker 2833*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor(_size, device=device, dtype=dtype) 2834*da0073e9SAndroid Build Coastguard Worker array = tensor.cpu().numpy() 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Worker for dim, correction, keepdim in test_args: 2837*da0073e9SAndroid Build Coastguard Worker numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim) 2838*da0073e9SAndroid Build Coastguard Worker if correction is None: 2839*da0073e9SAndroid Build Coastguard Worker # NumPy default is incompatible with torch.std (gh-50010) 2840*da0073e9SAndroid Build Coastguard Worker numpy_kwargs['ddof'] = 1 2841*da0073e9SAndroid Build Coastguard Worker 2842*da0073e9SAndroid Build Coastguard Worker numpy_res = np.asarray(np.std(array, **numpy_kwargs)) 2843*da0073e9SAndroid Build Coastguard Worker torch_res = torch.std(tensor, dim=dim, correction=correction, keepdim=keepdim) 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker # inf vs. nan results are sensitive to machine precision, 2846*da0073e9SAndroid Build Coastguard Worker # just treat them as equivalent 2847*da0073e9SAndroid Build Coastguard Worker numpy_res[np.isinf(numpy_res)] = np.nan 2848*da0073e9SAndroid Build Coastguard Worker torch_res[torch_res.isinf()] = np.nan 2849*da0073e9SAndroid Build Coastguard Worker 2850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch_res, numpy_res) 2851*da0073e9SAndroid Build Coastguard Worker 2852*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2853*da0073e9SAndroid Build Coastguard Worker def test_std_mean_correction(self, device, dtype): 2854*da0073e9SAndroid Build Coastguard Worker _size = (20, 20) 2855*da0073e9SAndroid Build Coastguard Worker test_args = [ 2856*da0073e9SAndroid Build Coastguard Worker *product( 2857*da0073e9SAndroid Build Coastguard Worker # dim 2858*da0073e9SAndroid Build Coastguard Worker (None, 0, 1), 2859*da0073e9SAndroid Build Coastguard Worker # correction 2860*da0073e9SAndroid Build Coastguard Worker (None, 0, 10, 30), 2861*da0073e9SAndroid Build Coastguard Worker # keepdim 2862*da0073e9SAndroid Build Coastguard Worker (False, True), 2863*da0073e9SAndroid Build Coastguard Worker ), 2864*da0073e9SAndroid Build Coastguard Worker [None, -100, True], # Negative correction 2865*da0073e9SAndroid Build Coastguard Worker ] 2866*da0073e9SAndroid Build Coastguard Worker 2867*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor(_size, device=device, dtype=dtype) 2868*da0073e9SAndroid Build Coastguard Worker 2869*da0073e9SAndroid Build Coastguard Worker for dim, correction, keepdim in test_args: 2870*da0073e9SAndroid Build Coastguard Worker kwargs = dict(dim=dim, correction=correction, keepdim=keepdim) 2871*da0073e9SAndroid Build Coastguard Worker std1 = torch.std(tensor, **kwargs) 2872*da0073e9SAndroid Build Coastguard Worker if dim is not None: 2873*da0073e9SAndroid Build Coastguard Worker mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim) 2874*da0073e9SAndroid Build Coastguard Worker else: 2875*da0073e9SAndroid Build Coastguard Worker mean1 = torch.mean(tensor) 2876*da0073e9SAndroid Build Coastguard Worker if keepdim: 2877*da0073e9SAndroid Build Coastguard Worker mean1 = mean1.reshape((1,) * tensor.ndim) 2878*da0073e9SAndroid Build Coastguard Worker std2, mean2 = torch.std_mean(tensor, **kwargs) 2879*da0073e9SAndroid Build Coastguard Worker 2880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(std1, std2) 2881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2882*da0073e9SAndroid Build Coastguard Worker 2883*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2884*da0073e9SAndroid Build Coastguard Worker def test_var_mean_correction(self, device, dtype): 2885*da0073e9SAndroid Build Coastguard Worker _size = (20, 20) 2886*da0073e9SAndroid Build Coastguard Worker test_args = [ 2887*da0073e9SAndroid Build Coastguard Worker *product( 2888*da0073e9SAndroid Build Coastguard Worker # dim 2889*da0073e9SAndroid Build Coastguard Worker (None, 0, 1), 2890*da0073e9SAndroid Build Coastguard Worker # correction 2891*da0073e9SAndroid Build Coastguard Worker (None, 0, 10, 30), 2892*da0073e9SAndroid Build Coastguard Worker # keepdim 2893*da0073e9SAndroid Build Coastguard Worker (False, True), 2894*da0073e9SAndroid Build Coastguard Worker ), 2895*da0073e9SAndroid Build Coastguard Worker [None, -100, True], # Negative correction 2896*da0073e9SAndroid Build Coastguard Worker ] 2897*da0073e9SAndroid Build Coastguard Worker 2898*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor(_size, device=device, dtype=dtype) 2899*da0073e9SAndroid Build Coastguard Worker 2900*da0073e9SAndroid Build Coastguard Worker for dim, correction, keepdim in test_args: 2901*da0073e9SAndroid Build Coastguard Worker kwargs = dict(dim=dim, correction=correction, keepdim=keepdim) 2902*da0073e9SAndroid Build Coastguard Worker var1 = torch.var(tensor, **kwargs) 2903*da0073e9SAndroid Build Coastguard Worker if dim is not None: 2904*da0073e9SAndroid Build Coastguard Worker mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim) 2905*da0073e9SAndroid Build Coastguard Worker else: 2906*da0073e9SAndroid Build Coastguard Worker mean1 = torch.mean(tensor) 2907*da0073e9SAndroid Build Coastguard Worker if keepdim: 2908*da0073e9SAndroid Build Coastguard Worker mean1 = mean1.reshape((1,) * tensor.ndim) 2909*da0073e9SAndroid Build Coastguard Worker var2, mean2 = torch.var_mean(tensor, **kwargs) 2910*da0073e9SAndroid Build Coastguard Worker 2911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var1, var2) 2912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mean1, mean2) 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 2915*da0073e9SAndroid Build Coastguard Worker def test_warn_invalid_degrees_of_freedom(self, device, dtype): 2916*da0073e9SAndroid Build Coastguard Worker def _assert_warning(_func, _tensor, _correction): 2917*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 2918*da0073e9SAndroid Build Coastguard Worker _func(_tensor, dim=-1, correction=_correction) 2919*da0073e9SAndroid Build Coastguard Worker self.assertIn('degrees of freedom is <= 0', str(w[0].message)) 2920*da0073e9SAndroid Build Coastguard Worker 2921*da0073e9SAndroid Build Coastguard Worker correction = 20 2922*da0073e9SAndroid Build Coastguard Worker size = (10, correction) 2923*da0073e9SAndroid Build Coastguard Worker tensor = make_tensor(size, dtype=dtype, device=device) 2924*da0073e9SAndroid Build Coastguard Worker for f in [torch.std, torch.var, torch.var_mean, torch.std_mean]: 2925*da0073e9SAndroid Build Coastguard Worker _assert_warning(f, tensor, correction) 2926*da0073e9SAndroid Build Coastguard Worker 2927*da0073e9SAndroid Build Coastguard Worker def test_amin_amax_some_dims(self, device): 2928*da0073e9SAndroid Build Coastguard Worker sizes = (4, 6, 7, 5, 3) 2929*da0073e9SAndroid Build Coastguard Worker dims = len(sizes) 2930*da0073e9SAndroid Build Coastguard Worker x = torch.rand(sizes, device=device) 2931*da0073e9SAndroid Build Coastguard Worker for num_of_dims in range(2, dims): 2932*da0073e9SAndroid Build Coastguard Worker dim_list = list(combinations(list(range(dims)), r=num_of_dims)) 2933*da0073e9SAndroid Build Coastguard Worker for dim in dim_list: 2934*da0073e9SAndroid Build Coastguard Worker for keepdim in [False, True]: 2935*da0073e9SAndroid Build Coastguard Worker amin1 = torch.amin(x, dim=dim, keepdim=keepdim) 2936*da0073e9SAndroid Build Coastguard Worker amax1 = torch.amax(x, dim=dim, keepdim=keepdim) 2937*da0073e9SAndroid Build Coastguard Worker amin2 = x 2938*da0073e9SAndroid Build Coastguard Worker amax2 = x 2939*da0073e9SAndroid Build Coastguard Worker for i, d in enumerate(dim): 2940*da0073e9SAndroid Build Coastguard Worker if not keepdim: 2941*da0073e9SAndroid Build Coastguard Worker d -= i 2942*da0073e9SAndroid Build Coastguard Worker amin2 = torch.amin(amin2, dim=d, keepdim=keepdim) 2943*da0073e9SAndroid Build Coastguard Worker amax2 = torch.amax(amax2, dim=d, keepdim=keepdim) 2944*da0073e9SAndroid Build Coastguard Worker self.assertEqual(amin1, amin2) 2945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(amax1, amax2) 2946*da0073e9SAndroid Build Coastguard Worker 2947*da0073e9SAndroid Build Coastguard Worker def test_histc(self, device): 2948*da0073e9SAndroid Build Coastguard Worker # negative nbins throws 2949*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): 2950*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1) 2951*da0073e9SAndroid Build Coastguard Worker # empty tensor 2952*da0073e9SAndroid Build Coastguard Worker actual = torch.histc(torch.tensor([], device=device), min=0, max=3) 2953*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(100, dtype=torch.float, device=device) 2954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2955*da0073e9SAndroid Build Coastguard Worker 2956*da0073e9SAndroid Build Coastguard Worker # without nbins 2957*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 2958*da0073e9SAndroid Build Coastguard Worker torch.tensor([2, 5], dtype=torch.float, device=device)) 2959*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(100, dtype=torch.float, device=device) 2960*da0073e9SAndroid Build Coastguard Worker expected[0] = 1 2961*da0073e9SAndroid Build Coastguard Worker expected[99] = 1 2962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2963*da0073e9SAndroid Build Coastguard Worker # tensor with the same element 2964*da0073e9SAndroid Build Coastguard Worker actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5) 2965*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2966*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device), 2967*da0073e9SAndroid Build Coastguard Worker actual) 2968*da0073e9SAndroid Build Coastguard Worker # no element falls between [min, max] 2969*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 2970*da0073e9SAndroid Build Coastguard Worker torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3) 2971*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2972*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device), 2973*da0073e9SAndroid Build Coastguard Worker actual) 2974*da0073e9SAndroid Build Coastguard Worker # element falls below min + integral bin size and 2975*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 2976*da0073e9SAndroid Build Coastguard Worker torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device), 2977*da0073e9SAndroid Build Coastguard Worker bins=5, min=1, max=5) 2978*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2979*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device), 2980*da0073e9SAndroid Build Coastguard Worker actual) 2981*da0073e9SAndroid Build Coastguard Worker # non-integral bin size 2982*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 2983*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 1], dtype=torch.float, device=device), 2984*da0073e9SAndroid Build Coastguard Worker bins=4, min=0, max=3) 2985*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2986*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), 2987*da0073e9SAndroid Build Coastguard Worker actual) 2988*da0073e9SAndroid Build Coastguard Worker # double input 2989*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 2990*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3) 2991*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2992*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device), 2993*da0073e9SAndroid Build Coastguard Worker actual) 2994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.dtype, torch.double) 2995*da0073e9SAndroid Build Coastguard Worker # mixed input 2996*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 2997*da0073e9SAndroid Build Coastguard Worker torch.tensor([1., 2, 1], dtype=torch.float, device=device), 2998*da0073e9SAndroid Build Coastguard Worker bins=4, min=0, max=3) 2999*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3000*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), 3001*da0073e9SAndroid Build Coastguard Worker actual) 3002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.dtype, torch.float) 3003*da0073e9SAndroid Build Coastguard Worker # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar. 3004*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 3005*da0073e9SAndroid Build Coastguard Worker torch.tensor(0, dtype=torch.float, device=device), 3006*da0073e9SAndroid Build Coastguard Worker bins=1, min=0, max=3) 3007*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3008*da0073e9SAndroid Build Coastguard Worker torch.tensor([1], dtype=torch.float, device=device), 3009*da0073e9SAndroid Build Coastguard Worker actual) 3010*da0073e9SAndroid Build Coastguard Worker # tensors with inf; min, max not provided -- should throw a RuntimeError 3011*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'): 3012*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device)) 3013*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'): 3014*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device)) 3015*da0073e9SAndroid Build Coastguard Worker # tensors with inf; min, max provided 3016*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3017*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device), 3018*da0073e9SAndroid Build Coastguard Worker bins=1, min=0, max=3), 3019*da0073e9SAndroid Build Coastguard Worker torch.tensor([0], dtype=torch.float, device=device)) 3020*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3021*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device), 3022*da0073e9SAndroid Build Coastguard Worker bins=4, max=3), 3023*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device)) 3024*da0073e9SAndroid Build Coastguard Worker # tensor with nan; min, max not provided -- should throw a RuntimeError 3025*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'): 3026*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device)) 3027*da0073e9SAndroid Build Coastguard Worker # tensor with nan; min, max provided -- nan is ignored 3028*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3029*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([1., 2., float("nan")], dtype=torch.float, device=device), 3030*da0073e9SAndroid Build Coastguard Worker bins=4, max=3), 3031*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device)) 3032*da0073e9SAndroid Build Coastguard Worker # tensors with min > max -- should throw a RuntimeError 3033*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "max must be larger than min"): 3034*da0073e9SAndroid Build Coastguard Worker torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device), 3035*da0073e9SAndroid Build Coastguard Worker bins=4, min=5, max=1) 3036*da0073e9SAndroid Build Coastguard Worker 3037*da0073e9SAndroid Build Coastguard Worker # test against numpy.histogram() 3038*da0073e9SAndroid Build Coastguard Worker def test_against_np(tensor, bins=100, min=0, max=0): 3039*da0073e9SAndroid Build Coastguard Worker if min == 0 and max == 0: 3040*da0073e9SAndroid Build Coastguard Worker min = tensor.min().item() 3041*da0073e9SAndroid Build Coastguard Worker max = tensor.max().item() 3042*da0073e9SAndroid Build Coastguard Worker nparr = tensor.cpu().numpy() 3043*da0073e9SAndroid Build Coastguard Worker actual = torch.histc(tensor, bins=bins, min=min, max=max) 3044*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0]) 3045*da0073e9SAndroid Build Coastguard Worker actual_cpu = actual.cpu() 3046*da0073e9SAndroid Build Coastguard Worker # NB: Numpy returns a int64 tensor, like normal people... 3047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(actual_cpu)) 3048*da0073e9SAndroid Build Coastguard Worker 3049*da0073e9SAndroid Build Coastguard Worker test_against_np(torch.tensor([1., 2, 1], device=device)) 3050*da0073e9SAndroid Build Coastguard Worker test_against_np(torch.randn(5000, device=device)) 3051*da0073e9SAndroid Build Coastguard Worker 3052*da0073e9SAndroid Build Coastguard Worker # Test bins arg 3053*da0073e9SAndroid Build Coastguard Worker test_against_np(torch.randn(301, device=device), bins=10) 3054*da0073e9SAndroid Build Coastguard Worker 3055*da0073e9SAndroid Build Coastguard Worker # Test truncated range 3056*da0073e9SAndroid Build Coastguard Worker test_against_np(torch.randn(201, device=device), min=0.1, max=1) 3057*da0073e9SAndroid Build Coastguard Worker 3058*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(100, 3, device=device)[:, 2] 3059*da0073e9SAndroid Build Coastguard Worker test_against_np(noncontig) 3060*da0073e9SAndroid Build Coastguard Worker 3061*da0073e9SAndroid Build Coastguard Worker multidim = torch.randn(3, 5, 7, 2, device=device) 3062*da0073e9SAndroid Build Coastguard Worker test_against_np(multidim) 3063*da0073e9SAndroid Build Coastguard Worker 3064*da0073e9SAndroid Build Coastguard Worker expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) 3065*da0073e9SAndroid Build Coastguard Worker test_against_np(expanded) 3066*da0073e9SAndroid Build Coastguard Worker 3067*da0073e9SAndroid Build Coastguard Worker linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device) 3068*da0073e9SAndroid Build Coastguard Worker test_against_np(linear, bins=20, min=0, max=0.99) 3069*da0073e9SAndroid Build Coastguard Worker 3070*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3071*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 3072*da0073e9SAndroid Build Coastguard Worker def test_histc_lowp(self, device, dtype): 3073*da0073e9SAndroid Build Coastguard Worker actual = torch.histc( 3074*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 2, 1], dtype=dtype, device=device), bins=4, min=0, max=3) 3075*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3076*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 2, 1, 0], dtype=dtype, device=device), 3077*da0073e9SAndroid Build Coastguard Worker actual) 3078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.dtype, dtype) 3079*da0073e9SAndroid Build Coastguard Worker 3080*da0073e9SAndroid Build Coastguard Worker """ 3081*da0073e9SAndroid Build Coastguard Worker Runs torch.histogram and numpy.histogram on the specified input parameters 3082*da0073e9SAndroid Build Coastguard Worker and asserts that their output is equal. 3083*da0073e9SAndroid Build Coastguard Worker """ 3084*da0073e9SAndroid Build Coastguard Worker def _test_histogram_numpy(self, t, bins, bin_range, weights, density): 3085*da0073e9SAndroid Build Coastguard Worker def to_np(t): 3086*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(t): 3087*da0073e9SAndroid Build Coastguard Worker return t 3088*da0073e9SAndroid Build Coastguard Worker else: 3089*da0073e9SAndroid Build Coastguard Worker return t.cpu().numpy() 3090*da0073e9SAndroid Build Coastguard Worker 3091*da0073e9SAndroid Build Coastguard Worker # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays. 3092*da0073e9SAndroid Build Coastguard Worker def reference_histogram(self, t, bins, bin_range, weights, density, dtype): 3093*da0073e9SAndroid Build Coastguard Worker (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights]) 3094*da0073e9SAndroid Build Coastguard Worker (np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density) 3095*da0073e9SAndroid Build Coastguard Worker return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype)) 3096*da0073e9SAndroid Build Coastguard Worker 3097*da0073e9SAndroid Build Coastguard Worker # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one 3098*da0073e9SAndroid Build Coastguard Worker if bin_range: 3099*da0073e9SAndroid Build Coastguard Worker (actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density) 3100*da0073e9SAndroid Build Coastguard Worker else: 3101*da0073e9SAndroid Build Coastguard Worker (actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density) 3102*da0073e9SAndroid Build Coastguard Worker 3103*da0073e9SAndroid Build Coastguard Worker (expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype) 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker """ 3106*da0073e9SAndroid Build Coastguard Worker Works around linspace discrepancies by passing torch's constructed bin_edges to numpy. 3107*da0073e9SAndroid Build Coastguard Worker When bin edges are not explicitly defined, histogram uses the linspace operator internally 3108*da0073e9SAndroid Build Coastguard Worker to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly 3109*da0073e9SAndroid Build Coastguard Worker from numpy.linspace output. 3110*da0073e9SAndroid Build Coastguard Worker Issue: https://github.com/pytorch/pytorch/issues/58758 3111*da0073e9SAndroid Build Coastguard Worker """ 3112*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(bins): 3113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5) 3114*da0073e9SAndroid Build Coastguard Worker # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument 3115*da0073e9SAndroid Build Coastguard Worker (expected_hist, expected_bin_edges) = reference_histogram( 3116*da0073e9SAndroid Build Coastguard Worker self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype) 3117*da0073e9SAndroid Build Coastguard Worker 3118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_hist, expected_hist) 3119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_bin_edges, expected_bin_edges) 3120*da0073e9SAndroid Build Coastguard Worker 3121*da0073e9SAndroid Build Coastguard Worker # Test passing non-contiguous output tensors 3122*da0073e9SAndroid Build Coastguard Worker hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype, 3123*da0073e9SAndroid Build Coastguard Worker noncontiguous=True) 3124*da0073e9SAndroid Build Coastguard Worker bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype, 3125*da0073e9SAndroid Build Coastguard Worker noncontiguous=True) 3126*da0073e9SAndroid Build Coastguard Worker 3127*da0073e9SAndroid Build Coastguard Worker # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one 3128*da0073e9SAndroid Build Coastguard Worker if bin_range: 3129*da0073e9SAndroid Build Coastguard Worker torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out)) 3130*da0073e9SAndroid Build Coastguard Worker else: 3131*da0073e9SAndroid Build Coastguard Worker torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out)) 3132*da0073e9SAndroid Build Coastguard Worker 3133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hist_out, expected_hist) 3134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin_edges_out, expected_bin_edges) 3135*da0073e9SAndroid Build Coastguard Worker 3136*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3137*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 3138*da0073e9SAndroid Build Coastguard Worker def test_histogram(self, device, dtype): 3139*da0073e9SAndroid Build Coastguard Worker shapes = ( 3140*da0073e9SAndroid Build Coastguard Worker (), 3141*da0073e9SAndroid Build Coastguard Worker (0,), 3142*da0073e9SAndroid Build Coastguard Worker (1,), 3143*da0073e9SAndroid Build Coastguard Worker (1, 5), 3144*da0073e9SAndroid Build Coastguard Worker (3, 5), 3145*da0073e9SAndroid Build Coastguard Worker (1, 5, 1), 3146*da0073e9SAndroid Build Coastguard Worker (2, 3, 5)) 3147*da0073e9SAndroid Build Coastguard Worker 3148*da0073e9SAndroid Build Coastguard Worker for contig, bins_contig, bin_ct, weighted, density, shape in \ 3149*da0073e9SAndroid Build Coastguard Worker product([True, False], [True, False], range(1, 10), [True, False], [True, False], shapes): 3150*da0073e9SAndroid Build Coastguard Worker values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig) 3151*da0073e9SAndroid Build Coastguard Worker weights = make_tensor(shape, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) if weighted else None 3152*da0073e9SAndroid Build Coastguard Worker 3153*da0073e9SAndroid Build Coastguard Worker # Tests passing just the bin_ct 3154*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(values, bin_ct, None, weights, density) 3155*da0073e9SAndroid Build Coastguard Worker 3156*da0073e9SAndroid Build Coastguard Worker # Tests with caller-specified histogram range 3157*da0073e9SAndroid Build Coastguard Worker bin_range = sorted((random.uniform(-9, 9), random.uniform(-9, 9))) 3158*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(values, bin_ct, bin_range, weights, density) 3159*da0073e9SAndroid Build Coastguard Worker 3160*da0073e9SAndroid Build Coastguard Worker # Tests with range min=max 3161*da0073e9SAndroid Build Coastguard Worker bin_range[1] = bin_range[0] 3162*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(values, bin_ct, bin_range, weights, density) 3163*da0073e9SAndroid Build Coastguard Worker 3164*da0073e9SAndroid Build Coastguard Worker # Tests with caller-specified bin edges 3165*da0073e9SAndroid Build Coastguard Worker bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() 3166*da0073e9SAndroid Build Coastguard Worker if not bins_contig: 3167*da0073e9SAndroid Build Coastguard Worker # Necessary because msort always produces contiguous output 3168*da0073e9SAndroid Build Coastguard Worker bin_edges_noncontig = make_tensor(bin_ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig) 3169*da0073e9SAndroid Build Coastguard Worker bin_edges_noncontig.copy_(bin_edges) 3170*da0073e9SAndroid Build Coastguard Worker bin_edges = bin_edges_noncontig 3171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin_edges.is_contiguous(), bins_contig) 3172*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(values, bin_edges, None, weights, density) 3173*da0073e9SAndroid Build Coastguard Worker 3174*da0073e9SAndroid Build Coastguard Worker # Tests with input tensor in which all elements are equal 3175*da0073e9SAndroid Build Coastguard Worker elt = random.uniform(-9, 9) 3176*da0073e9SAndroid Build Coastguard Worker values = make_tensor(shape, dtype=dtype, device=device, low=elt, high=elt, noncontiguous=not contig) 3177*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(values, bin_ct, bin_range, weights, density) 3178*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(values, bin_edges, None, weights, density) 3179*da0073e9SAndroid Build Coastguard Worker 3180*da0073e9SAndroid Build Coastguard Worker # Tests with input equal to bin_edges 3181*da0073e9SAndroid Build Coastguard Worker weights = ( 3182*da0073e9SAndroid Build Coastguard Worker make_tensor(bin_ct + 1, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) 3183*da0073e9SAndroid Build Coastguard Worker if weighted 3184*da0073e9SAndroid Build Coastguard Worker else None 3185*da0073e9SAndroid Build Coastguard Worker ) 3186*da0073e9SAndroid Build Coastguard Worker self._test_histogram_numpy(bin_edges, bin_edges, None, weights, density) 3187*da0073e9SAndroid Build Coastguard Worker 3188*da0073e9SAndroid Build Coastguard Worker # Tests values of default args 3189*da0073e9SAndroid Build Coastguard Worker for bin_ct, shape in product(range(1, 10), shapes): 3190*da0073e9SAndroid Build Coastguard Worker values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) 3191*da0073e9SAndroid Build Coastguard Worker (actual_hist, actual_bin_edges) = torch.histogram(values, bin_ct) 3192*da0073e9SAndroid Build Coastguard Worker (expected_hist, expected_bin_edges) = torch.histogram( 3193*da0073e9SAndroid Build Coastguard Worker values, bin_ct, range=None, weight=None, density=False) 3194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_hist, expected_hist) 3195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_bin_edges, expected_bin_edges) 3196*da0073e9SAndroid Build Coastguard Worker 3197*da0073e9SAndroid Build Coastguard Worker """ 3198*da0073e9SAndroid Build Coastguard Worker Runs torch.histogramdd and numpy.histogramdd on the specified input parameters 3199*da0073e9SAndroid Build Coastguard Worker and asserts that their output is equal. 3200*da0073e9SAndroid Build Coastguard Worker """ 3201*da0073e9SAndroid Build Coastguard Worker def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density): 3202*da0073e9SAndroid Build Coastguard Worker def to_np(t): 3203*da0073e9SAndroid Build Coastguard Worker if type(t) == list: 3204*da0073e9SAndroid Build Coastguard Worker return list(map(to_np, t)) 3205*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(t): 3206*da0073e9SAndroid Build Coastguard Worker return t 3207*da0073e9SAndroid Build Coastguard Worker return t.cpu().numpy() 3208*da0073e9SAndroid Build Coastguard Worker 3209*da0073e9SAndroid Build Coastguard Worker # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays. 3210*da0073e9SAndroid Build Coastguard Worker def reference_histogramdd(t, bins, bin_range, weights, density, dtype): 3211*da0073e9SAndroid Build Coastguard Worker (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights]) 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Worker # numpy.histogramdd accepts only (N, D) shapes 3214*da0073e9SAndroid Build Coastguard Worker D = np_t.shape[-1] 3215*da0073e9SAndroid Build Coastguard Worker N = np.prod(np_t.shape[:-1]) 3216*da0073e9SAndroid Build Coastguard Worker reshaped_t = np.reshape(np_t, (N, D)) 3217*da0073e9SAndroid Build Coastguard Worker reshaped_wt = np.reshape(np_weights, (N,)) if np_weights is not None else None 3218*da0073e9SAndroid Build Coastguard Worker 3219*da0073e9SAndroid Build Coastguard Worker # numpy.histogramdd throws an error for D=0 3220*da0073e9SAndroid Build Coastguard Worker if D == 0: 3221*da0073e9SAndroid Build Coastguard Worker return (torch.tensor(float('nan') if density else 0.), []) 3222*da0073e9SAndroid Build Coastguard Worker 3223*da0073e9SAndroid Build Coastguard Worker # numpy.histogramdd expects range to be specified as a sequence of D (lower, upper) tuples 3224*da0073e9SAndroid Build Coastguard Worker reshaped_range = None if not bin_range else [(bin_range[2 * i], bin_range[2 * i + 1]) for i in range(D)] 3225*da0073e9SAndroid Build Coastguard Worker 3226*da0073e9SAndroid Build Coastguard Worker (np_hist, np_bin_edges) = np.histogramdd(reshaped_t, np_bins, 3227*da0073e9SAndroid Build Coastguard Worker range=reshaped_range, weights=reshaped_wt, density=density) 3228*da0073e9SAndroid Build Coastguard Worker 3229*da0073e9SAndroid Build Coastguard Worker return (torch.from_numpy(np_hist).to(dtype), [torch.from_numpy(t).to(dtype) for t in np_bin_edges]) 3230*da0073e9SAndroid Build Coastguard Worker 3231*da0073e9SAndroid Build Coastguard Worker (actual_hist, actual_bin_edges) = torch.histogramdd(t, bins, range=bin_range, weight=weights, density=density) 3232*da0073e9SAndroid Build Coastguard Worker (expected_hist, expected_bin_edges) = reference_histogramdd(t, bins, bin_range, weights, density, actual_hist.dtype) 3233*da0073e9SAndroid Build Coastguard Worker 3234*da0073e9SAndroid Build Coastguard Worker D = len(actual_bin_edges) 3235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(D, len(expected_bin_edges)) 3236*da0073e9SAndroid Build Coastguard Worker 3237*da0073e9SAndroid Build Coastguard Worker """ 3238*da0073e9SAndroid Build Coastguard Worker Works around linspace discrepancies by passing torch's constructed bin_edges to numpy. 3239*da0073e9SAndroid Build Coastguard Worker When bin edges are not explicitly defined, histogram uses the linspace operator internally 3240*da0073e9SAndroid Build Coastguard Worker to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly 3241*da0073e9SAndroid Build Coastguard Worker from numpy.linspace output. 3242*da0073e9SAndroid Build Coastguard Worker Issue: https://github.com/pytorch/pytorch/issues/58758 3243*da0073e9SAndroid Build Coastguard Worker """ 3244*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(bins): 3245*da0073e9SAndroid Build Coastguard Worker for dim in range(D): 3246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5) 3247*da0073e9SAndroid Build Coastguard Worker # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument 3248*da0073e9SAndroid Build Coastguard Worker (expected_hist, expected_bin_edges) = reference_histogramdd( 3249*da0073e9SAndroid Build Coastguard Worker t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype) 3250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(D, len(expected_bin_edges)) 3251*da0073e9SAndroid Build Coastguard Worker 3252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_hist, expected_hist) 3253*da0073e9SAndroid Build Coastguard Worker for dim in range(D): 3254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim]) 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3257*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 3258*da0073e9SAndroid Build Coastguard Worker def test_histogramdd(self, device, dtype): 3259*da0073e9SAndroid Build Coastguard Worker shapes = ( 3260*da0073e9SAndroid Build Coastguard Worker (1, 5), 3261*da0073e9SAndroid Build Coastguard Worker (3, 5), 3262*da0073e9SAndroid Build Coastguard Worker (1, 5, 1), 3263*da0073e9SAndroid Build Coastguard Worker (2, 3, 5), 3264*da0073e9SAndroid Build Coastguard Worker (7, 7, 7, 7), 3265*da0073e9SAndroid Build Coastguard Worker (16, 8, 4, 2), 3266*da0073e9SAndroid Build Coastguard Worker (10, 10, 10), 3267*da0073e9SAndroid Build Coastguard Worker (7, 0, 3), 3268*da0073e9SAndroid Build Coastguard Worker (5, 0),) 3269*da0073e9SAndroid Build Coastguard Worker 3270*da0073e9SAndroid Build Coastguard Worker for contig, bins_contig, weighted, density, shape in \ 3271*da0073e9SAndroid Build Coastguard Worker product([True, False], [True, False], [True, False], [True, False], shapes): 3272*da0073e9SAndroid Build Coastguard Worker D = shape[-1] 3273*da0073e9SAndroid Build Coastguard Worker 3274*da0073e9SAndroid Build Coastguard Worker values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig) 3275*da0073e9SAndroid Build Coastguard Worker weights = ( 3276*da0073e9SAndroid Build Coastguard Worker make_tensor(shape[:-1], dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) 3277*da0073e9SAndroid Build Coastguard Worker if weighted 3278*da0073e9SAndroid Build Coastguard Worker else None 3279*da0073e9SAndroid Build Coastguard Worker ) 3280*da0073e9SAndroid Build Coastguard Worker 3281*da0073e9SAndroid Build Coastguard Worker # Tests passing a single bin count 3282*da0073e9SAndroid Build Coastguard Worker bin_ct = random.randint(1, 5) 3283*da0073e9SAndroid Build Coastguard Worker self._test_histogramdd_numpy(values, bin_ct, None, weights, density) 3284*da0073e9SAndroid Build Coastguard Worker 3285*da0073e9SAndroid Build Coastguard Worker # Tests passing a bin count for each dimension 3286*da0073e9SAndroid Build Coastguard Worker bin_ct = [random.randint(1, 5) for dim in range(D)] 3287*da0073e9SAndroid Build Coastguard Worker self._test_histogramdd_numpy(values, bin_ct, None, weights, density) 3288*da0073e9SAndroid Build Coastguard Worker 3289*da0073e9SAndroid Build Coastguard Worker # Tests with caller-specified histogram range 3290*da0073e9SAndroid Build Coastguard Worker bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)] 3291*da0073e9SAndroid Build Coastguard Worker bin_range = [elt for t in bin_range_tuples for elt in t] 3292*da0073e9SAndroid Build Coastguard Worker self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density) 3293*da0073e9SAndroid Build Coastguard Worker 3294*da0073e9SAndroid Build Coastguard Worker # Tests with range min=max 3295*da0073e9SAndroid Build Coastguard Worker for dim in range(D): 3296*da0073e9SAndroid Build Coastguard Worker bin_range[2 * dim + 1] = bin_range[2 * dim] 3297*da0073e9SAndroid Build Coastguard Worker self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density) 3298*da0073e9SAndroid Build Coastguard Worker 3299*da0073e9SAndroid Build Coastguard Worker # Tests with caller-specified bin edges 3300*da0073e9SAndroid Build Coastguard Worker bin_edges = [make_tensor(ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() for ct in bin_ct] 3301*da0073e9SAndroid Build Coastguard Worker if not bins_contig: 3302*da0073e9SAndroid Build Coastguard Worker # Necessary because msort always produces contiguous output 3303*da0073e9SAndroid Build Coastguard Worker bin_edges_noncontig = [ 3304*da0073e9SAndroid Build Coastguard Worker make_tensor(ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig) 3305*da0073e9SAndroid Build Coastguard Worker for ct in bin_ct 3306*da0073e9SAndroid Build Coastguard Worker ] 3307*da0073e9SAndroid Build Coastguard Worker for dim in range(D): 3308*da0073e9SAndroid Build Coastguard Worker bin_edges_noncontig[dim].copy_(bin_edges[dim]) 3309*da0073e9SAndroid Build Coastguard Worker bin_edges = bin_edges_noncontig 3310*da0073e9SAndroid Build Coastguard Worker for dim in range(D): 3311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig) 3312*da0073e9SAndroid Build Coastguard Worker self._test_histogramdd_numpy(values, bin_edges, None, weights, density) 3313*da0073e9SAndroid Build Coastguard Worker 3314*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3315*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32) 3316*da0073e9SAndroid Build Coastguard Worker def test_histogram_error_handling(self, device, dtype): 3317*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'not implemented for'): 3318*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=torch.int32, device=device) 3319*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 1) 3320*da0073e9SAndroid Build Coastguard Worker 3321*da0073e9SAndroid Build Coastguard Worker inconsistent_dtype = torch.float32 if dtype != torch.float32 else torch.float64 3322*da0073e9SAndroid Build Coastguard Worker 3323*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensors should have the same dtype'): 3324*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3325*da0073e9SAndroid Build Coastguard Worker bins = make_tensor((), dtype=inconsistent_dtype, device=device) 3326*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, bins) 3327*da0073e9SAndroid Build Coastguard Worker 3328*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'input tensor and weight tensor should have the same dtype'): 3329*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3330*da0073e9SAndroid Build Coastguard Worker weight = make_tensor((), dtype=inconsistent_dtype, device=device) 3331*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 1, weight=weight) 3332*da0073e9SAndroid Build Coastguard Worker 3333*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'input tensor and hist tensor should have the same dtype'): 3334*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3335*da0073e9SAndroid Build Coastguard Worker hist = make_tensor((), dtype=inconsistent_dtype, device=device) 3336*da0073e9SAndroid Build Coastguard Worker bin_edges = make_tensor((), dtype=dtype, device=device) 3337*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 1, out=(hist, bin_edges)) 3338*da0073e9SAndroid Build Coastguard Worker 3339*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'input tensor and bin_edges tensor should have the same dtype'): 3340*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3341*da0073e9SAndroid Build Coastguard Worker hist = make_tensor((), dtype=dtype, device=device) 3342*da0073e9SAndroid Build Coastguard Worker bin_edges = make_tensor((), dtype=inconsistent_dtype, device=device) 3343*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 1, out=(hist, bin_edges)) 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'bins tensor should have one dimension'): 3346*da0073e9SAndroid Build Coastguard Worker t = make_tensor((2, 2), dtype=dtype, device=device) 3347*da0073e9SAndroid Build Coastguard Worker torch.histogram(t, t) 3348*da0073e9SAndroid Build Coastguard Worker 3349*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'bins tensor should have at least 1 element'): 3350*da0073e9SAndroid Build Coastguard Worker t = make_tensor((0), dtype=dtype, device=device) 3351*da0073e9SAndroid Build Coastguard Worker torch.histogram(t, t) 3352*da0073e9SAndroid Build Coastguard Worker 3353*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): 3354*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3355*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, -1) 3356*da0073e9SAndroid Build Coastguard Worker 3357*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'if weight tensor is provided it should have the same shape \ 3358*da0073e9SAndroid Build Coastguard Workeras the input tensor excluding its innermost dimension'): 3359*da0073e9SAndroid Build Coastguard Worker values = make_tensor((2, 2), dtype=dtype, device=device) 3360*da0073e9SAndroid Build Coastguard Worker weight = make_tensor((1), dtype=dtype, device=device) 3361*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 1, weight=weight) 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 'received an invalid combination of arguments'): 3364*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3365*da0073e9SAndroid Build Coastguard Worker bin_edges = make_tensor((), dtype=dtype, device=device) 3366*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, bin_edges, range=(0, 1)) 3367*da0073e9SAndroid Build Coastguard Worker 3368*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'min should not exceed max'): 3369*da0073e9SAndroid Build Coastguard Worker values = make_tensor((), dtype=dtype, device=device) 3370*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 2, range=(1, 0)) 3371*da0073e9SAndroid Build Coastguard Worker 3372*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'range \[nan, nan\] is not finite'): 3373*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([float("nan")], device=device, dtype=dtype) 3374*da0073e9SAndroid Build Coastguard Worker torch.histogram(values, 2) 3375*da0073e9SAndroid Build Coastguard Worker 3376*da0073e9SAndroid Build Coastguard Worker # Tests to ensure that reduction functions employing comparison operators are usable when there 3377*da0073e9SAndroid Build Coastguard Worker # exists a zero dimension (i.e. when the tensors are empty) in the tensor. These tests specifically 3378*da0073e9SAndroid Build Coastguard Worker # cater to functions where specifying the `dim` parameter is necessary. 3379*da0073e9SAndroid Build Coastguard Worker def test_tensor_compare_ops_empty(self, device): 3380*da0073e9SAndroid Build Coastguard Worker shape = (2, 0, 4) 3381*da0073e9SAndroid Build Coastguard Worker master_input = torch.randn(shape, device=device) 3382*da0073e9SAndroid Build Coastguard Worker np_input = np.empty(shape) 3383*da0073e9SAndroid Build Coastguard Worker test_functions = [ 3384*da0073e9SAndroid Build Coastguard Worker ('amax', torch.amax, np.amax), 3385*da0073e9SAndroid Build Coastguard Worker ('amin', torch.amin, np.amin), 3386*da0073e9SAndroid Build Coastguard Worker ('max', lambda *args, **kwargs: torch.max(*args, **kwargs).values, np.max), 3387*da0073e9SAndroid Build Coastguard Worker ('min', lambda *args, **kwargs: torch.min(*args, **kwargs).values, np.min), 3388*da0073e9SAndroid Build Coastguard Worker ('median', lambda *args, **kwargs: torch.median(*args, **kwargs).values, np.median), 3389*da0073e9SAndroid Build Coastguard Worker ] 3390*da0073e9SAndroid Build Coastguard Worker 3391*da0073e9SAndroid Build Coastguard Worker for name, fn, np_function in test_functions: 3392*da0073e9SAndroid Build Coastguard Worker # Check if reduction happens along the specified dim with and without keepdim. Check with 3393*da0073e9SAndroid Build Coastguard Worker # numpy to maintain compatibility with numpy functions. 3394*da0073e9SAndroid Build Coastguard Worker error_msg = f"test function: {name}" 3395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg) 3396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=2), 3397*da0073e9SAndroid Build Coastguard Worker fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False) 3398*da0073e9SAndroid Build Coastguard Worker 3399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg) 3400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=-1), 3401*da0073e9SAndroid Build Coastguard Worker fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False) 3402*da0073e9SAndroid Build Coastguard Worker 3403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True), 3404*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=2, keepdims=True), 3406*da0073e9SAndroid Build Coastguard Worker fn(master_input, dim=2, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False) 3407*da0073e9SAndroid Build Coastguard Worker 3408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True), 3409*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=-1, keepdims=True), 3411*da0073e9SAndroid Build Coastguard Worker fn(master_input, dim=-1, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False) 3412*da0073e9SAndroid Build Coastguard Worker 3413*da0073e9SAndroid Build Coastguard Worker # Check if function raises error on specified zero'd dimension as reduction dim. 3414*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1)) 3415*da0073e9SAndroid Build Coastguard Worker 3416*da0073e9SAndroid Build Coastguard Worker # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using comparison operators 3417*da0073e9SAndroid Build Coastguard Worker # raises an error if no `dim` parameter is specified. This exists separately from tests in 3418*da0073e9SAndroid Build Coastguard Worker # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does 3419*da0073e9SAndroid Build Coastguard Worker # not throw errors. Also, checking the return type of argmax requires supplying a different dtype 3420*da0073e9SAndroid Build Coastguard Worker # argument than that for the input tensor. There is also variantion in numpy testing. 3421*da0073e9SAndroid Build Coastguard Worker def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device): 3422*da0073e9SAndroid Build Coastguard Worker shape = (2, 0, 4) 3423*da0073e9SAndroid Build Coastguard Worker master_input = torch.randn(shape, device=device) 3424*da0073e9SAndroid Build Coastguard Worker np_input = np.empty(shape) 3425*da0073e9SAndroid Build Coastguard Worker test_functions = [ 3426*da0073e9SAndroid Build Coastguard Worker ('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax), 3427*da0073e9SAndroid Build Coastguard Worker ('argmin', torch.argmin, {'dtype': torch.int64}, np.argmin), 3428*da0073e9SAndroid Build Coastguard Worker ('kthvalue', lambda *args, k=1, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values, 3429*da0073e9SAndroid Build Coastguard Worker {}, lambda *args, k=1, axis=None, **kwargs: np.partition(*args, k, **kwargs).take(k - 1, axis=axis)) 3430*da0073e9SAndroid Build Coastguard Worker ] 3431*da0073e9SAndroid Build Coastguard Worker 3432*da0073e9SAndroid Build Coastguard Worker for name, fn, dtype, np_function in test_functions: 3433*da0073e9SAndroid Build Coastguard Worker error_msg = f"test function: {name}" 3434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_msg) 3435*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3436*da0073e9SAndroid Build Coastguard Worker np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False 3437*da0073e9SAndroid Build Coastguard Worker ) 3438*da0073e9SAndroid Build Coastguard Worker 3439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_msg) 3440*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3441*da0073e9SAndroid Build Coastguard Worker np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False 3442*da0073e9SAndroid Build Coastguard Worker ) 3443*da0073e9SAndroid Build Coastguard Worker 3444*da0073e9SAndroid Build Coastguard Worker # keepdim variant does not exist for numpy 3445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=True), 3446*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=-1, keepdim=True), 3448*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3449*da0073e9SAndroid Build Coastguard Worker 3450*da0073e9SAndroid Build Coastguard Worker # Check if function raises error on specified zero'd dimension as reduction dim. 3451*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1)) 3452*da0073e9SAndroid Build Coastguard Worker if name != 'kthvalue': 3453*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input)) 3454*da0073e9SAndroid Build Coastguard Worker 3455*da0073e9SAndroid Build Coastguard Worker # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using math operators works when a 3456*da0073e9SAndroid Build Coastguard Worker # non-zero dim is specified for the reduction and throws an error when the dim specified is 0. Although 3457*da0073e9SAndroid Build Coastguard Worker # there is some repetition with test_tensor_compare_ops_optional_dim_empty and test_tensor_compare_ops_empty, 3458*da0073e9SAndroid Build Coastguard Worker # these tests are kept separate since tests for math operators also require checking for correctness of the 3459*da0073e9SAndroid Build Coastguard Worker # returned data using allclose() or isinf() which does not exists in the former tests. 3460*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 3461*da0073e9SAndroid Build Coastguard Worker def test_tensor_reduce_ops_empty(self, device): 3462*da0073e9SAndroid Build Coastguard Worker from scipy.special import logsumexp 3463*da0073e9SAndroid Build Coastguard Worker shape = (2, 0, 4) 3464*da0073e9SAndroid Build Coastguard Worker master_input = torch.randn(shape, device=device) 3465*da0073e9SAndroid Build Coastguard Worker np_input = np.empty(shape) 3466*da0073e9SAndroid Build Coastguard Worker test_functions = [ 3467*da0073e9SAndroid Build Coastguard Worker ('prod', torch.prod, 1., np.prod), 3468*da0073e9SAndroid Build Coastguard Worker ('sum', torch.sum, 0., np.sum), 3469*da0073e9SAndroid Build Coastguard Worker ('norm', torch.norm, 0., np.linalg.norm), 3470*da0073e9SAndroid Build Coastguard Worker ('mean', torch.mean, nan, np.mean), 3471*da0073e9SAndroid Build Coastguard Worker ('var', torch.var, nan, np.var), 3472*da0073e9SAndroid Build Coastguard Worker ('std', torch.std, nan, np.std), 3473*da0073e9SAndroid Build Coastguard Worker ('logsumexp', torch.logsumexp, -inf, logsumexp), 3474*da0073e9SAndroid Build Coastguard Worker ] 3475*da0073e9SAndroid Build Coastguard Worker 3476*da0073e9SAndroid Build Coastguard Worker for name, fn, return_value, np_function in test_functions: 3477*da0073e9SAndroid Build Coastguard Worker # Check if reduction happens along the specified dimension. 3478*da0073e9SAndroid Build Coastguard Worker error_msg = f"test function: {name}" 3479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg) 3480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, 3481*da0073e9SAndroid Build Coastguard Worker exact_dtype=False) 3482*da0073e9SAndroid Build Coastguard Worker 3483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg) 3484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, 3485*da0073e9SAndroid Build Coastguard Worker exact_dtype=False) 3486*da0073e9SAndroid Build Coastguard Worker 3487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True), 3488*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True), 3490*da0073e9SAndroid Build Coastguard Worker msg=error_msg, exact_dtype=False) 3491*da0073e9SAndroid Build Coastguard Worker 3492*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True), 3493*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True), 3495*da0073e9SAndroid Build Coastguard Worker msg=error_msg, exact_dtype=False) 3496*da0073e9SAndroid Build Coastguard Worker 3497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg) 3498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg) 3499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True), 3500*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True), 3502*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3503*da0073e9SAndroid Build Coastguard Worker 3504*da0073e9SAndroid Build Coastguard Worker if name != 'logsumexp': 3505*da0073e9SAndroid Build Coastguard Worker # The scipy function does not work for reduction the zero dimension 3506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(), 3507*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(), 3509*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)), 3511*da0073e9SAndroid Build Coastguard Worker fn(master_input, dim=1, keepdim=True).cpu().numpy(), 3512*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)), 3514*da0073e9SAndroid Build Coastguard Worker fn(master_input, dim=-2, keepdim=True).cpu().numpy(), 3515*da0073e9SAndroid Build Coastguard Worker msg=error_msg) 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker # logsumexp throws a type error when not specifying dim so test separately. 3518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg) 3519*da0073e9SAndroid Build Coastguard Worker else: 3520*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: fn(master_input)) 3521*da0073e9SAndroid Build Coastguard Worker 3522*da0073e9SAndroid Build Coastguard Worker # Tests to ensure that any() and all() functions work with zero-dim tensors. Kept separate from 3523*da0073e9SAndroid Build Coastguard Worker # other tests for checking reduction with zero-dim tensors because these tests have significantly 3524*da0073e9SAndroid Build Coastguard Worker # different testing behaviour than that used for the former tests. 3525*da0073e9SAndroid Build Coastguard Worker def test_reduction_empty_any_all(self, device): 3526*da0073e9SAndroid Build Coastguard Worker shape = (2, 0, 4) 3527*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 3528*da0073e9SAndroid Build Coastguard Worker 3529*da0073e9SAndroid Build Coastguard Worker for dtype in all_types_and_complex_and(torch.half, torch.bool): 3530*da0073e9SAndroid Build Coastguard Worker # Refer: [all, any uint8 compatibility] 3531*da0073e9SAndroid Build Coastguard Worker if dtype == torch.uint8: 3532*da0073e9SAndroid Build Coastguard Worker out_dtype = torch.uint8 3533*da0073e9SAndroid Build Coastguard Worker else: 3534*da0073e9SAndroid Build Coastguard Worker out_dtype = torch.bool # output of all/any is bool irrespective of input dtype 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker xb = x.to(dtype) 3537*da0073e9SAndroid Build Coastguard Worker yb = x.to(dtype) 3538*da0073e9SAndroid Build Coastguard Worker # any 3539*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0), xb.any(2).shape) 3540*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) 3541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1)) 3542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True)) 3543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any()) 3544*da0073e9SAndroid Build Coastguard Worker 3545*da0073e9SAndroid Build Coastguard Worker # all 3546*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0), xb.all(2).shape) 3547*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) 3548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1)) 3549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True)) 3550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all()) 3551*da0073e9SAndroid Build Coastguard Worker 3552*da0073e9SAndroid Build Coastguard Worker # TODO: can these be merged with their respective OpInfos? 3553*da0073e9SAndroid Build Coastguard Worker def test_reduce_dtype(self, device): 3554*da0073e9SAndroid Build Coastguard Worker def test_reduction(op, has_no_dim, takes_dtype=True): 3555*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, dtype=torch.float, requires_grad=True, device=device) 3556*da0073e9SAndroid Build Coastguard Worker 3557*da0073e9SAndroid Build Coastguard Worker if has_no_dim: 3558*da0073e9SAndroid Build Coastguard Worker grad1, = torch.autograd.grad([op(x)], [x]) 3559*da0073e9SAndroid Build Coastguard Worker grad2, = torch.autograd.grad([op(x, dtype=torch.double)], [x]) 3560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2) 3561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad2.dtype, torch.float) 3562*da0073e9SAndroid Build Coastguard Worker 3563*da0073e9SAndroid Build Coastguard Worker gi = torch.randn(op(x, dim=0).shape, dtype=torch.float, device=device) 3564*da0073e9SAndroid Build Coastguard Worker grad1, = torch.autograd.grad([op(x, dim=0)], [x], gi) 3565*da0073e9SAndroid Build Coastguard Worker if takes_dtype: 3566*da0073e9SAndroid Build Coastguard Worker grad2, = torch.autograd.grad([op(x, dim=0, dtype=torch.double)], [x], gi.double()) 3567*da0073e9SAndroid Build Coastguard Worker else: 3568*da0073e9SAndroid Build Coastguard Worker grad2, = torch.autograd.grad([op(x.double(), dim=0)], [x], gi.double()) 3569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad1, grad2) 3570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad2.dtype, torch.float) 3571*da0073e9SAndroid Build Coastguard Worker 3572*da0073e9SAndroid Build Coastguard Worker test_reduction(torch.sum, True) 3573*da0073e9SAndroid Build Coastguard Worker test_reduction(torch.prod, True) 3574*da0073e9SAndroid Build Coastguard Worker test_reduction(torch.cumsum, False) 3575*da0073e9SAndroid Build Coastguard Worker test_reduction(torch.cumprod, False) 3576*da0073e9SAndroid Build Coastguard Worker test_reduction(torch.logcumsumexp, False, takes_dtype=False) 3577*da0073e9SAndroid Build Coastguard Worker 3578*da0073e9SAndroid Build Coastguard Worker @ops(reference_masked_ops) 3579*da0073e9SAndroid Build Coastguard Worker def test_reference_masked(self, device, dtype, op): 3580*da0073e9SAndroid Build Coastguard Worker """Test masked reduction operations on strided-only tensors using 3581*da0073e9SAndroid Build Coastguard Worker numpy reductions as reference. 3582*da0073e9SAndroid Build Coastguard Worker """ 3583*da0073e9SAndroid Build Coastguard Worker 3584*da0073e9SAndroid Build Coastguard Worker def to_numpy(input): 3585*da0073e9SAndroid Build Coastguard Worker if input.dtype is torch.bfloat16: 3586*da0073e9SAndroid Build Coastguard Worker return input.cpu().to(torch.float32).numpy() 3587*da0073e9SAndroid Build Coastguard Worker else: 3588*da0073e9SAndroid Build Coastguard Worker return input.cpu().numpy() 3589*da0073e9SAndroid Build Coastguard Worker 3590*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs_func(op, device, dtype, requires_grad=False) 3591*da0073e9SAndroid Build Coastguard Worker for sample_input in samples: 3592*da0073e9SAndroid Build Coastguard Worker t = sample_input.input 3593*da0073e9SAndroid Build Coastguard Worker actual = op(t, *sample_input.args, **sample_input.kwargs) 3594*da0073e9SAndroid Build Coastguard Worker exact_dtype = not (t.dtype is torch.bfloat16 3595*da0073e9SAndroid Build Coastguard Worker or (op.promotes_int_to_float and not torch.is_floating_point(t))) 3596*da0073e9SAndroid Build Coastguard Worker expected = op.ref(to_numpy(t), *sample_input.args, 3597*da0073e9SAndroid Build Coastguard Worker **dict( 3598*da0073e9SAndroid Build Coastguard Worker # `identity` is mapped to numpy reduction `initial` argument 3599*da0073e9SAndroid Build Coastguard Worker identity=torch.masked._reduction_identity(op.name, t), 3600*da0073e9SAndroid Build Coastguard Worker **sample_input.kwargs)) 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker # Workaround https://github.com/pytorch/pytorch/issues/66556 3603*da0073e9SAndroid Build Coastguard Worker expected = np.asarray(expected) # transform numpy scalars to numpy.ndarray instances 3604*da0073e9SAndroid Build Coastguard Worker 3605*da0073e9SAndroid Build Coastguard Worker # Numpy differs, producing uint32 on Windows 3606*da0073e9SAndroid Build Coastguard Worker if expected.dtype in [np.uint64, np.uint32]: 3607*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 3608*da0073e9SAndroid Build Coastguard Worker 3609*da0073e9SAndroid Build Coastguard Worker msg = ("Failed to produce expected results! Input tensor was" 3610*da0073e9SAndroid Build Coastguard Worker f" {t}, torch result is {actual}, and reference result is" 3611*da0073e9SAndroid Build Coastguard Worker f" {expected}.") if t.numel() < 10 else None 3612*da0073e9SAndroid Build Coastguard Worker 3613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, msg, exact_dtype=exact_dtype) 3614*da0073e9SAndroid Build Coastguard Worker 3615*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3616*da0073e9SAndroid Build Coastguard Worker @largeTensorTest("8GB") 3617*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.chalf, torch.bfloat16) 3618*da0073e9SAndroid Build Coastguard Worker def test_reductions_large_half_tensors(self, device, dtype): 3619*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2**31, device=device, dtype=dtype) 3620*da0073e9SAndroid Build Coastguard Worker t[2**30:] = -1 3621*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(0, device=device, dtype=dtype) 3622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(t), expected) 3623*da0073e9SAndroid Build Coastguard Worker 3624*da0073e9SAndroid Build Coastguard Worker # mean_cuda is not implemented for ComplexHalf 3625*da0073e9SAndroid Build Coastguard Worker err_msg = "not implemented for 'ComplexHalf'" 3626*da0073e9SAndroid Build Coastguard Worker ctx = self.assertRaisesRegex( 3627*da0073e9SAndroid Build Coastguard Worker RuntimeError, err_msg) if dtype is torch.chalf else contextlib.nullcontext() 3628*da0073e9SAndroid Build Coastguard Worker with ctx: 3629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mean(t), expected) 3630*da0073e9SAndroid Build Coastguard Worker 3631*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestReductions, globals()) 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 3634*da0073e9SAndroid Build Coastguard Worker run_tests() 3635