1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport itertools 4*da0073e9SAndroid Build Coastguard Workerimport math 5*da0073e9SAndroid Build Coastguard Workerimport operator 6*da0073e9SAndroid Build Coastguard Workerimport random 7*da0073e9SAndroid Build Coastguard Workerimport warnings 8*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 9*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain, product 10*da0073e9SAndroid Build Coastguard Workerfrom numbers import Number 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport numpy as np 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torch 15*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD 16*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 19*da0073e9SAndroid Build Coastguard Worker deviceCountAtLeast, 20*da0073e9SAndroid Build Coastguard Worker dtypes, 21*da0073e9SAndroid Build Coastguard Worker dtypesIfCPU, 22*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 23*da0073e9SAndroid Build Coastguard Worker expectedFailureMeta, 24*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 25*da0073e9SAndroid Build Coastguard Worker onlyCPU, 26*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 27*da0073e9SAndroid Build Coastguard Worker onlyNativeDeviceTypes, 28*da0073e9SAndroid Build Coastguard Worker OpDTypes, 29*da0073e9SAndroid Build Coastguard Worker ops, 30*da0073e9SAndroid Build Coastguard Worker precisionOverride, 31*da0073e9SAndroid Build Coastguard Worker skipIf, 32*da0073e9SAndroid Build Coastguard Worker skipMeta, 33*da0073e9SAndroid Build Coastguard Worker) 34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 35*da0073e9SAndroid Build Coastguard Worker all_types_and, 36*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and, 37*da0073e9SAndroid Build Coastguard Worker complex_types, 38*da0073e9SAndroid Build Coastguard Worker floating_and_complex_types, 39*da0073e9SAndroid Build Coastguard Worker floating_types_and, 40*da0073e9SAndroid Build Coastguard Worker get_all_int_dtypes, 41*da0073e9SAndroid Build Coastguard Worker get_all_math_dtypes, 42*da0073e9SAndroid Build Coastguard Worker integral_types, 43*da0073e9SAndroid Build Coastguard Worker integral_types_and, 44*da0073e9SAndroid Build Coastguard Worker) 45*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import ( 46*da0073e9SAndroid Build Coastguard Worker binary_ufuncs, 47*da0073e9SAndroid Build Coastguard Worker binary_ufuncs_and_refs, 48*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_broadcasting_tensors, 49*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_extremal_value_tensors, 50*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_large_value_tensors, 51*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_small_value_tensors, 52*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_tensors, 53*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_with_scalar_and_type_promotion_samples, 54*da0073e9SAndroid Build Coastguard Worker generate_elementwise_binary_with_scalar_samples, 55*da0073e9SAndroid Build Coastguard Worker) 56*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 57*da0073e9SAndroid Build Coastguard Worker gradcheck, 58*da0073e9SAndroid Build Coastguard Worker iter_indices, 59*da0073e9SAndroid Build Coastguard Worker numpy_to_torch_dtype_dict, 60*da0073e9SAndroid Build Coastguard Worker run_tests, 61*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 62*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 63*da0073e9SAndroid Build Coastguard Worker slowTest, 64*da0073e9SAndroid Build Coastguard Worker TEST_SCIPY, 65*da0073e9SAndroid Build Coastguard Worker TestCase, 66*da0073e9SAndroid Build Coastguard Worker torch_to_numpy_dtype_dict, 67*da0073e9SAndroid Build Coastguard Worker xfailIfTorchDynamo, 68*da0073e9SAndroid Build Coastguard Worker) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerif TEST_SCIPY: 72*da0073e9SAndroid Build Coastguard Worker import scipy.integrate 73*da0073e9SAndroid Build Coastguard Worker import scipy.special 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker# TODO: update to use opinfos consistently 77*da0073e9SAndroid Build Coastguard Workerclass TestBinaryUfuncs(TestCase): 78*da0073e9SAndroid Build Coastguard Worker # Generic tests for elementwise binary (AKA binary universal (u) functions (funcs)) 79*da0073e9SAndroid Build Coastguard Worker # TODO: below contiguous tensor results are compared with a variety of noncontiguous results. 80*da0073e9SAndroid Build Coastguard Worker # It would be interesting to have the lhs and rhs have different discontiguities. 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker # Helper for comparing torch tensors and NumPy arrays 83*da0073e9SAndroid Build Coastguard Worker # TODO: should this or assertEqual also validate that strides are equal? 84*da0073e9SAndroid Build Coastguard Worker def assertEqualHelper( 85*da0073e9SAndroid Build Coastguard Worker self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs 86*da0073e9SAndroid Build Coastguard Worker ): 87*da0073e9SAndroid Build Coastguard Worker assert isinstance(actual, torch.Tensor) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker # Some NumPy functions return scalars, not arrays 90*da0073e9SAndroid Build Coastguard Worker if isinstance(expected, Number): 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.item(), expected, msg=msg, **kwargs) 92*da0073e9SAndroid Build Coastguard Worker elif isinstance(expected, np.ndarray): 93*da0073e9SAndroid Build Coastguard Worker # Handles exact dtype comparisons between arrays and tensors 94*da0073e9SAndroid Build Coastguard Worker if exact_dtype: 95*da0073e9SAndroid Build Coastguard Worker # Allows array dtype to be float32 when comparing with bfloat16 tensors 96*da0073e9SAndroid Build Coastguard Worker # since NumPy doesn't support the bfloat16 dtype 97*da0073e9SAndroid Build Coastguard Worker # Also ops like scipy.special.erf, scipy.special.erfc, etc, promote float16 98*da0073e9SAndroid Build Coastguard Worker # to float32 99*da0073e9SAndroid Build Coastguard Worker if expected.dtype == np.float32: 100*da0073e9SAndroid Build Coastguard Worker assert actual.dtype in ( 101*da0073e9SAndroid Build Coastguard Worker torch.float16, 102*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 103*da0073e9SAndroid Build Coastguard Worker torch.float32, 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker else: 106*da0073e9SAndroid Build Coastguard Worker assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype] 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 109*da0073e9SAndroid Build Coastguard Worker actual, 110*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(expected).to(actual.dtype), 111*da0073e9SAndroid Build Coastguard Worker msg, 112*da0073e9SAndroid Build Coastguard Worker exact_device=False, 113*da0073e9SAndroid Build Coastguard Worker **kwargs, 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker else: 116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, msg, exact_device=False, **kwargs) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker # Tests that the function and its (array-accepting) reference produce the same 119*da0073e9SAndroid Build Coastguard Worker # values on given tensors 120*da0073e9SAndroid Build Coastguard Worker def _test_reference_numerics(self, dtype, op, gen, equal_nan=True): 121*da0073e9SAndroid Build Coastguard Worker def _helper_reference_numerics( 122*da0073e9SAndroid Build Coastguard Worker expected, actual, msg, exact_dtype, equal_nan=True 123*da0073e9SAndroid Build Coastguard Worker ): 124*da0073e9SAndroid Build Coastguard Worker if not torch.can_cast( 125*da0073e9SAndroid Build Coastguard Worker numpy_to_torch_dtype_dict[expected.dtype.type], dtype 126*da0073e9SAndroid Build Coastguard Worker ): 127*da0073e9SAndroid Build Coastguard Worker exact_dtype = False 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bfloat16 and expected.dtype == np.float32: 130*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_utils.py#L1149 131*da0073e9SAndroid Build Coastguard Worker self.assertEqualHelper( 132*da0073e9SAndroid Build Coastguard Worker actual, 133*da0073e9SAndroid Build Coastguard Worker expected, 134*da0073e9SAndroid Build Coastguard Worker msg, 135*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 136*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 137*da0073e9SAndroid Build Coastguard Worker rtol=16e-3, 138*da0073e9SAndroid Build Coastguard Worker atol=1e-5, 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker else: 141*da0073e9SAndroid Build Coastguard Worker self.assertEqualHelper( 142*da0073e9SAndroid Build Coastguard Worker actual, 143*da0073e9SAndroid Build Coastguard Worker expected, 144*da0073e9SAndroid Build Coastguard Worker msg, 145*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 146*da0073e9SAndroid Build Coastguard Worker equal_nan=equal_nan, 147*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 148*da0073e9SAndroid Build Coastguard Worker ) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker for sample in gen: 151*da0073e9SAndroid Build Coastguard Worker # Each sample input acquired from the generator is just one lhs tensor 152*da0073e9SAndroid Build Coastguard Worker # and one rhs tensor 153*da0073e9SAndroid Build Coastguard Worker l = sample.input 154*da0073e9SAndroid Build Coastguard Worker r = sample.args[0] 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker numpy_sample = sample.numpy() 157*da0073e9SAndroid Build Coastguard Worker l_numpy = numpy_sample.input 158*da0073e9SAndroid Build Coastguard Worker r_numpy = numpy_sample.args[0] 159*da0073e9SAndroid Build Coastguard Worker actual = op(l, r) 160*da0073e9SAndroid Build Coastguard Worker expected = op.ref(l_numpy, r_numpy) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker # Crafts a custom error message for smaller, printable tensors 163*da0073e9SAndroid Build Coastguard Worker def _numel(x): 164*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor): 165*da0073e9SAndroid Build Coastguard Worker return x.numel() 166*da0073e9SAndroid Build Coastguard Worker # Assumes x is a scalar 167*da0073e9SAndroid Build Coastguard Worker return 1 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker if _numel(l) <= 100 and _numel(r) <= 100: 170*da0073e9SAndroid Build Coastguard Worker msg = ( 171*da0073e9SAndroid Build Coastguard Worker "Failed to produce expected results! Input lhs tensor was" 172*da0073e9SAndroid Build Coastguard Worker f" {l}, rhs tensor was {r}, torch result is {actual}, and reference result is" 173*da0073e9SAndroid Build Coastguard Worker f" {expected}." 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker else: 176*da0073e9SAndroid Build Coastguard Worker msg = None 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 179*da0073e9SAndroid Build Coastguard Worker if isinstance(actual, torch.Tensor): 180*da0073e9SAndroid Build Coastguard Worker _helper_reference_numerics( 181*da0073e9SAndroid Build Coastguard Worker expected, actual, msg, exact_dtype, equal_nan 182*da0073e9SAndroid Build Coastguard Worker ) 183*da0073e9SAndroid Build Coastguard Worker else: 184*da0073e9SAndroid Build Coastguard Worker for x, y in zip(expected, actual): 185*da0073e9SAndroid Build Coastguard Worker # testing multi-outputs results 186*da0073e9SAndroid Build Coastguard Worker _helper_reference_numerics(x, y, msg, exact_dtype, equal_nan) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker # The following tests only apply to elementwise binary operators with references 189*da0073e9SAndroid Build Coastguard Worker binary_ufuncs_with_references = list( 190*da0073e9SAndroid Build Coastguard Worker filter(lambda op: op.ref is not None and op.ref is not None, binary_ufuncs) 191*da0073e9SAndroid Build Coastguard Worker ) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs_with_references) 194*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics(self, device, dtype, op): 195*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_tensors(op, device=device, dtype=dtype) 196*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs_with_references) 199*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_small_values(self, device, dtype, op): 200*da0073e9SAndroid Build Coastguard Worker if dtype is torch.bool: 201*da0073e9SAndroid Build Coastguard Worker self.skipTest("Doesn't support bool!") 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_small_value_tensors( 204*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker @ops( 209*da0073e9SAndroid Build Coastguard Worker binary_ufuncs_with_references, 210*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=( 211*da0073e9SAndroid Build Coastguard Worker torch.int16, 212*da0073e9SAndroid Build Coastguard Worker torch.int32, 213*da0073e9SAndroid Build Coastguard Worker torch.int64, 214*da0073e9SAndroid Build Coastguard Worker torch.float16, 215*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 216*da0073e9SAndroid Build Coastguard Worker torch.float32, 217*da0073e9SAndroid Build Coastguard Worker torch.float64, 218*da0073e9SAndroid Build Coastguard Worker torch.complex64, 219*da0073e9SAndroid Build Coastguard Worker torch.complex128, 220*da0073e9SAndroid Build Coastguard Worker ), 221*da0073e9SAndroid Build Coastguard Worker ) 222*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_large_values(self, device, dtype, op): 223*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_large_value_tensors( 224*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype 225*da0073e9SAndroid Build Coastguard Worker ) 226*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker @ops( 229*da0073e9SAndroid Build Coastguard Worker binary_ufuncs_with_references, 230*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=( 231*da0073e9SAndroid Build Coastguard Worker torch.float16, 232*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, 233*da0073e9SAndroid Build Coastguard Worker torch.float32, 234*da0073e9SAndroid Build Coastguard Worker torch.float64, 235*da0073e9SAndroid Build Coastguard Worker torch.complex64, 236*da0073e9SAndroid Build Coastguard Worker torch.complex128, 237*da0073e9SAndroid Build Coastguard Worker ), 238*da0073e9SAndroid Build Coastguard Worker ) 239*da0073e9SAndroid Build Coastguard Worker def test_reference_numerics_extremal_values(self, device, dtype, op): 240*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_extremal_value_tensors( 241*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype 242*da0073e9SAndroid Build Coastguard Worker ) 243*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker # tests broadcasting and noncontiguous broadcasting behavior 246*da0073e9SAndroid Build Coastguard Worker @ops( 247*da0073e9SAndroid Build Coastguard Worker binary_ufuncs_with_references, 248*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=( 249*da0073e9SAndroid Build Coastguard Worker torch.long, 250*da0073e9SAndroid Build Coastguard Worker torch.float32, 251*da0073e9SAndroid Build Coastguard Worker ), 252*da0073e9SAndroid Build Coastguard Worker ) 253*da0073e9SAndroid Build Coastguard Worker def test_broadcasting(self, device, dtype, op): 254*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_broadcasting_tensors( 255*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype 256*da0073e9SAndroid Build Coastguard Worker ) 257*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker @ops( 260*da0073e9SAndroid Build Coastguard Worker binary_ufuncs_with_references, 261*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.long, torch.float32, torch.complex64), 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker def test_scalar_support(self, device, dtype, op): 264*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_with_scalar_samples( 265*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype 266*da0073e9SAndroid Build Coastguard Worker ) 267*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 268*da0073e9SAndroid Build Coastguard Worker gen = generate_elementwise_binary_with_scalar_and_type_promotion_samples( 269*da0073e9SAndroid Build Coastguard Worker op, device=device, dtype=dtype 270*da0073e9SAndroid Build Coastguard Worker ) 271*da0073e9SAndroid Build Coastguard Worker self._test_reference_numerics(dtype, op, gen, equal_nan=True) 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 274*da0073e9SAndroid Build Coastguard Worker def test_contig_vs_every_other(self, device, dtype, op): 275*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 276*da0073e9SAndroid Build Coastguard Worker (1026,), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs 277*da0073e9SAndroid Build Coastguard Worker ) 278*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 279*da0073e9SAndroid Build Coastguard Worker (1026,), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs 280*da0073e9SAndroid Build Coastguard Worker ) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker lhs_non_contig = lhs[::2] 283*da0073e9SAndroid Build Coastguard Worker rhs_non_contig = rhs[::2] 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 286*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker self.assertFalse(lhs_non_contig.is_contiguous()) 289*da0073e9SAndroid Build Coastguard Worker self.assertFalse(rhs_non_contig.is_contiguous()) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs)[::2] 292*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_non_contig, rhs_non_contig) 293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 296*da0073e9SAndroid Build Coastguard Worker def test_contig_vs_transposed(self, device, dtype, op): 297*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 298*da0073e9SAndroid Build Coastguard Worker (789, 357), device=device, dtype=dtype, **op.lhs_make_tensor_kwargs 299*da0073e9SAndroid Build Coastguard Worker ) 300*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 301*da0073e9SAndroid Build Coastguard Worker (789, 357), device=device, dtype=dtype, **op.rhs_make_tensor_kwargs 302*da0073e9SAndroid Build Coastguard Worker ) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker lhs_non_contig = lhs.T 305*da0073e9SAndroid Build Coastguard Worker rhs_non_contig = rhs.T 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 308*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker self.assertFalse(lhs_non_contig.is_contiguous()) 311*da0073e9SAndroid Build Coastguard Worker self.assertFalse(rhs_non_contig.is_contiguous()) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs).T 314*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_non_contig, rhs_non_contig) 315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 318*da0073e9SAndroid Build Coastguard Worker def test_non_contig(self, device, dtype, op): 319*da0073e9SAndroid Build Coastguard Worker shapes = ((5, 7), (1024,)) 320*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 321*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 322*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs 323*da0073e9SAndroid Build Coastguard Worker ) 324*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 325*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker lhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[ 329*da0073e9SAndroid Build Coastguard Worker ..., 0 330*da0073e9SAndroid Build Coastguard Worker ] 331*da0073e9SAndroid Build Coastguard Worker lhs_non_contig.copy_(lhs) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker rhs_non_contig = torch.empty(shape + (2,), device=device, dtype=dtype)[ 334*da0073e9SAndroid Build Coastguard Worker ..., 0 335*da0073e9SAndroid Build Coastguard Worker ] 336*da0073e9SAndroid Build Coastguard Worker rhs_non_contig.copy_(rhs) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 339*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker self.assertFalse(lhs_non_contig.is_contiguous()) 342*da0073e9SAndroid Build Coastguard Worker self.assertFalse(rhs_non_contig.is_contiguous()) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs) 345*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_non_contig, rhs_non_contig) 346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 349*da0073e9SAndroid Build Coastguard Worker def test_non_contig_index(self, device, dtype, op): 350*da0073e9SAndroid Build Coastguard Worker shape = (2, 2, 1, 2) 351*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 352*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs 353*da0073e9SAndroid Build Coastguard Worker ) 354*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 355*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs 356*da0073e9SAndroid Build Coastguard Worker ) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker lhs_non_contig = lhs[:, 1, ...] 359*da0073e9SAndroid Build Coastguard Worker lhs = lhs_non_contig.contiguous() 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker rhs_non_contig = rhs[:, 1, ...] 362*da0073e9SAndroid Build Coastguard Worker rhs = rhs_non_contig.contiguous() 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 365*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker self.assertFalse(lhs_non_contig.is_contiguous()) 368*da0073e9SAndroid Build Coastguard Worker self.assertFalse(rhs_non_contig.is_contiguous()) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs) 371*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_non_contig, rhs_non_contig) 372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 375*da0073e9SAndroid Build Coastguard Worker def test_non_contig_expand(self, device, dtype, op): 376*da0073e9SAndroid Build Coastguard Worker shapes = [(1, 3), (1, 7), (5, 7)] 377*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 378*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 379*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs 380*da0073e9SAndroid Build Coastguard Worker ) 381*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 382*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs 383*da0073e9SAndroid Build Coastguard Worker ) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker lhs_non_contig = lhs.clone().expand(3, -1, -1) 386*da0073e9SAndroid Build Coastguard Worker rhs_non_contig = rhs.clone().expand(3, -1, -1) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 389*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker self.assertFalse(lhs_non_contig.is_contiguous()) 392*da0073e9SAndroid Build Coastguard Worker self.assertFalse(rhs_non_contig.is_contiguous()) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs) 395*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_non_contig, rhs_non_contig) 396*da0073e9SAndroid Build Coastguard Worker for i in range(3): 397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual[i]) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 400*da0073e9SAndroid Build Coastguard Worker def test_contig_size1(self, device, dtype, op): 401*da0073e9SAndroid Build Coastguard Worker shape = (5, 100) 402*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 403*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs 404*da0073e9SAndroid Build Coastguard Worker ) 405*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 406*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs 407*da0073e9SAndroid Build Coastguard Worker ) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker lhs = lhs[:1, :50] 410*da0073e9SAndroid Build Coastguard Worker lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype) 411*da0073e9SAndroid Build Coastguard Worker lhs_alt.copy_(lhs) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker rhs = rhs[:1, :50] 414*da0073e9SAndroid Build Coastguard Worker rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype) 415*da0073e9SAndroid Build Coastguard Worker rhs_alt.copy_(rhs) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 418*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs_alt.is_contiguous()) 421*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs_alt.is_contiguous()) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs) 424*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_alt, rhs_alt) 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 428*da0073e9SAndroid Build Coastguard Worker def test_contig_size1_large_dim(self, device, dtype, op): 429*da0073e9SAndroid Build Coastguard Worker shape = (5, 2, 3, 1, 4, 5, 3, 2, 1, 2, 3, 4) 430*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 431*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs 432*da0073e9SAndroid Build Coastguard Worker ) 433*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 434*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs 435*da0073e9SAndroid Build Coastguard Worker ) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker lhs = lhs[:1, :, :, :, :, :, :, :, :, :, :, :] 438*da0073e9SAndroid Build Coastguard Worker lhs_alt = torch.empty(lhs.size(), device=device, dtype=dtype) 439*da0073e9SAndroid Build Coastguard Worker lhs_alt.copy_(lhs) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker rhs = rhs[:1, :, :, :, :, :, :, :, :, :, :, :] 442*da0073e9SAndroid Build Coastguard Worker rhs_alt = torch.empty(rhs.size(), device=device, dtype=dtype) 443*da0073e9SAndroid Build Coastguard Worker rhs_alt.copy_(rhs) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs.is_contiguous()) 446*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs.is_contiguous()) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lhs_alt.is_contiguous()) 449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(rhs_alt.is_contiguous()) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs) 452*da0073e9SAndroid Build Coastguard Worker actual = op(lhs_alt, rhs_alt) 453*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs) 456*da0073e9SAndroid Build Coastguard Worker def test_batch_vs_slicing(self, device, dtype, op): 457*da0073e9SAndroid Build Coastguard Worker shape = (32, 512) 458*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 459*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.lhs_make_tensor_kwargs 460*da0073e9SAndroid Build Coastguard Worker ) 461*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 462*da0073e9SAndroid Build Coastguard Worker shape, dtype=dtype, device=device, **op.rhs_make_tensor_kwargs 463*da0073e9SAndroid Build Coastguard Worker ) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker expected = op(lhs, rhs) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker actual = [] 468*da0073e9SAndroid Build Coastguard Worker for idx in range(32): 469*da0073e9SAndroid Build Coastguard Worker actual.append(op(lhs[idx], rhs[idx])) 470*da0073e9SAndroid Build Coastguard Worker actual = torch.stack(actual) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker # Tests that elementwise binary operators participate in type promotion properly 475*da0073e9SAndroid Build Coastguard Worker # NOTE: because the cross-product of all possible type promotion tests is huge, this 476*da0073e9SAndroid Build Coastguard Worker # just spot checks some handwritten cases. 477*da0073e9SAndroid Build Coastguard Worker # NOTE: It may be possible to refactor this test into something simpler 478*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs_and_refs, dtypes=OpDTypes.none) 479*da0073e9SAndroid Build Coastguard Worker def test_type_promotion(self, device, op): 480*da0073e9SAndroid Build Coastguard Worker supported_dtypes = op.supported_dtypes(torch.device(device).type) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker make_lhs = partial( 483*da0073e9SAndroid Build Coastguard Worker make_tensor, (5,), device=device, **op.lhs_make_tensor_kwargs 484*da0073e9SAndroid Build Coastguard Worker ) 485*da0073e9SAndroid Build Coastguard Worker make_rhs = partial( 486*da0073e9SAndroid Build Coastguard Worker make_tensor, (5,), device=device, **op.rhs_make_tensor_kwargs 487*da0073e9SAndroid Build Coastguard Worker ) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker make_rhs_scalar_tensor = partial( 490*da0073e9SAndroid Build Coastguard Worker make_tensor, (), device="cpu", **op.rhs_make_tensor_kwargs 491*da0073e9SAndroid Build Coastguard Worker ) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def _supported(dtypes): 494*da0073e9SAndroid Build Coastguard Worker return all(x in supported_dtypes for x in dtypes) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker # int x int type promotion 497*da0073e9SAndroid Build Coastguard Worker if _supported((torch.int16, torch.int32, torch.int64)): 498*da0073e9SAndroid Build Coastguard Worker lhs_i16 = make_lhs(dtype=torch.int16) 499*da0073e9SAndroid Build Coastguard Worker lhs_i32 = make_lhs(dtype=torch.int32) 500*da0073e9SAndroid Build Coastguard Worker lhs_i64 = make_lhs(dtype=torch.int64) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker rhs_i16 = make_rhs(dtype=torch.int16) 503*da0073e9SAndroid Build Coastguard Worker rhs_i32 = make_rhs(dtype=torch.int32) 504*da0073e9SAndroid Build Coastguard Worker rhs_i64 = make_rhs(dtype=torch.int64) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker if op.promotes_int_to_float: 507*da0073e9SAndroid Build Coastguard Worker default_dtype = torch.get_default_dtype() 508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32).dtype, default_dtype) 509*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 510*da0073e9SAndroid Build Coastguard Worker op(lhs_i16, rhs_i32), 511*da0073e9SAndroid Build Coastguard Worker op(lhs_i16.to(default_dtype), rhs_i32.to(default_dtype)), 512*da0073e9SAndroid Build Coastguard Worker ) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i32, rhs_i64).dtype, default_dtype) 515*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 516*da0073e9SAndroid Build Coastguard Worker op(lhs_i32, rhs_i64), 517*da0073e9SAndroid Build Coastguard Worker op(lhs_i32.to(default_dtype), rhs_i64.to(default_dtype)), 518*da0073e9SAndroid Build Coastguard Worker ) 519*da0073e9SAndroid Build Coastguard Worker elif op.always_returns_bool: 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.bool) 521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.bool) 522*da0073e9SAndroid Build Coastguard Worker else: # standard type promotion 523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32).dtype, torch.int32) 524*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 525*da0073e9SAndroid Build Coastguard Worker op(lhs_i16, rhs_i32), op(lhs_i16.to(torch.int32), rhs_i32) 526*da0073e9SAndroid Build Coastguard Worker ) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i32, rhs_i64).dtype, torch.int64) 529*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 530*da0073e9SAndroid Build Coastguard Worker op(lhs_i32, rhs_i64), op(lhs_i32.to(torch.int64), rhs_i64) 531*da0073e9SAndroid Build Coastguard Worker ) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker if op.supports_out: 534*da0073e9SAndroid Build Coastguard Worker if not op.promotes_int_to_float: 535*da0073e9SAndroid Build Coastguard Worker # Integers can be safely cast to other integer types 536*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_i64) 537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.int64) 538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_i16) 541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i32, rhs_i64, out=out).dtype, torch.int16) 542*da0073e9SAndroid Build Coastguard Worker else: 543*da0073e9SAndroid Build Coastguard Worker # Float outs cannot be safely cast to integer types 544*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "can't be cast"): 545*da0073e9SAndroid Build Coastguard Worker op(lhs_i16, rhs_i32, out=torch.empty_like(lhs_i64)) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker if not op.always_returns_bool: 548*da0073e9SAndroid Build Coastguard Worker # Neither integer nor float outs can be cast to bool 549*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "can't be cast"): 550*da0073e9SAndroid Build Coastguard Worker op( 551*da0073e9SAndroid Build Coastguard Worker lhs_i16, 552*da0073e9SAndroid Build Coastguard Worker rhs_i32, 553*da0073e9SAndroid Build Coastguard Worker out=torch.empty_like(lhs_i64, dtype=torch.bool), 554*da0073e9SAndroid Build Coastguard Worker ) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker # All these output types can be cast to any float or complex type 557*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_i64, dtype=torch.float16) 558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float16) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_i64, dtype=torch.bfloat16) 561*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.bfloat16) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_i64, dtype=torch.float32) 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.float32) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_i64, dtype=torch.complex64) 568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32, out=out).dtype, torch.complex64) 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_i16, rhs_i32), out, exact_dtype=False) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker # float x float type promotion 572*da0073e9SAndroid Build Coastguard Worker if _supported((torch.float32, torch.float64)): 573*da0073e9SAndroid Build Coastguard Worker lhs_f32 = make_lhs(dtype=torch.float32) 574*da0073e9SAndroid Build Coastguard Worker lhs_f64 = make_lhs(dtype=torch.float64) 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker rhs_f32 = make_rhs(dtype=torch.float32) 577*da0073e9SAndroid Build Coastguard Worker rhs_f64 = make_rhs(dtype=torch.float64) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker if op.always_returns_bool: 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.bool) 581*da0073e9SAndroid Build Coastguard Worker else: # normal float type promotion 582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64).dtype, torch.float64) 583*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 584*da0073e9SAndroid Build Coastguard Worker op(lhs_f32, rhs_f64), op(lhs_f32.to(torch.float64), rhs_f64) 585*da0073e9SAndroid Build Coastguard Worker ) 586*da0073e9SAndroid Build Coastguard Worker 587*da0073e9SAndroid Build Coastguard Worker if op.supports_out: 588*da0073e9SAndroid Build Coastguard Worker # All these output types can be cast to any float or complex type 589*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_f64, dtype=torch.float16) 590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float16) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_f64, dtype=torch.bfloat16) 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.bfloat16) 594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_f64, dtype=torch.float32) 597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.float32) 598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_f64, dtype=torch.complex64) 601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.complex64) 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker if not op.always_returns_bool: 605*da0073e9SAndroid Build Coastguard Worker # float outs can't be cast to an integer dtype 606*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "can't be cast"): 607*da0073e9SAndroid Build Coastguard Worker op( 608*da0073e9SAndroid Build Coastguard Worker lhs_f32, 609*da0073e9SAndroid Build Coastguard Worker rhs_f64, 610*da0073e9SAndroid Build Coastguard Worker out=torch.empty_like(lhs_f64, dtype=torch.int64), 611*da0073e9SAndroid Build Coastguard Worker ) 612*da0073e9SAndroid Build Coastguard Worker else: 613*da0073e9SAndroid Build Coastguard Worker # bool outs can be cast to an integer dtype 614*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_f64, dtype=torch.int64) 615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64) 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False) 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker # complex x complex type promotion 619*da0073e9SAndroid Build Coastguard Worker if _supported((torch.complex64, torch.complex128)): 620*da0073e9SAndroid Build Coastguard Worker lhs_c64 = make_lhs(dtype=torch.complex64) 621*da0073e9SAndroid Build Coastguard Worker lhs_c128 = make_lhs(dtype=torch.complex128) 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker rhs_c64 = make_rhs(dtype=torch.complex64) 624*da0073e9SAndroid Build Coastguard Worker rhs_c128 = make_rhs(dtype=torch.complex128) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker if op.always_returns_bool: 627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_c64, lhs_c128).dtype, torch.bool) 628*da0073e9SAndroid Build Coastguard Worker else: # normal complex type promotion 629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_c64, rhs_c128).dtype, torch.complex128) 630*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 631*da0073e9SAndroid Build Coastguard Worker op(lhs_c64, rhs_c128), op(lhs_c64.to(torch.complex128), rhs_c128) 632*da0073e9SAndroid Build Coastguard Worker ) 633*da0073e9SAndroid Build Coastguard Worker 634*da0073e9SAndroid Build Coastguard Worker if op.supports_out: 635*da0073e9SAndroid Build Coastguard Worker # All these output types can be cast to any or complex type 636*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_c64, dtype=torch.complex64) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_c64, rhs_c128, out=out).dtype, torch.complex64) 639*da0073e9SAndroid Build Coastguard Worker result = op(lhs_c64, rhs_c128) 640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, out.to(result.dtype)) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker if not op.always_returns_bool: 643*da0073e9SAndroid Build Coastguard Worker # complex outs can't be cast to float types 644*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "can't be cast"): 645*da0073e9SAndroid Build Coastguard Worker op( 646*da0073e9SAndroid Build Coastguard Worker lhs_c64, 647*da0073e9SAndroid Build Coastguard Worker rhs_c128, 648*da0073e9SAndroid Build Coastguard Worker out=torch.empty_like(lhs_c64, dtype=torch.float64), 649*da0073e9SAndroid Build Coastguard Worker ) 650*da0073e9SAndroid Build Coastguard Worker # complex outs can't be cast to an integer dtype 651*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "can't be cast"): 652*da0073e9SAndroid Build Coastguard Worker op( 653*da0073e9SAndroid Build Coastguard Worker lhs_c64, 654*da0073e9SAndroid Build Coastguard Worker rhs_c128, 655*da0073e9SAndroid Build Coastguard Worker out=torch.empty_like(lhs_c64, dtype=torch.int64), 656*da0073e9SAndroid Build Coastguard Worker ) 657*da0073e9SAndroid Build Coastguard Worker else: 658*da0073e9SAndroid Build Coastguard Worker # bool outs can be cast to a float type 659*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_c64, dtype=torch.float64) 660*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 661*da0073e9SAndroid Build Coastguard Worker op(lhs_c64, rhs_c128, out=out).dtype, torch.float64 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_c64, rhs_c128), out, exact_dtype=False) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker # bool outs can be cast to an integer dtype 666*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(lhs_f64, dtype=torch.int64) 667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64, out=out).dtype, torch.int64) 668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(lhs_f32, rhs_f64), out, exact_dtype=False) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker # int x float type promotion 671*da0073e9SAndroid Build Coastguard Worker # Note: float type is the result dtype 672*da0073e9SAndroid Build Coastguard Worker if _supported((torch.long, torch.float32)): 673*da0073e9SAndroid Build Coastguard Worker lhs_i64 = make_lhs(dtype=torch.int64) 674*da0073e9SAndroid Build Coastguard Worker rhs_f32 = make_rhs(dtype=torch.float32) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker result = op(lhs_i64, rhs_f32) 677*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool 678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker # float x complex type promotion 681*da0073e9SAndroid Build Coastguard Worker # Note: complex type with highest "value type" is the result dtype 682*da0073e9SAndroid Build Coastguard Worker if _supported((torch.float64, torch.complex64)): 683*da0073e9SAndroid Build Coastguard Worker lhs_f64 = make_lhs(dtype=torch.float64) 684*da0073e9SAndroid Build Coastguard Worker rhs_c64 = make_rhs(dtype=torch.complex64) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker result = op(lhs_f64, rhs_c64) 687*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 688*da0073e9SAndroid Build Coastguard Worker torch.complex128 if not op.always_returns_bool else torch.bool 689*da0073e9SAndroid Build Coastguard Worker ) 690*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker # int x float scalar type promotion 693*da0073e9SAndroid Build Coastguard Worker # Note: default float dtype is the result dtype 694*da0073e9SAndroid Build Coastguard Worker if _supported((torch.int64, torch.float32)) and op.supports_rhs_python_scalar: 695*da0073e9SAndroid Build Coastguard Worker lhs_i64 = make_lhs(dtype=torch.int64) 696*da0073e9SAndroid Build Coastguard Worker rhs_f_scalar = 1.0 697*da0073e9SAndroid Build Coastguard Worker 698*da0073e9SAndroid Build Coastguard Worker result = op(lhs_i64, rhs_f_scalar) 699*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 700*da0073e9SAndroid Build Coastguard Worker torch.get_default_dtype() if not op.always_returns_bool else torch.bool 701*da0073e9SAndroid Build Coastguard Worker ) 702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker # repeats with a scalar float tensor, which should set the dtype 705*da0073e9SAndroid Build Coastguard Worker rhs_f32_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float32) 706*da0073e9SAndroid Build Coastguard Worker result = op(lhs_i64, rhs_f32_scalar_tensor) 707*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker # Additional test with double 711*da0073e9SAndroid Build Coastguard Worker if _supported((torch.float64,)): 712*da0073e9SAndroid Build Coastguard Worker rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64) 713*da0073e9SAndroid Build Coastguard Worker result = op(lhs_i64, rhs_f64_scalar_tensor) 714*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 715*da0073e9SAndroid Build Coastguard Worker torch.float64 if not op.always_returns_bool else torch.bool 716*da0073e9SAndroid Build Coastguard Worker ) 717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker # float x complex scalar type promotion 720*da0073e9SAndroid Build Coastguard Worker # Note: result dtype is complex with highest "value type" among all tensors 721*da0073e9SAndroid Build Coastguard Worker if ( 722*da0073e9SAndroid Build Coastguard Worker _supported((torch.float32, torch.complex64)) 723*da0073e9SAndroid Build Coastguard Worker and op.supports_rhs_python_scalar 724*da0073e9SAndroid Build Coastguard Worker ): 725*da0073e9SAndroid Build Coastguard Worker lhs_f32 = make_lhs(dtype=torch.float32) 726*da0073e9SAndroid Build Coastguard Worker rhs_c_scalar = complex(1, 1) 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker result = op(lhs_f32, rhs_c_scalar) 729*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 730*da0073e9SAndroid Build Coastguard Worker torch.complex64 if not op.always_returns_bool else torch.bool 731*da0073e9SAndroid Build Coastguard Worker ) 732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 733*da0073e9SAndroid Build Coastguard Worker 734*da0073e9SAndroid Build Coastguard Worker # repeats with a scalar complex tensor 735*da0073e9SAndroid Build Coastguard Worker rhs_c64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex64) 736*da0073e9SAndroid Build Coastguard Worker result = op(lhs_f32, rhs_c64_scalar_tensor) 737*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 738*da0073e9SAndroid Build Coastguard Worker torch.complex64 if not op.always_returns_bool else torch.bool 739*da0073e9SAndroid Build Coastguard Worker ) 740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker # Additional test with complexdouble 743*da0073e9SAndroid Build Coastguard Worker if _supported((torch.complex128,)): 744*da0073e9SAndroid Build Coastguard Worker rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128) 745*da0073e9SAndroid Build Coastguard Worker result = op(lhs_f32, rhs_c128_scalar_tensor) 746*da0073e9SAndroid Build Coastguard Worker # Value type of 1D+ Tensor (lhs_f32) takes priority over scalar tensor (rhs_c128). 747*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 748*da0073e9SAndroid Build Coastguard Worker torch.complex64 if not op.always_returns_bool else torch.bool 749*da0073e9SAndroid Build Coastguard Worker ) 750*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker # float x float scalar tensor 753*da0073e9SAndroid Build Coastguard Worker # Note: result dtype is the type of the float tensor 754*da0073e9SAndroid Build Coastguard Worker if _supported((torch.float32, torch.float64)) and op.supports_rhs_python_scalar: 755*da0073e9SAndroid Build Coastguard Worker lhs_f32 = make_lhs(dtype=torch.float32) 756*da0073e9SAndroid Build Coastguard Worker rhs_f64_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.float64) 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker result = op(lhs_f32, rhs_f64_scalar_tensor) 759*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.float32 if not op.always_returns_bool else torch.bool 760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker # complex x complex scalar tensor 763*da0073e9SAndroid Build Coastguard Worker # Note: result dtype is the type of the complex tensor 764*da0073e9SAndroid Build Coastguard Worker if ( 765*da0073e9SAndroid Build Coastguard Worker _supported((torch.complex64, torch.complex128)) 766*da0073e9SAndroid Build Coastguard Worker and op.supports_rhs_python_scalar 767*da0073e9SAndroid Build Coastguard Worker ): 768*da0073e9SAndroid Build Coastguard Worker lhs_c64 = make_lhs(dtype=torch.complex64) 769*da0073e9SAndroid Build Coastguard Worker rhs_c128_scalar_tensor = make_rhs_scalar_tensor(dtype=torch.complex128) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker result = op(lhs_c64, rhs_c128_scalar_tensor) 772*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 773*da0073e9SAndroid Build Coastguard Worker torch.complex64 if not op.always_returns_bool else torch.bool 774*da0073e9SAndroid Build Coastguard Worker ) 775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker # scalar x scalar 778*da0073e9SAndroid Build Coastguard Worker # Note: result dtype is default float type 779*da0073e9SAndroid Build Coastguard Worker if op.supports_two_python_scalars and _supported((torch.long, torch.float32)): 780*da0073e9SAndroid Build Coastguard Worker rhs_f_scalar = 2.0 781*da0073e9SAndroid Build Coastguard Worker for lhs in (1, 1.0): 782*da0073e9SAndroid Build Coastguard Worker result = op(lhs, rhs_f_scalar) 783*da0073e9SAndroid Build Coastguard Worker expected_dtype = ( 784*da0073e9SAndroid Build Coastguard Worker torch.get_default_dtype() 785*da0073e9SAndroid Build Coastguard Worker if not op.always_returns_bool 786*da0073e9SAndroid Build Coastguard Worker else torch.bool 787*da0073e9SAndroid Build Coastguard Worker ) 788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, expected_dtype) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker # TODO: move to error input test 791*da0073e9SAndroid Build Coastguard Worker @ops(binary_ufuncs, allowed_dtypes=(torch.float32,)) 792*da0073e9SAndroid Build Coastguard Worker def test_not_broadcastable(self, device, dtype, op): 793*da0073e9SAndroid Build Coastguard Worker for shape_lhs, shape_rhs in ( 794*da0073e9SAndroid Build Coastguard Worker ((2,), (3,)), 795*da0073e9SAndroid Build Coastguard Worker ((3, 1), (2, 1)), 796*da0073e9SAndroid Build Coastguard Worker ((1, 3, 2), (3,)), 797*da0073e9SAndroid Build Coastguard Worker ((3, 1, 2), (2, 1, 2)), 798*da0073e9SAndroid Build Coastguard Worker ): 799*da0073e9SAndroid Build Coastguard Worker lhs = make_tensor( 800*da0073e9SAndroid Build Coastguard Worker shape_lhs, device=device, dtype=dtype, **op.lhs_make_tensor_kwargs 801*da0073e9SAndroid Build Coastguard Worker ) 802*da0073e9SAndroid Build Coastguard Worker rhs = make_tensor( 803*da0073e9SAndroid Build Coastguard Worker shape_rhs, device=device, dtype=dtype, **op.rhs_make_tensor_kwargs 804*da0073e9SAndroid Build Coastguard Worker ) 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker try: 807*da0073e9SAndroid Build Coastguard Worker broadcasted_shape = op(lhs, rhs).shape 808*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 809*da0073e9SAndroid Build Coastguard Worker continue 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker msg = ( 812*da0073e9SAndroid Build Coastguard Worker f"On {device}, torch.{op.name} broadcasts inputs shapes {shape_lhs} and {shape_rhs} into " 813*da0073e9SAndroid Build Coastguard Worker f"{broadcasted_shape}, although they are not broadcastable." 814*da0073e9SAndroid Build Coastguard Worker ) 815*da0073e9SAndroid Build Coastguard Worker raise AssertionError(msg) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker def test_add_broadcast_empty(self, device): 818*da0073e9SAndroid Build Coastguard Worker # empty + empty 819*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 820*da0073e9SAndroid Build Coastguard Worker RuntimeError, 821*da0073e9SAndroid Build Coastguard Worker lambda: torch.randn(5, 0, device=device) + torch.randn(0, 5, device=device), 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 824*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 0, device=device), 825*da0073e9SAndroid Build Coastguard Worker torch.randn(0, device=device) + torch.randn(5, 0, device=device), 826*da0073e9SAndroid Build Coastguard Worker ) 827*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 828*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 0, 0, device=device), 829*da0073e9SAndroid Build Coastguard Worker torch.randn(0, device=device) + torch.randn(5, 0, 1, device=device), 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker # scalar + empty 833*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 834*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 0, 6, device=device), 835*da0073e9SAndroid Build Coastguard Worker torch.randn((), device=device) + torch.randn(5, 0, 6, device=device), 836*da0073e9SAndroid Build Coastguard Worker ) 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker # non-empty, empty 839*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 840*da0073e9SAndroid Build Coastguard Worker torch.randn(0, device=device), 841*da0073e9SAndroid Build Coastguard Worker torch.randn(0, device=device) + torch.randn(1, device=device), 842*da0073e9SAndroid Build Coastguard Worker ) 843*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 844*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 7, 0, 6, 5, 0, 7, device=device), 845*da0073e9SAndroid Build Coastguard Worker torch.randn(0, 7, 0, 6, 5, 0, 1, device=device) 846*da0073e9SAndroid Build Coastguard Worker + torch.randn(1, 1, 5, 1, 7, device=device), 847*da0073e9SAndroid Build Coastguard Worker ) 848*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 849*da0073e9SAndroid Build Coastguard Worker RuntimeError, 850*da0073e9SAndroid Build Coastguard Worker lambda: torch.randn(7, 0, device=device) + torch.randn(2, 1, device=device), 851*da0073e9SAndroid Build Coastguard Worker ) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker def test_addcmul_scalars_as_floats(self, device): 854*da0073e9SAndroid Build Coastguard Worker # zero-dim variables that don't require grad should bind to scalar arguments 855*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2.0) 856*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(3.0, device=device) 857*da0073e9SAndroid Build Coastguard Worker # 3 + (3 * 3) * 2 858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.addcmul(y, y, value=x), 21) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2.0, requires_grad=True) 861*da0073e9SAndroid Build Coastguard Worker self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker # Tests that the binary operators and, or, and xor (as well as their reflected and inplace versions) 864*da0073e9SAndroid Build Coastguard Worker # work properly (AKA &, ||, ^ and &=, |=, ^=) 865*da0073e9SAndroid Build Coastguard Worker @dtypes(*integral_types_and(torch.bool)) 866*da0073e9SAndroid Build Coastguard Worker def test_bitwise_ops(self, device, dtype): 867*da0073e9SAndroid Build Coastguard Worker # Tensor x Tensor and Tensor x Scalar ops 868*da0073e9SAndroid Build Coastguard Worker ops = ( 869*da0073e9SAndroid Build Coastguard Worker operator.and_, 870*da0073e9SAndroid Build Coastguard Worker operator.iand, 871*da0073e9SAndroid Build Coastguard Worker operator.or_, 872*da0073e9SAndroid Build Coastguard Worker operator.ior, 873*da0073e9SAndroid Build Coastguard Worker operator.xor, 874*da0073e9SAndroid Build Coastguard Worker operator.ixor, 875*da0073e9SAndroid Build Coastguard Worker ) 876*da0073e9SAndroid Build Coastguard Worker inplace_ops = (operator.iand, operator.ior, operator.ixor) 877*da0073e9SAndroid Build Coastguard Worker shapes = ((5,), (15, 15), (500, 500)) 878*da0073e9SAndroid Build Coastguard Worker 879*da0073e9SAndroid Build Coastguard Worker for op, shape in itertools.product(ops, shapes): 880*da0073e9SAndroid Build Coastguard Worker # Tests tensor x tensor case 881*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 882*da0073e9SAndroid Build Coastguard Worker b = make_tensor(shape, device=device, dtype=dtype) 883*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().clone().numpy() 884*da0073e9SAndroid Build Coastguard Worker b_np = b.cpu().clone().numpy() 885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a, b), op(a_np, b_np)) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker # Tests tensor x scalar case 888*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 889*da0073e9SAndroid Build Coastguard Worker b_scalar = make_tensor((), device="cpu", dtype=dtype).item() 890*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().clone().numpy() 891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a, b_scalar), op(a_np, b_scalar)) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker # Tests scalar x tensor case 894*da0073e9SAndroid Build Coastguard Worker a_scalar = make_tensor((), device="cpu", dtype=dtype).item() 895*da0073e9SAndroid Build Coastguard Worker b = make_tensor(shape, device=device, dtype=dtype) 896*da0073e9SAndroid Build Coastguard Worker b_np = b.cpu().clone().numpy() 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(a_scalar, b), op(a_scalar, b_np)) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker # Tests scalar x tensor case (for ops which aren't inplace) 900*da0073e9SAndroid Build Coastguard Worker if op in inplace_ops: 901*da0073e9SAndroid Build Coastguard Worker # Tests tensor x tensor case 902*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 903*da0073e9SAndroid Build Coastguard Worker b = make_tensor(shape, device=device, dtype=dtype) 904*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().clone().numpy() 905*da0073e9SAndroid Build Coastguard Worker b_np = b.cpu().clone().numpy() 906*da0073e9SAndroid Build Coastguard Worker op(a, b) 907*da0073e9SAndroid Build Coastguard Worker op(a_np, b_np) 908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a_np) 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker # Tests tensor x scalar case 911*da0073e9SAndroid Build Coastguard Worker a = make_tensor(shape, device=device, dtype=dtype) 912*da0073e9SAndroid Build Coastguard Worker b_scalar = make_tensor((), device="cpu", dtype=dtype).item() 913*da0073e9SAndroid Build Coastguard Worker a_np = a.cpu().clone().numpy() 914*da0073e9SAndroid Build Coastguard Worker op(a, b_scalar) 915*da0073e9SAndroid Build Coastguard Worker op(a_np, b_scalar) 916*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a_np) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker def test_inplace_division(self, device): 919*da0073e9SAndroid Build Coastguard Worker t = torch.rand(5, 5, device=device) 920*da0073e9SAndroid Build Coastguard Worker id_before = id(t) 921*da0073e9SAndroid Build Coastguard Worker t /= 2 922*da0073e9SAndroid Build Coastguard Worker id_after = id(t) 923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id_before, id_after) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 926*da0073e9SAndroid Build Coastguard Worker def test_div_rounding_modes(self, device, dtype): 927*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 928*da0073e9SAndroid Build Coastguard Worker low, high = -10.0, 10.0 929*da0073e9SAndroid Build Coastguard Worker else: 930*da0073e9SAndroid Build Coastguard Worker info = torch.iinfo(dtype) 931*da0073e9SAndroid Build Coastguard Worker low, high = info.min, info.max 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker a = make_tensor((100,), dtype=dtype, device=device, low=low, high=high) 934*da0073e9SAndroid Build Coastguard Worker b = make_tensor((100,), dtype=dtype, device=device, low=low, high=high) 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker # Avoid division by zero so we can test (a / b) * b == a 937*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 938*da0073e9SAndroid Build Coastguard Worker eps = 0.1 939*da0073e9SAndroid Build Coastguard Worker b[(-eps < b) & (b < eps)] = eps 940*da0073e9SAndroid Build Coastguard Worker else: 941*da0073e9SAndroid Build Coastguard Worker b[b == 0] = 1 942*da0073e9SAndroid Build Coastguard Worker 943*da0073e9SAndroid Build Coastguard Worker if not dtype.is_floating_point: 944*da0073e9SAndroid Build Coastguard Worker # floor(a / b) * b can be < a, so fixup slightly to avoid underflow 945*da0073e9SAndroid Build Coastguard Worker a = torch.where(a < 0, a + b, a) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker d_true = torch.divide(a, b, rounding_mode=None) 948*da0073e9SAndroid Build Coastguard Worker self.assertTrue(d_true.is_floating_point()) 949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_true * b, a.to(d_true.dtype)) 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker d_floor = torch.divide(a, b, rounding_mode="floor") 952*da0073e9SAndroid Build Coastguard Worker if dtype not in (torch.bfloat16, torch.half): 953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_floor * b + torch.remainder(a, b), a) 954*da0073e9SAndroid Build Coastguard Worker else: 955*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 956*da0073e9SAndroid Build Coastguard Worker d_floor * b + torch.remainder(a.float(), b.float()), 957*da0073e9SAndroid Build Coastguard Worker a, 958*da0073e9SAndroid Build Coastguard Worker exact_dtype=False, 959*da0073e9SAndroid Build Coastguard Worker ) 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker d_trunc = torch.divide(a, b, rounding_mode="trunc") 962*da0073e9SAndroid Build Coastguard Worker rounding_unsupported = ( 963*da0073e9SAndroid Build Coastguard Worker dtype == torch.half 964*da0073e9SAndroid Build Coastguard Worker and device != "cuda" 965*da0073e9SAndroid Build Coastguard Worker or dtype == torch.bfloat16 966*da0073e9SAndroid Build Coastguard Worker and device != "cpu" 967*da0073e9SAndroid Build Coastguard Worker ) 968*da0073e9SAndroid Build Coastguard Worker d_ref = d_true.float() if rounding_unsupported else d_true 969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_trunc, d_ref.trunc().to(dtype)) 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.bfloat16, torch.float16)) 972*da0073e9SAndroid Build Coastguard Worker def test_floor_div_extremal(self, device, dtype): 973*da0073e9SAndroid Build Coastguard Worker for num, denom, shape in itertools.product( 974*da0073e9SAndroid Build Coastguard Worker [torch.finfo(dtype).max * 0.7], 975*da0073e9SAndroid Build Coastguard Worker [0.5, -0.5, 0.0], 976*da0073e9SAndroid Build Coastguard Worker [(), (32,)], # Scalar and vectorized 977*da0073e9SAndroid Build Coastguard Worker ): 978*da0073e9SAndroid Build Coastguard Worker a = torch.full(shape, num, dtype=dtype, device=device) 979*da0073e9SAndroid Build Coastguard Worker b = torch.full(shape, denom, dtype=dtype, device=device) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker ref = np.floor_divide(num, denom).item() 982*da0073e9SAndroid Build Coastguard Worker if ref > torch.finfo(dtype).max: 983*da0073e9SAndroid Build Coastguard Worker ref = np.inf 984*da0073e9SAndroid Build Coastguard Worker elif ref < torch.finfo(dtype).min: 985*da0073e9SAndroid Build Coastguard Worker ref = -np.inf 986*da0073e9SAndroid Build Coastguard Worker expect = torch.full(shape, ref, dtype=dtype, device=device) 987*da0073e9SAndroid Build Coastguard Worker actual = torch.div(a, b, rounding_mode="floor") 988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64) 991*da0073e9SAndroid Build Coastguard Worker def test_div_rounding_nonfinite(self, device, dtype): 992*da0073e9SAndroid Build Coastguard Worker # Compare division of special floating point values against NumPy 993*da0073e9SAndroid Build Coastguard Worker num = torch.tensor( 994*da0073e9SAndroid Build Coastguard Worker [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan], 995*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 996*da0073e9SAndroid Build Coastguard Worker device=device, 997*da0073e9SAndroid Build Coastguard Worker ) 998*da0073e9SAndroid Build Coastguard Worker # Divide by zero is tested separately 999*da0073e9SAndroid Build Coastguard Worker denom = num[num != 0] 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker a, b = num[None, :].clone(), denom[:, None].clone() 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker # Compare bfloat16 against NumPy float 1004*da0073e9SAndroid Build Coastguard Worker exact_dtype = dtype != torch.bfloat16 1005*da0073e9SAndroid Build Coastguard Worker if exact_dtype: 1006*da0073e9SAndroid Build Coastguard Worker an, bn = a.cpu().numpy(), b.cpu().numpy() 1007*da0073e9SAndroid Build Coastguard Worker else: 1008*da0073e9SAndroid Build Coastguard Worker an, bn = a.float().cpu().numpy(), b.float().cpu().numpy() 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker for mode, np_ref in ((None, np.true_divide), ("floor", np.floor_divide)): 1011*da0073e9SAndroid Build Coastguard Worker expect = np_ref(an, bn) 1012*da0073e9SAndroid Build Coastguard Worker kwargs = dict(rounding_mode=mode) if mode is not None else {} 1013*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1014*da0073e9SAndroid Build Coastguard Worker actual = torch.divide(a, b, **kwargs) 1015*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1016*da0073e9SAndroid Build Coastguard Worker actual, 1017*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(expect), 1018*da0073e9SAndroid Build Coastguard Worker exact_device=False, 1019*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 1020*da0073e9SAndroid Build Coastguard Worker ) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker # Compare contiguous (likely vectorized) against non-contiguous (not vectorized) 1023*da0073e9SAndroid Build Coastguard Worker a_noncontig = torch.empty([2 * i for i in a.shape], dtype=dtype, device=device)[ 1024*da0073e9SAndroid Build Coastguard Worker ::2, ::2 1025*da0073e9SAndroid Build Coastguard Worker ] 1026*da0073e9SAndroid Build Coastguard Worker a_noncontig[:] = a 1027*da0073e9SAndroid Build Coastguard Worker b_noncontig = torch.empty([2 * i for i in b.shape], dtype=dtype, device=device)[ 1028*da0073e9SAndroid Build Coastguard Worker ::2, ::2 1029*da0073e9SAndroid Build Coastguard Worker ] 1030*da0073e9SAndroid Build Coastguard Worker b_noncontig[:] = b 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker for rounding_mode in (None, "trunc", "floor"): 1033*da0073e9SAndroid Build Coastguard Worker expect = torch.divide(a_noncontig, b_noncontig, rounding_mode=rounding_mode) 1034*da0073e9SAndroid Build Coastguard Worker actual = torch.divide(a, b, rounding_mode=rounding_mode) 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half, torch.float32, torch.float64) 1038*da0073e9SAndroid Build Coastguard Worker def test_divide_by_zero_rounding(self, device, dtype): 1039*da0073e9SAndroid Build Coastguard Worker a = torch.tensor( 1040*da0073e9SAndroid Build Coastguard Worker [1.0, -1.0, 0, 0.1, -0.1, np.pi, -np.pi, np.inf, -np.inf, np.nan], 1041*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker exact_dtype = dtype != torch.bfloat16 1044*da0073e9SAndroid Build Coastguard Worker if exact_dtype: 1045*da0073e9SAndroid Build Coastguard Worker an = a.cpu().numpy() 1046*da0073e9SAndroid Build Coastguard Worker else: 1047*da0073e9SAndroid Build Coastguard Worker an = a.float().cpu().numpy() 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker zero = torch.zeros_like(a) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker # NOTE: NumPy's floor_divide rounding changed in 1.20.0 to be consistent with divide 1052*da0073e9SAndroid Build Coastguard Worker expect = np.divide(an, 0) 1053*da0073e9SAndroid Build Coastguard Worker for rounding_mode in (None, "floor"): 1054*da0073e9SAndroid Build Coastguard Worker # CPU scalar 1055*da0073e9SAndroid Build Coastguard Worker actual = torch.divide(a, 0, rounding_mode=rounding_mode) 1056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, exact_dtype=exact_dtype) 1057*da0073e9SAndroid Build Coastguard Worker # Device tensor 1058*da0073e9SAndroid Build Coastguard Worker actual = torch.divide(a, zero, rounding_mode=rounding_mode) 1059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expect, exact_dtype=exact_dtype) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half)) 1062*da0073e9SAndroid Build Coastguard Worker def test_div_rounding_numpy(self, device, dtype): 1063*da0073e9SAndroid Build Coastguard Worker info = torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype) 1064*da0073e9SAndroid Build Coastguard Worker low, high = info.min, info.max 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker # Compare division of random values against NumPy 1067*da0073e9SAndroid Build Coastguard Worker a = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high) 1068*da0073e9SAndroid Build Coastguard Worker b = make_tensor((4096,), dtype=dtype, device=device, low=low, high=high) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker # Avoid division by zero which raises for integers and, for floats, 1071*da0073e9SAndroid Build Coastguard Worker # NumPy 1.20 changed floor_divide to follow IEEE rules for inf/nan 1072*da0073e9SAndroid Build Coastguard Worker # after dividing by zero. 1073*da0073e9SAndroid Build Coastguard Worker b[b == 0] = 1 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker # Compare bfloat16 against NumPy float 1076*da0073e9SAndroid Build Coastguard Worker exact_dtype = dtype != torch.bfloat16 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker if exact_dtype: 1079*da0073e9SAndroid Build Coastguard Worker an, bn = a.cpu().numpy(), b.cpu().numpy() 1080*da0073e9SAndroid Build Coastguard Worker else: 1081*da0073e9SAndroid Build Coastguard Worker an, bn = a.float().cpu().numpy(), b.float().cpu().numpy() 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker for mode, np_ref in ( 1084*da0073e9SAndroid Build Coastguard Worker (None, np.true_divide), 1085*da0073e9SAndroid Build Coastguard Worker ("floor", np.floor_divide), 1086*da0073e9SAndroid Build Coastguard Worker ("trunc", lambda a, b: np.trunc(np.true_divide(a, b)).astype(a.dtype)), 1087*da0073e9SAndroid Build Coastguard Worker ): 1088*da0073e9SAndroid Build Coastguard Worker expect = torch.from_numpy(np_ref(an, bn)) 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker kwargs = dict(rounding_mode=mode) if mode is not None else {} 1091*da0073e9SAndroid Build Coastguard Worker # Contiguous (likely vectorized) 1092*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1093*da0073e9SAndroid Build Coastguard Worker actual = torch.divide(a, b, **kwargs) 1094*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1095*da0073e9SAndroid Build Coastguard Worker actual, expect, exact_device=False, exact_dtype=exact_dtype 1096*da0073e9SAndroid Build Coastguard Worker ) 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker # Non-contiguous (not vectorized) 1099*da0073e9SAndroid Build Coastguard Worker expect = expect[::2] 1100*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 1101*da0073e9SAndroid Build Coastguard Worker actual = torch.divide(a[::2], b[::2], **kwargs) 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1104*da0073e9SAndroid Build Coastguard Worker actual, expect, exact_device=False, exact_dtype=exact_dtype 1105*da0073e9SAndroid Build Coastguard Worker ) 1106*da0073e9SAndroid Build Coastguard Worker 1107*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 1108*da0073e9SAndroid Build Coastguard Worker def test_complex_div_underflow_overflow(self, device, dtype): 1109*da0073e9SAndroid Build Coastguard Worker # test to make sure the complex division does not produce underflow or overflow 1110*da0073e9SAndroid Build Coastguard Worker # in the intermediate of its calculations 1111*da0073e9SAndroid Build Coastguard Worker # NOTE: the calculation still produces an error if the number is greater than 1112*da0073e9SAndroid Build Coastguard Worker # finfo.max / 2, but hopefully people realized that it's a dangerous region to work with 1113*da0073e9SAndroid Build Coastguard Worker finfo = torch.finfo(dtype) 1114*da0073e9SAndroid Build Coastguard Worker nom_lst = [ 1115*da0073e9SAndroid Build Coastguard Worker complex(finfo.min / 2, finfo.min / 2), 1116*da0073e9SAndroid Build Coastguard Worker complex(finfo.max / 2, finfo.max / 2), 1117*da0073e9SAndroid Build Coastguard Worker complex(finfo.tiny, finfo.tiny), 1118*da0073e9SAndroid Build Coastguard Worker complex(finfo.tiny, 0.0), 1119*da0073e9SAndroid Build Coastguard Worker complex(0.0, 0.0), 1120*da0073e9SAndroid Build Coastguard Worker ] 1121*da0073e9SAndroid Build Coastguard Worker denom_lst = [ 1122*da0073e9SAndroid Build Coastguard Worker complex(finfo.min / 2, finfo.min / 2), 1123*da0073e9SAndroid Build Coastguard Worker complex(finfo.max / 2, finfo.max / 2), 1124*da0073e9SAndroid Build Coastguard Worker complex(finfo.tiny, finfo.tiny), 1125*da0073e9SAndroid Build Coastguard Worker complex(0.0, finfo.tiny), 1126*da0073e9SAndroid Build Coastguard Worker complex(finfo.tiny, finfo.tiny), 1127*da0073e9SAndroid Build Coastguard Worker ] 1128*da0073e9SAndroid Build Coastguard Worker expected_lst = [ 1129*da0073e9SAndroid Build Coastguard Worker complex(1.0, 0.0), 1130*da0073e9SAndroid Build Coastguard Worker complex(1.0, 0.0), 1131*da0073e9SAndroid Build Coastguard Worker complex(1.0, 0.0), 1132*da0073e9SAndroid Build Coastguard Worker complex(0.0, -1.0), 1133*da0073e9SAndroid Build Coastguard Worker complex(0.0, 0.0), 1134*da0073e9SAndroid Build Coastguard Worker ] 1135*da0073e9SAndroid Build Coastguard Worker nom = torch.tensor(nom_lst, dtype=dtype, device=device) 1136*da0073e9SAndroid Build Coastguard Worker denom = torch.tensor(denom_lst, dtype=dtype, device=device) 1137*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor(expected_lst, dtype=dtype, device=device) 1138*da0073e9SAndroid Build Coastguard Worker res = nom / denom 1139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker # Tests that trying to add, inplace, a CUDA tensor to a CPU tensor 1142*da0073e9SAndroid Build Coastguard Worker # throws the correct error message 1143*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1144*da0073e9SAndroid Build Coastguard Worker def test_cross_device_inplace_error_msg(self, device): 1145*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(2.0) 1146*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(2.0, device=device) 1147*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1148*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 1149*da0073e9SAndroid Build Coastguard Worker ): 1150*da0073e9SAndroid Build Coastguard Worker a += b 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker # TODO: refactor this test into a more generic one, it's parked here currently 1153*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1154*da0073e9SAndroid Build Coastguard Worker def test_out_resize_warning(self, device): 1155*da0073e9SAndroid Build Coastguard Worker a = torch.tensor((1, 2, 3), device=device, dtype=torch.float32) 1156*da0073e9SAndroid Build Coastguard Worker b = torch.tensor((4, 5, 6), device=device, dtype=torch.float32) 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker unary_inputs = (a,) 1159*da0073e9SAndroid Build Coastguard Worker binary_inputs = (a, b) 1160*da0073e9SAndroid Build Coastguard Worker unary_ops = (torch.ceil, torch.exp) 1161*da0073e9SAndroid Build Coastguard Worker binary_ops = (torch.add, torch.sub) 1162*da0073e9SAndroid Build Coastguard Worker for op in unary_ops + binary_ops: 1163*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1164*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 1165*da0073e9SAndroid Build Coastguard Worker inputs = unary_inputs if op in unary_ops else binary_inputs 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker # No warnings 1168*da0073e9SAndroid Build Coastguard Worker op(*inputs, out=torch.empty(3, device=device)) 1169*da0073e9SAndroid Build Coastguard Worker op(*inputs, out=torch.empty(0, device=device)) 1170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker # Cases that throw warnings 1173*da0073e9SAndroid Build Coastguard Worker op(*inputs, out=torch.empty(2, device=device)) 1174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 1175*da0073e9SAndroid Build Coastguard Worker # test that multi-d out doesn't trigger segfault 1176*da0073e9SAndroid Build Coastguard Worker arg1 = (torch.ones(2, 1, device=device), torch.ones(1, device=device)) 1177*da0073e9SAndroid Build Coastguard Worker arg2 = (torch.ones(2, device=device), torch.ones(1, 1, device=device)) 1178*da0073e9SAndroid Build Coastguard Worker outs = ( 1179*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 1, 1, 1, device=device), 1180*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, 2, 2, device=device), 1181*da0073e9SAndroid Build Coastguard Worker ) 1182*da0073e9SAndroid Build Coastguard Worker 1183*da0073e9SAndroid Build Coastguard Worker for a1, a2, o in zip(arg1, arg2, outs): 1184*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 1185*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 1186*da0073e9SAndroid Build Coastguard Worker torch.mul(a1, a2, out=o) 1187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker # Verifies that the inplace dunders (like idiv) actually are in place 1190*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # UserWarning not triggered 1191*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1192*da0073e9SAndroid Build Coastguard Worker def test_inplace_dunders(self, device): 1193*da0073e9SAndroid Build Coastguard Worker t = torch.randn((1,), device=device) 1194*da0073e9SAndroid Build Coastguard Worker expected = t.data_ptr() 1195*da0073e9SAndroid Build Coastguard Worker t += 1 1196*da0073e9SAndroid Build Coastguard Worker t -= 1 1197*da0073e9SAndroid Build Coastguard Worker t *= 1 1198*da0073e9SAndroid Build Coastguard Worker t /= 1 1199*da0073e9SAndroid Build Coastguard Worker t **= 1 1200*da0073e9SAndroid Build Coastguard Worker t //= 1 1201*da0073e9SAndroid Build Coastguard Worker t %= 1 1202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, t.data_ptr()) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker def check_internal_mem_overlap( 1205*da0073e9SAndroid Build Coastguard Worker self, inplace_op, num_inputs, dtype, device, expected_failure=False 1206*da0073e9SAndroid Build Coastguard Worker ): 1207*da0073e9SAndroid Build Coastguard Worker if isinstance(inplace_op, str): 1208*da0073e9SAndroid Build Coastguard Worker inplace_op = getattr(torch.Tensor, inplace_op) 1209*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) 1210*da0073e9SAndroid Build Coastguard Worker inputs = [input] + [torch.randn_like(input) for i in range(num_inputs - 1)] 1211*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 1212*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "single memory location"): 1213*da0073e9SAndroid Build Coastguard Worker inplace_op(*inputs) 1214*da0073e9SAndroid Build Coastguard Worker else: 1215*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1216*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "single memory location"): 1217*da0073e9SAndroid Build Coastguard Worker inplace_op(*inputs) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker def unary_check_input_output_mem_overlap( 1220*da0073e9SAndroid Build Coastguard Worker self, data, sz, op, expected_failure=False 1221*da0073e9SAndroid Build Coastguard Worker ): 1222*da0073e9SAndroid Build Coastguard Worker def _test(op, output, input): 1223*da0073e9SAndroid Build Coastguard Worker output_exp = torch.empty_like(output) 1224*da0073e9SAndroid Build Coastguard Worker op(input, out=output_exp) 1225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker # output is identical to input: 1228*da0073e9SAndroid Build Coastguard Worker _test(op, output=data[0:sz], input=data[0:sz]) 1229*da0073e9SAndroid Build Coastguard Worker # output and input are independent: 1230*da0073e9SAndroid Build Coastguard Worker _test(op, output=data[0:sz], input=data[sz : 2 * sz]) 1231*da0073e9SAndroid Build Coastguard Worker # output partially overlaps with input: 1232*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 1233*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unsupported operation"): 1234*da0073e9SAndroid Build Coastguard Worker _test(op, data[0:sz], data[1 : sz + 1]) 1235*da0073e9SAndroid Build Coastguard Worker else: 1236*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1237*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "unsupported operation"): 1238*da0073e9SAndroid Build Coastguard Worker _test(op, data[0:sz], data[1 : sz + 1]) 1239*da0073e9SAndroid Build Coastguard Worker 1240*da0073e9SAndroid Build Coastguard Worker def binary_check_input_output_mem_overlap(self, op, device, expected_failure=False): 1241*da0073e9SAndroid Build Coastguard Worker sz = 3 1242*da0073e9SAndroid Build Coastguard Worker data = torch.randn(2 * sz, device=device) 1243*da0073e9SAndroid Build Coastguard Worker other = torch.randn(sz, device=device) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 1246*da0073e9SAndroid Build Coastguard Worker data, 1247*da0073e9SAndroid Build Coastguard Worker sz, 1248*da0073e9SAndroid Build Coastguard Worker lambda input, out: op(other, input, out=out), 1249*da0073e9SAndroid Build Coastguard Worker expected_failure=expected_failure, 1250*da0073e9SAndroid Build Coastguard Worker ) 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 1253*da0073e9SAndroid Build Coastguard Worker data, 1254*da0073e9SAndroid Build Coastguard Worker sz, 1255*da0073e9SAndroid Build Coastguard Worker lambda input, out: op(input, other, out=out), 1256*da0073e9SAndroid Build Coastguard Worker expected_failure=expected_failure, 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/126474 1260*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 1261*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1262*da0073e9SAndroid Build Coastguard Worker def test_binary_op_mem_overlap(self, device, dtype): 1263*da0073e9SAndroid Build Coastguard Worker ops = [ 1264*da0073e9SAndroid Build Coastguard Worker ("add", True, True, "cpu"), 1265*da0073e9SAndroid Build Coastguard Worker ("add", True, True, "cuda"), 1266*da0073e9SAndroid Build Coastguard Worker ("mul", True, True, "cpu"), 1267*da0073e9SAndroid Build Coastguard Worker ("mul", True, True, "cuda"), 1268*da0073e9SAndroid Build Coastguard Worker ("sub", True, True, "cpu"), 1269*da0073e9SAndroid Build Coastguard Worker ("sub", True, True, "cuda"), 1270*da0073e9SAndroid Build Coastguard Worker ("div", True, True, "cpu"), 1271*da0073e9SAndroid Build Coastguard Worker ("div", True, True, "cuda"), 1272*da0073e9SAndroid Build Coastguard Worker ("pow", True, True, "cpu"), 1273*da0073e9SAndroid Build Coastguard Worker ("pow", True, True, "cuda"), 1274*da0073e9SAndroid Build Coastguard Worker ("fmod", True, True, "cpu"), 1275*da0073e9SAndroid Build Coastguard Worker ("fmod", True, True, "cuda"), 1276*da0073e9SAndroid Build Coastguard Worker ("atan2", True, True, "cpu"), 1277*da0073e9SAndroid Build Coastguard Worker ("atan2", True, True, "cuda"), 1278*da0073e9SAndroid Build Coastguard Worker ("hypot", True, True, "cpu"), 1279*da0073e9SAndroid Build Coastguard Worker ("hypot", True, True, "cuda"), 1280*da0073e9SAndroid Build Coastguard Worker ("igamma", True, True, "cpu"), 1281*da0073e9SAndroid Build Coastguard Worker ("igamma", True, True, "cuda"), 1282*da0073e9SAndroid Build Coastguard Worker ("igammac", True, True, "cpu"), 1283*da0073e9SAndroid Build Coastguard Worker ("igammac", True, True, "cuda"), 1284*da0073e9SAndroid Build Coastguard Worker ("nextafter", True, True, "cpu"), 1285*da0073e9SAndroid Build Coastguard Worker ("nextafter", True, True, "cuda"), 1286*da0073e9SAndroid Build Coastguard Worker ("le", True, True, "cpu"), 1287*da0073e9SAndroid Build Coastguard Worker ("le", True, True, "cuda"), 1288*da0073e9SAndroid Build Coastguard Worker ("lt", True, True, "cpu"), 1289*da0073e9SAndroid Build Coastguard Worker ("lt", True, True, "cuda"), 1290*da0073e9SAndroid Build Coastguard Worker ("ge", True, True, "cpu"), 1291*da0073e9SAndroid Build Coastguard Worker ("ge", True, True, "cuda"), 1292*da0073e9SAndroid Build Coastguard Worker ("gt", True, True, "cpu"), 1293*da0073e9SAndroid Build Coastguard Worker ("gt", True, True, "cuda"), 1294*da0073e9SAndroid Build Coastguard Worker ("eq", True, True, "cpu"), 1295*da0073e9SAndroid Build Coastguard Worker ("eq", True, True, "cuda"), 1296*da0073e9SAndroid Build Coastguard Worker ("ne", True, True, "cpu"), 1297*da0073e9SAndroid Build Coastguard Worker ("ne", True, True, "cuda"), 1298*da0073e9SAndroid Build Coastguard Worker ("logical_and", True, True, "cpu"), 1299*da0073e9SAndroid Build Coastguard Worker ("logical_and", True, True, "cuda"), 1300*da0073e9SAndroid Build Coastguard Worker ("logical_or", True, True, "cpu"), 1301*da0073e9SAndroid Build Coastguard Worker ("logical_or", True, True, "cuda"), 1302*da0073e9SAndroid Build Coastguard Worker ("logical_xor", True, True, "cpu"), 1303*da0073e9SAndroid Build Coastguard Worker ("logical_xor", True, True, "cuda"), 1304*da0073e9SAndroid Build Coastguard Worker ] 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker for ( 1307*da0073e9SAndroid Build Coastguard Worker fn, 1308*da0073e9SAndroid Build Coastguard Worker has_input_output_mem_overlap_check, 1309*da0073e9SAndroid Build Coastguard Worker has_internal_mem_overlap_check, 1310*da0073e9SAndroid Build Coastguard Worker dev, 1311*da0073e9SAndroid Build Coastguard Worker ) in ops: 1312*da0073e9SAndroid Build Coastguard Worker if dev != device: 1313*da0073e9SAndroid Build Coastguard Worker continue 1314*da0073e9SAndroid Build Coastguard Worker out_op = getattr(torch, fn) 1315*da0073e9SAndroid Build Coastguard Worker inplace_op = getattr(torch.Tensor, fn + "_") 1316*da0073e9SAndroid Build Coastguard Worker self.check_internal_mem_overlap( 1317*da0073e9SAndroid Build Coastguard Worker inplace_op, 1318*da0073e9SAndroid Build Coastguard Worker 2, 1319*da0073e9SAndroid Build Coastguard Worker dtype, 1320*da0073e9SAndroid Build Coastguard Worker device, 1321*da0073e9SAndroid Build Coastguard Worker expected_failure=not has_internal_mem_overlap_check, 1322*da0073e9SAndroid Build Coastguard Worker ) 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker self.binary_check_input_output_mem_overlap( 1325*da0073e9SAndroid Build Coastguard Worker out_op, device, expected_failure=not has_input_output_mem_overlap_check 1326*da0073e9SAndroid Build Coastguard Worker ) 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker def _do_pow_for_exponents(self, m1, exponents, pow_fn, atol): 1329*da0073e9SAndroid Build Coastguard Worker for num in exponents: 1330*da0073e9SAndroid Build Coastguard Worker if ( 1331*da0073e9SAndroid Build Coastguard Worker isinstance(num, int) 1332*da0073e9SAndroid Build Coastguard Worker and num < 0 1333*da0073e9SAndroid Build Coastguard Worker and not m1.is_floating_point() 1334*da0073e9SAndroid Build Coastguard Worker and not m1.is_complex() 1335*da0073e9SAndroid Build Coastguard Worker ): 1336*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1337*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1338*da0073e9SAndroid Build Coastguard Worker r"Integers to negative integer powers are not allowed\.", 1339*da0073e9SAndroid Build Coastguard Worker ): 1340*da0073e9SAndroid Build Coastguard Worker torch.pow(m1[4], num) 1341*da0073e9SAndroid Build Coastguard Worker else: 1342*da0073e9SAndroid Build Coastguard Worker # base - tensor, exponent - number 1343*da0073e9SAndroid Build Coastguard Worker # contiguous 1344*da0073e9SAndroid Build Coastguard Worker res1 = torch.pow(m1[4], num) 1345*da0073e9SAndroid Build Coastguard Worker res2 = res1.clone().zero_() 1346*da0073e9SAndroid Build Coastguard Worker # `math.pow` has issues with complex exponentiation so we need to resort to normal `pow`. 1347*da0073e9SAndroid Build Coastguard Worker for i in range(res2.size(0)): 1348*da0073e9SAndroid Build Coastguard Worker res2[i] = pow_fn(m1[4][i], num) 1349*da0073e9SAndroid Build Coastguard Worker rtol = 0 if atol is not None else None 1350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2, atol=atol, rtol=rtol) 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker # non-contiguous 1353*da0073e9SAndroid Build Coastguard Worker res1 = torch.pow(m1[:, 4], num) 1354*da0073e9SAndroid Build Coastguard Worker res2 = res1.clone().zero_() 1355*da0073e9SAndroid Build Coastguard Worker for i in range(res2.size(0)): 1356*da0073e9SAndroid Build Coastguard Worker res2[i] = pow_fn(m1[i, 4], num) 1357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2, atol=atol, rtol=rtol) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker # scalar ** tensor to enforce correct handling of dtypes for __rpow__(). 1360*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.result_type(num, m1) 1361*da0073e9SAndroid Build Coastguard Worker res1 = num ** m1[4] 1362*da0073e9SAndroid Build Coastguard Worker res2 = ( 1363*da0073e9SAndroid Build Coastguard Worker torch.tensor(num, dtype=expected_dtype, device=m1.device) ** m1[4] 1364*da0073e9SAndroid Build Coastguard Worker ) 1365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1.dtype, expected_dtype) 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 1369*da0073e9SAndroid Build Coastguard Worker def test_pow(self, device, dtype): 1370*da0073e9SAndroid Build Coastguard Worker m1 = torch.empty(0, dtype=dtype, device=device) 1371*da0073e9SAndroid Build Coastguard Worker if m1.is_floating_point() or m1.is_complex(): 1372*da0073e9SAndroid Build Coastguard Worker m1 = ( 1373*da0073e9SAndroid Build Coastguard Worker make_tensor((100, 100), low=0, high=1, dtype=dtype, device=device) + 0.5 1374*da0073e9SAndroid Build Coastguard Worker ) 1375*da0073e9SAndroid Build Coastguard Worker else: 1376*da0073e9SAndroid Build Coastguard Worker # math.pow will overflow and throw exceptions for large integers 1377*da0073e9SAndroid Build Coastguard Worker range_high = 4 if dtype in (torch.int8, torch.uint8) else 10 1378*da0073e9SAndroid Build Coastguard Worker m1 = make_tensor( 1379*da0073e9SAndroid Build Coastguard Worker (100, 100), low=1, high=range_high, dtype=dtype, device=device 1380*da0073e9SAndroid Build Coastguard Worker ) 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker exponents = [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3, True, False] 1383*da0073e9SAndroid Build Coastguard Worker complex_exponents = [ 1384*da0073e9SAndroid Build Coastguard Worker -2.5j, 1385*da0073e9SAndroid Build Coastguard Worker -1.0j, 1386*da0073e9SAndroid Build Coastguard Worker 0j, 1387*da0073e9SAndroid Build Coastguard Worker 1.0j, 1388*da0073e9SAndroid Build Coastguard Worker 2.5j, 1389*da0073e9SAndroid Build Coastguard Worker 1.0 + 1.0j, 1390*da0073e9SAndroid Build Coastguard Worker -1.0 - 1.5j, 1391*da0073e9SAndroid Build Coastguard Worker 3.3j, 1392*da0073e9SAndroid Build Coastguard Worker ] 1393*da0073e9SAndroid Build Coastguard Worker if m1.is_complex(): 1394*da0073e9SAndroid Build Coastguard Worker self._do_pow_for_exponents(m1, exponents + complex_exponents, pow, 10e-4) 1395*da0073e9SAndroid Build Coastguard Worker else: 1396*da0073e9SAndroid Build Coastguard Worker self._do_pow_for_exponents(m1, exponents, math.pow, None) 1397*da0073e9SAndroid Build Coastguard Worker will_raise_error = ( 1398*da0073e9SAndroid Build Coastguard Worker dtype is torch.half and torch.device(device).type == "cpu" 1399*da0073e9SAndroid Build Coastguard Worker ) 1400*da0073e9SAndroid Build Coastguard Worker if will_raise_error: 1401*da0073e9SAndroid Build Coastguard Worker # On CPU, 1402*da0073e9SAndroid Build Coastguard Worker # Half Tensor with complex exponents leads to computation dtype 1403*da0073e9SAndroid Build Coastguard Worker # of ComplexHalf for which this ops is not supported yet 1404*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1405*da0073e9SAndroid Build Coastguard Worker RuntimeError, "not implemented for 'ComplexHalf'" 1406*da0073e9SAndroid Build Coastguard Worker ): 1407*da0073e9SAndroid Build Coastguard Worker self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) 1408*da0073e9SAndroid Build Coastguard Worker else: 1409*da0073e9SAndroid Build Coastguard Worker self._do_pow_for_exponents(m1, complex_exponents, pow, 10e-4) 1410*da0073e9SAndroid Build Coastguard Worker 1411*da0073e9SAndroid Build Coastguard Worker # base - number, exponent - tensor 1412*da0073e9SAndroid Build Coastguard Worker # contiguous 1413*da0073e9SAndroid Build Coastguard Worker res1 = torch.pow(3, m1[4]) 1414*da0073e9SAndroid Build Coastguard Worker res2 = res1.clone().zero_() 1415*da0073e9SAndroid Build Coastguard Worker for i in range(res2.size(0)): 1416*da0073e9SAndroid Build Coastguard Worker res2[i] = pow(3, m1[4, i]) 1417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker # non-contiguous 1420*da0073e9SAndroid Build Coastguard Worker res1 = torch.pow(3, m1[:, 4]) 1421*da0073e9SAndroid Build Coastguard Worker res2 = res1.clone().zero_() 1422*da0073e9SAndroid Build Coastguard Worker for i in range(res2.size(0)): 1423*da0073e9SAndroid Build Coastguard Worker res2[i] = pow(3, m1[i][4]) 1424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker # TODO: refactor all these tests using opinfos properly 1427*da0073e9SAndroid Build Coastguard Worker def _test_pow(self, base, exponent, np_exponent=None): 1428*da0073e9SAndroid Build Coastguard Worker if np_exponent is None: 1429*da0073e9SAndroid Build Coastguard Worker np_exponent = exponent 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker def to_np(value): 1432*da0073e9SAndroid Build Coastguard Worker if isinstance(value, torch.Tensor): 1433*da0073e9SAndroid Build Coastguard Worker return value.cpu().numpy() 1434*da0073e9SAndroid Build Coastguard Worker return value 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker try: 1437*da0073e9SAndroid Build Coastguard Worker np_res = np.power(to_np(base), to_np(np_exponent)) 1438*da0073e9SAndroid Build Coastguard Worker expected = ( 1439*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(np_res) 1440*da0073e9SAndroid Build Coastguard Worker if isinstance(np_res, np.ndarray) 1441*da0073e9SAndroid Build Coastguard Worker else torch.tensor(np_res, dtype=base.dtype) 1442*da0073e9SAndroid Build Coastguard Worker ) 1443*da0073e9SAndroid Build Coastguard Worker except ValueError as e: 1444*da0073e9SAndroid Build Coastguard Worker err_msg = "Integers to negative integer powers are not allowed." 1445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(e), err_msg) 1446*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(base) 1447*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1448*da0073e9SAndroid Build Coastguard Worker lambda: base.pow(exponent), 1449*da0073e9SAndroid Build Coastguard Worker lambda: base.pow_(exponent), 1450*da0073e9SAndroid Build Coastguard Worker lambda: torch.pow(base, exponent), 1451*da0073e9SAndroid Build Coastguard Worker lambda: torch.pow(base, exponent, out=out), 1452*da0073e9SAndroid Build Coastguard Worker ] 1453*da0073e9SAndroid Build Coastguard Worker for test_case in test_cases: 1454*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, err_msg, test_case) 1455*da0073e9SAndroid Build Coastguard Worker else: 1456*da0073e9SAndroid Build Coastguard Worker if isinstance(base, torch.Tensor): 1457*da0073e9SAndroid Build Coastguard Worker actual = base.pow(exponent) 1458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(actual)) 1459*da0073e9SAndroid Build Coastguard Worker actual = base.clone() 1460*da0073e9SAndroid Build Coastguard Worker # When base is a 0-dim cpu tensor and exp is a cuda tensor, we exp `pow` to work but `pow_` to fail, since 1461*da0073e9SAndroid Build Coastguard Worker # `pow` will try to create the output tensor on a cuda device, but `pow_` needs to use the cpu tensor as the output 1462*da0073e9SAndroid Build Coastguard Worker if ( 1463*da0073e9SAndroid Build Coastguard Worker isinstance(exponent, torch.Tensor) 1464*da0073e9SAndroid Build Coastguard Worker and base.dim() == 0 1465*da0073e9SAndroid Build Coastguard Worker and base.device.type == "cpu" 1466*da0073e9SAndroid Build Coastguard Worker and exponent.device.type == "cuda" 1467*da0073e9SAndroid Build Coastguard Worker ): 1468*da0073e9SAndroid Build Coastguard Worker regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!" 1469*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) 1470*da0073e9SAndroid Build Coastguard Worker elif torch.can_cast(torch.result_type(base, exponent), base.dtype): 1471*da0073e9SAndroid Build Coastguard Worker actual2 = actual.pow_(exponent) 1472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 1473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual2, expected) 1474*da0073e9SAndroid Build Coastguard Worker else: 1475*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1476*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1477*da0073e9SAndroid Build Coastguard Worker "Found dtype \\w+ but expected \\w+", 1478*da0073e9SAndroid Build Coastguard Worker lambda: actual.pow_(exponent), 1479*da0073e9SAndroid Build Coastguard Worker ) 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker actual = torch.pow(base, exponent) 1482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(actual)) 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker actual2 = torch.pow(base, exponent, out=actual) 1485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(actual)) 1486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual2, expected.to(actual)) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker # We can potentially merge this into OpInfo, but one blocker is that the 1489*da0073e9SAndroid Build Coastguard Worker # first input must be a scalar. It is not as simple as just wrapping this in 1490*da0073e9SAndroid Build Coastguard Worker # a lambada that switches the inputs, because we also want to test samples inputs 1491*da0073e9SAndroid Build Coastguard Worker # where the second input is a scalar. The wrapper would need some more logic. 1492*da0073e9SAndroid Build Coastguard Worker def test_pow_scalar_base(self, device): 1493*da0073e9SAndroid Build Coastguard Worker a = ( 1494*da0073e9SAndroid Build Coastguard Worker torch.arange(1, 13, dtype=torch.double, device=device) 1495*da0073e9SAndroid Build Coastguard Worker .view(3, 4) 1496*da0073e9SAndroid Build Coastguard Worker .requires_grad_() 1497*da0073e9SAndroid Build Coastguard Worker ) 1498*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda a: torch.pow(2, a), (a,)) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker # Tests pow() for integral, floating-type tensors, with integral, floating-type 1501*da0073e9SAndroid Build Coastguard Worker # exponents (tensor or scalar), respectively. noncontiguous tensors are also tested. 1502*da0073e9SAndroid Build Coastguard Worker def test_int_and_float_pow(self, device): 1503*da0073e9SAndroid Build Coastguard Worker def _test_int_and_float_pow(dt, low, high, dev): 1504*da0073e9SAndroid Build Coastguard Worker test_cases = ( 1505*da0073e9SAndroid Build Coastguard Worker ((4, 4), 0, (4, 1)), 1506*da0073e9SAndroid Build Coastguard Worker ((3, 1), 4, (3, 1)), 1507*da0073e9SAndroid Build Coastguard Worker ((2,), 4, (1,)), 1508*da0073e9SAndroid Build Coastguard Worker ((1,), 2, ()), 1509*da0073e9SAndroid Build Coastguard Worker ((513, 513), 4, (513,)), 1510*da0073e9SAndroid Build Coastguard Worker ((5, 5, 5), 5, (5,)), 1511*da0073e9SAndroid Build Coastguard Worker ((), 2, ()), 1512*da0073e9SAndroid Build Coastguard Worker ) 1513*da0073e9SAndroid Build Coastguard Worker for base_shape, exp_scalar, exp_shape in test_cases: 1514*da0073e9SAndroid Build Coastguard Worker base_tensor = make_tensor( 1515*da0073e9SAndroid Build Coastguard Worker base_shape, dtype=dt, device=dev, low=low, high=high 1516*da0073e9SAndroid Build Coastguard Worker ) 1517*da0073e9SAndroid Build Coastguard Worker # int tensors don't take negative exponents 1518*da0073e9SAndroid Build Coastguard Worker if dt in [ 1519*da0073e9SAndroid Build Coastguard Worker torch.uint8, 1520*da0073e9SAndroid Build Coastguard Worker torch.int8, 1521*da0073e9SAndroid Build Coastguard Worker torch.int16, 1522*da0073e9SAndroid Build Coastguard Worker torch.int32, 1523*da0073e9SAndroid Build Coastguard Worker torch.int64, 1524*da0073e9SAndroid Build Coastguard Worker ]: 1525*da0073e9SAndroid Build Coastguard Worker exp_tensor = make_tensor( 1526*da0073e9SAndroid Build Coastguard Worker exp_shape, dtype=dt, device=dev, low=0, high=high 1527*da0073e9SAndroid Build Coastguard Worker ) 1528*da0073e9SAndroid Build Coastguard Worker else: 1529*da0073e9SAndroid Build Coastguard Worker exp_tensor = make_tensor( 1530*da0073e9SAndroid Build Coastguard Worker exp_shape, dtype=dt, device=dev, low=low, high=high 1531*da0073e9SAndroid Build Coastguard Worker ) 1532*da0073e9SAndroid Build Coastguard Worker self._test_pow(base_tensor, exp_scalar) 1533*da0073e9SAndroid Build Coastguard Worker self._test_pow(base_tensor, exp_tensor) 1534*da0073e9SAndroid Build Coastguard Worker # test non-contiguous tensors as well 1535*da0073e9SAndroid Build Coastguard Worker base_tensor = make_tensor( 1536*da0073e9SAndroid Build Coastguard Worker base_shape, 1537*da0073e9SAndroid Build Coastguard Worker dtype=dt, 1538*da0073e9SAndroid Build Coastguard Worker device=dev, 1539*da0073e9SAndroid Build Coastguard Worker low=low, 1540*da0073e9SAndroid Build Coastguard Worker high=high, 1541*da0073e9SAndroid Build Coastguard Worker noncontiguous=True, 1542*da0073e9SAndroid Build Coastguard Worker ) 1543*da0073e9SAndroid Build Coastguard Worker if dt in [ 1544*da0073e9SAndroid Build Coastguard Worker torch.uint8, 1545*da0073e9SAndroid Build Coastguard Worker torch.int8, 1546*da0073e9SAndroid Build Coastguard Worker torch.int16, 1547*da0073e9SAndroid Build Coastguard Worker torch.int32, 1548*da0073e9SAndroid Build Coastguard Worker torch.int64, 1549*da0073e9SAndroid Build Coastguard Worker ]: 1550*da0073e9SAndroid Build Coastguard Worker exp_tensor = make_tensor( 1551*da0073e9SAndroid Build Coastguard Worker exp_shape, 1552*da0073e9SAndroid Build Coastguard Worker dtype=dt, 1553*da0073e9SAndroid Build Coastguard Worker device=dev, 1554*da0073e9SAndroid Build Coastguard Worker low=0, 1555*da0073e9SAndroid Build Coastguard Worker high=high, 1556*da0073e9SAndroid Build Coastguard Worker noncontiguous=True, 1557*da0073e9SAndroid Build Coastguard Worker ) 1558*da0073e9SAndroid Build Coastguard Worker else: 1559*da0073e9SAndroid Build Coastguard Worker exp_tensor = make_tensor( 1560*da0073e9SAndroid Build Coastguard Worker exp_shape, 1561*da0073e9SAndroid Build Coastguard Worker dtype=dt, 1562*da0073e9SAndroid Build Coastguard Worker device=dev, 1563*da0073e9SAndroid Build Coastguard Worker low=low, 1564*da0073e9SAndroid Build Coastguard Worker high=high, 1565*da0073e9SAndroid Build Coastguard Worker noncontiguous=True, 1566*da0073e9SAndroid Build Coastguard Worker ) 1567*da0073e9SAndroid Build Coastguard Worker self._test_pow(base_tensor, exp_scalar) 1568*da0073e9SAndroid Build Coastguard Worker self._test_pow(base_tensor, exp_tensor) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.int8, -2, 2, device) 1571*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.uint8, 0, 3, device) 1572*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.int16, -5, 5, device) 1573*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.int64, -10, 10, device) 1574*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.int32, -10, 10, device) 1575*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.float16, 0.0, 5.0, device) 1576*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.float32, 0.0, 10.0, device) 1577*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.float64, 0.0, 10.0, device) 1578*da0073e9SAndroid Build Coastguard Worker # pow's output would have some NaNs as well 1579*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.float32, -10.0, 10.0, device) 1580*da0073e9SAndroid Build Coastguard Worker _test_int_and_float_pow(torch.float64, -10.0, 10.0, device) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker # Tests that a Runtime error occurs when a base tensor cannot be resized 1583*da0073e9SAndroid Build Coastguard Worker # by pow's inplace variant due to PyTorch's broadcasting semantics. 1584*da0073e9SAndroid Build Coastguard Worker def test_pow_inplace_resizing_exception(self, device): 1585*da0073e9SAndroid Build Coastguard Worker test_cases = ( 1586*da0073e9SAndroid Build Coastguard Worker ((), (3,)), 1587*da0073e9SAndroid Build Coastguard Worker ((2,), (2, 1)), 1588*da0073e9SAndroid Build Coastguard Worker ((2, 1), (2, 2)), 1589*da0073e9SAndroid Build Coastguard Worker ((2, 2), (2, 1, 1)), 1590*da0073e9SAndroid Build Coastguard Worker ) 1591*da0073e9SAndroid Build Coastguard Worker test_inputs = [ 1592*da0073e9SAndroid Build Coastguard Worker ( 1593*da0073e9SAndroid Build Coastguard Worker make_tensor( 1594*da0073e9SAndroid Build Coastguard Worker base_size, dtype=torch.float64, device=device, high=10.0, low=0.0 1595*da0073e9SAndroid Build Coastguard Worker ), 1596*da0073e9SAndroid Build Coastguard Worker make_tensor( 1597*da0073e9SAndroid Build Coastguard Worker exp_size, dtype=torch.float64, device=device, high=10.0, low=0.0 1598*da0073e9SAndroid Build Coastguard Worker ), 1599*da0073e9SAndroid Build Coastguard Worker ) 1600*da0073e9SAndroid Build Coastguard Worker for base_size, exp_size in test_cases 1601*da0073e9SAndroid Build Coastguard Worker ] 1602*da0073e9SAndroid Build Coastguard Worker for base, exponent in test_inputs: 1603*da0073e9SAndroid Build Coastguard Worker regex = "doesn't match the broadcast shape" 1604*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker def test_int_tensor_pow_neg_ints(self, device): 1607*da0073e9SAndroid Build Coastguard Worker ints = [ 1608*da0073e9SAndroid Build Coastguard Worker torch.iinfo(torch.int32).min, 1609*da0073e9SAndroid Build Coastguard Worker -3, 1610*da0073e9SAndroid Build Coastguard Worker -2, 1611*da0073e9SAndroid Build Coastguard Worker -1, 1612*da0073e9SAndroid Build Coastguard Worker 0, 1613*da0073e9SAndroid Build Coastguard Worker 1, 1614*da0073e9SAndroid Build Coastguard Worker 2, 1615*da0073e9SAndroid Build Coastguard Worker 3, 1616*da0073e9SAndroid Build Coastguard Worker torch.iinfo(torch.int32).max, 1617*da0073e9SAndroid Build Coastguard Worker ] 1618*da0073e9SAndroid Build Coastguard Worker neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] 1619*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(ints, dtype=torch.int32, device=device) 1620*da0073e9SAndroid Build Coastguard Worker for pow in neg_ints: 1621*da0073e9SAndroid Build Coastguard Worker self._test_pow(tensor, pow) 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker def test_long_tensor_pow_floats(self, device): 1624*da0073e9SAndroid Build Coastguard Worker ints = [0, 1, 23, 4567] 1625*da0073e9SAndroid Build Coastguard Worker floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] 1626*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(ints, dtype=torch.int64, device=device) 1627*da0073e9SAndroid Build Coastguard Worker for pow in floats: 1628*da0073e9SAndroid Build Coastguard Worker self._test_pow(tensor, pow) 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker @dtypes(*[torch.float32, torch.float64]) 1631*da0073e9SAndroid Build Coastguard Worker def test_float_scalar_pow_float_tensor(self, device, dtype): 1632*da0073e9SAndroid Build Coastguard Worker floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] 1633*da0073e9SAndroid Build Coastguard Worker exponent_shapes = ( 1634*da0073e9SAndroid Build Coastguard Worker (1,), 1635*da0073e9SAndroid Build Coastguard Worker (2, 2), 1636*da0073e9SAndroid Build Coastguard Worker (2, 1), 1637*da0073e9SAndroid Build Coastguard Worker (2, 2, 2), 1638*da0073e9SAndroid Build Coastguard Worker ) 1639*da0073e9SAndroid Build Coastguard Worker tensors = [ 1640*da0073e9SAndroid Build Coastguard Worker make_tensor(shape, dtype=dtype, device=device, low=0) 1641*da0073e9SAndroid Build Coastguard Worker for shape in exponent_shapes 1642*da0073e9SAndroid Build Coastguard Worker ] 1643*da0073e9SAndroid Build Coastguard Worker floats_tensor = torch.tensor(floats, dtype=dtype, device=device) 1644*da0073e9SAndroid Build Coastguard Worker for base in floats: 1645*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, floats_tensor) 1646*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 1647*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, tensor) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1650*da0073e9SAndroid Build Coastguard Worker def test_cuda_tensor_pow_scalar_tensor(self, device): 1651*da0073e9SAndroid Build Coastguard Worker cuda_tensors = [ 1652*da0073e9SAndroid Build Coastguard Worker torch.randn((3, 3), device=device), 1653*da0073e9SAndroid Build Coastguard Worker torch.tensor(3.0, device=device), 1654*da0073e9SAndroid Build Coastguard Worker ] 1655*da0073e9SAndroid Build Coastguard Worker scalar_tensors = [ 1656*da0073e9SAndroid Build Coastguard Worker torch.tensor(5.0, device="cpu"), 1657*da0073e9SAndroid Build Coastguard Worker torch.tensor(-3), 1658*da0073e9SAndroid Build Coastguard Worker torch.tensor(1), 1659*da0073e9SAndroid Build Coastguard Worker ] 1660*da0073e9SAndroid Build Coastguard Worker for base, exp in product(cuda_tensors, scalar_tensors): 1661*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, exp) 1662*da0073e9SAndroid Build Coastguard Worker 1663*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1664*da0073e9SAndroid Build Coastguard Worker def test_cpu_tensor_pow_cuda_scalar_tensor(self, device): 1665*da0073e9SAndroid Build Coastguard Worker cuda_tensors = [ 1666*da0073e9SAndroid Build Coastguard Worker torch.tensor(5.0, device="cuda"), 1667*da0073e9SAndroid Build Coastguard Worker torch.tensor(-3, device="cuda"), 1668*da0073e9SAndroid Build Coastguard Worker ] 1669*da0073e9SAndroid Build Coastguard Worker for exp in cuda_tensors: 1670*da0073e9SAndroid Build Coastguard Worker base = torch.randn((3, 3), device="cpu") 1671*da0073e9SAndroid Build Coastguard Worker regex = "Expected all tensors to be on the same device, but found at least two devices, cuda.* and cpu!" 1672*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, regex, torch.pow, base, exp) 1673*da0073e9SAndroid Build Coastguard Worker for exp in cuda_tensors: 1674*da0073e9SAndroid Build Coastguard Worker # Binary ops with a cpu + cuda tensor are allowed if the cpu tensor has 0 dimension 1675*da0073e9SAndroid Build Coastguard Worker base = torch.tensor(3.0, device="cpu") 1676*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, exp) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1679*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.complex64, torch.complex128) 1680*da0073e9SAndroid Build Coastguard Worker def test_pow_cuda_complex_extremal_failing(self, device, dtype): 1681*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(complex(-1.0, float("inf")), dtype=dtype, device=device) 1682*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1683*da0073e9SAndroid Build Coastguard Worker cuda_out = t.pow(2) 1684*da0073e9SAndroid Build Coastguard Worker cpu_out = t.cpu().pow(2) 1685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cpu_out, cuda_out) 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 1688*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1689*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half)) 1690*da0073e9SAndroid Build Coastguard Worker def test_complex_scalar_pow_tensor(self, device, dtype): 1691*da0073e9SAndroid Build Coastguard Worker complexes = [0.5j, 1.0 + 1.0j, -1.5j, 2.2 - 1.6j, 1 + 0j] 1692*da0073e9SAndroid Build Coastguard Worker first_exp = make_tensor((100,), dtype=dtype, device=device, low=-2, high=2) 1693*da0073e9SAndroid Build Coastguard Worker second_exp = make_tensor( 1694*da0073e9SAndroid Build Coastguard Worker (100,), dtype=dtype, device=device, low=-2, high=2, noncontiguous=True 1695*da0073e9SAndroid Build Coastguard Worker ) 1696*da0073e9SAndroid Build Coastguard Worker first_exp[0] = first_exp[10] = first_exp[20] = 0 1697*da0073e9SAndroid Build Coastguard Worker second_exp[0] = second_exp[10] = second_exp[20] = 0 1698*da0073e9SAndroid Build Coastguard Worker for base in complexes: 1699*da0073e9SAndroid Build Coastguard Worker # On CPU, 1700*da0073e9SAndroid Build Coastguard Worker # Half Tensor with complex base leads to computation dtype 1701*da0073e9SAndroid Build Coastguard Worker # of ComplexHalf for which this ops is not supported yet 1702*da0073e9SAndroid Build Coastguard Worker # NOTE: pow has fast-path when base is 1 which supports 1703*da0073e9SAndroid Build Coastguard Worker # ComplexHalf 1704*da0073e9SAndroid Build Coastguard Worker will_raise_error = ( 1705*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == "cpu" 1706*da0073e9SAndroid Build Coastguard Worker and dtype is torch.half 1707*da0073e9SAndroid Build Coastguard Worker and base != (1 + 0j) 1708*da0073e9SAndroid Build Coastguard Worker ) 1709*da0073e9SAndroid Build Coastguard Worker if will_raise_error: 1710*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1711*da0073e9SAndroid Build Coastguard Worker RuntimeError, "not implemented for 'ComplexHalf'" 1712*da0073e9SAndroid Build Coastguard Worker ): 1713*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, first_exp) 1714*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, second_exp) 1715*da0073e9SAndroid Build Coastguard Worker else: 1716*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, first_exp) 1717*da0073e9SAndroid Build Coastguard Worker self._test_pow(base, second_exp) 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1720*da0073e9SAndroid Build Coastguard Worker @skipMeta 1721*da0073e9SAndroid Build Coastguard Worker def test_pow_scalar_type_promotion(self, device): 1722*da0073e9SAndroid Build Coastguard Worker # Test against a scalar and non-scalar input 1723*da0073e9SAndroid Build Coastguard Worker inputs = [17, [17]] 1724*da0073e9SAndroid Build Coastguard Worker for input in inputs: 1725*da0073e9SAndroid Build Coastguard Worker # We expect the computation to be performed in uint8 (overflowing to 0), and then cast to int64 1726*da0073e9SAndroid Build Coastguard Worker input_tensor_uint8 = torch.tensor(input, dtype=torch.uint8, device=device) 1727*da0073e9SAndroid Build Coastguard Worker out_uint8_computation = torch.pow( 1728*da0073e9SAndroid Build Coastguard Worker 2, 1729*da0073e9SAndroid Build Coastguard Worker input_tensor_uint8, 1730*da0073e9SAndroid Build Coastguard Worker out=torch.tensor(0, dtype=torch.int64, device=device), 1731*da0073e9SAndroid Build Coastguard Worker ) 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker # Computation should run in int64, and not overflow 1734*da0073e9SAndroid Build Coastguard Worker input_tensor_int64 = torch.tensor(input, dtype=torch.int64, device=device) 1735*da0073e9SAndroid Build Coastguard Worker out_int64_computation = torch.pow( 1736*da0073e9SAndroid Build Coastguard Worker 2, 1737*da0073e9SAndroid Build Coastguard Worker input_tensor_int64, 1738*da0073e9SAndroid Build Coastguard Worker out=torch.tensor(0, dtype=torch.int64, device=device), 1739*da0073e9SAndroid Build Coastguard Worker ) 1740*da0073e9SAndroid Build Coastguard Worker 1741*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(out_uint8_computation, out_int64_computation) 1742*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1743*da0073e9SAndroid Build Coastguard Worker out_uint8_computation.to(dtype=torch.uint8), 1744*da0073e9SAndroid Build Coastguard Worker out_int64_computation.to(dtype=torch.uint8), 1745*da0073e9SAndroid Build Coastguard Worker ) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker def test_tensor_pow_tensor(self, device): 1748*da0073e9SAndroid Build Coastguard Worker def rotate(l, n): 1749*da0073e9SAndroid Build Coastguard Worker return l[-n:] + l[:-n] 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker def test_tensor_pow_tensor(values, torch_type, numpy_type): 1752*da0073e9SAndroid Build Coastguard Worker vals_tensor = torch.tensor(values, dtype=torch_type, device=device) 1753*da0073e9SAndroid Build Coastguard Worker for i in range(len(values)): 1754*da0073e9SAndroid Build Coastguard Worker pows = rotate(values, i) 1755*da0073e9SAndroid Build Coastguard Worker pows_tensor = torch.tensor(pows, dtype=torch_type, device=device) 1756*da0073e9SAndroid Build Coastguard Worker self._test_pow(vals_tensor, pows_tensor) 1757*da0073e9SAndroid Build Coastguard Worker 1758*da0073e9SAndroid Build Coastguard Worker ints = [0, 1, 2, 3] 1759*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(ints, torch.uint8, np.uint8) 1760*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(ints, torch.int8, np.int8) 1761*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(ints, torch.int16, np.int16) 1762*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(ints, torch.int32, np.int32) 1763*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(ints, torch.int64, np.int64) 1764*da0073e9SAndroid Build Coastguard Worker 1765*da0073e9SAndroid Build Coastguard Worker floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, 0.0, 1 / 3, 1 / 2, 1.0, 2.0, 3.0] 1766*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(floats, torch.float16, np.float16) 1767*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(floats, torch.float32, np.float32) 1768*da0073e9SAndroid Build Coastguard Worker test_tensor_pow_tensor(floats, torch.float64, np.float64) 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Worker def test_logical_xor_with_nontrivial_alignment(self, device): 1771*da0073e9SAndroid Build Coastguard Worker # test tensor that is not aligned to multiple of 16 bytes 1772*da0073e9SAndroid Build Coastguard Worker size = 128 1773*da0073e9SAndroid Build Coastguard Worker a = torch.randn(size, device=device) > 0 1774*da0073e9SAndroid Build Coastguard Worker b = torch.randn(size, device=device) > 0 1775*da0073e9SAndroid Build Coastguard Worker c = torch.randn(size, device=device) > 0 1776*da0073e9SAndroid Build Coastguard Worker non_trivial_alignment = [1, 2, 4, 8, 15] 1777*da0073e9SAndroid Build Coastguard Worker for i in non_trivial_alignment: 1778*da0073e9SAndroid Build Coastguard Worker for j in non_trivial_alignment: 1779*da0073e9SAndroid Build Coastguard Worker for k in non_trivial_alignment: 1780*da0073e9SAndroid Build Coastguard Worker a_ = a[i : 100 + i] 1781*da0073e9SAndroid Build Coastguard Worker b_ = b[j : 100 + j] 1782*da0073e9SAndroid Build Coastguard Worker c_ = c[k : 100 + k] 1783*da0073e9SAndroid Build Coastguard Worker torch.logical_xor(a_, b_, out=c_) 1784*da0073e9SAndroid Build Coastguard Worker for x, y, z in zip(a_.tolist(), b_.tolist(), c_.tolist()): 1785*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x ^ y, z) 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 1788*da0073e9SAndroid Build Coastguard Worker def test_add_with_tail(self, device, dtype): 1789*da0073e9SAndroid Build Coastguard Worker # test tensor where there is a tail which is not a multiple 1790*da0073e9SAndroid Build Coastguard Worker # of GPU warp size 1791*da0073e9SAndroid Build Coastguard Worker for tail_size in [1, 63, 67, 130]: 1792*da0073e9SAndroid Build Coastguard Worker size = 4096 + tail_size 1793*da0073e9SAndroid Build Coastguard Worker a = torch.randn(size, device=device, dtype=dtype) 1794*da0073e9SAndroid Build Coastguard Worker b = torch.randn(size, device=device, dtype=dtype) 1795*da0073e9SAndroid Build Coastguard Worker c = a + b 1796*da0073e9SAndroid Build Coastguard Worker for x, y, z in zip(a.tolist(), b.tolist(), c.tolist()): 1797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x + y, z) 1798*da0073e9SAndroid Build Coastguard Worker 1799*da0073e9SAndroid Build Coastguard Worker # Tests that CUDA tensors on different devices cannot be used in the same 1800*da0073e9SAndroid Build Coastguard Worker # binary operation, and that CUDA "scalars" cannot be used in the same 1801*da0073e9SAndroid Build Coastguard Worker # binary operation as non-scalar CPU tensors. 1802*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 1803*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1804*da0073e9SAndroid Build Coastguard Worker def test_cross_device_binary_ops(self, devices): 1805*da0073e9SAndroid Build Coastguard Worker vals = (1.0, (2.0,)) 1806*da0073e9SAndroid Build Coastguard Worker cpu_tensor = torch.randn(2, 2) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker def do_test(op, a, b): 1809*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): 1810*da0073e9SAndroid Build Coastguard Worker op(a, b) 1811*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): 1812*da0073e9SAndroid Build Coastguard Worker op(b, a) 1813*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): 1814*da0073e9SAndroid Build Coastguard Worker op(a, cpu_tensor) 1815*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected all tensors.+"): 1816*da0073e9SAndroid Build Coastguard Worker op(cpu_tensor, a) 1817*da0073e9SAndroid Build Coastguard Worker 1818*da0073e9SAndroid Build Coastguard Worker for op in ( 1819*da0073e9SAndroid Build Coastguard Worker operator.add, 1820*da0073e9SAndroid Build Coastguard Worker torch.add, 1821*da0073e9SAndroid Build Coastguard Worker operator.sub, 1822*da0073e9SAndroid Build Coastguard Worker torch.sub, 1823*da0073e9SAndroid Build Coastguard Worker operator.mul, 1824*da0073e9SAndroid Build Coastguard Worker torch.mul, 1825*da0073e9SAndroid Build Coastguard Worker operator.truediv, 1826*da0073e9SAndroid Build Coastguard Worker torch.true_divide, 1827*da0073e9SAndroid Build Coastguard Worker operator.floordiv, 1828*da0073e9SAndroid Build Coastguard Worker torch.floor_divide, 1829*da0073e9SAndroid Build Coastguard Worker ): 1830*da0073e9SAndroid Build Coastguard Worker for a, b in product(vals, vals): 1831*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(a, device=devices[0]) 1832*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(b, device=devices[1]) 1833*da0073e9SAndroid Build Coastguard Worker 1834*da0073e9SAndroid Build Coastguard Worker do_test(op, a, b) 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker # This test ensures that a scalar Tensor can be safely used 1837*da0073e9SAndroid Build Coastguard Worker # in a binary operation in conjunction with a Tensor on all 1838*da0073e9SAndroid Build Coastguard Worker # available CUDA devices 1839*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 1840*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1841*da0073e9SAndroid Build Coastguard Worker def test_binary_op_scalar_device_unspecified(self, devices): 1842*da0073e9SAndroid Build Coastguard Worker scalar_val = torch.tensor(1.0) 1843*da0073e9SAndroid Build Coastguard Worker for default_device in devices: 1844*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(default_device): 1845*da0073e9SAndroid Build Coastguard Worker for device in devices: 1846*da0073e9SAndroid Build Coastguard Worker device_obj = torch.device(device) 1847*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, device=device) 1848*da0073e9SAndroid Build Coastguard Worker y0 = x * scalar_val 1849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y0.device, device_obj) 1850*da0073e9SAndroid Build Coastguard Worker y1 = scalar_val * x 1851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1.device, device_obj) 1852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y0, y1) 1853*da0073e9SAndroid Build Coastguard Worker 1854*da0073e9SAndroid Build Coastguard Worker def test_div_and_floordiv_vs_python(self, device): 1855*da0073e9SAndroid Build Coastguard Worker # Tests torch division ops which can handle both arguments being 1856*da0073e9SAndroid Build Coastguard Worker # scalars. 1857*da0073e9SAndroid Build Coastguard Worker def _scalar_helper(python_op, torch_op): 1858*da0073e9SAndroid Build Coastguard Worker for a, b in product(range(-10, 10), range(-10, 10)): 1859*da0073e9SAndroid Build Coastguard Worker for op in (lambda x: x * 0.5, lambda x: math.floor(x)): 1860*da0073e9SAndroid Build Coastguard Worker a = op(a) 1861*da0073e9SAndroid Build Coastguard Worker b = op(b) 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker # Skips zero divisors 1864*da0073e9SAndroid Build Coastguard Worker if b == 0: 1865*da0073e9SAndroid Build Coastguard Worker continue 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker expected = python_op(a, b) 1868*da0073e9SAndroid Build Coastguard Worker 1869*da0073e9SAndroid Build Coastguard Worker for op in (operator.truediv, torch.true_divide): 1870*da0073e9SAndroid Build Coastguard Worker actual_scalar = torch_op(a, b) 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker a_t = torch.tensor(a, device=device) 1873*da0073e9SAndroid Build Coastguard Worker b_t = torch.tensor(b, device=device) 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker actual_tensor = torch_op(a_t, b_t) 1876*da0073e9SAndroid Build Coastguard Worker actual_first_tensor = torch_op(a_t, b) 1877*da0073e9SAndroid Build Coastguard Worker actual_second_tensor = torch_op(a, b_t) 1878*da0073e9SAndroid Build Coastguard Worker 1879*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_scalar, expected) 1880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_tensor.item(), expected) 1881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_first_tensor, actual_tensor) 1882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_second_tensor, actual_tensor) 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker _scalar_helper(operator.truediv, operator.truediv) 1885*da0073e9SAndroid Build Coastguard Worker _scalar_helper(operator.truediv, torch.true_divide) 1886*da0073e9SAndroid Build Coastguard Worker _scalar_helper(lambda a, b: math.floor(a / b), operator.floordiv) 1887*da0073e9SAndroid Build Coastguard Worker _scalar_helper(lambda a, b: math.floor(a / b), torch.floor_divide) 1888*da0073e9SAndroid Build Coastguard Worker 1889*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1890*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1891*da0073e9SAndroid Build Coastguard Worker def test_div_and_floordiv_script_vs_python(self, device): 1892*da0073e9SAndroid Build Coastguard Worker # Creates jitted functions of two tensors 1893*da0073e9SAndroid Build Coastguard Worker def _wrapped_div(a, b): 1894*da0073e9SAndroid Build Coastguard Worker return a / b 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker def _wrapped_floordiv(a, b): 1897*da0073e9SAndroid Build Coastguard Worker return a // b 1898*da0073e9SAndroid Build Coastguard Worker 1899*da0073e9SAndroid Build Coastguard Worker scripted_div = torch.jit.script(_wrapped_div) 1900*da0073e9SAndroid Build Coastguard Worker scripted_floordiv = torch.jit.script(_wrapped_floordiv) 1901*da0073e9SAndroid Build Coastguard Worker for a, b in product(range(-10, 10), range(-10, 10)): 1902*da0073e9SAndroid Build Coastguard Worker for op in (lambda x: x * 0.5, lambda x: math.floor(x)): 1903*da0073e9SAndroid Build Coastguard Worker a = op(a) 1904*da0073e9SAndroid Build Coastguard Worker b = op(b) 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker # Skips zero divisors 1907*da0073e9SAndroid Build Coastguard Worker if b == 0: 1908*da0073e9SAndroid Build Coastguard Worker continue 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker expected_div = a / b 1911*da0073e9SAndroid Build Coastguard Worker expected_floordiv = math.floor(a / b) 1912*da0073e9SAndroid Build Coastguard Worker a_t = torch.tensor(a, device=device) 1913*da0073e9SAndroid Build Coastguard Worker b_t = torch.tensor(b, device=device) 1914*da0073e9SAndroid Build Coastguard Worker 1915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_div(a_t, b_t), expected_div) 1916*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_floordiv(a_t, b_t), expected_floordiv) 1917*da0073e9SAndroid Build Coastguard Worker 1918*da0073e9SAndroid Build Coastguard Worker # Creates jitted functions of one tensor 1919*da0073e9SAndroid Build Coastguard Worker def _wrapped_div_scalar(a): 1920*da0073e9SAndroid Build Coastguard Worker return a / 5 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker # NOTE: the JIT implements division as torch.reciprocal(a) * 5 1923*da0073e9SAndroid Build Coastguard Worker def _wrapped_rdiv_scalar(a): 1924*da0073e9SAndroid Build Coastguard Worker return 5 / a 1925*da0073e9SAndroid Build Coastguard Worker 1926*da0073e9SAndroid Build Coastguard Worker def _wrapped_floordiv_scalar(a): 1927*da0073e9SAndroid Build Coastguard Worker return a // 5 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker # NOTE: this fails if the input is not an integer tensor 1930*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/45199 1931*da0073e9SAndroid Build Coastguard Worker def _wrapped_rfloordiv_scalar(a): 1932*da0073e9SAndroid Build Coastguard Worker return 5 // a 1933*da0073e9SAndroid Build Coastguard Worker 1934*da0073e9SAndroid Build Coastguard Worker scripted_div_scalar = torch.jit.script(_wrapped_div_scalar) 1935*da0073e9SAndroid Build Coastguard Worker scripted_rdiv_scalar = torch.jit.script(_wrapped_rdiv_scalar) 1936*da0073e9SAndroid Build Coastguard Worker scripted_floordiv_scalar = torch.jit.script(_wrapped_floordiv_scalar) 1937*da0073e9SAndroid Build Coastguard Worker scripted_rfloordiv_scalar = torch.jit.script(_wrapped_rfloordiv_scalar) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker for a in range(-10, 10): 1940*da0073e9SAndroid Build Coastguard Worker for op in (lambda x: x * 0.5, lambda x: math.floor(x)): 1941*da0073e9SAndroid Build Coastguard Worker a = op(a) 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Worker a_t = torch.tensor(a, device=device) 1944*da0073e9SAndroid Build Coastguard Worker 1945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a / 5, scripted_div_scalar(a_t)) 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker # Skips zero divisors 1948*da0073e9SAndroid Build Coastguard Worker if a == 0: 1949*da0073e9SAndroid Build Coastguard Worker continue 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5 / a, scripted_rdiv_scalar(a_t)) 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker # Handles Issue 45199 (see comment above) 1954*da0073e9SAndroid Build Coastguard Worker if a_t.is_floating_point(): 1955*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 1956*da0073e9SAndroid Build Coastguard Worker scripted_rfloordiv_scalar(a_t) 1957*da0073e9SAndroid Build Coastguard Worker else: 1958*da0073e9SAndroid Build Coastguard Worker # This should emit a UserWarning, why doesn't it? 1959*da0073e9SAndroid Build Coastguard Worker # See issue gh-52387 1960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5 // a, scripted_rfloordiv_scalar(a_t)) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1963*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1964*da0073e9SAndroid Build Coastguard Worker def test_idiv_and_ifloordiv_vs_python(self, device): 1965*da0073e9SAndroid Build Coastguard Worker def _wrapped_idiv_tensor(a, b): 1966*da0073e9SAndroid Build Coastguard Worker a /= b 1967*da0073e9SAndroid Build Coastguard Worker return a 1968*da0073e9SAndroid Build Coastguard Worker 1969*da0073e9SAndroid Build Coastguard Worker def _wrapped_idiv_scalar(a): 1970*da0073e9SAndroid Build Coastguard Worker a /= 5 1971*da0073e9SAndroid Build Coastguard Worker return a 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker def _wrapped_true_divide__tensor(a, b): 1974*da0073e9SAndroid Build Coastguard Worker a.true_divide_(b) 1975*da0073e9SAndroid Build Coastguard Worker return a 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker def _wrapped_true_divide__scalar(a): 1978*da0073e9SAndroid Build Coastguard Worker a.true_divide_(5) 1979*da0073e9SAndroid Build Coastguard Worker return a 1980*da0073e9SAndroid Build Coastguard Worker 1981*da0073e9SAndroid Build Coastguard Worker def _wrapped_floor_divide__tensor(a, b): 1982*da0073e9SAndroid Build Coastguard Worker a.floor_divide_(b) 1983*da0073e9SAndroid Build Coastguard Worker return a 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker def _wrapped_floor_divide__scalar(a): 1986*da0073e9SAndroid Build Coastguard Worker a.floor_divide_(5) 1987*da0073e9SAndroid Build Coastguard Worker return a 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Worker # The following functions are unsupported by the JIT 1990*da0073e9SAndroid Build Coastguard Worker def _wrapped_ifloordiv_tensor(a, b): 1991*da0073e9SAndroid Build Coastguard Worker a //= b 1992*da0073e9SAndroid Build Coastguard Worker return a 1993*da0073e9SAndroid Build Coastguard Worker 1994*da0073e9SAndroid Build Coastguard Worker def _wrapped_ifloordiv_scalar(a): 1995*da0073e9SAndroid Build Coastguard Worker a //= 5 1996*da0073e9SAndroid Build Coastguard Worker return a 1997*da0073e9SAndroid Build Coastguard Worker 1998*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch.jit.frontend.NotSupportedError): 1999*da0073e9SAndroid Build Coastguard Worker scripted_ifloordiv_tensor = torch.jit.script(_wrapped_ifloordiv_tensor) 2000*da0073e9SAndroid Build Coastguard Worker 2001*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch.jit.frontend.NotSupportedError): 2002*da0073e9SAndroid Build Coastguard Worker scripted_ifloordiv_scalar = torch.jit.script(_wrapped_ifloordiv_scalar) 2003*da0073e9SAndroid Build Coastguard Worker 2004*da0073e9SAndroid Build Coastguard Worker scripted_idiv_tensor = torch.jit.script(_wrapped_idiv_tensor) 2005*da0073e9SAndroid Build Coastguard Worker scripted_idiv_scalar = torch.jit.script(_wrapped_idiv_scalar) 2006*da0073e9SAndroid Build Coastguard Worker scripted_true_divide__tensor = torch.jit.script(_wrapped_true_divide__tensor) 2007*da0073e9SAndroid Build Coastguard Worker scripted_true_divide__scalar = torch.jit.script(_wrapped_true_divide__scalar) 2008*da0073e9SAndroid Build Coastguard Worker scripted_floor_divide__tensor = torch.jit.script(_wrapped_floor_divide__tensor) 2009*da0073e9SAndroid Build Coastguard Worker scripted_floor_divide__scalar = torch.jit.script(_wrapped_floor_divide__scalar) 2010*da0073e9SAndroid Build Coastguard Worker 2011*da0073e9SAndroid Build Coastguard Worker for a, b in product(range(-10, 10), range(-10, 10)): 2012*da0073e9SAndroid Build Coastguard Worker for op in (lambda x: x * 0.5, lambda x: math.floor(x)): 2013*da0073e9SAndroid Build Coastguard Worker a = op(a) 2014*da0073e9SAndroid Build Coastguard Worker b = op(b) 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker # Skips zero divisors 2017*da0073e9SAndroid Build Coastguard Worker if b == 0: 2018*da0073e9SAndroid Build Coastguard Worker continue 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker expected_idiv = a / b 2021*da0073e9SAndroid Build Coastguard Worker expected_ifloordiv = a // b 2022*da0073e9SAndroid Build Coastguard Worker 2023*da0073e9SAndroid Build Coastguard Worker a_t = torch.tensor(a, device=device) 2024*da0073e9SAndroid Build Coastguard Worker b_t = torch.tensor(b, device=device) 2025*da0073e9SAndroid Build Coastguard Worker 2026*da0073e9SAndroid Build Coastguard Worker if a_t.is_floating_point(): 2027*da0073e9SAndroid Build Coastguard Worker tmp0 = a_t.clone() 2028*da0073e9SAndroid Build Coastguard Worker tmp0 /= b 2029*da0073e9SAndroid Build Coastguard Worker 2030*da0073e9SAndroid Build Coastguard Worker tmp1 = a_t.clone() 2031*da0073e9SAndroid Build Coastguard Worker tmp1 /= b_t 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp0.item(), expected_idiv) 2034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp1.item(), expected_idiv) 2035*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2036*da0073e9SAndroid Build Coastguard Worker scripted_true_divide__tensor(a_t.clone(), b_t).item(), 2037*da0073e9SAndroid Build Coastguard Worker expected_idiv, 2038*da0073e9SAndroid Build Coastguard Worker ) 2039*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2040*da0073e9SAndroid Build Coastguard Worker scripted_true_divide__scalar(a_t.clone()).item(), a / 5 2041*da0073e9SAndroid Build Coastguard Worker ) 2042*da0073e9SAndroid Build Coastguard Worker else: 2043*da0073e9SAndroid Build Coastguard Worker tmp = a_t.clone() 2044*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2045*da0073e9SAndroid Build Coastguard Worker tmp /= b 2046*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2047*da0073e9SAndroid Build Coastguard Worker tmp /= b_t 2048*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2049*da0073e9SAndroid Build Coastguard Worker scripted_true_divide__tensor(tmp, b_t) 2050*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2051*da0073e9SAndroid Build Coastguard Worker scripted_true_divide__scalar(tmp) 2052*da0073e9SAndroid Build Coastguard Worker 2053*da0073e9SAndroid Build Coastguard Worker if not a_t.is_floating_point() and b_t.is_floating_point(): 2054*da0073e9SAndroid Build Coastguard Worker # Inplace modification fails because a float tensor is required 2055*da0073e9SAndroid Build Coastguard Worker # if the divisor is a float tensor 2056*da0073e9SAndroid Build Coastguard Worker a_t.clone().floor_divide_(b_t) 2057*da0073e9SAndroid Build Coastguard Worker scripted_floor_divide__tensor(a_t.clone(), b_t) 2058*da0073e9SAndroid Build Coastguard Worker tmp = a_t.clone() 2059*da0073e9SAndroid Build Coastguard Worker tmp //= b_t 2060*da0073e9SAndroid Build Coastguard Worker else: 2061*da0073e9SAndroid Build Coastguard Worker # Inplace modification is OK when both or neither tensor is 2062*da0073e9SAndroid Build Coastguard Worker # a float tensor 2063*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2064*da0073e9SAndroid Build Coastguard Worker a_t.clone().floor_divide_(b_t).item(), expected_ifloordiv 2065*da0073e9SAndroid Build Coastguard Worker ) 2066*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2067*da0073e9SAndroid Build Coastguard Worker scripted_floor_divide__tensor(a_t.clone(), b_t).item(), 2068*da0073e9SAndroid Build Coastguard Worker expected_ifloordiv, 2069*da0073e9SAndroid Build Coastguard Worker ) 2070*da0073e9SAndroid Build Coastguard Worker tmp = a_t.clone() 2071*da0073e9SAndroid Build Coastguard Worker tmp //= b_t 2072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp.item(), expected_ifloordiv) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_floor_divide__scalar(a_t), math.floor(a / 5)) 2075*da0073e9SAndroid Build Coastguard Worker 2076*da0073e9SAndroid Build Coastguard Worker # Tests binary op equivalence with Python builtin ops 2077*da0073e9SAndroid Build Coastguard Worker # Also tests that reverse operations are equivalent to forward ops 2078*da0073e9SAndroid Build Coastguard Worker # NOTE: division ops are tested separately above 2079*da0073e9SAndroid Build Coastguard Worker def test_binary_ops_with_scalars(self, device): 2080*da0073e9SAndroid Build Coastguard Worker for python_op, torch_op in ( 2081*da0073e9SAndroid Build Coastguard Worker (operator.add, torch.add), 2082*da0073e9SAndroid Build Coastguard Worker (operator.sub, torch.sub), 2083*da0073e9SAndroid Build Coastguard Worker (operator.mul, torch.mul), 2084*da0073e9SAndroid Build Coastguard Worker (operator.truediv, torch.div), 2085*da0073e9SAndroid Build Coastguard Worker ): 2086*da0073e9SAndroid Build Coastguard Worker for a, b in product(range(-10, 10), range(-10, 10)): 2087*da0073e9SAndroid Build Coastguard Worker for op in (lambda x: x * 0.5, lambda x: math.floor(x)): 2088*da0073e9SAndroid Build Coastguard Worker a = op(a) 2089*da0073e9SAndroid Build Coastguard Worker b = op(b) 2090*da0073e9SAndroid Build Coastguard Worker 2091*da0073e9SAndroid Build Coastguard Worker # Skips zero divisors 2092*da0073e9SAndroid Build Coastguard Worker if b == 0 or a == 0: 2093*da0073e9SAndroid Build Coastguard Worker continue 2094*da0073e9SAndroid Build Coastguard Worker 2095*da0073e9SAndroid Build Coastguard Worker a_tensor = torch.tensor(a, device=device) 2096*da0073e9SAndroid Build Coastguard Worker b_tensor = torch.tensor(b, device=device) 2097*da0073e9SAndroid Build Coastguard Worker a_tensor_cpu = a_tensor.cpu() 2098*da0073e9SAndroid Build Coastguard Worker b_tensor_cpu = b_tensor.cpu() 2099*da0073e9SAndroid Build Coastguard Worker vals = (a, b, a_tensor, b_tensor, a_tensor_cpu, b_tensor_cpu) 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker for args in product(vals, vals): 2102*da0073e9SAndroid Build Coastguard Worker first, second = args 2103*da0073e9SAndroid Build Coastguard Worker 2104*da0073e9SAndroid Build Coastguard Worker first_scalar = ( 2105*da0073e9SAndroid Build Coastguard Worker first 2106*da0073e9SAndroid Build Coastguard Worker if not isinstance(first, torch.Tensor) 2107*da0073e9SAndroid Build Coastguard Worker else first.item() 2108*da0073e9SAndroid Build Coastguard Worker ) 2109*da0073e9SAndroid Build Coastguard Worker second_scalar = ( 2110*da0073e9SAndroid Build Coastguard Worker second 2111*da0073e9SAndroid Build Coastguard Worker if not isinstance(second, torch.Tensor) 2112*da0073e9SAndroid Build Coastguard Worker else second.item() 2113*da0073e9SAndroid Build Coastguard Worker ) 2114*da0073e9SAndroid Build Coastguard Worker expected = python_op(first_scalar, second_scalar) 2115*da0073e9SAndroid Build Coastguard Worker 2116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, python_op(first, second)) 2117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, torch_op(first, second)) 2118*da0073e9SAndroid Build Coastguard Worker 2119*da0073e9SAndroid Build Coastguard Worker @dtypes( 2120*da0073e9SAndroid Build Coastguard Worker *product( 2121*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bfloat16, torch.bool), 2122*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bfloat16, torch.bool), 2123*da0073e9SAndroid Build Coastguard Worker ) 2124*da0073e9SAndroid Build Coastguard Worker ) 2125*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_type_promotion(self, device, dtypes): 2126*da0073e9SAndroid Build Coastguard Worker a = torch.tensor((0, 1), device=device, dtype=dtypes[0]) 2127*da0073e9SAndroid Build Coastguard Worker b = torch.tensor((1, 0), device=device, dtype=dtypes[1]) 2128*da0073e9SAndroid Build Coastguard Worker for op in ( 2129*da0073e9SAndroid Build Coastguard Worker torch.maximum, 2130*da0073e9SAndroid Build Coastguard Worker torch.max, 2131*da0073e9SAndroid Build Coastguard Worker torch.fmax, 2132*da0073e9SAndroid Build Coastguard Worker torch.minimum, 2133*da0073e9SAndroid Build Coastguard Worker torch.min, 2134*da0073e9SAndroid Build Coastguard Worker torch.fmin, 2135*da0073e9SAndroid Build Coastguard Worker ): 2136*da0073e9SAndroid Build Coastguard Worker result = op(a, b) 2137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, torch.result_type(a, b)) 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker @dtypes(*integral_types_and(torch.bool)) 2140*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_int_and_bool(self, device, dtype): 2141*da0073e9SAndroid Build Coastguard Worker ops = ( 2142*da0073e9SAndroid Build Coastguard Worker (torch.maximum, torch.max, np.maximum), 2143*da0073e9SAndroid Build Coastguard Worker (torch.minimum, torch.min, np.minimum), 2144*da0073e9SAndroid Build Coastguard Worker (torch.fmax, None, np.fmax), 2145*da0073e9SAndroid Build Coastguard Worker (torch.fmin, None, np.fmin), 2146*da0073e9SAndroid Build Coastguard Worker ) 2147*da0073e9SAndroid Build Coastguard Worker rng = np.random.default_rng() 2148*da0073e9SAndroid Build Coastguard Worker a_np = np.array( 2149*da0073e9SAndroid Build Coastguard Worker rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype] 2150*da0073e9SAndroid Build Coastguard Worker ) 2151*da0073e9SAndroid Build Coastguard Worker b_np = np.array( 2152*da0073e9SAndroid Build Coastguard Worker rng.integers(-100, 100, size=10), dtype=torch_to_numpy_dtype_dict[dtype] 2153*da0073e9SAndroid Build Coastguard Worker ) 2154*da0073e9SAndroid Build Coastguard Worker 2155*da0073e9SAndroid Build Coastguard Worker for torch_op, alias, numpy_op in ops: 2156*da0073e9SAndroid Build Coastguard Worker a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) 2157*da0073e9SAndroid Build Coastguard Worker b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) 2158*da0073e9SAndroid Build Coastguard Worker tensor_result = torch_op(a_tensor, b_tensor) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a_tensor) 2161*da0073e9SAndroid Build Coastguard Worker torch_op(a_tensor, b_tensor, out=out) 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker numpy_result = numpy_op(a_np, b_np) 2164*da0073e9SAndroid Build Coastguard Worker 2165*da0073e9SAndroid Build Coastguard Worker if alias is not None: 2166*da0073e9SAndroid Build Coastguard Worker alias_result = alias(a_tensor, b_tensor) 2167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(alias_result, tensor_result) 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_result, numpy_result) 2170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, numpy_result) 2171*da0073e9SAndroid Build Coastguard Worker 2172*da0073e9SAndroid Build Coastguard Worker @precisionOverride({torch.bfloat16: 1e-2}) 2173*da0073e9SAndroid Build Coastguard Worker @dtypes(*(floating_types_and(torch.half, torch.bfloat16))) 2174*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_float(self, device, dtype): 2175*da0073e9SAndroid Build Coastguard Worker ops = ( 2176*da0073e9SAndroid Build Coastguard Worker (torch.maximum, torch.max, np.maximum), 2177*da0073e9SAndroid Build Coastguard Worker (torch.minimum, torch.min, np.minimum), 2178*da0073e9SAndroid Build Coastguard Worker (torch.fmax, None, np.fmax), 2179*da0073e9SAndroid Build Coastguard Worker (torch.fmin, None, np.fmin), 2180*da0073e9SAndroid Build Coastguard Worker ) 2181*da0073e9SAndroid Build Coastguard Worker 2182*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 2183*da0073e9SAndroid Build Coastguard Worker a_np = np.random.randn(10).astype(np.float64) 2184*da0073e9SAndroid Build Coastguard Worker b_np = np.random.randn(10).astype(np.float64) 2185*da0073e9SAndroid Build Coastguard Worker else: 2186*da0073e9SAndroid Build Coastguard Worker a_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) 2187*da0073e9SAndroid Build Coastguard Worker b_np = np.random.randn(10).astype(torch_to_numpy_dtype_dict[dtype]) 2188*da0073e9SAndroid Build Coastguard Worker 2189*da0073e9SAndroid Build Coastguard Worker for torch_op, alias, numpy_op in ops: 2190*da0073e9SAndroid Build Coastguard Worker numpy_result = numpy_op(a_np, b_np) 2191*da0073e9SAndroid Build Coastguard Worker 2192*da0073e9SAndroid Build Coastguard Worker a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) 2193*da0073e9SAndroid Build Coastguard Worker b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) 2194*da0073e9SAndroid Build Coastguard Worker tensor_result = torch_op(a_tensor, b_tensor) 2195*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a_tensor) 2196*da0073e9SAndroid Build Coastguard Worker torch_op(a_tensor, b_tensor, out=out) 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker if alias is not None: 2199*da0073e9SAndroid Build Coastguard Worker alias_result = alias(a_tensor, b_tensor) 2200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(alias_result, tensor_result, exact_dtype=False) 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_result, numpy_result, exact_dtype=False) 2203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, numpy_result, exact_dtype=False) 2204*da0073e9SAndroid Build Coastguard Worker 2205*da0073e9SAndroid Build Coastguard Worker @dtypes(*(floating_types_and(torch.half, torch.bfloat16))) 2206*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_float_nan_and_inf(self, device, dtype): 2207*da0073e9SAndroid Build Coastguard Worker # np.maximum and np.minimum functions compare input arrays element-wisely. 2208*da0073e9SAndroid Build Coastguard Worker # if one of the elements being compared is a NaN, then that element is returned. 2209*da0073e9SAndroid Build Coastguard Worker ops = ( 2210*da0073e9SAndroid Build Coastguard Worker (torch.maximum, torch.max, np.maximum), 2211*da0073e9SAndroid Build Coastguard Worker (torch.minimum, torch.min, np.minimum), 2212*da0073e9SAndroid Build Coastguard Worker (torch.fmax, None, np.fmax), 2213*da0073e9SAndroid Build Coastguard Worker (torch.fmin, None, np.fmin), 2214*da0073e9SAndroid Build Coastguard Worker ) 2215*da0073e9SAndroid Build Coastguard Worker a_vals = ( 2216*da0073e9SAndroid Build Coastguard Worker float("inf"), 2217*da0073e9SAndroid Build Coastguard Worker -float("inf"), 2218*da0073e9SAndroid Build Coastguard Worker float("nan"), 2219*da0073e9SAndroid Build Coastguard Worker float("inf"), 2220*da0073e9SAndroid Build Coastguard Worker float("nan"), 2221*da0073e9SAndroid Build Coastguard Worker float("nan"), 2222*da0073e9SAndroid Build Coastguard Worker 1, 2223*da0073e9SAndroid Build Coastguard Worker float("nan"), 2224*da0073e9SAndroid Build Coastguard Worker ) 2225*da0073e9SAndroid Build Coastguard Worker b_vals = ( 2226*da0073e9SAndroid Build Coastguard Worker -float("inf"), 2227*da0073e9SAndroid Build Coastguard Worker float("inf"), 2228*da0073e9SAndroid Build Coastguard Worker float("inf"), 2229*da0073e9SAndroid Build Coastguard Worker float("nan"), 2230*da0073e9SAndroid Build Coastguard Worker float("nan"), 2231*da0073e9SAndroid Build Coastguard Worker 0, 2232*da0073e9SAndroid Build Coastguard Worker float("nan"), 2233*da0073e9SAndroid Build Coastguard Worker -5, 2234*da0073e9SAndroid Build Coastguard Worker ) 2235*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 2236*da0073e9SAndroid Build Coastguard Worker a_np = np.array(a_vals, dtype=np.float64) 2237*da0073e9SAndroid Build Coastguard Worker b_np = np.array(b_vals, dtype=np.float64) 2238*da0073e9SAndroid Build Coastguard Worker else: 2239*da0073e9SAndroid Build Coastguard Worker a_np = np.array(a_vals, dtype=torch_to_numpy_dtype_dict[dtype]) 2240*da0073e9SAndroid Build Coastguard Worker b_np = np.array(b_vals, dtype=torch_to_numpy_dtype_dict[dtype]) 2241*da0073e9SAndroid Build Coastguard Worker 2242*da0073e9SAndroid Build Coastguard Worker for torch_op, alias, numpy_op in ops: 2243*da0073e9SAndroid Build Coastguard Worker numpy_result = numpy_op(a_np, b_np) 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker a_tensor = torch.from_numpy(a_np).to(device=device, dtype=dtype) 2246*da0073e9SAndroid Build Coastguard Worker b_tensor = torch.from_numpy(b_np).to(device=device, dtype=dtype) 2247*da0073e9SAndroid Build Coastguard Worker tensor_result = torch_op(a_tensor, b_tensor) 2248*da0073e9SAndroid Build Coastguard Worker 2249*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(a_tensor) 2250*da0073e9SAndroid Build Coastguard Worker torch_op(a_tensor, b_tensor, out=out) 2251*da0073e9SAndroid Build Coastguard Worker 2252*da0073e9SAndroid Build Coastguard Worker if alias is not None: 2253*da0073e9SAndroid Build Coastguard Worker alias_result = alias(a_tensor, b_tensor) 2254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(alias_result, tensor_result) 2255*da0073e9SAndroid Build Coastguard Worker 2256*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 2257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_result, numpy_result, exact_dtype=False) 2258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, numpy_result, exact_dtype=False) 2259*da0073e9SAndroid Build Coastguard Worker else: 2260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_result, numpy_result) 2261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, numpy_result) 2262*da0073e9SAndroid Build Coastguard Worker 2263*da0073e9SAndroid Build Coastguard Worker @dtypes( 2264*da0073e9SAndroid Build Coastguard Worker *product( 2265*da0073e9SAndroid Build Coastguard Worker complex_types(), 2266*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2267*da0073e9SAndroid Build Coastguard Worker ) 2268*da0073e9SAndroid Build Coastguard Worker ) 2269*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_complex(self, device, dtypes): 2270*da0073e9SAndroid Build Coastguard Worker for torch_op in ( 2271*da0073e9SAndroid Build Coastguard Worker torch.maximum, 2272*da0073e9SAndroid Build Coastguard Worker torch.minimum, 2273*da0073e9SAndroid Build Coastguard Worker torch.max, 2274*da0073e9SAndroid Build Coastguard Worker torch.min, 2275*da0073e9SAndroid Build Coastguard Worker torch.fmax, 2276*da0073e9SAndroid Build Coastguard Worker torch.fmin, 2277*da0073e9SAndroid Build Coastguard Worker ): 2278*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"): 2279*da0073e9SAndroid Build Coastguard Worker torch_op( 2280*da0073e9SAndroid Build Coastguard Worker torch.ones(1, device=device, dtype=dtypes[0]), 2281*da0073e9SAndroid Build Coastguard Worker torch.ones(1, device=device, dtype=dtypes[1]), 2282*da0073e9SAndroid Build Coastguard Worker ) 2283*da0073e9SAndroid Build Coastguard Worker 2284*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, ".+not implemented for.+"): 2285*da0073e9SAndroid Build Coastguard Worker torch_op( 2286*da0073e9SAndroid Build Coastguard Worker torch.ones(1, device=device, dtype=dtypes[1]), 2287*da0073e9SAndroid Build Coastguard Worker torch.ones(1, device=device, dtype=dtypes[0]), 2288*da0073e9SAndroid Build Coastguard Worker ) 2289*da0073e9SAndroid Build Coastguard Worker 2290*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2291*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_cross_device(self, device): 2292*da0073e9SAndroid Build Coastguard Worker a = torch.tensor((1, 2, -1)) 2293*da0073e9SAndroid Build Coastguard Worker b = torch.tensor((3, 0, 4), device=device) 2294*da0073e9SAndroid Build Coastguard Worker ops = (torch.maximum, torch.minimum) 2295*da0073e9SAndroid Build Coastguard Worker 2296*da0073e9SAndroid Build Coastguard Worker for torch_op in ops: 2297*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2298*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 2299*da0073e9SAndroid Build Coastguard Worker ): 2300*da0073e9SAndroid Build Coastguard Worker torch_op(a, b) 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2303*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 2304*da0073e9SAndroid Build Coastguard Worker ): 2305*da0073e9SAndroid Build Coastguard Worker torch_op(b, a) 2306*da0073e9SAndroid Build Coastguard Worker 2307*da0073e9SAndroid Build Coastguard Worker # test cuda tensor and cpu scalar 2308*da0073e9SAndroid Build Coastguard Worker ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum)) 2309*da0073e9SAndroid Build Coastguard Worker a_np = np.array(1) 2310*da0073e9SAndroid Build Coastguard Worker b_np = np.array([3, 0, 4]) 2311*da0073e9SAndroid Build Coastguard Worker 2312*da0073e9SAndroid Build Coastguard Worker for torch_op, numpy_op in ops: 2313*da0073e9SAndroid Build Coastguard Worker a_tensor = torch.from_numpy(a_np) 2314*da0073e9SAndroid Build Coastguard Worker b_tensor = torch.from_numpy(b_np).to(device=device) 2315*da0073e9SAndroid Build Coastguard Worker tensor_result_1 = torch_op(a_tensor, b_tensor) 2316*da0073e9SAndroid Build Coastguard Worker numpy_result_1 = numpy_op(a_np, b_np) 2317*da0073e9SAndroid Build Coastguard Worker tensor_result_2 = torch_op(b_tensor, a_tensor) 2318*da0073e9SAndroid Build Coastguard Worker numpy_result_2 = numpy_op(b_np, a_np) 2319*da0073e9SAndroid Build Coastguard Worker 2320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_result_1, numpy_result_1) 2321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_result_2, numpy_result_2) 2322*da0073e9SAndroid Build Coastguard Worker 2323*da0073e9SAndroid Build Coastguard Worker @dtypes( 2324*da0073e9SAndroid Build Coastguard Worker *product( 2325*da0073e9SAndroid Build Coastguard Worker floating_types_and(torch.half, torch.bfloat16), 2326*da0073e9SAndroid Build Coastguard Worker floating_types_and(torch.half, torch.bfloat16), 2327*da0073e9SAndroid Build Coastguard Worker ) 2328*da0073e9SAndroid Build Coastguard Worker ) 2329*da0073e9SAndroid Build Coastguard Worker def test_maximum_and_minimum_subgradient(self, device, dtypes): 2330*da0073e9SAndroid Build Coastguard Worker def run_test(f, a, b, expected_a_grad, expected_b_grad): 2331*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(a, requires_grad=True, device=device, dtype=dtypes[0]) 2332*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(b, requires_grad=True, device=device, dtype=dtypes[1]) 2333*da0073e9SAndroid Build Coastguard Worker z = f(a, b) 2334*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 2335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, expected_a_grad) 2336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, expected_b_grad) 2337*da0073e9SAndroid Build Coastguard Worker 2338*da0073e9SAndroid Build Coastguard Worker run_test( 2339*da0073e9SAndroid Build Coastguard Worker torch.maximum, 2340*da0073e9SAndroid Build Coastguard Worker [0.0, 1.0, 2.0], 2341*da0073e9SAndroid Build Coastguard Worker [1.0, 1.0, 1.0], 2342*da0073e9SAndroid Build Coastguard Worker [0.0, 0.5, 1.0], 2343*da0073e9SAndroid Build Coastguard Worker [1.0, 0.5, 0.0], 2344*da0073e9SAndroid Build Coastguard Worker ) 2345*da0073e9SAndroid Build Coastguard Worker run_test( 2346*da0073e9SAndroid Build Coastguard Worker torch.minimum, 2347*da0073e9SAndroid Build Coastguard Worker [0.0, 1.0, 2.0], 2348*da0073e9SAndroid Build Coastguard Worker [1.0, 1.0, 1.0], 2349*da0073e9SAndroid Build Coastguard Worker [1.0, 0.5, 0.0], 2350*da0073e9SAndroid Build Coastguard Worker [0.0, 0.5, 1.0], 2351*da0073e9SAndroid Build Coastguard Worker ) 2352*da0073e9SAndroid Build Coastguard Worker 2353*da0073e9SAndroid Build Coastguard Worker def test_maximum_minimum_forward_ad_float32(self, device): 2354*da0073e9SAndroid Build Coastguard Worker # TODO: This should really be covered by OpInfo but it isn't. The problem 2355*da0073e9SAndroid Build Coastguard Worker # is that our gradient tests test using float64 but it should also test 2356*da0073e9SAndroid Build Coastguard Worker # float32 2357*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device=device, dtype=torch.float32) 2358*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device=device, dtype=torch.float32) 2359*da0073e9SAndroid Build Coastguard Worker tx = torch.randn(3, device=device, dtype=torch.float32) 2360*da0073e9SAndroid Build Coastguard Worker ty = torch.randn(3, device=device, dtype=torch.float32) 2361*da0073e9SAndroid Build Coastguard Worker 2362*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 2363*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, tx) 2364*da0073e9SAndroid Build Coastguard Worker y_dual = fwAD.make_dual(y, ty) 2365*da0073e9SAndroid Build Coastguard Worker result = torch.maximum(x_dual, y_dual) 2366*da0073e9SAndroid Build Coastguard Worker _, result_tangent = fwAD.unpack_dual(result) 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker expected = torch.where(x > y, tx, ty) 2369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_tangent, expected) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 2372*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, tx) 2373*da0073e9SAndroid Build Coastguard Worker y_dual = fwAD.make_dual(y, ty) 2374*da0073e9SAndroid Build Coastguard Worker result = torch.minimum(x_dual, y_dual) 2375*da0073e9SAndroid Build Coastguard Worker _, result_tangent = fwAD.unpack_dual(result) 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Worker expected = torch.where(x < y, tx, ty) 2378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result_tangent, expected) 2379*da0073e9SAndroid Build Coastguard Worker 2380*da0073e9SAndroid Build Coastguard Worker # TODO: tests like this should be generic 2381*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 2382*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2383*da0073e9SAndroid Build Coastguard Worker def test_mul_intertype_scalar(self, device, dtype): 2384*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1.5, dtype=dtype, device=device) 2385*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(3, dtype=torch.int32, device=device) 2386*da0073e9SAndroid Build Coastguard Worker 2387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x * y, 4.5) 2388*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y * x, 4.5) 2389*da0073e9SAndroid Build Coastguard Worker 2390*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2391*da0073e9SAndroid Build Coastguard Worker RuntimeError, "can't be cast to the desired output type" 2392*da0073e9SAndroid Build Coastguard Worker ): 2393*da0073e9SAndroid Build Coastguard Worker y *= x 2394*da0073e9SAndroid Build Coastguard Worker x *= y 2395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, 4.5) 2396*da0073e9SAndroid Build Coastguard Worker 2397*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2398*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 2399*da0073e9SAndroid Build Coastguard Worker def test_sub(self, device, dtype): 2400*da0073e9SAndroid Build Coastguard Worker if dtype in integral_types(): 2401*da0073e9SAndroid Build Coastguard Worker # Before Python 3.10, floats were implicitly converted to ints, but with 2402*da0073e9SAndroid Build Coastguard Worker # DeprecationWarning: an integer is required (got type float). 2403*da0073e9SAndroid Build Coastguard Worker # Implicit conversion to integers using __int__ is deprecated, 2404*da0073e9SAndroid Build Coastguard Worker # and may be removed in a future version of Python. 2405*da0073e9SAndroid Build Coastguard Worker # Since Python 3.10, that attempt gives an error. 2406*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([2, 4], dtype=dtype, device=device) 2407*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([1, 2], dtype=dtype, device=device) 2408*da0073e9SAndroid Build Coastguard Worker diff = torch.tensor([1, 2], dtype=dtype) 2409*da0073e9SAndroid Build Coastguard Worker else: 2410*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([2.34, 4.44], dtype=dtype, device=device) 2411*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([1.23, 2.33], dtype=dtype, device=device) 2412*da0073e9SAndroid Build Coastguard Worker diff = torch.tensor([1.11, 2.11], dtype=dtype) 2413*da0073e9SAndroid Build Coastguard Worker 2414*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 2415*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: m1 - m2) 2416*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bfloat16 or dtype == torch.half: 2417*da0073e9SAndroid Build Coastguard Worker # bfloat16 has a lower precision so we have to have a separate check for it 2418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1 - m2, diff, atol=0.01, rtol=0) 2419*da0073e9SAndroid Build Coastguard Worker else: 2420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1 - m2, diff) 2421*da0073e9SAndroid Build Coastguard Worker 2422*da0073e9SAndroid Build Coastguard Worker # TODO: what is this test testing? 2423*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2424*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 2425*da0073e9SAndroid Build Coastguard Worker def test_csub(self, device, dtype): 2426*da0073e9SAndroid Build Coastguard Worker # with a tensor 2427*da0073e9SAndroid Build Coastguard Worker a = torch.randn(100, 90, dtype=dtype, device=device) 2428*da0073e9SAndroid Build Coastguard Worker b = a.clone().normal_() 2429*da0073e9SAndroid Build Coastguard Worker 2430*da0073e9SAndroid Build Coastguard Worker res_add = torch.add(a, b, alpha=-1) 2431*da0073e9SAndroid Build Coastguard Worker res_csub = a.clone() 2432*da0073e9SAndroid Build Coastguard Worker res_csub.sub_(b) 2433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_add, res_csub) 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker # with a scalar 2436*da0073e9SAndroid Build Coastguard Worker a = torch.randn(100, 100, dtype=dtype, device=device) 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker scalar = 123.5 2439*da0073e9SAndroid Build Coastguard Worker res_add = torch.add(a, -scalar) 2440*da0073e9SAndroid Build Coastguard Worker res_csub = a.clone() 2441*da0073e9SAndroid Build Coastguard Worker res_csub.sub_(scalar) 2442*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_add, res_csub) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker # TODO: reconcile with minimum/maximum tests 2445*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double) 2446*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2447*da0073e9SAndroid Build Coastguard Worker def test_min_max_binary_op_nan(self, device, dtype): 2448*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1000, dtype=dtype, device=device) 2449*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1000, dtype=dtype, device=device) 2450*da0073e9SAndroid Build Coastguard Worker 2451*da0073e9SAndroid Build Coastguard Worker # 0:250: a -- nan, b -- not nan 2452*da0073e9SAndroid Build Coastguard Worker a[:250] = float("nan") 2453*da0073e9SAndroid Build Coastguard Worker # 250:500: a -- not nan, b -- nan 2454*da0073e9SAndroid Build Coastguard Worker b[250:500] = float("nan") 2455*da0073e9SAndroid Build Coastguard Worker # 500:750: a and b both nan 2456*da0073e9SAndroid Build Coastguard Worker a[500:750] = float("nan") 2457*da0073e9SAndroid Build Coastguard Worker b[500:750] = float("nan") 2458*da0073e9SAndroid Build Coastguard Worker # 750:1000: neither nan 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker ma = torch.max(a, b) 2461*da0073e9SAndroid Build Coastguard Worker mi = torch.min(a, b) 2462*da0073e9SAndroid Build Coastguard Worker 2463*da0073e9SAndroid Build Coastguard Worker for i in range(750): 2464*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2465*da0073e9SAndroid Build Coastguard Worker torch.isnan(ma[i]), 2466*da0073e9SAndroid Build Coastguard Worker f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}", 2467*da0073e9SAndroid Build Coastguard Worker ) 2468*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2469*da0073e9SAndroid Build Coastguard Worker torch.isnan(mi[i]), 2470*da0073e9SAndroid Build Coastguard Worker f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}", 2471*da0073e9SAndroid Build Coastguard Worker ) 2472*da0073e9SAndroid Build Coastguard Worker 2473*da0073e9SAndroid Build Coastguard Worker for i in range(750, 1000): 2474*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 2475*da0073e9SAndroid Build Coastguard Worker torch.isnan(ma[i]), 2476*da0073e9SAndroid Build Coastguard Worker f"max(a, b): {ma[i]}, a: {a[i]}, b: {b[i]}", 2477*da0073e9SAndroid Build Coastguard Worker ) 2478*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 2479*da0073e9SAndroid Build Coastguard Worker torch.isnan(mi[i]), 2480*da0073e9SAndroid Build Coastguard Worker f"min(a, b): {mi[i]}, a: {a[i]}, b: {b[i]}", 2481*da0073e9SAndroid Build Coastguard Worker ) 2482*da0073e9SAndroid Build Coastguard Worker 2483*da0073e9SAndroid Build Coastguard Worker @dtypes( 2484*da0073e9SAndroid Build Coastguard Worker *product( 2485*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bfloat16, torch.bool), 2486*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bfloat16, torch.bool), 2487*da0073e9SAndroid Build Coastguard Worker ) 2488*da0073e9SAndroid Build Coastguard Worker ) 2489*da0073e9SAndroid Build Coastguard Worker def test_copysign(self, device, dtypes): 2490*da0073e9SAndroid Build Coastguard Worker def _test_copysign_numpy(a, b): 2491*da0073e9SAndroid Build Coastguard Worker torch_result = torch.copysign(a, b) 2492*da0073e9SAndroid Build Coastguard Worker 2493*da0073e9SAndroid Build Coastguard Worker if a.dtype == torch.bfloat16: 2494*da0073e9SAndroid Build Coastguard Worker np_a = a.to(torch.float).cpu().numpy() 2495*da0073e9SAndroid Build Coastguard Worker else: 2496*da0073e9SAndroid Build Coastguard Worker np_a = a.cpu().numpy() 2497*da0073e9SAndroid Build Coastguard Worker 2498*da0073e9SAndroid Build Coastguard Worker if b.dtype == torch.bfloat16: 2499*da0073e9SAndroid Build Coastguard Worker np_b = b.to(torch.float).cpu().numpy() 2500*da0073e9SAndroid Build Coastguard Worker else: 2501*da0073e9SAndroid Build Coastguard Worker np_b = b.cpu().numpy() 2502*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.copysign(np_a, np_b)) 2503*da0073e9SAndroid Build Coastguard Worker # To handle inconsistencies of type promotion between PyTorch and Numpy 2504*da0073e9SAndroid Build Coastguard Worker # Applied for both arguments having integral precision and bfloat16 2505*da0073e9SAndroid Build Coastguard Worker types = integral_types_and(torch.bool, torch.bfloat16) 2506*da0073e9SAndroid Build Coastguard Worker if a.dtype in types or b.dtype in types: 2507*da0073e9SAndroid Build Coastguard Worker promoted_type = torch.promote_types(torch_result.dtype, expected.dtype) 2508*da0073e9SAndroid Build Coastguard Worker torch_result = torch_result.to(promoted_type) 2509*da0073e9SAndroid Build Coastguard Worker expected = expected.to(promoted_type) 2510*da0073e9SAndroid Build Coastguard Worker 2511*da0073e9SAndroid Build Coastguard Worker # Verify Value 2512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch_result, expected) 2513*da0073e9SAndroid Build Coastguard Worker # Verify Sign 2514*da0073e9SAndroid Build Coastguard Worker # Use double copysign to verify the correctnes of 0.0 and -0.0, since 2515*da0073e9SAndroid Build Coastguard Worker # it always True for self.assertEqual(0.0 == -0.0). So, we use 1 as the 2516*da0073e9SAndroid Build Coastguard Worker # magnitude to verify the sign between torch and numpy results, elementwise. 2517*da0073e9SAndroid Build Coastguard Worker # Special case: NaN conversions between FP32 and FP16 is not bitwise 2518*da0073e9SAndroid Build Coastguard Worker # equivalent to pass this assertion. 2519*da0073e9SAndroid Build Coastguard Worker if a.dtype != torch.float16 and b.dtype != torch.float16: 2520*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2521*da0073e9SAndroid Build Coastguard Worker torch.copysign(torch.tensor(1.0), torch_result), 2522*da0073e9SAndroid Build Coastguard Worker torch.copysign(torch.tensor(1.0), expected), 2523*da0073e9SAndroid Build Coastguard Worker ) 2524*da0073e9SAndroid Build Coastguard Worker 2525*da0073e9SAndroid Build Coastguard Worker # Compare Result with NumPy 2526*da0073e9SAndroid Build Coastguard Worker # Type promotion 2527*da0073e9SAndroid Build Coastguard Worker a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) 2528*da0073e9SAndroid Build Coastguard Worker b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) 2529*da0073e9SAndroid Build Coastguard Worker _test_copysign_numpy(a, b) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker # Broadcast 2532*da0073e9SAndroid Build Coastguard Worker a = make_tensor((10, 1, 10), device=device, dtype=dtypes[0], low=-9, high=9) 2533*da0073e9SAndroid Build Coastguard Worker b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) 2534*da0073e9SAndroid Build Coastguard Worker _test_copysign_numpy(a, b) 2535*da0073e9SAndroid Build Coastguard Worker 2536*da0073e9SAndroid Build Coastguard Worker a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) 2537*da0073e9SAndroid Build Coastguard Worker b = make_tensor((10, 1, 10), device=device, dtype=dtypes[1], low=-9, high=9) 2538*da0073e9SAndroid Build Coastguard Worker _test_copysign_numpy(a, b) 2539*da0073e9SAndroid Build Coastguard Worker 2540*da0073e9SAndroid Build Coastguard Worker # 0.0/-0.0/inf/-inf/nan 2541*da0073e9SAndroid Build Coastguard Worker cases = [0.0, -0.0, float("inf"), float("-inf"), float("nan")] 2542*da0073e9SAndroid Build Coastguard Worker # torch.bfloat16 can not hold '-nan' 2543*da0073e9SAndroid Build Coastguard Worker # torch.half can not hold '-nan' on CUDA 2544*da0073e9SAndroid Build Coastguard Worker types = [torch.float32, torch.float64] 2545*da0073e9SAndroid Build Coastguard Worker if device == "cpu": 2546*da0073e9SAndroid Build Coastguard Worker types.append(torch.float16) 2547*da0073e9SAndroid Build Coastguard Worker if dtypes[0] in types: 2548*da0073e9SAndroid Build Coastguard Worker b = make_tensor((10, 10), device=device, dtype=dtypes[1], low=-9, high=9) 2549*da0073e9SAndroid Build Coastguard Worker for case in cases: 2550*da0073e9SAndroid Build Coastguard Worker _test_copysign_numpy( 2551*da0073e9SAndroid Build Coastguard Worker torch.tensor([case], device=device, dtype=dtypes[0]), b 2552*da0073e9SAndroid Build Coastguard Worker ) 2553*da0073e9SAndroid Build Coastguard Worker 2554*da0073e9SAndroid Build Coastguard Worker if dtypes[1] in floating_types_and(torch.half, torch.bfloat16): 2555*da0073e9SAndroid Build Coastguard Worker a = make_tensor((10, 10), device=device, dtype=dtypes[0], low=-9, high=9) 2556*da0073e9SAndroid Build Coastguard Worker for case in cases: 2557*da0073e9SAndroid Build Coastguard Worker _test_copysign_numpy( 2558*da0073e9SAndroid Build Coastguard Worker a, torch.tensor([case], device=device, dtype=dtypes[1]) 2559*da0073e9SAndroid Build Coastguard Worker ) 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker @dtypes( 2562*da0073e9SAndroid Build Coastguard Worker *product( 2563*da0073e9SAndroid Build Coastguard Worker floating_types_and(torch.half, torch.bfloat16), 2564*da0073e9SAndroid Build Coastguard Worker floating_types_and(torch.half, torch.bfloat16), 2565*da0073e9SAndroid Build Coastguard Worker ) 2566*da0073e9SAndroid Build Coastguard Worker ) 2567*da0073e9SAndroid Build Coastguard Worker def test_copysign_subgradient(self, device, dtypes): 2568*da0073e9SAndroid Build Coastguard Worker # Input is 0.0 2569*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 2570*da0073e9SAndroid Build Coastguard Worker [0.0, 0.0, 0.0], dtype=dtypes[0], device=device, requires_grad=True 2571*da0073e9SAndroid Build Coastguard Worker ) 2572*da0073e9SAndroid Build Coastguard Worker y = torch.tensor( 2573*da0073e9SAndroid Build Coastguard Worker [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True 2574*da0073e9SAndroid Build Coastguard Worker ) 2575*da0073e9SAndroid Build Coastguard Worker out = torch.copysign(x, y) 2576*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) 2578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.tolist(), [0.0] * 3) 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker # Input is -0.0 2581*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 2582*da0073e9SAndroid Build Coastguard Worker [-0.0, -0.0, -0.0], dtype=dtypes[0], device=device, requires_grad=True 2583*da0073e9SAndroid Build Coastguard Worker ) 2584*da0073e9SAndroid Build Coastguard Worker y = torch.tensor( 2585*da0073e9SAndroid Build Coastguard Worker [-1.0, 0.0, 1.0], dtype=dtypes[1], device=device, requires_grad=True 2586*da0073e9SAndroid Build Coastguard Worker ) 2587*da0073e9SAndroid Build Coastguard Worker out = torch.copysign(x, y) 2588*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad.tolist(), [0.0, 0.0, 0.0]) 2590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.tolist(), [0.0] * 3) 2591*da0073e9SAndroid Build Coastguard Worker 2592*da0073e9SAndroid Build Coastguard Worker # Other is 0.0 2593*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 2594*da0073e9SAndroid Build Coastguard Worker [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True 2595*da0073e9SAndroid Build Coastguard Worker ) 2596*da0073e9SAndroid Build Coastguard Worker y = torch.tensor( 2597*da0073e9SAndroid Build Coastguard Worker [0.0, 0.0, 0.0], dtype=dtypes[1], device=device, requires_grad=True 2598*da0073e9SAndroid Build Coastguard Worker ) 2599*da0073e9SAndroid Build Coastguard Worker out = torch.copysign(x, y) 2600*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad.tolist(), [-1.0, 0.0, 1.0]) 2602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.tolist(), [0.0] * 3) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker # Other is -0.0 2605*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 2606*da0073e9SAndroid Build Coastguard Worker [-1.0, 0.0, 1.0], dtype=dtypes[0], device=device, requires_grad=True 2607*da0073e9SAndroid Build Coastguard Worker ) 2608*da0073e9SAndroid Build Coastguard Worker y = torch.tensor( 2609*da0073e9SAndroid Build Coastguard Worker [-0.0, -0.0, -0.0], dtype=dtypes[1], device=device, requires_grad=True 2610*da0073e9SAndroid Build Coastguard Worker ) 2611*da0073e9SAndroid Build Coastguard Worker out = torch.copysign(x, y) 2612*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 2613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad.tolist(), [1.0, 0.0, -1.0]) 2614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.tolist(), [0.0] * 3) 2615*da0073e9SAndroid Build Coastguard Worker 2616*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float) 2617*da0073e9SAndroid Build Coastguard Worker def test_div(self, device, dtype): 2618*da0073e9SAndroid Build Coastguard Worker for op, method, inplace in ( 2619*da0073e9SAndroid Build Coastguard Worker (torch.div, torch.Tensor.div, torch.Tensor.div_), 2620*da0073e9SAndroid Build Coastguard Worker (torch.true_divide, torch.Tensor.true_divide, torch.Tensor.true_divide_), 2621*da0073e9SAndroid Build Coastguard Worker ): 2622*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype) 2623*da0073e9SAndroid Build Coastguard Worker res1 = m1.clone() 2624*da0073e9SAndroid Build Coastguard Worker inplace(res1[:, 3], 2) 2625*da0073e9SAndroid Build Coastguard Worker res2 = m1.clone() 2626*da0073e9SAndroid Build Coastguard Worker for i in range(m1.size(0)): 2627*da0073e9SAndroid Build Coastguard Worker res2[i, 3] = res2[i, 3] / 2 2628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 2631*da0073e9SAndroid Build Coastguard Worker a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) 2632*da0073e9SAndroid Build Coastguard Worker a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device) 2633*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2634*da0073e9SAndroid Build Coastguard Worker op(a1, a2), 2635*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.1, 3.1], dtype=dtype, device=device), 2636*da0073e9SAndroid Build Coastguard Worker atol=0.01, 2637*da0073e9SAndroid Build Coastguard Worker rtol=0, 2638*da0073e9SAndroid Build Coastguard Worker ) 2639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(method(a1, a2), op(a1, a2)) 2640*da0073e9SAndroid Build Coastguard Worker 2641*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float) 2642*da0073e9SAndroid Build Coastguard Worker def test_true_divide_out(self, device, dtype): 2643*da0073e9SAndroid Build Coastguard Worker a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device) 2644*da0073e9SAndroid Build Coastguard Worker a2 = torch.tensor([2.0, 2.0], dtype=dtype, device=device) 2645*da0073e9SAndroid Build Coastguard Worker res = torch.empty_like(a1) 2646*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2647*da0073e9SAndroid Build Coastguard Worker torch.true_divide(a1, a2, out=res), 2648*da0073e9SAndroid Build Coastguard Worker torch.tensor([2.1, 3.1], dtype=dtype, device=device), 2649*da0073e9SAndroid Build Coastguard Worker atol=0.01, 2650*da0073e9SAndroid Build Coastguard Worker rtol=0, 2651*da0073e9SAndroid Build Coastguard Worker ) 2652*da0073e9SAndroid Build Coastguard Worker 2653*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) 2654*da0073e9SAndroid Build Coastguard Worker def test_divmul_scalar(self, device, dtype): 2655*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(100.0, device=device, dtype=dtype) 2656*da0073e9SAndroid Build Coastguard Worker x_ref = x.float() 2657*da0073e9SAndroid Build Coastguard Worker scale = 1e5 2658*da0073e9SAndroid Build Coastguard Worker res = x.div(scale) 2659*da0073e9SAndroid Build Coastguard Worker expected = x_ref.div(scale) 2660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0) 2661*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1e-5, device=device, dtype=dtype) 2662*da0073e9SAndroid Build Coastguard Worker x_ref = x.float() 2663*da0073e9SAndroid Build Coastguard Worker res = x.mul(scale) 2664*da0073e9SAndroid Build Coastguard Worker expected = x_ref.mul(scale) 2665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0) 2666*da0073e9SAndroid Build Coastguard Worker res = scale * x 2667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected.to(dtype), atol=0.0, rtol=0.0) 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 2670*da0073e9SAndroid Build Coastguard Worker *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128} 2671*da0073e9SAndroid Build Coastguard Worker ) 2672*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128}) 2673*da0073e9SAndroid Build Coastguard Worker def test_floor_divide_tensor(self, device, dtype): 2674*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, device=device).mul(30).to(dtype) 2675*da0073e9SAndroid Build Coastguard Worker y = torch.arange(1, 11, dtype=dtype, device=device) 2676*da0073e9SAndroid Build Coastguard Worker 2677*da0073e9SAndroid Build Coastguard Worker z = x // y 2678*da0073e9SAndroid Build Coastguard Worker z_alt = torch.floor(x.double() / y.double()).to(dtype) 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.dtype, x.dtype) 2681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z_alt) 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 2684*da0073e9SAndroid Build Coastguard Worker *set(get_all_math_dtypes("cuda")) - {torch.complex64, torch.complex128} 2685*da0073e9SAndroid Build Coastguard Worker ) 2686*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(get_all_math_dtypes("cpu")) - {torch.complex64, torch.complex128}) 2687*da0073e9SAndroid Build Coastguard Worker def test_floor_divide_scalar(self, device, dtype): 2688*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100, device=device).mul(10).to(dtype) 2689*da0073e9SAndroid Build Coastguard Worker 2690*da0073e9SAndroid Build Coastguard Worker z = x // 3 2691*da0073e9SAndroid Build Coastguard Worker z_alt = torch.tensor( 2692*da0073e9SAndroid Build Coastguard Worker [math.floor(v.item() / 3.0) for v in x], dtype=x.dtype, device=device 2693*da0073e9SAndroid Build Coastguard Worker ) 2694*da0073e9SAndroid Build Coastguard Worker 2695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.dtype, x.dtype) 2696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z_alt) 2697*da0073e9SAndroid Build Coastguard Worker 2698*da0073e9SAndroid Build Coastguard Worker @onlyCPU 2699*da0073e9SAndroid Build Coastguard Worker @dtypes(*get_all_math_dtypes("cpu")) 2700*da0073e9SAndroid Build Coastguard Worker def test_rdiv(self, device, dtype): 2701*da0073e9SAndroid Build Coastguard Worker if dtype is torch.float16: 2702*da0073e9SAndroid Build Coastguard Worker return 2703*da0073e9SAndroid Build Coastguard Worker elif dtype.is_complex: 2704*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, dtype=dtype, device=device).add(1).mul(4) 2705*da0073e9SAndroid Build Coastguard Worker else: 2706*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, device=device).add(1).mul(4).to(dtype) 2707*da0073e9SAndroid Build Coastguard Worker y = 30 / x 2708*da0073e9SAndroid Build Coastguard Worker z = torch.tensor([30 / v.item() for v in x], device=device) 2709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, z, exact_dtype=False) 2710*da0073e9SAndroid Build Coastguard Worker 2711*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 2712*da0073e9SAndroid Build Coastguard Worker def test_fmod_remainder_by_zero_float(self, device, dtype): 2713*da0073e9SAndroid Build Coastguard Worker fn_list = (torch.fmod, torch.remainder) 2714*da0073e9SAndroid Build Coastguard Worker for fn in fn_list: 2715*da0073e9SAndroid Build Coastguard Worker # check floating-point tensor fmod/remainder to zero is nan on both CPU and GPU 2716*da0073e9SAndroid Build Coastguard Worker x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) 2717*da0073e9SAndroid Build Coastguard Worker zero = torch.zeros_like(x) 2718*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(fn(x, 0.0).isnan())) 2719*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(fn(x, zero).isnan())) 2720*da0073e9SAndroid Build Coastguard Worker 2721*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes # Check Issue https://github.com/pytorch/pytorch/issues/48130 2722*da0073e9SAndroid Build Coastguard Worker @dtypes(*integral_types()) 2723*da0073e9SAndroid Build Coastguard Worker def test_fmod_remainder_by_zero_integral(self, device, dtype): 2724*da0073e9SAndroid Build Coastguard Worker fn_list = (torch.fmod, torch.remainder) 2725*da0073e9SAndroid Build Coastguard Worker for fn in fn_list: 2726*da0073e9SAndroid Build Coastguard Worker # check integral tensor fmod/remainder to zero 2727*da0073e9SAndroid Build Coastguard Worker x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) 2728*da0073e9SAndroid Build Coastguard Worker zero = torch.zeros_like(x) 2729*da0073e9SAndroid Build Coastguard Worker # RuntimeError on CPU 2730*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 2731*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): 2732*da0073e9SAndroid Build Coastguard Worker fn(x, zero) 2733*da0073e9SAndroid Build Coastguard Worker elif torch.version.hip is not None: 2734*da0073e9SAndroid Build Coastguard Worker # ROCm behavior: x % 0 is a no-op; x is returned 2735*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, zero), x) 2736*da0073e9SAndroid Build Coastguard Worker else: 2737*da0073e9SAndroid Build Coastguard Worker # CUDA behavior: Different value for different dtype 2738*da0073e9SAndroid Build Coastguard Worker # Due to it's an undefined behavior, CUDA returns a pattern of all 1s 2739*da0073e9SAndroid Build Coastguard Worker # for integral dividend (other than int64) divided by zero. For int64, 2740*da0073e9SAndroid Build Coastguard Worker # CUDA returns all 1s for negative dividend, half 1s for positive dividend. 2741*da0073e9SAndroid Build Coastguard Worker # uint8: 0xff -> 255 2742*da0073e9SAndroid Build Coastguard Worker # int32: 0xffffffff -> -1 2743*da0073e9SAndroid Build Coastguard Worker if dtype == torch.int64: 2744*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, zero) == 4294967295, x >= 0) 2745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, zero) == -1, x < 0) 2746*da0073e9SAndroid Build Coastguard Worker else: 2747*da0073e9SAndroid Build Coastguard Worker value = 255 if dtype == torch.uint8 else -1 2748*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(fn(x, zero) == value)) 2749*da0073e9SAndroid Build Coastguard Worker 2750*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half)) 2751*da0073e9SAndroid Build Coastguard Worker def test_fmod_remainder(self, device, dtype): 2752*da0073e9SAndroid Build Coastguard Worker # Use numpy as reference 2753*da0073e9SAndroid Build Coastguard Worker def _helper(x, mod, fns_list): 2754*da0073e9SAndroid Build Coastguard Worker for fn, inplace_fn, ref_fn in fns_list: 2755*da0073e9SAndroid Build Coastguard Worker np_x = x.cpu().numpy() if torch.is_tensor(x) else x 2756*da0073e9SAndroid Build Coastguard Worker np_mod = mod.cpu().numpy() if torch.is_tensor(mod) else mod 2757*da0073e9SAndroid Build Coastguard Worker exp = ref_fn(np_x, np_mod) 2758*da0073e9SAndroid Build Coastguard Worker exp = torch.from_numpy(exp) 2759*da0073e9SAndroid Build Coastguard Worker res = fn(x, mod) 2760*da0073e9SAndroid Build Coastguard Worker 2761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, exp, exact_dtype=False) 2762*da0073e9SAndroid Build Coastguard Worker 2763*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(x): 2764*da0073e9SAndroid Build Coastguard Worker # out 2765*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=device, dtype=res.dtype) 2766*da0073e9SAndroid Build Coastguard Worker fn(x, mod, out=out) 2767*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, exp, exact_dtype=False) 2768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.size(), torch.Size([10, 10])) 2769*da0073e9SAndroid Build Coastguard Worker # in-place (Type cast runtime error) 2770*da0073e9SAndroid Build Coastguard Worker try: 2771*da0073e9SAndroid Build Coastguard Worker inplace_fn(x, mod) 2772*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, exp, exact_dtype=False) 2773*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 2774*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 2775*da0073e9SAndroid Build Coastguard Worker str(e), 2776*da0073e9SAndroid Build Coastguard Worker "result type (Half|Float|Double) " 2777*da0073e9SAndroid Build Coastguard Worker "can't be cast to the desired output " 2778*da0073e9SAndroid Build Coastguard Worker "type (Byte|Char|Short|Int|Long)", 2779*da0073e9SAndroid Build Coastguard Worker ) 2780*da0073e9SAndroid Build Coastguard Worker 2781*da0073e9SAndroid Build Coastguard Worker x = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) 2782*da0073e9SAndroid Build Coastguard Worker # mod with same dtype as x 2783*da0073e9SAndroid Build Coastguard Worker mod = make_tensor((10, 10), device=device, dtype=dtype, low=-9, high=9) 2784*da0073e9SAndroid Build Coastguard Worker # Exclude 0 2785*da0073e9SAndroid Build Coastguard Worker mod[mod == 0] = 1 2786*da0073e9SAndroid Build Coastguard Worker 2787*da0073e9SAndroid Build Coastguard Worker # Mods: Integer, Float, Tensor, Non-contiguous Tensor 2788*da0073e9SAndroid Build Coastguard Worker mods = [3, 2.3, mod, mod.t()] 2789*da0073e9SAndroid Build Coastguard Worker # mod with floating-point dtype 2790*da0073e9SAndroid Build Coastguard Worker if dtype in integral_types(): 2791*da0073e9SAndroid Build Coastguard Worker mod_float = make_tensor( 2792*da0073e9SAndroid Build Coastguard Worker (10, 10), device=device, dtype=torch.float, low=-9, high=9 2793*da0073e9SAndroid Build Coastguard Worker ) 2794*da0073e9SAndroid Build Coastguard Worker mod[mod == 0] = 1 2795*da0073e9SAndroid Build Coastguard Worker mods.append(mod_float) 2796*da0073e9SAndroid Build Coastguard Worker 2797*da0073e9SAndroid Build Coastguard Worker for dividend, mod in product([x, x.t()], mods): 2798*da0073e9SAndroid Build Coastguard Worker _helper( 2799*da0073e9SAndroid Build Coastguard Worker dividend, 2800*da0073e9SAndroid Build Coastguard Worker mod, 2801*da0073e9SAndroid Build Coastguard Worker ( 2802*da0073e9SAndroid Build Coastguard Worker (torch.fmod, torch.Tensor.fmod_, np.fmod), 2803*da0073e9SAndroid Build Coastguard Worker (torch.remainder, torch.Tensor.remainder_, np.remainder), 2804*da0073e9SAndroid Build Coastguard Worker ), 2805*da0073e9SAndroid Build Coastguard Worker ) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker # Tests for torch.remainder(scalar, tensor) 2808*da0073e9SAndroid Build Coastguard Worker for dividend, mod in product([5, 3.14], mods): 2809*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(mod): 2810*da0073e9SAndroid Build Coastguard Worker _helper( 2811*da0073e9SAndroid Build Coastguard Worker dividend, 2812*da0073e9SAndroid Build Coastguard Worker mod, 2813*da0073e9SAndroid Build Coastguard Worker ((torch.remainder, torch.Tensor.remainder_, np.remainder),), 2814*da0073e9SAndroid Build Coastguard Worker ) 2815*da0073e9SAndroid Build Coastguard Worker 2816*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 2817*da0073e9SAndroid Build Coastguard Worker def test_remainder_fmod_large_dividend(self, device, dtype): 2818*da0073e9SAndroid Build Coastguard Worker alarge = 1e9 2819*da0073e9SAndroid Build Coastguard Worker pi = 3.14159265358979 2820*da0073e9SAndroid Build Coastguard Worker for avalue in [alarge, -alarge]: 2821*da0073e9SAndroid Build Coastguard Worker for bvalue in [pi, -pi]: 2822*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([avalue], dtype=dtype, device=device) 2823*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([bvalue], dtype=dtype, device=device) 2824*da0073e9SAndroid Build Coastguard Worker c = torch.remainder(a, b) 2825*da0073e9SAndroid Build Coastguard Worker d = torch.fmod(a, b) 2826*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2827*da0073e9SAndroid Build Coastguard Worker (b[0] > 0) == (c[0] > 0) 2828*da0073e9SAndroid Build Coastguard Worker ) # remainder has same sign as divisor 2829*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2830*da0073e9SAndroid Build Coastguard Worker (a[0] > 0) == (d[0] > 0) 2831*da0073e9SAndroid Build Coastguard Worker ) # fmod has same sign as dividend 2832*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2833*da0073e9SAndroid Build Coastguard Worker abs(c[0]) < abs(b[0]) 2834*da0073e9SAndroid Build Coastguard Worker ) # remainder is within range of divisor 2835*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2836*da0073e9SAndroid Build Coastguard Worker abs(d[0]) < abs(b[0]) 2837*da0073e9SAndroid Build Coastguard Worker ) # fmod is within range of divisor 2838*da0073e9SAndroid Build Coastguard Worker if (a[0] > 0) == (b[0] > 0): 2839*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c[0] == d[0]) # remainder is same as fmod 2840*da0073e9SAndroid Build Coastguard Worker else: 2841*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2842*da0073e9SAndroid Build Coastguard Worker abs(c[0] - d[0]) == abs(b[0]) 2843*da0073e9SAndroid Build Coastguard Worker ) # differ by one divisor 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.bfloat16, torch.half, torch.float32, torch.float64) 2846*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 2847*da0073e9SAndroid Build Coastguard Worker def test_hypot(self, device, dtype): 2848*da0073e9SAndroid Build Coastguard Worker inputs = [ 2849*da0073e9SAndroid Build Coastguard Worker ( 2850*da0073e9SAndroid Build Coastguard Worker torch.randn(10, device=device).to(dtype), 2851*da0073e9SAndroid Build Coastguard Worker torch.randn(10, device=device).to(dtype), 2852*da0073e9SAndroid Build Coastguard Worker ), 2853*da0073e9SAndroid Build Coastguard Worker ( 2854*da0073e9SAndroid Build Coastguard Worker torch.randn((3, 3, 3), device=device).to(dtype), 2855*da0073e9SAndroid Build Coastguard Worker torch.randn((3, 3, 3), device=device).to(dtype), 2856*da0073e9SAndroid Build Coastguard Worker ), 2857*da0073e9SAndroid Build Coastguard Worker ( 2858*da0073e9SAndroid Build Coastguard Worker torch.randn((10, 1), device=device).to(dtype), 2859*da0073e9SAndroid Build Coastguard Worker torch.randn((10, 1), device=device).to(dtype).transpose(0, 1), 2860*da0073e9SAndroid Build Coastguard Worker ), 2861*da0073e9SAndroid Build Coastguard Worker ( 2862*da0073e9SAndroid Build Coastguard Worker torch.randint(100, (10,), device=device, dtype=torch.long), 2863*da0073e9SAndroid Build Coastguard Worker torch.randn(10, device=device).to(dtype), 2864*da0073e9SAndroid Build Coastguard Worker ), 2865*da0073e9SAndroid Build Coastguard Worker ] 2866*da0073e9SAndroid Build Coastguard Worker for input in inputs: 2867*da0073e9SAndroid Build Coastguard Worker actual = torch.hypot(input[0], input[1]) 2868*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.bfloat16, torch.half]: 2869*da0073e9SAndroid Build Coastguard Worker expected = torch.sqrt(input[0] * input[0] + input[1] * input[1]) 2870*da0073e9SAndroid Build Coastguard Worker else: 2871*da0073e9SAndroid Build Coastguard Worker expected = np.hypot(input[0].cpu().numpy(), input[1].cpu().numpy()) 2872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 2873*da0073e9SAndroid Build Coastguard Worker 2874*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2875*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 2876*da0073e9SAndroid Build Coastguard Worker def test_gcd(self, device, dtype): 2877*da0073e9SAndroid Build Coastguard Worker # Tests gcd(0, 0), gcd(0, a) cases 2878*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) 2879*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) 2880*da0073e9SAndroid Build Coastguard Worker actual = torch.gcd(t1, t2) 2881*da0073e9SAndroid Build Coastguard Worker expected = np.gcd([0, 10, 0], [0, 0, 10]) 2882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 2883*da0073e9SAndroid Build Coastguard Worker 2884*da0073e9SAndroid Build Coastguard Worker if dtype == torch.uint8: 2885*da0073e9SAndroid Build Coastguard Worker # Test unsigned integers with potential sign issues (i.e., uint8 with value >= 128) 2886*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([190, 210], device=device, dtype=dtype) 2887*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([190, 220], device=device, dtype=dtype) 2888*da0073e9SAndroid Build Coastguard Worker actual = torch.gcd(a, b) 2889*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor([190, 10], device=device, dtype=dtype) 2890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 2891*da0073e9SAndroid Build Coastguard Worker else: 2892*da0073e9SAndroid Build Coastguard Worker # Compares with NumPy 2893*da0073e9SAndroid Build Coastguard Worker a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) 2894*da0073e9SAndroid Build Coastguard Worker b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) 2895*da0073e9SAndroid Build Coastguard Worker actual = torch.gcd(a, b) 2896*da0073e9SAndroid Build Coastguard Worker expected = np.gcd(a.cpu().numpy(), b.cpu().numpy()) 2897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 2898*da0073e9SAndroid Build Coastguard Worker 2899*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2900*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int16, torch.int32, torch.int64) 2901*da0073e9SAndroid Build Coastguard Worker def test_lcm(self, device, dtype): 2902*da0073e9SAndroid Build Coastguard Worker # Tests lcm(0, 0), lcm(0, a) cases 2903*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([0, 10, 0], dtype=dtype, device=device) 2904*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([0, 0, 10], dtype=dtype, device=device) 2905*da0073e9SAndroid Build Coastguard Worker actual = torch.lcm(t1, t2) 2906*da0073e9SAndroid Build Coastguard Worker expected = np.lcm([0, 10, 0], [0, 0, 10]) 2907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 2908*da0073e9SAndroid Build Coastguard Worker 2909*da0073e9SAndroid Build Coastguard Worker # Compares with NumPy 2910*da0073e9SAndroid Build Coastguard Worker a = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) 2911*da0073e9SAndroid Build Coastguard Worker b = torch.randint(-20, 20, (1024,), device=device, dtype=dtype) 2912*da0073e9SAndroid Build Coastguard Worker actual = torch.lcm(a, b) 2913*da0073e9SAndroid Build Coastguard Worker expected = np.lcm(a.cpu().numpy(), b.cpu().numpy()) 2914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 2915*da0073e9SAndroid Build Coastguard Worker 2916*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2917*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float32, torch.float64, torch.float16) 2918*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64) 2919*da0073e9SAndroid Build Coastguard Worker def test_nextafter(self, device, dtype): 2920*da0073e9SAndroid Build Coastguard Worker # Test special cases 2921*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([0, 0, 10], device=device, dtype=dtype) 2922*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([inf, -inf, 10], device=device, dtype=dtype) 2923*da0073e9SAndroid Build Coastguard Worker actual = torch.nextafter(t1, t2) 2924*da0073e9SAndroid Build Coastguard Worker expected = np.nextafter(t1.cpu().numpy(), t2.cpu().numpy()) 2925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0, rtol=0) 2926*da0073e9SAndroid Build Coastguard Worker 2927*da0073e9SAndroid Build Coastguard Worker actual = torch.nextafter(t2, t1) 2928*da0073e9SAndroid Build Coastguard Worker expected = np.nextafter(t2.cpu().numpy(), t1.cpu().numpy()) 2929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0, rtol=0) 2930*da0073e9SAndroid Build Coastguard Worker 2931*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([0, nan], device=device, dtype=dtype) 2932*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([nan, 0], device=device, dtype=dtype) 2933*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.nextafter(t1, t2).isnan().all()) 2934*da0073e9SAndroid Build Coastguard Worker 2935*da0073e9SAndroid Build Coastguard Worker a = torch.randn(100, device=device, dtype=dtype) 2936*da0073e9SAndroid Build Coastguard Worker b = torch.randn(100, device=device, dtype=dtype) 2937*da0073e9SAndroid Build Coastguard Worker actual = torch.nextafter(a, b) 2938*da0073e9SAndroid Build Coastguard Worker expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy()) 2939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0, rtol=0) 2940*da0073e9SAndroid Build Coastguard Worker 2941*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2942*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16) 2943*da0073e9SAndroid Build Coastguard Worker def test_nextafter_bfloat16(self, device, dtype): 2944*da0073e9SAndroid Build Coastguard Worker nan = float("nan") 2945*da0073e9SAndroid Build Coastguard Worker inf = float("inf") 2946*da0073e9SAndroid Build Coastguard Worker cases = ( 2947*da0073e9SAndroid Build Coastguard Worker # (from, to, expected) 2948*da0073e9SAndroid Build Coastguard Worker (0, 1, 9.183549615799121e-41), 2949*da0073e9SAndroid Build Coastguard Worker (0, -1, -9.183549615799121e-41), 2950*da0073e9SAndroid Build Coastguard Worker (1, -2, 0.99609375), 2951*da0073e9SAndroid Build Coastguard Worker (1, 0, 0.99609375), 2952*da0073e9SAndroid Build Coastguard Worker (1, 2, 1.0078125), 2953*da0073e9SAndroid Build Coastguard Worker (-1, -2, -1.0078125), 2954*da0073e9SAndroid Build Coastguard Worker (-1, 0, -0.99609375), 2955*da0073e9SAndroid Build Coastguard Worker (2, -1, 1.9921875), 2956*da0073e9SAndroid Build Coastguard Worker (2, 1, 1.9921875), 2957*da0073e9SAndroid Build Coastguard Worker (20, 3000, 20.125), 2958*da0073e9SAndroid Build Coastguard Worker (20, -3000, 19.875), 2959*da0073e9SAndroid Build Coastguard Worker (3000, -20, 2992.0), 2960*da0073e9SAndroid Build Coastguard Worker (-3000, 20, -2992.0), 2961*da0073e9SAndroid Build Coastguard Worker (65536, 0, 65280.0), 2962*da0073e9SAndroid Build Coastguard Worker (65536, inf, 66048.0), 2963*da0073e9SAndroid Build Coastguard Worker (-65536, 0, -65280.0), 2964*da0073e9SAndroid Build Coastguard Worker (-65536, -inf, -66048.0), 2965*da0073e9SAndroid Build Coastguard Worker (nan, 0, nan), 2966*da0073e9SAndroid Build Coastguard Worker (0, nan, nan), 2967*da0073e9SAndroid Build Coastguard Worker (nan, nan, nan), 2968*da0073e9SAndroid Build Coastguard Worker (nan, inf, nan), 2969*da0073e9SAndroid Build Coastguard Worker (inf, nan, nan), 2970*da0073e9SAndroid Build Coastguard Worker (inf, -inf, 3.3895313892515355e38), 2971*da0073e9SAndroid Build Coastguard Worker (-inf, inf, -3.3895313892515355e38), 2972*da0073e9SAndroid Build Coastguard Worker (inf, 0, 3.3895313892515355e38), 2973*da0073e9SAndroid Build Coastguard Worker (0, inf, 9.183549615799121e-41), 2974*da0073e9SAndroid Build Coastguard Worker (-inf, 0, -3.3895313892515355e38), 2975*da0073e9SAndroid Build Coastguard Worker (0, -inf, -9.183549615799121e-41), 2976*da0073e9SAndroid Build Coastguard Worker ) 2977*da0073e9SAndroid Build Coastguard Worker 2978*da0073e9SAndroid Build Coastguard Worker for from_v, to_v, expected in cases: 2979*da0073e9SAndroid Build Coastguard Worker from_t = torch.tensor([from_v], device=device, dtype=dtype) 2980*da0073e9SAndroid Build Coastguard Worker to_t = torch.tensor([to_v], device=device, dtype=dtype) 2981*da0073e9SAndroid Build Coastguard Worker actual = torch.nextafter(from_t, to_t).item() 2982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0, rtol=0) 2983*da0073e9SAndroid Build Coastguard Worker 2984*da0073e9SAndroid Build Coastguard Worker def _test_cop(self, torchfn, mathfn, dtype, device): 2985*da0073e9SAndroid Build Coastguard Worker def reference_implementation(res2): 2986*da0073e9SAndroid Build Coastguard Worker for i, j in iter_indices(sm1): 2987*da0073e9SAndroid Build Coastguard Worker idx1d = i * sm1.size(0) + j 2988*da0073e9SAndroid Build Coastguard Worker res2[i, j] = mathfn(sm1[i, j], sm2[idx1d]) 2989*da0073e9SAndroid Build Coastguard Worker return res2 2990*da0073e9SAndroid Build Coastguard Worker 2991*da0073e9SAndroid Build Coastguard Worker # contiguous 2992*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) 2993*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(10, 10 * 10, dtype=dtype, device=device) 2994*da0073e9SAndroid Build Coastguard Worker sm1 = m1[4] 2995*da0073e9SAndroid Build Coastguard Worker sm2 = m2[4] 2996*da0073e9SAndroid Build Coastguard Worker 2997*da0073e9SAndroid Build Coastguard Worker res1 = torchfn(sm1, sm2.view(10, 10)) 2998*da0073e9SAndroid Build Coastguard Worker res2 = reference_implementation(res1.clone()) 2999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3000*da0073e9SAndroid Build Coastguard Worker 3001*da0073e9SAndroid Build Coastguard Worker # non-contiguous 3002*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, 10, dtype=dtype, device=device) 3003*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(10 * 10, 10 * 10, dtype=dtype, device=device) 3004*da0073e9SAndroid Build Coastguard Worker sm1 = m1[:, 4] 3005*da0073e9SAndroid Build Coastguard Worker sm2 = m2[:, 4] 3006*da0073e9SAndroid Build Coastguard Worker # view as sm1.size() 3007*da0073e9SAndroid Build Coastguard Worker sm2.set_( 3008*da0073e9SAndroid Build Coastguard Worker sm2.storage(), 3009*da0073e9SAndroid Build Coastguard Worker sm2.storage_offset(), 3010*da0073e9SAndroid Build Coastguard Worker sm1.size(), 3011*da0073e9SAndroid Build Coastguard Worker (sm2.stride()[0] * 10, sm2.stride()[0]), 3012*da0073e9SAndroid Build Coastguard Worker ) 3013*da0073e9SAndroid Build Coastguard Worker res1 = torchfn(sm1, sm2) 3014*da0073e9SAndroid Build Coastguard Worker # reference_implementation assumes 1-d sm2 3015*da0073e9SAndroid Build Coastguard Worker sm2.set_( 3016*da0073e9SAndroid Build Coastguard Worker sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride() 3017*da0073e9SAndroid Build Coastguard Worker ) 3018*da0073e9SAndroid Build Coastguard Worker res2 = reference_implementation(res1.clone()) 3019*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3022*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3023*da0073e9SAndroid Build Coastguard Worker def test_cdiv(self, device, dtype): 3024*da0073e9SAndroid Build Coastguard Worker self._test_cop(torch.div, operator.truediv, dtype, device) 3025*da0073e9SAndroid Build Coastguard Worker 3026*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3027*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3028*da0073e9SAndroid Build Coastguard Worker def test_cremainder(self, device, dtype): 3029*da0073e9SAndroid Build Coastguard Worker self._test_cop(torch.remainder, operator.mod, dtype, device) 3030*da0073e9SAndroid Build Coastguard Worker 3031*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3032*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3033*da0073e9SAndroid Build Coastguard Worker def test_cmul(self, device, dtype): 3034*da0073e9SAndroid Build Coastguard Worker self._test_cop(torch.mul, operator.mul, dtype, device) 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3037*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 3038*da0073e9SAndroid Build Coastguard Worker def test_cpow(self, device, dtype): 3039*da0073e9SAndroid Build Coastguard Worker self._test_cop( 3040*da0073e9SAndroid Build Coastguard Worker torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y), dtype, device 3041*da0073e9SAndroid Build Coastguard Worker ) 3042*da0073e9SAndroid Build Coastguard Worker 3043*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3044*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 3045*da0073e9SAndroid Build Coastguard Worker def test_floor_divide_zero(self, device, dtype): 3046*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([0, 1], dtype=dtype, device=device) 3047*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([0, 1], dtype=dtype, device=device) 3048*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): 3049*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): 3050*da0073e9SAndroid Build Coastguard Worker a // b 3051*da0073e9SAndroid Build Coastguard Worker 3052*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 3053*da0073e9SAndroid Build Coastguard Worker def test_muldiv_scalar(self, device, dtype): 3054*da0073e9SAndroid Build Coastguard Worker x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None) 3055*da0073e9SAndroid Build Coastguard Worker s = make_tensor((1,), dtype=dtype, device="cpu", low=None, high=None).item() 3056*da0073e9SAndroid Build Coastguard Worker y = torch.full_like(x, s) 3057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x * s, x * y) 3058*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s * x, y * x) 3059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x / s, x / y) 3060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s / x, y / x) 3061*da0073e9SAndroid Build Coastguard Worker 3062*da0073e9SAndroid Build Coastguard Worker # TODO: update make_tensor to support extremal additions and remove this in favor of make_tensor 3063*da0073e9SAndroid Build Coastguard Worker def _generate_input(self, shape, dtype, device, with_extremal): 3064*da0073e9SAndroid Build Coastguard Worker if shape == (): 3065*da0073e9SAndroid Build Coastguard Worker x = torch.tensor((), dtype=dtype, device=device) 3066*da0073e9SAndroid Build Coastguard Worker else: 3067*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 3068*da0073e9SAndroid Build Coastguard Worker # work around torch.randn not being implemented for bfloat16 3069*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 3070*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*shape, device=device) * random.randint(30, 100) 3071*da0073e9SAndroid Build Coastguard Worker x = x.to(torch.bfloat16) 3072*da0073e9SAndroid Build Coastguard Worker else: 3073*da0073e9SAndroid Build Coastguard Worker x = torch.randn( 3074*da0073e9SAndroid Build Coastguard Worker *shape, dtype=dtype, device=device 3075*da0073e9SAndroid Build Coastguard Worker ) * random.randint(30, 100) 3076*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = 0 3077*da0073e9SAndroid Build Coastguard Worker if with_extremal and dtype.is_floating_point: 3078*da0073e9SAndroid Build Coastguard Worker # Use extremal values 3079*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("nan") 3080*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("inf") 3081*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = float("-inf") 3082*da0073e9SAndroid Build Coastguard Worker elif with_extremal and dtype.is_complex: 3083*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("nan") 3084*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("inf") 3085*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = complex("-inf") 3086*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.bool: 3087*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(shape, dtype=dtype, device=device) 3088*da0073e9SAndroid Build Coastguard Worker x[torch.randn(*shape) > 0.5] = True 3089*da0073e9SAndroid Build Coastguard Worker else: 3090*da0073e9SAndroid Build Coastguard Worker x = torch.randint(15, 100, shape, dtype=dtype, device=device) 3091*da0073e9SAndroid Build Coastguard Worker 3092*da0073e9SAndroid Build Coastguard Worker return x 3093*da0073e9SAndroid Build Coastguard Worker 3094*da0073e9SAndroid Build Coastguard Worker @dtypes( 3095*da0073e9SAndroid Build Coastguard Worker *tuple( 3096*da0073e9SAndroid Build Coastguard Worker itertools.combinations_with_replacement( 3097*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 2 3098*da0073e9SAndroid Build Coastguard Worker ) 3099*da0073e9SAndroid Build Coastguard Worker ) 3100*da0073e9SAndroid Build Coastguard Worker ) 3101*da0073e9SAndroid Build Coastguard Worker def test_comparison_ops_type_promotion_and_broadcasting(self, device, dtypes): 3102*da0073e9SAndroid Build Coastguard Worker # issue #42660 3103*da0073e9SAndroid Build Coastguard Worker # testing all combinations of broadcasting and type promotion 3104*da0073e9SAndroid Build Coastguard Worker # with a range of dtypes and input shapes, and with extremal values 3105*da0073e9SAndroid Build Coastguard Worker def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, out=None): 3106*da0073e9SAndroid Build Coastguard Worker # working around the fact that numpy doesn't support bfloat16 3107*da0073e9SAndroid Build Coastguard Worker # by letting numpy treat them as float32's 3108*da0073e9SAndroid Build Coastguard Worker x_np = x if x.dtype != torch.bfloat16 else x.to(torch.float32) 3109*da0073e9SAndroid Build Coastguard Worker y_np = ( 3110*da0073e9SAndroid Build Coastguard Worker y.cpu().numpy() 3111*da0073e9SAndroid Build Coastguard Worker if y.dtype != torch.bfloat16 3112*da0073e9SAndroid Build Coastguard Worker else y.to(torch.float32).cpu().numpy() 3113*da0073e9SAndroid Build Coastguard Worker ) 3114*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 3115*da0073e9SAndroid Build Coastguard Worker lambda inp: torch_fn(inp, y, out=out) if out else torch_fn(inp, y), 3116*da0073e9SAndroid Build Coastguard Worker lambda inp: np_fn(inp, y_np, out=out) if out else np_fn(inp, y_np), 3117*da0073e9SAndroid Build Coastguard Worker x_np, 3118*da0073e9SAndroid Build Coastguard Worker ) 3119*da0073e9SAndroid Build Coastguard Worker 3120*da0073e9SAndroid Build Coastguard Worker complex_op_denylist = [ 3121*da0073e9SAndroid Build Coastguard Worker torch.lt, 3122*da0073e9SAndroid Build Coastguard Worker torch.le, 3123*da0073e9SAndroid Build Coastguard Worker torch.gt, 3124*da0073e9SAndroid Build Coastguard Worker torch.ge, 3125*da0073e9SAndroid Build Coastguard Worker ] # complex not supported 3126*da0073e9SAndroid Build Coastguard Worker input_sizes = [(1,), (10,), (10, 1), (1, 10), (4, 10), (64, 10), (12, 3)] 3127*da0073e9SAndroid Build Coastguard Worker op_pairs = [ 3128*da0073e9SAndroid Build Coastguard Worker (torch.lt, np.less), 3129*da0073e9SAndroid Build Coastguard Worker (torch.le, np.less_equal), 3130*da0073e9SAndroid Build Coastguard Worker (torch.gt, np.greater), 3131*da0073e9SAndroid Build Coastguard Worker (torch.ge, np.greater_equal), 3132*da0073e9SAndroid Build Coastguard Worker (torch.eq, np.equal), 3133*da0073e9SAndroid Build Coastguard Worker (torch.ne, np.not_equal), 3134*da0073e9SAndroid Build Coastguard Worker (torch.logical_and, np.logical_and), 3135*da0073e9SAndroid Build Coastguard Worker (torch.logical_or, np.logical_or), 3136*da0073e9SAndroid Build Coastguard Worker (torch.logical_xor, np.logical_xor), 3137*da0073e9SAndroid Build Coastguard Worker ] 3138*da0073e9SAndroid Build Coastguard Worker 3139*da0073e9SAndroid Build Coastguard Worker for size1 in input_sizes: 3140*da0073e9SAndroid Build Coastguard Worker size2 = (2,) + size1 # perform broadcasting 3141*da0073e9SAndroid Build Coastguard Worker for with_extremal in [False, True]: 3142*da0073e9SAndroid Build Coastguard Worker a = self._generate_input(size1, dtypes[0], device, with_extremal) 3143*da0073e9SAndroid Build Coastguard Worker b = self._generate_input(size2, dtypes[1], device, with_extremal) 3144*da0073e9SAndroid Build Coastguard Worker for torch_op, numpy_op in op_pairs: 3145*da0073e9SAndroid Build Coastguard Worker if ( 3146*da0073e9SAndroid Build Coastguard Worker dtypes[0].is_complex or dtypes[1].is_complex 3147*da0073e9SAndroid Build Coastguard Worker ) and torch_op in complex_op_denylist: 3148*da0073e9SAndroid Build Coastguard Worker continue 3149*da0073e9SAndroid Build Coastguard Worker # functional version of op 3150*da0073e9SAndroid Build Coastguard Worker compare_with_numpy_bin_op(torch_op, numpy_op, a, b) 3151*da0073e9SAndroid Build Coastguard Worker 3152*da0073e9SAndroid Build Coastguard Worker # functional comparison ops always return bool tensors 3153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch_op(a, b).dtype, torch.bool) 3154*da0073e9SAndroid Build Coastguard Worker 3155*da0073e9SAndroid Build Coastguard Worker # out version of op 3156*da0073e9SAndroid Build Coastguard Worker out = torch.zeros( 3157*da0073e9SAndroid Build Coastguard Worker 1, dtype=torch.complex128 3158*da0073e9SAndroid Build Coastguard Worker ) # all casts to complex128 are safe 3159*da0073e9SAndroid Build Coastguard Worker compare_with_numpy_bin_op(torch_op, numpy_op, a, b, out=out) 3160*da0073e9SAndroid Build Coastguard Worker 3161*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3162*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) 3163*da0073e9SAndroid Build Coastguard Worker def test_signed_shift(self, device, dtype): 3164*da0073e9SAndroid Build Coastguard Worker "Ensure that signed integer bit shifting works as expected." 3165*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-10, 10], device=device, dtype=dtype) # [11...1110110, 1010] 3166*da0073e9SAndroid Build Coastguard Worker expected_l = torch.tensor( 3167*da0073e9SAndroid Build Coastguard Worker [-40, 40], device=device, dtype=dtype 3168*da0073e9SAndroid Build Coastguard Worker ) # [11...11011000, 101000] 3169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a << 2, expected_l) 3170*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: x << 2, lambda x: np.left_shift(x, 2), a) 3171*da0073e9SAndroid Build Coastguard Worker expected_r = torch.tensor( 3172*da0073e9SAndroid Build Coastguard Worker [-5, 5], device=device, dtype=dtype 3173*da0073e9SAndroid Build Coastguard Worker ) # [1111...111011, 101] 3174*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a >> 1, expected_r) 3175*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy(lambda x: x >> 1, lambda x: np.right_shift(x, 1), a) 3176*da0073e9SAndroid Build Coastguard Worker 3177*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3178*da0073e9SAndroid Build Coastguard Worker @dtypes(*get_all_int_dtypes()) 3179*da0073e9SAndroid Build Coastguard Worker def test_shift_limits(self, device, dtype): 3180*da0073e9SAndroid Build Coastguard Worker "Ensure that integer bit shifting works as expected with out-of-limits shift values." 3181*da0073e9SAndroid Build Coastguard Worker # Issue #70904 3182*da0073e9SAndroid Build Coastguard Worker iinfo = torch.iinfo(dtype) 3183*da0073e9SAndroid Build Coastguard Worker bits = iinfo.bits 3184*da0073e9SAndroid Build Coastguard Worker low = iinfo.min 3185*da0073e9SAndroid Build Coastguard Worker high = iinfo.max 3186*da0073e9SAndroid Build Coastguard Worker exact_dtype = ( 3187*da0073e9SAndroid Build Coastguard Worker dtype != torch.uint8 3188*da0073e9SAndroid Build Coastguard Worker ) # numpy changes dtype from uint8 to int16 for some out-of-limits shift values 3189*da0073e9SAndroid Build Coastguard Worker for input in ( 3190*da0073e9SAndroid Build Coastguard Worker torch.tensor( 3191*da0073e9SAndroid Build Coastguard Worker [-1, 0, 1], device=device, dtype=dtype 3192*da0073e9SAndroid Build Coastguard Worker ), # small for non-vectorized operation 3193*da0073e9SAndroid Build Coastguard Worker torch.tensor( 3194*da0073e9SAndroid Build Coastguard Worker [low, high], device=device, dtype=dtype 3195*da0073e9SAndroid Build Coastguard Worker ), # small for non-vectorized operation 3196*da0073e9SAndroid Build Coastguard Worker make_tensor( 3197*da0073e9SAndroid Build Coastguard Worker (64, 64, 64), low=low, high=high, device=device, dtype=dtype 3198*da0073e9SAndroid Build Coastguard Worker ), # large for vectorized operation 3199*da0073e9SAndroid Build Coastguard Worker ): 3200*da0073e9SAndroid Build Coastguard Worker shift_left_expected = torch.zeros_like(input) 3201*da0073e9SAndroid Build Coastguard Worker shift_right_expected = torch.clamp(input, -1, 0) 3202*da0073e9SAndroid Build Coastguard Worker for shift in chain(range(-100, -1), range(bits, 100)): 3203*da0073e9SAndroid Build Coastguard Worker shift_left = input << shift 3204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}") 3205*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 3206*da0073e9SAndroid Build Coastguard Worker lambda x: x << shift, 3207*da0073e9SAndroid Build Coastguard Worker lambda x: np.left_shift(x, shift), 3208*da0073e9SAndroid Build Coastguard Worker input, 3209*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 3210*da0073e9SAndroid Build Coastguard Worker msg=f"<< {shift}", 3211*da0073e9SAndroid Build Coastguard Worker ) 3212*da0073e9SAndroid Build Coastguard Worker shift_right = input >> shift 3213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shift_right, shift_right_expected, msg=f">> {shift}") 3214*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 3215*da0073e9SAndroid Build Coastguard Worker lambda x: x >> shift, 3216*da0073e9SAndroid Build Coastguard Worker lambda x: np.right_shift(x, shift), 3217*da0073e9SAndroid Build Coastguard Worker input, 3218*da0073e9SAndroid Build Coastguard Worker exact_dtype=exact_dtype, 3219*da0073e9SAndroid Build Coastguard Worker msg=f">> {shift}", 3220*da0073e9SAndroid Build Coastguard Worker ) 3221*da0073e9SAndroid Build Coastguard Worker 3222*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3223*da0073e9SAndroid Build Coastguard Worker @dtypes( 3224*da0073e9SAndroid Build Coastguard Worker *list( 3225*da0073e9SAndroid Build Coastguard Worker product( 3226*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bfloat16, torch.bool), 3227*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bfloat16, torch.bool), 3228*da0073e9SAndroid Build Coastguard Worker ) 3229*da0073e9SAndroid Build Coastguard Worker ) 3230*da0073e9SAndroid Build Coastguard Worker ) 3231*da0073e9SAndroid Build Coastguard Worker def test_heaviside(self, device, dtypes): 3232*da0073e9SAndroid Build Coastguard Worker input_dtype = dtypes[0] 3233*da0073e9SAndroid Build Coastguard Worker values_dtype = dtypes[1] 3234*da0073e9SAndroid Build Coastguard Worker 3235*da0073e9SAndroid Build Coastguard Worker rng = np.random.default_rng() 3236*da0073e9SAndroid Build Coastguard Worker input = np.array( 3237*da0073e9SAndroid Build Coastguard Worker rng.integers(-10, 10, size=10), 3238*da0073e9SAndroid Build Coastguard Worker dtype=torch_to_numpy_dtype_dict[ 3239*da0073e9SAndroid Build Coastguard Worker input_dtype if (input_dtype != torch.bfloat16) else torch.float64 3240*da0073e9SAndroid Build Coastguard Worker ], 3241*da0073e9SAndroid Build Coastguard Worker ) 3242*da0073e9SAndroid Build Coastguard Worker input[0] = input[3] = input[7] = 0 3243*da0073e9SAndroid Build Coastguard Worker values = np.array( 3244*da0073e9SAndroid Build Coastguard Worker rng.integers(-10, 10, size=10), 3245*da0073e9SAndroid Build Coastguard Worker dtype=torch_to_numpy_dtype_dict[ 3246*da0073e9SAndroid Build Coastguard Worker values_dtype if (values_dtype != torch.bfloat16) else torch.float64 3247*da0073e9SAndroid Build Coastguard Worker ], 3248*da0073e9SAndroid Build Coastguard Worker ) 3249*da0073e9SAndroid Build Coastguard Worker np_result = torch.from_numpy(np.heaviside(input, values)).to( 3250*da0073e9SAndroid Build Coastguard Worker device=device, dtype=input_dtype 3251*da0073e9SAndroid Build Coastguard Worker ) 3252*da0073e9SAndroid Build Coastguard Worker 3253*da0073e9SAndroid Build Coastguard Worker input = torch.from_numpy(input).to(device=device, dtype=input_dtype) 3254*da0073e9SAndroid Build Coastguard Worker values = torch.from_numpy(values).to(device=device, dtype=values_dtype) 3255*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(input) 3256*da0073e9SAndroid Build Coastguard Worker 3257*da0073e9SAndroid Build Coastguard Worker if input_dtype == values_dtype: 3258*da0073e9SAndroid Build Coastguard Worker torch_result = torch.heaviside(input, values) 3259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_result, torch_result) 3260*da0073e9SAndroid Build Coastguard Worker 3261*da0073e9SAndroid Build Coastguard Worker torch_result = input.heaviside(values) 3262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_result, torch_result) 3263*da0073e9SAndroid Build Coastguard Worker 3264*da0073e9SAndroid Build Coastguard Worker torch.heaviside(input, values, out=out) 3265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_result, out) 3266*da0073e9SAndroid Build Coastguard Worker 3267*da0073e9SAndroid Build Coastguard Worker input.heaviside_(values) 3268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_result, input) 3269*da0073e9SAndroid Build Coastguard Worker else: 3270*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3271*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3272*da0073e9SAndroid Build Coastguard Worker "heaviside is not yet implemented for tensors with different dtypes.", 3273*da0073e9SAndroid Build Coastguard Worker ): 3274*da0073e9SAndroid Build Coastguard Worker torch.heaviside(input, values) 3275*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3276*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3277*da0073e9SAndroid Build Coastguard Worker "heaviside is not yet implemented for tensors with different dtypes.", 3278*da0073e9SAndroid Build Coastguard Worker ): 3279*da0073e9SAndroid Build Coastguard Worker input.heaviside(values) 3280*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3281*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3282*da0073e9SAndroid Build Coastguard Worker "heaviside is not yet implemented for tensors with different dtypes.", 3283*da0073e9SAndroid Build Coastguard Worker ): 3284*da0073e9SAndroid Build Coastguard Worker torch.heaviside(input, values, out=out) 3285*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3286*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3287*da0073e9SAndroid Build Coastguard Worker "heaviside is not yet implemented for tensors with different dtypes.", 3288*da0073e9SAndroid Build Coastguard Worker ): 3289*da0073e9SAndroid Build Coastguard Worker input.heaviside_(values) 3290*da0073e9SAndroid Build Coastguard Worker 3291*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3292*da0073e9SAndroid Build Coastguard Worker def test_heaviside_cross_device(self, device): 3293*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-9, 5, 0, 6, -2, 2], device=device) 3294*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(0) 3295*da0073e9SAndroid Build Coastguard Worker result = torch.heaviside(x, y) 3296*da0073e9SAndroid Build Coastguard Worker expect = torch.tensor([0, 1, 0, 1, 0, 1], device=device) 3297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expect) 3298*da0073e9SAndroid Build Coastguard Worker 3299*da0073e9SAndroid Build Coastguard Worker result = torch.heaviside(y, x) 3300*da0073e9SAndroid Build Coastguard Worker expect = torch.tensor([-9, 5, 0, 6, -2, 2], device=device) 3301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expect) 3302*da0073e9SAndroid Build Coastguard Worker 3303*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-9, 5, 0, 6, -2, 2]) 3304*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(0, device=device) 3305*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3306*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 3307*da0073e9SAndroid Build Coastguard Worker ): 3308*da0073e9SAndroid Build Coastguard Worker torch.heaviside(x, y) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3311*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device" 3312*da0073e9SAndroid Build Coastguard Worker ): 3313*da0073e9SAndroid Build Coastguard Worker torch.heaviside(y, x) 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker @dtypes(*list(product(complex_types(), complex_types()))) 3316*da0073e9SAndroid Build Coastguard Worker def test_heaviside_complex(self, device, dtypes): 3317*da0073e9SAndroid Build Coastguard Worker input_dtype = dtypes[0] 3318*da0073e9SAndroid Build Coastguard Worker values_dtype = dtypes[1] 3319*da0073e9SAndroid Build Coastguard Worker 3320*da0073e9SAndroid Build Coastguard Worker data = (complex(0, -6), complex(-1, 3), complex(1, 1)) 3321*da0073e9SAndroid Build Coastguard Worker input = torch.tensor(data, device=device, dtype=input_dtype) 3322*da0073e9SAndroid Build Coastguard Worker values = torch.tensor(data, device=device, dtype=values_dtype) 3323*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(input) 3324*da0073e9SAndroid Build Coastguard Worker real = input.real 3325*da0073e9SAndroid Build Coastguard Worker 3326*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3327*da0073e9SAndroid Build Coastguard Worker RuntimeError, "heaviside is not yet implemented for complex tensors." 3328*da0073e9SAndroid Build Coastguard Worker ): 3329*da0073e9SAndroid Build Coastguard Worker torch.heaviside(input, real) 3330*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3331*da0073e9SAndroid Build Coastguard Worker RuntimeError, "heaviside is not yet implemented for complex tensors." 3332*da0073e9SAndroid Build Coastguard Worker ): 3333*da0073e9SAndroid Build Coastguard Worker real.heaviside(values) 3334*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3335*da0073e9SAndroid Build Coastguard Worker RuntimeError, "heaviside is not yet implemented for complex tensors." 3336*da0073e9SAndroid Build Coastguard Worker ): 3337*da0073e9SAndroid Build Coastguard Worker input.heaviside_(values) 3338*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3339*da0073e9SAndroid Build Coastguard Worker RuntimeError, "heaviside is not yet implemented for complex tensors." 3340*da0073e9SAndroid Build Coastguard Worker ): 3341*da0073e9SAndroid Build Coastguard Worker torch.heaviside(real, real, out=out) 3342*da0073e9SAndroid Build Coastguard Worker 3343*da0073e9SAndroid Build Coastguard Worker def _test_logical(self, device, dtypes, op, a_, b_, expected_res_): 3344*da0073e9SAndroid Build Coastguard Worker expected_res = torch.tensor(expected_res_, dtype=dtypes[0], device=device) 3345*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(a_, dtype=dtypes[0], device=device) 3346*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(b_, dtype=dtypes[1], device=device) 3347*da0073e9SAndroid Build Coastguard Worker 3348*da0073e9SAndroid Build Coastguard Worker # new tensor 3349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res.bool(), getattr(a, op)(b)) 3350*da0073e9SAndroid Build Coastguard Worker # out 3351*da0073e9SAndroid Build Coastguard Worker c = torch.empty(0, dtype=torch.bool, device=device) 3352*da0073e9SAndroid Build Coastguard Worker getattr(torch, op)(a, b, out=c) 3353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res.bool(), c) 3354*da0073e9SAndroid Build Coastguard Worker 3355*da0073e9SAndroid Build Coastguard Worker getattr(a, op + "_")(b) 3356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_res, a) 3357*da0073e9SAndroid Build Coastguard Worker 3358*da0073e9SAndroid Build Coastguard Worker @dtypes( 3359*da0073e9SAndroid Build Coastguard Worker *product( 3360*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 3361*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 3362*da0073e9SAndroid Build Coastguard Worker ) 3363*da0073e9SAndroid Build Coastguard Worker ) 3364*da0073e9SAndroid Build Coastguard Worker def test_logical_xor(self, device, dtypes): 3365*da0073e9SAndroid Build Coastguard Worker self._test_logical( 3366*da0073e9SAndroid Build Coastguard Worker device, dtypes, "logical_xor", [10, 0, 1, 0], [1, 0, 0, 10], [0, 0, 1, 1] 3367*da0073e9SAndroid Build Coastguard Worker ) 3368*da0073e9SAndroid Build Coastguard Worker 3369*da0073e9SAndroid Build Coastguard Worker @dtypes( 3370*da0073e9SAndroid Build Coastguard Worker *product( 3371*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 3372*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 3373*da0073e9SAndroid Build Coastguard Worker ) 3374*da0073e9SAndroid Build Coastguard Worker ) 3375*da0073e9SAndroid Build Coastguard Worker def test_logical_and(self, device, dtypes): 3376*da0073e9SAndroid Build Coastguard Worker self._test_logical( 3377*da0073e9SAndroid Build Coastguard Worker device, dtypes, "logical_and", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 0, 0] 3378*da0073e9SAndroid Build Coastguard Worker ) 3379*da0073e9SAndroid Build Coastguard Worker 3380*da0073e9SAndroid Build Coastguard Worker @dtypes( 3381*da0073e9SAndroid Build Coastguard Worker *product( 3382*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 3383*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 3384*da0073e9SAndroid Build Coastguard Worker ) 3385*da0073e9SAndroid Build Coastguard Worker ) 3386*da0073e9SAndroid Build Coastguard Worker def test_logical_or(self, device, dtypes): 3387*da0073e9SAndroid Build Coastguard Worker self._test_logical( 3388*da0073e9SAndroid Build Coastguard Worker device, dtypes, "logical_or", [10, 0, 1, 0], [1, 0, 0, 10], [1, 0, 1, 1] 3389*da0073e9SAndroid Build Coastguard Worker ) 3390*da0073e9SAndroid Build Coastguard Worker 3391*da0073e9SAndroid Build Coastguard Worker def test_remainder_overflow(self, device): 3392*da0073e9SAndroid Build Coastguard Worker # Check Integer Overflows 3393*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(23500, dtype=torch.int64, device=device) 3394*da0073e9SAndroid Build Coastguard Worker q = 392486996410368 3395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x % q, x) 3396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-x % q, q - x) 3397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x % -q, x - q) 3398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-x % -q, -x) 3399*da0073e9SAndroid Build Coastguard Worker 3400*da0073e9SAndroid Build Coastguard Worker def test_rpow(self, device): 3401*da0073e9SAndroid Build Coastguard Worker m = torch.randn(10, 10, device=device) 3402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.pow(2, m), 2**m) 3403*da0073e9SAndroid Build Coastguard Worker 3404*da0073e9SAndroid Build Coastguard Worker # test with scalar 3405*da0073e9SAndroid Build Coastguard Worker m = torch.randn(1, device=device).squeeze() 3406*da0073e9SAndroid Build Coastguard Worker assert m.dim() == 0, "m is intentionally a scalar" 3407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.pow(2, m), 2**m) 3408*da0073e9SAndroid Build Coastguard Worker 3409*da0073e9SAndroid Build Coastguard Worker def test_ldexp(self, device): 3410*da0073e9SAndroid Build Coastguard Worker # random values 3411*da0073e9SAndroid Build Coastguard Worker mantissas = torch.randn(64, device=device) 3412*da0073e9SAndroid Build Coastguard Worker exponents = torch.randint(-31, 31, (64,), device=device, dtype=torch.int32) 3413*da0073e9SAndroid Build Coastguard Worker 3414*da0073e9SAndroid Build Coastguard Worker # basic test 3415*da0073e9SAndroid Build Coastguard Worker np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy()) 3416*da0073e9SAndroid Build Coastguard Worker pt_outcome_1 = torch.ldexp(mantissas, exponents) 3417*da0073e9SAndroid Build Coastguard Worker pt_outcome_2 = mantissas.ldexp(exponents) 3418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_outcome, pt_outcome_1.cpu()) 3419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_outcome, pt_outcome_2.cpu()) 3420*da0073e9SAndroid Build Coastguard Worker mantissas.ldexp_(exponents) 3421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_outcome, mantissas.cpu()) 3422*da0073e9SAndroid Build Coastguard Worker 3423*da0073e9SAndroid Build Coastguard Worker # test bounds 3424*da0073e9SAndroid Build Coastguard Worker mantissas = torch.tensor( 3425*da0073e9SAndroid Build Coastguard Worker [float("inf"), float("-inf"), float("inf"), float("nan")], device=device 3426*da0073e9SAndroid Build Coastguard Worker ) 3427*da0073e9SAndroid Build Coastguard Worker exponents = torch.randint(0, 31, (4,), device=device, dtype=torch.int32) 3428*da0073e9SAndroid Build Coastguard Worker np_outcome = np.ldexp(mantissas.cpu().numpy(), exponents.cpu().numpy()) 3429*da0073e9SAndroid Build Coastguard Worker pt_outcome = torch.ldexp(mantissas, exponents) 3430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(np_outcome, pt_outcome.cpu()) 3431*da0073e9SAndroid Build Coastguard Worker 3432*da0073e9SAndroid Build Coastguard Worker # test half dtype behavior 3433*da0073e9SAndroid Build Coastguard Worker mantissas = torch.randn(64, device=device, dtype=torch.half) 3434*da0073e9SAndroid Build Coastguard Worker exponents = torch.randint(-5, 5, (64,), device=device) 3435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ldexp(mantissas, exponents).dtype, torch.half) 3436*da0073e9SAndroid Build Coastguard Worker 3437*da0073e9SAndroid Build Coastguard Worker # test float64 computation 3438*da0073e9SAndroid Build Coastguard Worker mantissas = torch.tensor([1], dtype=torch.float64, device=device) 3439*da0073e9SAndroid Build Coastguard Worker exponents = torch.tensor([128], dtype=torch.int64, device=device) 3440*da0073e9SAndroid Build Coastguard Worker expected = torch.pow( 3441*da0073e9SAndroid Build Coastguard Worker torch.full((1,), 2, device=device, dtype=torch.float64), 128 3442*da0073e9SAndroid Build Coastguard Worker ) 3443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ldexp(mantissas, exponents), expected) 3444*da0073e9SAndroid Build Coastguard Worker 3445*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) 3446*da0073e9SAndroid Build Coastguard Worker def test_lerp(self, device, dtype): 3447*da0073e9SAndroid Build Coastguard Worker start_end_weight_shapes = [(), (5,), (5, 5)] 3448*da0073e9SAndroid Build Coastguard Worker for shapes in product( 3449*da0073e9SAndroid Build Coastguard Worker start_end_weight_shapes, start_end_weight_shapes, start_end_weight_shapes 3450*da0073e9SAndroid Build Coastguard Worker ): 3451*da0073e9SAndroid Build Coastguard Worker start = torch.randn(shapes[0], device=device, dtype=dtype) 3452*da0073e9SAndroid Build Coastguard Worker end = torch.randn(shapes[1], device=device, dtype=dtype) 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker # Tensor weights 3455*da0073e9SAndroid Build Coastguard Worker weights = [ 3456*da0073e9SAndroid Build Coastguard Worker torch.randn(shapes[2], device=device, dtype=dtype), 3457*da0073e9SAndroid Build Coastguard Worker random.random(), 3458*da0073e9SAndroid Build Coastguard Worker ] 3459*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 3460*da0073e9SAndroid Build Coastguard Worker weights += [complex(0, 1), complex(0.4, 1.2)] 3461*da0073e9SAndroid Build Coastguard Worker 3462*da0073e9SAndroid Build Coastguard Worker for weight in weights: 3463*da0073e9SAndroid Build Coastguard Worker actual = torch.lerp(start, end, weight) 3464*da0073e9SAndroid Build Coastguard Worker actual_method = start.lerp(end, weight) 3465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, actual_method) 3466*da0073e9SAndroid Build Coastguard Worker actual_out = torch.tensor(1.0, dtype=dtype, device=device) 3467*da0073e9SAndroid Build Coastguard Worker torch.lerp(start, end, weight, out=actual_out) 3468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, actual_out) 3469*da0073e9SAndroid Build Coastguard Worker expected = start + weight * (end - start) 3470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 3471*da0073e9SAndroid Build Coastguard Worker 3472*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3473*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16) 3474*da0073e9SAndroid Build Coastguard Worker def test_lerp_lowp(self, device, dtype): 3475*da0073e9SAndroid Build Coastguard Worker xvals = (0.0, -30000.0) 3476*da0073e9SAndroid Build Coastguard Worker yvals = (0.1, -20000.0) 3477*da0073e9SAndroid Build Coastguard Worker xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals] 3478*da0073e9SAndroid Build Coastguard Worker ys = [torch.full((4,), yval, device=device, dtype=dtype) for yval in yvals] 3479*da0073e9SAndroid Build Coastguard Worker weights = [70000, torch.full((4,), 8, device=device, dtype=dtype)] 3480*da0073e9SAndroid Build Coastguard Worker for x, y, w in zip(xs, ys, weights): 3481*da0073e9SAndroid Build Coastguard Worker xref = x.float() 3482*da0073e9SAndroid Build Coastguard Worker yref = y.float() 3483*da0073e9SAndroid Build Coastguard Worker wref = w.float() if isinstance(w, torch.Tensor) else w 3484*da0073e9SAndroid Build Coastguard Worker actual = torch.lerp(x, y, w) 3485*da0073e9SAndroid Build Coastguard Worker expected = torch.lerp(xref, yref, wref).to(dtype) 3486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0.0, rtol=0.0) 3487*da0073e9SAndroid Build Coastguard Worker 3488*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3489*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.bfloat16) 3490*da0073e9SAndroid Build Coastguard Worker def test_lerp_lowp_cpu(self, device, dtype): 3491*da0073e9SAndroid Build Coastguard Worker xvals = (0.0, -30000.0) 3492*da0073e9SAndroid Build Coastguard Worker yvals = (0.1, -20000.0) 3493*da0073e9SAndroid Build Coastguard Worker for shape in [(4,), (20,), (3, 10, 10)]: 3494*da0073e9SAndroid Build Coastguard Worker xs = [torch.full(shape, xval, device=device, dtype=dtype) for xval in xvals] 3495*da0073e9SAndroid Build Coastguard Worker ys = [torch.full(shape, yval, device=device, dtype=dtype) for yval in yvals] 3496*da0073e9SAndroid Build Coastguard Worker weights = [70000, torch.full(shape, 8, device=device, dtype=dtype)] 3497*da0073e9SAndroid Build Coastguard Worker for x, y, w in zip(xs, ys, weights): 3498*da0073e9SAndroid Build Coastguard Worker xref = x.float() 3499*da0073e9SAndroid Build Coastguard Worker yref = y.float() 3500*da0073e9SAndroid Build Coastguard Worker wref = w.float() if isinstance(w, torch.Tensor) else w 3501*da0073e9SAndroid Build Coastguard Worker actual = torch.lerp(x, y, w) 3502*da0073e9SAndroid Build Coastguard Worker expected = torch.lerp(xref, yref, wref).to(dtype) 3503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, atol=0.0, rtol=0.0) 3504*da0073e9SAndroid Build Coastguard Worker 3505*da0073e9SAndroid Build Coastguard Worker def _test_logaddexp(self, device, dtype, base2): 3506*da0073e9SAndroid Build Coastguard Worker if base2: 3507*da0073e9SAndroid Build Coastguard Worker ref_func = np.logaddexp2 3508*da0073e9SAndroid Build Coastguard Worker our_func = torch.logaddexp2 3509*da0073e9SAndroid Build Coastguard Worker elif dtype in (torch.complex64, torch.complex128): 3510*da0073e9SAndroid Build Coastguard Worker # numpy has not implemented logaddexp for complex 3511*da0073e9SAndroid Build Coastguard Worker def _ref_func(x, y): 3512*da0073e9SAndroid Build Coastguard Worker return scipy.special.logsumexp(np.stack((x, y), axis=0), axis=0) 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker ref_func = _ref_func 3515*da0073e9SAndroid Build Coastguard Worker our_func = torch.logaddexp 3516*da0073e9SAndroid Build Coastguard Worker else: 3517*da0073e9SAndroid Build Coastguard Worker ref_func = np.logaddexp 3518*da0073e9SAndroid Build Coastguard Worker our_func = torch.logaddexp 3519*da0073e9SAndroid Build Coastguard Worker 3520*da0073e9SAndroid Build Coastguard Worker def _test_helper(a, b): 3521*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 3522*da0073e9SAndroid Build Coastguard Worker ref = ref_func(a.cpu().float().numpy(), b.cpu().float().numpy()) 3523*da0073e9SAndroid Build Coastguard Worker v = our_func(a, b) 3524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01) 3525*da0073e9SAndroid Build Coastguard Worker else: 3526*da0073e9SAndroid Build Coastguard Worker ref = ref_func(a.cpu().numpy(), b.cpu().numpy()) 3527*da0073e9SAndroid Build Coastguard Worker v = our_func(a, b) 3528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, v) 3529*da0073e9SAndroid Build Coastguard Worker 3530*da0073e9SAndroid Build Coastguard Worker # simple test 3531*da0073e9SAndroid Build Coastguard Worker a = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 3532*da0073e9SAndroid Build Coastguard Worker b = torch.randn(64, 2, dtype=dtype, device=device) - 0.5 3533*da0073e9SAndroid Build Coastguard Worker _test_helper(a, b) 3534*da0073e9SAndroid Build Coastguard Worker _test_helper(a[:3], b[:3]) 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker # large value test for numerical stability 3537*da0073e9SAndroid Build Coastguard Worker a *= 10000 3538*da0073e9SAndroid Build Coastguard Worker b *= 10000 3539*da0073e9SAndroid Build Coastguard Worker _test_helper(a, b) 3540*da0073e9SAndroid Build Coastguard Worker _test_helper(a[:3], b[:3]) 3541*da0073e9SAndroid Build Coastguard Worker 3542*da0073e9SAndroid Build Coastguard Worker a = torch.tensor( 3543*da0073e9SAndroid Build Coastguard Worker [float("inf"), float("-inf"), float("inf"), float("nan")], 3544*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 3545*da0073e9SAndroid Build Coastguard Worker device=device, 3546*da0073e9SAndroid Build Coastguard Worker ) 3547*da0073e9SAndroid Build Coastguard Worker b = torch.tensor( 3548*da0073e9SAndroid Build Coastguard Worker [float("inf"), float("-inf"), float("-inf"), float("nan")], 3549*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 3550*da0073e9SAndroid Build Coastguard Worker device=device, 3551*da0073e9SAndroid Build Coastguard Worker ) 3552*da0073e9SAndroid Build Coastguard Worker _test_helper(a, b) 3553*da0073e9SAndroid Build Coastguard Worker 3554*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() # complex infs/nans differ under Dynamo/Inductor 3555*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float32, torch.float64, torch.bfloat16) 3556*da0073e9SAndroid Build Coastguard Worker @dtypes( 3557*da0073e9SAndroid Build Coastguard Worker torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128 3558*da0073e9SAndroid Build Coastguard Worker ) 3559*da0073e9SAndroid Build Coastguard Worker def test_logaddexp(self, device, dtype): 3560*da0073e9SAndroid Build Coastguard Worker self._test_logaddexp(device, dtype, base2=False) 3561*da0073e9SAndroid Build Coastguard Worker 3562*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.float64, torch.bfloat16) 3563*da0073e9SAndroid Build Coastguard Worker def test_logaddexp2(self, device, dtype): 3564*da0073e9SAndroid Build Coastguard Worker self._test_logaddexp(device, dtype, base2=True) 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker def test_add(self, device): 3567*da0073e9SAndroid Build Coastguard Worker dtypes = floating_and_complex_types() 3568*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 3569*da0073e9SAndroid Build Coastguard Worker # [res] torch.add([res,] tensor1, tensor2) 3570*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(100, 100, dtype=dtype, device=device) 3571*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(100, dtype=dtype, device=device) 3572*da0073e9SAndroid Build Coastguard Worker 3573*da0073e9SAndroid Build Coastguard Worker # contiguous 3574*da0073e9SAndroid Build Coastguard Worker res1 = torch.add(m1[4], v1) 3575*da0073e9SAndroid Build Coastguard Worker res2 = res1.clone().zero_() 3576*da0073e9SAndroid Build Coastguard Worker for i in range(m1.size(1)): 3577*da0073e9SAndroid Build Coastguard Worker res2[i] = m1[4, i] + v1[i] 3578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3579*da0073e9SAndroid Build Coastguard Worker 3580*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(100, 100, device=device) 3581*da0073e9SAndroid Build Coastguard Worker v1 = torch.randn(100, device=device) 3582*da0073e9SAndroid Build Coastguard Worker 3583*da0073e9SAndroid Build Coastguard Worker # non-contiguous 3584*da0073e9SAndroid Build Coastguard Worker res1 = torch.add(m1[:, 4], v1) 3585*da0073e9SAndroid Build Coastguard Worker res2 = res1.clone().zero_() 3586*da0073e9SAndroid Build Coastguard Worker for i in range(m1.size(0)): 3587*da0073e9SAndroid Build Coastguard Worker res2[i] = m1[i, 4] + v1[i] 3588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3589*da0073e9SAndroid Build Coastguard Worker 3590*da0073e9SAndroid Build Coastguard Worker # [res] torch.add([res,] tensor, value) 3591*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, device=device) 3592*da0073e9SAndroid Build Coastguard Worker 3593*da0073e9SAndroid Build Coastguard Worker # contiguous 3594*da0073e9SAndroid Build Coastguard Worker res1 = m1.clone() 3595*da0073e9SAndroid Build Coastguard Worker res1[3].add_(2) 3596*da0073e9SAndroid Build Coastguard Worker res2 = m1.clone() 3597*da0073e9SAndroid Build Coastguard Worker for i in range(m1.size(1)): 3598*da0073e9SAndroid Build Coastguard Worker res2[3, i] = res2[3, i] + 2 3599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3600*da0073e9SAndroid Build Coastguard Worker 3601*da0073e9SAndroid Build Coastguard Worker # non-contiguous 3602*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, device=device) 3603*da0073e9SAndroid Build Coastguard Worker res1 = m1.clone() 3604*da0073e9SAndroid Build Coastguard Worker res1[:, 3].add_(2) 3605*da0073e9SAndroid Build Coastguard Worker res2 = m1.clone() 3606*da0073e9SAndroid Build Coastguard Worker for i in range(m1.size(0)): 3607*da0073e9SAndroid Build Coastguard Worker res2[i, 3] = res2[i, 3] + 2 3608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3609*da0073e9SAndroid Build Coastguard Worker 3610*da0073e9SAndroid Build Coastguard Worker # inter-type 3611*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, dtype=dtype, device=device) 3612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1 + 3, m1 + torch.tensor(3)) 3613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3 + m1, torch.tensor(3) + m1) 3614*da0073e9SAndroid Build Coastguard Worker 3615*da0073e9SAndroid Build Coastguard Worker # contiguous + non-contiguous 3616*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, dtype=dtype, device=device) 3617*da0073e9SAndroid Build Coastguard Worker m2 = torch.randn(10, 10, dtype=dtype, device=device).t() 3618*da0073e9SAndroid Build Coastguard Worker res = m1 + m2 3619*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.is_contiguous()) 3620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, m1 + m2.contiguous()) 3621*da0073e9SAndroid Build Coastguard Worker 3622*da0073e9SAndroid Build Coastguard Worker # 1d + empty 3623*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([1.0], dtype=dtype, device=device) 3624*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([], dtype=dtype, device=device) 3625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1 + m2, []) 3626*da0073e9SAndroid Build Coastguard Worker 3627*da0073e9SAndroid Build Coastguard Worker # inter-type unint8 3628*da0073e9SAndroid Build Coastguard Worker one = torch.tensor(1, dtype=torch.uint8, device=device) 3629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(one, 1), 2) 3630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(one, 1).dtype, torch.uint8) 3631*da0073e9SAndroid Build Coastguard Worker 3632*da0073e9SAndroid Build Coastguard Worker # bool 3633*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor( 3634*da0073e9SAndroid Build Coastguard Worker [True, False, False, True, False, False], dtype=torch.bool, device=device 3635*da0073e9SAndroid Build Coastguard Worker ) 3636*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor( 3637*da0073e9SAndroid Build Coastguard Worker [True, True, False, False, False, True], dtype=torch.bool, device=device 3638*da0073e9SAndroid Build Coastguard Worker ) 3639*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 3640*da0073e9SAndroid Build Coastguard Worker [True, True, False, True, False, True], dtype=torch.bool, device=device 3641*da0073e9SAndroid Build Coastguard Worker ) 3642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1 + m2, expected) 3643*da0073e9SAndroid Build Coastguard Worker 3644*da0073e9SAndroid Build Coastguard Worker # fused multiply add 3645*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(2, 3, dtype=torch.bool, device=device) 3646*da0073e9SAndroid Build Coastguard Worker res = torch.add(a, a, alpha=0) 3647*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(2, 3, device=device).bool() 3648*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 3649*da0073e9SAndroid Build Coastguard Worker 3650*da0073e9SAndroid Build Coastguard Worker # bfloat16 3651*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([1.0, 2.0], dtype=torch.bfloat16) 3652*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([3.0, 4.0], dtype=torch.bfloat16) 3653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m1 + m2, torch.tensor([4.0, 6.0], dtype=torch.bfloat16)) 3654*da0073e9SAndroid Build Coastguard Worker 3655*da0073e9SAndroid Build Coastguard Worker # different alpha types 3656*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([2 + 3j, 4 + 5j], dtype=torch.complex64, device=device) 3657*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([4 + 5j, 2 + 3j], dtype=torch.complex64, device=device) 3658*da0073e9SAndroid Build Coastguard Worker # add complex numbers with float alpha 3659*da0073e9SAndroid Build Coastguard Worker res = torch.add(m1, m2, alpha=0.1) 3660*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 3661*da0073e9SAndroid Build Coastguard Worker [2.4000 + 3.5000j, 4.2000 + 5.3000j], dtype=torch.complex64, device=device 3662*da0073e9SAndroid Build Coastguard Worker ) 3663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 3664*da0073e9SAndroid Build Coastguard Worker 3665*da0073e9SAndroid Build Coastguard Worker # add complex numbers with complex alpha 3666*da0073e9SAndroid Build Coastguard Worker res = torch.add(m1, m2, alpha=complex(0.1, 0.2)) 3667*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 3668*da0073e9SAndroid Build Coastguard Worker [1.4000 + 4.3000j, 3.6000 + 5.7000j], dtype=torch.complex64, device=device 3669*da0073e9SAndroid Build Coastguard Worker ) 3670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 3671*da0073e9SAndroid Build Coastguard Worker 3672*da0073e9SAndroid Build Coastguard Worker # add complex numbers with integer alpha 3673*da0073e9SAndroid Build Coastguard Worker res = torch.add(m1, m2, alpha=2) 3674*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 3675*da0073e9SAndroid Build Coastguard Worker [10.0 + 13.0j, 8.0 + 11.0j], dtype=torch.complex64, device=device 3676*da0073e9SAndroid Build Coastguard Worker ) 3677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected) 3678*da0073e9SAndroid Build Coastguard Worker 3679*da0073e9SAndroid Build Coastguard Worker # mismatched alpha 3680*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([1], dtype=torch.int8, device=device) 3681*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([2], dtype=torch.int8, device=device) 3682*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3683*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3684*da0073e9SAndroid Build Coastguard Worker r"Boolean alpha only supported for Boolean results\.", 3685*da0073e9SAndroid Build Coastguard Worker lambda: torch.add(m1, m2, alpha=True), 3686*da0073e9SAndroid Build Coastguard Worker ) 3687*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3688*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3689*da0073e9SAndroid Build Coastguard Worker r"For integral input tensors, argument alpha must not be a floating point number\.", 3690*da0073e9SAndroid Build Coastguard Worker lambda: torch.add(m1, m2, alpha=1.0), 3691*da0073e9SAndroid Build Coastguard Worker ) 3692*da0073e9SAndroid Build Coastguard Worker 3693*da0073e9SAndroid Build Coastguard Worker # mismatched alpha, float / double tensor and complex alpha 3694*da0073e9SAndroid Build Coastguard Worker msg = r"For non-complex input tensors, argument alpha must not be a complex number\." 3695*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([3.0, 4.0], device=device) 3696*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([4.0, 3.0], device=device) 3697*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3698*da0073e9SAndroid Build Coastguard Worker RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2)) 3699*da0073e9SAndroid Build Coastguard Worker ) 3700*da0073e9SAndroid Build Coastguard Worker 3701*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([3.0, 4.0], dtype=torch.double, device=device) 3702*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([4.0, 3.0], dtype=torch.double, device=device) 3703*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3704*da0073e9SAndroid Build Coastguard Worker RuntimeError, msg, lambda: torch.add(m1, m2, alpha=complex(0.1, 0.2)) 3705*da0073e9SAndroid Build Coastguard Worker ) 3706*da0073e9SAndroid Build Coastguard Worker 3707*da0073e9SAndroid Build Coastguard Worker # complex 3708*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor((4.0000 + 4.0000j), dtype=torch.complex64) 3709*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor(4.0, dtype=torch.float64) 3710*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3711*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3712*da0073e9SAndroid Build Coastguard Worker r"result type ComplexFloat can't be cast to the desired output type Double", 3713*da0073e9SAndroid Build Coastguard Worker lambda: torch.add(m1, m1, out=m2), 3714*da0073e9SAndroid Build Coastguard Worker ) 3715*da0073e9SAndroid Build Coastguard Worker 3716*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3717*da0073e9SAndroid Build Coastguard Worker def test_addsub_half_tensor(self, device): 3718*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([60000.0], dtype=torch.half, device=device) 3719*da0073e9SAndroid Build Coastguard Worker for op, y, alpha in ( 3720*da0073e9SAndroid Build Coastguard Worker (torch.add, torch.tensor([-60000.0], dtype=torch.half, device=device), 2), 3721*da0073e9SAndroid Build Coastguard Worker (torch.sub, torch.tensor([60000.0], dtype=torch.half, device=device), 2), 3722*da0073e9SAndroid Build Coastguard Worker (torch.add, -70000.0, 1), 3723*da0073e9SAndroid Build Coastguard Worker (torch.sub, 70000.0, 1), 3724*da0073e9SAndroid Build Coastguard Worker ): 3725*da0073e9SAndroid Build Coastguard Worker actual = op(x, y, alpha=alpha) 3726*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not (actual.isnan() or actual.isinf())) 3727*da0073e9SAndroid Build Coastguard Worker 3728*da0073e9SAndroid Build Coastguard Worker def test_sub_typing(self, device): 3729*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor( 3730*da0073e9SAndroid Build Coastguard Worker [True, False, False, True, False, False], dtype=torch.bool, device=device 3731*da0073e9SAndroid Build Coastguard Worker ) 3732*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor( 3733*da0073e9SAndroid Build Coastguard Worker [True, True, False, False, False, True], dtype=torch.bool, device=device 3734*da0073e9SAndroid Build Coastguard Worker ) 3735*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3736*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3737*da0073e9SAndroid Build Coastguard Worker r"Subtraction, the `\-` operator, with two bool tensors is not supported. " 3738*da0073e9SAndroid Build Coastguard Worker r"Use the `\^` or `logical_xor\(\)` operator instead.", 3739*da0073e9SAndroid Build Coastguard Worker lambda: m1 - m2, 3740*da0073e9SAndroid Build Coastguard Worker ) 3741*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3742*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3743*da0073e9SAndroid Build Coastguard Worker r"Subtraction, the `\-` operator, with a bool tensor is not supported. " 3744*da0073e9SAndroid Build Coastguard Worker r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", 3745*da0073e9SAndroid Build Coastguard Worker lambda: 1 - m1, 3746*da0073e9SAndroid Build Coastguard Worker ) 3747*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3748*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3749*da0073e9SAndroid Build Coastguard Worker r"Subtraction, the `\-` operator, with a bool tensor is not supported. " 3750*da0073e9SAndroid Build Coastguard Worker r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", 3751*da0073e9SAndroid Build Coastguard Worker lambda: m2 - 1, 3752*da0073e9SAndroid Build Coastguard Worker ) 3753*da0073e9SAndroid Build Coastguard Worker 3754*da0073e9SAndroid Build Coastguard Worker # mismatched alpha 3755*da0073e9SAndroid Build Coastguard Worker m1 = torch.tensor([1], dtype=torch.int8, device=device) 3756*da0073e9SAndroid Build Coastguard Worker m2 = torch.tensor([2], dtype=torch.int8, device=device) 3757*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3758*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3759*da0073e9SAndroid Build Coastguard Worker r"Boolean alpha only supported for Boolean results\.", 3760*da0073e9SAndroid Build Coastguard Worker lambda: torch.sub(m1, m2, alpha=True), 3761*da0073e9SAndroid Build Coastguard Worker ) 3762*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3763*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3764*da0073e9SAndroid Build Coastguard Worker r"For integral input tensors, argument alpha must not be a floating point number\.", 3765*da0073e9SAndroid Build Coastguard Worker lambda: torch.sub(m1, m2, alpha=1.0), 3766*da0073e9SAndroid Build Coastguard Worker ) 3767*da0073e9SAndroid Build Coastguard Worker 3768*da0073e9SAndroid Build Coastguard Worker def test_mul(self, device): 3769*da0073e9SAndroid Build Coastguard Worker m1 = torch.randn(10, 10, device=device) 3770*da0073e9SAndroid Build Coastguard Worker res1 = m1.clone() 3771*da0073e9SAndroid Build Coastguard Worker res1[:, 3].mul_(2) 3772*da0073e9SAndroid Build Coastguard Worker res2 = m1.clone() 3773*da0073e9SAndroid Build Coastguard Worker for i in range(res1.size(0)): 3774*da0073e9SAndroid Build Coastguard Worker res2[i, 3] = res2[i, 3] * 2 3775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 3776*da0073e9SAndroid Build Coastguard Worker 3777*da0073e9SAndroid Build Coastguard Worker a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) 3778*da0073e9SAndroid Build Coastguard Worker a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) 3779*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3780*da0073e9SAndroid Build Coastguard Worker a1 * a2, 3781*da0073e9SAndroid Build Coastguard Worker torch.tensor([True, False, False, False], dtype=torch.bool, device=device), 3782*da0073e9SAndroid Build Coastguard Worker ) 3783*da0073e9SAndroid Build Coastguard Worker 3784*da0073e9SAndroid Build Coastguard Worker if device == "cpu": 3785*da0073e9SAndroid Build Coastguard Worker a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) 3786*da0073e9SAndroid Build Coastguard Worker a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) 3787*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3788*da0073e9SAndroid Build Coastguard Worker a1 * a2, 3789*da0073e9SAndroid Build Coastguard Worker torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), 3790*da0073e9SAndroid Build Coastguard Worker atol=0.01, 3791*da0073e9SAndroid Build Coastguard Worker rtol=0, 3792*da0073e9SAndroid Build Coastguard Worker ) 3793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1.mul(a2), a1 * a2) 3794*da0073e9SAndroid Build Coastguard Worker 3795*da0073e9SAndroid Build Coastguard Worker def test_bool_tensor_comparison_ops(self, device): 3796*da0073e9SAndroid Build Coastguard Worker a = torch.tensor( 3797*da0073e9SAndroid Build Coastguard Worker [True, False, True, False, True, False], dtype=torch.bool, device=device 3798*da0073e9SAndroid Build Coastguard Worker ) 3799*da0073e9SAndroid Build Coastguard Worker b = torch.tensor( 3800*da0073e9SAndroid Build Coastguard Worker [True, False, True, True, True, True], dtype=torch.bool, device=device 3801*da0073e9SAndroid Build Coastguard Worker ) 3802*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3803*da0073e9SAndroid Build Coastguard Worker a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device) 3804*da0073e9SAndroid Build Coastguard Worker ) 3805*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3806*da0073e9SAndroid Build Coastguard Worker a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device) 3807*da0073e9SAndroid Build Coastguard Worker ) 3808*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3809*da0073e9SAndroid Build Coastguard Worker a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device) 3810*da0073e9SAndroid Build Coastguard Worker ) 3811*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3812*da0073e9SAndroid Build Coastguard Worker a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device) 3813*da0073e9SAndroid Build Coastguard Worker ) 3814*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3815*da0073e9SAndroid Build Coastguard Worker a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device) 3816*da0073e9SAndroid Build Coastguard Worker ) 3817*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3818*da0073e9SAndroid Build Coastguard Worker a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device) 3819*da0073e9SAndroid Build Coastguard Worker ) 3820*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3821*da0073e9SAndroid Build Coastguard Worker a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device) 3822*da0073e9SAndroid Build Coastguard Worker ) 3823*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3824*da0073e9SAndroid Build Coastguard Worker a == torch.tensor(True, dtype=torch.bool, device=device), 3825*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device), 3826*da0073e9SAndroid Build Coastguard Worker ) 3827*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3828*da0073e9SAndroid Build Coastguard Worker a == torch.tensor(0, dtype=torch.bool, device=device), 3829*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device), 3830*da0073e9SAndroid Build Coastguard Worker ) 3831*da0073e9SAndroid Build Coastguard Worker self.assertFalse(a.equal(b)) 3832*da0073e9SAndroid Build Coastguard Worker 3833*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16, torch.bool)) 3834*da0073e9SAndroid Build Coastguard Worker def test_logical(self, device, dtype): 3835*da0073e9SAndroid Build Coastguard Worker if dtype != torch.bool: 3836*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype) 3837*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([2], device=device, dtype=dtype) 3838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) 3839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) 3840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) 3841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) 3842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) 3843*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) 3844*da0073e9SAndroid Build Coastguard Worker 3845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) 3846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) 3847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) 3848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) 3849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) 3850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) 3851*da0073e9SAndroid Build Coastguard Worker else: 3852*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([True, False, True, False], device=device) 3853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) 3854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) 3855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) 3856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) 3857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) 3858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) 3859*da0073e9SAndroid Build Coastguard Worker 3860*da0073e9SAndroid Build Coastguard Worker def test_atan2(self, device): 3861*da0073e9SAndroid Build Coastguard Worker def _test_atan2_with_size(size, device): 3862*da0073e9SAndroid Build Coastguard Worker a = torch.rand(size=size, device=device, dtype=torch.double) 3863*da0073e9SAndroid Build Coastguard Worker b = torch.rand(size=size, device=device, dtype=torch.double) 3864*da0073e9SAndroid Build Coastguard Worker actual = a.atan2(b) 3865*da0073e9SAndroid Build Coastguard Worker x = a.view(-1) 3866*da0073e9SAndroid Build Coastguard Worker y = b.view(-1) 3867*da0073e9SAndroid Build Coastguard Worker expected = torch.tensor( 3868*da0073e9SAndroid Build Coastguard Worker [math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], 3869*da0073e9SAndroid Build Coastguard Worker device=device, 3870*da0073e9SAndroid Build Coastguard Worker dtype=torch.double, 3871*da0073e9SAndroid Build Coastguard Worker ) 3872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual.view(-1), rtol=0, atol=0.02) 3873*da0073e9SAndroid Build Coastguard Worker 3874*da0073e9SAndroid Build Coastguard Worker # bfloat16/float16 3875*da0073e9SAndroid Build Coastguard Worker for lowp_dtype in [torch.bfloat16, torch.float16]: 3876*da0073e9SAndroid Build Coastguard Worker if lowp_dtype == torch.bfloat16: 3877*da0073e9SAndroid Build Coastguard Worker rtol = 0 3878*da0073e9SAndroid Build Coastguard Worker atol = 0.02 3879*da0073e9SAndroid Build Coastguard Worker else: 3880*da0073e9SAndroid Build Coastguard Worker rtol = 0 3881*da0073e9SAndroid Build Coastguard Worker atol = 0.001 3882*da0073e9SAndroid Build Coastguard Worker a_16 = a.to(dtype=lowp_dtype) 3883*da0073e9SAndroid Build Coastguard Worker b_16 = b.to(dtype=lowp_dtype) 3884*da0073e9SAndroid Build Coastguard Worker actual_16 = a_16.atan2(b_16) 3885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_16, actual.to(dtype=lowp_dtype)) 3886*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3887*da0073e9SAndroid Build Coastguard Worker expected, 3888*da0073e9SAndroid Build Coastguard Worker actual_16.view(-1), 3889*da0073e9SAndroid Build Coastguard Worker exact_dtype=False, 3890*da0073e9SAndroid Build Coastguard Worker rtol=rtol, 3891*da0073e9SAndroid Build Coastguard Worker atol=atol, 3892*da0073e9SAndroid Build Coastguard Worker ) 3893*da0073e9SAndroid Build Coastguard Worker 3894*da0073e9SAndroid Build Coastguard Worker _test_atan2_with_size((2, 2), device) 3895*da0073e9SAndroid Build Coastguard Worker _test_atan2_with_size((3, 3), device) 3896*da0073e9SAndroid Build Coastguard Worker _test_atan2_with_size((5, 5), device) 3897*da0073e9SAndroid Build Coastguard Worker 3898*da0073e9SAndroid Build Coastguard Worker def test_atan2_edgecases(self, device): 3899*da0073e9SAndroid Build Coastguard Worker def _test_atan2(x, y, expected, device, dtype): 3900*da0073e9SAndroid Build Coastguard Worker expected_tensor = torch.tensor([expected], dtype=dtype, device=device) 3901*da0073e9SAndroid Build Coastguard Worker x_tensor = torch.tensor([x], dtype=dtype, device=device) 3902*da0073e9SAndroid Build Coastguard Worker y_tensor = torch.tensor([y], dtype=dtype, device=device) 3903*da0073e9SAndroid Build Coastguard Worker actual = torch.atan2(y_tensor, x_tensor) 3904*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_tensor, actual, rtol=0, atol=0.02) 3905*da0073e9SAndroid Build Coastguard Worker 3906*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 3907*da0073e9SAndroid Build Coastguard Worker _test_atan2(0, 0, 0, device, dtype) 3908*da0073e9SAndroid Build Coastguard Worker _test_atan2(0, 1, math.pi / 2, device, dtype) 3909*da0073e9SAndroid Build Coastguard Worker _test_atan2(0, -1, math.pi / -2, device, dtype) 3910*da0073e9SAndroid Build Coastguard Worker _test_atan2(-1, 0, math.pi, device, dtype) 3911*da0073e9SAndroid Build Coastguard Worker _test_atan2(1, 0, 0, device, dtype) 3912*da0073e9SAndroid Build Coastguard Worker _test_atan2(-1, -1, math.pi * -3 / 4, device, dtype) 3913*da0073e9SAndroid Build Coastguard Worker _test_atan2(1, 1, math.pi / 4, device, dtype) 3914*da0073e9SAndroid Build Coastguard Worker _test_atan2(1, -1, math.pi / -4, device, dtype) 3915*da0073e9SAndroid Build Coastguard Worker _test_atan2(-1, 1, math.pi * 3 / 4, device, dtype) 3916*da0073e9SAndroid Build Coastguard Worker 3917*da0073e9SAndroid Build Coastguard Worker def test_trapezoid(self, device): 3918*da0073e9SAndroid Build Coastguard Worker def test_dx(sizes, dim, dx, device): 3919*da0073e9SAndroid Build Coastguard Worker t = torch.randn(sizes, device=device) 3920*da0073e9SAndroid Build Coastguard Worker actual = torch.trapezoid(t, dx=dx, dim=dim) 3921*da0073e9SAndroid Build Coastguard Worker expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) # noqa: NPY201 3922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 3923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, exact_dtype=False) 3924*da0073e9SAndroid Build Coastguard Worker 3925*da0073e9SAndroid Build Coastguard Worker def test_x(sizes, dim, x, device): 3926*da0073e9SAndroid Build Coastguard Worker t = torch.randn(sizes, device=device) 3927*da0073e9SAndroid Build Coastguard Worker actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim) 3928*da0073e9SAndroid Build Coastguard Worker expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) # noqa: NPY201 3929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 3930*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual.cpu(), exact_dtype=False) 3931*da0073e9SAndroid Build Coastguard Worker 3932*da0073e9SAndroid Build Coastguard Worker test_dx((2, 3, 4), 1, 1, device) 3933*da0073e9SAndroid Build Coastguard Worker test_dx((10, 2), 0, 0.1, device) 3934*da0073e9SAndroid Build Coastguard Worker test_dx((1, 10), 0, 2.3, device) 3935*da0073e9SAndroid Build Coastguard Worker test_dx((0, 2), 0, 1.0, device) 3936*da0073e9SAndroid Build Coastguard Worker test_dx((0, 2), 1, 1.0, device) 3937*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) 3938*da0073e9SAndroid Build Coastguard Worker test_x( 3939*da0073e9SAndroid Build Coastguard Worker (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device 3940*da0073e9SAndroid Build Coastguard Worker ) 3941*da0073e9SAndroid Build Coastguard Worker test_x((1, 10), 0, [1.0], device) 3942*da0073e9SAndroid Build Coastguard Worker test_x((0, 2), 0, [], device) 3943*da0073e9SAndroid Build Coastguard Worker test_x((0, 2), 1, [1.0, 2.0], device) 3944*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device) 3945*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 0, [1.0, 2.0], device) 3946*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) 3947*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device) 3948*da0073e9SAndroid Build Coastguard Worker test_x((2, 2, 4), -1, [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]], device) 3949*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 3950*da0073e9SAndroid Build Coastguard Worker test_x((2, 3), 2, [], device) 3951*da0073e9SAndroid Build Coastguard Worker test_dx((2, 3), 2, 1.0, device) 3952*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3953*da0073e9SAndroid Build Coastguard Worker RuntimeError, "There must be one `x` value for each sample point" 3954*da0073e9SAndroid Build Coastguard Worker ): 3955*da0073e9SAndroid Build Coastguard Worker test_x((2, 3), 1, [1.0, 2.0], device) 3956*da0073e9SAndroid Build Coastguard Worker test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) 3957*da0073e9SAndroid Build Coastguard Worker 3958*da0073e9SAndroid Build Coastguard Worker @skipIf(not TEST_SCIPY, "Scipy required for the test.") 3959*da0073e9SAndroid Build Coastguard Worker def test_cumulative_trapezoid(self, device): 3960*da0073e9SAndroid Build Coastguard Worker import scipy.integrate 3961*da0073e9SAndroid Build Coastguard Worker 3962*da0073e9SAndroid Build Coastguard Worker if hasattr(scipy.integrate, "cumulative_trapezoid"): 3963*da0073e9SAndroid Build Coastguard Worker _scipy_cumulative_trapezoid = scipy.integrate.cumulative_trapezoid 3964*da0073e9SAndroid Build Coastguard Worker else: # Older version of SciPy uses a different name 3965*da0073e9SAndroid Build Coastguard Worker _scipy_cumulative_trapezoid = scipy.integrate.cumtrapz 3966*da0073e9SAndroid Build Coastguard Worker 3967*da0073e9SAndroid Build Coastguard Worker def scipy_cumulative_trapezoid(y, x=None, dx=1.0, axis=-1, initial=None): 3968*da0073e9SAndroid Build Coastguard Worker if y.shape[axis] == 0: 3969*da0073e9SAndroid Build Coastguard Worker return np.empty_like(y) 3970*da0073e9SAndroid Build Coastguard Worker else: 3971*da0073e9SAndroid Build Coastguard Worker return _scipy_cumulative_trapezoid(y, x, dx, axis, initial) 3972*da0073e9SAndroid Build Coastguard Worker 3973*da0073e9SAndroid Build Coastguard Worker def test_dx(sizes, dim, dx, device): 3974*da0073e9SAndroid Build Coastguard Worker t = torch.randn(sizes, device=device) 3975*da0073e9SAndroid Build Coastguard Worker y = t.cpu().numpy() 3976*da0073e9SAndroid Build Coastguard Worker actual = torch.cumulative_trapezoid(t, dx=dx, dim=dim) 3977*da0073e9SAndroid Build Coastguard Worker expected = scipy_cumulative_trapezoid(t.cpu().numpy(), dx=dx, axis=dim) 3978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 3979*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, exact_dtype=False, atol=1e-4, rtol=1e-4) 3980*da0073e9SAndroid Build Coastguard Worker 3981*da0073e9SAndroid Build Coastguard Worker def test_x(sizes, dim, x, device): 3982*da0073e9SAndroid Build Coastguard Worker t = torch.randn(sizes, device=device) 3983*da0073e9SAndroid Build Coastguard Worker actual = torch.cumulative_trapezoid( 3984*da0073e9SAndroid Build Coastguard Worker t, x=torch.tensor(x, device=device), dim=dim 3985*da0073e9SAndroid Build Coastguard Worker ) 3986*da0073e9SAndroid Build Coastguard Worker expected = scipy_cumulative_trapezoid(t.cpu().numpy(), x=x, axis=dim) 3987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 3988*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3989*da0073e9SAndroid Build Coastguard Worker expected, actual.cpu(), exact_dtype=False, atol=1e-4, rtol=1e-4 3990*da0073e9SAndroid Build Coastguard Worker ) 3991*da0073e9SAndroid Build Coastguard Worker 3992*da0073e9SAndroid Build Coastguard Worker def test_empty_x(sizes, dim, x, device): 3993*da0073e9SAndroid Build Coastguard Worker t = torch.randn(sizes, device=device) 3994*da0073e9SAndroid Build Coastguard Worker actual = torch.cumulative_trapezoid( 3995*da0073e9SAndroid Build Coastguard Worker t, x=torch.tensor(x, device=device), dim=dim 3996*da0073e9SAndroid Build Coastguard Worker ) 3997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(actual.shape), actual) 3998*da0073e9SAndroid Build Coastguard Worker 3999*da0073e9SAndroid Build Coastguard Worker test_dx((2,), -1, 1, device) 4000*da0073e9SAndroid Build Coastguard Worker test_dx((3, 3), -1, 1, device) 4001*da0073e9SAndroid Build Coastguard Worker test_dx((4, 2), 0, 1, device) 4002*da0073e9SAndroid Build Coastguard Worker test_dx((2, 3, 4), 1, 1, device) 4003*da0073e9SAndroid Build Coastguard Worker test_dx((10, 2), 0, 0.1, device) 4004*da0073e9SAndroid Build Coastguard Worker test_dx((1, 10), 0, 2.3, device) 4005*da0073e9SAndroid Build Coastguard Worker test_dx((0, 2), 0, 1.0, device) 4006*da0073e9SAndroid Build Coastguard Worker test_dx((0, 2), 1, 1.0, device) 4007*da0073e9SAndroid Build Coastguard Worker test_dx((512, 512), 1, 1.0, device) 4008*da0073e9SAndroid Build Coastguard Worker test_dx((100, 100, 100), 1, 1.0, device) 4009*da0073e9SAndroid Build Coastguard Worker 4010*da0073e9SAndroid Build Coastguard Worker test_x((2,), -1, [100, 50], device) 4011*da0073e9SAndroid Build Coastguard Worker test_x((4, 2), 0, [2, 3, 4, 5], device) 4012*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) 4013*da0073e9SAndroid Build Coastguard Worker test_x( 4014*da0073e9SAndroid Build Coastguard Worker (10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device 4015*da0073e9SAndroid Build Coastguard Worker ) 4016*da0073e9SAndroid Build Coastguard Worker test_x((1, 10), 0, [1.0], device) 4017*da0073e9SAndroid Build Coastguard Worker test_x((0, 2), 1, [1, 2], device) 4018*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), -1, [1.0, 2.0, 3.0, 4.0], device) 4019*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 0, [1.0, 2.0], device) 4020*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) 4021*da0073e9SAndroid Build Coastguard Worker test_x((2, 3, 4), 2, [1.0, 2.0, 3.0, 4.0], device) 4022*da0073e9SAndroid Build Coastguard Worker 4023*da0073e9SAndroid Build Coastguard Worker test_empty_x( 4024*da0073e9SAndroid Build Coastguard Worker (0, 2), 0, [], device 4025*da0073e9SAndroid Build Coastguard Worker ) # SciPy failing when x == [], but our version returns empty 4026*da0073e9SAndroid Build Coastguard Worker 4027*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, "Dimension out of range"): 4028*da0073e9SAndroid Build Coastguard Worker test_x((2, 3), 2, [], device) 4029*da0073e9SAndroid Build Coastguard Worker test_dx((2, 3), 2, 1.0, device) 4030*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4031*da0073e9SAndroid Build Coastguard Worker RuntimeError, "There must be one `x` value for each sample point" 4032*da0073e9SAndroid Build Coastguard Worker ): 4033*da0073e9SAndroid Build Coastguard Worker test_x((2, 3), 1, [1.0, 2.0], device) 4034*da0073e9SAndroid Build Coastguard Worker test_x((0, 2), 0, [1.0, 2.0], device) 4035*da0073e9SAndroid Build Coastguard Worker test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) 4036*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4037*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Currently, we only support dx as a real number" 4038*da0073e9SAndroid Build Coastguard Worker ): 4039*da0073e9SAndroid Build Coastguard Worker test_dx((2, 2), -1, complex(1, 1), device) 4040*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4041*da0073e9SAndroid Build Coastguard Worker TypeError, "received an invalid combination of arguments" 4042*da0073e9SAndroid Build Coastguard Worker ): 4043*da0073e9SAndroid Build Coastguard Worker actual = torch.cumulative_trapezoid( 4044*da0073e9SAndroid Build Coastguard Worker torch.randn((3, 3)), x=torch.randn((3, 3)), dx=3 4045*da0073e9SAndroid Build Coastguard Worker ) 4046*da0073e9SAndroid Build Coastguard Worker 4047*da0073e9SAndroid Build Coastguard Worker @skipMeta 4048*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 4049*da0073e9SAndroid Build Coastguard Worker def test_pow_scalar_overloads_mem_overlap(self, device, dtype): 4050*da0073e9SAndroid Build Coastguard Worker sz = 3 4051*da0073e9SAndroid Build Coastguard Worker doubles = torch.randn(2 * sz, dtype=dtype, device=device) 4052*da0073e9SAndroid Build Coastguard Worker self.check_internal_mem_overlap(lambda t: t.pow_(42), 1, dtype, device) 4053*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 4054*da0073e9SAndroid Build Coastguard Worker doubles, sz, lambda input, out: torch.pow(input, 42, out=out) 4055*da0073e9SAndroid Build Coastguard Worker ) 4056*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 4057*da0073e9SAndroid Build Coastguard Worker doubles, sz, lambda input, out: torch.pow(42, input, out=out) 4058*da0073e9SAndroid Build Coastguard Worker ) 4059*da0073e9SAndroid Build Coastguard Worker 4060*da0073e9SAndroid Build Coastguard Worker @dtypes( 4061*da0073e9SAndroid Build Coastguard Worker *list( 4062*da0073e9SAndroid Build Coastguard Worker product( 4063*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16), 4064*da0073e9SAndroid Build Coastguard Worker all_types_and_complex_and(torch.half, torch.bfloat16), 4065*da0073e9SAndroid Build Coastguard Worker ) 4066*da0073e9SAndroid Build Coastguard Worker ) 4067*da0073e9SAndroid Build Coastguard Worker ) 4068*da0073e9SAndroid Build Coastguard Worker def test_float_power(self, device, dtypes): 4069*da0073e9SAndroid Build Coastguard Worker def to_np(value): 4070*da0073e9SAndroid Build Coastguard Worker if isinstance(value, torch.Tensor) and value.dtype == torch.bfloat16: 4071*da0073e9SAndroid Build Coastguard Worker return value.to(torch.float).cpu().numpy() 4072*da0073e9SAndroid Build Coastguard Worker return value.cpu().numpy() if isinstance(value, torch.Tensor) else value 4073*da0073e9SAndroid Build Coastguard Worker 4074*da0073e9SAndroid Build Coastguard Worker base_dtype = dtypes[0] 4075*da0073e9SAndroid Build Coastguard Worker exp_dtype = dtypes[1] 4076*da0073e9SAndroid Build Coastguard Worker out_dtype = ( 4077*da0073e9SAndroid Build Coastguard Worker torch.complex128 4078*da0073e9SAndroid Build Coastguard Worker if base_dtype.is_complex or exp_dtype.is_complex 4079*da0073e9SAndroid Build Coastguard Worker else torch.float64 4080*da0073e9SAndroid Build Coastguard Worker ) 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Worker base = make_tensor((30,), dtype=base_dtype, device=device, low=1, high=100) 4083*da0073e9SAndroid Build Coastguard Worker # Complex and real results do not agree between PyTorch and NumPy when computing negative and zero power of 0 4084*da0073e9SAndroid Build Coastguard Worker # Related: https://github.com/pytorch/pytorch/issues/48000 4085*da0073e9SAndroid Build Coastguard Worker # base[0] = base[3] = base[7] = 0 4086*da0073e9SAndroid Build Coastguard Worker exp = make_tensor((30,), dtype=exp_dtype, device=device, low=-2, high=2) 4087*da0073e9SAndroid Build Coastguard Worker exp[0] = exp[4] = exp[6] = 0 4088*da0073e9SAndroid Build Coastguard Worker 4089*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.float_power(to_np(base), to_np(exp))) 4090*da0073e9SAndroid Build Coastguard Worker 4091*da0073e9SAndroid Build Coastguard Worker exponents = [-2.8, -2, -1, -0.5, 0.5, 1, 2] 4092*da0073e9SAndroid Build Coastguard Worker complex_exponents = exponents + [ 4093*da0073e9SAndroid Build Coastguard Worker -2.5j, 4094*da0073e9SAndroid Build Coastguard Worker -1.0j, 4095*da0073e9SAndroid Build Coastguard Worker 1.0j, 4096*da0073e9SAndroid Build Coastguard Worker 2.5j, 4097*da0073e9SAndroid Build Coastguard Worker 1.0 + 1.0j, 4098*da0073e9SAndroid Build Coastguard Worker -1.0 - 1.5j, 4099*da0073e9SAndroid Build Coastguard Worker 3.3j, 4100*da0073e9SAndroid Build Coastguard Worker ] 4101*da0073e9SAndroid Build Coastguard Worker 4102*da0073e9SAndroid Build Coastguard Worker for op in ( 4103*da0073e9SAndroid Build Coastguard Worker torch.float_power, 4104*da0073e9SAndroid Build Coastguard Worker torch.Tensor.float_power, 4105*da0073e9SAndroid Build Coastguard Worker torch.Tensor.float_power_, 4106*da0073e9SAndroid Build Coastguard Worker ): 4107*da0073e9SAndroid Build Coastguard Worker # Case of Tensor x Tensor 4108*da0073e9SAndroid Build Coastguard Worker if op is torch.Tensor.float_power_ and base_dtype != out_dtype: 4109*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4110*da0073e9SAndroid Build Coastguard Worker RuntimeError, "operation's result requires dtype" 4111*da0073e9SAndroid Build Coastguard Worker ): 4112*da0073e9SAndroid Build Coastguard Worker op(base.clone(), exp) 4113*da0073e9SAndroid Build Coastguard Worker else: 4114*da0073e9SAndroid Build Coastguard Worker result = op(base.clone(), exp) 4115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, result) 4116*da0073e9SAndroid Build Coastguard Worker 4117*da0073e9SAndroid Build Coastguard Worker if op is torch.float_power: 4118*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(base).to(device=device, dtype=out_dtype) 4119*da0073e9SAndroid Build Coastguard Worker op(base, exp, out=out) 4120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 4121*da0073e9SAndroid Build Coastguard Worker 4122*da0073e9SAndroid Build Coastguard Worker # Case of Tensor x Scalar 4123*da0073e9SAndroid Build Coastguard Worker for i in complex_exponents if exp_dtype.is_complex else exponents: 4124*da0073e9SAndroid Build Coastguard Worker out_dtype_scalar_exp = ( 4125*da0073e9SAndroid Build Coastguard Worker torch.complex128 4126*da0073e9SAndroid Build Coastguard Worker if base_dtype.is_complex or type(i) == complex 4127*da0073e9SAndroid Build Coastguard Worker else torch.float64 4128*da0073e9SAndroid Build Coastguard Worker ) 4129*da0073e9SAndroid Build Coastguard Worker expected_scalar_exp = torch.from_numpy(np.float_power(to_np(base), i)) 4130*da0073e9SAndroid Build Coastguard Worker 4131*da0073e9SAndroid Build Coastguard Worker if ( 4132*da0073e9SAndroid Build Coastguard Worker op is torch.Tensor.float_power_ 4133*da0073e9SAndroid Build Coastguard Worker and base_dtype != out_dtype_scalar_exp 4134*da0073e9SAndroid Build Coastguard Worker ): 4135*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4136*da0073e9SAndroid Build Coastguard Worker RuntimeError, "operation's result requires dtype" 4137*da0073e9SAndroid Build Coastguard Worker ): 4138*da0073e9SAndroid Build Coastguard Worker op(base.clone(), i) 4139*da0073e9SAndroid Build Coastguard Worker else: 4140*da0073e9SAndroid Build Coastguard Worker result = op(base.clone(), i) 4141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_scalar_exp, result) 4142*da0073e9SAndroid Build Coastguard Worker 4143*da0073e9SAndroid Build Coastguard Worker if op is torch.float_power: 4144*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(base).to( 4145*da0073e9SAndroid Build Coastguard Worker device=device, dtype=out_dtype_scalar_exp 4146*da0073e9SAndroid Build Coastguard Worker ) 4147*da0073e9SAndroid Build Coastguard Worker op(base, i, out=out) 4148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_scalar_exp, out) 4149*da0073e9SAndroid Build Coastguard Worker 4150*da0073e9SAndroid Build Coastguard Worker # Case of Scalar x Tensor 4151*da0073e9SAndroid Build Coastguard Worker for i in complex_exponents if base_dtype.is_complex else exponents: 4152*da0073e9SAndroid Build Coastguard Worker out_dtype_scalar_base = ( 4153*da0073e9SAndroid Build Coastguard Worker torch.complex128 4154*da0073e9SAndroid Build Coastguard Worker if exp_dtype.is_complex or type(i) == complex 4155*da0073e9SAndroid Build Coastguard Worker else torch.float64 4156*da0073e9SAndroid Build Coastguard Worker ) 4157*da0073e9SAndroid Build Coastguard Worker expected_scalar_base = torch.from_numpy(np.float_power(i, to_np(exp))) 4158*da0073e9SAndroid Build Coastguard Worker 4159*da0073e9SAndroid Build Coastguard Worker result = torch.float_power(i, exp) 4160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_scalar_base, result) 4161*da0073e9SAndroid Build Coastguard Worker 4162*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(exp).to(device=device, dtype=out_dtype_scalar_base) 4163*da0073e9SAndroid Build Coastguard Worker torch.float_power(i, exp, out=out) 4164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_scalar_base, out) 4165*da0073e9SAndroid Build Coastguard Worker 4166*da0073e9SAndroid Build Coastguard Worker def test_float_power_exceptions(self, device): 4167*da0073e9SAndroid Build Coastguard Worker def _promo_helper(x, y): 4168*da0073e9SAndroid Build Coastguard Worker for i in (x, y): 4169*da0073e9SAndroid Build Coastguard Worker if type(i) == complex: 4170*da0073e9SAndroid Build Coastguard Worker return torch.complex128 4171*da0073e9SAndroid Build Coastguard Worker elif type(i) == torch.Tensor and i.is_complex(): 4172*da0073e9SAndroid Build Coastguard Worker return torch.complex128 4173*da0073e9SAndroid Build Coastguard Worker return torch.double 4174*da0073e9SAndroid Build Coastguard Worker 4175*da0073e9SAndroid Build Coastguard Worker test_cases = ( 4176*da0073e9SAndroid Build Coastguard Worker (torch.tensor([-2, -1, 0, 1, 2], device=device), -0.25), 4177*da0073e9SAndroid Build Coastguard Worker ( 4178*da0073e9SAndroid Build Coastguard Worker torch.tensor([-1.0j, 0j, 1.0j, 1.0 + 1.0j, -1.0 - 1.5j], device=device), 4179*da0073e9SAndroid Build Coastguard Worker 2.0, 4180*da0073e9SAndroid Build Coastguard Worker ), 4181*da0073e9SAndroid Build Coastguard Worker ) 4182*da0073e9SAndroid Build Coastguard Worker for base, exp in test_cases: 4183*da0073e9SAndroid Build Coastguard Worker for out_dtype in (torch.long, torch.float, torch.double, torch.cdouble): 4184*da0073e9SAndroid Build Coastguard Worker out = torch.empty(1, device=device, dtype=out_dtype) 4185*da0073e9SAndroid Build Coastguard Worker required_dtype = _promo_helper(base, exp) 4186*da0073e9SAndroid Build Coastguard Worker 4187*da0073e9SAndroid Build Coastguard Worker if out.dtype == required_dtype: 4188*da0073e9SAndroid Build Coastguard Worker torch.float_power(base, exp, out=out) 4189*da0073e9SAndroid Build Coastguard Worker else: 4190*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4191*da0073e9SAndroid Build Coastguard Worker RuntimeError, "operation's result requires dtype" 4192*da0073e9SAndroid Build Coastguard Worker ): 4193*da0073e9SAndroid Build Coastguard Worker torch.float_power(base, exp, out=out) 4194*da0073e9SAndroid Build Coastguard Worker 4195*da0073e9SAndroid Build Coastguard Worker if base.dtype == required_dtype: 4196*da0073e9SAndroid Build Coastguard Worker torch.Tensor.float_power_(base.clone(), exp) 4197*da0073e9SAndroid Build Coastguard Worker else: 4198*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4199*da0073e9SAndroid Build Coastguard Worker RuntimeError, "operation's result requires dtype" 4200*da0073e9SAndroid Build Coastguard Worker ): 4201*da0073e9SAndroid Build Coastguard Worker torch.Tensor.float_power_(base.clone(), exp) 4202*da0073e9SAndroid Build Coastguard Worker 4203*da0073e9SAndroid Build Coastguard Worker @skipIf(not TEST_SCIPY, "Scipy required for the test.") 4204*da0073e9SAndroid Build Coastguard Worker @dtypes( 4205*da0073e9SAndroid Build Coastguard Worker *product( 4206*da0073e9SAndroid Build Coastguard Worker all_types_and(torch.half, torch.bool), all_types_and(torch.half, torch.bool) 4207*da0073e9SAndroid Build Coastguard Worker ) 4208*da0073e9SAndroid Build Coastguard Worker ) 4209*da0073e9SAndroid Build Coastguard Worker def test_xlogy_xlog1py(self, device, dtypes): 4210*da0073e9SAndroid Build Coastguard Worker x_dtype, y_dtype = dtypes 4211*da0073e9SAndroid Build Coastguard Worker 4212*da0073e9SAndroid Build Coastguard Worker def out_variant_helper(torch_fn, x, y): 4213*da0073e9SAndroid Build Coastguard Worker expected = torch_fn(x, y) 4214*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(expected) 4215*da0073e9SAndroid Build Coastguard Worker torch_fn(x, y, out=out) 4216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, out) 4217*da0073e9SAndroid Build Coastguard Worker 4218*da0073e9SAndroid Build Coastguard Worker def xlogy_inplace_variant_helper(x, y): 4219*da0073e9SAndroid Build Coastguard Worker if x.dtype in integral_types_and(torch.bool): 4220*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4221*da0073e9SAndroid Build Coastguard Worker RuntimeError, "can't be cast to the desired output type" 4222*da0073e9SAndroid Build Coastguard Worker ): 4223*da0073e9SAndroid Build Coastguard Worker x.clone().xlogy_(y) 4224*da0073e9SAndroid Build Coastguard Worker else: 4225*da0073e9SAndroid Build Coastguard Worker expected = torch.empty_like(x) 4226*da0073e9SAndroid Build Coastguard Worker torch.xlogy(x, y, out=expected) 4227*da0073e9SAndroid Build Coastguard Worker inplace_out = x.clone().xlogy_(y) 4228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, inplace_out) 4229*da0073e9SAndroid Build Coastguard Worker 4230*da0073e9SAndroid Build Coastguard Worker def test_helper(torch_fn, reference_fn, inputs, scalar=None): 4231*da0073e9SAndroid Build Coastguard Worker x, y, z = inputs 4232*da0073e9SAndroid Build Coastguard Worker torch_fn_partial = partial(torch_fn, x) 4233*da0073e9SAndroid Build Coastguard Worker reference_fn_partial = partial(reference_fn, x.cpu().numpy()) 4234*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 4235*da0073e9SAndroid Build Coastguard Worker torch_fn_partial, reference_fn_partial, x, exact_dtype=False 4236*da0073e9SAndroid Build Coastguard Worker ) 4237*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 4238*da0073e9SAndroid Build Coastguard Worker torch_fn_partial, reference_fn_partial, y, exact_dtype=False 4239*da0073e9SAndroid Build Coastguard Worker ) 4240*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 4241*da0073e9SAndroid Build Coastguard Worker torch_fn_partial, reference_fn_partial, z, exact_dtype=False 4242*da0073e9SAndroid Build Coastguard Worker ) 4243*da0073e9SAndroid Build Coastguard Worker 4244*da0073e9SAndroid Build Coastguard Worker val = scalar if scalar is not None else x 4245*da0073e9SAndroid Build Coastguard Worker out_variant_helper(torch_fn, val, x) 4246*da0073e9SAndroid Build Coastguard Worker out_variant_helper(torch_fn, val, y) 4247*da0073e9SAndroid Build Coastguard Worker out_variant_helper(torch_fn, val, z) 4248*da0073e9SAndroid Build Coastguard Worker 4249*da0073e9SAndroid Build Coastguard Worker # Tensor-Tensor Test (tensor of same and different shape) 4250*da0073e9SAndroid Build Coastguard Worker x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000) 4251*da0073e9SAndroid Build Coastguard Worker y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) 4252*da0073e9SAndroid Build Coastguard Worker z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) 4253*da0073e9SAndroid Build Coastguard Worker 4254*da0073e9SAndroid Build Coastguard Worker x_1p = make_tensor( 4255*da0073e9SAndroid Build Coastguard Worker (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.5, high=1000 4256*da0073e9SAndroid Build Coastguard Worker ) 4257*da0073e9SAndroid Build Coastguard Worker y_1p = make_tensor( 4258*da0073e9SAndroid Build Coastguard Worker (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000 4259*da0073e9SAndroid Build Coastguard Worker ) 4260*da0073e9SAndroid Build Coastguard Worker z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.5, high=1000) 4261*da0073e9SAndroid Build Coastguard Worker 4262*da0073e9SAndroid Build Coastguard Worker xlogy_fns = torch.xlogy, scipy.special.xlogy 4263*da0073e9SAndroid Build Coastguard Worker xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py 4264*da0073e9SAndroid Build Coastguard Worker 4265*da0073e9SAndroid Build Coastguard Worker test_helper(*xlogy_fns, (x, y, z)) 4266*da0073e9SAndroid Build Coastguard Worker xlogy_inplace_variant_helper(x, x) 4267*da0073e9SAndroid Build Coastguard Worker xlogy_inplace_variant_helper(x, y) 4268*da0073e9SAndroid Build Coastguard Worker xlogy_inplace_variant_helper(x, z) 4269*da0073e9SAndroid Build Coastguard Worker test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p)) 4270*da0073e9SAndroid Build Coastguard Worker 4271*da0073e9SAndroid Build Coastguard Worker # Scalar-Tensor Test 4272*da0073e9SAndroid Build Coastguard Worker test_helper(*xlogy_fns, (x, y, z), 3.14) 4273*da0073e9SAndroid Build Coastguard Worker test_helper(*xlog1py_fns, (x_1p, y_1p, z_1p), 3.14) 4274*da0073e9SAndroid Build Coastguard Worker 4275*da0073e9SAndroid Build Coastguard Worker # Special Values Tensor-Tensor 4276*da0073e9SAndroid Build Coastguard Worker t = torch.tensor( 4277*da0073e9SAndroid Build Coastguard Worker [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")], 4278*da0073e9SAndroid Build Coastguard Worker device=device, 4279*da0073e9SAndroid Build Coastguard Worker ) 4280*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros(7, dtype=y_dtype, device=device) 4281*da0073e9SAndroid Build Coastguard Worker 4282*da0073e9SAndroid Build Coastguard Worker def test_zeros_special_helper(torch_fn, reference_fn, scalar=False): 4283*da0073e9SAndroid Build Coastguard Worker zeros_t = 0 if scalar else zeros 4284*da0073e9SAndroid Build Coastguard Worker zeros_np = 0 if scalar else zeros.cpu().numpy() 4285*da0073e9SAndroid Build Coastguard Worker torch_fn_partial = partial(torch_fn, zeros_t) 4286*da0073e9SAndroid Build Coastguard Worker reference_fn_partial = partial(reference_fn, zeros_np) 4287*da0073e9SAndroid Build Coastguard Worker self.compare_with_numpy( 4288*da0073e9SAndroid Build Coastguard Worker torch_fn_partial, reference_fn_partial, t, exact_dtype=False 4289*da0073e9SAndroid Build Coastguard Worker ) 4290*da0073e9SAndroid Build Coastguard Worker out_variant_helper(torch_fn, zeros_t, t) 4291*da0073e9SAndroid Build Coastguard Worker 4292*da0073e9SAndroid Build Coastguard Worker test_zeros_special_helper(*xlogy_fns) 4293*da0073e9SAndroid Build Coastguard Worker xlogy_inplace_variant_helper(zeros, t) 4294*da0073e9SAndroid Build Coastguard Worker test_zeros_special_helper(*xlog1py_fns) 4295*da0073e9SAndroid Build Coastguard Worker 4296*da0073e9SAndroid Build Coastguard Worker # Special Values Scalar-Tensor 4297*da0073e9SAndroid Build Coastguard Worker test_zeros_special_helper(*xlogy_fns, scalar=True) 4298*da0073e9SAndroid Build Coastguard Worker test_zeros_special_helper(*xlog1py_fns, scalar=True) 4299*da0073e9SAndroid Build Coastguard Worker 4300*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float64) 4301*da0073e9SAndroid Build Coastguard Worker def test_xlogy_xlog1py_gradients(self, device, dtype): 4302*da0073e9SAndroid Build Coastguard Worker make_arg = partial(torch.tensor, dtype=dtype, device=device, requires_grad=True) 4303*da0073e9SAndroid Build Coastguard Worker 4304*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros((2,), dtype=dtype, device=device) 4305*da0073e9SAndroid Build Coastguard Worker 4306*da0073e9SAndroid Build Coastguard Worker x = make_arg([0.0, 0.0]) 4307*da0073e9SAndroid Build Coastguard Worker y = make_arg([-1.5, 0.0]) 4308*da0073e9SAndroid Build Coastguard Worker torch.special.xlogy(x, y).sum().backward() 4309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, zeros) 4310*da0073e9SAndroid Build Coastguard Worker 4311*da0073e9SAndroid Build Coastguard Worker x = make_arg([0.0, 0.0]) 4312*da0073e9SAndroid Build Coastguard Worker y = make_arg([-2.5, -1.0]) 4313*da0073e9SAndroid Build Coastguard Worker torch.special.xlog1py(x, y).sum().backward() 4314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, zeros) 4315*da0073e9SAndroid Build Coastguard Worker 4316*da0073e9SAndroid Build Coastguard Worker def test_xlogy_xlog1py_scalar_type_promotion(self, device): 4317*da0073e9SAndroid Build Coastguard Worker # Test that python numbers don't participate in type promotion at the same 4318*da0073e9SAndroid Build Coastguard Worker # priority level as 0-dim tensors 4319*da0073e9SAndroid Build Coastguard Worker t = torch.randn((), dtype=torch.float32, device=device) 4320*da0073e9SAndroid Build Coastguard Worker 4321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype) 4322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.xlogy(t, 5.0).dtype) 4323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.special.xlog1py(t, 5).dtype) 4324*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.special.xlog1py(t, 5.0).dtype) 4325*da0073e9SAndroid Build Coastguard Worker 4326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.xlogy(5, t).dtype) 4327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.xlogy(5.0, t).dtype) 4328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.special.xlog1py(5, t).dtype) 4329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, torch.special.xlog1py(5.0, t).dtype) 4330*da0073e9SAndroid Build Coastguard Worker 4331*da0073e9SAndroid Build Coastguard Worker @skipIf(not TEST_SCIPY, "Scipy required for the test.") 4332*da0073e9SAndroid Build Coastguard Worker def test_xlogy_xlog1py_bfloat16(self, device): 4333*da0073e9SAndroid Build Coastguard Worker def _compare_helper(x, y, torch_fn, reference_fn): 4334*da0073e9SAndroid Build Coastguard Worker x_np = x if isinstance(x, float) else x.cpu().to(torch.float).numpy() 4335*da0073e9SAndroid Build Coastguard Worker y_np = y if isinstance(y, float) else y.cpu().to(torch.float).numpy() 4336*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(reference_fn(x_np, y_np)) 4337*da0073e9SAndroid Build Coastguard Worker actual = torch_fn(x, y) 4338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, exact_dtype=False) 4339*da0073e9SAndroid Build Coastguard Worker 4340*da0073e9SAndroid Build Coastguard Worker x_dtype, y_dtype = torch.bfloat16, torch.bfloat16 4341*da0073e9SAndroid Build Coastguard Worker 4342*da0073e9SAndroid Build Coastguard Worker # Tensor-Tensor Test (tensor of same and different shape) 4343*da0073e9SAndroid Build Coastguard Worker x = make_tensor((3, 2, 4, 5), dtype=x_dtype, device=device, low=0.5, high=1000) 4344*da0073e9SAndroid Build Coastguard Worker y = make_tensor((3, 2, 4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) 4345*da0073e9SAndroid Build Coastguard Worker z = make_tensor((4, 5), dtype=y_dtype, device=device, low=0.5, high=1000) 4346*da0073e9SAndroid Build Coastguard Worker 4347*da0073e9SAndroid Build Coastguard Worker x_1p = make_tensor( 4348*da0073e9SAndroid Build Coastguard Worker (3, 2, 4, 5), dtype=x_dtype, device=device, low=-0.8, high=1000 4349*da0073e9SAndroid Build Coastguard Worker ) 4350*da0073e9SAndroid Build Coastguard Worker y_1p = make_tensor( 4351*da0073e9SAndroid Build Coastguard Worker (3, 2, 4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000 4352*da0073e9SAndroid Build Coastguard Worker ) 4353*da0073e9SAndroid Build Coastguard Worker z_1p = make_tensor((4, 5), dtype=y_dtype, device=device, low=-0.8, high=1000) 4354*da0073e9SAndroid Build Coastguard Worker 4355*da0073e9SAndroid Build Coastguard Worker xlogy_fns = torch.xlogy, scipy.special.xlogy 4356*da0073e9SAndroid Build Coastguard Worker xlog1py_fns = torch.special.xlog1py, scipy.special.xlog1py 4357*da0073e9SAndroid Build Coastguard Worker 4358*da0073e9SAndroid Build Coastguard Worker _compare_helper(x, x, *xlogy_fns) 4359*da0073e9SAndroid Build Coastguard Worker _compare_helper(x, y, *xlogy_fns) 4360*da0073e9SAndroid Build Coastguard Worker _compare_helper(x, z, *xlogy_fns) 4361*da0073e9SAndroid Build Coastguard Worker _compare_helper(x, 3.14, *xlogy_fns) 4362*da0073e9SAndroid Build Coastguard Worker _compare_helper(y, 3.14, *xlogy_fns) 4363*da0073e9SAndroid Build Coastguard Worker _compare_helper(z, 3.14, *xlogy_fns) 4364*da0073e9SAndroid Build Coastguard Worker 4365*da0073e9SAndroid Build Coastguard Worker _compare_helper(x_1p, x_1p, *xlog1py_fns) 4366*da0073e9SAndroid Build Coastguard Worker _compare_helper(x_1p, y_1p, *xlog1py_fns) 4367*da0073e9SAndroid Build Coastguard Worker _compare_helper(x_1p, z_1p, *xlog1py_fns) 4368*da0073e9SAndroid Build Coastguard Worker _compare_helper(x_1p, 3.14, *xlog1py_fns) 4369*da0073e9SAndroid Build Coastguard Worker _compare_helper(y_1p, 3.14, *xlog1py_fns) 4370*da0073e9SAndroid Build Coastguard Worker _compare_helper(z_1p, 3.14, *xlog1py_fns) 4371*da0073e9SAndroid Build Coastguard Worker 4372*da0073e9SAndroid Build Coastguard Worker # Special Values Tensor-Tensor 4373*da0073e9SAndroid Build Coastguard Worker t = torch.tensor( 4374*da0073e9SAndroid Build Coastguard Worker [-1.0, 0.0, 1.0, 2.0, float("inf"), -float("inf"), float("nan")], 4375*da0073e9SAndroid Build Coastguard Worker device=device, 4376*da0073e9SAndroid Build Coastguard Worker ) 4377*da0073e9SAndroid Build Coastguard Worker zeros = torch.tensor(7, dtype=y_dtype, device=device) 4378*da0073e9SAndroid Build Coastguard Worker 4379*da0073e9SAndroid Build Coastguard Worker _compare_helper(t, zeros, *xlogy_fns) 4380*da0073e9SAndroid Build Coastguard Worker _compare_helper(t, 0.0, *xlogy_fns) 4381*da0073e9SAndroid Build Coastguard Worker 4382*da0073e9SAndroid Build Coastguard Worker _compare_helper(t, zeros, *xlog1py_fns) 4383*da0073e9SAndroid Build Coastguard Worker _compare_helper(t, 0.0, *xlog1py_fns) 4384*da0073e9SAndroid Build Coastguard Worker 4385*da0073e9SAndroid Build Coastguard Worker @dtypes(*product(all_types_and(torch.bool), all_types_and(torch.bool))) 4386*da0073e9SAndroid Build Coastguard Worker @skipIf(not TEST_SCIPY, "Scipy required for the test.") 4387*da0073e9SAndroid Build Coastguard Worker @slowTest 4388*da0073e9SAndroid Build Coastguard Worker def test_zeta(self, device, dtypes): 4389*da0073e9SAndroid Build Coastguard Worker x_dtype, q_dtype = dtypes 4390*da0073e9SAndroid Build Coastguard Worker 4391*da0073e9SAndroid Build Coastguard Worker def test_helper(x, q): 4392*da0073e9SAndroid Build Coastguard Worker x_np = x if isinstance(x, float) else x.cpu().numpy() 4393*da0073e9SAndroid Build Coastguard Worker q_np = q if isinstance(q, float) else q.cpu().numpy() 4394*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(scipy.special.zeta(x_np, q_np)) 4395*da0073e9SAndroid Build Coastguard Worker actual = torch.special.zeta(x, q) 4396*da0073e9SAndroid Build Coastguard Worker 4397*da0073e9SAndroid Build Coastguard Worker rtol, atol = None, None 4398*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cpu": 4399*da0073e9SAndroid Build Coastguard Worker rtol, atol = 1e-6, 1e-6 4400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, rtol=rtol, atol=atol, exact_dtype=False) 4401*da0073e9SAndroid Build Coastguard Worker 4402*da0073e9SAndroid Build Coastguard Worker # x tensor - q tensor same size 4403*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2, 3, 4), dtype=x_dtype, device=device) 4404*da0073e9SAndroid Build Coastguard Worker q = make_tensor((2, 3, 4), dtype=q_dtype, device=device) 4405*da0073e9SAndroid Build Coastguard Worker test_helper(x, q) 4406*da0073e9SAndroid Build Coastguard Worker 4407*da0073e9SAndroid Build Coastguard Worker # x tensor - q tensor broadcast lhs 4408*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2, 1, 4), dtype=x_dtype, device=device) 4409*da0073e9SAndroid Build Coastguard Worker q = make_tensor((2, 3, 4), dtype=q_dtype, device=device) 4410*da0073e9SAndroid Build Coastguard Worker test_helper(x, q) 4411*da0073e9SAndroid Build Coastguard Worker 4412*da0073e9SAndroid Build Coastguard Worker # x tensor - q tensor broadcast rhs 4413*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2, 3, 4), dtype=x_dtype, device=device) 4414*da0073e9SAndroid Build Coastguard Worker q = make_tensor((2, 1, 4), dtype=q_dtype, device=device) 4415*da0073e9SAndroid Build Coastguard Worker test_helper(x, q) 4416*da0073e9SAndroid Build Coastguard Worker 4417*da0073e9SAndroid Build Coastguard Worker # x tensor - q tensor broadcast all 4418*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2, 3, 1), dtype=x_dtype, device=device) 4419*da0073e9SAndroid Build Coastguard Worker q = make_tensor((2, 1, 4), dtype=q_dtype, device=device) 4420*da0073e9SAndroid Build Coastguard Worker test_helper(x, q) 4421*da0073e9SAndroid Build Coastguard Worker 4422*da0073e9SAndroid Build Coastguard Worker # x scalar - q tensor 4423*da0073e9SAndroid Build Coastguard Worker for x in np.linspace(-5, 5, num=10).tolist(): 4424*da0073e9SAndroid Build Coastguard Worker if not q_dtype.is_floating_point: 4425*da0073e9SAndroid Build Coastguard Worker q_dtype = torch.get_default_dtype() 4426*da0073e9SAndroid Build Coastguard Worker q = make_tensor((2, 3, 4), dtype=q_dtype, device=device) 4427*da0073e9SAndroid Build Coastguard Worker test_helper(x, q) 4428*da0073e9SAndroid Build Coastguard Worker 4429*da0073e9SAndroid Build Coastguard Worker # x tensor - q scalar 4430*da0073e9SAndroid Build Coastguard Worker for q in np.linspace(-5, 5, num=10).tolist(): 4431*da0073e9SAndroid Build Coastguard Worker if not x_dtype.is_floating_point: 4432*da0073e9SAndroid Build Coastguard Worker x_dtype = torch.get_default_dtype() 4433*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2, 3, 4), dtype=x_dtype, device=device) 4434*da0073e9SAndroid Build Coastguard Worker test_helper(x, q) 4435*da0073e9SAndroid Build Coastguard Worker 4436*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4437*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.chalf) 4438*da0073e9SAndroid Build Coastguard Worker def test_mul_chalf_tensor_and_cpu_scalar(self, device, dtype): 4439*da0073e9SAndroid Build Coastguard Worker # Tests that Tensor and CPU Scalar work for `mul` for chalf. 4440*da0073e9SAndroid Build Coastguard Worker # Ideally, this should be covered by `test_complex_half_reference_testing` 4441*da0073e9SAndroid Build Coastguard Worker # from test_ops.py by checking reference_samples from the OpInfo. 4442*da0073e9SAndroid Build Coastguard Worker # But currently that doesn't work as sample generation requires support of 4443*da0073e9SAndroid Build Coastguard Worker # `index_select` which is not implemented for `complex32` at the 4444*da0073e9SAndroid Build Coastguard Worker # time of writing this test. 4445*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this test once above issue is fixed. 4446*da0073e9SAndroid Build Coastguard Worker # Ref: https://github.com/pytorch/pytorch/pull/76364 4447*da0073e9SAndroid Build Coastguard Worker x = make_tensor((2, 2), device=device, dtype=dtype) 4448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x * 2.5, x * torch.tensor(2.5, device=device, dtype=dtype)) 4449*da0073e9SAndroid Build Coastguard Worker 4450*da0073e9SAndroid Build Coastguard Worker 4451*da0073e9SAndroid Build Coastguard Workertensor_binary_ops = [ 4452*da0073e9SAndroid Build Coastguard Worker "__lt__", 4453*da0073e9SAndroid Build Coastguard Worker "__le__", 4454*da0073e9SAndroid Build Coastguard Worker "__gt__", 4455*da0073e9SAndroid Build Coastguard Worker "__ge__", 4456*da0073e9SAndroid Build Coastguard Worker "__eq__", 4457*da0073e9SAndroid Build Coastguard Worker "__ne__", 4458*da0073e9SAndroid Build Coastguard Worker "__add__", 4459*da0073e9SAndroid Build Coastguard Worker "__radd__", 4460*da0073e9SAndroid Build Coastguard Worker "__iadd__", 4461*da0073e9SAndroid Build Coastguard Worker "__sub__", 4462*da0073e9SAndroid Build Coastguard Worker "__rsub__", 4463*da0073e9SAndroid Build Coastguard Worker "__isub__", 4464*da0073e9SAndroid Build Coastguard Worker "__mul__", 4465*da0073e9SAndroid Build Coastguard Worker "__rmul__", 4466*da0073e9SAndroid Build Coastguard Worker "__imul__", 4467*da0073e9SAndroid Build Coastguard Worker "__matmul__", 4468*da0073e9SAndroid Build Coastguard Worker "__rmatmul__", 4469*da0073e9SAndroid Build Coastguard Worker "__truediv__", 4470*da0073e9SAndroid Build Coastguard Worker "__rtruediv__", 4471*da0073e9SAndroid Build Coastguard Worker "__itruediv__", 4472*da0073e9SAndroid Build Coastguard Worker "__floordiv__", 4473*da0073e9SAndroid Build Coastguard Worker "__rfloordiv__", 4474*da0073e9SAndroid Build Coastguard Worker "__ifloordiv__", 4475*da0073e9SAndroid Build Coastguard Worker "__mod__", 4476*da0073e9SAndroid Build Coastguard Worker "__rmod__", 4477*da0073e9SAndroid Build Coastguard Worker "__imod__", 4478*da0073e9SAndroid Build Coastguard Worker "__pow__", 4479*da0073e9SAndroid Build Coastguard Worker "__rpow__", 4480*da0073e9SAndroid Build Coastguard Worker "__ipow__", 4481*da0073e9SAndroid Build Coastguard Worker "__lshift__", 4482*da0073e9SAndroid Build Coastguard Worker "__rlshift__", 4483*da0073e9SAndroid Build Coastguard Worker "__ilshift__", 4484*da0073e9SAndroid Build Coastguard Worker "__rshift__", 4485*da0073e9SAndroid Build Coastguard Worker "__rrshift__", 4486*da0073e9SAndroid Build Coastguard Worker "__irshift__", 4487*da0073e9SAndroid Build Coastguard Worker "__and__", 4488*da0073e9SAndroid Build Coastguard Worker "__rand__", 4489*da0073e9SAndroid Build Coastguard Worker "__iand__", 4490*da0073e9SAndroid Build Coastguard Worker "__xor__", 4491*da0073e9SAndroid Build Coastguard Worker "__rxor__", 4492*da0073e9SAndroid Build Coastguard Worker "__ixor__", 4493*da0073e9SAndroid Build Coastguard Worker "__or__", 4494*da0073e9SAndroid Build Coastguard Worker "__ror__", 4495*da0073e9SAndroid Build Coastguard Worker "__ior__", 4496*da0073e9SAndroid Build Coastguard Worker # Unsupported operators 4497*da0073e9SAndroid Build Coastguard Worker # '__imatmul__', 4498*da0073e9SAndroid Build Coastguard Worker # '__divmod__', '__rdivmod__', '__idivmod__', 4499*da0073e9SAndroid Build Coastguard Worker] 4500*da0073e9SAndroid Build Coastguard Worker 4501*da0073e9SAndroid Build Coastguard Worker 4502*da0073e9SAndroid Build Coastguard Worker# Test that binary math operations return NotImplemented for unknown types. 4503*da0073e9SAndroid Build Coastguard Workerdef generate_not_implemented_tests(cls): 4504*da0073e9SAndroid Build Coastguard Worker class UnknownType: 4505*da0073e9SAndroid Build Coastguard Worker pass 4506*da0073e9SAndroid Build Coastguard Worker 4507*da0073e9SAndroid Build Coastguard Worker # TODO: refactor to inline these 4508*da0073e9SAndroid Build Coastguard Worker _types = [ 4509*da0073e9SAndroid Build Coastguard Worker torch.half, 4510*da0073e9SAndroid Build Coastguard Worker torch.float, 4511*da0073e9SAndroid Build Coastguard Worker torch.double, 4512*da0073e9SAndroid Build Coastguard Worker torch.int8, 4513*da0073e9SAndroid Build Coastguard Worker torch.short, 4514*da0073e9SAndroid Build Coastguard Worker torch.int, 4515*da0073e9SAndroid Build Coastguard Worker torch.long, 4516*da0073e9SAndroid Build Coastguard Worker torch.uint8, 4517*da0073e9SAndroid Build Coastguard Worker ] 4518*da0073e9SAndroid Build Coastguard Worker 4519*da0073e9SAndroid Build Coastguard Worker def create_test_func(op): 4520*da0073e9SAndroid Build Coastguard Worker @dtypes(*_types) 4521*da0073e9SAndroid Build Coastguard Worker def test(self, device, dtype): 4522*da0073e9SAndroid Build Coastguard Worker # Generate the inputs 4523*da0073e9SAndroid Build Coastguard Worker tensor = torch.empty((), device=device, dtype=dtype) 4524*da0073e9SAndroid Build Coastguard Worker 4525*da0073e9SAndroid Build Coastguard Worker # Runs the tensor op on the device 4526*da0073e9SAndroid Build Coastguard Worker result = getattr(tensor, op)(UnknownType()) 4527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, NotImplemented) 4528*da0073e9SAndroid Build Coastguard Worker 4529*da0073e9SAndroid Build Coastguard Worker return test 4530*da0073e9SAndroid Build Coastguard Worker 4531*da0073e9SAndroid Build Coastguard Worker for op in tensor_binary_ops: 4532*da0073e9SAndroid Build Coastguard Worker test_name = f"test_{op}_not_implemented" 4533*da0073e9SAndroid Build Coastguard Worker assert not hasattr(cls, test_name), f"{test_name} already in {cls.__name__}" 4534*da0073e9SAndroid Build Coastguard Worker 4535*da0073e9SAndroid Build Coastguard Worker setattr(cls, test_name, create_test_func(op)) 4536*da0073e9SAndroid Build Coastguard Worker 4537*da0073e9SAndroid Build Coastguard Worker 4538*da0073e9SAndroid Build Coastguard Workergenerate_not_implemented_tests(TestBinaryUfuncs) 4539*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestBinaryUfuncs, globals()) 4540*da0073e9SAndroid Build Coastguard Worker 4541*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 4542*da0073e9SAndroid Build Coastguard Worker run_tests() 4543