1*da0073e9SAndroid Build Coastguard Worker# mypy: ignore-errors 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: numpy"] 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport numpy as np 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 13*da0073e9SAndroid Build Coastguard Worker dtypes, 14*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 15*da0073e9SAndroid Build Coastguard Worker onlyCPU, 16*da0073e9SAndroid Build Coastguard Worker skipMeta, 17*da0073e9SAndroid Build Coastguard Worker) 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import all_types_and_complex_and 19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker# For testing handling NumPy objects and sending tensors to / accepting 23*da0073e9SAndroid Build Coastguard Worker# arrays from NumPy. 24*da0073e9SAndroid Build Coastguard Workerclass TestNumPyInterop(TestCase): 25*da0073e9SAndroid Build Coastguard Worker # Note: the warning this tests for only appears once per program, so 26*da0073e9SAndroid Build Coastguard Worker # other instances of this warning should be addressed to avoid 27*da0073e9SAndroid Build Coastguard Worker # the tests depending on the order in which they're run. 28*da0073e9SAndroid Build Coastguard Worker @onlyCPU 29*da0073e9SAndroid Build Coastguard Worker def test_numpy_non_writeable(self, device): 30*da0073e9SAndroid Build Coastguard Worker arr = np.zeros(5) 31*da0073e9SAndroid Build Coastguard Worker arr.flags["WRITEABLE"] = False 32*da0073e9SAndroid Build Coastguard Worker self.assertWarns(UserWarning, lambda: torch.from_numpy(arr)) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker @onlyCPU 35*da0073e9SAndroid Build Coastguard Worker def test_numpy_unresizable(self, device) -> None: 36*da0073e9SAndroid Build Coastguard Worker x = np.zeros((2, 2)) 37*da0073e9SAndroid Build Coastguard Worker y = torch.from_numpy(x) 38*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 39*da0073e9SAndroid Build Coastguard Worker x.resize((5, 5)) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker z = torch.randn(5, 5) 42*da0073e9SAndroid Build Coastguard Worker w = z.numpy() 43*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 44*da0073e9SAndroid Build Coastguard Worker z.resize_(10, 10) 45*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 46*da0073e9SAndroid Build Coastguard Worker w.resize((10, 10)) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker @onlyCPU 49*da0073e9SAndroid Build Coastguard Worker def test_to_numpy(self, device) -> None: 50*da0073e9SAndroid Build Coastguard Worker def get_castable_tensor(shape, dtype): 51*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 52*da0073e9SAndroid Build Coastguard Worker dtype_info = torch.finfo(dtype) 53*da0073e9SAndroid Build Coastguard Worker # can't directly use min and max, because for double, max - min 54*da0073e9SAndroid Build Coastguard Worker # is greater than double range and sampling always gives inf. 55*da0073e9SAndroid Build Coastguard Worker low = max(dtype_info.min, -1e10) 56*da0073e9SAndroid Build Coastguard Worker high = min(dtype_info.max, 1e10) 57*da0073e9SAndroid Build Coastguard Worker t = torch.empty(shape, dtype=torch.float64).uniform_(low, high) 58*da0073e9SAndroid Build Coastguard Worker else: 59*da0073e9SAndroid Build Coastguard Worker # can't directly use min and max, because for int64_t, max - min 60*da0073e9SAndroid Build Coastguard Worker # is greater than int64_t range and triggers UB. 61*da0073e9SAndroid Build Coastguard Worker low = max(torch.iinfo(dtype).min, int(-1e10)) 62*da0073e9SAndroid Build Coastguard Worker high = min(torch.iinfo(dtype).max, int(1e10)) 63*da0073e9SAndroid Build Coastguard Worker t = torch.empty(shape, dtype=torch.int64).random_(low, high) 64*da0073e9SAndroid Build Coastguard Worker return t.to(dtype) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker dtypes = [ 67*da0073e9SAndroid Build Coastguard Worker torch.uint8, 68*da0073e9SAndroid Build Coastguard Worker torch.int8, 69*da0073e9SAndroid Build Coastguard Worker torch.short, 70*da0073e9SAndroid Build Coastguard Worker torch.int, 71*da0073e9SAndroid Build Coastguard Worker torch.half, 72*da0073e9SAndroid Build Coastguard Worker torch.float, 73*da0073e9SAndroid Build Coastguard Worker torch.double, 74*da0073e9SAndroid Build Coastguard Worker torch.long, 75*da0073e9SAndroid Build Coastguard Worker ] 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker for dtp in dtypes: 78*da0073e9SAndroid Build Coastguard Worker # 1D 79*da0073e9SAndroid Build Coastguard Worker sz = 10 80*da0073e9SAndroid Build Coastguard Worker x = get_castable_tensor(sz, dtp) 81*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 82*da0073e9SAndroid Build Coastguard Worker for i in range(sz): 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i], y[i]) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker # 1D > 0 storage offset 86*da0073e9SAndroid Build Coastguard Worker xm = get_castable_tensor(sz * 2, dtp) 87*da0073e9SAndroid Build Coastguard Worker x = xm.narrow(0, sz - 1, sz) 88*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.storage_offset() > 0) 89*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 90*da0073e9SAndroid Build Coastguard Worker for i in range(sz): 91*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i], y[i]) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def check2d(x, y): 94*da0073e9SAndroid Build Coastguard Worker for i in range(sz1): 95*da0073e9SAndroid Build Coastguard Worker for j in range(sz2): 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i][j], y[i][j]) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker # empty 99*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([]).to(dtp) 100*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 101*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.size, 0) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker # contiguous 2D 104*da0073e9SAndroid Build Coastguard Worker sz1 = 3 105*da0073e9SAndroid Build Coastguard Worker sz2 = 5 106*da0073e9SAndroid Build Coastguard Worker x = get_castable_tensor((sz1, sz2), dtp) 107*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 108*da0073e9SAndroid Build Coastguard Worker check2d(x, y) 109*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.flags["C_CONTIGUOUS"]) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker # with storage offset 112*da0073e9SAndroid Build Coastguard Worker xm = get_castable_tensor((sz1 * 2, sz2), dtp) 113*da0073e9SAndroid Build Coastguard Worker x = xm.narrow(0, sz1 - 1, sz1) 114*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 115*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.storage_offset() > 0) 116*da0073e9SAndroid Build Coastguard Worker check2d(x, y) 117*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.flags["C_CONTIGUOUS"]) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker # non-contiguous 2D 120*da0073e9SAndroid Build Coastguard Worker x = get_castable_tensor((sz2, sz1), dtp).t() 121*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 122*da0073e9SAndroid Build Coastguard Worker check2d(x, y) 123*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.flags["C_CONTIGUOUS"]) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # with storage offset 126*da0073e9SAndroid Build Coastguard Worker xm = get_castable_tensor((sz2 * 2, sz1), dtp) 127*da0073e9SAndroid Build Coastguard Worker x = xm.narrow(0, sz2 - 1, sz2).t() 128*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 129*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.storage_offset() > 0) 130*da0073e9SAndroid Build Coastguard Worker check2d(x, y) 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker # non-contiguous 2D with holes 133*da0073e9SAndroid Build Coastguard Worker xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp) 134*da0073e9SAndroid Build Coastguard Worker x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t() 135*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 136*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.storage_offset() > 0) 137*da0073e9SAndroid Build Coastguard Worker check2d(x, y) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker if dtp != torch.half: 140*da0073e9SAndroid Build Coastguard Worker # check writeable 141*da0073e9SAndroid Build Coastguard Worker x = get_castable_tensor((3, 4), dtp) 142*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 143*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.flags.writeable) 144*da0073e9SAndroid Build Coastguard Worker y[0][1] = 3 145*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x[0][1] == 3) 146*da0073e9SAndroid Build Coastguard Worker y = x.t().numpy() 147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.flags.writeable) 148*da0073e9SAndroid Build Coastguard Worker y[0][1] = 3 149*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x[0][1] == 3) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def test_to_numpy_bool(self, device) -> None: 152*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([True, False], dtype=torch.bool) 153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.bool) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, np.bool_) 157*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i], y[i]) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([True], dtype=torch.bool) 161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dtype, torch.bool) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker y = x.numpy() 164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.dtype, np.bool_) 165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[0], y[0]) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("conj bit not implemented in TensorVariable yet") 168*da0073e9SAndroid Build Coastguard Worker def test_to_numpy_force_argument(self, device) -> None: 169*da0073e9SAndroid Build Coastguard Worker for force in [False, True]: 170*da0073e9SAndroid Build Coastguard Worker for requires_grad in [False, True]: 171*da0073e9SAndroid Build Coastguard Worker for sparse in [False, True]: 172*da0073e9SAndroid Build Coastguard Worker for conj in [False, True]: 173*da0073e9SAndroid Build Coastguard Worker data = [[1 + 2j, -2 + 3j], [-1 - 2j, 3 - 2j]] 174*da0073e9SAndroid Build Coastguard Worker x = torch.tensor( 175*da0073e9SAndroid Build Coastguard Worker data, requires_grad=requires_grad, device=device 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker y = x 178*da0073e9SAndroid Build Coastguard Worker if sparse: 179*da0073e9SAndroid Build Coastguard Worker if requires_grad: 180*da0073e9SAndroid Build Coastguard Worker continue 181*da0073e9SAndroid Build Coastguard Worker x = x.to_sparse() 182*da0073e9SAndroid Build Coastguard Worker if conj: 183*da0073e9SAndroid Build Coastguard Worker x = x.conj() 184*da0073e9SAndroid Build Coastguard Worker y = x.resolve_conj() 185*da0073e9SAndroid Build Coastguard Worker expect_error = ( 186*da0073e9SAndroid Build Coastguard Worker requires_grad or sparse or conj or not device == "cpu" 187*da0073e9SAndroid Build Coastguard Worker ) 188*da0073e9SAndroid Build Coastguard Worker error_msg = r"Use (t|T)ensor\..*(\.numpy\(\))?" 189*da0073e9SAndroid Build Coastguard Worker if not force and expect_error: 190*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 191*da0073e9SAndroid Build Coastguard Worker (RuntimeError, TypeError), error_msg, lambda: x.numpy() 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 194*da0073e9SAndroid Build Coastguard Worker (RuntimeError, TypeError), 195*da0073e9SAndroid Build Coastguard Worker error_msg, 196*da0073e9SAndroid Build Coastguard Worker lambda: x.numpy(force=False), 197*da0073e9SAndroid Build Coastguard Worker ) 198*da0073e9SAndroid Build Coastguard Worker elif force and sparse: 199*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 200*da0073e9SAndroid Build Coastguard Worker TypeError, error_msg, lambda: x.numpy(force=True) 201*da0073e9SAndroid Build Coastguard Worker ) 202*da0073e9SAndroid Build Coastguard Worker else: 203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.numpy(force=force), y) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker def test_from_numpy(self, device) -> None: 206*da0073e9SAndroid Build Coastguard Worker dtypes = [ 207*da0073e9SAndroid Build Coastguard Worker np.double, 208*da0073e9SAndroid Build Coastguard Worker np.float64, 209*da0073e9SAndroid Build Coastguard Worker np.float16, 210*da0073e9SAndroid Build Coastguard Worker np.complex64, 211*da0073e9SAndroid Build Coastguard Worker np.complex128, 212*da0073e9SAndroid Build Coastguard Worker np.int64, 213*da0073e9SAndroid Build Coastguard Worker np.int32, 214*da0073e9SAndroid Build Coastguard Worker np.int16, 215*da0073e9SAndroid Build Coastguard Worker np.int8, 216*da0073e9SAndroid Build Coastguard Worker np.uint8, 217*da0073e9SAndroid Build Coastguard Worker np.longlong, 218*da0073e9SAndroid Build Coastguard Worker np.bool_, 219*da0073e9SAndroid Build Coastguard Worker ] 220*da0073e9SAndroid Build Coastguard Worker complex_dtypes = [ 221*da0073e9SAndroid Build Coastguard Worker np.complex64, 222*da0073e9SAndroid Build Coastguard Worker np.complex128, 223*da0073e9SAndroid Build Coastguard Worker ] 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 226*da0073e9SAndroid Build Coastguard Worker array = np.array([1, 2, 3, 4], dtype=dtype) 227*da0073e9SAndroid Build Coastguard Worker tensor_from_array = torch.from_numpy(array) 228*da0073e9SAndroid Build Coastguard Worker # TODO: change to tensor equality check once HalfTensor 229*da0073e9SAndroid Build Coastguard Worker # implements `==` 230*da0073e9SAndroid Build Coastguard Worker for i in range(len(array)): 231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_from_array[i], array[i]) 232*da0073e9SAndroid Build Coastguard Worker # ufunc 'remainder' not supported for complex dtypes 233*da0073e9SAndroid Build Coastguard Worker if dtype not in complex_dtypes: 234*da0073e9SAndroid Build Coastguard Worker # This is a special test case for Windows 235*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/22615 236*da0073e9SAndroid Build Coastguard Worker array2 = array % 2 237*da0073e9SAndroid Build Coastguard Worker tensor_from_array2 = torch.from_numpy(array2) 238*da0073e9SAndroid Build Coastguard Worker for i in range(len(array2)): 239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_from_array2[i], array2[i]) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker # Test unsupported type 242*da0073e9SAndroid Build Coastguard Worker array = np.array(["foo", "bar"], dtype=np.dtype(np.str_)) 243*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 244*da0073e9SAndroid Build Coastguard Worker tensor_from_array = torch.from_numpy(array) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # check storage offset 247*da0073e9SAndroid Build Coastguard Worker x = np.linspace(1, 125, 125) 248*da0073e9SAndroid Build Coastguard Worker x.shape = (5, 5, 5) 249*da0073e9SAndroid Build Coastguard Worker x = x[1] 250*da0073e9SAndroid Build Coastguard Worker expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1] 251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(x), expected) 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker # check noncontiguous 254*da0073e9SAndroid Build Coastguard Worker x = np.linspace(1, 25, 25) 255*da0073e9SAndroid Build Coastguard Worker x.shape = (5, 5) 256*da0073e9SAndroid Build Coastguard Worker expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t() 257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(x.T), expected) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker # check noncontiguous with holes 260*da0073e9SAndroid Build Coastguard Worker x = np.linspace(1, 125, 125) 261*da0073e9SAndroid Build Coastguard Worker x.shape = (5, 5, 5) 262*da0073e9SAndroid Build Coastguard Worker x = x[:, 1] 263*da0073e9SAndroid Build Coastguard Worker expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1] 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(x), expected) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker # check zero dimensional 267*da0073e9SAndroid Build Coastguard Worker x = np.zeros((0, 2)) 268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(x).shape, (0, 2)) 269*da0073e9SAndroid Build Coastguard Worker x = np.zeros((2, 0)) 270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.from_numpy(x).shape, (2, 0)) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker # check ill-sized strides raise exception 273*da0073e9SAndroid Build Coastguard Worker x = np.array([3.0, 5.0, 8.0]) 274*da0073e9SAndroid Build Coastguard Worker x.strides = (3,) 275*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.from_numpy(x)) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.") 278*da0073e9SAndroid Build Coastguard Worker def test_from_numpy_no_leak_on_invalid_dtype(self): 279*da0073e9SAndroid Build Coastguard Worker # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary 280*da0073e9SAndroid Build Coastguard Worker # object. See https://github.com/pytorch/pytorch/issues/121138 281*da0073e9SAndroid Build Coastguard Worker x = np.array("value".encode("ascii")) 282*da0073e9SAndroid Build Coastguard Worker for _ in range(1000): 283*da0073e9SAndroid Build Coastguard Worker try: 284*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(x) 285*da0073e9SAndroid Build Coastguard Worker except TypeError: 286*da0073e9SAndroid Build Coastguard Worker pass 287*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sys.getrefcount(x) == 2) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker @skipMeta 290*da0073e9SAndroid Build Coastguard Worker def test_from_list_of_ndarray_warning(self, device): 291*da0073e9SAndroid Build Coastguard Worker warning_msg = ( 292*da0073e9SAndroid Build Coastguard Worker r"Creating a tensor from a list of numpy.ndarrays is extremely slow" 293*da0073e9SAndroid Build Coastguard Worker ) 294*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, warning_msg): 295*da0073e9SAndroid Build Coastguard Worker torch.tensor([np.array([0]), np.array([1])], device=device) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_ctor_with_invalid_numpy_array_sequence(self, device): 298*da0073e9SAndroid Build Coastguard Worker # Invalid list of numpy array 299*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected sequence of length"): 300*da0073e9SAndroid Build Coastguard Worker torch.tensor( 301*da0073e9SAndroid Build Coastguard Worker [np.random.random(size=(3, 3)), np.random.random(size=(3, 0))], 302*da0073e9SAndroid Build Coastguard Worker device=device, 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker # Invalid list of list of numpy array 306*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected sequence of length"): 307*da0073e9SAndroid Build Coastguard Worker torch.tensor( 308*da0073e9SAndroid Build Coastguard Worker [[np.random.random(size=(3, 3)), np.random.random(size=(3, 2))]], 309*da0073e9SAndroid Build Coastguard Worker device=device, 310*da0073e9SAndroid Build Coastguard Worker ) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected sequence of length"): 313*da0073e9SAndroid Build Coastguard Worker torch.tensor( 314*da0073e9SAndroid Build Coastguard Worker [ 315*da0073e9SAndroid Build Coastguard Worker [np.random.random(size=(3, 3)), np.random.random(size=(3, 3))], 316*da0073e9SAndroid Build Coastguard Worker [np.random.random(size=(3, 3)), np.random.random(size=(3, 2))], 317*da0073e9SAndroid Build Coastguard Worker ], 318*da0073e9SAndroid Build Coastguard Worker device=device, 319*da0073e9SAndroid Build Coastguard Worker ) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker # expected shape is `[1, 2, 3]`, hence we try to iterate over 0-D array 322*da0073e9SAndroid Build Coastguard Worker # leading to type error : not a sequence. 323*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "not a sequence"): 324*da0073e9SAndroid Build Coastguard Worker torch.tensor( 325*da0073e9SAndroid Build Coastguard Worker [[np.random.random(size=(3)), np.random.random()]], device=device 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker # list of list or numpy array. 329*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "expected sequence of length"): 330*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2, 3], np.random.random(size=(2,))], device=device) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker @onlyCPU 333*da0073e9SAndroid Build Coastguard Worker def test_ctor_with_numpy_scalar_ctor(self, device) -> None: 334*da0073e9SAndroid Build Coastguard Worker dtypes = [ 335*da0073e9SAndroid Build Coastguard Worker np.double, 336*da0073e9SAndroid Build Coastguard Worker np.float64, 337*da0073e9SAndroid Build Coastguard Worker np.float16, 338*da0073e9SAndroid Build Coastguard Worker np.int64, 339*da0073e9SAndroid Build Coastguard Worker np.int32, 340*da0073e9SAndroid Build Coastguard Worker np.int16, 341*da0073e9SAndroid Build Coastguard Worker np.uint8, 342*da0073e9SAndroid Build Coastguard Worker np.bool_, 343*da0073e9SAndroid Build Coastguard Worker ] 344*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker @onlyCPU 348*da0073e9SAndroid Build Coastguard Worker def test_numpy_index(self, device): 349*da0073e9SAndroid Build Coastguard Worker i = np.array([0, 1, 2], dtype=np.int32) 350*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 351*da0073e9SAndroid Build Coastguard Worker for idx in i: 352*da0073e9SAndroid Build Coastguard Worker self.assertFalse(isinstance(idx, int)) 353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[idx], x[int(idx)]) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker @onlyCPU 356*da0073e9SAndroid Build Coastguard Worker def test_numpy_index_multi(self, device): 357*da0073e9SAndroid Build Coastguard Worker for dim_sz in [2, 8, 16, 32]: 358*da0073e9SAndroid Build Coastguard Worker i = np.zeros((dim_sz, dim_sz, dim_sz), dtype=np.int32) 359*da0073e9SAndroid Build Coastguard Worker i[: dim_sz // 2, :, :] = 1 360*da0073e9SAndroid Build Coastguard Worker x = torch.randn(dim_sz, dim_sz, dim_sz) 361*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x[i == 1].numel() == np.sum(i)) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker @onlyCPU 364*da0073e9SAndroid Build Coastguard Worker def test_numpy_array_interface(self, device): 365*da0073e9SAndroid Build Coastguard Worker types = [ 366*da0073e9SAndroid Build Coastguard Worker torch.DoubleTensor, 367*da0073e9SAndroid Build Coastguard Worker torch.FloatTensor, 368*da0073e9SAndroid Build Coastguard Worker torch.HalfTensor, 369*da0073e9SAndroid Build Coastguard Worker torch.LongTensor, 370*da0073e9SAndroid Build Coastguard Worker torch.IntTensor, 371*da0073e9SAndroid Build Coastguard Worker torch.ShortTensor, 372*da0073e9SAndroid Build Coastguard Worker torch.ByteTensor, 373*da0073e9SAndroid Build Coastguard Worker ] 374*da0073e9SAndroid Build Coastguard Worker dtypes = [ 375*da0073e9SAndroid Build Coastguard Worker np.float64, 376*da0073e9SAndroid Build Coastguard Worker np.float32, 377*da0073e9SAndroid Build Coastguard Worker np.float16, 378*da0073e9SAndroid Build Coastguard Worker np.int64, 379*da0073e9SAndroid Build Coastguard Worker np.int32, 380*da0073e9SAndroid Build Coastguard Worker np.int16, 381*da0073e9SAndroid Build Coastguard Worker np.uint8, 382*da0073e9SAndroid Build Coastguard Worker ] 383*da0073e9SAndroid Build Coastguard Worker for tp, dtype in zip(types, dtypes): 384*da0073e9SAndroid Build Coastguard Worker # Only concrete class can be given where "Type[number[_64Bit]]" is expected 385*da0073e9SAndroid Build Coastguard Worker if np.dtype(dtype).kind == "u": # type: ignore[misc] 386*da0073e9SAndroid Build Coastguard Worker # .type expects a XxxTensor, which have no type hints on 387*da0073e9SAndroid Build Coastguard Worker # purpose, so ignore during mypy type checking 388*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3, 4]).type(tp) # type: ignore[call-overload] 389*da0073e9SAndroid Build Coastguard Worker array = np.array([1, 2, 3, 4], dtype=dtype) 390*da0073e9SAndroid Build Coastguard Worker else: 391*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, -2, 3, -4]).type(tp) # type: ignore[call-overload] 392*da0073e9SAndroid Build Coastguard Worker array = np.array([1, -2, 3, -4], dtype=dtype) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker # Test __array__ w/o dtype argument 395*da0073e9SAndroid Build Coastguard Worker asarray = np.asarray(x) 396*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(asarray, np.ndarray) 397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asarray.dtype, dtype) 398*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asarray[i], x[i]) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker # Test __array_wrap__, same dtype 402*da0073e9SAndroid Build Coastguard Worker abs_x = np.abs(x) 403*da0073e9SAndroid Build Coastguard Worker abs_array = np.abs(array) 404*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(abs_x, tp) 405*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(abs_x[i], abs_array[i]) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker # Test __array__ with dtype argument 409*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 410*da0073e9SAndroid Build Coastguard Worker x = torch.IntTensor([1, -2, 3, -4]) 411*da0073e9SAndroid Build Coastguard Worker asarray = np.asarray(x, dtype=dtype) 412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asarray.dtype, dtype) 413*da0073e9SAndroid Build Coastguard Worker # Only concrete class can be given where "Type[number[_64Bit]]" is expected 414*da0073e9SAndroid Build Coastguard Worker if np.dtype(dtype).kind == "u": # type: ignore[misc] 415*da0073e9SAndroid Build Coastguard Worker wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) 416*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asarray[i], wrapped_x[i]) 418*da0073e9SAndroid Build Coastguard Worker else: 419*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(asarray[i], x[i]) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker # Test some math functions with float types 423*da0073e9SAndroid Build Coastguard Worker float_types = [torch.DoubleTensor, torch.FloatTensor] 424*da0073e9SAndroid Build Coastguard Worker float_dtypes = [np.float64, np.float32] 425*da0073e9SAndroid Build Coastguard Worker for tp, dtype in zip(float_types, float_dtypes): 426*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3, 4]).type(tp) # type: ignore[call-overload] 427*da0073e9SAndroid Build Coastguard Worker array = np.array([1, 2, 3, 4], dtype=dtype) 428*da0073e9SAndroid Build Coastguard Worker for func in ["sin", "sqrt", "ceil"]: 429*da0073e9SAndroid Build Coastguard Worker ufunc = getattr(np, func) 430*da0073e9SAndroid Build Coastguard Worker res_x = ufunc(x) 431*da0073e9SAndroid Build Coastguard Worker res_array = ufunc(array) 432*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res_x, tp) 433*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_x[i], res_array[i]) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker # Test functions with boolean return value 437*da0073e9SAndroid Build Coastguard Worker for tp, dtype in zip(types, dtypes): 438*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3, 4]).type(tp) # type: ignore[call-overload] 439*da0073e9SAndroid Build Coastguard Worker array = np.array([1, 2, 3, 4], dtype=dtype) 440*da0073e9SAndroid Build Coastguard Worker geq2_x = np.greater_equal(x, 2) 441*da0073e9SAndroid Build Coastguard Worker geq2_array = np.greater_equal(array, 2).astype("uint8") 442*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(geq2_x, torch.ByteTensor) 443*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 444*da0073e9SAndroid Build Coastguard Worker self.assertEqual(geq2_x[i], geq2_array[i]) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker @onlyCPU 447*da0073e9SAndroid Build Coastguard Worker def test_multiplication_numpy_scalar(self, device) -> None: 448*da0073e9SAndroid Build Coastguard Worker for np_dtype in [ 449*da0073e9SAndroid Build Coastguard Worker np.float32, 450*da0073e9SAndroid Build Coastguard Worker np.float64, 451*da0073e9SAndroid Build Coastguard Worker np.int32, 452*da0073e9SAndroid Build Coastguard Worker np.int64, 453*da0073e9SAndroid Build Coastguard Worker np.int16, 454*da0073e9SAndroid Build Coastguard Worker np.uint8, 455*da0073e9SAndroid Build Coastguard Worker ]: 456*da0073e9SAndroid Build Coastguard Worker for t_dtype in [torch.float, torch.double]: 457*da0073e9SAndroid Build Coastguard Worker # mypy raises an error when np.floatXY(2.0) is called 458*da0073e9SAndroid Build Coastguard Worker # even though this is valid code 459*da0073e9SAndroid Build Coastguard Worker np_sc = np_dtype(2.0) # type: ignore[abstract, arg-type] 460*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, requires_grad=True, dtype=t_dtype) 461*da0073e9SAndroid Build Coastguard Worker r1 = t * np_sc 462*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r1, torch.Tensor) 463*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r1.dtype == t_dtype) 464*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r1.requires_grad) 465*da0073e9SAndroid Build Coastguard Worker r2 = np_sc * t 466*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(r2, torch.Tensor) 467*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r2.dtype == t_dtype) 468*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r2.requires_grad) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker @onlyCPU 471*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo() 472*da0073e9SAndroid Build Coastguard Worker def test_parse_numpy_int_overflow(self, device): 473*da0073e9SAndroid Build Coastguard Worker # assertRaises uses a try-except which dynamo has issues with 474*da0073e9SAndroid Build Coastguard Worker # Only concrete class can be given where "Type[number[_64Bit]]" is expected 475*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 476*da0073e9SAndroid Build Coastguard Worker RuntimeError, 477*da0073e9SAndroid Build Coastguard Worker "(Overflow|an integer is required)", 478*da0073e9SAndroid Build Coastguard Worker lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)), 479*da0073e9SAndroid Build Coastguard Worker ) # type: ignore[call-overload] 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker @onlyCPU 482*da0073e9SAndroid Build Coastguard Worker def test_parse_numpy_int(self, device): 483*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/29252 484*da0073e9SAndroid Build Coastguard Worker for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]: 485*da0073e9SAndroid Build Coastguard Worker scalar = 3 486*da0073e9SAndroid Build Coastguard Worker np_arr = np.array([scalar], dtype=nptype) 487*da0073e9SAndroid Build Coastguard Worker np_val = np_arr[0] 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker # np integral type can be treated as a python int in native functions with 490*da0073e9SAndroid Build Coastguard Worker # int parameters: 491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val)) 492*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 493*da0073e9SAndroid Build Coastguard Worker torch.ones([2, 2, 2, 2]).mean(scalar), 494*da0073e9SAndroid Build Coastguard Worker torch.ones([2, 2, 2, 2]).mean(np_val), 495*da0073e9SAndroid Build Coastguard Worker ) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker # numpy integral type parses like a python int in custom python bindings: 498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Storage(np_val).size(), scalar) # type: ignore[attr-defined] 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([2], dtype=torch.int) 501*da0073e9SAndroid Build Coastguard Worker tensor[0] = np_val 502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor[0], np_val) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker # Original reported issue, np integral type parses to the correct 505*da0073e9SAndroid Build Coastguard Worker # PyTorch integral type when passed for a `Scalar` parameter in 506*da0073e9SAndroid Build Coastguard Worker # arithmetic operations: 507*da0073e9SAndroid Build Coastguard Worker t = torch.from_numpy(np_arr) 508*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t + np_val).dtype, t.dtype) 509*da0073e9SAndroid Build Coastguard Worker self.assertEqual((np_val + t).dtype, t.dtype) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def test_has_storage_numpy(self, device): 512*da0073e9SAndroid Build Coastguard Worker for dtype in [np.float32, np.float64, np.int64, np.int32, np.int16, np.uint8]: 513*da0073e9SAndroid Build Coastguard Worker arr = np.array([1], dtype=dtype) 514*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 515*da0073e9SAndroid Build Coastguard Worker torch.tensor(arr, device=device, dtype=torch.float32).storage() 516*da0073e9SAndroid Build Coastguard Worker ) 517*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 518*da0073e9SAndroid Build Coastguard Worker torch.tensor(arr, device=device, dtype=torch.double).storage() 519*da0073e9SAndroid Build Coastguard Worker ) 520*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 521*da0073e9SAndroid Build Coastguard Worker torch.tensor(arr, device=device, dtype=torch.int).storage() 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 524*da0073e9SAndroid Build Coastguard Worker torch.tensor(arr, device=device, dtype=torch.long).storage() 525*da0073e9SAndroid Build Coastguard Worker ) 526*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 527*da0073e9SAndroid Build Coastguard Worker torch.tensor(arr, device=device, dtype=torch.uint8).storage() 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 531*da0073e9SAndroid Build Coastguard Worker def test_numpy_scalar_cmp(self, device, dtype): 532*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 533*da0073e9SAndroid Build Coastguard Worker tensors = ( 534*da0073e9SAndroid Build Coastguard Worker torch.tensor(complex(1, 3), dtype=dtype, device=device), 535*da0073e9SAndroid Build Coastguard Worker torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device), 536*da0073e9SAndroid Build Coastguard Worker torch.tensor( 537*da0073e9SAndroid Build Coastguard Worker [[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device 538*da0073e9SAndroid Build Coastguard Worker ), 539*da0073e9SAndroid Build Coastguard Worker ) 540*da0073e9SAndroid Build Coastguard Worker else: 541*da0073e9SAndroid Build Coastguard Worker tensors = ( 542*da0073e9SAndroid Build Coastguard Worker torch.tensor(3, dtype=dtype, device=device), 543*da0073e9SAndroid Build Coastguard Worker torch.tensor([1, 0, -3], dtype=dtype, device=device), 544*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device), 545*da0073e9SAndroid Build Coastguard Worker ) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker for tensor in tensors: 548*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 549*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 550*da0073e9SAndroid Build Coastguard Worker np_array = tensor.cpu().numpy() 551*da0073e9SAndroid Build Coastguard Worker continue 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker np_array = tensor.cpu().numpy() 554*da0073e9SAndroid Build Coastguard Worker for t, a in product( 555*da0073e9SAndroid Build Coastguard Worker (tensor.flatten()[0], tensor.flatten()[0].item()), 556*da0073e9SAndroid Build Coastguard Worker (np_array.flatten()[0], np_array.flatten()[0].item()), 557*da0073e9SAndroid Build Coastguard Worker ): 558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, a) 559*da0073e9SAndroid Build Coastguard Worker if ( 560*da0073e9SAndroid Build Coastguard Worker dtype == torch.complex64 561*da0073e9SAndroid Build Coastguard Worker and torch.is_tensor(t) 562*da0073e9SAndroid Build Coastguard Worker and type(a) == np.complex64 563*da0073e9SAndroid Build Coastguard Worker ): 564*da0073e9SAndroid Build Coastguard Worker # TODO: Imaginary part is dropped in this case. Need fix. 565*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/43579 566*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t == a) 567*da0073e9SAndroid Build Coastguard Worker else: 568*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t == a) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker @onlyCPU 571*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 572*da0073e9SAndroid Build Coastguard Worker def test___eq__(self, device, dtype): 573*da0073e9SAndroid Build Coastguard Worker a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9) 574*da0073e9SAndroid Build Coastguard Worker b = a.clone().detach() 575*da0073e9SAndroid Build Coastguard Worker b_np = b.numpy() 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker # Check all elements equal 578*da0073e9SAndroid Build Coastguard Worker res_check = torch.ones_like(a, dtype=torch.bool) 579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a == b_np, res_check) 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_np == a, res_check) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker # Check one element unequal 583*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 584*da0073e9SAndroid Build Coastguard Worker b[1][3] = not b[1][3] 585*da0073e9SAndroid Build Coastguard Worker else: 586*da0073e9SAndroid Build Coastguard Worker b[1][3] += 1 587*da0073e9SAndroid Build Coastguard Worker res_check[1][3] = False 588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a == b_np, res_check) 589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_np == a, res_check) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker # Check random elements unequal 592*da0073e9SAndroid Build Coastguard Worker rand = torch.randint(0, 2, a.shape, dtype=torch.bool) 593*da0073e9SAndroid Build Coastguard Worker res_check = rand.logical_not() 594*da0073e9SAndroid Build Coastguard Worker b.copy_(a) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 597*da0073e9SAndroid Build Coastguard Worker b[rand] = b[rand].logical_not() 598*da0073e9SAndroid Build Coastguard Worker else: 599*da0073e9SAndroid Build Coastguard Worker b[rand] += 1 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a == b_np, res_check) 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_np == a, res_check) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker # Check all elements unequal 605*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 606*da0073e9SAndroid Build Coastguard Worker b.copy_(a.logical_not()) 607*da0073e9SAndroid Build Coastguard Worker else: 608*da0073e9SAndroid Build Coastguard Worker b.copy_(a + 1) 609*da0073e9SAndroid Build Coastguard Worker res_check.fill_(False) 610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a == b_np, res_check) 611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_np == a, res_check) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker @onlyCPU 614*da0073e9SAndroid Build Coastguard Worker def test_empty_tensors_interop(self, device): 615*da0073e9SAndroid Build Coastguard Worker x = torch.rand((), dtype=torch.float16) 616*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(np.random.rand(0), dtype=torch.float16) 617*da0073e9SAndroid Build Coastguard Worker # Same can be achieved by running 618*da0073e9SAndroid Build Coastguard Worker # y = torch.empty_strided((0,), (0,), dtype=torch.float16) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/115068 621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.true_divide(x, y).shape, y.shape) 622*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/115066 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mul(x, y).shape, y.shape) 624*da0073e9SAndroid Build Coastguard Worker # Regression test for https://github.com/pytorch/pytorch/issues/113037 625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.div(x, y, rounding_mode="floor").shape, y.shape) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNumPyInterop, globals()) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 631*da0073e9SAndroid Build Coastguard Worker run_tests() 632