1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: tests"] 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerimport torch.utils.data 7*da0073e9SAndroid Build Coastguard Workerimport numpy as np 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport contextlib 10*da0073e9SAndroid Build Coastguard Workerimport gc 11*da0073e9SAndroid Build Coastguard Workerimport io 12*da0073e9SAndroid Build Coastguard Workerimport inspect 13*da0073e9SAndroid Build Coastguard Workerimport itertools 14*da0073e9SAndroid Build Coastguard Workerimport math 15*da0073e9SAndroid Build Coastguard Workerimport random 16*da0073e9SAndroid Build Coastguard Workerimport re 17*da0073e9SAndroid Build Coastguard Workerimport copy 18*da0073e9SAndroid Build Coastguard Workerimport os 19*da0073e9SAndroid Build Coastguard Workerimport tempfile 20*da0073e9SAndroid Build Coastguard Workerimport unittest 21*da0073e9SAndroid Build Coastguard Workerimport warnings 22*da0073e9SAndroid Build Coastguard Workerimport types 23*da0073e9SAndroid Build Coastguard Workerimport pickle 24*da0073e9SAndroid Build Coastguard Workerimport textwrap 25*da0073e9SAndroid Build Coastguard Workerimport subprocess 26*da0073e9SAndroid Build Coastguard Workerimport weakref 27*da0073e9SAndroid Build Coastguard Workerimport sys 28*da0073e9SAndroid Build Coastguard Workerimport copyreg 29*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan 30*da0073e9SAndroid Build Coastguard Workerfrom itertools import product, combinations, permutations, chain 31*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 32*da0073e9SAndroid Build Coastguard Workerfrom torch import multiprocessing as mp 33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 34*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_optimizers import ( 35*da0073e9SAndroid Build Coastguard Worker optim_db, optims, _get_optim_inputs_including_global_cliquey_kwargs) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( # type: ignore[attr-defined] 38*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, run_tests, IS_JETSON, 39*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, 40*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, slowTestIf, 41*da0073e9SAndroid Build Coastguard Worker TEST_WITH_CROSSREF, skipIfTorchDynamo, skipRocmIfTorchInductor, set_default_dtype, 42*da0073e9SAndroid Build Coastguard Worker skipCUDAMemoryLeakCheckIf, BytesIOContext, 43*da0073e9SAndroid Build Coastguard Worker skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName, 44*da0073e9SAndroid Build Coastguard Worker wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, 45*da0073e9SAndroid Build Coastguard Worker bytes_to_scalar, parametrize, skipIfMps, noncontiguous_like, 46*da0073e9SAndroid Build Coastguard Worker AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo) 47*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler 48*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 49*da0073e9SAndroid Build Coastguard Worker expectedFailureMeta, 50*da0073e9SAndroid Build Coastguard Worker expectedFailureXLA, 51*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 52*da0073e9SAndroid Build Coastguard Worker onlyCUDA, onlyCPU, 53*da0073e9SAndroid Build Coastguard Worker dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, 54*da0073e9SAndroid Build Coastguard Worker skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, 55*da0073e9SAndroid Build Coastguard Worker get_all_device_types, skipXLA) 56*da0073e9SAndroid Build Coastguard Workerfrom typing import Tuple 57*da0073e9SAndroid Build Coastguard Workerimport torch.backends.quantized 58*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.data 59*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 60*da0073e9SAndroid Build Coastguard Worker tf32_on_and_off, tf32_is_not_fp32, TEST_CUDNN, TEST_MULTIGPU, 61*da0073e9SAndroid Build Coastguard Worker _create_scaling_case, _create_scaling_models_optimizers) 62*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_mkldnn import bf32_on_and_off 63*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import ( 64*da0073e9SAndroid Build Coastguard Worker floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, 65*da0073e9SAndroid Build Coastguard Worker all_types_and, floating_types, floating_and_complex_types, integral_types_and, 66*da0073e9SAndroid Build Coastguard Worker get_all_qint_dtypes, 67*da0073e9SAndroid Build Coastguard Worker) 68*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerif TEST_WITH_TORCHINDUCTOR: 71*da0073e9SAndroid Build Coastguard Worker from torch._inductor.test_case import TestCase 72*da0073e9SAndroid Build Coastguard Workerelse: 73*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import TestCase # type: ignore[assignment] 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker# Protects against includes accidentally setting the default dtype 77*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() is torch.float32 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 80*da0073e9SAndroid Build Coastguard Worker# sharding on sandcastle. This line silences flake warnings 81*da0073e9SAndroid Build Coastguard Workerload_tests = load_tests 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard WorkerAMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 86*da0073e9SAndroid Build Coastguard Workerdef torch_vital_set(value): 87*da0073e9SAndroid Build Coastguard Worker stash = None 88*da0073e9SAndroid Build Coastguard Worker if 'TORCH_VITAL' in os.environ: 89*da0073e9SAndroid Build Coastguard Worker stash = os.environ['TORCH_VITAL'] 90*da0073e9SAndroid Build Coastguard Worker os.environ['TORCH_VITAL'] = value 91*da0073e9SAndroid Build Coastguard Worker try: 92*da0073e9SAndroid Build Coastguard Worker yield 93*da0073e9SAndroid Build Coastguard Worker finally: 94*da0073e9SAndroid Build Coastguard Worker if stash: 95*da0073e9SAndroid Build Coastguard Worker os.environ['TORCH_VITAL'] = stash 96*da0073e9SAndroid Build Coastguard Worker else: 97*da0073e9SAndroid Build Coastguard Worker del os.environ['TORCH_VITAL'] 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker# Tests Vital Signs for Torch 100*da0073e9SAndroid Build Coastguard Worker# FIXME: document or deprecate whatever this is 101*da0073e9SAndroid Build Coastguard Workerclass TestBasicVitalSigns(TestCase): 102*da0073e9SAndroid Build Coastguard Worker def test_basic_vitals(self): 103*da0073e9SAndroid Build Coastguard Worker with torch_vital_set(''): 104*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.vitals_enabled()) 105*da0073e9SAndroid Build Coastguard Worker with torch_vital_set('ON'): 106*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.vitals_enabled()) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker def test_basic_vitals_read_write(self): 109*da0073e9SAndroid Build Coastguard Worker with torch_vital_set('ON'): 110*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.vitals_enabled()) 111*da0073e9SAndroid Build Coastguard Worker # This tests the code path of setting a vital 112*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.set_vital('Dataloader', 'basic_unit_test', 'TEST_VALUE_STRING')) 113*da0073e9SAndroid Build Coastguard Worker self.assertIn('TEST_VALUE_STRING', torch.read_vitals()) 114*da0073e9SAndroid Build Coastguard Worker self.assertIn('CUDA.used', torch.read_vitals()) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def test_dataloader_vitals(self): 117*da0073e9SAndroid Build Coastguard Worker with torch_vital_set('ON'): 118*da0073e9SAndroid Build Coastguard Worker inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 119*da0073e9SAndroid Build Coastguard Worker tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 120*da0073e9SAndroid Build Coastguard Worker dataset = torch.utils.data.TensorDataset(inps, tgts) 121*da0073e9SAndroid Build Coastguard Worker loader = torch.utils.data.DataLoader(dataset, batch_size=2) 122*da0073e9SAndroid Build Coastguard Worker self.assertIn('Dataloader.enabled\t\t True', torch.read_vitals()) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker# FIXME: document or deprecate whatever this is 125*da0073e9SAndroid Build Coastguard Workerclass TestVitalSignsCuda(TestCase): 126*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 127*da0073e9SAndroid Build Coastguard Worker def test_cuda_vitals_gpu_only(self, device): 128*da0073e9SAndroid Build Coastguard Worker with torch_vital_set('ON'): 129*da0073e9SAndroid Build Coastguard Worker self.assertIn('CUDA.used\t\t true', torch.read_vitals()) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Workeris_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Workerclass TestTorchDeviceType(TestCase): 135*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker # TODO: move all tensor creation to common ops 138*da0073e9SAndroid Build Coastguard Worker def _rand_shape(self, dim, min_size, max_size): 139*da0073e9SAndroid Build Coastguard Worker shape = [] 140*da0073e9SAndroid Build Coastguard Worker for i in range(dim): 141*da0073e9SAndroid Build Coastguard Worker shape.append(random.randint(min_size, max_size)) 142*da0073e9SAndroid Build Coastguard Worker return tuple(shape) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker # Validates that mathematical constants are defined properly, as required by 145*da0073e9SAndroid Build Coastguard Worker # the Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) 146*da0073e9SAndroid Build Coastguard Worker @onlyCPU 147*da0073e9SAndroid Build Coastguard Worker def test_constants(self, device): 148*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.e, float) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.e, math.e, atol=0, rtol=0) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.pi, float) 152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.pi, math.pi, atol=0, rtol=0) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.nan, float) 155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nan, math.nan, equal_nan=True) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.inf, float) 158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.inf, math.inf) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 161*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, 162*da0073e9SAndroid Build Coastguard Worker torch.bool, torch.float32, torch.complex64, torch.float64, 163*da0073e9SAndroid Build Coastguard Worker torch.complex128, torch.uint16, torch.uint32, torch.uint64) 164*da0073e9SAndroid Build Coastguard Worker def test_bytes_to_scalar(self, device, dtype): 165*da0073e9SAndroid Build Coastguard Worker def rand_byte(): 166*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 167*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 2, ()).item() 168*da0073e9SAndroid Build Coastguard Worker else: 169*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 256, ()).item() 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker element_size = torch._utils._element_size(dtype) 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker for i in range(10): 174*da0073e9SAndroid Build Coastguard Worker bytes_list = [rand_byte() for _ in range(element_size)] 175*da0073e9SAndroid Build Coastguard Worker scalar = bytes_to_scalar(bytes_list, dtype, device) 176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scalar.storage().untyped().tolist(), bytes_list) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, 179*da0073e9SAndroid Build Coastguard Worker torch.bool, torch.float32, torch.complex64, torch.float64, 180*da0073e9SAndroid Build Coastguard Worker torch.complex128, torch.uint16, torch.uint32, torch.uint64) 181*da0073e9SAndroid Build Coastguard Worker def test_storage(self, device, dtype): 182*da0073e9SAndroid Build Coastguard Worker v = make_tensor((3, 5), dtype=dtype, device=device, low=-9, high=9) 183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.storage()[0], v[0][0]) 184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.storage()[14], v[2][4]) 185*da0073e9SAndroid Build Coastguard Worker v_s = v.storage() 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker for el_num in range(v.numel()): 188*da0073e9SAndroid Build Coastguard Worker dim0 = el_num // v.size(1) 189*da0073e9SAndroid Build Coastguard Worker dim1 = el_num % v.size(1) 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 191*da0073e9SAndroid Build Coastguard Worker v_s[el_num], 192*da0073e9SAndroid Build Coastguard Worker v[dim0][dim1]) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker v_s_byte = v.storage().untyped() 195*da0073e9SAndroid Build Coastguard Worker el_size = v.element_size() 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker for el_num in range(v.numel()): 198*da0073e9SAndroid Build Coastguard Worker start = el_num * el_size 199*da0073e9SAndroid Build Coastguard Worker end = start + el_size 200*da0073e9SAndroid Build Coastguard Worker dim0 = el_num // v.size(1) 201*da0073e9SAndroid Build Coastguard Worker dim1 = el_num % v.size(1) 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 203*da0073e9SAndroid Build Coastguard Worker bytes_to_scalar(v_s_byte[start:end], dtype, device), 204*da0073e9SAndroid Build Coastguard Worker v[dim0][dim1]) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 207*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64, 208*da0073e9SAndroid Build Coastguard Worker torch.bool, torch.float32, torch.complex64, torch.float64, 209*da0073e9SAndroid Build Coastguard Worker torch.complex128, torch.quint8, torch.qint8, torch.qint32, 210*da0073e9SAndroid Build Coastguard Worker torch.quint4x2) 211*da0073e9SAndroid Build Coastguard Worker def test_storage_setitem(self, device, dtype): 212*da0073e9SAndroid Build Coastguard Worker # Skip quantized dtypes for CUDA, since they're not supported 213*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 214*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]: 215*da0073e9SAndroid Build Coastguard Worker return 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker storage_type_name = torch.storage._dtype_to_storage_type_map()[dtype] 218*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 219*da0073e9SAndroid Build Coastguard Worker storage_type = eval('torch.cuda.' + storage_type_name) 220*da0073e9SAndroid Build Coastguard Worker else: 221*da0073e9SAndroid Build Coastguard Worker storage_type = eval('torch.' + storage_type_name) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker N = 10 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker s = storage_type(N) 226*da0073e9SAndroid Build Coastguard Worker s[:] = 0 227*da0073e9SAndroid Build Coastguard Worker l = [0] * N 228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, storage_type(l)) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker for i in range(N): 231*da0073e9SAndroid Build Coastguard Worker s[i] = i 232*da0073e9SAndroid Build Coastguard Worker l[i] = i 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, storage_type(l)) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker l[2:7] = [1] * 5 237*da0073e9SAndroid Build Coastguard Worker s[2:7] = 1 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, storage_type(l)) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 241*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 242*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 243*da0073e9SAndroid Build Coastguard Worker def test_tensor_storage_type(self, device, dtype): 244*da0073e9SAndroid Build Coastguard Worker a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker module = torch.cuda if (torch.device(device).type == 'cuda') else torch 247*da0073e9SAndroid Build Coastguard Worker expected_storage_type = getattr(module, torch.storage._dtype_to_storage_type_map()[dtype]) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.storage_type(), expected_storage_type) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 252*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) 253*da0073e9SAndroid Build Coastguard Worker def test_tensor_from_storage(self, device, dtype): 254*da0073e9SAndroid Build Coastguard Worker a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) 255*da0073e9SAndroid Build Coastguard Worker a_s = a.storage() 256*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(a_s, device=device, dtype=dtype).reshape(a.size()) 257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 258*da0073e9SAndroid Build Coastguard Worker c = torch.tensor(a_s.untyped(), device=device, dtype=dtype).reshape(a.size()) 259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, c) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): 262*da0073e9SAndroid Build Coastguard Worker if error_dtype == dtype: 263*da0073e9SAndroid Build Coastguard Worker continue 264*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'): 265*da0073e9SAndroid Build Coastguard Worker error_storage = a.to(error_dtype).storage() 266*da0073e9SAndroid Build Coastguard Worker torch.tensor(error_storage, device=device, dtype=dtype) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 269*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 270*da0073e9SAndroid Build Coastguard Worker def test_set_storage(self, device, dtype): 271*da0073e9SAndroid Build Coastguard Worker a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) 272*da0073e9SAndroid Build Coastguard Worker a_s = a.storage() 273*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([], device=device, dtype=dtype).set_(a_s).reshape(a.size()) 274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 275*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([], device=device, dtype=dtype).set_(a_s.untyped()).reshape(a.size()) 276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, c) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker for error_dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): 279*da0073e9SAndroid Build Coastguard Worker if error_dtype == dtype: 280*da0073e9SAndroid Build Coastguard Worker continue 281*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Expected a Storage of type'): 282*da0073e9SAndroid Build Coastguard Worker error_storage = a.to(error_dtype).storage() 283*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([], device=device, dtype=dtype).set_(error_storage) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def _check_storage_meta(self, s, s_check): 286*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 287*da0073e9SAndroid Build Coastguard Worker isinstance(s, (torch.UntypedStorage, torch.TypedStorage)) and 288*da0073e9SAndroid Build Coastguard Worker isinstance(s_check, type(s)), 289*da0073e9SAndroid Build Coastguard Worker ( 290*da0073e9SAndroid Build Coastguard Worker 's and s_check must both be one of UntypedStorage or ' 291*da0073e9SAndroid Build Coastguard Worker 'TypedStorage, but got' 292*da0073e9SAndroid Build Coastguard Worker f' {type(s).__name__} and {type(s_check).__name__}')) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.device.type, 'meta') 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.nbytes(), s_check.nbytes()) 296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.size(), s_check.size()) 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.data_ptr(), 0) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, r'Not available'): 300*da0073e9SAndroid Build Coastguard Worker s[0] 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker if isinstance(s, torch.TypedStorage): 303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.dtype, s_check.dtype) 304*da0073e9SAndroid Build Coastguard Worker self._check_storage_meta(s.untyped(), s_check.untyped()) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 307*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 308*da0073e9SAndroid Build Coastguard Worker def test_typed_storage_meta(self, device, dtype): 309*da0073e9SAndroid Build Coastguard Worker args_list = [ 310*da0073e9SAndroid Build Coastguard Worker [], 311*da0073e9SAndroid Build Coastguard Worker [0], 312*da0073e9SAndroid Build Coastguard Worker [100], 313*da0073e9SAndroid Build Coastguard Worker [[1, 2, 3, 4, 5, 6]], 314*da0073e9SAndroid Build Coastguard Worker ] 315*da0073e9SAndroid Build Coastguard Worker for args in args_list: 316*da0073e9SAndroid Build Coastguard Worker s_check = torch.TypedStorage(*args, dtype=dtype, device=device) 317*da0073e9SAndroid Build Coastguard Worker s = torch.TypedStorage(*args, dtype=dtype, device='meta') 318*da0073e9SAndroid Build Coastguard Worker self._check_storage_meta(s, s_check) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 321*da0073e9SAndroid Build Coastguard Worker def test_untyped_storage_meta(self, device): 322*da0073e9SAndroid Build Coastguard Worker args_list = [ 323*da0073e9SAndroid Build Coastguard Worker [], 324*da0073e9SAndroid Build Coastguard Worker [0], 325*da0073e9SAndroid Build Coastguard Worker [100], 326*da0073e9SAndroid Build Coastguard Worker [[1, 2, 3, 4, 5, 6]], 327*da0073e9SAndroid Build Coastguard Worker ] 328*da0073e9SAndroid Build Coastguard Worker for args in args_list: 329*da0073e9SAndroid Build Coastguard Worker s_check = torch.UntypedStorage(*args, device=device) 330*da0073e9SAndroid Build Coastguard Worker s = torch.UntypedStorage(*args, device='meta') 331*da0073e9SAndroid Build Coastguard Worker self._check_storage_meta(s, s_check) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 334*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 335*da0073e9SAndroid Build Coastguard Worker def test_storage_meta_from_tensor(self, device, dtype): 336*da0073e9SAndroid Build Coastguard Worker t_check = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9) 337*da0073e9SAndroid Build Coastguard Worker t = t_check.to('meta') 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker s_check = t_check.storage() 340*da0073e9SAndroid Build Coastguard Worker s = t.storage() 341*da0073e9SAndroid Build Coastguard Worker self._check_storage_meta(s, s_check) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 344*da0073e9SAndroid Build Coastguard Worker def test_storage_meta_errors(self, device, dtype): 345*da0073e9SAndroid Build Coastguard Worker s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): 348*da0073e9SAndroid Build Coastguard Worker s0.cpu() 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'only available on CPU'): 351*da0073e9SAndroid Build Coastguard Worker s0._share_fd_cpu_() 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'only available on CPU'): 354*da0073e9SAndroid Build Coastguard Worker s0._share_filename_cpu_() 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 357*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): 358*da0073e9SAndroid Build Coastguard Worker s0.cuda() 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'only available on CUDA'): 361*da0073e9SAndroid Build Coastguard Worker s0._share_cuda_() 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"cannot pin 'torch.storage.UntypedStorage' only CPU memory can be pinned"): 364*da0073e9SAndroid Build Coastguard Worker s0.pin_memory() 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'only available on CPU'): 367*da0073e9SAndroid Build Coastguard Worker s0.share_memory_() 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, r'Not available'): 370*da0073e9SAndroid Build Coastguard Worker s0.tolist() 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 373*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): 374*da0073e9SAndroid Build Coastguard Worker s0._write_file(f, True, True, s0.element_size()) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']: 377*da0073e9SAndroid Build Coastguard Worker s1 = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, r'Cannot copy out'): 380*da0073e9SAndroid Build Coastguard Worker s1.copy_(s0) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker @onlyCPU 383*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 384*da0073e9SAndroid Build Coastguard Worker def test_storage_meta_ok(self, device, dtype): 385*da0073e9SAndroid Build Coastguard Worker s0 = torch.TypedStorage([1, 2, 3, 4], device='meta', dtype=dtype) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker # This is OK, it changes the meta storage size without allocating 388*da0073e9SAndroid Build Coastguard Worker s0.resize_(10) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 391*da0073e9SAndroid Build Coastguard Worker def test_module_share_memory(self): 392*da0073e9SAndroid Build Coastguard Worker # Test fix for issue #80733 393*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/80733 394*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(3, 1) 395*da0073e9SAndroid Build Coastguard Worker model_cuda = model.to('cuda') 396*da0073e9SAndroid Build Coastguard Worker model.share_memory() 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 399*da0073e9SAndroid Build Coastguard Worker def test_deepcopy(self, device, dtype): 400*da0073e9SAndroid Build Coastguard Worker from copy import deepcopy 401*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5, dtype=dtype, device=device) 402*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5, dtype=dtype, device=device) 403*da0073e9SAndroid Build Coastguard Worker c = a.view(25) 404*da0073e9SAndroid Build Coastguard Worker q = [a, [a.storage(), b.storage()], b, c] 405*da0073e9SAndroid Build Coastguard Worker w = deepcopy(q) 406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[0], q[0], atol=0, rtol=0) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[1][0], q[1][0], atol=0, rtol=0) 408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[1][1], q[1][1], atol=0, rtol=0) 409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[1], q[1], atol=0, rtol=0) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[2], q[2], atol=0, rtol=0) 411*da0073e9SAndroid Build Coastguard Worker 412*da0073e9SAndroid Build Coastguard Worker # Check that deepcopy preserves sharing 413*da0073e9SAndroid Build Coastguard Worker w[0].add_(1) 414*da0073e9SAndroid Build Coastguard Worker for i in range(a.numel()): 415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[1][0][i], q[1][0][i] + 1) 416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[3], c + 1) 417*da0073e9SAndroid Build Coastguard Worker w[2].sub_(1) 418*da0073e9SAndroid Build Coastguard Worker for i in range(a.numel()): 419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w[1][1][i], q[1][1][i] - 1) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker # Check that deepcopy preserves attributes 422*da0073e9SAndroid Build Coastguard Worker a.foo = 3 423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(deepcopy(a).foo, 3) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float32, torch.complex64) 426*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_scalar(self, device, dtype): 427*da0073e9SAndroid Build Coastguard Worker from copy import deepcopy 428*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(5, dtype=dtype, device=device) 429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), deepcopy(a).size()) 430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, deepcopy(a)) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def check_internal_mem_overlap(self, inplace_op, num_inputs, 433*da0073e9SAndroid Build Coastguard Worker dtype, device, 434*da0073e9SAndroid Build Coastguard Worker expected_failure=False): 435*da0073e9SAndroid Build Coastguard Worker if isinstance(inplace_op, str): 436*da0073e9SAndroid Build Coastguard Worker inplace_op = getattr(torch.Tensor, inplace_op) 437*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, dtype=dtype, device=device).expand(3, 3) 438*da0073e9SAndroid Build Coastguard Worker inputs = [input] + [torch.randn_like(input) 439*da0073e9SAndroid Build Coastguard Worker for i in range(num_inputs - 1)] 440*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 441*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'single memory location'): 442*da0073e9SAndroid Build Coastguard Worker inplace_op(*inputs) 443*da0073e9SAndroid Build Coastguard Worker else: 444*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 445*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'single memory location'): 446*da0073e9SAndroid Build Coastguard Worker inplace_op(*inputs) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker def unary_check_input_output_mem_overlap(self, data, sz, op, 449*da0073e9SAndroid Build Coastguard Worker expected_failure=False): 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker def _test(op, output, input): 452*da0073e9SAndroid Build Coastguard Worker output_exp = torch.empty_like(output) 453*da0073e9SAndroid Build Coastguard Worker op(input, out=output_exp) 454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(input, out=output), output_exp, msg=op.__name__) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker # output is identical to input: 457*da0073e9SAndroid Build Coastguard Worker _test(op, output=data[0:sz], input=data[0:sz]) 458*da0073e9SAndroid Build Coastguard Worker # output and input are independent: 459*da0073e9SAndroid Build Coastguard Worker _test(op, output=data[0:sz], input=data[sz:2 * sz]) 460*da0073e9SAndroid Build Coastguard Worker # output partially overlaps with input: 461*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 462*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 463*da0073e9SAndroid Build Coastguard Worker _test(op, data[0:sz], data[1:sz + 1]) 464*da0073e9SAndroid Build Coastguard Worker else: 465*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 466*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 467*da0073e9SAndroid Build Coastguard Worker _test(op, data[0:sz], data[1:sz + 1]) 468*da0073e9SAndroid Build Coastguard Worker # output is transpose of input: 469*da0073e9SAndroid Build Coastguard Worker length = int(math.sqrt(sz)) 470*da0073e9SAndroid Build Coastguard Worker input = data[:length**2].view([length, length]) 471*da0073e9SAndroid Build Coastguard Worker out = input.t() 472*da0073e9SAndroid Build Coastguard Worker if not expected_failure: 473*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 474*da0073e9SAndroid Build Coastguard Worker _test(op, out, input) 475*da0073e9SAndroid Build Coastguard Worker else: 476*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 477*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 478*da0073e9SAndroid Build Coastguard Worker _test(op, out, input) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def ternary_check_input_output_mem_overlap(self, op, device, 481*da0073e9SAndroid Build Coastguard Worker expected_failure=False): 482*da0073e9SAndroid Build Coastguard Worker sz = 9 483*da0073e9SAndroid Build Coastguard Worker data = torch.randn(2 * sz, device=device) 484*da0073e9SAndroid Build Coastguard Worker other1 = torch.randn(sz, device=device) 485*da0073e9SAndroid Build Coastguard Worker other2 = torch.randn(sz, device=device) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 488*da0073e9SAndroid Build Coastguard Worker data, sz, lambda input, out: 489*da0073e9SAndroid Build Coastguard Worker op(input, other1.view(input.shape), other2.view(input.shape), out=out), 490*da0073e9SAndroid Build Coastguard Worker expected_failure=expected_failure) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 493*da0073e9SAndroid Build Coastguard Worker data, sz, lambda input, out: 494*da0073e9SAndroid Build Coastguard Worker op(other1.view(input.shape), input, other2.view(input.shape), out=out), 495*da0073e9SAndroid Build Coastguard Worker expected_failure=expected_failure) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 498*da0073e9SAndroid Build Coastguard Worker data, sz, lambda input, out: 499*da0073e9SAndroid Build Coastguard Worker op(other1.view(input.shape), other2.view(input.shape), input, out=out), 500*da0073e9SAndroid Build Coastguard Worker expected_failure=expected_failure) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker def _select_broadcastable_dims(self, dims_full=None): 503*da0073e9SAndroid Build Coastguard Worker # select full dimensionality 504*da0073e9SAndroid Build Coastguard Worker if dims_full is None: 505*da0073e9SAndroid Build Coastguard Worker dims_full = [] 506*da0073e9SAndroid Build Coastguard Worker ndims = random.randint(1, 4) 507*da0073e9SAndroid Build Coastguard Worker dims_full = [random.randint(1, 8) for _ in range(ndims)] 508*da0073e9SAndroid Build Coastguard Worker else: 509*da0073e9SAndroid Build Coastguard Worker ndims = len(dims_full) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker # select actual dimensions for ops: 512*da0073e9SAndroid Build Coastguard Worker # larger: full ndims, individual sizes may be reduced 513*da0073e9SAndroid Build Coastguard Worker # smaller: possibly reduced ndims, sizes may be reduced 514*da0073e9SAndroid Build Coastguard Worker smaller_ndims = random.randint(1, ndims) 515*da0073e9SAndroid Build Coastguard Worker dims_small = [] 516*da0073e9SAndroid Build Coastguard Worker dims_large = [] 517*da0073e9SAndroid Build Coastguard Worker for i in range(ndims - 1, -1, -1): 518*da0073e9SAndroid Build Coastguard Worker j = random.randint(1, 3) 519*da0073e9SAndroid Build Coastguard Worker if j == 1: # no reduced singleton dimension 520*da0073e9SAndroid Build Coastguard Worker ds = dims_full[i] 521*da0073e9SAndroid Build Coastguard Worker dl = dims_full[i] 522*da0073e9SAndroid Build Coastguard Worker elif j == 2: # larger may have reduced singleton dimension 523*da0073e9SAndroid Build Coastguard Worker ds = dims_full[i] 524*da0073e9SAndroid Build Coastguard Worker dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] 525*da0073e9SAndroid Build Coastguard Worker elif j == 3: # smaller may have reduced singleton dimension 526*da0073e9SAndroid Build Coastguard Worker ds = 1 527*da0073e9SAndroid Build Coastguard Worker dl = dims_full[i] 528*da0073e9SAndroid Build Coastguard Worker dims_large = [dl] + dims_large 529*da0073e9SAndroid Build Coastguard Worker if len(dims_small) < smaller_ndims: 530*da0073e9SAndroid Build Coastguard Worker dims_small = [ds] + dims_small 531*da0073e9SAndroid Build Coastguard Worker return (dims_small, dims_large, dims_full) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker # collected tests of ops that used scalar_check in Declarations.cwrap for 534*da0073e9SAndroid Build Coastguard Worker # correctness 535*da0073e9SAndroid Build Coastguard Worker def test_scalar_check(self, device): 536*da0073e9SAndroid Build Coastguard Worker zero_d = torch.randn((), device=device) 537*da0073e9SAndroid Build Coastguard Worker one_d = torch.randn((1,), device=device) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker # remainder 540*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.remainder(zero_d, zero_d).shape) 541*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.remainder(zero_d, 2).shape) 542*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.remainder(zero_d, one_d).shape) 543*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.remainder(one_d, zero_d).shape) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker # fmod 546*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.fmod(zero_d, zero_d).shape) 547*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.fmod(zero_d, 2).shape) 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.fmod(zero_d, one_d).shape) 549*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.fmod(one_d, zero_d).shape) 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker # exp, cos, cosh, tan, atan, tanh, erf, erfc, reciprocal 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.exp(zero_d).shape) 553*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.cos(zero_d).shape) 554*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.cosh(zero_d).shape) 555*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.tan(zero_d).shape) 556*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.atan(zero_d).shape) 557*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.acosh(zero_d).shape) 558*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.asinh(zero_d).shape) 559*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.atanh(zero_d).shape) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.tanh(zero_d).shape) 561*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.erf(zero_d).shape) 562*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.erfc(zero_d).shape) 563*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.reciprocal(zero_d).shape) 564*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.exp(one_d).shape) 565*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.cos(one_d).shape) 566*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.cosh(one_d).shape) 567*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.tan(one_d).shape) 568*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.atan(one_d).shape) 569*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.acosh(one_d).shape) 570*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.asinh(one_d).shape) 571*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.atanh(one_d).shape) 572*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.tanh(one_d).shape) 573*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.erf(one_d).shape) 574*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.erfc(one_d).shape) 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.reciprocal(one_d).shape) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker # clamp 578*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.clamp(zero_d, min=0, max=1).shape) 579*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.clamp(zero_d, min=0).shape) 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.clamp(zero_d, max=1).shape) 581*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.clamp(one_d, min=0, max=1).shape) 582*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.clamp(one_d, min=0).shape) 583*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.clamp(one_d, max=1).shape) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker # cumsum, cumprod, cummax, cummin 586*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.logcumsumexp(zero_d, 0).shape) 587*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.cumsum(zero_d, 0).shape) 588*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.cumprod(zero_d, 0).shape) 589*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.cummax(zero_d, 0)[0].shape) 590*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.cummin(zero_d, 0)[0].shape) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker # sort, topk 593*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, False)]) 594*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.sort(zero_d, 0, True)]) 595*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, False)]) 596*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.topk(zero_d, 1, 0, True)]) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker # max, min 599*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.max(zero_d, zero_d).shape) 600*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.max(one_d, zero_d).shape) 601*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.max(zero_d, one_d).shape) 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.min(zero_d, zero_d).shape) 603*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.min(one_d, zero_d).shape) 604*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.min(zero_d, one_d).shape) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker zero_d_int = torch.tensor(1, device=device) 607*da0073e9SAndroid Build Coastguard Worker one_d_int = torch.tensor([1], device=device) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker # lshift, rshift 610*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int >> zero_d_int).shape) 611*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int >> 1).shape) 612*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int >> zero_d_int).shape) 613*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (zero_d_int >> one_d_int).shape) 614*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int >> 1).shape) 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int << zero_d_int).shape) 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int << 1).shape) 618*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int << zero_d_int).shape) 619*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (zero_d_int << one_d_int).shape) 620*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int << 1).shape) 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker # or 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int | zero_d_int).shape) 624*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int | 1).shape) 625*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int | zero_d_int).shape) 626*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (zero_d_int | one_d_int).shape) 627*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int | 1).shape) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker # and 630*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int & zero_d_int).shape) 631*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), (zero_d_int & 1).shape) 632*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int & zero_d_int).shape) 633*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (zero_d_int & one_d_int).shape) 634*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), (one_d_int & 1).shape) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker # clone 637*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), zero_d.clone().shape) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker zero_d_bool = torch.tensor(True, device=device) 640*da0073e9SAndroid Build Coastguard Worker one_d_bool = torch.tensor([True], device=device) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker # masked_select 643*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.masked_select(zero_d_bool, zero_d_bool).shape) 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.masked_select(zero_d_bool, one_d_bool).shape) 645*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.masked_select(one_d_bool, zero_d_bool).shape) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker zero_d_uint8 = torch.tensor(1, dtype=torch.uint8, device=device) 648*da0073e9SAndroid Build Coastguard Worker one_d_uint8 = torch.tensor([1], dtype=torch.uint8, device=device) 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker # mode 651*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=True)]) 652*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.mode(zero_d, dim=0, keepdim=False)]) 653*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(1,), (1,)], [x.shape for x in torch.mode(one_d, dim=0, keepdim=True)]) 654*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.mode(one_d, dim=0, keepdim=False)]) 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker # max 657*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=True)]) 658*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.max(zero_d, dim=0, keepdim=False)]) 659*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(1,), (1,)], [x.shape for x in torch.max(one_d, dim=0, keepdim=True)]) 660*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.max(one_d, dim=0, keepdim=False)]) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker # amax 663*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=True).shape) 664*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.amax(zero_d, dim=0, keepdim=False).shape) 665*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.amax(one_d, dim=0, keepdim=True).shape) 666*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.amax(one_d, dim=0, keepdim=False).shape) 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker # min 669*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=True)]) 670*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.min(zero_d, dim=0, keepdim=False)]) 671*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(1,), (1,)], [x.shape for x in torch.min(one_d, dim=0, keepdim=True)]) 672*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(), ()], [x.shape for x in torch.min(one_d, dim=0, keepdim=False)]) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker # amin 675*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=True).shape) 676*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.amin(zero_d, dim=0, keepdim=False).shape) 677*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.amin(one_d, dim=0, keepdim=True).shape) 678*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.amin(one_d, dim=0, keepdim=False).shape) 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker # set_ 681*da0073e9SAndroid Build Coastguard Worker zero_d_clone = zero_d.clone() 682*da0073e9SAndroid Build Coastguard Worker one_d_clone = one_d.clone() 683*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), zero_d_clone.set_(one_d.storage(), 0, (), ()).shape) 684*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), zero_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape) 685*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), one_d_clone.set_(one_d.storage(), 0, (), ()).shape) 686*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), one_d_clone.set_(one_d.storage(), 0, (1,), (1,)).shape) 687*da0073e9SAndroid Build Coastguard Worker 688*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), zero_d.clone().set_(zero_d).shape) 689*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), one_d.clone().set_(zero_d).shape) 690*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), zero_d.clone().set_(one_d).shape) 691*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), one_d.clone().set_(one_d).shape) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker # take 694*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.randn((2, 3), device=device).take(zero_d_int).shape) 695*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.randn((2, 3), device=device).take(one_d_int).shape) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker # gather 698*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.gather(zero_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape) 699*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.gather(zero_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape) 700*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.gather(one_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape) 701*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.gather(one_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker # normal 704*da0073e9SAndroid Build Coastguard Worker # std must be >= 0 705*da0073e9SAndroid Build Coastguard Worker zero_d_ge_0 = torch.rand((), device=device) 706*da0073e9SAndroid Build Coastguard Worker # documentation says out shape matches shape of mean 707*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.normal(zero_d, zero_d_ge_0).shape) 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.normal(one_d, zero_d_ge_0).shape) 709*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.normal(1, zero_d_ge_0).shape) 710*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.normal(zero_d, 1).shape) 711*da0073e9SAndroid Build Coastguard Worker self.assertEqual((1,), torch.normal(one_d, 1).shape) 712*da0073e9SAndroid Build Coastguard Worker # TODO: this behavior differs on CPU and GPU, see https://github.com/pytorch/pytorch/issues/30480. 713*da0073e9SAndroid Build Coastguard Worker # self.assertEqual((), torch.normal(zero_d, one_d).shape) 714*da0073e9SAndroid Build Coastguard Worker # self.assertEqual((), torch.normal(1, one_d).shape) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker # convolutions. Yes, we are testing nn.functional here; seems justified 717*da0073e9SAndroid Build Coastguard Worker # given its similar to the other tests 718*da0073e9SAndroid Build Coastguard Worker w = torch.randn(2, 1, 3, 3, device=device).div_(2).requires_grad_() 719*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=1)) 720*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.nn.functional.conv2d(zero_d, w, groups=2)) 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker # nll_loss -- verify input can't be 0-dimensional. 723*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, zero_d, reduction='none')) 724*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: torch.nn.functional.nll_loss(zero_d, one_d, reduction='none')) 725*da0073e9SAndroid Build Coastguard Worker # verify output is 0-dimensional when reduction != 'none' 726*da0073e9SAndroid Build Coastguard Worker for (input, target) in ((torch.randn(1, 1, device=device), torch.tensor([0], device=device)), 727*da0073e9SAndroid Build Coastguard Worker (torch.randn(1, 1, 1, 1, device=device), torch.tensor([[[0]]], device=device))): 728*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='mean').shape) 729*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), torch.nn.functional.nll_loss(input, target, reduction='sum').shape) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker # Test that `torch._check_tensor_all` raises errors in the correct cases 732*da0073e9SAndroid Build Coastguard Worker def test_check_tensor_all(self, device): 733*da0073e9SAndroid Build Coastguard Worker default_message = 'Expected cond to be True' 734*da0073e9SAndroid Build Coastguard Worker check_fn = torch._check_tensor_all 735*da0073e9SAndroid Build Coastguard Worker expected_error = RuntimeError 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker # cond must be a tensor 738*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 'cond must be a tensor'): 739*da0073e9SAndroid Build Coastguard Worker check_fn(True) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker # cond tensor must be boolean 742*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 'cond tensor must have dtype torch.bool'): 743*da0073e9SAndroid Build Coastguard Worker check_fn(torch.ones(1, device=device)) 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker test_sizes = [ 746*da0073e9SAndroid Build Coastguard Worker (), 747*da0073e9SAndroid Build Coastguard Worker (1,), 748*da0073e9SAndroid Build Coastguard Worker (10,), 749*da0073e9SAndroid Build Coastguard Worker (1, 1), 750*da0073e9SAndroid Build Coastguard Worker (1, 10), 751*da0073e9SAndroid Build Coastguard Worker (10, 1), 752*da0073e9SAndroid Build Coastguard Worker (10, 10), 753*da0073e9SAndroid Build Coastguard Worker (1, 1, 1), 754*da0073e9SAndroid Build Coastguard Worker (10, 1, 1), 755*da0073e9SAndroid Build Coastguard Worker (1, 10, 1), 756*da0073e9SAndroid Build Coastguard Worker (10, 10, 10), 757*da0073e9SAndroid Build Coastguard Worker ] 758*da0073e9SAndroid Build Coastguard Worker for size in test_sizes: 759*da0073e9SAndroid Build Coastguard Worker t_all_true = torch.ones(size, dtype=torch.bool, device=device) 760*da0073e9SAndroid Build Coastguard Worker t_all_false = torch.zeros(size, dtype=torch.bool, device=device) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker # Should not raise error 763*da0073e9SAndroid Build Coastguard Worker check_fn(t_all_true) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, default_message): 766*da0073e9SAndroid Build Coastguard Worker check_fn(t_all_false) 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker if t_all_true.numel() > 1: 769*da0073e9SAndroid Build Coastguard Worker t_all_true_but_one = t_all_true.clone() 770*da0073e9SAndroid Build Coastguard Worker # Choose a random element to set to false 771*da0073e9SAndroid Build Coastguard Worker idx = (random.choice(range(dim_size)) for dim_size in size) 772*da0073e9SAndroid Build Coastguard Worker t_all_true_but_one[(..., *idx)] = False 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, default_message): 775*da0073e9SAndroid Build Coastguard Worker check_fn(t_all_true_but_one) 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker # Test a simple failure message 778*da0073e9SAndroid Build Coastguard Worker message = 'message' 779*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, message): 780*da0073e9SAndroid Build Coastguard Worker check_fn(t_all_false, lambda: message) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker # Test message with tensor 783*da0073e9SAndroid Build Coastguard Worker def message(): 784*da0073e9SAndroid Build Coastguard Worker return torch.arange(4) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, re.escape(str(message()))): 787*da0073e9SAndroid Build Coastguard Worker check_fn(t_all_false, message) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker # Test format string message 790*da0073e9SAndroid Build Coastguard Worker def message(): 791*da0073e9SAndroid Build Coastguard Worker return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}" 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, re.escape(str(message()))): 794*da0073e9SAndroid Build Coastguard Worker check_fn(t_all_false, message) 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker # Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python 797*da0073e9SAndroid Build Coastguard Worker def test_check_tensor_internal(self, device): 798*da0073e9SAndroid Build Coastguard Worker test_sizes = [ 799*da0073e9SAndroid Build Coastguard Worker (), 800*da0073e9SAndroid Build Coastguard Worker (1,), 801*da0073e9SAndroid Build Coastguard Worker (10,), 802*da0073e9SAndroid Build Coastguard Worker (1, 1), 803*da0073e9SAndroid Build Coastguard Worker (1, 10), 804*da0073e9SAndroid Build Coastguard Worker (10, 1), 805*da0073e9SAndroid Build Coastguard Worker (10, 10), 806*da0073e9SAndroid Build Coastguard Worker (1, 1, 1), 807*da0073e9SAndroid Build Coastguard Worker (10, 1, 1), 808*da0073e9SAndroid Build Coastguard Worker (1, 10, 1), 809*da0073e9SAndroid Build Coastguard Worker (10, 10, 10), 810*da0073e9SAndroid Build Coastguard Worker ] 811*da0073e9SAndroid Build Coastguard Worker for size in test_sizes: 812*da0073e9SAndroid Build Coastguard Worker t_all_true = torch.ones(size, dtype=torch.bool, device=device) 813*da0073e9SAndroid Build Coastguard Worker t_all_false = torch.zeros(size, dtype=torch.bool, device=device) 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker # Should not raise error 816*da0073e9SAndroid Build Coastguard Worker torch._test_check_tensor(t_all_true) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"): 819*da0073e9SAndroid Build Coastguard Worker torch._test_check_tensor(t_all_false) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker if t_all_true.numel() > 1: 822*da0073e9SAndroid Build Coastguard Worker t_all_true_but_one = t_all_true.clone() 823*da0073e9SAndroid Build Coastguard Worker # Choose a random element to set to false 824*da0073e9SAndroid Build Coastguard Worker idx = (random.choice(range(dim_size)) for dim_size in size) 825*da0073e9SAndroid Build Coastguard Worker t_all_true_but_one[(..., *idx)] = False 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Test message for TORCH_CHECK_TENSOR_ALL"): 828*da0073e9SAndroid Build Coastguard Worker torch._test_check_tensor(t_all_true_but_one) 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker # Uses mismatched arange out size to trigger a warning 831*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 832*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_CROSSREF, "crossref perturbs line numbering") 833*da0073e9SAndroid Build Coastguard Worker def test_cpp_warnings_have_python_context(self, device): 834*da0073e9SAndroid Build Coastguard Worker # Creates long string in advance to avoid a too-long Python line 835*da0073e9SAndroid Build Coastguard Worker s = ".+Triggered internally at.+RangeFactories.+" 836*da0073e9SAndroid Build Coastguard Worker # nvfuser deprecation warning filter 837*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings("ignore", "torch::jit::fuser::cuda", UserWarning) 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker def cpp_warn_fn(): 840*da0073e9SAndroid Build Coastguard Worker out = torch.empty((5,)) 841*da0073e9SAndroid Build Coastguard Worker torch.arange(0, 3, out=out) 842*da0073e9SAndroid Build Coastguard Worker return out 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker # Checks eager-mode cpp warning 845*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 846*da0073e9SAndroid Build Coastguard Worker cpp_warn_fn() 847*da0073e9SAndroid Build Coastguard Worker frameinfo = inspect.getframeinfo(inspect.currentframe()) 848*da0073e9SAndroid Build Coastguard Worker warning = w[0] 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker # Checks for cpp context in the warning message 851*da0073e9SAndroid Build Coastguard Worker escaped_warning_message = str(warning.message).encode('unicode_escape') 852*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker # Checks the Python features of the warning 855*da0073e9SAndroid Build Coastguard Worker # Note: the eager mode warning refers to the line in the function 856*da0073e9SAndroid Build Coastguard Worker # that throws the warning. 857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frameinfo.lineno - 6, warning.lineno) 858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker # Checks jitted cpp warning 861*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 862*da0073e9SAndroid Build Coastguard Worker scripted_cpp_warn_fn = torch.jit.script(cpp_warn_fn) 863*da0073e9SAndroid Build Coastguard Worker scripted_cpp_warn_fn() 864*da0073e9SAndroid Build Coastguard Worker warning = w[0] 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker # Checks for cpp context in the warning message 867*da0073e9SAndroid Build Coastguard Worker escaped_warning_message = str(warning.message).encode('unicode_escape') 868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search(s, repr(escaped_warning_message), re.IGNORECASE) is not None) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker # Checks the Python features of the warning 871*da0073e9SAndroid Build Coastguard Worker # Note: the jitted warning's lineno refers to the call to the jitted 872*da0073e9SAndroid Build Coastguard Worker # function, which in our test suite has a layer of indirection 873*da0073e9SAndroid Build Coastguard Worker # that makes checking the Python lineno fragile 874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 875*da0073e9SAndroid Build Coastguard Worker 876*da0073e9SAndroid Build Coastguard Worker # Checks jitted Python warning 877*da0073e9SAndroid Build Coastguard Worker def warn_fn(): 878*da0073e9SAndroid Build Coastguard Worker warnings.warn("Warning!") 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker # The jit mimics an eager-mode Python warning in this case 881*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 882*da0073e9SAndroid Build Coastguard Worker scripted_warn_fn = torch.jit.script(warn_fn) 883*da0073e9SAndroid Build Coastguard Worker scripted_warn_fn() 884*da0073e9SAndroid Build Coastguard Worker frameinfo = inspect.getframeinfo(inspect.currentframe()) 885*da0073e9SAndroid Build Coastguard Worker warning = w[0] 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search('Warning!', str(warning.message)) is not None) 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker # Checks the Python features of the warning 890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(frameinfo.lineno - 6, warning.lineno) 891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_testing 894*da0073e9SAndroid Build Coastguard Worker @onlyCPU 895*da0073e9SAndroid Build Coastguard Worker def test_warn_always_caught(self, device): 896*da0073e9SAndroid Build Coastguard Worker # Check that we can catch a TORCH_WARN_ONCE warning twice 897*da0073e9SAndroid Build Coastguard Worker # since assertWarnsOnceRegex uses set_warn_always(True) which changes 898*da0073e9SAndroid Build Coastguard Worker # TORCH_WARN_ONCE to TORCH_WARN 899*da0073e9SAndroid Build Coastguard Worker a = np.arange(10) 900*da0073e9SAndroid Build Coastguard Worker a.flags.writeable = False 901*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'): 902*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a) 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker # OK, got it once, now try again 905*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'): 906*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a) 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker # Make sure emitting two warnings will pass the assertWarnsOnceRegex 909*da0073e9SAndroid Build Coastguard Worker # context manager 910*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, '.*non-writable.*'): 911*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a) 912*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(a) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 915*da0073e9SAndroid Build Coastguard Worker def test_complex_half_experimental_warning(self, device): 916*da0073e9SAndroid Build Coastguard Worker msg = 'ComplexHalf support is experimental' 917*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 918*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, dtype=torch.chalf, device=device) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 921*da0073e9SAndroid Build Coastguard Worker torch.rand(3, dtype=torch.chalf, device=device) 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 924*da0073e9SAndroid Build Coastguard Worker torch.empty(3, dtype=torch.chalf, device=device) 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 927*da0073e9SAndroid Build Coastguard Worker torch.ones(3, dtype=torch.chalf, device=device) 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 930*da0073e9SAndroid Build Coastguard Worker torch.zeros(3, dtype=torch.chalf, device=device) 931*da0073e9SAndroid Build Coastguard Worker 932*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 933*da0073e9SAndroid Build Coastguard Worker torch.randn_like(t) 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 936*da0073e9SAndroid Build Coastguard Worker torch.rand_like(t) 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 939*da0073e9SAndroid Build Coastguard Worker torch.empty_like(t) 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 942*da0073e9SAndroid Build Coastguard Worker torch.ones_like(t) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 945*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(t) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 948*da0073e9SAndroid Build Coastguard Worker # t + 1 allocates a new tensor for result using empty 949*da0073e9SAndroid Build Coastguard Worker t + 1 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 952*da0073e9SAndroid Build Coastguard Worker def test_dtypetensor_warnings(self, device): 953*da0073e9SAndroid Build Coastguard Worker msg = 'The torch.cuda.*DtypeTensor constructors are no longer recommended' 954*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 955*da0073e9SAndroid Build Coastguard Worker t = torch.cuda.FloatTensor([0]) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 958*da0073e9SAndroid Build Coastguard Worker t = torch.cuda.DoubleTensor([0]) 959*da0073e9SAndroid Build Coastguard Worker 960*da0073e9SAndroid Build Coastguard Worker def test_set_default_tensor_type_warnings(self, device): 961*da0073e9SAndroid Build Coastguard Worker msg = '.*is deprecated as of PyTorch 2.1, please use torch.set_default_dtype().*' 962*da0073e9SAndroid Build Coastguard Worker default_type = torch.tensor([]).type() 963*da0073e9SAndroid Build Coastguard Worker try: 964*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 965*da0073e9SAndroid Build Coastguard Worker torch.set_default_tensor_type(torch.FloatTensor) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 968*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex(UserWarning, msg): 969*da0073e9SAndroid Build Coastguard Worker torch.set_default_tensor_type(torch.cuda.FloatTensor) 970*da0073e9SAndroid Build Coastguard Worker finally: 971*da0073e9SAndroid Build Coastguard Worker torch.set_default_tensor_type(default_type) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker # TODO: this test should be in test_nn.py 974*da0073e9SAndroid Build Coastguard Worker def test_conv_transposed_backward_agnostic_to_memory_format(self, device): 975*da0073e9SAndroid Build Coastguard Worker in_channels = 64 976*da0073e9SAndroid Build Coastguard Worker out_channels = 128 977*da0073e9SAndroid Build Coastguard Worker scale_factor = 8 978*da0073e9SAndroid Build Coastguard Worker batch_size = 8 979*da0073e9SAndroid Build Coastguard Worker length = 16 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.ConvTranspose1d( 982*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, kernel_size=scale_factor * 2, stride=scale_factor).to(device) 983*da0073e9SAndroid Build Coastguard Worker layer_norm = torch.nn.LayerNorm(out_channels).to(device) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker input_ = torch.randn(batch_size, in_channels, length).to(device).contiguous() 986*da0073e9SAndroid Build Coastguard Worker input_ = conv(input_).contiguous() 987*da0073e9SAndroid Build Coastguard Worker input_ = layer_norm(input_.transpose(1, 2).contiguous()).contiguous() 988*da0073e9SAndroid Build Coastguard Worker input_.sum().backward() 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker # 3d 991*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.ConvTranspose3d(3, 3, kernel_size=3).to(device) 992*da0073e9SAndroid Build Coastguard Worker input = torch.randn(batch_size, 3, length, length, length, device=device) 993*da0073e9SAndroid Build Coastguard Worker out = conv(input) 994*da0073e9SAndroid Build Coastguard Worker out.backward(torch.ones_like(out).transpose(-2, -1)) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker # TODO: this test should be in test_nn.py 997*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 998*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('12GB') 999*da0073e9SAndroid Build Coastguard Worker def test_conv_transposed_large(self, device): 1000*da0073e9SAndroid Build Coastguard Worker # ConvTranspose3d works for large input tensors (gh-32866) 1001*da0073e9SAndroid Build Coastguard Worker in_channels = 64 1002*da0073e9SAndroid Build Coastguard Worker out_channels = 128 1003*da0073e9SAndroid Build Coastguard Worker kernel_size = 5 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker conv = torch.nn.ConvTranspose3d( 1006*da0073e9SAndroid Build Coastguard Worker in_channels, out_channels, kernel_size=kernel_size, 1007*da0073e9SAndroid Build Coastguard Worker stride=2, padding=2, output_padding=1).to(device) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker x = torch.rand([1, 64, 8, 128, 172]).to(device) 1010*da0073e9SAndroid Build Coastguard Worker y = conv(x) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker def test_is_set_to(self, device): 1013*da0073e9SAndroid Build Coastguard Worker t1 = torch.empty(3, 4, 9, 10, device=device) 1014*da0073e9SAndroid Build Coastguard Worker t2 = torch.empty(3, 4, 9, 10, device=device) 1015*da0073e9SAndroid Build Coastguard Worker t3 = torch.tensor([], device=device).set_(t1) 1016*da0073e9SAndroid Build Coastguard Worker t4 = t3.clone().resize_(12, 90) 1017*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_set_to(t2)) 1018*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t1.is_set_to(t3)) 1019*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t3.is_set_to(t1), "is_set_to should be symmetric") 1020*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_set_to(t4)) 1021*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor([]).is_set_to(torch.tensor([])), 1022*da0073e9SAndroid Build Coastguard Worker "Tensors with no storages should not appear to be set " 1023*da0073e9SAndroid Build Coastguard Worker "to each other") 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([True, True], dtype=torch.bool, device=device) 1026*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([0], dtype=torch.bool, device=device).set_(t1) 1027*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t1.is_set_to(t2)) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker # test that sizes must match 1030*da0073e9SAndroid Build Coastguard Worker t1 = torch.empty([2, 3, 4], device=device) 1031*da0073e9SAndroid Build Coastguard Worker t2 = t1.view(4, 3, 2) 1032*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_set_to(t2)) 1033*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t2.is_set_to(t1)) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker # test that legacy empty size behavior used to be respected (i.e. all 1036*da0073e9SAndroid Build Coastguard Worker # empty tensors were logically collapsed to size [0]). 1037*da0073e9SAndroid Build Coastguard Worker t1 = torch.empty([2, 5, 0], device=device) 1038*da0073e9SAndroid Build Coastguard Worker t2 = t1.view([0]) 1039*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_set_to(t2)) 1040*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t2.is_set_to(t1)) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/72650 1043*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1044*da0073e9SAndroid Build Coastguard Worker @skipMeta 1045*da0073e9SAndroid Build Coastguard Worker @parametrize( 1046*da0073e9SAndroid Build Coastguard Worker "fn", 1047*da0073e9SAndroid Build Coastguard Worker [ 1048*da0073e9SAndroid Build Coastguard Worker "dist", "atan2", "pow", "lerp", "add", "sub", "mul", "div", "fmod", "remainder", "eq", "ge", "gt", "le", 1049*da0073e9SAndroid Build Coastguard Worker "lt", "max", "min", "ne", "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", "map", 1050*da0073e9SAndroid Build Coastguard Worker "map2", "copy", 1051*da0073e9SAndroid Build Coastguard Worker ], 1052*da0073e9SAndroid Build Coastguard Worker ) 1053*da0073e9SAndroid Build Coastguard Worker def test_broadcast(self, fn, device): 1054*da0073e9SAndroid Build Coastguard Worker # functions with three tensor arguments 1055*da0073e9SAndroid Build Coastguard Worker fns_3_args = {"map2"} 1056*da0073e9SAndroid Build Coastguard Worker fns_value_kwarg = {"addcdiv", "addcmul"} 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() 1059*da0073e9SAndroid Build Coastguard Worker full1d = torch.randn(*dims_full, device=device).flatten().float() 1060*da0073e9SAndroid Build Coastguard Worker small = torch.randn(*dims_small, device=device).float() 1061*da0073e9SAndroid Build Coastguard Worker large = torch.randn(*dims_large, device=device).float() 1062*da0073e9SAndroid Build Coastguard Worker small_expanded = small.expand(*dims_full) 1063*da0073e9SAndroid Build Coastguard Worker large_expanded = large.expand(*dims_full) 1064*da0073e9SAndroid Build Coastguard Worker small2 = None 1065*da0073e9SAndroid Build Coastguard Worker small2_expanded = None 1066*da0073e9SAndroid Build Coastguard Worker if fn in fns_3_args or fn in fns_value_kwarg: 1067*da0073e9SAndroid Build Coastguard Worker # create another smaller tensor 1068*da0073e9SAndroid Build Coastguard Worker (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) 1069*da0073e9SAndroid Build Coastguard Worker small2 = torch.randn(*dims_small2, device=device).float() 1070*da0073e9SAndroid Build Coastguard Worker small2_expanded = small2.expand(*dims_full) 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker if small.is_cuda and fn in ['map', 'map2']: 1073*da0073e9SAndroid Build Coastguard Worker # map and map2 are not implementd on CUDA tensors 1074*da0073e9SAndroid Build Coastguard Worker return 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker if hasattr(large_expanded, fn): 1077*da0073e9SAndroid Build Coastguard Worker # run through tensor versions of functions 1078*da0073e9SAndroid Build Coastguard Worker # and verify fully expanded inputs give same results 1079*da0073e9SAndroid Build Coastguard Worker expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker def tensorfn(myfn, t1, t2): 1082*da0073e9SAndroid Build Coastguard Worker if fn == "lerp": 1083*da0073e9SAndroid Build Coastguard Worker return myfn(t1, 0.5) 1084*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_select": 1085*da0073e9SAndroid Build Coastguard Worker return myfn(t1 < 0) 1086*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_scatter": 1087*da0073e9SAndroid Build Coastguard Worker return myfn(t1 < 0.5, full1d) 1088*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_fill": 1089*da0073e9SAndroid Build Coastguard Worker return myfn(t1 < 0.5, 1.0) 1090*da0073e9SAndroid Build Coastguard Worker elif fn in fns_3_args: 1091*da0073e9SAndroid Build Coastguard Worker return myfn(1, t1, t2) 1092*da0073e9SAndroid Build Coastguard Worker elif fn in fns_value_kwarg: 1093*da0073e9SAndroid Build Coastguard Worker return myfn(t1, t2, value=1) 1094*da0073e9SAndroid Build Coastguard Worker else: 1095*da0073e9SAndroid Build Coastguard Worker return myfn(t1) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker # test various orders 1098*da0073e9SAndroid Build Coastguard Worker for first, second, third in [(large, small, small2), (small, large, small2), 1099*da0073e9SAndroid Build Coastguard Worker (small2, small, large), (small2, large, small)]: 1100*da0073e9SAndroid Build Coastguard Worker if first is None: 1101*da0073e9SAndroid Build Coastguard Worker break # ignore last iter when small2 is None 1102*da0073e9SAndroid Build Coastguard Worker method_expanded = getattr(expanded[first], fn) 1103*da0073e9SAndroid Build Coastguard Worker method = getattr(first, fn) 1104*da0073e9SAndroid Build Coastguard Worker r1 = tensorfn(method_expanded, expanded[second], expanded[third]) 1105*da0073e9SAndroid Build Coastguard Worker r2 = tensorfn(method, second, third) 1106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2) 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker # now for torch. versions of functions 1109*da0073e9SAndroid Build Coastguard Worker if hasattr(torch, fn): 1110*da0073e9SAndroid Build Coastguard Worker fntorch = getattr(torch, fn) 1111*da0073e9SAndroid Build Coastguard Worker expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker def torchfn(t1, t2, t3): 1114*da0073e9SAndroid Build Coastguard Worker if fn == "lerp": 1115*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, t2, 0.5) 1116*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_select": 1117*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, t2 < 0) 1118*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_scatter": 1119*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, t2 < 0.5, full1d) 1120*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_fill": 1121*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, t2 < 0.5, 1.0) 1122*da0073e9SAndroid Build Coastguard Worker elif fn in fns_3_args: 1123*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, 1.0, t2, t3) 1124*da0073e9SAndroid Build Coastguard Worker elif fn in fns_value_kwarg: 1125*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, t2, t3, value=1.0) 1126*da0073e9SAndroid Build Coastguard Worker else: 1127*da0073e9SAndroid Build Coastguard Worker return fntorch(t1, t2) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker # test various orders 1130*da0073e9SAndroid Build Coastguard Worker for first, second, third in [(large, small, small2), (small, large, small2), 1131*da0073e9SAndroid Build Coastguard Worker (small2, small, large), (small2, large, small)]: 1132*da0073e9SAndroid Build Coastguard Worker if first is None: 1133*da0073e9SAndroid Build Coastguard Worker break # ignore last iter when small2 is None 1134*da0073e9SAndroid Build Coastguard Worker r1 = torchfn(expanded[first], expanded[second], expanded[third]) 1135*da0073e9SAndroid Build Coastguard Worker r2 = torchfn(first, second, third) 1136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2) 1137*da0073e9SAndroid Build Coastguard Worker 1138*da0073e9SAndroid Build Coastguard Worker # now for in place functions 1139*da0073e9SAndroid Build Coastguard Worker # in-place tensor is not broadcastable; test only guaranteed 1140*da0073e9SAndroid Build Coastguard Worker # to work by broadcasting other argument(s) 1141*da0073e9SAndroid Build Coastguard Worker if not hasattr(large_expanded, fn + "_"): 1142*da0073e9SAndroid Build Coastguard Worker return 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker # need to clone largeExpanded so we can reuse, since functions are in-place 1145*da0073e9SAndroid Build Coastguard Worker large_expanded_clone = large_expanded.clone() 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker def tensorfn_inplace(t0, t1, t2=None): 1148*da0073e9SAndroid Build Coastguard Worker t0_fn = getattr(t0, fn + "_") 1149*da0073e9SAndroid Build Coastguard Worker if fn == "lerp": 1150*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1, 0.5) 1151*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_scatter": 1152*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1 < 0.5, full1d) 1153*da0073e9SAndroid Build Coastguard Worker elif fn == "masked_fill": 1154*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1 < 0.5, 1.0) 1155*da0073e9SAndroid Build Coastguard Worker elif fn == "map": 1156*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1, lambda x, y: x + y) 1157*da0073e9SAndroid Build Coastguard Worker elif fn == "map2": 1158*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1, t2, lambda x, y, z: x + y + z) 1159*da0073e9SAndroid Build Coastguard Worker elif fn in fns_3_args: 1160*da0073e9SAndroid Build Coastguard Worker return t0_fn(1.0, t1, t2) 1161*da0073e9SAndroid Build Coastguard Worker elif fn in fns_value_kwarg: 1162*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1, t2, value=1.0) 1163*da0073e9SAndroid Build Coastguard Worker else: 1164*da0073e9SAndroid Build Coastguard Worker return t0_fn(t1) 1165*da0073e9SAndroid Build Coastguard Worker # in-place pointwise operations don't actually work if the in-place 1166*da0073e9SAndroid Build Coastguard Worker # tensor is 0-strided (numpy has the same issue) 1167*da0073e9SAndroid Build Coastguard Worker if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): 1168*da0073e9SAndroid Build Coastguard Worker r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) 1169*da0073e9SAndroid Build Coastguard Worker r2 = tensorfn_inplace(large_expanded_clone, small, small2) 1170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker def broadcastable(t0, t1, t2=None): 1173*da0073e9SAndroid Build Coastguard Worker try: 1174*da0073e9SAndroid Build Coastguard Worker t1.expand_as(t0) 1175*da0073e9SAndroid Build Coastguard Worker if t2 is not None: 1176*da0073e9SAndroid Build Coastguard Worker t2.expand_as(t0) 1177*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 1178*da0073e9SAndroid Build Coastguard Worker return False 1179*da0073e9SAndroid Build Coastguard Worker return True 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker def _test_in_place_broadcastable(t0, t1, t2=None): 1182*da0073e9SAndroid Build Coastguard Worker if not broadcastable(t0, t1, t2): 1183*da0073e9SAndroid Build Coastguard Worker same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) 1184*da0073e9SAndroid Build Coastguard Worker if not same_size: 1185*da0073e9SAndroid Build Coastguard Worker # Functionalization converts the inplace to an out-of-place, which causes us to error. 1186*da0073e9SAndroid Build Coastguard Worker # We should fix this, but "error probably on bad inputs" isn't a hi-pri PT2 item. 1187*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 1188*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) 1189*da0073e9SAndroid Build Coastguard Worker else: 1190*da0073e9SAndroid Build Coastguard Worker tensorfn_inplace(t0, t1, t2) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker if fn not in fns_3_args and fn not in fns_value_kwarg: 1193*da0073e9SAndroid Build Coastguard Worker _test_in_place_broadcastable(small, large_expanded) 1194*da0073e9SAndroid Build Coastguard Worker _test_in_place_broadcastable(small, large) 1195*da0073e9SAndroid Build Coastguard Worker else: 1196*da0073e9SAndroid Build Coastguard Worker _test_in_place_broadcastable(small2, small_expanded, large_expanded) 1197*da0073e9SAndroid Build Coastguard Worker _test_in_place_broadcastable(small2, small, large) 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") 1200*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 1201*da0073e9SAndroid Build Coastguard Worker @wrapDeterministicFlagAPITest 1202*da0073e9SAndroid Build Coastguard Worker def test_cublas_config_nondeterministic_alert(self, device): 1203*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1204*da0073e9SAndroid Build Coastguard Worker # (function, (tensor sizes)) 1205*da0073e9SAndroid Build Coastguard Worker ('mm', ((2, 2), (2, 2),)), 1206*da0073e9SAndroid Build Coastguard Worker ('mv', ((2, 2), (2,),)), 1207*da0073e9SAndroid Build Coastguard Worker ('bmm', ((1, 2, 2), (1, 2, 2),))] 1208*da0073e9SAndroid Build Coastguard Worker 1209*da0073e9SAndroid Build Coastguard Worker test_configs = [ 1210*da0073e9SAndroid Build Coastguard Worker # (CuBLAS workspace config, is deterministic) 1211*da0073e9SAndroid Build Coastguard Worker ('garbage', False), 1212*da0073e9SAndroid Build Coastguard Worker (None, False), 1213*da0073e9SAndroid Build Coastguard Worker (':4096:8', True), 1214*da0073e9SAndroid Build Coastguard Worker (':16:8', True)] 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker cublas_var_name = 'CUBLAS_WORKSPACE_CONFIG' 1217*da0073e9SAndroid Build Coastguard Worker is_cuda10_2_or_higher = ( 1218*da0073e9SAndroid Build Coastguard Worker (torch.version.cuda is not None) 1219*da0073e9SAndroid Build Coastguard Worker and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def test_case_info(fn_name, config): 1222*da0073e9SAndroid Build Coastguard Worker return f'function "{fn_name}" with config "{"" if config is None else config}"' 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker # Create processes to test each combination of test cases and config settings 1225*da0073e9SAndroid Build Coastguard Worker processes = [] 1226*da0073e9SAndroid Build Coastguard Worker for fn_name, arg_sizes in test_cases: 1227*da0073e9SAndroid Build Coastguard Worker for config, is_config_deterministic in test_configs: 1228*da0073e9SAndroid Build Coastguard Worker env = os.environ.copy() 1229*da0073e9SAndroid Build Coastguard Worker if config is None: 1230*da0073e9SAndroid Build Coastguard Worker if env.get(cublas_var_name) is not None: 1231*da0073e9SAndroid Build Coastguard Worker del env[cublas_var_name] 1232*da0073e9SAndroid Build Coastguard Worker else: 1233*da0073e9SAndroid Build Coastguard Worker env[cublas_var_name] = config 1234*da0073e9SAndroid Build Coastguard Worker should_throw_error = is_cuda10_2_or_higher and not is_config_deterministic 1235*da0073e9SAndroid Build Coastguard Worker script = f""" 1236*da0073e9SAndroid Build Coastguard Workerimport torch 1237*da0073e9SAndroid Build Coastguard Workertorch.use_deterministic_algorithms(True) 1238*da0073e9SAndroid Build Coastguard Workerfn = torch.{fn_name} 1239*da0073e9SAndroid Build Coastguard Workerarg_sizes = {arg_sizes} 1240*da0073e9SAndroid Build Coastguard Workerdevice = '{device}' 1241*da0073e9SAndroid Build Coastguard Workershould_throw_error = {should_throw_error} 1242*da0073e9SAndroid Build Coastguard Workerargs = [] 1243*da0073e9SAndroid Build Coastguard Workerfor arg_size in arg_sizes: 1244*da0073e9SAndroid Build Coastguard Worker args.append(torch.randn(*arg_size, device=device)) 1245*da0073e9SAndroid Build Coastguard Workertry: 1246*da0073e9SAndroid Build Coastguard Worker fn(*args) 1247*da0073e9SAndroid Build Coastguard Workerexcept RuntimeError as e: 1248*da0073e9SAndroid Build Coastguard Worker if not should_throw_error: 1249*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Did not expect any error to be raised') 1250*da0073e9SAndroid Build Coastguard Worker elif 'Deterministic behavior was enabled with either' not in str(e): 1251*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Expected a CuBLAS nondeterministic error, but got a different error') 1252*da0073e9SAndroid Build Coastguard Workerelse: 1253*da0073e9SAndroid Build Coastguard Worker if should_throw_error: 1254*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Expected a CuBLAS nondeterministic error, but it was not raised') 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker""" 1257*da0073e9SAndroid Build Coastguard Worker try: 1258*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 1259*da0073e9SAndroid Build Coastguard Worker [sys.executable, '-c', script], 1260*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 1261*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 1262*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 1263*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 1264*da0073e9SAndroid Build Coastguard Worker env=env) 1265*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 1266*da0073e9SAndroid Build Coastguard Worker self.fail(msg=( 1267*da0073e9SAndroid Build Coastguard Worker f'Subprocess exception while attempting to run {test_case_info(fn_name, config)}:\n' 1268*da0073e9SAndroid Build Coastguard Worker + e.output.decode("utf-8"))) 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1271*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1272*da0073e9SAndroid Build Coastguard Worker @dtypes(*get_all_qint_dtypes()) 1273*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_resize_quantized(self, device, dtype): 1274*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([-1, 0, 1, 2, 3], dtype=torch.float, device=device) 1275*da0073e9SAndroid Build Coastguard Worker b = torch.quantize_per_tensor(a, 0.1, 10, dtype) 1276*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1277*da0073e9SAndroid Build Coastguard Worker lambda: b.resize_((10,)), 1278*da0073e9SAndroid Build Coastguard Worker 'quantized_resize_cpu_') 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker @skipXLA 1281*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1282*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) 1283*da0073e9SAndroid Build Coastguard Worker def test_deterministic_resize(self, device, dtype): 1284*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1285*da0073e9SAndroid Build Coastguard Worker # size, stride, resize_size 1286*da0073e9SAndroid Build Coastguard Worker ((10,), (1,), (5,)), 1287*da0073e9SAndroid Build Coastguard Worker ((10,), (0,), (10,)), 1288*da0073e9SAndroid Build Coastguard Worker ((10,), (1,), (20,)), 1289*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), None, (2, 3, 4)), 1290*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), None, (6, 3, 4)), 1291*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), None, (2, 5, 4)), 1292*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), None, (2, 3, 6)), 1293*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), None, (3, 4, 5)), 1294*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), (1, 4, 12), (2, 3, 4)), 1295*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), (1, 4, 12), (4, 3, 4)), 1296*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), (1, 4, 12), (2, 4, 4)), 1297*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), (1, 4, 12), (2, 3, 5)), 1298*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), (1, 4, 12), (3, 4, 5)), 1299*da0073e9SAndroid Build Coastguard Worker ((2, 3, 4), (1, 0, 1), (2, 4, 5)), 1300*da0073e9SAndroid Build Coastguard Worker ] 1301*da0073e9SAndroid Build Coastguard Worker 1302*da0073e9SAndroid Build Coastguard Worker for size, stride, resize_size in test_cases: 1303*da0073e9SAndroid Build Coastguard Worker if stride is None: 1304*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(size, dtype=dtype, device=device) 1305*da0073e9SAndroid Build Coastguard Worker else: 1306*da0073e9SAndroid Build Coastguard Worker a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0) 1307*da0073e9SAndroid Build Coastguard Worker old_storage = a.untyped_storage().clone() 1308*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True, fill_uninitialized_memory=True): 1309*da0073e9SAndroid Build Coastguard Worker a.resize_(resize_size) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker new_storage = a.untyped_storage() 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker # If storage size was increased, check that the new section is 1314*da0073e9SAndroid Build Coastguard Worker # filled with NaN/MAX_INT. Otherwise, check that the storages are 1315*da0073e9SAndroid Build Coastguard Worker # equal. 1316*da0073e9SAndroid Build Coastguard Worker old_tensor = torch.tensor(old_storage, dtype=dtype) 1317*da0073e9SAndroid Build Coastguard Worker old_numel = old_tensor.numel() 1318*da0073e9SAndroid Build Coastguard Worker new_tensor = torch.tensor(new_storage, dtype=dtype) 1319*da0073e9SAndroid Build Coastguard Worker new_numel = new_tensor.numel() 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker if new_numel > old_numel: 1322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_tensor[:old_numel], old_tensor) 1323*da0073e9SAndroid Build Coastguard Worker fill_section = new_tensor[old_numel:] 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 1326*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fill_section.isnan().all()) 1327*da0073e9SAndroid Build Coastguard Worker else: 1328*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 1329*da0073e9SAndroid Build Coastguard Worker max_val = True 1330*da0073e9SAndroid Build Coastguard Worker else: 1331*da0073e9SAndroid Build Coastguard Worker max_val = torch.iinfo(dtype).max 1332*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fill_section.eq(max_val).all()) 1333*da0073e9SAndroid Build Coastguard Worker else: 1334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(old_tensor, new_tensor) 1335*da0073e9SAndroid Build Coastguard Worker 1336*da0073e9SAndroid Build Coastguard Worker # When deterministic algorithms are enabled, `torch.empty` should fill floating 1337*da0073e9SAndroid Build Coastguard Worker # point tensors with NaN and integer tensors with MAX_INT 1338*da0073e9SAndroid Build Coastguard Worker @skipXLA 1339*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1340*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) 1341*da0073e9SAndroid Build Coastguard Worker def test_deterministic_empty(self, device, dtype): 1342*da0073e9SAndroid Build Coastguard Worker gen_fns = [ 1343*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty(10, 9, device=device, dtype=dtype), 1344*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty(10, 9, out=torch.zeros(1, device=device, dtype=dtype)), 1345*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype)), 1346*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_like(torch.zeros(10, 9, device=device, dtype=dtype), memory_format=torch.contiguous_format), 1347*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_strided((10, 9), (1, 5), device=device, dtype=dtype), 1348*da0073e9SAndroid Build Coastguard Worker lambda: torch.empty_permuted((2, 3, 5), (1, 0, 2), device=device, dtype=dtype), 1349*da0073e9SAndroid Build Coastguard Worker ] 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker for gen_fn in gen_fns: 1352*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True, fill_uninitialized_memory=True): 1353*da0073e9SAndroid Build Coastguard Worker res = gen_fn() 1354*da0073e9SAndroid Build Coastguard Worker 1355*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 1356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.isnan().all()) 1357*da0073e9SAndroid Build Coastguard Worker else: 1358*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 1359*da0073e9SAndroid Build Coastguard Worker max_val = True 1360*da0073e9SAndroid Build Coastguard Worker else: 1361*da0073e9SAndroid Build Coastguard Worker max_val = torch.iinfo(dtype).max 1362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.eq(max_val).all()) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker # FIXME: update OpInfos to support "nondeterministic samples" and port these tests 1365*da0073e9SAndroid Build Coastguard Worker # to that architecture 1366*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1367*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1368*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_AvgPool3d(self, device): 1369*da0073e9SAndroid Build Coastguard Worker module = torch.nn.AvgPool3d(3) 1370*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device) 1371*da0073e9SAndroid Build Coastguard Worker res = module(input) 1372*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1375*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1376*da0073e9SAndroid Build Coastguard Worker 'avg_pool3d_backward_cuda', 1377*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1380*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1381*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_AdaptiveAvgPool2d(self, device): 1382*da0073e9SAndroid Build Coastguard Worker module = torch.nn.AdaptiveAvgPool2d(3) 1383*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, requires_grad=True, device=device) 1384*da0073e9SAndroid Build Coastguard Worker res = module(input) 1385*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1388*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1389*da0073e9SAndroid Build Coastguard Worker 'adaptive_avg_pool2d_backward_cuda', 1390*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1393*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1394*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_AdaptiveAvgPool3d(self, device): 1395*da0073e9SAndroid Build Coastguard Worker module = torch.nn.AdaptiveAvgPool3d(3) 1396*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device) 1397*da0073e9SAndroid Build Coastguard Worker res = module(input) 1398*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1401*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1402*da0073e9SAndroid Build Coastguard Worker 'adaptive_avg_pool3d_backward_cuda', 1403*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1406*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1407*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_MaxPool3d(self, device): 1408*da0073e9SAndroid Build Coastguard Worker module = torch.nn.MaxPool3d(3) 1409*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device) 1410*da0073e9SAndroid Build Coastguard Worker res = module(input) 1411*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1414*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1415*da0073e9SAndroid Build Coastguard Worker 'max_pool3d_with_indices_backward_cuda', 1416*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1417*da0073e9SAndroid Build Coastguard Worker 1418*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1419*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1420*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_AdaptiveMaxPool2d(self, device): 1421*da0073e9SAndroid Build Coastguard Worker module = torch.nn.AdaptiveMaxPool2d(3) 1422*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, requires_grad=True, device=device) 1423*da0073e9SAndroid Build Coastguard Worker res = module(input) 1424*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1427*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1428*da0073e9SAndroid Build Coastguard Worker 'adaptive_max_pool2d_backward_cuda', 1429*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1432*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1433*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_FractionalMaxPool2d(self, device): 1434*da0073e9SAndroid Build Coastguard Worker module = torch.nn.FractionalMaxPool2d(2, output_ratio=0.5) 1435*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, 3, requires_grad=True, device=device) 1436*da0073e9SAndroid Build Coastguard Worker res = module(input) 1437*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1438*da0073e9SAndroid Build Coastguard Worker 1439*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1440*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1441*da0073e9SAndroid Build Coastguard Worker 'fractional_max_pool2d_backward_cuda', 1442*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1445*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1446*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_FractionalMaxPool3d(self, device): 1447*da0073e9SAndroid Build Coastguard Worker module = torch.nn.FractionalMaxPool3d(2, output_ratio=0.5) 1448*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 3, 3, 3, requires_grad=True, device=device) 1449*da0073e9SAndroid Build Coastguard Worker res = module(input) 1450*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1453*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1454*da0073e9SAndroid Build Coastguard Worker 'fractional_max_pool3d_backward_cuda', 1455*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 1458*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1459*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_MaxUnpool1d(self, device, dtype): 1460*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and torch.device(device).type == 'cpu': 1461*da0073e9SAndroid Build Coastguard Worker self.skipTest('float16 not implemented on CPU') 1462*da0073e9SAndroid Build Coastguard Worker 1463*da0073e9SAndroid Build Coastguard Worker module = torch.nn.MaxUnpool1d(3, 1) 1464*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 7, dtype=dtype, device=device) 1465*da0073e9SAndroid Build Coastguard Worker indices = torch.zeros_like(input, dtype=torch.long, device=device) 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1468*da0073e9SAndroid Build Coastguard Worker lambda: module(input, indices), 1469*da0073e9SAndroid Build Coastguard Worker 'max_unpooling2d_forward_out') 1470*da0073e9SAndroid Build Coastguard Worker 1471*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 1472*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1473*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_MaxUnpool2d(self, device, dtype): 1474*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and torch.device(device).type == 'cpu': 1475*da0073e9SAndroid Build Coastguard Worker self.skipTest('float16 not implemented on CPU') 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker module = torch.nn.MaxUnpool2d(3, 1) 1478*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 7, 7, dtype=dtype, device=device) 1479*da0073e9SAndroid Build Coastguard Worker indices = torch.zeros_like(input, dtype=torch.long, device=device) 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1482*da0073e9SAndroid Build Coastguard Worker lambda: module(input, indices), 1483*da0073e9SAndroid Build Coastguard Worker 'max_unpooling2d_forward_out') 1484*da0073e9SAndroid Build Coastguard Worker 1485*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 1486*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1487*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_MaxUnpool3d(self, device, dtype): 1488*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and torch.device(device).type == 'cpu': 1489*da0073e9SAndroid Build Coastguard Worker self.skipTest('float16 not implemented on CPU') 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker module = torch.nn.MaxUnpool3d(3, 1) 1492*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 7, 7, 7, dtype=dtype, device=device) 1493*da0073e9SAndroid Build Coastguard Worker indices = torch.zeros_like(input, dtype=torch.long, device=device) 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1496*da0073e9SAndroid Build Coastguard Worker lambda: module(input, indices), 1497*da0073e9SAndroid Build Coastguard Worker 'max_unpooling3d_forward_out') 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1500*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1501*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_interpolate_linear(self, device): 1502*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 4, device=device, requires_grad=True) 1503*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.interpolate( 1504*da0073e9SAndroid Build Coastguard Worker input, 1505*da0073e9SAndroid Build Coastguard Worker size=12, 1506*da0073e9SAndroid Build Coastguard Worker mode='linear', 1507*da0073e9SAndroid Build Coastguard Worker align_corners=False) 1508*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1511*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad), 1512*da0073e9SAndroid Build Coastguard Worker 'upsample_linear1d_backward_out_cuda', 1513*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1514*da0073e9SAndroid Build Coastguard Worker 1515*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1516*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_interpolate_bilinear(self, device): 1517*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) 1518*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.interpolate( 1519*da0073e9SAndroid Build Coastguard Worker input, 1520*da0073e9SAndroid Build Coastguard Worker size=12, 1521*da0073e9SAndroid Build Coastguard Worker mode='bilinear', 1522*da0073e9SAndroid Build Coastguard Worker align_corners=False) 1523*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1526*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad), 1527*da0073e9SAndroid Build Coastguard Worker 'upsample_bilinear2d_backward_out_cuda', 1528*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("aot-autograd issue") 1531*da0073e9SAndroid Build Coastguard Worker def test_deterministic_replication_pad2d(self, device): 1532*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1533*da0073e9SAndroid Build Coastguard Worker # size, padding 1534*da0073e9SAndroid Build Coastguard Worker [(1, 2, 4, 4), (0, 0, 0, 0)], 1535*da0073e9SAndroid Build Coastguard Worker [(1, 2, 4, 4), (3, 4, 5, 6)], 1536*da0073e9SAndroid Build Coastguard Worker [(3, 8, 7), (0, 0, 0, 0)], 1537*da0073e9SAndroid Build Coastguard Worker [(3, 8, 7), (4, 3, 2, 7)], 1538*da0073e9SAndroid Build Coastguard Worker ] 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type != 'xla': 1541*da0073e9SAndroid Build Coastguard Worker test_cases += [ 1542*da0073e9SAndroid Build Coastguard Worker [(4, 3, 5, 10), (-9, 4, 5, 6)], 1543*da0073e9SAndroid Build Coastguard Worker [(3, 8, 7), (-4, -2, -2, -3)], 1544*da0073e9SAndroid Build Coastguard Worker ] 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker for size, padding in test_cases: 1547*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*size, device=device, requires_grad=True) 1548*da0073e9SAndroid Build Coastguard Worker grad = None 1549*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 1550*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.pad( 1551*da0073e9SAndroid Build Coastguard Worker input, 1552*da0073e9SAndroid Build Coastguard Worker padding, 1553*da0073e9SAndroid Build Coastguard Worker mode='replicate') 1554*da0073e9SAndroid Build Coastguard Worker res.backward(torch.ones_like(res)) 1555*da0073e9SAndroid Build Coastguard Worker if grad is None: 1556*da0073e9SAndroid Build Coastguard Worker grad = input.grad 1557*da0073e9SAndroid Build Coastguard Worker else: 1558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, input.grad, atol=0, rtol=0) 1559*da0073e9SAndroid Build Coastguard Worker input.grad = None 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1562*da0073e9SAndroid Build Coastguard Worker def test_deterministic_interpolate_bilinear(self, device): 1563*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) 1564*da0073e9SAndroid Build Coastguard Worker grad = None 1565*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 1566*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 1567*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.interpolate( 1568*da0073e9SAndroid Build Coastguard Worker input, 1569*da0073e9SAndroid Build Coastguard Worker size=12, 1570*da0073e9SAndroid Build Coastguard Worker mode='bilinear', 1571*da0073e9SAndroid Build Coastguard Worker align_corners=False) 1572*da0073e9SAndroid Build Coastguard Worker res.backward(torch.ones_like(res)) 1573*da0073e9SAndroid Build Coastguard Worker if grad is None: 1574*da0073e9SAndroid Build Coastguard Worker grad = input.grad 1575*da0073e9SAndroid Build Coastguard Worker else: 1576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, input.grad, atol=0, rtol=0) 1577*da0073e9SAndroid Build Coastguard Worker input.grad = None 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1580*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1581*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_interpolate_bicubic(self, device): 1582*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 4, 4, device=device, requires_grad=True) 1583*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.interpolate( 1584*da0073e9SAndroid Build Coastguard Worker input, 1585*da0073e9SAndroid Build Coastguard Worker size=12, 1586*da0073e9SAndroid Build Coastguard Worker mode='bicubic', 1587*da0073e9SAndroid Build Coastguard Worker align_corners=False) 1588*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1591*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad), 1592*da0073e9SAndroid Build Coastguard Worker 'upsample_bicubic2d_backward_out_cuda', 1593*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1596*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1597*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_interpolate_trilinear(self, device): 1598*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 2, 4, 4, 4, device=device, requires_grad=True) 1599*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.interpolate( 1600*da0073e9SAndroid Build Coastguard Worker input, 1601*da0073e9SAndroid Build Coastguard Worker size=12, 1602*da0073e9SAndroid Build Coastguard Worker mode='trilinear', 1603*da0073e9SAndroid Build Coastguard Worker align_corners=False) 1604*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1607*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad), 1608*da0073e9SAndroid Build Coastguard Worker 'upsample_trilinear3d_backward_out_cuda', 1609*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1612*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1613*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_ReflectionPad1d(self, device): 1614*da0073e9SAndroid Build Coastguard Worker module = torch.nn.ReflectionPad1d((1, 2)) 1615*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 8, device=device, requires_grad=True) 1616*da0073e9SAndroid Build Coastguard Worker res = module(input) 1617*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1620*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1621*da0073e9SAndroid Build Coastguard Worker 'reflection_pad1d_backward_out_cuda', 1622*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1625*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_ReflectionPad2d(self, device): 1626*da0073e9SAndroid Build Coastguard Worker module = torch.nn.ReflectionPad2d((1, 2, 3, 4)) 1627*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 8, 8, device=device, requires_grad=True) 1628*da0073e9SAndroid Build Coastguard Worker res = module(input) 1629*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1630*da0073e9SAndroid Build Coastguard Worker 1631*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1632*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1633*da0073e9SAndroid Build Coastguard Worker 'reflection_pad2d_backward_cuda', 1634*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1637*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1638*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_ReflectionPad3d(self, device): 1639*da0073e9SAndroid Build Coastguard Worker module = torch.nn.ReflectionPad3d((1, 2, 3, 4, 5, 6)) 1640*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 8, 8, 8, device=device, requires_grad=True) 1641*da0073e9SAndroid Build Coastguard Worker res = module(input) 1642*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1645*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1646*da0073e9SAndroid Build Coastguard Worker 'reflection_pad3d_backward_out_cuda', 1647*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1650*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1651*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_ReplicationPad1d(self, device): 1652*da0073e9SAndroid Build Coastguard Worker module = torch.nn.ReplicationPad1d((1, 2)) 1653*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 4, device=device, requires_grad=True) 1654*da0073e9SAndroid Build Coastguard Worker res = module(input) 1655*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1658*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1659*da0073e9SAndroid Build Coastguard Worker 'replication_pad1d_backward_cuda', 1660*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1663*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_ReplicationPad2d(self, device): 1664*da0073e9SAndroid Build Coastguard Worker module = torch.nn.ReplicationPad2d((1, 2, 3, 4)) 1665*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 4, 4, device=device, requires_grad=True) 1666*da0073e9SAndroid Build Coastguard Worker res = module(input) 1667*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker # Nondeterministic alert should only be raised if the forward call was 1670*da0073e9SAndroid Build Coastguard Worker # nondeterministic 1671*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1672*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1673*da0073e9SAndroid Build Coastguard Worker 'replication_pad2d_backward_cuda', 1674*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 1677*da0073e9SAndroid Build Coastguard Worker res = module(input) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1680*da0073e9SAndroid Build Coastguard Worker 1681*da0073e9SAndroid Build Coastguard Worker # If the forward call was deterministic, nondeterministic alert should 1682*da0073e9SAndroid Build Coastguard Worker # not be raised 1683*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1684*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1685*da0073e9SAndroid Build Coastguard Worker 'replication_pad2d_backward_cuda', 1686*da0073e9SAndroid Build Coastguard Worker False) 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1689*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1690*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_ReplicationPad3d(self, device): 1691*da0073e9SAndroid Build Coastguard Worker module = torch.nn.ReplicationPad3d((1, 2, 3, 4, 5, 6)) 1692*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 4, 4, 4, device=device, requires_grad=True) 1693*da0073e9SAndroid Build Coastguard Worker res = module(input) 1694*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1695*da0073e9SAndroid Build Coastguard Worker 1696*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1697*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1698*da0073e9SAndroid Build Coastguard Worker 'replication_pad3d_backward_cuda', 1699*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Warning is not raised.") 1702*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_NLLLoss(self, device): 1703*da0073e9SAndroid Build Coastguard Worker module = torch.nn.NLLLoss() 1704*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3, 5, 5, device=device) 1705*da0073e9SAndroid Build Coastguard Worker target = torch.rand(2, 5, 5, device=device).mul(3).floor().long() 1706*da0073e9SAndroid Build Coastguard Worker 1707*da0073e9SAndroid Build Coastguard Worker 1708*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1709*da0073e9SAndroid Build Coastguard Worker lambda: module(input, target), 1710*da0073e9SAndroid Build Coastguard Worker 'nll_loss2d_forward_out_cuda_template', 1711*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1714*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_CTCLoss(self, device): 1715*da0073e9SAndroid Build Coastguard Worker module = torch.nn.CTCLoss() 1716*da0073e9SAndroid Build Coastguard Worker input = torch.randn(50, 3, 15, device=device, requires_grad=True) 1717*da0073e9SAndroid Build Coastguard Worker target = torch.randint(0, 14, (3, 30), device=device) 1718*da0073e9SAndroid Build Coastguard Worker input_lengths = [50, 50, 50] 1719*da0073e9SAndroid Build Coastguard Worker target_lengths = [30, 25, 20] 1720*da0073e9SAndroid Build Coastguard Worker res = module(input, target, input_lengths, target_lengths) 1721*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1724*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1725*da0073e9SAndroid Build Coastguard Worker 'ctc_loss_backward_gpu', 1726*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1729*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_EmbeddingBag_max(self, device): 1730*da0073e9SAndroid Build Coastguard Worker module = torch.nn.EmbeddingBag( 1731*da0073e9SAndroid Build Coastguard Worker 4, 3, None, 2., False, 'max', 1732*da0073e9SAndroid Build Coastguard Worker _weight=torch.randn(4, 3, device=device, requires_grad=True)) 1733*da0073e9SAndroid Build Coastguard Worker input = torch.randint(0, 3, (4, 3), device=device) 1734*da0073e9SAndroid Build Coastguard Worker res = module(input) 1735*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1738*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1739*da0073e9SAndroid Build Coastguard Worker 'embedding_bag_backward_cuda_max', 1740*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool)) 1743*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1744*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_cumsum(self, device, dtype): 1745*da0073e9SAndroid Build Coastguard Worker input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) 1746*da0073e9SAndroid Build Coastguard Worker should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex) 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker for op_call in [torch.Tensor.cumsum, torch.cumsum]: 1749*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1750*da0073e9SAndroid Build Coastguard Worker lambda: op_call(input, 0), 1751*da0073e9SAndroid Build Coastguard Worker 'cumsum_cuda_kernel', 1752*da0073e9SAndroid Build Coastguard Worker should_alert) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # expected a non-determinitic error, but it was not raised 1755*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1756*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_put(self, device): 1757*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, device=device) 1758*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([0, 0], device=device) 1759*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([0., 1.], device=device) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker for op_call in [torch.Tensor.put, torch.Tensor.put_]: 1762*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1763*da0073e9SAndroid Build Coastguard Worker lambda: op_call(a, indices, values, accumulate=False), 1764*da0073e9SAndroid Build Coastguard Worker 'put_') 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker # warn_only=False correctly raises RuntimeError: put_ does not have a deterministic implementation 1767*da0073e9SAndroid Build Coastguard Worker # warn_only=True logs warning from the FallbackKernel: torch.ops.aten.put_.default, instead of as UserWarning: 1768*da0073e9SAndroid Build Coastguard Worker # [W Context.cpp:%(lineno)] Warning: put_ does not have a deterministic implementation 1769*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("warning is logged from the FallbackKernel: torch.ops.aten.put_.default when warn_only=True") 1770*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_put_accumulate(self, device): 1771*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, device=device) 1772*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([0, 0], device=device) 1773*da0073e9SAndroid Build Coastguard Worker values = torch.tensor([0., 1.], device=device) 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker for op_call in [torch.Tensor.put, torch.Tensor.put_]: 1776*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1777*da0073e9SAndroid Build Coastguard Worker lambda: op_call(a, indices, values, accumulate=True), 1778*da0073e9SAndroid Build Coastguard Worker 'put_', 1779*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1782*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_histc(self, device): 1783*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([], device=device) 1784*da0073e9SAndroid Build Coastguard Worker for op_call in [torch.histc, torch.Tensor.histc]: 1785*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1786*da0073e9SAndroid Build Coastguard Worker lambda: op_call(a, min=0, max=3), 1787*da0073e9SAndroid Build Coastguard Worker '_histc_cuda', 1788*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1789*da0073e9SAndroid Build Coastguard Worker 1790*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1791*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_bincount(self, device): 1792*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([], device=device, dtype=torch.long) 1793*da0073e9SAndroid Build Coastguard Worker weights = torch.tensor([], device=device) 1794*da0073e9SAndroid Build Coastguard Worker 1795*da0073e9SAndroid Build Coastguard Worker for op_call in [torch.bincount, torch.Tensor.bincount]: 1796*da0073e9SAndroid Build Coastguard Worker # Error should only be raised when device is CUDA and weights are 1797*da0073e9SAndroid Build Coastguard Worker # given 1798*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1799*da0073e9SAndroid Build Coastguard Worker lambda: op_call(a, weights), 1800*da0073e9SAndroid Build Coastguard Worker '_bincount_cuda', 1801*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1802*da0073e9SAndroid Build Coastguard Worker 1803*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1804*da0073e9SAndroid Build Coastguard Worker lambda: op_call(a), 1805*da0073e9SAndroid Build Coastguard Worker '_bincount_cuda', 1806*da0073e9SAndroid Build Coastguard Worker False) 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker # Ensures that kthvalue throws nondeterministic alerts in the correct cases 1809*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1810*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_kthvalue(self, device, dtype): 1811*da0073e9SAndroid Build Coastguard Worker def test_func(call_type): 1812*da0073e9SAndroid Build Coastguard Worker S = 10 1813*da0073e9SAndroid Build Coastguard Worker k = 5 1814*da0073e9SAndroid Build Coastguard Worker a = torch.randn(S, device=device) 1815*da0073e9SAndroid Build Coastguard Worker if call_type == 'function': 1816*da0073e9SAndroid Build Coastguard Worker torch.kthvalue(a, k) 1817*da0073e9SAndroid Build Coastguard Worker elif call_type == 'method': 1818*da0073e9SAndroid Build Coastguard Worker a.kthvalue(k) 1819*da0073e9SAndroid Build Coastguard Worker elif call_type == 'out': 1820*da0073e9SAndroid Build Coastguard Worker values = torch.empty_like(a) 1821*da0073e9SAndroid Build Coastguard Worker indices = torch.empty((), device=device, dtype=torch.long) 1822*da0073e9SAndroid Build Coastguard Worker torch.kthvalue(a, k, out=(values, indices)) 1823*da0073e9SAndroid Build Coastguard Worker else: 1824*da0073e9SAndroid Build Coastguard Worker self.fail(f"'{call_type}' is not a valid call type") 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker for call_type in ['function', 'method', 'out']: 1827*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1828*da0073e9SAndroid Build Coastguard Worker lambda: test_func('function'), 1829*da0073e9SAndroid Build Coastguard Worker 'kthvalue CUDA', 1830*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1833*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1834*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_grid_sample_2d(self, device): 1835*da0073e9SAndroid Build Coastguard Worker input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True) 1836*da0073e9SAndroid Build Coastguard Worker grid = torch.empty(1, 1, 1, 2, device=device) 1837*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.grid_sample(input, grid, align_corners=False) 1838*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1839*da0073e9SAndroid Build Coastguard Worker 1840*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1841*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1842*da0073e9SAndroid Build Coastguard Worker 'grid_sampler_2d_backward_cuda', 1843*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1844*da0073e9SAndroid Build Coastguard Worker 1845*da0073e9SAndroid Build Coastguard Worker @skipIfMps 1846*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") 1847*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_grid_sample_3d(self, device): 1848*da0073e9SAndroid Build Coastguard Worker input = torch.empty(1, 1, 2, 2, 2, device=device, requires_grad=True) 1849*da0073e9SAndroid Build Coastguard Worker grid = torch.empty(1, 1, 1, 2, 3, device=device) 1850*da0073e9SAndroid Build Coastguard Worker res = torch.nn.functional.grid_sample(input, grid, align_corners=False) 1851*da0073e9SAndroid Build Coastguard Worker grad = torch.ones_like(res) 1852*da0073e9SAndroid Build Coastguard Worker 1853*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1854*da0073e9SAndroid Build Coastguard Worker lambda: res.backward(grad, retain_graph=True), 1855*da0073e9SAndroid Build Coastguard Worker 'grid_sampler_3d_backward_cuda', 1856*da0073e9SAndroid Build Coastguard Worker torch.device(device).type == 'cuda') 1857*da0073e9SAndroid Build Coastguard Worker 1858*da0073e9SAndroid Build Coastguard Worker def test_invalid_shapes_grid_sampler(self, device): 1859*da0073e9SAndroid Build Coastguard Worker make_arg = partial( 1860*da0073e9SAndroid Build Coastguard Worker make_tensor, device=device, dtype=torch.float64, requires_grad=True) 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker inputs = ( 1863*da0073e9SAndroid Build Coastguard Worker # input, grid 1864*da0073e9SAndroid Build Coastguard Worker ((5, 5, 5, 5, 5,), (1, 1, 1, 4, 4,)), # 3d 1865*da0073e9SAndroid Build Coastguard Worker ((5, 5, 5, 5,), (1, 1, 4, 4,)), # 2d 1866*da0073e9SAndroid Build Coastguard Worker ) 1867*da0073e9SAndroid Build Coastguard Worker 1868*da0073e9SAndroid Build Coastguard Worker interpolation_mode = 0 1869*da0073e9SAndroid Build Coastguard Worker padding_mode = 0 1870*da0073e9SAndroid Build Coastguard Worker align_corners = True 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker err = "expected grid and input to have same batch size" 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker for input, grid in inputs: 1875*da0073e9SAndroid Build Coastguard Worker input = make_arg(input) 1876*da0073e9SAndroid Build Coastguard Worker grid = make_arg(grid, low=-1, high=1) 1877*da0073e9SAndroid Build Coastguard Worker 1878*da0073e9SAndroid Build Coastguard Worker # Wrapper for the 2d, 3d, and cuDNN functions listed below. 1879*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err): 1880*da0073e9SAndroid Build Coastguard Worker torch.grid_sampler( 1881*da0073e9SAndroid Build Coastguard Worker input, grid, interpolation_mode, padding_mode, 1882*da0073e9SAndroid Build Coastguard Worker align_corners) 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker # Expects 2d input. 1885*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err): 1886*da0073e9SAndroid Build Coastguard Worker torch.grid_sampler_2d( 1887*da0073e9SAndroid Build Coastguard Worker input, grid, interpolation_mode, padding_mode, 1888*da0073e9SAndroid Build Coastguard Worker align_corners) 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker # Expects 3d input. 1891*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err): 1892*da0073e9SAndroid Build Coastguard Worker torch.grid_sampler_3d( 1893*da0073e9SAndroid Build Coastguard Worker input, grid, interpolation_mode, padding_mode, 1894*da0073e9SAndroid Build Coastguard Worker align_corners) 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker # Expects 2d input. 1897*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err): 1898*da0073e9SAndroid Build Coastguard Worker torch._grid_sampler_2d_cpu_fallback( 1899*da0073e9SAndroid Build Coastguard Worker input, grid, interpolation_mode, padding_mode, 1900*da0073e9SAndroid Build Coastguard Worker align_corners) 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker # Expects 2d input, on CUDA. 1903*da0073e9SAndroid Build Coastguard Worker # Doesn't work on CPU and ROCm. 1904*da0073e9SAndroid Build Coastguard Worker if device != 'cpu' and TEST_CUDNN and not TEST_WITH_ROCM: 1905*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err): 1906*da0073e9SAndroid Build Coastguard Worker torch.cudnn_grid_sampler(input, grid) 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker def test_dist(self, device): 1909*da0073e9SAndroid Build Coastguard Worker def run_test(x, y): 1910*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 4, inf, -inf]: 1911*da0073e9SAndroid Build Coastguard Worker dist_xy = torch.dist(x, y, p) 1912*da0073e9SAndroid Build Coastguard Worker dist_xy_norm = torch.norm(x - y, p) 1913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dist_xy, dist_xy_norm) 1914*da0073e9SAndroid Build Coastguard Worker 1915*da0073e9SAndroid Build Coastguard Worker run_test(torch.randn(5, device=device), torch.randn(5, device=device)) 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, device=device) 1918*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(3, device=device) 1919*da0073e9SAndroid Build Coastguard Worker y[1] = 1. 1920*da0073e9SAndroid Build Coastguard Worker run_test(x, y) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker # Ensures that median throws nondeterministic alerts in the correct cases 1923*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 1924*da0073e9SAndroid Build Coastguard Worker def test_nondeterministic_alert_median(self, device, dtype): 1925*da0073e9SAndroid Build Coastguard Worker def test_func(call_type): 1926*da0073e9SAndroid Build Coastguard Worker S = 10 1927*da0073e9SAndroid Build Coastguard Worker a = torch.randn(S, device=device) 1928*da0073e9SAndroid Build Coastguard Worker if call_type == 'function': 1929*da0073e9SAndroid Build Coastguard Worker torch.median(a) 1930*da0073e9SAndroid Build Coastguard Worker elif call_type == 'function with indices': 1931*da0073e9SAndroid Build Coastguard Worker torch.median(a, 0) 1932*da0073e9SAndroid Build Coastguard Worker elif call_type == 'method': 1933*da0073e9SAndroid Build Coastguard Worker a.median() 1934*da0073e9SAndroid Build Coastguard Worker elif call_type == 'method with indices': 1935*da0073e9SAndroid Build Coastguard Worker a.median(0) 1936*da0073e9SAndroid Build Coastguard Worker elif call_type == 'out with indices': 1937*da0073e9SAndroid Build Coastguard Worker result = torch.empty_like(a) 1938*da0073e9SAndroid Build Coastguard Worker indices = torch.empty((), dtype=torch.long, device=device) 1939*da0073e9SAndroid Build Coastguard Worker torch.median(a, 0, out=(result, indices)) 1940*da0073e9SAndroid Build Coastguard Worker else: 1941*da0073e9SAndroid Build Coastguard Worker self.fail(f"'{call_type}' is not a valid call type") 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Worker def test_func_expect_error(call_type, should_error): 1944*da0073e9SAndroid Build Coastguard Worker self.check_nondeterministic_alert( 1945*da0073e9SAndroid Build Coastguard Worker lambda: test_func(call_type), 1946*da0073e9SAndroid Build Coastguard Worker 'median CUDA with indices output', 1947*da0073e9SAndroid Build Coastguard Worker should_error) 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Worker is_cuda = torch.device(device).type == 'cuda' 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker test_func_expect_error('function', False) 1952*da0073e9SAndroid Build Coastguard Worker test_func_expect_error('function with indices', is_cuda) 1953*da0073e9SAndroid Build Coastguard Worker test_func_expect_error('method', False) 1954*da0073e9SAndroid Build Coastguard Worker test_func_expect_error('method with indices', is_cuda) 1955*da0073e9SAndroid Build Coastguard Worker test_func_expect_error('out with indices', is_cuda) 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_scatter_gather_ops 1958*da0073e9SAndroid Build Coastguard Worker def _test_gather_backward_one_dim(self, device, deterministic: bool = False) -> None: 1959*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(deterministic): 1960*da0073e9SAndroid Build Coastguard Worker m = random.randint(2000, 3000) 1961*da0073e9SAndroid Build Coastguard Worker elems = random.randint(10 * m, 20 * m) 1962*da0073e9SAndroid Build Coastguard Worker dim = 0 1963*da0073e9SAndroid Build Coastguard Worker src = torch.randn(m, device=device, requires_grad=True) 1964*da0073e9SAndroid Build Coastguard Worker idx = torch.randint(m, (elems,), device=device) 1965*da0073e9SAndroid Build Coastguard Worker res = torch.gather(src, dim, idx) 1966*da0073e9SAndroid Build Coastguard Worker weight = torch.rand_like(res, device=device) * 10 ** 6 1967*da0073e9SAndroid Build Coastguard Worker res.backward(weight) 1968*da0073e9SAndroid Build Coastguard Worker assert src.grad is not None 1969*da0073e9SAndroid Build Coastguard Worker grad = src.grad.detach().clone() 1970*da0073e9SAndroid Build Coastguard Worker 1971*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'cuda': 1972*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 1973*da0073e9SAndroid Build Coastguard Worker src.grad.data.zero_() 1974*da0073e9SAndroid Build Coastguard Worker res = torch.gather(src, dim, idx) 1975*da0073e9SAndroid Build Coastguard Worker res.backward(weight) 1976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.grad, grad, atol=0, rtol=0) 1977*da0073e9SAndroid Build Coastguard Worker else: 1978*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros_like(src, device=device) 1979*da0073e9SAndroid Build Coastguard Worker for i in range(elems): 1980*da0073e9SAndroid Build Coastguard Worker expected[idx[i]] += weight[i] 1981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, expected, atol=0, rtol=0) 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_scatter_gather_ops 1984*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1985*da0073e9SAndroid Build Coastguard Worker def test_gather_backward_deterministic_path(self, device) -> None: 1986*da0073e9SAndroid Build Coastguard Worker self._test_gather_backward_one_dim(device, True) 1987*da0073e9SAndroid Build Coastguard Worker 1988*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_scatter_gather_ops 1989*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1990*da0073e9SAndroid Build Coastguard Worker def test_gather_backward_one_dim(self, device) -> None: 1991*da0073e9SAndroid Build Coastguard Worker self._test_gather_backward_one_dim(device, False) 1992*da0073e9SAndroid Build Coastguard Worker 1993*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_scatter_gather_ops 1994*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 1995*da0073e9SAndroid Build Coastguard Worker def test_scatter_add_one_dim_deterministic(self, device) -> None: 1996*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 1997*da0073e9SAndroid Build Coastguard Worker m = random.randint(20, 30) 1998*da0073e9SAndroid Build Coastguard Worker elems = random.randint(2000 * m, 3000 * m) 1999*da0073e9SAndroid Build Coastguard Worker dim = 0 2000*da0073e9SAndroid Build Coastguard Worker src = torch.randn(elems, device=device) 2001*da0073e9SAndroid Build Coastguard Worker idx = torch.randint(m, (elems,), device=device) 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(m, device=device) 2004*da0073e9SAndroid Build Coastguard Worker res = x.scatter_add(dim, idx, src) 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Worker # Checking if scatter_add is deterministic 2007*da0073e9SAndroid Build Coastguard Worker for i in range(5): 2008*da0073e9SAndroid Build Coastguard Worker res_next = x.scatter_add(dim, idx, src) 2009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_next, atol=0, rtol=0) 2010*da0073e9SAndroid Build Coastguard Worker res = res_next 2011*da0073e9SAndroid Build Coastguard Worker 2012*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(m, device=device) 2013*da0073e9SAndroid Build Coastguard Worker for i in range(elems): 2014*da0073e9SAndroid Build Coastguard Worker expected[idx[i]] += src[i] 2015*da0073e9SAndroid Build Coastguard Worker 2016*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, expected, atol=1e-4, rtol=1e-5) 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_scatter_gather_ops 2019*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2020*da0073e9SAndroid Build Coastguard Worker def test_scatter_zero_size_index(self, device) -> None: 2021*da0073e9SAndroid Build Coastguard Worker null_index = torch.zeros((0, 4), dtype=torch.int64) 2022*da0073e9SAndroid Build Coastguard Worker null_arr = torch.zeros((0, 4)) 2023*da0073e9SAndroid Build Coastguard Worker original = torch.arange(4, dtype=torch.float32) 2024*da0073e9SAndroid Build Coastguard Worker result = original.scatter(0, null_index, null_arr) 2025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, original, atol=0, rtol=0) 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2028*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("FIXME") 2029*da0073e9SAndroid Build Coastguard Worker def test_sync_warning(self, device): 2030*da0073e9SAndroid Build Coastguard Worker 2031*da0073e9SAndroid Build Coastguard Worker def _sync_raises_helper(f, level): 2032*da0073e9SAndroid Build Coastguard Worker with CudaSyncGuard(level): 2033*da0073e9SAndroid Build Coastguard Worker if level == 1: 2034*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "called a synchronizing "): 2035*da0073e9SAndroid Build Coastguard Worker f() 2036*da0073e9SAndroid Build Coastguard Worker elif level == 2: 2037*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "called a synchronizing "): 2038*da0073e9SAndroid Build Coastguard Worker f() 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker def _no_sync_helper(f, level): 2041*da0073e9SAndroid Build Coastguard Worker with CudaSyncGuard(level): 2042*da0073e9SAndroid Build Coastguard Worker f() 2043*da0073e9SAndroid Build Coastguard Worker 2044*da0073e9SAndroid Build Coastguard Worker def _ind_put_fn(x, ind, val): 2045*da0073e9SAndroid Build Coastguard Worker x[ind] = val 2046*da0073e9SAndroid Build Coastguard Worker return x 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker def _ind_get_fn(x, ind): 2049*da0073e9SAndroid Build Coastguard Worker return x[ind] 2050*da0073e9SAndroid Build Coastguard Worker 2051*da0073e9SAndroid Build Coastguard Worker def _cond_fn(x): 2052*da0073e9SAndroid Build Coastguard Worker if x: # taking boolean value of a tensor synchronizes 2053*da0073e9SAndroid Build Coastguard Worker return x 2054*da0073e9SAndroid Build Coastguard Worker else: 2055*da0073e9SAndroid Build Coastguard Worker return 2 * x 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker # prepare inputs for subsequent ops 2058*da0073e9SAndroid Build Coastguard Worker size = 4 2059*da0073e9SAndroid Build Coastguard Worker x = torch.rand(size, device=device) 2060*da0073e9SAndroid Build Coastguard Worker y = torch.rand((), device=device) 2061*da0073e9SAndroid Build Coastguard Worker ind = torch.randint(size, (3,), device=device) 2062*da0073e9SAndroid Build Coastguard Worker ind_cpu = ind.cpu() 2063*da0073e9SAndroid Build Coastguard Worker repeats = torch.full((1,), 2, device=device) 2064*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(2, (size,), device=device, dtype=bool) 2065*da0073e9SAndroid Build Coastguard Worker expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.), 2066*da0073e9SAndroid Build Coastguard Worker lambda: _ind_put_fn(x, ind, y), 2067*da0073e9SAndroid Build Coastguard Worker lambda: _ind_get_fn(x, ind), 2068*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.one_hot(ind, num_classes=size), 2069*da0073e9SAndroid Build Coastguard Worker lambda: torch.randperm(20000, device=device), 2070*da0073e9SAndroid Build Coastguard Worker lambda: torch.repeat_interleave(x, 2, output_size=2 * size), 2071*da0073e9SAndroid Build Coastguard Worker lambda: torch.repeat_interleave(x, repeats, output_size=2 * size), 2072*da0073e9SAndroid Build Coastguard Worker lambda: torch.any(y)) 2073*da0073e9SAndroid Build Coastguard Worker expect_sync = (lambda: _ind_put_fn(x, mask, y), 2074*da0073e9SAndroid Build Coastguard Worker lambda: _ind_put_fn(x, ind_cpu, y), 2075*da0073e9SAndroid Build Coastguard Worker lambda: _ind_get_fn(x, mask), 2076*da0073e9SAndroid Build Coastguard Worker lambda: _ind_get_fn(x, ind_cpu), 2077*da0073e9SAndroid Build Coastguard Worker lambda: x.nonzero(), 2078*da0073e9SAndroid Build Coastguard Worker lambda: _cond_fn(y), 2079*da0073e9SAndroid Build Coastguard Worker lambda: torch.nn.functional.one_hot(ind), 2080*da0073e9SAndroid Build Coastguard Worker lambda: torch.repeat_interleave(x, repeats)) 2081*da0073e9SAndroid Build Coastguard Worker for f, level in product(expect_no_sync, (1, 2)): 2082*da0073e9SAndroid Build Coastguard Worker _no_sync_helper(f, level) 2083*da0073e9SAndroid Build Coastguard Worker for f, level in product(expect_sync, (1, 2)): 2084*da0073e9SAndroid Build Coastguard Worker _sync_raises_helper(f, level) 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker 2087*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2088*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2089*da0073e9SAndroid Build Coastguard Worker def test_log_normal(self, device, dtype): 2090*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([10], dtype=dtype, device=device).log_normal_() 2091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, dtype) 2092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), torch.Size([1])) 2093*da0073e9SAndroid Build Coastguard Worker 2094*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 2095*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2096*da0073e9SAndroid Build Coastguard Worker def test_geometric(self, device, dtype): 2097*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([10], dtype=dtype, device=device).geometric_(0.5) 2098*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, dtype) 2099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), torch.Size([1])) 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2102*da0073e9SAndroid Build Coastguard Worker def test_repeat_interleave(self, device): 2103*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([[1, 2], [3, 4]], device=device) 2104*da0073e9SAndroid Build Coastguard Worker # exercise single argument function signature 2105*da0073e9SAndroid Build Coastguard Worker temp = y.repeat_interleave(2) 2106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([8]), temp.size()) 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.int, torch.long]: 2109*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([1, 2], dtype=dtype, device=device) 2110*da0073e9SAndroid Build Coastguard Worker output_size = torch.sum(lengths) 2111*da0073e9SAndroid Build Coastguard Worker a = torch.repeat_interleave( 2112*da0073e9SAndroid Build Coastguard Worker y, 2113*da0073e9SAndroid Build Coastguard Worker lengths, 2114*da0073e9SAndroid Build Coastguard Worker dim=0, 2115*da0073e9SAndroid Build Coastguard Worker ) 2116*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, y.dtype) 2117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), torch.Size([3, 2])) 2118*da0073e9SAndroid Build Coastguard Worker 2119*da0073e9SAndroid Build Coastguard Worker a_with_output = torch.repeat_interleave( 2120*da0073e9SAndroid Build Coastguard Worker y, 2121*da0073e9SAndroid Build Coastguard Worker lengths, 2122*da0073e9SAndroid Build Coastguard Worker dim=0, 2123*da0073e9SAndroid Build Coastguard Worker output_size=output_size, 2124*da0073e9SAndroid Build Coastguard Worker ) 2125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_with_output.dtype, y.dtype) 2126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_with_output.size(), torch.Size([3, 2])) 2127*da0073e9SAndroid Build Coastguard Worker 2128*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types()) 2129*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(*floating_types_and(torch.bfloat16, torch.half)) 2130*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half)) 2131*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_p(self, device, dtype): 2132*da0073e9SAndroid Build Coastguard Worker for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): 2133*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(trivial_p, dtype=dtype, device=device) 2134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.bernoulli().tolist(), trivial_p) 2135*da0073e9SAndroid Build Coastguard Worker 2136*da0073e9SAndroid Build Coastguard Worker def isBinary(t): 2137*da0073e9SAndroid Build Coastguard Worker return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker p = torch.rand(5, 5, dtype=dtype, device=device) 2140*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(p.bernoulli())) 2141*da0073e9SAndroid Build Coastguard Worker 2142*da0073e9SAndroid Build Coastguard Worker p = torch.rand(5, dtype=dtype, device=device).expand(5, 5) 2143*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(p.bernoulli())) 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker p = torch.rand(5, 5, dtype=dtype, device=device) 2146*da0073e9SAndroid Build Coastguard Worker torch.bernoulli(torch.rand_like(p), out=p) 2147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(p)) 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker # RngUniform not implemented for Integral type in XLA test 2150*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types()) 2151*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(*all_types_and(torch.bool, torch.half)) 2152*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*all_types_and(torch.bool, torch.half)) 2153*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_self(self, device, dtype): 2154*da0073e9SAndroid Build Coastguard Worker 2155*da0073e9SAndroid Build Coastguard Worker def isBinary(t): 2156*da0073e9SAndroid Build Coastguard Worker return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0 2157*da0073e9SAndroid Build Coastguard Worker 2158*da0073e9SAndroid Build Coastguard Worker t = torch.empty(10, 10, dtype=dtype, device=device) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker t.fill_(2) 2161*da0073e9SAndroid Build Coastguard Worker t.bernoulli_(0.5) 2162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(t)) 2163*da0073e9SAndroid Build Coastguard Worker 2164*da0073e9SAndroid Build Coastguard Worker for p_dtype in floating_types_and(*[torch.half] if device.startswith('cuda') else []): 2165*da0073e9SAndroid Build Coastguard Worker p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10) 2166*da0073e9SAndroid Build Coastguard Worker t.fill_(2) 2167*da0073e9SAndroid Build Coastguard Worker t.bernoulli_(p) 2168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(t)) 2169*da0073e9SAndroid Build Coastguard Worker 2170*da0073e9SAndroid Build Coastguard Worker t.fill_(2) 2171*da0073e9SAndroid Build Coastguard Worker torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t) 2172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(t)) 2173*da0073e9SAndroid Build Coastguard Worker 2174*da0073e9SAndroid Build Coastguard Worker t.fill_(2) 2175*da0073e9SAndroid Build Coastguard Worker t.bernoulli_(torch.rand_like(t, dtype=p_dtype)) 2176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isBinary(t)) 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker @slowTest 2179*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 2180*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half)) 2181*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_edge_cases(self, device, dtype): 2182*da0073e9SAndroid Build Coastguard Worker # Need to draw a lot of samples to cover every random floating point number. 2183*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(10000, 10000, dtype=dtype, device=device) # probability of drawing "1" is 0 2184*da0073e9SAndroid Build Coastguard Worker num_ones = (torch.bernoulli(a) == 1).sum() 2185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_ones, 0) 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker b = torch.ones(10000, 10000, dtype=dtype, device=device) # probability of drawing "1" is 1 2188*da0073e9SAndroid Build Coastguard Worker num_zeros = (torch.bernoulli(b) == 0).sum() 2189*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_zeros, 0) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2192*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2193*da0073e9SAndroid Build Coastguard Worker def test_exponential(self, device, dtype): 2194*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([10], dtype=dtype, device=device).exponential_(0.5) 2195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, dtype) 2196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), torch.Size([1])) 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker # Tests extremal behavior 2199*da0073e9SAndroid Build Coastguard Worker t = torch.empty((1,), device=device, dtype=dtype).exponential_(float('inf')) 2200*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.item() == 0) 2201*da0073e9SAndroid Build Coastguard Worker 2202*da0073e9SAndroid Build Coastguard Worker # Tests that negative lambda fails 2203*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2204*da0073e9SAndroid Build Coastguard Worker torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5) 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2207*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half, torch.float) 2208*da0073e9SAndroid Build Coastguard Worker def test_exponential_no_zero(self, device, dtype): 2209*da0073e9SAndroid Build Coastguard Worker # naively, 0 in exponential can be generated with probability 2^-24 2210*da0073e9SAndroid Build Coastguard Worker # so we need more samples to check if it's not generated 2211*da0073e9SAndroid Build Coastguard Worker # instead of doing one 2212*da0073e9SAndroid Build Coastguard Worker # don't test CPU, that would be a long test 2213*da0073e9SAndroid Build Coastguard Worker x = torch.empty(50000000, device=device, dtype=dtype).exponential_() 2214*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.min() > 0) 2215*da0073e9SAndroid Build Coastguard Worker 2216*da0073e9SAndroid Build Coastguard Worker def _generate_correlation_tensors(self, device, dtype): 2217*da0073e9SAndroid Build Coastguard Worker yield make_tensor((0, 0), dtype=dtype, device=device) 2218*da0073e9SAndroid Build Coastguard Worker yield make_tensor((1, 0), dtype=dtype, device=device) 2219*da0073e9SAndroid Build Coastguard Worker yield make_tensor((0, 1), dtype=dtype, device=device) 2220*da0073e9SAndroid Build Coastguard Worker yield make_tensor((2,), dtype=dtype, device=device) 2221*da0073e9SAndroid Build Coastguard Worker yield make_tensor((2, 1), dtype=dtype, device=device) 2222*da0073e9SAndroid Build Coastguard Worker yield make_tensor((2, 2), dtype=dtype, device=device) 2223*da0073e9SAndroid Build Coastguard Worker yield make_tensor((2, 3), dtype=dtype, device=device) 2224*da0073e9SAndroid Build Coastguard Worker yield make_tensor((5, 10), dtype=dtype, device=device) 2225*da0073e9SAndroid Build Coastguard Worker yield make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True) 2226*da0073e9SAndroid Build Coastguard Worker if dtype != torch.int: 2227*da0073e9SAndroid Build Coastguard Worker yield torch.tensor([0, -2, nan, 10.2, inf], dtype=dtype, device=device) 2228*da0073e9SAndroid Build Coastguard Worker 2229*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2230*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int, torch.float, torch.cfloat) 2231*da0073e9SAndroid Build Coastguard Worker def test_corrcoef(self, device, dtype): 2232*da0073e9SAndroid Build Coastguard Worker for x in self._generate_correlation_tensors(device, dtype): 2233*da0073e9SAndroid Build Coastguard Worker res = torch.corrcoef(x) 2234*da0073e9SAndroid Build Coastguard Worker ref = np.corrcoef(x.cpu().numpy()) 2235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, ref, exact_dtype=False) 2236*da0073e9SAndroid Build Coastguard Worker 2237*da0073e9SAndroid Build Coastguard Worker @skipRocmIfTorchInductor 2238*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int, torch.float, torch.cfloat) 2239*da0073e9SAndroid Build Coastguard Worker def test_cov(self, device, dtype): 2240*da0073e9SAndroid Build Coastguard Worker def check(t, correction=1, fweights=None, aweights=None): 2241*da0073e9SAndroid Build Coastguard Worker res = torch.cov(t, correction=correction, fweights=fweights, aweights=aweights) 2242*da0073e9SAndroid Build Coastguard Worker t = t.cpu().numpy() 2243*da0073e9SAndroid Build Coastguard Worker fweights = fweights.cpu().numpy() if fweights is not None else None 2244*da0073e9SAndroid Build Coastguard Worker aweights = aweights.cpu().numpy() if aweights is not None else None 2245*da0073e9SAndroid Build Coastguard Worker ref = np.cov(t, ddof=correction, fweights=fweights, aweights=aweights) 2246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, ref, atol=1e-05, rtol=1e-05, exact_dtype=False) 2247*da0073e9SAndroid Build Coastguard Worker 2248*da0073e9SAndroid Build Coastguard Worker for x in self._generate_correlation_tensors(device, dtype): 2249*da0073e9SAndroid Build Coastguard Worker check(x) 2250*da0073e9SAndroid Build Coastguard Worker num_observations = x.numel() if x.ndim < 2 else x.size(1) 2251*da0073e9SAndroid Build Coastguard Worker if num_observations > 0: 2252*da0073e9SAndroid Build Coastguard Worker fweights = torch.randint(1, 10, (num_observations,), device=device) 2253*da0073e9SAndroid Build Coastguard Worker aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=1) 2254*da0073e9SAndroid Build Coastguard Worker for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]): 2255*da0073e9SAndroid Build Coastguard Worker check(x, correction, fweights, aweights) 2256*da0073e9SAndroid Build Coastguard Worker 2257*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 2258*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2259*da0073e9SAndroid Build Coastguard Worker def test_uniform_kstest(self, device, dtype): 2260*da0073e9SAndroid Build Coastguard Worker from scipy import stats 2261*da0073e9SAndroid Build Coastguard Worker size = 1000 2262*da0073e9SAndroid Build Coastguard Worker for from_ in [-42, 0, 4.2]: 2263*da0073e9SAndroid Build Coastguard Worker for to_ in [-4.2, 0, 42]: 2264*da0073e9SAndroid Build Coastguard Worker if to_ > from_: 2265*da0073e9SAndroid Build Coastguard Worker t = torch.empty(size, dtype=dtype, device=device).uniform_(from_, to_) 2266*da0073e9SAndroid Build Coastguard Worker res = stats.kstest(t.cpu().to(torch.double), 'uniform', args=(from_, (to_ - from_))) 2267*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.statistic < 0.1) 2268*da0073e9SAndroid Build Coastguard Worker 2269*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 2270*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half)) 2271*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16)) 2272*da0073e9SAndroid Build Coastguard Worker def test_normal_kstest(self, device, dtype): 2273*da0073e9SAndroid Build Coastguard Worker from scipy import stats 2274*da0073e9SAndroid Build Coastguard Worker size = 1000 2275*da0073e9SAndroid Build Coastguard Worker for mean in [-10, 0, 50]: 2276*da0073e9SAndroid Build Coastguard Worker for std in [1, 5, 10]: 2277*da0073e9SAndroid Build Coastguard Worker t = torch.empty(size, dtype=dtype, device=device).normal_(mean=mean, std=std) 2278*da0073e9SAndroid Build Coastguard Worker res = stats.kstest(t.cpu().to(torch.double), 'norm', args=(mean, std)) 2279*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.statistic < 0.1) 2280*da0073e9SAndroid Build Coastguard Worker 2281*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2282*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 2283*da0073e9SAndroid Build Coastguard Worker @skipRocmIfTorchInductor 2284*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2285*da0073e9SAndroid Build Coastguard Worker def test_lognormal_kstest(self, device, dtype): 2286*da0073e9SAndroid Build Coastguard Worker from scipy import stats 2287*da0073e9SAndroid Build Coastguard Worker size = 1000 2288*da0073e9SAndroid Build Coastguard Worker for mean in [-3, 0, 7]: 2289*da0073e9SAndroid Build Coastguard Worker for std in [1, 5, 7]: 2290*da0073e9SAndroid Build Coastguard Worker t = torch.empty(size, dtype=dtype, device=device).log_normal_(mean=mean, std=std) 2291*da0073e9SAndroid Build Coastguard Worker res = stats.kstest(t.cpu().to(torch.double), 'lognorm', args=(std, 0, math.exp(mean))) 2292*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 2293*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.statistic < 0.3) 2294*da0073e9SAndroid Build Coastguard Worker else: 2295*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.statistic < 0.1) 2296*da0073e9SAndroid Build Coastguard Worker 2297*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2298*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 2299*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2300*da0073e9SAndroid Build Coastguard Worker def test_exponential_kstest(self, device, dtype): 2301*da0073e9SAndroid Build Coastguard Worker from scipy import stats 2302*da0073e9SAndroid Build Coastguard Worker size = 1000 2303*da0073e9SAndroid Build Coastguard Worker for lambd in [0.5, 1.0, 5.0]: 2304*da0073e9SAndroid Build Coastguard Worker t = torch.empty(size, dtype=dtype, device=device).exponential_(lambd=lambd) 2305*da0073e9SAndroid Build Coastguard Worker res = stats.kstest(t.cpu().to(torch.double), 'expon', args=(0, 1 / lambd,)) 2306*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.statistic < 0.1) 2307*da0073e9SAndroid Build Coastguard Worker 2308*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2309*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 2310*da0073e9SAndroid Build Coastguard Worker @skipRocmIfTorchInductor 2311*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2312*da0073e9SAndroid Build Coastguard Worker def test_cauchy_kstest(self, device, dtype): 2313*da0073e9SAndroid Build Coastguard Worker from scipy import stats 2314*da0073e9SAndroid Build Coastguard Worker size = 1000 2315*da0073e9SAndroid Build Coastguard Worker for median in [-10, 0, 50]: 2316*da0073e9SAndroid Build Coastguard Worker for sigma in [0.5, 1.0, 10.0]: 2317*da0073e9SAndroid Build Coastguard Worker t = torch.empty(size, dtype=dtype, device=device).cauchy_(median=median, sigma=sigma) 2318*da0073e9SAndroid Build Coastguard Worker res = stats.kstest(t.cpu().to(torch.double), 'cauchy', args=(median, sigma)) 2319*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res.statistic < 0.1) 2320*da0073e9SAndroid Build Coastguard Worker 2321*da0073e9SAndroid Build Coastguard Worker @slowTest 2322*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2323*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.float32) 2324*da0073e9SAndroid Build Coastguard Worker def test_cauchy_no_inf(self, device, dtype): 2325*da0073e9SAndroid Build Coastguard Worker # torch.float16 will have `inf` because of its smaller range. 2326*da0073e9SAndroid Build Coastguard Worker for _ in range((2**16) * 2): 2327*da0073e9SAndroid Build Coastguard Worker x = torch.empty((2**16), dtype=dtype, device=device) 2328*da0073e9SAndroid Build Coastguard Worker x.cauchy_() 2329*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.isinf().sum()) 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_types_and(torch.half, torch.bfloat16)) 2332*da0073e9SAndroid Build Coastguard Worker def test_cauchy(self, device, dtype): 2333*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([10], dtype=dtype, device=device).cauchy_(0.0, 0.5) 2334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, dtype) 2335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), torch.Size([1])) 2336*da0073e9SAndroid Build Coastguard Worker 2337*da0073e9SAndroid Build Coastguard Worker # Tests extremal behavior 2338*da0073e9SAndroid Build Coastguard Worker t = torch.empty((1,), device=device, dtype=dtype).cauchy_(float('inf'), 0.5) 2339*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.item() == float('inf')) 2340*da0073e9SAndroid Build Coastguard Worker 2341*da0073e9SAndroid Build Coastguard Worker # Tests non-positive rate fails 2342*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2343*da0073e9SAndroid Build Coastguard Worker torch.empty((1,), device=device, dtype=dtype).cauchy_(0.0, 0.0) 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2346*da0073e9SAndroid Build Coastguard Worker @skipIfNoSciPy 2347*da0073e9SAndroid Build Coastguard Worker @skipRocmIfTorchInductor 2348*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 2349*da0073e9SAndroid Build Coastguard Worker def test_geometric_kstest(self, device, dtype): 2350*da0073e9SAndroid Build Coastguard Worker from scipy import stats 2351*da0073e9SAndroid Build Coastguard Worker size = 1000 2352*da0073e9SAndroid Build Coastguard Worker for p in [0.2, 0.5, 0.8]: 2353*da0073e9SAndroid Build Coastguard Worker t = torch.empty(size, dtype=dtype, device=device).geometric_(p=p) 2354*da0073e9SAndroid Build Coastguard Worker actual = np.histogram(t.cpu().to(torch.double), np.arange(1, 100))[0] 2355*da0073e9SAndroid Build Coastguard Worker expected = stats.geom(p).pmf(np.arange(1, 99)) * size 2356*da0073e9SAndroid Build Coastguard Worker res = stats.chisquare(actual, expected) 2357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res.pvalue, 1.0, atol=0.1, rtol=0) 2358*da0073e9SAndroid Build Coastguard Worker 2359*da0073e9SAndroid Build Coastguard Worker # FIXME: find test suite for pdist and cdist 2360*da0073e9SAndroid Build Coastguard Worker def test_pairwise_distance_empty(self, device): 2361*da0073e9SAndroid Build Coastguard Worker shape = (2, 0) 2362*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 2363*da0073e9SAndroid Build Coastguard Worker y = torch.randn(shape, device=device) 2364*da0073e9SAndroid Build Coastguard Worker 2365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y)) 2366*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker shape = (0, 2) 2369*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 2370*da0073e9SAndroid Build Coastguard Worker y = torch.randn(shape, device=device) 2371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y)) 2372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker def test_pdist_empty(self, device): 2375*da0073e9SAndroid Build Coastguard Worker shape = (0, 2) 2376*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 2377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) 2378*da0073e9SAndroid Build Coastguard Worker 2379*da0073e9SAndroid Build Coastguard Worker shape = (1, 2) 2380*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 2381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker shape = (3, 0) 2384*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 2385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(3, device=device), torch.pdist(x)) 2386*da0073e9SAndroid Build Coastguard Worker 2387*da0073e9SAndroid Build Coastguard Worker def test_cdist_empty(self, device): 2388*da0073e9SAndroid Build Coastguard Worker x = torch.randn((0, 5), device=device) 2389*da0073e9SAndroid Build Coastguard Worker y = torch.randn((4, 5), device=device) 2390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y)) 2391*da0073e9SAndroid Build Coastguard Worker 2392*da0073e9SAndroid Build Coastguard Worker x = torch.randn((2, 5), device=device) 2393*da0073e9SAndroid Build Coastguard Worker y = torch.randn((0, 5), device=device) 2394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) 2395*da0073e9SAndroid Build Coastguard Worker 2396*da0073e9SAndroid Build Coastguard Worker x = torch.randn((2, 0), device=device) 2397*da0073e9SAndroid Build Coastguard Worker y = torch.randn((3, 0), device=device) 2398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y)) 2399*da0073e9SAndroid Build Coastguard Worker 2400*da0073e9SAndroid Build Coastguard Worker x = torch.randn((2, 0), device=device) 2401*da0073e9SAndroid Build Coastguard Worker y = torch.randn((0, 0), device=device) 2402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) 2403*da0073e9SAndroid Build Coastguard Worker 2404*da0073e9SAndroid Build Coastguard Worker def _brute_cdist(self, x, y, p=2): 2405*da0073e9SAndroid Build Coastguard Worker r1 = x.shape[-2] 2406*da0073e9SAndroid Build Coastguard Worker r2 = y.shape[-2] 2407*da0073e9SAndroid Build Coastguard Worker if r1 == 0 or r2 == 0: 2408*da0073e9SAndroid Build Coastguard Worker return torch.empty(r1, r2, device=x.device) 2409*da0073e9SAndroid Build Coastguard Worker return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1) 2410*da0073e9SAndroid Build Coastguard Worker 2411*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2412*da0073e9SAndroid Build Coastguard Worker def test_cdist_norm(self, device): 2413*da0073e9SAndroid Build Coastguard Worker for r1 in [3, 4, 5, 6]: 2414*da0073e9SAndroid Build Coastguard Worker for m in [2, 3, 4, 10]: 2415*da0073e9SAndroid Build Coastguard Worker for r2 in [4, 6, 7, 8]: 2416*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: 2417*da0073e9SAndroid Build Coastguard Worker x = torch.randn(r1, m, device=device) 2418*da0073e9SAndroid Build Coastguard Worker y = torch.randn(r2, m, device=device) 2419*da0073e9SAndroid Build Coastguard Worker if p == 2: 2420*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2421*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2422*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, rtol=0, atol=0.02) 2424*da0073e9SAndroid Build Coastguard Worker else: 2425*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=p) 2426*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=p) 2427*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2428*da0073e9SAndroid Build Coastguard Worker 2429*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2430*da0073e9SAndroid Build Coastguard Worker def test_cdist_norm_batch(self, device): 2431*da0073e9SAndroid Build Coastguard Worker for r1 in [3, 4, 5, 6]: 2432*da0073e9SAndroid Build Coastguard Worker for m in [2, 3, 4, 10]: 2433*da0073e9SAndroid Build Coastguard Worker for r2 in [4, 6, 7, 8]: 2434*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: 2435*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 6, r1, m, device=device) 2436*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, 6, r2, m, device=device) 2437*da0073e9SAndroid Build Coastguard Worker if p == 2: 2438*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2439*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2440*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual, rtol=0, atol=0.02) 2442*da0073e9SAndroid Build Coastguard Worker else: 2443*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=p) 2444*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=p) 2445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2446*da0073e9SAndroid Build Coastguard Worker 2447*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 2448*da0073e9SAndroid Build Coastguard Worker def test_cdist_cuda_backward(self, device): 2449*da0073e9SAndroid Build Coastguard Worker for l1 in [1, 511, 513]: 2450*da0073e9SAndroid Build Coastguard Worker for l2 in [1, 511, 513]: 2451*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: 2452*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn(4, l1, 32, device=device, requires_grad=True) 2453*da0073e9SAndroid Build Coastguard Worker x2 = x1.clone().detach_().requires_grad_() 2454*da0073e9SAndroid Build Coastguard Worker y1 = torch.randn(4, l2, 32, device=device, requires_grad=True) 2455*da0073e9SAndroid Build Coastguard Worker y2 = y1.clone().detach_().requires_grad_() 2456*da0073e9SAndroid Build Coastguard Worker if p == 2: 2457*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2458*da0073e9SAndroid Build Coastguard Worker z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean() 2459*da0073e9SAndroid Build Coastguard Worker z2 = self._brute_cdist(x2, y2, p=2).mean() 2460*da0073e9SAndroid Build Coastguard Worker z1.backward() 2461*da0073e9SAndroid Build Coastguard Worker z2.backward() 2462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001) 2463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001) 2464*da0073e9SAndroid Build Coastguard Worker else: 2465*da0073e9SAndroid Build Coastguard Worker z1 = torch.cdist(x1, y1, p=p).mean() 2466*da0073e9SAndroid Build Coastguard Worker z2 = self._brute_cdist(x2, y2, p=p).mean() 2467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001) 2468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001) 2469*da0073e9SAndroid Build Coastguard Worker 2470*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 2471*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 2472*da0073e9SAndroid Build Coastguard Worker def test_cdist_large(self, device): 2473*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2474*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1000, 10, device=device) 2475*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1000, 10, device=device) 2476*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2477*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker @slowTest 2481*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.01) 2482*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.01) 2483*da0073e9SAndroid Build Coastguard Worker def test_cdist_large_batch(self, device): 2484*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2485*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 1000, 10, device=device) 2486*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 3, 1000, 10, device=device) 2487*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2488*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2490*da0073e9SAndroid Build Coastguard Worker 2491*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 2492*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 2493*da0073e9SAndroid Build Coastguard Worker def test_cdist_non_contiguous(self, device): 2494*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2495*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 7, device=device).mT 2496*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, device=device).mT 2497*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2498*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2499*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 2500*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 2501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2502*da0073e9SAndroid Build Coastguard Worker 2503*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 5, device=device) 2504*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3, device=device).t() 2505*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2506*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2507*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 2508*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 2509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2510*da0073e9SAndroid Build Coastguard Worker 2511*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 7, device=device).t() 2512*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 5, device=device) 2513*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2514*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2515*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 2516*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous()) 2517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Worker @tf32_on_and_off(0.005) 2520*da0073e9SAndroid Build Coastguard Worker @bf32_on_and_off(0.005) 2521*da0073e9SAndroid Build Coastguard Worker def test_cdist_non_contiguous_batch(self, device): 2522*da0073e9SAndroid Build Coastguard Worker for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: 2523*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 2, 5, 7, device=device).mT 2524*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 3, 2, 5, 3, device=device).mT 2525*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2526*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2527*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 2528*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 2529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 2, 7, 5, device=device) 2532*da0073e9SAndroid Build Coastguard Worker y = torch.randn(7, 2, 5, 3, device=device).mT 2533*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2534*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2535*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous()) 2536*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 2537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2538*da0073e9SAndroid Build Coastguard Worker 2539*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, 7, device=device).mT 2540*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 3, 5, device=device) 2541*da0073e9SAndroid Build Coastguard Worker actual = torch.cdist(x, y, p=2, compute_mode=cm) 2542*da0073e9SAndroid Build Coastguard Worker expected = self._brute_cdist(x, y, p=2) 2543*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_contiguous()) 2544*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous()) 2545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2546*da0073e9SAndroid Build Coastguard Worker 2547*da0073e9SAndroid Build Coastguard Worker # Maybe merge into OpInfo? 2548*da0073e9SAndroid Build Coastguard Worker def test_cdist_euclidean_large(self, device): 2549*da0073e9SAndroid Build Coastguard Worker def _test_euclidean_large_cdist(sizex, sizey=None): 2550*da0073e9SAndroid Build Coastguard Worker if sizey is None: 2551*da0073e9SAndroid Build Coastguard Worker sizey = sizex 2552*da0073e9SAndroid Build Coastguard Worker x = torch.randn(sizex, device=device, dtype=torch.float) 2553*da0073e9SAndroid Build Coastguard Worker y = torch.randn(sizey, device=device, dtype=torch.float) 2554*da0073e9SAndroid Build Coastguard Worker eps = 1e-6 2555*da0073e9SAndroid Build Coastguard Worker # to avoid extremum 2556*da0073e9SAndroid Build Coastguard Worker x = x - (((x - y) < eps).float() * 2 * eps) 2557*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 2558*da0073e9SAndroid Build Coastguard Worker y.requires_grad = True 2559*da0073e9SAndroid Build Coastguard Worker dist = torch.cdist(x, y, p=2) 2560*da0073e9SAndroid Build Coastguard Worker # Do a backward pass to check that it is valid for large 2561*da0073e9SAndroid Build Coastguard Worker # matrices 2562*da0073e9SAndroid Build Coastguard Worker loss = dist.sum() 2563*da0073e9SAndroid Build Coastguard Worker loss.backward() 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker _test_euclidean_large_cdist((2000, 5)) 2566*da0073e9SAndroid Build Coastguard Worker 2567*da0073e9SAndroid Build Coastguard Worker # Ensure that cdist backward with p<1 does not produce NaNs 2568*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2569*da0073e9SAndroid Build Coastguard Worker def test_cdist_grad_p_lt_1_no_nan(self, device): 2570*da0073e9SAndroid Build Coastguard Worker for p in [0.99, 0.7, 0.5, 0.1, 0.01]: 2571*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2, device=device) 2572*da0073e9SAndroid Build Coastguard Worker y = x.clone().detach() + torch.tensor([[1., 0.]], device=device) 2573*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 2574*da0073e9SAndroid Build Coastguard Worker y.requires_grad = True 2575*da0073e9SAndroid Build Coastguard Worker result = torch.cdist(x, y, p=p) 2576*da0073e9SAndroid Build Coastguard Worker result.backward(torch.ones_like(result)) 2577*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(x.grad).any()) 2578*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.isnan(y.grad).any()) 2579*da0073e9SAndroid Build Coastguard Worker 2580*da0073e9SAndroid Build Coastguard Worker def test_cdist_same_inputs(self, device): 2581*da0073e9SAndroid Build Coastguard Worker # Test to detect issues in cdist gradient calculation 2582*da0073e9SAndroid Build Coastguard Worker # When the distances are 0 2583*da0073e9SAndroid Build Coastguard Worker sizex = (1, 27, 32) 2584*da0073e9SAndroid Build Coastguard Worker for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: 2585*da0073e9SAndroid Build Coastguard Worker x = torch.randn(sizex, device=device, dtype=torch.float) 2586*da0073e9SAndroid Build Coastguard Worker dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float) 2587*da0073e9SAndroid Build Coastguard Worker y = x.clone() 2588*da0073e9SAndroid Build Coastguard Worker eps = 1e-6 2589*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 2590*da0073e9SAndroid Build Coastguard Worker d = torch.cdist(x, y) 2591*da0073e9SAndroid Build Coastguard Worker d.backward(dist_grad) 2592*da0073e9SAndroid Build Coastguard Worker # Check that the backward passs does not contain invalid 2593*da0073e9SAndroid Build Coastguard Worker # values such as nan or inf 2594*da0073e9SAndroid Build Coastguard Worker assert torch.isfinite(x.grad).all() 2595*da0073e9SAndroid Build Coastguard Worker 2596*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2597*da0073e9SAndroid Build Coastguard Worker def test_cumsum(self, device): 2598*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, device=device) 2599*da0073e9SAndroid Build Coastguard Worker res1 = torch.cumsum(x, 1) 2600*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor([]).to(device) 2601*da0073e9SAndroid Build Coastguard Worker torch.cumsum(x, 1, out=res2) 2602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 2603*da0073e9SAndroid Build Coastguard Worker x.cumsum_(1) 2604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, x) 2605*da0073e9SAndroid Build Coastguard Worker 2606*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[True, False, True], 2607*da0073e9SAndroid Build Coastguard Worker [False, False, False], 2608*da0073e9SAndroid Build Coastguard Worker [True, True, True]], device=device) 2609*da0073e9SAndroid Build Coastguard Worker b = a.byte() 2610*da0073e9SAndroid Build Coastguard Worker aRes = torch.cumsum(a, 0) 2611*da0073e9SAndroid Build Coastguard Worker bRes = torch.cumsum(b, 0) 2612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, bRes) 2613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, torch.tensor([[1, 0, 1], 2614*da0073e9SAndroid Build Coastguard Worker [1, 0, 1], 2615*da0073e9SAndroid Build Coastguard Worker [2, 1, 2]])) 2616*da0073e9SAndroid Build Coastguard Worker 2617*da0073e9SAndroid Build Coastguard Worker aRes = torch.cumsum(a, 1) 2618*da0073e9SAndroid Build Coastguard Worker bRes = torch.cumsum(b, 1) 2619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, bRes) 2620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, torch.tensor([[1, 1, 2], 2621*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 2622*da0073e9SAndroid Build Coastguard Worker [1, 2, 3]])) 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker # Check that cummulative sum over a zero length dimension doesn't crash on backprop. 2625*da0073e9SAndroid Build Coastguard Worker # Also check that cumsum over other dimensions in a tensor with a zero-length 2626*da0073e9SAndroid Build Coastguard Worker # dimensiuon also works 2627*da0073e9SAndroid Build Coastguard Worker # Also include a basic suite of similar tests for other bases cases. 2628*da0073e9SAndroid Build Coastguard Worker shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]] 2629*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2630*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape)): 2631*da0073e9SAndroid Build Coastguard Worker raw_tensor = torch.zeros(*shape, requires_grad=True) 2632*da0073e9SAndroid Build Coastguard Worker integrated = raw_tensor.cumsum(dim=dim) 2633*da0073e9SAndroid Build Coastguard Worker # Check that backward does not crash 2634*da0073e9SAndroid Build Coastguard Worker integrated.sum().backward() 2635*da0073e9SAndroid Build Coastguard Worker # Check that output maintained correct shape 2636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) 2637*da0073e9SAndroid Build Coastguard Worker 2638*da0073e9SAndroid Build Coastguard Worker # Check a scalar example 2639*da0073e9SAndroid Build Coastguard Worker raw_tensor = torch.tensor(3., requires_grad=True) 2640*da0073e9SAndroid Build Coastguard Worker integrated = raw_tensor.cumsum(dim=-1) 2641*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor, integrated) 2642*da0073e9SAndroid Build Coastguard Worker # Check that backward does not crash 2643*da0073e9SAndroid Build Coastguard Worker integrated.sum().backward() 2644*da0073e9SAndroid Build Coastguard Worker # Check that output maintained correct shape 2645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2648*da0073e9SAndroid Build Coastguard Worker def test_cumprod(self, device): 2649*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, device=device) 2650*da0073e9SAndroid Build Coastguard Worker res1 = torch.cumprod(x, 1) 2651*da0073e9SAndroid Build Coastguard Worker res2 = torch.tensor([]).to(device) 2652*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 2653*da0073e9SAndroid Build Coastguard Worker torch.cumprod(x, 1, out=res2) 2654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2) 2655*da0073e9SAndroid Build Coastguard Worker x.cumprod_(1) 2656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, x) 2657*da0073e9SAndroid Build Coastguard Worker 2658*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[True, False, True], 2659*da0073e9SAndroid Build Coastguard Worker [False, False, False], 2660*da0073e9SAndroid Build Coastguard Worker [True, True, True]], dtype=torch.bool, device=device) 2661*da0073e9SAndroid Build Coastguard Worker b = a.byte() 2662*da0073e9SAndroid Build Coastguard Worker aRes = torch.cumprod(a, 0) 2663*da0073e9SAndroid Build Coastguard Worker bRes = torch.cumprod(b, 0) 2664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, bRes) 2665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, torch.tensor([[1, 0, 1], 2666*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 2667*da0073e9SAndroid Build Coastguard Worker [0, 0, 0]])) 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker aRes = torch.cumprod(a, 1) 2670*da0073e9SAndroid Build Coastguard Worker bRes = torch.cumprod(b, 1) 2671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, bRes) 2672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes, torch.tensor([[1, 0, 0], 2673*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 2674*da0073e9SAndroid Build Coastguard Worker [1, 1, 1]])) 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker # Check that cummulative prod over a zero length dimension doesn't crash on backprop. 2677*da0073e9SAndroid Build Coastguard Worker # Also check that cumprod over other dimensions in a tensor with a zero-length 2678*da0073e9SAndroid Build Coastguard Worker # dimensiuon also works 2679*da0073e9SAndroid Build Coastguard Worker # Also include a basic suite of similar tests for other bases cases. 2680*da0073e9SAndroid Build Coastguard Worker shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]] 2681*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2682*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape)): 2683*da0073e9SAndroid Build Coastguard Worker raw_tensor = torch.zeros(*shape, requires_grad=True) 2684*da0073e9SAndroid Build Coastguard Worker integrated = raw_tensor.cumprod(dim=dim) 2685*da0073e9SAndroid Build Coastguard Worker # Check that backward does not crash 2686*da0073e9SAndroid Build Coastguard Worker integrated.sum().backward() 2687*da0073e9SAndroid Build Coastguard Worker # Check that output maintained correct shape 2688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) 2689*da0073e9SAndroid Build Coastguard Worker 2690*da0073e9SAndroid Build Coastguard Worker # Check a scalar example 2691*da0073e9SAndroid Build Coastguard Worker raw_tensor = torch.tensor(3., requires_grad=True) 2692*da0073e9SAndroid Build Coastguard Worker integrated = raw_tensor.cumprod(dim=-1) 2693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor, integrated) 2694*da0073e9SAndroid Build Coastguard Worker # Check that backward does not crash 2695*da0073e9SAndroid Build Coastguard Worker integrated.sum().backward() 2696*da0073e9SAndroid Build Coastguard Worker # Check that output maintained correct shape 2697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) 2698*da0073e9SAndroid Build Coastguard Worker 2699*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2700*da0073e9SAndroid Build Coastguard Worker def test_cummax_cummin(self, device): 2701*da0073e9SAndroid Build Coastguard Worker def test_ops(op, string_of_function_name, expected_output1, expected_output2): 2702*da0073e9SAndroid Build Coastguard Worker x = torch.rand(100, 100, device=device) 2703*da0073e9SAndroid Build Coastguard Worker out1 = op(x, 1) 2704*da0073e9SAndroid Build Coastguard Worker res2 = torch.empty(0, device=device) 2705*da0073e9SAndroid Build Coastguard Worker indices2 = torch.empty(0, dtype=torch.int64, device=device) 2706*da0073e9SAndroid Build Coastguard Worker op(x, 1, out=(res2, indices2)) 2707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1[0], res2) 2708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1[1], indices2) 2709*da0073e9SAndroid Build Coastguard Worker 2710*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[True, False, True], 2711*da0073e9SAndroid Build Coastguard Worker [False, False, False], 2712*da0073e9SAndroid Build Coastguard Worker [True, True, True]], dtype=torch.bool, device=device) 2713*da0073e9SAndroid Build Coastguard Worker b = a.byte() 2714*da0073e9SAndroid Build Coastguard Worker aRes = op(a, 0) 2715*da0073e9SAndroid Build Coastguard Worker bRes = op(b, 0) 2716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes[0], bRes[0].bool()) 2717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(aRes[0], expected_output1.bool()) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker # test inf and nan input 2720*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1]) 2721*da0073e9SAndroid Build Coastguard Worker xRes = op(x, 0)[0] 2722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xRes, expected_output2) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker # op shouldn't support values, indices with a dtype, device type or layout 2725*da0073e9SAndroid Build Coastguard Worker # different from that of input tensor 2726*da0073e9SAndroid Build Coastguard Worker t = torch.randn(10) 2727*da0073e9SAndroid Build Coastguard Worker values = torch.empty(0, dtype=torch.int16) 2728*da0073e9SAndroid Build Coastguard Worker indices = torch.empty(0, dtype=torch.int64) 2729*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2730*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2731*da0073e9SAndroid Build Coastguard Worker 'expected scalar_type Float but found Short'): 2732*da0073e9SAndroid Build Coastguard Worker op(t, 0, out=(values, indices)) 2733*da0073e9SAndroid Build Coastguard Worker 2734*da0073e9SAndroid Build Coastguard Worker # Check that op over a zero length dimension doesn't crash on backprop. 2735*da0073e9SAndroid Build Coastguard Worker # Also check that op over other dimensions in a tensor with a zero-length 2736*da0073e9SAndroid Build Coastguard Worker # dimension also works 2737*da0073e9SAndroid Build Coastguard Worker # Also include a basic suite of similar tests for other bases cases. 2738*da0073e9SAndroid Build Coastguard Worker shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]] 2739*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2740*da0073e9SAndroid Build Coastguard Worker for dim in range(len(shape)): 2741*da0073e9SAndroid Build Coastguard Worker raw_tensor = torch.zeros(*shape, requires_grad=True) 2742*da0073e9SAndroid Build Coastguard Worker integrated = getattr(raw_tensor, string_of_function_name)(dim=dim) 2743*da0073e9SAndroid Build Coastguard Worker # Check that backward does not crash 2744*da0073e9SAndroid Build Coastguard Worker integrated[0].sum().backward() 2745*da0073e9SAndroid Build Coastguard Worker # Check that output maintained correct shape 2746*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) 2747*da0073e9SAndroid Build Coastguard Worker 2748*da0073e9SAndroid Build Coastguard Worker # Check a scalar example 2749*da0073e9SAndroid Build Coastguard Worker raw_tensor = torch.tensor(3., requires_grad=True) 2750*da0073e9SAndroid Build Coastguard Worker integrated = getattr(raw_tensor, string_of_function_name)(dim=-1) 2751*da0073e9SAndroid Build Coastguard Worker # Check that backward does not crash 2752*da0073e9SAndroid Build Coastguard Worker integrated[0].sum().backward() 2753*da0073e9SAndroid Build Coastguard Worker # Check that output maintained correct shape 2754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker expected_out = torch.tensor([4, inf, inf, inf, inf, nan, nan]) 2757*da0073e9SAndroid Build Coastguard Worker test_ops(torch.cummax, "cummax", torch.tensor([[1, 0, 1], 2758*da0073e9SAndroid Build Coastguard Worker [1, 0, 1], 2759*da0073e9SAndroid Build Coastguard Worker [1, 1, 1]]), expected_out) 2760*da0073e9SAndroid Build Coastguard Worker 2761*da0073e9SAndroid Build Coastguard Worker expected_out = torch.tensor([4, 4, 1.5, -inf, -inf, nan, nan]) 2762*da0073e9SAndroid Build Coastguard Worker test_ops(torch.cummin, "cummin", torch.tensor([[1, 0, 1], 2763*da0073e9SAndroid Build Coastguard Worker [0, 0, 0], 2764*da0073e9SAndroid Build Coastguard Worker [0, 0, 0]]), expected_out) 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker @skipIfMps 2767*da0073e9SAndroid Build Coastguard Worker def test_logcumsumexp(self, device): 2768*da0073e9SAndroid Build Coastguard Worker def logcumsumexp(a, axis): 2769*da0073e9SAndroid Build Coastguard Worker return torch.cumsum(a.exp(), axis=axis).log_() 2770*da0073e9SAndroid Build Coastguard Worker 2771*da0073e9SAndroid Build Coastguard Worker axis = -1 2772*da0073e9SAndroid Build Coastguard Worker a = torch.randn(100, 100, device=device) 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker actual = a.logcumsumexp(axis) 2775*da0073e9SAndroid Build Coastguard Worker expected = logcumsumexp(a, axis) 2776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, actual.dtype) 2777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected.shape, actual.shape) 2778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2779*da0073e9SAndroid Build Coastguard Worker 2780*da0073e9SAndroid Build Coastguard Worker # check -inf and nan handling 2781*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([-float('inf'), -float('inf'), 1.0, 1.0, float('inf'), 2782*da0073e9SAndroid Build Coastguard Worker float('inf'), float('nan'), 1.0, 1.0], device=device) 2783*da0073e9SAndroid Build Coastguard Worker x2d = x.unsqueeze(0).expand(2, -1) 2784*da0073e9SAndroid Build Coastguard Worker 2785*da0073e9SAndroid Build Coastguard Worker for inp in (x, x2d): 2786*da0073e9SAndroid Build Coastguard Worker actual = inp.logcumsumexp(axis) 2787*da0073e9SAndroid Build Coastguard Worker expected = logcumsumexp(inp, axis) 2788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 2789*da0073e9SAndroid Build Coastguard Worker 2790*da0073e9SAndroid Build Coastguard Worker # Check that out is actually inplace 2791*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 2, device=device) 2792*da0073e9SAndroid Build Coastguard Worker inplace_out = torch.zeros(5, 2, device=device) 2793*da0073e9SAndroid Build Coastguard Worker 2794*da0073e9SAndroid Build Coastguard Worker expected = logcumsumexp(b, axis) 2795*da0073e9SAndroid Build Coastguard Worker torch.logcumsumexp(b, axis=axis, out=inplace_out) 2796*da0073e9SAndroid Build Coastguard Worker 2797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inplace_out, expected) 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker # Check input and inplace_output type mismatch 2800*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 2, device=device, dtype=torch.float64) 2801*da0073e9SAndroid Build Coastguard Worker inplace_out = torch.zeros(5, 2, device=device, dtype=torch.float32) 2802*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2803*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2804*da0073e9SAndroid Build Coastguard Worker 'expected scalar_type Double but found Float'): 2805*da0073e9SAndroid Build Coastguard Worker torch.logcumsumexp(b, axis, out=inplace_out) 2806*da0073e9SAndroid Build Coastguard Worker 2807*da0073e9SAndroid Build Coastguard Worker def _test_diff_numpy(self, t, dims=None): 2808*da0073e9SAndroid Build Coastguard Worker # Helper for test_diff to compare with NumPy reference implementation 2809*da0073e9SAndroid Build Coastguard Worker def to_np(t): 2810*da0073e9SAndroid Build Coastguard Worker if t.dtype == torch.bfloat16: 2811*da0073e9SAndroid Build Coastguard Worker return t.to(dtype=torch.float, device="cpu").numpy() 2812*da0073e9SAndroid Build Coastguard Worker else: 2813*da0073e9SAndroid Build Coastguard Worker return t.cpu().numpy() 2814*da0073e9SAndroid Build Coastguard Worker 2815*da0073e9SAndroid Build Coastguard Worker for dim in dims if dims else range(t.dim()): 2816*da0073e9SAndroid Build Coastguard Worker prepend = t.narrow(dim, 0, 1) 2817*da0073e9SAndroid Build Coastguard Worker append = t.narrow(dim, 0, 1) 2818*da0073e9SAndroid Build Coastguard Worker np_t = to_np(t) 2819*da0073e9SAndroid Build Coastguard Worker 2820*da0073e9SAndroid Build Coastguard Worker # test when no prepend and append 2821*da0073e9SAndroid Build Coastguard Worker for n in range(t.size(dim)): 2822*da0073e9SAndroid Build Coastguard Worker actual = torch.diff(t, dim=dim, n=n) 2823*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n)) 2824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(t.dtype)) 2825*da0073e9SAndroid Build Coastguard Worker 2826*da0073e9SAndroid Build Coastguard Worker # test when prepend and append's size along dim is 1 2827*da0073e9SAndroid Build Coastguard Worker for n in range(1, t.size(dim) + 4): 2828*da0073e9SAndroid Build Coastguard Worker actual = torch.diff(t, dim=dim, n=n, prepend=prepend, append=append) 2829*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=to_np(prepend), append=to_np(append))) 2830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(t.dtype)) 2831*da0073e9SAndroid Build Coastguard Worker 2832*da0073e9SAndroid Build Coastguard Worker # test when prepend and append's size along dim != 1 2833*da0073e9SAndroid Build Coastguard Worker for n in range(1, t.size(dim) * 3): 2834*da0073e9SAndroid Build Coastguard Worker actual = torch.diff(t, dim=dim, n=n, prepend=t, append=t) 2835*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.diff(np_t, axis=dim, n=n, prepend=np_t, append=np_t)) 2836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected.to(t.dtype)) 2837*da0073e9SAndroid Build Coastguard Worker 2838*da0073e9SAndroid Build Coastguard Worker # All tensors appear contiguous on XLA 2839*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2840*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) 2841*da0073e9SAndroid Build Coastguard Worker def test_diff_noncontig(self, device, dtype): 2842*da0073e9SAndroid Build Coastguard Worker shapes = ( 2843*da0073e9SAndroid Build Coastguard Worker (1,), 2844*da0073e9SAndroid Build Coastguard Worker (1, 5), 2845*da0073e9SAndroid Build Coastguard Worker (3, 5), 2846*da0073e9SAndroid Build Coastguard Worker (1, 5, 1), 2847*da0073e9SAndroid Build Coastguard Worker (2, 3, 5)) 2848*da0073e9SAndroid Build Coastguard Worker 2849*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2850*da0073e9SAndroid Build Coastguard Worker contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) 2851*da0073e9SAndroid Build Coastguard Worker 2852*da0073e9SAndroid Build Coastguard Worker non_contig = torch.empty(shape + (2, 2), device=device, dtype=dtype)[..., 0] 2853*da0073e9SAndroid Build Coastguard Worker non_contig = non_contig.select(-1, -1) 2854*da0073e9SAndroid Build Coastguard Worker non_contig.copy_(contig) 2855*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not non_contig.is_contiguous() or shape == (1,)) 2856*da0073e9SAndroid Build Coastguard Worker 2857*da0073e9SAndroid Build Coastguard Worker self._test_diff_numpy(non_contig) 2858*da0073e9SAndroid Build Coastguard Worker 2859*da0073e9SAndroid Build Coastguard Worker # RngNormal not implemented for type f16 for XLA 2860*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool)) 2861*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool)) 2862*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool)) 2863*da0073e9SAndroid Build Coastguard Worker def test_diff(self, device, dtype): 2864*da0073e9SAndroid Build Coastguard Worker shapes = ( 2865*da0073e9SAndroid Build Coastguard Worker (1,), 2866*da0073e9SAndroid Build Coastguard Worker (1, 5), 2867*da0073e9SAndroid Build Coastguard Worker (3, 5), 2868*da0073e9SAndroid Build Coastguard Worker (1, 5, 1), 2869*da0073e9SAndroid Build Coastguard Worker (2, 3, 5)) 2870*da0073e9SAndroid Build Coastguard Worker 2871*da0073e9SAndroid Build Coastguard Worker for shape in shapes: 2872*da0073e9SAndroid Build Coastguard Worker contig = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9) 2873*da0073e9SAndroid Build Coastguard Worker self._test_diff_numpy(contig) 2874*da0073e9SAndroid Build Coastguard Worker 2875*da0073e9SAndroid Build Coastguard Worker t = torch.ones(2, 3) 2876*da0073e9SAndroid Build Coastguard Worker 2877*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2878*da0073e9SAndroid Build Coastguard Worker RuntimeError, 'diff expects prepend or append to be the same dimension as input'): 2879*da0073e9SAndroid Build Coastguard Worker invalid_prepend = torch.tensor([1, 2, 3], device=device, dtype=dtype) 2880*da0073e9SAndroid Build Coastguard Worker t.diff(dim=0, prepend=invalid_prepend) 2881*da0073e9SAndroid Build Coastguard Worker 2882*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2883*da0073e9SAndroid Build Coastguard Worker RuntimeError, 'diff expects the shape of tensor to prepend or append to match that of input'): 2884*da0073e9SAndroid Build Coastguard Worker invalid_prepend = torch.tensor([[0, 1]], device=device, dtype=dtype) 2885*da0073e9SAndroid Build Coastguard Worker t.diff(dim=0, prepend=invalid_prepend) 2886*da0073e9SAndroid Build Coastguard Worker 2887*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2888*da0073e9SAndroid Build Coastguard Worker RuntimeError, 'diff expects input to be at least one-dimensional'): 2889*da0073e9SAndroid Build Coastguard Worker scalar = torch.tensor(2, device=device, dtype=dtype) 2890*da0073e9SAndroid Build Coastguard Worker torch.diff(scalar) 2891*da0073e9SAndroid Build Coastguard Worker 2892*da0073e9SAndroid Build Coastguard Worker # if the given input arg is not a list, it returns a list of single element: [arg] 2893*da0073e9SAndroid Build Coastguard Worker def _wrap_to_list(self, input_array): 2894*da0073e9SAndroid Build Coastguard Worker return input_array if isinstance(input_array, list) else [input_array] 2895*da0073e9SAndroid Build Coastguard Worker 2896*da0073e9SAndroid Build Coastguard Worker # To ensure inf, -inf, and nan values do not cause divergence between Numpy and PyTorch. 2897*da0073e9SAndroid Build Coastguard Worker # There are two types of possible divergence: 2898*da0073e9SAndroid Build Coastguard Worker # 1. When we compute a,b both real numbers and has very small absolute values (i.e. very near to 0.0) 2899*da0073e9SAndroid Build Coastguard Worker # then, result of a/b be inf, -inf and nan, and this cause divergence. 2900*da0073e9SAndroid Build Coastguard Worker # 2. When we are dividing complex numbers by zero. For example, when a = torch.tensor(3+5j) we have 2901*da0073e9SAndroid Build Coastguard Worker # a/0 to be equal to nan + nan*j in PyTorch and inf + inf*j in Numpy. 2902*da0073e9SAndroid Build Coastguard Worker def _inf_nan_preprocess(self, actual, expected): 2903*da0073e9SAndroid Build Coastguard Worker for i in range(len(expected)): 2904*da0073e9SAndroid Build Coastguard Worker expected[i] = np.nan_to_num(expected[i], nan=nan, posinf=nan, neginf=nan) 2905*da0073e9SAndroid Build Coastguard Worker # nan_to_num is not defined for complex tensors in PyTorch. 2906*da0073e9SAndroid Build Coastguard Worker if actual[i].dtype == torch.complex64 : 2907*da0073e9SAndroid Build Coastguard Worker actual[i].real = torch.nan_to_num(actual[i].real, nan=nan, posinf=nan, neginf=nan) 2908*da0073e9SAndroid Build Coastguard Worker actual[i].imag = torch.nan_to_num(actual[i].imag, nan=nan, posinf=nan, neginf=nan) 2909*da0073e9SAndroid Build Coastguard Worker else: 2910*da0073e9SAndroid Build Coastguard Worker actual[i] = torch.nan_to_num(actual[i], nan=nan, posinf=nan, neginf=nan) 2911*da0073e9SAndroid Build Coastguard Worker 2912*da0073e9SAndroid Build Coastguard Worker return actual, expected 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2915*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.long, torch.float32, torch.complex64) 2916*da0073e9SAndroid Build Coastguard Worker def test_gradient_all(self, device, dtype): 2917*da0073e9SAndroid Build Coastguard Worker def create_scalar(shape): 2918*da0073e9SAndroid Build Coastguard Worker return make_tensor((1,), device='cpu', dtype=dtype, low=1.).item() 2919*da0073e9SAndroid Build Coastguard Worker 2920*da0073e9SAndroid Build Coastguard Worker def create_list(shape): 2921*da0073e9SAndroid Build Coastguard Worker return make_tensor((len(shape),), device='cpu', dtype=dtype, low=1.).tolist() 2922*da0073e9SAndroid Build Coastguard Worker 2923*da0073e9SAndroid Build Coastguard Worker def create_coordinate_tensors(shape): 2924*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 2925*da0073e9SAndroid Build Coastguard Worker for i in range(len(shape)): 2926*da0073e9SAndroid Build Coastguard Worker tensor_list.append(make_tensor((shape[i],), device=device, dtype=dtype)) 2927*da0073e9SAndroid Build Coastguard Worker return tensor_list 2928*da0073e9SAndroid Build Coastguard Worker 2929*da0073e9SAndroid Build Coastguard Worker def filter_shape(shape, dim): 2930*da0073e9SAndroid Build Coastguard Worker filtered_shape = [] 2931*da0073e9SAndroid Build Coastguard Worker for i in range(len(dim)): 2932*da0073e9SAndroid Build Coastguard Worker filtered_shape.append(shape[dim[i]]) 2933*da0073e9SAndroid Build Coastguard Worker return filtered_shape 2934*da0073e9SAndroid Build Coastguard Worker 2935*da0073e9SAndroid Build Coastguard Worker # shape, dims format 2936*da0073e9SAndroid Build Coastguard Worker test_cases = ( 2937*da0073e9SAndroid Build Coastguard Worker ((5,), (0,)), 2938*da0073e9SAndroid Build Coastguard Worker ((4, 4), (0, 1)), 2939*da0073e9SAndroid Build Coastguard Worker ((3, 3, 3), (-1, 0)), 2940*da0073e9SAndroid Build Coastguard Worker ((4, 4, 4), (2,)), 2941*da0073e9SAndroid Build Coastguard Worker ((4, 4, 4), (0, 1)), 2942*da0073e9SAndroid Build Coastguard Worker ((4, 4, 4, 3), (0, 2, 3)), 2943*da0073e9SAndroid Build Coastguard Worker ((4, 5, 3, 4, 3), (1, 2)), 2944*da0073e9SAndroid Build Coastguard Worker ((4, 3, 6, 5, 3), (2, 4)), 2945*da0073e9SAndroid Build Coastguard Worker ((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)), 2946*da0073e9SAndroid Build Coastguard Worker ((1, 3, 3), (1, 2)), 2947*da0073e9SAndroid Build Coastguard Worker ((1, 5), (1,)), 2948*da0073e9SAndroid Build Coastguard Worker ) 2949*da0073e9SAndroid Build Coastguard Worker 2950*da0073e9SAndroid Build Coastguard Worker for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2], 2951*da0073e9SAndroid Build Coastguard Worker (create_scalar, create_list, create_coordinate_tensors)): 2952*da0073e9SAndroid Build Coastguard Worker shape, dims = case 2953*da0073e9SAndroid Build Coastguard Worker # filter shape by dims before passing filtered shape to create_* functions 2954*da0073e9SAndroid Build Coastguard Worker filtered_shape = filter_shape(shape, dims) 2955*da0073e9SAndroid Build Coastguard Worker 2956*da0073e9SAndroid Build Coastguard Worker spacing = space_fn(filtered_shape) 2957*da0073e9SAndroid Build Coastguard Worker t = make_tensor(shape, device=device, dtype=dtype, noncontiguous=not contig) 2958*da0073e9SAndroid Build Coastguard Worker t_np = t.cpu().numpy() 2959*da0073e9SAndroid Build Coastguard Worker 2960*da0073e9SAndroid Build Coastguard Worker actual = torch.gradient(t, spacing=spacing, dim=dims, edge_order=edge_order) 2961*da0073e9SAndroid Build Coastguard Worker if space_fn == create_coordinate_tensors and spacing[0].device != 'cpu': 2962*da0073e9SAndroid Build Coastguard Worker spacing = [space.cpu().detach().numpy() for space in spacing] 2963*da0073e9SAndroid Build Coastguard Worker expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order) 2964*da0073e9SAndroid Build Coastguard Worker actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected)) 2965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, equal_nan=True, atol=1e-4, rtol=0, exact_dtype=False) 2966*da0073e9SAndroid Build Coastguard Worker 2967*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2968*da0073e9SAndroid Build Coastguard Worker @slowTestIf(TEST_WITH_TORCHINDUCTOR) 2969*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.long, torch.float32, torch.complex64) 2970*da0073e9SAndroid Build Coastguard Worker def test_gradient_extreme_cases(self, device, dtype): 2971*da0073e9SAndroid Build Coastguard Worker # Test behaviour for inf and nan values 2972*da0073e9SAndroid Build Coastguard Worker actual = torch.gradient(torch.tensor([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan])) 2973*da0073e9SAndroid Build Coastguard Worker expected = np.gradient(np.array([2, -2, inf, inf, -inf, -inf, inf, 3, -inf, 2, nan, nan, 3, inf, nan])) 2974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, self._wrap_to_list(expected), exact_dtype=False) 2975*da0073e9SAndroid Build Coastguard Worker 2976*da0073e9SAndroid Build Coastguard Worker # Test behaviour in very big tensors 2977*da0073e9SAndroid Build Coastguard Worker large_size = 100000 2978*da0073e9SAndroid Build Coastguard Worker t = make_tensor((large_size,), dtype=dtype, device=device) 2979*da0073e9SAndroid Build Coastguard Worker t_np = t.cpu().numpy() 2980*da0073e9SAndroid Build Coastguard Worker coordinates_np = np.random.randn(large_size) 2981*da0073e9SAndroid Build Coastguard Worker coordinates = [torch.tensor(coordinates_np, device=device)] 2982*da0073e9SAndroid Build Coastguard Worker actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=1) 2983*da0073e9SAndroid Build Coastguard Worker expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=1)] 2984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 2985*da0073e9SAndroid Build Coastguard Worker 2986*da0073e9SAndroid Build Coastguard Worker actual = torch.gradient(t, spacing=coordinates, dim=0, edge_order=2) 2987*da0073e9SAndroid Build Coastguard Worker expected = [np.gradient(t_np, coordinates_np, axis=0, edge_order=2)] 2988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, exact_dtype=False) 2989*da0073e9SAndroid Build Coastguard Worker 2990*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 2991*da0073e9SAndroid Build Coastguard Worker def test_gradient_type_promotion(self, device): 2992*da0073e9SAndroid Build Coastguard Worker inputs = ( 2993*da0073e9SAndroid Build Coastguard Worker make_tensor((4, 4), device=device, dtype=torch.float32), 2994*da0073e9SAndroid Build Coastguard Worker make_tensor((4, 4), device=device, dtype=torch.complex64), 2995*da0073e9SAndroid Build Coastguard Worker make_tensor((4, 4), device=device, dtype=torch.int64), 2996*da0073e9SAndroid Build Coastguard Worker ) 2997*da0073e9SAndroid Build Coastguard Worker 2998*da0073e9SAndroid Build Coastguard Worker spacing = ( 2999*da0073e9SAndroid Build Coastguard Worker make_tensor((1,), device='cpu', dtype=torch.float32).item(), 3000*da0073e9SAndroid Build Coastguard Worker make_tensor((1,), device='cpu', dtype=torch.int64).item(), 3001*da0073e9SAndroid Build Coastguard Worker make_tensor((1,), device='cpu', dtype=torch.complex64).item(), 3002*da0073e9SAndroid Build Coastguard Worker make_tensor((2,), device='cpu', dtype=torch.float32, low=0.1).tolist(), 3003*da0073e9SAndroid Build Coastguard Worker make_tensor((2,), device='cpu', dtype=torch.int64, low=1).tolist(), 3004*da0073e9SAndroid Build Coastguard Worker make_tensor((2,), device='cpu', dtype=torch.complex64).tolist(), 3005*da0073e9SAndroid Build Coastguard Worker [make_tensor((4,), device=device, dtype=torch.float32), 3006*da0073e9SAndroid Build Coastguard Worker make_tensor((4,), device=device, dtype=torch.float32)], 3007*da0073e9SAndroid Build Coastguard Worker [make_tensor((4,), device=device, dtype=torch.int64), 3008*da0073e9SAndroid Build Coastguard Worker make_tensor((4,), device=device, dtype=torch.int64)], 3009*da0073e9SAndroid Build Coastguard Worker [make_tensor((4,), device=device, dtype=torch.complex64), 3010*da0073e9SAndroid Build Coastguard Worker make_tensor((4,), device=device, dtype=torch.complex64)], 3011*da0073e9SAndroid Build Coastguard Worker ) 3012*da0073e9SAndroid Build Coastguard Worker 3013*da0073e9SAndroid Build Coastguard Worker for input, spacing_or_coord, edge_order in product(inputs, spacing, [1, 2]): 3014*da0073e9SAndroid Build Coastguard Worker input_np = input.cpu().numpy() 3015*da0073e9SAndroid Build Coastguard Worker input_np = input.cpu().numpy() 3016*da0073e9SAndroid Build Coastguard Worker actual = torch.gradient(input, spacing=spacing_or_coord, dim=(0, 1), edge_order=edge_order) 3017*da0073e9SAndroid Build Coastguard Worker spacing_or_coord_wrapped = self._wrap_to_list(spacing_or_coord) 3018*da0073e9SAndroid Build Coastguard Worker spacing_or_coord_np = [] 3019*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(spacing_or_coord_wrapped[0]) and torch.device(spacing_or_coord_wrapped[0].device).type != 'cpu': 3020*da0073e9SAndroid Build Coastguard Worker for i in range(len(spacing_or_coord_wrapped)): 3021*da0073e9SAndroid Build Coastguard Worker spacing_or_coord_np.append(spacing_or_coord_wrapped[i].detach().clone().cpu().numpy()) 3022*da0073e9SAndroid Build Coastguard Worker else: 3023*da0073e9SAndroid Build Coastguard Worker spacing_or_coord_np = spacing_or_coord_wrapped 3024*da0073e9SAndroid Build Coastguard Worker expected = np.gradient(input_np, *spacing_or_coord_np, axis=(0, 1), edge_order=edge_order) 3025*da0073e9SAndroid Build Coastguard Worker if actual[0].dtype == torch.complex64 and input.dtype != torch.complex64: 3026*da0073e9SAndroid Build Coastguard Worker for i in range(len(actual)): 3027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual[i].real, expected[i].real, exact_dtype=False) 3028*da0073e9SAndroid Build Coastguard Worker # Type promotion fails on Numpy when spacing is given as complex number and input is given as real. 3029*da0073e9SAndroid Build Coastguard Worker # Result is given just as real number and all the imaginary parts to be equal to zero. 3030*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False) 3031*da0073e9SAndroid Build Coastguard Worker else: 3032*da0073e9SAndroid Build Coastguard Worker actual, expected = self._inf_nan_preprocess(list(actual), expected) 3033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False) 3034*da0073e9SAndroid Build Coastguard Worker 3035*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3036*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.long, torch.float32, torch.complex64) 3037*da0073e9SAndroid Build Coastguard Worker def test_gradient_spacing_list_length_error(self, device, dtype): 3038*da0073e9SAndroid Build Coastguard Worker t = make_tensor((2, 2), device=device, dtype=dtype) 3039*da0073e9SAndroid Build Coastguard Worker 3040*da0073e9SAndroid Build Coastguard Worker spacing = (make_tensor((2,), device=device, dtype=dtype),) 3041*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'): 3042*da0073e9SAndroid Build Coastguard Worker torch.gradient(t, spacing=spacing) 3043*da0073e9SAndroid Build Coastguard Worker 3044*da0073e9SAndroid Build Coastguard Worker spacing = (make_tensor((2,), device=device, dtype=dtype),) * 2 3045*da0073e9SAndroid Build Coastguard Worker torch.gradient(t, spacing=spacing) 3046*da0073e9SAndroid Build Coastguard Worker 3047*da0073e9SAndroid Build Coastguard Worker spacing = (make_tensor((2,), device=device, dtype=dtype),) * 3 3048*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'): 3049*da0073e9SAndroid Build Coastguard Worker torch.gradient(t, spacing=spacing) 3050*da0073e9SAndroid Build Coastguard Worker 3051*da0073e9SAndroid Build Coastguard Worker spacing = (2,) 3052*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'): 3053*da0073e9SAndroid Build Coastguard Worker torch.gradient(t, spacing=spacing) 3054*da0073e9SAndroid Build Coastguard Worker 3055*da0073e9SAndroid Build Coastguard Worker spacing = (2, 2) 3056*da0073e9SAndroid Build Coastguard Worker torch.gradient(t, spacing=spacing) 3057*da0073e9SAndroid Build Coastguard Worker 3058*da0073e9SAndroid Build Coastguard Worker spacing = (2, 2, 2) 3059*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'expected spacing to be'): 3060*da0073e9SAndroid Build Coastguard Worker torch.gradient(t, spacing=spacing) 3061*da0073e9SAndroid Build Coastguard Worker 3062*da0073e9SAndroid Build Coastguard Worker def _test_large_cum_fn_helper(self, x, fn): 3063*da0073e9SAndroid Build Coastguard Worker expected = fn(x.cpu().float()) 3064*da0073e9SAndroid Build Coastguard Worker actual = fn(x).cpu().float() 3065*da0073e9SAndroid Build Coastguard Worker # Avoid self.assertEqual to save memory. 3066*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(expected, actual) 3067*da0073e9SAndroid Build Coastguard Worker 3068*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration") 3069*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "psutil issue for largeTensorTest. Too large for Jetson.") 3070*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3071*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) # only small dtype not to get oom 3072*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('25GB', device='cpu') 3073*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('4GB', device='cuda') 3074*da0073e9SAndroid Build Coastguard Worker def test_large_cumsum(self, device, dtype): 3075*da0073e9SAndroid Build Coastguard Worker # initialization to avoid overflow and half caveats 3076*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2**30 + 200, device=device, dtype=dtype) 3077*da0073e9SAndroid Build Coastguard Worker x[::3] = -3 3078*da0073e9SAndroid Build Coastguard Worker x[1::3] = 2 3079*da0073e9SAndroid Build Coastguard Worker x[2::3] = 1 3080*da0073e9SAndroid Build Coastguard Worker self._test_large_cum_fn_helper(x, lambda x: torch.cumsum(x, 0)) 3081*da0073e9SAndroid Build Coastguard Worker 3082*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3083*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.half) # only small dtype not to get oom 3084*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('25GB', device='cpu') 3085*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('4GB', device='cuda') 3086*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "psutil issue for largeTensorTest. Too large for Jetson.") 3087*da0073e9SAndroid Build Coastguard Worker def test_large_cumprod(self, device, dtype): 3088*da0073e9SAndroid Build Coastguard Worker # initialization to avoid overflow and half caveats 3089*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2**30 + 200, device=device, dtype=dtype) 3090*da0073e9SAndroid Build Coastguard Worker x[::3] = 8 3091*da0073e9SAndroid Build Coastguard Worker x[1::3] = .25 3092*da0073e9SAndroid Build Coastguard Worker x[2::3] = .5 3093*da0073e9SAndroid Build Coastguard Worker self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0)) 3094*da0073e9SAndroid Build Coastguard Worker 3095*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Torchdynamo fails with unknown reason") 3096*da0073e9SAndroid Build Coastguard Worker @skipIfMps 3097*da0073e9SAndroid Build Coastguard Worker def test_discontiguous_out_cumsum(self, device): 3098*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 8, device=device) 3099*da0073e9SAndroid Build Coastguard Worker y = torch.empty(4, 16, device=device)[:, ::2] 3100*da0073e9SAndroid Build Coastguard Worker out = torch.cumsum(x, 0) 3101*da0073e9SAndroid Build Coastguard Worker torch.cumsum(x, 0, out=y) 3102*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 3103*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, y, atol=0., rtol=0.) 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker def _test_cumminmax_helper(self, x, fn, expected_val, expected_ind): 3106*da0073e9SAndroid Build Coastguard Worker val, ind = fn(x, -1) 3107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(val, expected_val, atol=0, rtol=0) 3108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ind, expected_ind, atol=0, rtol=0) 3109*da0073e9SAndroid Build Coastguard Worker out_val = torch.empty_like(val).t().contiguous().t() 3110*da0073e9SAndroid Build Coastguard Worker out_ind = torch.empty_like(ind).t().contiguous().t() 3111*da0073e9SAndroid Build Coastguard Worker fn(x, -1, out=(out_val, out_ind)) 3112*da0073e9SAndroid Build Coastguard Worker # TODO: Fix this. It reproduces with aot_eager too, and looks like a functionalization bug. 3113*da0073e9SAndroid Build Coastguard Worker # (the problematic case seems rare, as we're calling an out= op directly from user code, 3114*da0073e9SAndroid Build Coastguard Worker # where the passed-in out tensors are non-contiguous). 3115*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 3116*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out_val.is_contiguous()) 3117*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out_ind.is_contiguous()) 3118*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_val, expected_val, atol=0, rtol=0) 3119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ind, expected_ind, atol=0, rtol=0) 3120*da0073e9SAndroid Build Coastguard Worker 3121*da0073e9SAndroid Build Coastguard Worker @skipIfMps 3122*da0073e9SAndroid Build Coastguard Worker def test_cummax_discontiguous(self, device): 3123*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[0, 1, 2, 3, 2, 1], [4, 5, 6, 5, 6, 7]], device=device, dtype=torch.float).t().contiguous().t() 3124*da0073e9SAndroid Build Coastguard Worker expected_val = torch.tensor([[0, 1, 2, 3, 3, 3], [4, 5, 6, 6, 6, 7]], device=device, dtype=torch.float) 3125*da0073e9SAndroid Build Coastguard Worker expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 2, 4, 5]], device=device, dtype=torch.long) 3126*da0073e9SAndroid Build Coastguard Worker self._test_cumminmax_helper(x, torch.cummax, expected_val, expected_ind) 3127*da0073e9SAndroid Build Coastguard Worker 3128*da0073e9SAndroid Build Coastguard Worker @skipIfMps 3129*da0073e9SAndroid Build Coastguard Worker def test_cummin_discontiguous(self, device): 3130*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[3, 2, 1, 0, 1, 2], [7, 6, 5, 4, 5, 2]], device=device, dtype=torch.float).t().contiguous().t() 3131*da0073e9SAndroid Build Coastguard Worker expected_val = torch.tensor([[3, 2, 1, 0, 0, 0], [7, 6, 5, 4, 4, 2]], device=device, dtype=torch.float) 3132*da0073e9SAndroid Build Coastguard Worker expected_ind = torch.tensor([[0, 1, 2, 3, 3, 3], [0, 1, 2, 3, 3, 5]], device=device, dtype=torch.long) 3133*da0073e9SAndroid Build Coastguard Worker self._test_cumminmax_helper(x, torch.cummin, expected_val, expected_ind) 3134*da0073e9SAndroid Build Coastguard Worker 3135*da0073e9SAndroid Build Coastguard Worker def test_bool_tensor_value_change(self, device): 3136*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([True, False], dtype=torch.bool, device=device) 3137*da0073e9SAndroid Build Coastguard Worker x[0] = False 3138*da0073e9SAndroid Build Coastguard Worker x[1] = True 3139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device)) 3140*da0073e9SAndroid Build Coastguard Worker 3141*da0073e9SAndroid Build Coastguard Worker # FIXME: move to shape ops test suite 3142*da0073e9SAndroid Build Coastguard Worker def test_unfold_all_devices_and_dtypes(self, device): 3143*da0073e9SAndroid Build Coastguard Worker for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): 3144*da0073e9SAndroid Build Coastguard Worker 3145*da0073e9SAndroid Build Coastguard Worker if dt == torch.bool: 3146*da0073e9SAndroid Build Coastguard Worker x = torch.empty((0, 1, 3, 0), dtype=dt, device=device) 3147*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) 3148*da0073e9SAndroid Build Coastguard Worker else: 3149*da0073e9SAndroid Build Coastguard Worker x = torch.empty((0, 1, 3, 0), dtype=dt, device=device) 3150*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) 3151*da0073e9SAndroid Build Coastguard Worker 3152*da0073e9SAndroid Build Coastguard Worker # FIXME: move to shape ops test suite 3153*da0073e9SAndroid Build Coastguard Worker def test_unfold_scalars(self, device): 3154*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.5, device=device) 3155*da0073e9SAndroid Build Coastguard Worker # unfold on a 0-dimensional tensor should always return a 1-d dimensional 3156*da0073e9SAndroid Build Coastguard Worker # tensor of shape [size] (i.e., the second parameter to unfold) 3157*da0073e9SAndroid Build Coastguard Worker 3158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1)) 3159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2)) 3160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1)) 3161*da0073e9SAndroid Build Coastguard Worker 3162*da0073e9SAndroid Build Coastguard Worker # FIXME: move to data movement test suite 3163*da0073e9SAndroid Build Coastguard Worker def test_copy_all_dtypes_and_devices(self, device): 3164*da0073e9SAndroid Build Coastguard Worker from copy import copy 3165*da0073e9SAndroid Build Coastguard Worker for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): 3166*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device) 3167*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone() 3168*da0073e9SAndroid Build Coastguard Worker y = copy(x) 3169*da0073e9SAndroid Build Coastguard Worker y.fill_(1) 3170*da0073e9SAndroid Build Coastguard Worker # copy is a shallow copy, only copies the tensor view, 3171*da0073e9SAndroid Build Coastguard Worker # not the data 3172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 3173*da0073e9SAndroid Build Coastguard Worker 3174*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3175*da0073e9SAndroid Build Coastguard Worker def test_bfloat16_neg_abs(self, device): 3176*da0073e9SAndroid Build Coastguard Worker src = torch.randn(256) 3177*da0073e9SAndroid Build Coastguard Worker src[0] = torch.nan 3178*da0073e9SAndroid Build Coastguard Worker src[1] = -torch.nan 3179*da0073e9SAndroid Build Coastguard Worker src[2] = torch.inf 3180*da0073e9SAndroid Build Coastguard Worker src[3] = -torch.inf 3181*da0073e9SAndroid Build Coastguard Worker src_bf16 = src.bfloat16() 3182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.neg().bfloat16(), src_bf16.neg()) 3183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src.abs().bfloat16(), src_bf16.abs()) 3184*da0073e9SAndroid Build Coastguard Worker 3185*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3186*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.bfloat16, torch.half) 3187*da0073e9SAndroid Build Coastguard Worker def test_reduced_type_float_copy(self, device, dtype): 3188*da0073e9SAndroid Build Coastguard Worker for shape in [(20, 7), (249, 137), (1029, 917), (1, 7, 19, 17), (3, 77, 1091)]: 3189*da0073e9SAndroid Build Coastguard Worker input = torch.randn(shape, dtype=torch.float, device=device) 3190*da0073e9SAndroid Build Coastguard Worker out1 = input.to(dtype=dtype) 3191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, out1, atol=None, rtol=None, exact_dtype=False) 3192*da0073e9SAndroid Build Coastguard Worker out2 = out1.to(torch.float) 3193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False) 3194*da0073e9SAndroid Build Coastguard Worker 3195*da0073e9SAndroid Build Coastguard Worker input_s = input[..., ::2, :] 3196*da0073e9SAndroid Build Coastguard Worker out1 = input_s.to(dtype=dtype) 3197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input_s, out1, atol=None, rtol=None, exact_dtype=False) 3198*da0073e9SAndroid Build Coastguard Worker out2 = out1.to(torch.float) 3199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2, out1, atol=0, rtol=0, exact_dtype=False) 3200*da0073e9SAndroid Build Coastguard Worker 3201*da0073e9SAndroid Build Coastguard Worker # FIXME: move to data movement test suite 3202*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3203*da0073e9SAndroid Build Coastguard Worker def test_copy_math_view(self, device): 3204*da0073e9SAndroid Build Coastguard Worker for dst_dtype, src_dtype in [ 3205*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.float32), 3206*da0073e9SAndroid Build Coastguard Worker (torch.float64, torch.float32), 3207*da0073e9SAndroid Build Coastguard Worker (torch.int64, torch.int32), 3208*da0073e9SAndroid Build Coastguard Worker (torch.complex128, torch.complex64), 3209*da0073e9SAndroid Build Coastguard Worker ]: 3210*da0073e9SAndroid Build Coastguard Worker src = make_tensor((100,), dtype=src_dtype, device=device) 3211*da0073e9SAndroid Build Coastguard Worker dst = torch.empty(100, dtype=dst_dtype, device=device) 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Worker dst.copy_(src) 3214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src, exact_dtype=False) 3215*da0073e9SAndroid Build Coastguard Worker 3216*da0073e9SAndroid Build Coastguard Worker dst.copy_(src._neg_view()) 3217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src.neg(), exact_dtype=False) 3218*da0073e9SAndroid Build Coastguard Worker 3219*da0073e9SAndroid Build Coastguard Worker dst._neg_view().copy_(torch._neg_view(src)) 3220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src, exact_dtype=False) 3221*da0073e9SAndroid Build Coastguard Worker 3222*da0073e9SAndroid Build Coastguard Worker dst._neg_view().copy_(src) 3223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src.neg(), exact_dtype=False) 3224*da0073e9SAndroid Build Coastguard Worker 3225*da0073e9SAndroid Build Coastguard Worker # issue: https://github.com/pytorch/pytorch/issues/106051 3226*da0073e9SAndroid Build Coastguard Worker dst._neg_view().copy_(dst) 3227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src, exact_dtype=False) 3228*da0073e9SAndroid Build Coastguard Worker 3229*da0073e9SAndroid Build Coastguard Worker for dst_dtype, src_dtype in [ 3230*da0073e9SAndroid Build Coastguard Worker (torch.complex64, torch.complex64), 3231*da0073e9SAndroid Build Coastguard Worker (torch.complex128, torch.complex64), 3232*da0073e9SAndroid Build Coastguard Worker ]: 3233*da0073e9SAndroid Build Coastguard Worker src = make_tensor((100,), dtype=src_dtype, device=device) 3234*da0073e9SAndroid Build Coastguard Worker dst = torch.empty(100, dtype=dst_dtype, device=device) 3235*da0073e9SAndroid Build Coastguard Worker 3236*da0073e9SAndroid Build Coastguard Worker dst.conj().copy_(src) 3237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src.conj_physical(), exact_dtype=False) 3238*da0073e9SAndroid Build Coastguard Worker 3239*da0073e9SAndroid Build Coastguard Worker dst.conj().copy_(src._neg_view()) 3240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src.neg().conj_physical(), exact_dtype=False) 3241*da0073e9SAndroid Build Coastguard Worker 3242*da0073e9SAndroid Build Coastguard Worker # FIXME: move to data movement test suite 3243*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3244*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float32, torch.complex64) 3245*da0073e9SAndroid Build Coastguard Worker def test_copy_transpose_math_view(self, device, dtype): 3246*da0073e9SAndroid Build Coastguard Worker src = make_tensor((100, 100), dtype=dtype, device=device).transpose(0, 1) 3247*da0073e9SAndroid Build Coastguard Worker dst = torch.empty((100, 100), dtype=dtype, device=device) 3248*da0073e9SAndroid Build Coastguard Worker 3249*da0073e9SAndroid Build Coastguard Worker dst._neg_view().copy_(src) 3250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, -src) 3251*da0073e9SAndroid Build Coastguard Worker dst._neg_view().copy_(src._neg_view()) 3252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src) 3253*da0073e9SAndroid Build Coastguard Worker dst.copy_(src._neg_view()) 3254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, -src) 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 3257*da0073e9SAndroid Build Coastguard Worker dst.conj().copy_(src) 3258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src.conj_physical()) 3259*da0073e9SAndroid Build Coastguard Worker dst.conj().copy_(src.conj()) 3260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src) 3261*da0073e9SAndroid Build Coastguard Worker dst.copy_(src.conj()) 3262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, src.conj_physical()) 3263*da0073e9SAndroid Build Coastguard Worker 3264*da0073e9SAndroid Build Coastguard Worker def test_clone_all_dtypes_and_devices(self, device): 3265*da0073e9SAndroid Build Coastguard Worker for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16): 3266*da0073e9SAndroid Build Coastguard Worker x = torch.tensor((1, 1), dtype=dt, device=device) 3267*da0073e9SAndroid Build Coastguard Worker y = x.clone() 3268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 3269*da0073e9SAndroid Build Coastguard Worker 3270*da0073e9SAndroid Build Coastguard Worker def test_clone_zero_stride_dim(self, device): 3271*da0073e9SAndroid Build Coastguard Worker # stride zero, size 1 axis, not contiguous 3272*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3273*da0073e9SAndroid Build Coastguard Worker y = x.as_strided([2, 1, 5], [1, 0, 2]) 3274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y.clone()) 3275*da0073e9SAndroid Build Coastguard Worker 3276*da0073e9SAndroid Build Coastguard Worker def test_clone_not_memory_dense(self): 3277*da0073e9SAndroid Build Coastguard Worker # github issue: https://github.com/pytorch/pytorch/issues/64176 3278*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 8).t()[::2, ::2] 3279*da0073e9SAndroid Build Coastguard Worker y = x.clone() 3280*da0073e9SAndroid Build Coastguard Worker # should retain permutation after densification 3281*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.stride() == (1, 4)) 3282*da0073e9SAndroid Build Coastguard Worker 3283*da0073e9SAndroid Build Coastguard Worker # FIXME: move to elementwise ternary test suite 3284*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) 3285*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(get_all_math_dtypes('cpu'))) 3286*da0073e9SAndroid Build Coastguard Worker def test_addcmul(self, device, dtype): 3287*da0073e9SAndroid Build Coastguard Worker # Returns floating or integral scalar corresponding to dtype 3288*da0073e9SAndroid Build Coastguard Worker def _number(floating, integer, dtype): 3289*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]: 3290*da0073e9SAndroid Build Coastguard Worker return floating 3291*da0073e9SAndroid Build Coastguard Worker elif dtype in [torch.cfloat, torch.cdouble]: 3292*da0073e9SAndroid Build Coastguard Worker return floating * (1 + 1j) 3293*da0073e9SAndroid Build Coastguard Worker else: 3294*da0073e9SAndroid Build Coastguard Worker return integer 3295*da0073e9SAndroid Build Coastguard Worker 3296*da0073e9SAndroid Build Coastguard Worker def rand_tensor(size, dtype, device): 3297*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 3298*da0073e9SAndroid Build Coastguard Worker return torch.rand(size=size, dtype=dtype, device=device) 3299*da0073e9SAndroid Build Coastguard Worker if dtype == torch.uint8: 3300*da0073e9SAndroid Build Coastguard Worker return torch.randint(1, 5, size=size, dtype=dtype, device=device) 3301*da0073e9SAndroid Build Coastguard Worker else: 3302*da0073e9SAndroid Build Coastguard Worker return torch.randint(-5, 5, size=size, dtype=dtype, device=device) 3303*da0073e9SAndroid Build Coastguard Worker 3304*da0073e9SAndroid Build Coastguard Worker a = rand_tensor((2, 2), dtype=dtype, device=device) 3305*da0073e9SAndroid Build Coastguard Worker b = rand_tensor((2, 2), dtype=dtype, device=device) 3306*da0073e9SAndroid Build Coastguard Worker c = rand_tensor((2, 2), dtype=dtype, device=device) 3307*da0073e9SAndroid Build Coastguard Worker 3308*da0073e9SAndroid Build Coastguard Worker alpha = _number(0.5, 3, dtype) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker actual = torch.addcmul(a, b, c, value=alpha) 3311*da0073e9SAndroid Build Coastguard Worker expected = a + alpha * b * c 3312*da0073e9SAndroid Build Coastguard Worker 3313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 3316*da0073e9SAndroid Build Coastguard Worker UserWarning, "This overload of addcmul is deprecated"): 3317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, torch.addcmul(a, alpha, b, c)) 3318*da0073e9SAndroid Build Coastguard Worker 3319*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and dtype == torch.half: 3320*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([60000.0], device=device, dtype=dtype) 3321*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([60000.0], device=device, dtype=dtype) 3322*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([2.0], device=device, dtype=dtype) 3323*da0073e9SAndroid Build Coastguard Worker out = torch.addcmul(a, b, c, value=-1) 3324*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not (out.isnan() or out.isinf())) 3325*da0073e9SAndroid Build Coastguard Worker 3326*da0073e9SAndroid Build Coastguard Worker # FIXME: move to shape ops test suite 3327*da0073e9SAndroid Build Coastguard Worker def test_narrow_empty(self, device): 3328*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, device=device) 3329*da0073e9SAndroid Build Coastguard Worker for d in range(x.dim()): 3330*da0073e9SAndroid Build Coastguard Worker y = x.narrow(d, x.size(d), 0) 3331*da0073e9SAndroid Build Coastguard Worker sz = list(x.size()) 3332*da0073e9SAndroid Build Coastguard Worker sz[d] = 0 3333*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sz, y.size()) 3334*da0073e9SAndroid Build Coastguard Worker 3335*da0073e9SAndroid Build Coastguard Worker def test_narrow_copy_non_contiguous(self, device): 3336*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/91690. 3337*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10, 2, device=device).movedim(-1, 0) 3338*da0073e9SAndroid Build Coastguard Worker expected = torch.narrow_copy(inp.contiguous(), 1, 0, 10) 3339*da0073e9SAndroid Build Coastguard Worker actual = torch.narrow_copy(inp, 1, 0, 10) 3340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 3341*da0073e9SAndroid Build Coastguard Worker 3342*da0073e9SAndroid Build Coastguard Worker # FIXME: move to indexing test suite 3343*da0073e9SAndroid Build Coastguard Worker @parametrize("reduce", ['prod', 'amin', 'amax', 'mean']) 3344*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and(torch.half, torch.bfloat16)) 3345*da0073e9SAndroid Build Coastguard Worker def test_index_reduce(self, device, dtype, reduce): 3346*da0073e9SAndroid Build Coastguard Worker size = (3, 4, 5) 3347*da0073e9SAndroid Build Coastguard Worker index_dtypes = [torch.int, torch.long] 3348*da0073e9SAndroid Build Coastguard Worker include_selfs = [True, False] 3349*da0073e9SAndroid Build Coastguard Worker amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max 3350*da0073e9SAndroid Build Coastguard Worker amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min 3351*da0073e9SAndroid Build Coastguard Worker reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init} 3352*da0073e9SAndroid Build Coastguard Worker 3353*da0073e9SAndroid Build Coastguard Worker for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3): 3354*da0073e9SAndroid Build Coastguard Worker for idx_dtype, include_self in product(index_dtypes, include_selfs): 3355*da0073e9SAndroid Build Coastguard Worker for dim in range(len(size)): 3356*da0073e9SAndroid Build Coastguard Worker num_src = np.random.randint(10) 3357*da0073e9SAndroid Build Coastguard Worker num_dest = size[dim] 3358*da0073e9SAndroid Build Coastguard Worker dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig) 3359*da0073e9SAndroid Build Coastguard Worker src_size = size[:dim] + (num_src,) + size[dim + 1:] 3360*da0073e9SAndroid Build Coastguard Worker src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig) 3361*da0073e9SAndroid Build Coastguard Worker idx = torch.testing.make_tensor( 3362*da0073e9SAndroid Build Coastguard Worker num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig 3363*da0073e9SAndroid Build Coastguard Worker ) 3364*da0073e9SAndroid Build Coastguard Worker expected = dest.clone() 3365*da0073e9SAndroid Build Coastguard Worker dest.index_reduce_(dim, idx, src, reduce, include_self=include_self) 3366*da0073e9SAndroid Build Coastguard Worker # fill rows in idx with reduction inits if include_self=False 3367*da0073e9SAndroid Build Coastguard Worker if (not include_self): 3368*da0073e9SAndroid Build Coastguard Worker expected.index_fill_(dim, idx.long(), reduction_init[reduce]) 3369*da0073e9SAndroid Build Coastguard Worker expected = expected.transpose(0, dim) 3370*da0073e9SAndroid Build Coastguard Worker src = src.transpose(0, dim) 3371*da0073e9SAndroid Build Coastguard Worker for i in range(num_src): 3372*da0073e9SAndroid Build Coastguard Worker if reduce == 'prod': 3373*da0073e9SAndroid Build Coastguard Worker expected[idx[i]] *= src[i] 3374*da0073e9SAndroid Build Coastguard Worker elif reduce == 'amin': 3375*da0073e9SAndroid Build Coastguard Worker torch.minimum(expected[idx[i]], src[i], out=expected[idx[i]]) 3376*da0073e9SAndroid Build Coastguard Worker elif reduce == 'amax': 3377*da0073e9SAndroid Build Coastguard Worker torch.maximum(expected[idx[i]], src[i], out=expected[idx[i]]) 3378*da0073e9SAndroid Build Coastguard Worker else: 3379*da0073e9SAndroid Build Coastguard Worker expected[idx[i]] += src[i] 3380*da0073e9SAndroid Build Coastguard Worker if reduce == 'mean': 3381*da0073e9SAndroid Build Coastguard Worker counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected) 3382*da0073e9SAndroid Build Coastguard Worker counts.index_add_(0, idx, torch.ones_like(src)) 3383*da0073e9SAndroid Build Coastguard Worker counts.masked_fill_(counts == 0, 1) 3384*da0073e9SAndroid Build Coastguard Worker if (dtype.is_floating_point): 3385*da0073e9SAndroid Build Coastguard Worker expected.div_(counts) 3386*da0073e9SAndroid Build Coastguard Worker else: 3387*da0073e9SAndroid Build Coastguard Worker expected.div_(counts, rounding_mode="floor") 3388*da0073e9SAndroid Build Coastguard Worker expected = expected.transpose(0, dim) 3389*da0073e9SAndroid Build Coastguard Worker 3390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest, expected) 3391*da0073e9SAndroid Build Coastguard Worker 3392*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3393*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3394*da0073e9SAndroid Build Coastguard Worker def test_index_copy(self, device, dtype): 3395*da0073e9SAndroid Build Coastguard Worker # We just test for num_copy <= num_dest, as otherwise there are repeated indices 3396*da0073e9SAndroid Build Coastguard Worker # and the behavior is undefined 3397*da0073e9SAndroid Build Coastguard Worker num_copy, num_dest = 3, 5 3398*da0073e9SAndroid Build Coastguard Worker 3399*da0073e9SAndroid Build Coastguard Worker def make_arg(batch_sizes, n, dim, contig): 3400*da0073e9SAndroid Build Coastguard Worker size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:] 3401*da0073e9SAndroid Build Coastguard Worker return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig) 3402*da0073e9SAndroid Build Coastguard Worker 3403*da0073e9SAndroid Build Coastguard Worker def ref_index_copy(tgt, dim, idx, src): 3404*da0073e9SAndroid Build Coastguard Worker for i in range(idx.size(0)): 3405*da0073e9SAndroid Build Coastguard Worker idx_dest = dim * (slice(None),) + (idx[i],) 3406*da0073e9SAndroid Build Coastguard Worker idx_src = dim * (slice(None),) + (i,) 3407*da0073e9SAndroid Build Coastguard Worker tgt[idx_dest] = src[idx_src] 3408*da0073e9SAndroid Build Coastguard Worker 3409*da0073e9SAndroid Build Coastguard Worker # More thorough testing as in index_add 3410*da0073e9SAndroid Build Coastguard Worker for dest_contig, src_contig, index_contig in product([True, False], repeat=3): 3411*da0073e9SAndroid Build Coastguard Worker for other_sizes in ((), (4, 5)): 3412*da0073e9SAndroid Build Coastguard Worker for dim in range(len(other_sizes)): 3413*da0073e9SAndroid Build Coastguard Worker dest = make_arg(other_sizes, num_dest, dim, dest_contig) 3414*da0073e9SAndroid Build Coastguard Worker src = make_arg(other_sizes, num_copy, dim, src_contig) 3415*da0073e9SAndroid Build Coastguard Worker idx = torch.randperm(num_dest, dtype=torch.int64, device=device)[:num_copy] 3416*da0073e9SAndroid Build Coastguard Worker if not index_contig: 3417*da0073e9SAndroid Build Coastguard Worker idx = torch.repeat_interleave(idx, 2, dim=-1) 3418*da0073e9SAndroid Build Coastguard Worker idx = idx[..., ::2] 3419*da0073e9SAndroid Build Coastguard Worker dest2 = dest.clone() 3420*da0073e9SAndroid Build Coastguard Worker dest.index_copy_(dim, idx, src) 3421*da0073e9SAndroid Build Coastguard Worker ref_index_copy(dest2, dim, idx, src) 3422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest, dest2) 3423*da0073e9SAndroid Build Coastguard Worker 3424*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3425*da0073e9SAndroid Build Coastguard Worker # onlyNativeDeviceTypes due to an XLA error: 3426*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/53256 3427*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3428*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3429*da0073e9SAndroid Build Coastguard Worker def test_index_copy_scalars(self, device, dtype): 3430*da0073e9SAndroid Build Coastguard Worker # Create the 8 possible combinations of scalar sizes for target / index / source 3431*da0073e9SAndroid Build Coastguard Worker scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None), 3432*da0073e9SAndroid Build Coastguard Worker make_tensor(size_i, dtype=torch.int64, device=device, low=0, high=1), 3433*da0073e9SAndroid Build Coastguard Worker make_tensor(size_s, dtype=dtype, device=device, low=None, high=None)) 3434*da0073e9SAndroid Build Coastguard Worker for size_t, size_i, size_s in product([(), (1,)], repeat=3)) 3435*da0073e9SAndroid Build Coastguard Worker for target, idx, source in scalars: 3436*da0073e9SAndroid Build Coastguard Worker target.index_copy_(0, idx, source) 3437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target.item(), source.item()) 3438*da0073e9SAndroid Build Coastguard Worker 3439*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3440*da0073e9SAndroid Build Coastguard Worker @onlyCPU 3441*da0073e9SAndroid Build Coastguard Worker def test_errors_index_copy(self, device): 3442*da0073e9SAndroid Build Coastguard Worker # We do not test the GPU as the CUDA_ASSERT would break the CUDA context 3443*da0073e9SAndroid Build Coastguard Worker idx_dim = 8 3444*da0073e9SAndroid Build Coastguard Worker tgt_dim = 5 3445*da0073e9SAndroid Build Coastguard Worker batch_dim = 3 3446*da0073e9SAndroid Build Coastguard Worker 3447*da0073e9SAndroid Build Coastguard Worker # Too large of an index 3448*da0073e9SAndroid Build Coastguard Worker a = torch.randn(batch_dim, tgt_dim, device=device) 3449*da0073e9SAndroid Build Coastguard Worker idx = torch.full((idx_dim,), tgt_dim, device=device) 3450*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(batch_dim, idx_dim, device=device) 3451*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 3452*da0073e9SAndroid Build Coastguard Worker a.index_copy_(1, idx, c) 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker # Too small (negative indices) 3455*da0073e9SAndroid Build Coastguard Worker idx = torch.full((idx_dim,), -1, device=device) 3456*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 3457*da0073e9SAndroid Build Coastguard Worker a.index_copy_(1, idx, c) 3458*da0073e9SAndroid Build Coastguard Worker 3459*da0073e9SAndroid Build Coastguard Worker # Too small (very negative indices) - they should be unsupported even 3460*da0073e9SAndroid Build Coastguard Worker # when support for negative indices is implemented for index_copy_ 3461*da0073e9SAndroid Build Coastguard Worker idx = torch.full((idx_dim,), -tgt_dim - 1, device=device) 3462*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 3463*da0073e9SAndroid Build Coastguard Worker a.index_copy_(1, idx, c) 3464*da0073e9SAndroid Build Coastguard Worker 3465*da0073e9SAndroid Build Coastguard Worker def _prepare_data_for_index_copy_and_add_deterministic( 3466*da0073e9SAndroid Build Coastguard Worker self, dim: int, device: torch.device 3467*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 3468*da0073e9SAndroid Build Coastguard Worker assert (dim >= 0 and dim < 3) 3469*da0073e9SAndroid Build Coastguard Worker a = [5, 4, 3] 3470*da0073e9SAndroid Build Coastguard Worker a[dim] = 2000 3471*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(a, device=device) 3472*da0073e9SAndroid Build Coastguard Worker b = a.copy() 3473*da0073e9SAndroid Build Coastguard Worker elems = a[dim] * 20 3474*da0073e9SAndroid Build Coastguard Worker b[dim] = elems 3475*da0073e9SAndroid Build Coastguard Worker src = torch.rand(b, device=device) 3476*da0073e9SAndroid Build Coastguard Worker index = torch.randint(a[dim], (elems,), device=device) 3477*da0073e9SAndroid Build Coastguard Worker return (x, index, src) 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3480*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3481*da0073e9SAndroid Build Coastguard Worker def test_index_copy_deterministic(self, device: torch.device) -> None: 3482*da0073e9SAndroid Build Coastguard Worker for dim in range(3): 3483*da0073e9SAndroid Build Coastguard Worker x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device) 3484*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 3485*da0073e9SAndroid Build Coastguard Worker y0 = torch.index_copy(x, dim, index, src) 3486*da0073e9SAndroid Build Coastguard Worker 3487*da0073e9SAndroid Build Coastguard Worker x0 = x.clone().detach() 3488*da0073e9SAndroid Build Coastguard Worker index_list = index.tolist() 3489*da0073e9SAndroid Build Coastguard Worker for i in range(len(index_list)): 3490*da0073e9SAndroid Build Coastguard Worker if dim == 0: 3491*da0073e9SAndroid Build Coastguard Worker x0[index_list[i], :, :] = src[i, :, :] 3492*da0073e9SAndroid Build Coastguard Worker elif dim == 1: 3493*da0073e9SAndroid Build Coastguard Worker x0[:, index_list[i], :] = src[:, i, :] 3494*da0073e9SAndroid Build Coastguard Worker elif dim == 2: 3495*da0073e9SAndroid Build Coastguard Worker x0[:, :, index_list[i]] = src[:, :, i] 3496*da0073e9SAndroid Build Coastguard Worker 3497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0, y0, atol=0, rtol=0) 3498*da0073e9SAndroid Build Coastguard Worker 3499*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3500*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3501*da0073e9SAndroid Build Coastguard Worker def test_index_add_deterministic(self, device: torch.device) -> None: 3502*da0073e9SAndroid Build Coastguard Worker for dim in range(3): 3503*da0073e9SAndroid Build Coastguard Worker x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device) 3504*da0073e9SAndroid Build Coastguard Worker alpha = random.random() + 1 3505*da0073e9SAndroid Build Coastguard Worker # on CPU it should be deterministic regardless of the deterministic mode 3506*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 3507*da0073e9SAndroid Build Coastguard Worker y0 = torch.index_add(x, dim, index, src, alpha=alpha) 3508*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 3509*da0073e9SAndroid Build Coastguard Worker y = torch.index_add(x, dim, index, src, alpha=alpha) 3510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y0, atol=0, rtol=0) 3511*da0073e9SAndroid Build Coastguard Worker 3512*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(False): 3513*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 3514*da0073e9SAndroid Build Coastguard Worker y_nd = torch.index_add(x, dim, index, src, alpha=alpha) 3515*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5) 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the put operator 3518*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3519*da0073e9SAndroid Build Coastguard Worker def test_index_put_non_accumulate_deterministic(self, device) -> None: 3520*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True): 3521*da0073e9SAndroid Build Coastguard Worker for i in range(3): 3522*da0073e9SAndroid Build Coastguard Worker m = random.randint(10, 20) 3523*da0073e9SAndroid Build Coastguard Worker elems = random.randint(20000, 30000) 3524*da0073e9SAndroid Build Coastguard Worker values = torch.rand(elems, device=device) 3525*da0073e9SAndroid Build Coastguard Worker indices = torch.randint(m, (elems,), device=device) 3526*da0073e9SAndroid Build Coastguard Worker input = torch.rand(m, device=device) 3527*da0073e9SAndroid Build Coastguard Worker output = input.index_put((indices,), values, accumulate=False) 3528*da0073e9SAndroid Build Coastguard Worker 3529*da0073e9SAndroid Build Coastguard Worker input_list = input.tolist() 3530*da0073e9SAndroid Build Coastguard Worker indices_list = indices.tolist() 3531*da0073e9SAndroid Build Coastguard Worker values_list = values.tolist() 3532*da0073e9SAndroid Build Coastguard Worker for i, v in zip(indices_list, values_list): 3533*da0073e9SAndroid Build Coastguard Worker input_list[i] = v 3534*da0073e9SAndroid Build Coastguard Worker 3535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, input_list) 3536*da0073e9SAndroid Build Coastguard Worker 3537*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3538*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3539*da0073e9SAndroid Build Coastguard Worker @skipIfMps 3540*da0073e9SAndroid Build Coastguard Worker def test_index_fill(self, device, dtype): 3541*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device) 3542*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([0], device=device) 3543*da0073e9SAndroid Build Coastguard Worker x.index_fill_(1, index, 0) 3544*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device)) 3545*da0073e9SAndroid Build Coastguard Worker if not x.is_complex() and not device == "meta": 3546*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Scalar"): 3547*da0073e9SAndroid Build Coastguard Worker x.index_fill_(1, index, 1 + 1j) 3548*da0073e9SAndroid Build Coastguard Worker # Make sure that the result stays 0-dim while applied to 3549*da0073e9SAndroid Build Coastguard Worker # a 0-dim input 3550*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1, dtype=dtype, device=device) 3551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, x.index_fill(0, index, -1).dim()) 3552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, x.index_fill_(0, index, -1).dim()) 3553*da0073e9SAndroid Build Coastguard Worker 3554*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test indexing 3555*da0073e9SAndroid Build Coastguard Worker # The test fails for zero-dimensional tensors on XLA 3556*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3557*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3558*da0073e9SAndroid Build Coastguard Worker def test_index_select(self, device, dtype): 3559*da0073e9SAndroid Build Coastguard Worker num_src, num_out = 3, 5 3560*da0073e9SAndroid Build Coastguard Worker 3561*da0073e9SAndroid Build Coastguard Worker def make_arg(batch_sizes, n, dim, contig): 3562*da0073e9SAndroid Build Coastguard Worker size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:] 3563*da0073e9SAndroid Build Coastguard Worker return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig) 3564*da0073e9SAndroid Build Coastguard Worker 3565*da0073e9SAndroid Build Coastguard Worker def ref_index_select(src, dim, idx): 3566*da0073e9SAndroid Build Coastguard Worker # bfloat16 is just used on GPU, so it's not supported on numpy 3567*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 3568*da0073e9SAndroid Build Coastguard Worker src = src.float() 3569*da0073e9SAndroid Build Coastguard Worker out = torch.from_numpy(np.take(src.cpu().numpy(), idx.cpu().numpy(), axis=dim)) 3570*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 3571*da0073e9SAndroid Build Coastguard Worker out = out.to(device=device, dtype=dtype) 3572*da0073e9SAndroid Build Coastguard Worker return out 3573*da0073e9SAndroid Build Coastguard Worker 3574*da0073e9SAndroid Build Coastguard Worker for src_contig, idx_contig in product([True, False], repeat=2): 3575*da0073e9SAndroid Build Coastguard Worker for other_sizes in ((), (4, 5)): 3576*da0073e9SAndroid Build Coastguard Worker for dim in range(len(other_sizes)): 3577*da0073e9SAndroid Build Coastguard Worker src = make_arg(other_sizes, num_src, dim, src_contig) 3578*da0073e9SAndroid Build Coastguard Worker idx = make_tensor( 3579*da0073e9SAndroid Build Coastguard Worker (num_out,), dtype=torch.int64, device=device, low=0, high=num_src, noncontiguous=not idx_contig 3580*da0073e9SAndroid Build Coastguard Worker ) 3581*da0073e9SAndroid Build Coastguard Worker out = torch.index_select(src, dim, idx) 3582*da0073e9SAndroid Build Coastguard Worker out2 = ref_index_select(src, dim, idx) 3583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2) 3584*da0073e9SAndroid Build Coastguard Worker 3585*da0073e9SAndroid Build Coastguard Worker for idx_type in (torch.int32, torch.int64): 3586*da0073e9SAndroid Build Coastguard Worker other_sizes = (3, 2) 3587*da0073e9SAndroid Build Coastguard Worker dim = 1 3588*da0073e9SAndroid Build Coastguard Worker src = make_arg(other_sizes, num_src, dim, True) 3589*da0073e9SAndroid Build Coastguard Worker idx = make_tensor((num_out,), dtype=idx_type, device=device, low=0, high=num_src, noncontiguous=False) 3590*da0073e9SAndroid Build Coastguard Worker out = torch.index_select(src, dim, idx) 3591*da0073e9SAndroid Build Coastguard Worker out2 = ref_index_select(src, dim, idx) 3592*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2) 3593*da0073e9SAndroid Build Coastguard Worker 3594*da0073e9SAndroid Build Coastguard Worker # Create the 4 possible combinations of scalar sizes for index / source 3595*da0073e9SAndroid Build Coastguard Worker scalars = ((make_tensor(size_s, dtype=dtype, device=device), 3596*da0073e9SAndroid Build Coastguard Worker torch.zeros(size_i, dtype=torch.int64, device=device)) 3597*da0073e9SAndroid Build Coastguard Worker for size_s, size_i in product([(), (1,)], repeat=2)) 3598*da0073e9SAndroid Build Coastguard Worker for source, idx in scalars: 3599*da0073e9SAndroid Build Coastguard Worker out = source.index_select(0, idx) 3600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.item(), source.item()) 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the take operator 3603*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3604*da0073e9SAndroid Build Coastguard Worker def test_take(self, device, dtype): 3605*da0073e9SAndroid Build Coastguard Worker idx_size = (4,) 3606*da0073e9SAndroid Build Coastguard Worker 3607*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 3608*da0073e9SAndroid Build Coastguard Worker make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64) 3609*da0073e9SAndroid Build Coastguard Worker 3610*da0073e9SAndroid Build Coastguard Worker def ref_take(src, idx): 3611*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 3612*da0073e9SAndroid Build Coastguard Worker src = src.half() 3613*da0073e9SAndroid Build Coastguard Worker src = src.cpu().numpy() 3614*da0073e9SAndroid Build Coastguard Worker idx = idx.cpu().numpy() 3615*da0073e9SAndroid Build Coastguard Worker out = torch.from_numpy(np.take(src, idx)).to(device=device, dtype=dtype) 3616*da0073e9SAndroid Build Coastguard Worker return out 3617*da0073e9SAndroid Build Coastguard Worker 3618*da0073e9SAndroid Build Coastguard Worker for src_contig, idx_contig, idx_reshape in product([True, False], repeat=3): 3619*da0073e9SAndroid Build Coastguard Worker for src_size in ((5,), (4, 5)): 3620*da0073e9SAndroid Build Coastguard Worker src = make_arg(src_size, noncontiguous=not src_contig) 3621*da0073e9SAndroid Build Coastguard Worker idx = make_idx(idx_size, high=src.numel(), noncontiguous=not idx_contig) 3622*da0073e9SAndroid Build Coastguard Worker if idx_reshape: 3623*da0073e9SAndroid Build Coastguard Worker idx = idx.reshape(2, 2) 3624*da0073e9SAndroid Build Coastguard Worker out = torch.take(src, idx) 3625*da0073e9SAndroid Build Coastguard Worker out2 = ref_take(src, idx) 3626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out2) 3627*da0073e9SAndroid Build Coastguard Worker 3628*da0073e9SAndroid Build Coastguard Worker # Create the 4 possible combinations of scalar sizes for source / index 3629*da0073e9SAndroid Build Coastguard Worker for size_s, size_i in product([(), (1,)], repeat=2): 3630*da0073e9SAndroid Build Coastguard Worker source = make_arg(size_s) 3631*da0073e9SAndroid Build Coastguard Worker idx = make_idx(size_i, high=1) 3632*da0073e9SAndroid Build Coastguard Worker out = source.take(idx) 3633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.item(), source.item()) 3634*da0073e9SAndroid Build Coastguard Worker 3635*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the put operator 3636*da0073e9SAndroid Build Coastguard Worker # The bool instance does not work on GPU. See 3637*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/54317 3638*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 3639*da0073e9SAndroid Build Coastguard Worker def test_put(self, device, dtype): 3640*da0073e9SAndroid Build Coastguard Worker src_size = (4,) 3641*da0073e9SAndroid Build Coastguard Worker 3642*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, device=device, dtype=dtype) 3643*da0073e9SAndroid Build Coastguard Worker make_idx = partial(make_tensor, low=0, device=device, dtype=torch.int64) 3644*da0073e9SAndroid Build Coastguard Worker 3645*da0073e9SAndroid Build Coastguard Worker def ref_put(dst, idx, src, accumulate): 3646*da0073e9SAndroid Build Coastguard Worker new_dst = dst.clone(memory_format=torch.contiguous_format).view(-1) 3647*da0073e9SAndroid Build Coastguard Worker new_idx = idx.contiguous().view(-1) 3648*da0073e9SAndroid Build Coastguard Worker new_src = src.contiguous().view(-1) 3649*da0073e9SAndroid Build Coastguard Worker method = new_dst.index_add_ if accumulate else new_dst.index_copy_ 3650*da0073e9SAndroid Build Coastguard Worker return method(0, new_idx, new_src).view_as(dst) 3651*da0073e9SAndroid Build Coastguard Worker 3652*da0073e9SAndroid Build Coastguard Worker for dst_contig, src_contig, idx_contig, idx_reshape, accumulate in product([True, False], repeat=5): 3653*da0073e9SAndroid Build Coastguard Worker for dst_size in ((5,), (4, 5)): 3654*da0073e9SAndroid Build Coastguard Worker dst = make_arg(dst_size, noncontiguous=not dst_contig) 3655*da0073e9SAndroid Build Coastguard Worker src = make_arg(src_size, noncontiguous=not src_contig) 3656*da0073e9SAndroid Build Coastguard Worker 3657*da0073e9SAndroid Build Coastguard Worker # If accumulate=True, `put_` should be deterministic regardless of the inputs on CPU 3658*da0073e9SAndroid Build Coastguard Worker # On CUDA it may not be, but the test has enough tolerance to account for this 3659*da0073e9SAndroid Build Coastguard Worker if accumulate: 3660*da0073e9SAndroid Build Coastguard Worker idx = make_idx(src_size, high=dst.numel()) 3661*da0073e9SAndroid Build Coastguard Worker else: 3662*da0073e9SAndroid Build Coastguard Worker idx = torch.randperm(dst.numel(), dtype=torch.int64, device=device)[:src_size[0]] 3663*da0073e9SAndroid Build Coastguard Worker if not idx_contig: 3664*da0073e9SAndroid Build Coastguard Worker idx = torch.repeat_interleave(idx, 2, dim=-1)[..., ::2] 3665*da0073e9SAndroid Build Coastguard Worker if idx_reshape: 3666*da0073e9SAndroid Build Coastguard Worker idx = idx.reshape(2, 2) 3667*da0073e9SAndroid Build Coastguard Worker out = torch.put(dst, idx, src, accumulate) 3668*da0073e9SAndroid Build Coastguard Worker # out-place 3669*da0073e9SAndroid Build Coastguard Worker reference = ref_put(dst, idx, src, accumulate) 3670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, reference) 3671*da0073e9SAndroid Build Coastguard Worker 3672*da0073e9SAndroid Build Coastguard Worker # in-place 3673*da0073e9SAndroid Build Coastguard Worker dst.put_(idx, src, accumulate) 3674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, reference) 3675*da0073e9SAndroid Build Coastguard Worker 3676*da0073e9SAndroid Build Coastguard Worker 3677*da0073e9SAndroid Build Coastguard Worker # Create the 8 possible combinations of scalar sizes for target / index / source 3678*da0073e9SAndroid Build Coastguard Worker scalars = ((make_arg(size_t), 3679*da0073e9SAndroid Build Coastguard Worker make_idx(size_i, high=1), 3680*da0073e9SAndroid Build Coastguard Worker make_arg(size_s)) 3681*da0073e9SAndroid Build Coastguard Worker for size_t, size_i, size_s in product([(), (1,)], repeat=3)) 3682*da0073e9SAndroid Build Coastguard Worker for (dest, idx, source), accumulate in product(scalars, [True, False]): 3683*da0073e9SAndroid Build Coastguard Worker dest_init = dest.clone() 3684*da0073e9SAndroid Build Coastguard Worker # out-place 3685*da0073e9SAndroid Build Coastguard Worker out = torch.put(dest, idx, source, accumulate=accumulate) 3686*da0073e9SAndroid Build Coastguard Worker # in-place 3687*da0073e9SAndroid Build Coastguard Worker dest1 = dest.clone() 3688*da0073e9SAndroid Build Coastguard Worker dest1.put_(idx, source, accumulate=accumulate) 3689*da0073e9SAndroid Build Coastguard Worker for d in [out, dest1]: 3690*da0073e9SAndroid Build Coastguard Worker if accumulate: 3691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.item(), (dest_init + source).item()) 3692*da0073e9SAndroid Build Coastguard Worker else: 3693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.item(), source.item()) 3694*da0073e9SAndroid Build Coastguard Worker 3695*da0073e9SAndroid Build Coastguard Worker # Empty case 3696*da0073e9SAndroid Build Coastguard Worker dest = make_arg((3, 2)) 3697*da0073e9SAndroid Build Coastguard Worker reference = dest.clone() 3698*da0073e9SAndroid Build Coastguard Worker idx = make_idx((0,), high=1) 3699*da0073e9SAndroid Build Coastguard Worker source = make_arg((0,)) 3700*da0073e9SAndroid Build Coastguard Worker for accumulate in [True, False]: 3701*da0073e9SAndroid Build Coastguard Worker out = torch.put(dest, idx, source, accumulate=accumulate) 3702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, reference) 3703*da0073e9SAndroid Build Coastguard Worker dest.put_(idx, source, accumulate=accumulate) 3704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest, reference) 3705*da0073e9SAndroid Build Coastguard Worker 3706*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the put operator 3707*da0073e9SAndroid Build Coastguard Worker # The bool instance does not work on GPU. See 3708*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/54317 3709*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 3710*da0073e9SAndroid Build Coastguard Worker def test_put_accumulate(self, device, dtype): 3711*da0073e9SAndroid Build Coastguard Worker # Test for parallel adds with accumulate == True 3712*da0073e9SAndroid Build Coastguard Worker low_precision = dtype == torch.half or dtype == torch.bfloat16 3713*da0073e9SAndroid Build Coastguard Worker # Less numbers to avoid overflow with low_precision 3714*da0073e9SAndroid Build Coastguard Worker # Grainsize is 3000 for the for_loop to be parallized on CPU 3715*da0073e9SAndroid Build Coastguard Worker sizes = ((100,)) if low_precision else ((200,), (3002,)) 3716*da0073e9SAndroid Build Coastguard Worker # Bfloat16 has a particularly bad performance here 3717*da0073e9SAndroid Build Coastguard Worker # This operation is nondeterministic on GPU, so we are generous with the rtol 3718*da0073e9SAndroid Build Coastguard Worker rtol, atol = (1e-1, 1e-2) if low_precision else (1e-3, 1e-4) 3719*da0073e9SAndroid Build Coastguard Worker 3720*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, low=-2, high=3, device=device, dtype=dtype) 3721*da0073e9SAndroid Build Coastguard Worker # Dump everything into the 0-th position 3722*da0073e9SAndroid Build Coastguard Worker make_idx = partial(torch.zeros, device=device, dtype=torch.int64) 3723*da0073e9SAndroid Build Coastguard Worker args = ((make_idx(size), make_arg(size)) for size in sizes) 3724*da0073e9SAndroid Build Coastguard Worker 3725*da0073e9SAndroid Build Coastguard Worker for idx, source in args: 3726*da0073e9SAndroid Build Coastguard Worker orig = make_arg((1,)) 3727*da0073e9SAndroid Build Coastguard Worker out = orig.put(idx, source, accumulate=True) 3728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, orig + source.sum(), rtol=rtol, atol=atol) 3729*da0073e9SAndroid Build Coastguard Worker 3730*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the take operator 3731*da0073e9SAndroid Build Coastguard Worker @skipIfMps 3732*da0073e9SAndroid Build Coastguard Worker def test_take_empty(self, device): 3733*da0073e9SAndroid Build Coastguard Worker for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: 3734*da0073e9SAndroid Build Coastguard Worker for indices_shape in [(0,), (0, 1, 2, 0)]: 3735*da0073e9SAndroid Build Coastguard Worker input = torch.empty(input_shape, device=device) 3736*da0073e9SAndroid Build Coastguard Worker indices = torch.empty(indices_shape, dtype=torch.int64, device=device) 3737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices, torch.take(input, indices), exact_dtype=False) 3738*da0073e9SAndroid Build Coastguard Worker 3739*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the put operator 3740*da0073e9SAndroid Build Coastguard Worker def test_put_empty(self, device): 3741*da0073e9SAndroid Build Coastguard Worker for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: 3742*da0073e9SAndroid Build Coastguard Worker for indices_shape in [(0,), (0, 1, 2, 0)]: 3743*da0073e9SAndroid Build Coastguard Worker for accumulate in [False, True]: 3744*da0073e9SAndroid Build Coastguard Worker dst = torch.randn(dst_shape, device=device) 3745*da0073e9SAndroid Build Coastguard Worker indices = torch.empty(indices_shape, dtype=torch.int64, device=device) 3746*da0073e9SAndroid Build Coastguard Worker src = torch.randn(indices_shape, device=device) 3747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate)) 3748*da0073e9SAndroid Build Coastguard Worker 3749*da0073e9SAndroid Build Coastguard Worker # FIXME: port to test_scatter_gather_ops.py 3750*da0073e9SAndroid Build Coastguard Worker def scatter_allow_reduce(self, device, dtype, reduceop): 3751*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device).type 3752*da0073e9SAndroid Build Coastguard Worker return device_type != 'cuda' or (reduceop == 'multiply' and dtype.is_floating_point) 3753*da0073e9SAndroid Build Coastguard Worker 3754*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3755*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3756*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3757*da0073e9SAndroid Build Coastguard Worker def test_scatter_reduce_operations_to_large_input(self, device, dtype): 3758*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([[1], [2]], device=device, dtype=torch.long) 3759*da0073e9SAndroid Build Coastguard Worker test_data = [ 3760*da0073e9SAndroid Build Coastguard Worker (torch.zeros(4, 4, device=device, dtype=dtype), 3761*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, device=device, dtype=dtype), 3762*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0, 0, 0], 3763*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3764*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3765*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0]], 3766*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype), "add"), 3767*da0073e9SAndroid Build Coastguard Worker (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4), 3768*da0073e9SAndroid Build Coastguard Worker torch.tensor([6], device=device, dtype=dtype).repeat(2, 2), 3769*da0073e9SAndroid Build Coastguard Worker torch.tensor([[2, 2, 2, 2], 3770*da0073e9SAndroid Build Coastguard Worker [12, 2, 2, 2], 3771*da0073e9SAndroid Build Coastguard Worker [12, 2, 2, 2], 3772*da0073e9SAndroid Build Coastguard Worker [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"), 3773*da0073e9SAndroid Build Coastguard Worker ] 3774*da0073e9SAndroid Build Coastguard Worker 3775*da0073e9SAndroid Build Coastguard Worker for input, src, result, operation in test_data: 3776*da0073e9SAndroid Build Coastguard Worker if not self.scatter_allow_reduce(device, dtype, operation): 3777*da0073e9SAndroid Build Coastguard Worker continue 3778*da0073e9SAndroid Build Coastguard Worker input.scatter_(0, index, src, reduce=operation) 3779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, result) 3780*da0073e9SAndroid Build Coastguard Worker 3781*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3782*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3783*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3784*da0073e9SAndroid Build Coastguard Worker def test_scatter_reduce_scalar(self, device, dtype): 3785*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([[1], [2]], device=device, dtype=torch.long) 3786*da0073e9SAndroid Build Coastguard Worker test_data = [ 3787*da0073e9SAndroid Build Coastguard Worker (torch.zeros(4, 4, device=device, dtype=dtype), 1, 3788*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0, 0, 0, 0], 3789*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3790*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3791*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0]], 3792*da0073e9SAndroid Build Coastguard Worker device=device, dtype=dtype), "add"), 3793*da0073e9SAndroid Build Coastguard Worker (torch.tensor([2], device=device, dtype=dtype).repeat(4, 4), 2, 3794*da0073e9SAndroid Build Coastguard Worker torch.tensor([[2, 2, 2, 2], 3795*da0073e9SAndroid Build Coastguard Worker [4, 2, 2, 2], 3796*da0073e9SAndroid Build Coastguard Worker [4, 2, 2, 2], 3797*da0073e9SAndroid Build Coastguard Worker [2, 2, 2, 2]], device=device, dtype=dtype), "multiply"), 3798*da0073e9SAndroid Build Coastguard Worker ] 3799*da0073e9SAndroid Build Coastguard Worker 3800*da0073e9SAndroid Build Coastguard Worker for input, src, result, operation in test_data: 3801*da0073e9SAndroid Build Coastguard Worker if not self.scatter_allow_reduce(device, dtype, operation): 3802*da0073e9SAndroid Build Coastguard Worker continue 3803*da0073e9SAndroid Build Coastguard Worker input.scatter_(0, index, src, reduce=operation) 3804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, result) 3805*da0073e9SAndroid Build Coastguard Worker 3806*da0073e9SAndroid Build Coastguard Worker # FIXME: port to test_scatter_gather_ops.py 3807*da0073e9SAndroid Build Coastguard Worker # TODO: remove this after scatter_add_ is deprecated. 3808*da0073e9SAndroid Build Coastguard Worker def test_scatter_add_non_unique_index(self, device): 3809*da0073e9SAndroid Build Coastguard Worker height = 2 3810*da0073e9SAndroid Build Coastguard Worker width = 65536 3811*da0073e9SAndroid Build Coastguard Worker input = torch.ones(height, width, device=device) 3812*da0073e9SAndroid Build Coastguard Worker index = torch.zeros(height, width, dtype=torch.long, device=device) 3813*da0073e9SAndroid Build Coastguard Worker src = torch.ones(height, width, device=device) 3814*da0073e9SAndroid Build Coastguard Worker input.scatter_add_(0, index, src) 3815*da0073e9SAndroid Build Coastguard Worker 3816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, 3817*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3], [1]], device=device, 3818*da0073e9SAndroid Build Coastguard Worker dtype=torch.float32).repeat(1, width)) 3819*da0073e9SAndroid Build Coastguard Worker 3820*da0073e9SAndroid Build Coastguard Worker @dtypes(*floating_and_complex_types()) 3821*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3822*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3823*da0073e9SAndroid Build Coastguard Worker def test_scatter_reduce_non_unique_index(self, device, dtype): 3824*da0073e9SAndroid Build Coastguard Worker height = 2 3825*da0073e9SAndroid Build Coastguard Worker width = 2 3826*da0073e9SAndroid Build Coastguard Worker index = torch.zeros(height, width, dtype=torch.long, device=device) 3827*da0073e9SAndroid Build Coastguard Worker test_data = [ 3828*da0073e9SAndroid Build Coastguard Worker (torch.ones(height, width, device=device, dtype=dtype), 3829*da0073e9SAndroid Build Coastguard Worker torch.ones(height, width, device=device, dtype=dtype), 3830*da0073e9SAndroid Build Coastguard Worker torch.tensor([[3], [1]], device=device, dtype=dtype).repeat(1, width), "add"), 3831*da0073e9SAndroid Build Coastguard Worker (torch.tensor([2], device=device, dtype=dtype).repeat(height, width), 3832*da0073e9SAndroid Build Coastguard Worker torch.tensor([2], device=device, dtype=dtype).repeat(height, width), 3833*da0073e9SAndroid Build Coastguard Worker torch.tensor([[8], [2]], device=device, 3834*da0073e9SAndroid Build Coastguard Worker dtype=dtype).repeat(1, width), "multiply"), 3835*da0073e9SAndroid Build Coastguard Worker ] 3836*da0073e9SAndroid Build Coastguard Worker 3837*da0073e9SAndroid Build Coastguard Worker for input, src, result, operation in test_data: 3838*da0073e9SAndroid Build Coastguard Worker if not self.scatter_allow_reduce(device, dtype, operation): 3839*da0073e9SAndroid Build Coastguard Worker continue 3840*da0073e9SAndroid Build Coastguard Worker input.scatter_(0, index, src, reduce=operation) 3841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, result, msg=f"result: {result} input: {input} method: {str(operation)}") 3842*da0073e9SAndroid Build Coastguard Worker 3843*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3844*da0073e9SAndroid Build Coastguard Worker @dtypes(*complex_types()) 3845*da0073e9SAndroid Build Coastguard Worker def test_scatter_reduce_multiply_unsupported_dtypes(self, device, dtype): 3846*da0073e9SAndroid Build Coastguard Worker height = 2 3847*da0073e9SAndroid Build Coastguard Worker width = 2 3848*da0073e9SAndroid Build Coastguard Worker index = torch.zeros(height, width, dtype=torch.long, device=device) 3849*da0073e9SAndroid Build Coastguard Worker input = torch.ones(height, width, device=device, dtype=dtype) 3850*da0073e9SAndroid Build Coastguard Worker src = torch.ones(height, width, device=device, dtype=dtype) 3851*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3852*da0073e9SAndroid Build Coastguard Worker input.scatter_(0, index, src, reduce="multiply") 3853*da0073e9SAndroid Build Coastguard Worker 3854*da0073e9SAndroid Build Coastguard Worker # FIXME: port to test_scatter_gather_ops.py 3855*da0073e9SAndroid Build Coastguard Worker def test_scatter_to_large_input(self, device): 3856*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(4, 4, device=device) 3857*da0073e9SAndroid Build Coastguard Worker src = torch.ones(2, 2, device=device) 3858*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([[1], [2]], device=device, dtype=torch.long) 3859*da0073e9SAndroid Build Coastguard Worker input.scatter_(0, index, src) 3860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, torch.tensor([[0, 0, 0, 0], 3861*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3862*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3863*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0]], device=device, dtype=torch.float32)) 3864*da0073e9SAndroid Build Coastguard Worker 3865*da0073e9SAndroid Build Coastguard Worker # FIXME: port to test_scatter_gather_ops.py 3866*da0073e9SAndroid Build Coastguard Worker def test_scatter_add_to_large_input(self, device): 3867*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(4, 4, device=device) 3868*da0073e9SAndroid Build Coastguard Worker src = torch.ones(2, 2, device=device) 3869*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([[1], [2]], device=device, dtype=torch.long) 3870*da0073e9SAndroid Build Coastguard Worker input.scatter_add_(0, index, src) 3871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, torch.tensor([[0, 0, 0, 0], 3872*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3873*da0073e9SAndroid Build Coastguard Worker [1, 0, 0, 0], 3874*da0073e9SAndroid Build Coastguard Worker [0, 0, 0, 0]], device=device, dtype=torch.float32)) 3875*da0073e9SAndroid Build Coastguard Worker 3876*da0073e9SAndroid Build Coastguard Worker # FIXME: port to test_scatter_gather_ops.py 3877*da0073e9SAndroid Build Coastguard Worker def test_scatter_bool(self, device): 3878*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[True, True, True], [True, True, True]], device=device) 3879*da0073e9SAndroid Build Coastguard Worker res = torch.zeros(3, 3, dtype=torch.bool, device=device) 3880*da0073e9SAndroid Build Coastguard Worker res = res.scatter_(0, torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), x) 3881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.tensor([[True, False, False], 3882*da0073e9SAndroid Build Coastguard Worker [False, True, False], 3883*da0073e9SAndroid Build Coastguard Worker [False, False, True]], device=device)) 3884*da0073e9SAndroid Build Coastguard Worker 3885*da0073e9SAndroid Build Coastguard Worker # FIXME: port to test_scatter_gather_ops.py 3886*da0073e9SAndroid Build Coastguard Worker def test_scatter_add_bool(self, device): 3887*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]], device=device) 3888*da0073e9SAndroid Build Coastguard Worker res = torch.zeros(3, 5, dtype=torch.bool, device=device) 3889*da0073e9SAndroid Build Coastguard Worker res = res.scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], device=device), x) 3890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.tensor([[True, True, True, True, True], 3891*da0073e9SAndroid Build Coastguard Worker [False, True, False, True, False], 3892*da0073e9SAndroid Build Coastguard Worker [True, False, True, False, True]], device=device)) 3893*da0073e9SAndroid Build Coastguard Worker 3894*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked scatter operator 3895*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 3896*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) 3897*da0073e9SAndroid Build Coastguard Worker def test_masked_scatter(self, device, dtype): 3898*da0073e9SAndroid Build Coastguard Worker dt = dtype 3899*da0073e9SAndroid Build Coastguard Worker num_copy, num_dest = 3, 10 3900*da0073e9SAndroid Build Coastguard Worker dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device) 3901*da0073e9SAndroid Build Coastguard Worker dest2 = dest.clone() 3902*da0073e9SAndroid Build Coastguard Worker dest_ones = dest.clone() 3903*da0073e9SAndroid Build Coastguard Worker dest_ones_expected = dest.clone() 3904*da0073e9SAndroid Build Coastguard Worker src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device) 3905*da0073e9SAndroid Build Coastguard Worker src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device) 3906*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=torch.bool, device=device) 3907*da0073e9SAndroid Build Coastguard Worker 3908*da0073e9SAndroid Build Coastguard Worker dest.masked_scatter_(mask, src) 3909*da0073e9SAndroid Build Coastguard Worker j = 0 3910*da0073e9SAndroid Build Coastguard Worker for i in range(num_dest): 3911*da0073e9SAndroid Build Coastguard Worker if mask[i]: 3912*da0073e9SAndroid Build Coastguard Worker dest2[i] = src[j] 3913*da0073e9SAndroid Build Coastguard Worker dest_ones_expected[i] = src_ones[j] 3914*da0073e9SAndroid Build Coastguard Worker j += 1 3915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest, dest2, atol=0, rtol=0) 3916*da0073e9SAndroid Build Coastguard Worker 3917*da0073e9SAndroid Build Coastguard Worker dest_ones.masked_scatter_(mask, src_ones) 3918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0) 3919*da0073e9SAndroid Build Coastguard Worker 3920*da0073e9SAndroid Build Coastguard Worker # Bound checking in CUDA is done inside a kernel 3921*da0073e9SAndroid Build Coastguard Worker # in order to avoid synchronization, but this means 3922*da0073e9SAndroid Build Coastguard Worker # we can not clear the failures. So there is no way 3923*da0073e9SAndroid Build Coastguard Worker # to test it then recover. 3924*da0073e9SAndroid Build Coastguard Worker if self.device_type != 'cuda': 3925*da0073e9SAndroid Build Coastguard Worker # make src smaller. this should fail 3926*da0073e9SAndroid Build Coastguard Worker src = torch.zeros(num_copy - 1, dtype=dt, device=device) 3927*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3928*da0073e9SAndroid Build Coastguard Worker dest.masked_scatter_(mask, src) 3929*da0073e9SAndroid Build Coastguard Worker 3930*da0073e9SAndroid Build Coastguard Worker # empty tensor 3931*da0073e9SAndroid Build Coastguard Worker dest = torch.empty((5, 0, 5), dtype=dt, device=device) 3932*da0073e9SAndroid Build Coastguard Worker mask = torch.ones_like(dest, dtype=torch.bool, device=device) 3933*da0073e9SAndroid Build Coastguard Worker src = torch.empty((0,), dtype=dt, device=device) 3934*da0073e9SAndroid Build Coastguard Worker dest.masked_scatter_(mask, src) 3935*da0073e9SAndroid Build Coastguard Worker 3936*da0073e9SAndroid Build Coastguard Worker dest = torch.empty((5, 0, 5), dtype=dt, device=device) 3937*da0073e9SAndroid Build Coastguard Worker mask = torch.ones((5, 1, 5), dtype=torch.bool, device=device) 3938*da0073e9SAndroid Build Coastguard Worker src = torch.empty((0,), dtype=dt, device=device) 3939*da0073e9SAndroid Build Coastguard Worker dest.masked_scatter_(mask, src) 3940*da0073e9SAndroid Build Coastguard Worker 3941*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked scatter operator 3942*da0073e9SAndroid Build Coastguard Worker @skipIfMps 3943*da0073e9SAndroid Build Coastguard Worker def test_masked_scatter_bool_tensor(self, device): 3944*da0073e9SAndroid Build Coastguard Worker src = torch.tensor([True, True, True], device=device) 3945*da0073e9SAndroid Build Coastguard Worker dst = torch.tensor([False, False, False], device=device) 3946*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([False, True, False], device=device) 3947*da0073e9SAndroid Build Coastguard Worker 3948*da0073e9SAndroid Build Coastguard Worker dst.masked_scatter_(mask, src) 3949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, torch.tensor([False, True, False], device=device)) 3950*da0073e9SAndroid Build Coastguard Worker 3951*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([True, False, True], device=device) 3952*da0073e9SAndroid Build Coastguard Worker dst = dst.masked_scatter(mask, src) 3953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, torch.tensor([True, True, True], device=device)) 3954*da0073e9SAndroid Build Coastguard Worker 3955*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked scatter operator 3956*da0073e9SAndroid Build Coastguard Worker # test_scatter_gather_ops or test_masked_ops? 3957*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 3958*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('30GB') 3959*da0073e9SAndroid Build Coastguard Worker def test_masked_scatter_large_tensor(self, device): 3960*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.empty(2**31 + 1, dtype=torch.bool).random_() 3961*da0073e9SAndroid Build Coastguard Worker t = t_cpu.to(device) 3962*da0073e9SAndroid Build Coastguard Worker result_cpu = t_cpu.masked_scatter(t_cpu, t_cpu) 3963*da0073e9SAndroid Build Coastguard Worker result = t.masked_scatter(t, t) 3964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_cpu) 3965*da0073e9SAndroid Build Coastguard Worker 3966*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked select operator 3967*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3968*da0073e9SAndroid Build Coastguard Worker def test_masked_select(self, device, dtype): 3969*da0073e9SAndroid Build Coastguard Worker if device == 'cpu': 3970*da0073e9SAndroid Build Coastguard Worker warn = 'masked_select received a mask with dtype torch.uint8,' 3971*da0073e9SAndroid Build Coastguard Worker else: 3972*da0073e9SAndroid Build Coastguard Worker warn = 'indexing with dtype torch.uint8 is now deprecated, pl' 3973*da0073e9SAndroid Build Coastguard Worker for maskType in integral_types_and(torch.bool): 3974*da0073e9SAndroid Build Coastguard Worker num_src = 10 3975*da0073e9SAndroid Build Coastguard Worker src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device) 3976*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(2, (num_src,), device=device, dtype=maskType) 3977*da0073e9SAndroid Build Coastguard Worker 3978*da0073e9SAndroid Build Coastguard Worker if maskType is not torch.bool: 3979*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'expected BoolTensor for mask'): 3980*da0073e9SAndroid Build Coastguard Worker dst = src.masked_select(mask) 3981*da0073e9SAndroid Build Coastguard Worker continue 3982*da0073e9SAndroid Build Coastguard Worker else: 3983*da0073e9SAndroid Build Coastguard Worker dst = src.masked_select(mask) 3984*da0073e9SAndroid Build Coastguard Worker dst2 = [] 3985*da0073e9SAndroid Build Coastguard Worker for i in range(num_src): 3986*da0073e9SAndroid Build Coastguard Worker if mask[i]: 3987*da0073e9SAndroid Build Coastguard Worker dst2 += [src[i]] 3988*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, torch.tensor(dst2), atol=0, rtol=0) 3989*da0073e9SAndroid Build Coastguard Worker 3990*da0073e9SAndroid Build Coastguard Worker dst3 = torch.empty(0, device=device, dtype=dtype) 3991*da0073e9SAndroid Build Coastguard Worker torch.masked_select(src, mask, out=dst3) 3992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst3, torch.tensor(dst2, dtype=dst3.dtype), atol=0, rtol=0) 3993*da0073e9SAndroid Build Coastguard Worker 3994*da0073e9SAndroid Build Coastguard Worker # Since half on CPU is not supported, need to skip the remaining test cases 3995*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half and torch.device(device).type == 'cpu': 3996*da0073e9SAndroid Build Coastguard Worker return 3997*da0073e9SAndroid Build Coastguard Worker 3998*da0073e9SAndroid Build Coastguard Worker # Ensure that masks are expanded to match tensor properly 3999*da0073e9SAndroid Build Coastguard Worker a = torch.rand(100, 100, device=device).mul(100).to(dtype) 4000*da0073e9SAndroid Build Coastguard Worker mask_first_el_each_row = torch.zeros(100, device=device, dtype=torch.bool) 4001*da0073e9SAndroid Build Coastguard Worker mask_first_el_each_row[0] = True 4002*da0073e9SAndroid Build Coastguard Worker a_masked = a.masked_select(mask_first_el_each_row) 4003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_masked, a[:, 0]) 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker mask_first_row = torch.zeros(100, 1, device=device, dtype=torch.bool) 4006*da0073e9SAndroid Build Coastguard Worker mask_first_row[0][0] = True 4007*da0073e9SAndroid Build Coastguard Worker a_masked = a.masked_select(mask_first_row) 4008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_masked, a[0, :]) 4009*da0073e9SAndroid Build Coastguard Worker 4010*da0073e9SAndroid Build Coastguard Worker # Ensure that tensor is expanded to match mask properly 4011*da0073e9SAndroid Build Coastguard Worker a = torch.rand(100, device=device).mul(100).to(dtype) 4012*da0073e9SAndroid Build Coastguard Worker mask_copy_3_times = torch.tensor([[True], [True], [False], [True]], device=device) 4013*da0073e9SAndroid Build Coastguard Worker a_masked = a.masked_select(mask_copy_3_times) 4014*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_masked, a.unsqueeze(0).expand(3, 100).flatten()) 4015*da0073e9SAndroid Build Coastguard Worker 4016*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked select operator 4017*da0073e9SAndroid Build Coastguard Worker def test_masked_select_discontiguous(self, device): 4018*da0073e9SAndroid Build Coastguard Worker for size in (10, 200): 4019*da0073e9SAndroid Build Coastguard Worker vals = torch.rand(size, size, device=device) 4020*da0073e9SAndroid Build Coastguard Worker mask = torch.full((size, size), False, dtype=torch.bool, device=device) 4021*da0073e9SAndroid Build Coastguard Worker mask[:, ::2] = True 4022*da0073e9SAndroid Build Coastguard Worker vals_list = (vals, vals.t()) 4023*da0073e9SAndroid Build Coastguard Worker mask_list = (mask, mask.t()) 4024*da0073e9SAndroid Build Coastguard Worker out_dc = torch.empty(size * size, device=device)[::2] 4025*da0073e9SAndroid Build Coastguard Worker for v, m in product(vals_list, mask_list): 4026*da0073e9SAndroid Build Coastguard Worker if m.is_contiguous(): 4027*da0073e9SAndroid Build Coastguard Worker expected = v[:, ::2].clone().reshape((-1, )) 4028*da0073e9SAndroid Build Coastguard Worker else: 4029*da0073e9SAndroid Build Coastguard Worker expected = v[::2].clone().reshape((-1, )) 4030*da0073e9SAndroid Build Coastguard Worker out = torch.masked_select(v, m) 4031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected, atol=0, rtol=0) 4032*da0073e9SAndroid Build Coastguard Worker torch.masked_select(v, m, out=out_dc) 4033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_dc, expected, atol=0, rtol=0) 4034*da0073e9SAndroid Build Coastguard Worker 4035*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked fill operator 4036*da0073e9SAndroid Build Coastguard Worker @dtypes(*product(all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16), (torch.uint8, torch.bool))) 4037*da0073e9SAndroid Build Coastguard Worker def test_masked_fill(self, device, dtypes): 4038*da0073e9SAndroid Build Coastguard Worker dtype = dtypes[0] 4039*da0073e9SAndroid Build Coastguard Worker mask_dtype = dtypes[1] 4040*da0073e9SAndroid Build Coastguard Worker 4041*da0073e9SAndroid Build Coastguard Worker num_dest = 10 4042*da0073e9SAndroid Build Coastguard Worker dst = torch.zeros(num_dest, dtype=dtype) 4043*da0073e9SAndroid Build Coastguard Worker mask = torch.randint(2, (num_dest,), dtype=mask_dtype) 4044*da0073e9SAndroid Build Coastguard Worker val = random.random() 4045*da0073e9SAndroid Build Coastguard Worker dst2 = dst.clone() 4046*da0073e9SAndroid Build Coastguard Worker 4047*da0073e9SAndroid Build Coastguard Worker if mask_dtype is not torch.bool: 4048*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'only supports boolean masks'): 4049*da0073e9SAndroid Build Coastguard Worker dst.masked_fill_(mask, val) 4050*da0073e9SAndroid Build Coastguard Worker return 4051*da0073e9SAndroid Build Coastguard Worker 4052*da0073e9SAndroid Build Coastguard Worker dst.masked_fill_(mask, val) 4053*da0073e9SAndroid Build Coastguard Worker for i in range(num_dest): 4054*da0073e9SAndroid Build Coastguard Worker if mask[i]: 4055*da0073e9SAndroid Build Coastguard Worker dst2[i] = val 4056*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, dst2, atol=0, rtol=0) 4057*da0073e9SAndroid Build Coastguard Worker 4058*da0073e9SAndroid Build Coastguard Worker # test non-contiguous case 4059*da0073e9SAndroid Build Coastguard Worker dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1)) 4060*da0073e9SAndroid Build Coastguard Worker dst2 = dst.contiguous() 4061*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 4062*da0073e9SAndroid Build Coastguard Worker mask = dst.abs() > 0 4063*da0073e9SAndroid Build Coastguard Worker else: 4064*da0073e9SAndroid Build Coastguard Worker mask = dst > 0 4065*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not dst.is_contiguous()) 4066*da0073e9SAndroid Build Coastguard Worker self.assertTrue(dst2.is_contiguous()) 4067*da0073e9SAndroid Build Coastguard Worker dst.masked_fill_(mask.to(mask_dtype), val) 4068*da0073e9SAndroid Build Coastguard Worker dst2.masked_fill_(mask.to(mask_dtype), val) 4069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, dst2, atol=0, rtol=0) 4070*da0073e9SAndroid Build Coastguard Worker 4071*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the masked fill operator 4072*da0073e9SAndroid Build Coastguard Worker def test_masked_fill_bool_tensor(self, device): 4073*da0073e9SAndroid Build Coastguard Worker dst = torch.tensor([True, False, True], device=device) 4074*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([False, True, False], device=device) 4075*da0073e9SAndroid Build Coastguard Worker 4076*da0073e9SAndroid Build Coastguard Worker dst.masked_fill_(mask, True) 4077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, torch.tensor([True, True, True], device=device)) 4078*da0073e9SAndroid Build Coastguard Worker 4079*da0073e9SAndroid Build Coastguard Worker dst = dst.masked_fill(mask, False) 4080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dst, torch.tensor([True, False, True], device=device)) 4081*da0073e9SAndroid Build Coastguard Worker 4082*da0073e9SAndroid Build Coastguard Worker def test_tensor_shape_empty(self, device): 4083*da0073e9SAndroid Build Coastguard Worker x = torch.randn((0, 1, 3, 0), device=device) 4084*da0073e9SAndroid Build Coastguard Worker # flatten 4085*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), torch.flatten(x, 0, 3).shape) 4086*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape) 4087*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape) 4088*da0073e9SAndroid Build Coastguard Worker 4089*da0073e9SAndroid Build Coastguard Worker # squeeze, unsqueeze 4090*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape) 4091*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape) 4092*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 3, 0), torch.squeeze(x).shape) 4093*da0073e9SAndroid Build Coastguard Worker 4094*da0073e9SAndroid Build Coastguard Worker # transpose, t 4095*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape) 4096*da0073e9SAndroid Build Coastguard Worker y = torch.randn((5, 0), device=device) 4097*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 5), y.t().shape) 4098*da0073e9SAndroid Build Coastguard Worker 4099*da0073e9SAndroid Build Coastguard Worker # select 4100*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape) 4101*da0073e9SAndroid Build Coastguard Worker 4102*da0073e9SAndroid Build Coastguard Worker # repeat, permute 4103*da0073e9SAndroid Build Coastguard Worker self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape) 4104*da0073e9SAndroid Build Coastguard Worker self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape) 4105*da0073e9SAndroid Build Coastguard Worker 4106*da0073e9SAndroid Build Coastguard Worker # diagonal, diagflat 4107*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape) 4108*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape) 4109*da0073e9SAndroid Build Coastguard Worker # off the end offsets are valid 4110*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape) 4111*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape) 4112*da0073e9SAndroid Build Coastguard Worker # check non-zero sized offsets off the end 4113*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape) 4114*da0073e9SAndroid Build Coastguard Worker self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape) 4115*da0073e9SAndroid Build Coastguard Worker 4116*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), torch.diagflat(torch.tensor([], device=device)).shape) 4117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([], device=device), offset=1)) 4118*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 0), torch.diagflat(torch.tensor([[]], device=device)).shape) 4119*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([[]], device=device), offset=1)) 4120*da0073e9SAndroid Build Coastguard Worker 4121*da0073e9SAndroid Build Coastguard Worker # stack, split, chunk 4122*da0073e9SAndroid Build Coastguard Worker self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape) 4123*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 3, 0)], 4124*da0073e9SAndroid Build Coastguard Worker [z.shape for z in torch.chunk(x, 1, dim=0)]) 4125*da0073e9SAndroid Build Coastguard Worker 4126*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=0)]) 4127*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=2)]) 4128*da0073e9SAndroid Build Coastguard Worker 4129*da0073e9SAndroid Build Coastguard Worker # NOTE: split_with_sizes behaves differently than NumPy in that it 4130*da0073e9SAndroid Build Coastguard Worker # takes sizes rather than offsets 4131*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)], 4132*da0073e9SAndroid Build Coastguard Worker [z.shape for z in torch.split(x, (0, 1, 2), dim=2)]) 4133*da0073e9SAndroid Build Coastguard Worker 4134*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1)) 4135*da0073e9SAndroid Build Coastguard Worker # This is strange because the split size is larger than the dim size, but consistent with 4136*da0073e9SAndroid Build Coastguard Worker # how split handles that case generally (when no 0s are involved). 4137*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)]) 4138*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)]) 4139*da0073e9SAndroid Build Coastguard Worker 4140*da0073e9SAndroid Build Coastguard Worker # functions that operate over a dimension but don't reduce. 4141*da0073e9SAndroid Build Coastguard Worker def test_dim_function_empty(self, device): 4142*da0073e9SAndroid Build Coastguard Worker shape = (0, 1, 2, 0) 4143*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, device=device) 4144*da0073e9SAndroid Build Coastguard Worker 4145*da0073e9SAndroid Build Coastguard Worker # size stride 4146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, x.size(3)) 4147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, x.size(2)) 4148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, x.stride(0)) 4149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, x.stride(2)) 4150*da0073e9SAndroid Build Coastguard Worker 4151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.glu(x, 0)) 4152*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 1, 0), torch.nn.functional.glu(x, 2).shape) 4153*da0073e9SAndroid Build Coastguard Worker 4154*da0073e9SAndroid Build Coastguard Worker # softmax, logsoftmax 4155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.softmax(x, 0)) 4156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.softmax(x, 2)) 4157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.softmax(x, 3)) 4158*da0073e9SAndroid Build Coastguard Worker 4159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.log_softmax(x, 0)) 4160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.log_softmax(x, 2)) 4161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.nn.functional.log_softmax(x, 3)) 4162*da0073e9SAndroid Build Coastguard Worker 4163*da0073e9SAndroid Build Coastguard Worker # cumsum, cumprod, cummax, cummin 4164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cumsum(x, 0).shape) 4165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cumsum(x, 2).shape) 4166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cumprod(x, 0).shape) 4167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cumprod(x, 2).shape) 4168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cummax(x, 0)[0].shape) 4169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cummax(x, 2)[0].shape) 4170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cummin(x, 0)[0].shape) 4171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.cummin(x, 2)[0].shape) 4172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.logcumsumexp(x, 0).shape) 4173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.logcumsumexp(x, 2).shape) 4174*da0073e9SAndroid Build Coastguard Worker 4175*da0073e9SAndroid Build Coastguard Worker # flip 4176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.flip(0)) 4177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.flip(2)) 4178*da0073e9SAndroid Build Coastguard Worker 4179*da0073e9SAndroid Build Coastguard Worker # roll 4180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.roll(0, 1).roll(0, -1)) 4181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.roll(1, x.size(1))) 4182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.roll(1)) 4183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.roll((1, 1), (3, 1))) 4184*da0073e9SAndroid Build Coastguard Worker 4185*da0073e9SAndroid Build Coastguard Worker # unbind 4186*da0073e9SAndroid Build Coastguard Worker self.assertEqual((), x.unbind(0)) 4187*da0073e9SAndroid Build Coastguard Worker self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)), 4188*da0073e9SAndroid Build Coastguard Worker x.unbind(2)) 4189*da0073e9SAndroid Build Coastguard Worker 4190*da0073e9SAndroid Build Coastguard Worker # cross 4191*da0073e9SAndroid Build Coastguard Worker y = torch.randn((0, 1, 3, 0), device=device) 4192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.shape, torch.cross(y, y).shape) 4193*da0073e9SAndroid Build Coastguard Worker 4194*da0073e9SAndroid Build Coastguard Worker # renorm 4195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape) 4196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape) 4197*da0073e9SAndroid Build Coastguard Worker 4198*da0073e9SAndroid Build Coastguard Worker # sort 4199*da0073e9SAndroid Build Coastguard Worker self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=0)]) 4200*da0073e9SAndroid Build Coastguard Worker self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=2)]) 4201*da0073e9SAndroid Build Coastguard Worker 4202*da0073e9SAndroid Build Coastguard Worker # topk 4203*da0073e9SAndroid Build Coastguard Worker self.assertEqual([shape, shape], [z.shape for z in torch.topk(x, 0, dim=0)]) 4204*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape for z in torch.topk(x, 1, dim=2)]) 4205*da0073e9SAndroid Build Coastguard Worker 4206*da0073e9SAndroid Build Coastguard Worker y = torch.randn((2, 3, 4), device=device) 4207*da0073e9SAndroid Build Coastguard Worker self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)]) 4208*da0073e9SAndroid Build Coastguard Worker 4209*da0073e9SAndroid Build Coastguard Worker # gather 4210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape) 4211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape) 4212*da0073e9SAndroid Build Coastguard Worker larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device) 4213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape) 4214*da0073e9SAndroid Build Coastguard Worker smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device) 4215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape) 4216*da0073e9SAndroid Build Coastguard Worker y = torch.randn((2, 3, 4), device=device) 4217*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 3, 4), 4218*da0073e9SAndroid Build Coastguard Worker torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape) 4219*da0073e9SAndroid Build Coastguard Worker 4220*da0073e9SAndroid Build Coastguard Worker # scatter, scatter_add 4221*da0073e9SAndroid Build Coastguard Worker for dim in [0, 2]: 4222*da0073e9SAndroid Build Coastguard Worker y = torch.randn(shape, device=device) 4223*da0073e9SAndroid Build Coastguard Worker y_src = torch.randn(shape, device=device) 4224*da0073e9SAndroid Build Coastguard Worker ind = torch.empty(shape, dtype=torch.int64, device=device) 4225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape) 4226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape) 4227*da0073e9SAndroid Build Coastguard Worker 4228*da0073e9SAndroid Build Coastguard Worker z = torch.randn((2, 3, 4), device=device) 4229*da0073e9SAndroid Build Coastguard Worker z_src = torch.randn((2, 3, 4), device=device) 4230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) 4231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) 4232*da0073e9SAndroid Build Coastguard Worker 4233*da0073e9SAndroid Build Coastguard Worker # index_fill, index_copy, index_add 4234*da0073e9SAndroid Build Coastguard Worker c = x.clone() 4235*da0073e9SAndroid Build Coastguard Worker c_clone = c.clone() 4236*da0073e9SAndroid Build Coastguard Worker ind_empty = torch.tensor([], dtype=torch.int64, device=device) 4237*da0073e9SAndroid Build Coastguard Worker ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device) 4238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) 4239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1)) 4240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_fill_(2, ind_01, -1)) 4241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) 4242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) 4243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) 4244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) 4245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) 4246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) 4247*da0073e9SAndroid Build Coastguard Worker 4248*da0073e9SAndroid Build Coastguard Worker c = torch.randn((0, 1, 2), device=device) 4249*da0073e9SAndroid Build Coastguard Worker c_clone = c.clone() 4250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) 4251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) 4252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) 4253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) 4254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) 4255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) 4256*da0073e9SAndroid Build Coastguard Worker 4257*da0073e9SAndroid Build Coastguard Worker # index fill/copy/add non-empty 4258*da0073e9SAndroid Build Coastguard Worker z = torch.randn((2, 3, 4), device=device) 4259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z.index_fill_(0, ind_empty, -1)) 4260*da0073e9SAndroid Build Coastguard Worker z = torch.randn((2, 3, 4), device=device) 4261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device))) 4262*da0073e9SAndroid Build Coastguard Worker z = torch.randn((2, 3, 4), device=device) 4263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device))) 4264*da0073e9SAndroid Build Coastguard Worker 4265*da0073e9SAndroid Build Coastguard Worker # index_select 4266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.index_select(0, ind_empty)) 4267*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape) 4268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.index_select(2, ind_01)) 4269*da0073e9SAndroid Build Coastguard Worker z = torch.randn((2, 3, 4), device=device) # non-empty 4270*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape) 4271*da0073e9SAndroid Build Coastguard Worker c = torch.randn((0, 1, 2), device=device) 4272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, c.index_select(0, ind_empty)) 4273*da0073e9SAndroid Build Coastguard Worker c = torch.randn((0, 1, 2), device=device) 4274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, c.index_select(0, ind_empty)) 4275*da0073e9SAndroid Build Coastguard Worker w = torch.randn((0, 3), device=device) 4276*da0073e9SAndroid Build Coastguard Worker self.assertEqual((0, 2), w.index_select(1, ind_01).shape) 4277*da0073e9SAndroid Build Coastguard Worker w = torch.randn((3, 0), device=device) 4278*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0), w.index_select(0, ind_01).shape) 4279*da0073e9SAndroid Build Coastguard Worker ind_01_int32 = torch.tensor([0, 1], dtype=torch.int32, device=device) 4280*da0073e9SAndroid Build Coastguard Worker self.assertEqual((2, 0), w.index_select(0, ind_01_int32).shape) 4281*da0073e9SAndroid Build Coastguard Worker s = torch.randn([], device=device) 4282*da0073e9SAndroid Build Coastguard Worker ind_0 = torch.tensor([0], dtype=torch.int32, device=device) 4283*da0073e9SAndroid Build Coastguard Worker self.assertEqual([], s.index_select(0, ind_0).shape) 4284*da0073e9SAndroid Build Coastguard Worker if device == 'cpu': 4285*da0073e9SAndroid Build Coastguard Worker w = torch.randn((0, 3), device=device) 4286*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "self indexing axis dim should be positive"): 4287*da0073e9SAndroid Build Coastguard Worker torch.index_select(w, 0, ind_01) 4288*da0073e9SAndroid Build Coastguard Worker ind_05 = torch.tensor([0, 5], dtype=torch.int64, device=device) 4289*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "INDICES element is out of DATA bounds"): 4290*da0073e9SAndroid Build Coastguard Worker torch.index_select(w, 1, ind_05) 4291*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"): 4292*da0073e9SAndroid Build Coastguard Worker torch.index_select(s, 0, ind_empty) 4293*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Index to scalar can have only 1 value"): 4294*da0073e9SAndroid Build Coastguard Worker torch.ones([]).index_select(0, torch.Tensor([0, 0]).int()) 4295*da0073e9SAndroid Build Coastguard Worker 4296*da0073e9SAndroid Build Coastguard Worker # FIXME: find a test suite for the pdist operator 4297*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "sandcastle OOM with current tpx gpu/re configuration") 4298*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 4299*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4300*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('32GB', device='cpu') 4301*da0073e9SAndroid Build Coastguard Worker @largeTensorTest('5GB', device='cuda') 4302*da0073e9SAndroid Build Coastguard Worker def test_pdist_norm_large(self, device): 4303*da0073e9SAndroid Build Coastguard Worker # use dim0>=46342 for forward, see: 4304*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/30583 4305*da0073e9SAndroid Build Coastguard Worker # Compare output using GPU with the CPU implementation 4306*da0073e9SAndroid Build Coastguard Worker x = torch.randn(50000, 1, dtype=torch.float32) # 50k * 4 bytes = 200 KB 4307*da0073e9SAndroid Build Coastguard Worker # Will require 1249975000 float32s 4308*da0073e9SAndroid Build Coastguard Worker expected_cpu = torch.pdist(x, p=2) # ~1250M * 4 bytes = 5 GB on CPU 4309*da0073e9SAndroid Build Coastguard Worker actual_cpu = torch.pdist(x.to(device), p=2).cpu() # 5 GB on GPU + 5GB on CPU 4310*da0073e9SAndroid Build Coastguard Worker # Workaround for large memory overhead of self.assertTrue (see #84944) 4311*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected_cpu, actual_cpu)) # ~20GB in allclose 4312*da0073e9SAndroid Build Coastguard Worker 4313*da0073e9SAndroid Build Coastguard Worker # FIXME: move to elementwise ternary test suite 4314*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4315*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(*set(get_all_math_dtypes('cuda'))) 4316*da0073e9SAndroid Build Coastguard Worker @dtypes(*set(get_all_math_dtypes('cpu'))) 4317*da0073e9SAndroid Build Coastguard Worker def test_addcdiv(self, device, dtype): 4318*da0073e9SAndroid Build Coastguard Worker # Returns floating or integral scalar corresponding to dtype 4319*da0073e9SAndroid Build Coastguard Worker def _number(floating, integer, dtype): 4320*da0073e9SAndroid Build Coastguard Worker if dtype in [torch.half, torch.float, torch.double, torch.bfloat16]: 4321*da0073e9SAndroid Build Coastguard Worker return floating 4322*da0073e9SAndroid Build Coastguard Worker elif dtype in [torch.cfloat, torch.cdouble]: 4323*da0073e9SAndroid Build Coastguard Worker return floating * (1 + 1j) 4324*da0073e9SAndroid Build Coastguard Worker else: 4325*da0073e9SAndroid Build Coastguard Worker return integer 4326*da0073e9SAndroid Build Coastguard Worker 4327*da0073e9SAndroid Build Coastguard Worker def non_zero_rand(size, dtype, device): 4328*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 4329*da0073e9SAndroid Build Coastguard Worker a = torch.rand(size=size, dtype=dtype, device=device) 4330*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.uint8: 4331*da0073e9SAndroid Build Coastguard Worker a = torch.randint(1, 5, size=size, dtype=dtype, device=device) 4332*da0073e9SAndroid Build Coastguard Worker else: 4333*da0073e9SAndroid Build Coastguard Worker a = torch.randint(-5, 5, size=size, dtype=dtype, device=device) 4334*da0073e9SAndroid Build Coastguard Worker return a + (a == 0).to(dtype) 4335*da0073e9SAndroid Build Coastguard Worker 4336*da0073e9SAndroid Build Coastguard Worker def _test_addcdiv(): 4337*da0073e9SAndroid Build Coastguard Worker a = non_zero_rand((2, 2), dtype=dtype, device=device) 4338*da0073e9SAndroid Build Coastguard Worker b = non_zero_rand((2, 2), dtype=dtype, device=device) 4339*da0073e9SAndroid Build Coastguard Worker c = non_zero_rand((2, 2), dtype=dtype, device=device) 4340*da0073e9SAndroid Build Coastguard Worker alpha = _number(0.5, 3, dtype) 4341*da0073e9SAndroid Build Coastguard Worker 4342*da0073e9SAndroid Build Coastguard Worker expected = a + (alpha * b) / c 4343*da0073e9SAndroid Build Coastguard Worker actual = torch.addcdiv(a, b, c, value=alpha) 4344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, actual) 4345*da0073e9SAndroid Build Coastguard Worker 4346*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 4347*da0073e9SAndroid Build Coastguard Worker UserWarning, "This overload of addcdiv is deprecated"): 4348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, torch.addcdiv(a, alpha, b, c)) 4349*da0073e9SAndroid Build Coastguard Worker 4350*da0073e9SAndroid Build Coastguard Worker if not (dtype.is_floating_point or dtype.is_complex): 4351*da0073e9SAndroid Build Coastguard Worker # Integer division with addcdiv is prohibited 4352*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 4353*da0073e9SAndroid Build Coastguard Worker _test_addcdiv() 4354*da0073e9SAndroid Build Coastguard Worker else: 4355*da0073e9SAndroid Build Coastguard Worker _test_addcdiv() 4356*da0073e9SAndroid Build Coastguard Worker 4357*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'cuda' and dtype == torch.half: 4358*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([60000.0], device=device, dtype=dtype) 4359*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([60000.0], device=device, dtype=dtype) 4360*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([1.0], device=device, dtype=dtype) 4361*da0073e9SAndroid Build Coastguard Worker out = torch.addcmul(a, b, c, value=-2) 4362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not (out.isnan() or out.isinf())) 4363*da0073e9SAndroid Build Coastguard Worker 4364*da0073e9SAndroid Build Coastguard Worker def test_nullary_op_mem_overlap(self, device): 4365*da0073e9SAndroid Build Coastguard Worker ops = ( 4366*da0073e9SAndroid Build Coastguard Worker ("random_", ()), 4367*da0073e9SAndroid Build Coastguard Worker ("uniform_", ()), 4368*da0073e9SAndroid Build Coastguard Worker ("cauchy_", ()), 4369*da0073e9SAndroid Build Coastguard Worker ("log_normal_", ()), 4370*da0073e9SAndroid Build Coastguard Worker ("exponential_", ()), 4371*da0073e9SAndroid Build Coastguard Worker ("geometric_", (0.5,)), 4372*da0073e9SAndroid Build Coastguard Worker ("normal_", ()), 4373*da0073e9SAndroid Build Coastguard Worker ) 4374*da0073e9SAndroid Build Coastguard Worker 4375*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 3)).expand((3, 3)) 4376*da0073e9SAndroid Build Coastguard Worker for op, args in ops: 4377*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4378*da0073e9SAndroid Build Coastguard Worker getattr(x, op)(*args) 4379*da0073e9SAndroid Build Coastguard Worker 4380*da0073e9SAndroid Build Coastguard Worker # FIXME: move to an elementwise ternary test suite and make this an OpInfo test 4381*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/126474 4382*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 4383*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/126474") 4384*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 4385*da0073e9SAndroid Build Coastguard Worker def test_ternary_op_mem_overlap(self, device, dtype): 4386*da0073e9SAndroid Build Coastguard Worker if device == "cpu" and TEST_WITH_TORCHINDUCTOR: 4387*da0073e9SAndroid Build Coastguard Worker self.skipTest("Failing on cpu") 4388*da0073e9SAndroid Build Coastguard Worker 4389*da0073e9SAndroid Build Coastguard Worker ops = [ 4390*da0073e9SAndroid Build Coastguard Worker ("addcmul", True, True, 'cpu'), 4391*da0073e9SAndroid Build Coastguard Worker ("addcmul", True, True, 'cuda'), 4392*da0073e9SAndroid Build Coastguard Worker ("addcdiv", True, True, 'cpu'), 4393*da0073e9SAndroid Build Coastguard Worker ("addcdiv", True, True, 'cuda'), 4394*da0073e9SAndroid Build Coastguard Worker ("lerp", True, True, 'cpu'), 4395*da0073e9SAndroid Build Coastguard Worker ("lerp", True, True, 'cuda') 4396*da0073e9SAndroid Build Coastguard Worker ] 4397*da0073e9SAndroid Build Coastguard Worker 4398*da0073e9SAndroid Build Coastguard Worker for (fn, has_input_output_mem_overlap_check, 4399*da0073e9SAndroid Build Coastguard Worker has_internal_mem_overlap_check, dev) in ops: 4400*da0073e9SAndroid Build Coastguard Worker if dev != device: 4401*da0073e9SAndroid Build Coastguard Worker continue 4402*da0073e9SAndroid Build Coastguard Worker out_op = getattr(torch, fn) 4403*da0073e9SAndroid Build Coastguard Worker inplace_op = getattr(torch.Tensor, fn + '_') 4404*da0073e9SAndroid Build Coastguard Worker self.check_internal_mem_overlap( 4405*da0073e9SAndroid Build Coastguard Worker inplace_op, 3, dtype, device, 4406*da0073e9SAndroid Build Coastguard Worker expected_failure=not has_internal_mem_overlap_check) 4407*da0073e9SAndroid Build Coastguard Worker self.ternary_check_input_output_mem_overlap(out_op, dev, 4408*da0073e9SAndroid Build Coastguard Worker expected_failure=not has_input_output_mem_overlap_check) 4409*da0073e9SAndroid Build Coastguard Worker 4410*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError not raised 4411*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double) 4412*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4413*da0073e9SAndroid Build Coastguard Worker def test_copy_mem_overlap(self, device, dtype): 4414*da0073e9SAndroid Build Coastguard Worker self.check_internal_mem_overlap( 4415*da0073e9SAndroid Build Coastguard Worker torch.Tensor.copy_, num_inputs=2, dtype=dtype, device=device) 4416*da0073e9SAndroid Build Coastguard Worker sz = 9 4417*da0073e9SAndroid Build Coastguard Worker doubles = torch.randn(2 * sz, dtype=dtype, device=device) 4418*da0073e9SAndroid Build Coastguard Worker self.unary_check_input_output_mem_overlap( 4419*da0073e9SAndroid Build Coastguard Worker doubles, sz, lambda input, out: out.copy_(input)) 4420*da0073e9SAndroid Build Coastguard Worker 4421*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4422*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4423*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4424*da0073e9SAndroid Build Coastguard Worker def test_index_add_mem_overlap(self, device): 4425*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4426*da0073e9SAndroid Build Coastguard Worker y = torch.rand((6,), device=device) 4427*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([2, 1, 0], device=device) 4428*da0073e9SAndroid Build Coastguard Worker value = torch.rand((3,), device=device) 4429*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4430*da0073e9SAndroid Build Coastguard Worker x.index_add_(0, ind, value) 4431*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4432*da0073e9SAndroid Build Coastguard Worker y.index_add_(0, ind, y[:3]) 4433*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4434*da0073e9SAndroid Build Coastguard Worker ind.index_add_(0, ind, ind.clone()) 4435*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4436*da0073e9SAndroid Build Coastguard Worker ind.index_add_(0, ind.clone(), ind) 4437*da0073e9SAndroid Build Coastguard Worker 4438*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4439*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4440*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4441*da0073e9SAndroid Build Coastguard Worker def test_index_copy_mem_overlap(self, device): 4442*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4443*da0073e9SAndroid Build Coastguard Worker y = torch.rand((6,), device=device) 4444*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([2, 1, 0], device=device) 4445*da0073e9SAndroid Build Coastguard Worker value = torch.rand((3,), device=device) 4446*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4447*da0073e9SAndroid Build Coastguard Worker x.index_copy_(0, ind, value) 4448*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4449*da0073e9SAndroid Build Coastguard Worker y.index_copy_(0, ind, y[:3]) 4450*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4451*da0073e9SAndroid Build Coastguard Worker ind.index_copy_(0, ind, ind.clone()) 4452*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4453*da0073e9SAndroid Build Coastguard Worker ind.index_copy_(0, ind.clone(), ind) 4454*da0073e9SAndroid Build Coastguard Worker 4455*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4456*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4457*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # Warning not triggered 4458*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4459*da0073e9SAndroid Build Coastguard Worker def test_index_fill_mem_overlap(self, device): 4460*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4461*da0073e9SAndroid Build Coastguard Worker y = torch.rand((6,), device=device) 4462*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([2, 1, 0], device=device) 4463*da0073e9SAndroid Build Coastguard Worker value = torch.rand((3,), device=device) 4464*da0073e9SAndroid Build Coastguard Worker 4465*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "index_fill_ on expanded tensors"): 4466*da0073e9SAndroid Build Coastguard Worker x.index_fill_(0, ind, 1.0) 4467*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4468*da0073e9SAndroid Build Coastguard Worker ind.index_fill_(0, ind, 0) 4469*da0073e9SAndroid Build Coastguard Worker 4470*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4471*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError not raised 4472*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4473*da0073e9SAndroid Build Coastguard Worker def test_shift_mem_overlap(self, device): 4474*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, device=device) 4475*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4476*da0073e9SAndroid Build Coastguard Worker x[:-1] <<= x[1:] 4477*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4478*da0073e9SAndroid Build Coastguard Worker x[:-1] >>= x[1:] 4479*da0073e9SAndroid Build Coastguard Worker 4480*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4481*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors) 4482*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError not raised 4483*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4484*da0073e9SAndroid Build Coastguard Worker def test_bernoulli_mem_overlap(self, device): 4485*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4486*da0073e9SAndroid Build Coastguard Worker 4487*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4488*da0073e9SAndroid Build Coastguard Worker x.bernoulli_() 4489*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4490*da0073e9SAndroid Build Coastguard Worker x.bernoulli_(p=0.1) 4491*da0073e9SAndroid Build Coastguard Worker p = torch.rand(6, device=device) 4492*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4493*da0073e9SAndroid Build Coastguard Worker x.bernoulli_(p=p) 4494*da0073e9SAndroid Build Coastguard Worker 4495*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4496*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4497*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError not raised 4498*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4499*da0073e9SAndroid Build Coastguard Worker def test_put_mem_overlap(self, device): 4500*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4501*da0073e9SAndroid Build Coastguard Worker y = torch.rand((6,), device=device) 4502*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([2, 1, 0], device=device) 4503*da0073e9SAndroid Build Coastguard Worker value = torch.rand((3,), device=device) 4504*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4505*da0073e9SAndroid Build Coastguard Worker x.put_(ind, value) 4506*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4507*da0073e9SAndroid Build Coastguard Worker y.put_(ind[0], y[0]) 4508*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4509*da0073e9SAndroid Build Coastguard Worker ind.put_(ind, ind) 4510*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4511*da0073e9SAndroid Build Coastguard Worker y.put_(ind, y[:3]) 4512*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4513*da0073e9SAndroid Build Coastguard Worker ind.put_(ind, ind.clone()) 4514*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4515*da0073e9SAndroid Build Coastguard Worker ind.put_(ind.clone(), ind) 4516*da0073e9SAndroid Build Coastguard Worker 4517*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4518*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4519*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # UserWarning not triggered 4520*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4521*da0073e9SAndroid Build Coastguard Worker def test_index_put_mem_overlap(self, device): 4522*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4523*da0073e9SAndroid Build Coastguard Worker y = torch.rand((6,), device=device) 4524*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([2, 1, 0], device=device) 4525*da0073e9SAndroid Build Coastguard Worker value = torch.rand((3,), device=device) 4526*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, 'expanded tensors'): 4527*da0073e9SAndroid Build Coastguard Worker x.index_put_((ind,), value) 4528*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4529*da0073e9SAndroid Build Coastguard Worker y.index_put_((ind,), y[0]) 4530*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4531*da0073e9SAndroid Build Coastguard Worker ind.index_put_((ind,), ind) 4532*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4533*da0073e9SAndroid Build Coastguard Worker y.index_put_((ind,), y[:3]) 4534*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4535*da0073e9SAndroid Build Coastguard Worker ind.index_put_((ind,), ind.clone()) 4536*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4537*da0073e9SAndroid Build Coastguard Worker ind.index_put_((ind.clone(),), ind) 4538*da0073e9SAndroid Build Coastguard Worker 4539*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4540*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4541*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # UserWarning not triggered 4542*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4543*da0073e9SAndroid Build Coastguard Worker def test_masked_fill_mem_overlap(self, device): 4544*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4545*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([True, False, True, True, False, False], device=device) 4546*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, 'expanded tensors'): 4547*da0073e9SAndroid Build Coastguard Worker x.masked_fill_(mask, 0.) 4548*da0073e9SAndroid Build Coastguard Worker 4549*da0073e9SAndroid Build Coastguard Worker fill_val = torch.tensor(0., device=device) 4550*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, 'expanded tensors'): 4551*da0073e9SAndroid Build Coastguard Worker x.masked_fill_(mask, fill_val) 4552*da0073e9SAndroid Build Coastguard Worker 4553*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4554*da0073e9SAndroid Build Coastguard Worker mask[1:].masked_fill_(mask[:-1], False) 4555*da0073e9SAndroid Build Coastguard Worker 4556*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4557*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4558*da0073e9SAndroid Build Coastguard Worker @expectedFailureMeta # RuntimeError not raised 4559*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4560*da0073e9SAndroid Build Coastguard Worker def test_masked_scatter_mem_overlap(self, device): 4561*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4562*da0073e9SAndroid Build Coastguard Worker src = torch.rand((3,), device=device) 4563*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([True, False, True, True, False, False], device=device) 4564*da0073e9SAndroid Build Coastguard Worker 4565*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4566*da0073e9SAndroid Build Coastguard Worker x.masked_scatter_(mask, src) 4567*da0073e9SAndroid Build Coastguard Worker 4568*da0073e9SAndroid Build Coastguard Worker # FIXME: convert to ErrorInputs 4569*da0073e9SAndroid Build Coastguard Worker # (but have to extend ErrorInputs to handle inplace-only errors!) 4570*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 4571*da0073e9SAndroid Build Coastguard Worker def test_scatter_mem_overlap(self, device): 4572*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device).expand((6,)) 4573*da0073e9SAndroid Build Coastguard Worker src = torch.rand((3,), device=device) 4574*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([2, 1, 0], device=device, dtype=torch.int64) 4575*da0073e9SAndroid Build Coastguard Worker 4576*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4577*da0073e9SAndroid Build Coastguard Worker x.scatter_(0, ind, src) 4578*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4579*da0073e9SAndroid Build Coastguard Worker src.scatter_(0, ind, src) 4580*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): 4581*da0073e9SAndroid Build Coastguard Worker ind.scatter_(0, ind, ind.clone()) 4582*da0073e9SAndroid Build Coastguard Worker 4583*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 4584*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4585*da0073e9SAndroid Build Coastguard Worker def test_multinomial_device_constrain(self, device): 4586*da0073e9SAndroid Build Coastguard Worker x = torch.empty(3, device="cpu") 4587*da0073e9SAndroid Build Coastguard Worker y = torch.empty(3, device=device) 4588*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 4589*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device", 4590*da0073e9SAndroid Build Coastguard Worker lambda: torch.multinomial(x, 2, out=y)) 4591*da0073e9SAndroid Build Coastguard Worker 4592*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 4593*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 4594*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4595*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("FIXME: error not thrown") 4596*da0073e9SAndroid Build Coastguard Worker def test_multinomial_gpu_device_constrain(self, devices): 4597*da0073e9SAndroid Build Coastguard Worker x = torch.empty(3, device=devices[0]) 4598*da0073e9SAndroid Build Coastguard Worker y = torch.empty(3, device=devices[1], dtype=torch.long) 4599*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 4600*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Expected all tensors to be on the same device", 4601*da0073e9SAndroid Build Coastguard Worker lambda: torch.multinomial(x, 2, out=y)) 4602*da0073e9SAndroid Build Coastguard Worker 4603*da0073e9SAndroid Build Coastguard Worker # FIXME: convert this to an automated OpInfo test 4604*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 4605*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4606*da0073e9SAndroid Build Coastguard Worker def test_device_guard(self, devices): 4607*da0073e9SAndroid Build Coastguard Worker # verify that all operators with `device_guard: False` behave properly with multiple devices. 4608*da0073e9SAndroid Build Coastguard Worker # TODO: if we had operator introspection we could figure out this set of operators automatically... 4609*da0073e9SAndroid Build Coastguard Worker x = torch.randn((1, 2, 3), device=devices[1]) 4610*da0073e9SAndroid Build Coastguard Worker y = torch.zeros((1, 3, 2), device=devices[1]) 4611*da0073e9SAndroid Build Coastguard Worker scalar = torch.tensor(5, device=devices[1]) 4612*da0073e9SAndroid Build Coastguard Worker 4613*da0073e9SAndroid Build Coastguard Worker # property ops 4614*da0073e9SAndroid Build Coastguard Worker torch.cudnn_is_acceptable(x) 4615*da0073e9SAndroid Build Coastguard Worker x.is_distributed() 4616*da0073e9SAndroid Build Coastguard Worker x.is_floating_point() 4617*da0073e9SAndroid Build Coastguard Worker x.is_complex() 4618*da0073e9SAndroid Build Coastguard Worker x.is_same_size(y) 4619*da0073e9SAndroid Build Coastguard Worker x.is_signed() 4620*da0073e9SAndroid Build Coastguard Worker x.size(0) 4621*da0073e9SAndroid Build Coastguard Worker x.stride(0) 4622*da0073e9SAndroid Build Coastguard Worker x.numel() 4623*da0073e9SAndroid Build Coastguard Worker x.is_set_to(y) 4624*da0073e9SAndroid Build Coastguard Worker x.data_ptr() 4625*da0073e9SAndroid Build Coastguard Worker scalar.is_nonzero() 4626*da0073e9SAndroid Build Coastguard Worker 4627*da0073e9SAndroid Build Coastguard Worker # sparse property ops 4628*da0073e9SAndroid Build Coastguard Worker y[0][1] = 5 4629*da0073e9SAndroid Build Coastguard Worker y_sparse = y.to_sparse() 4630*da0073e9SAndroid Build Coastguard Worker y_sparse.sparse_dim() 4631*da0073e9SAndroid Build Coastguard Worker y_sparse._dimI() 4632*da0073e9SAndroid Build Coastguard Worker y_sparse.dense_dim() 4633*da0073e9SAndroid Build Coastguard Worker y_sparse._dimV() 4634*da0073e9SAndroid Build Coastguard Worker y_sparse._nnz() 4635*da0073e9SAndroid Build Coastguard Worker y_sparse.is_coalesced() 4636*da0073e9SAndroid Build Coastguard Worker y_sparse._indices() 4637*da0073e9SAndroid Build Coastguard Worker y_sparse._values() 4638*da0073e9SAndroid Build Coastguard Worker y_sparse.indices() 4639*da0073e9SAndroid Build Coastguard Worker y_sparse.values() 4640*da0073e9SAndroid Build Coastguard Worker 4641*da0073e9SAndroid Build Coastguard Worker # in-place ops 4642*da0073e9SAndroid Build Coastguard Worker def inplace(): 4643*da0073e9SAndroid Build Coastguard Worker return torch.randn((1, 2, 3), device=devices[1]) 4644*da0073e9SAndroid Build Coastguard Worker inplace().as_strided_(y.size(), y.stride()) 4645*da0073e9SAndroid Build Coastguard Worker inplace().resize_(y.size()) 4646*da0073e9SAndroid Build Coastguard Worker inplace().squeeze_() 4647*da0073e9SAndroid Build Coastguard Worker inplace().squeeze_(0) 4648*da0073e9SAndroid Build Coastguard Worker inplace().unsqueeze_(2) 4649*da0073e9SAndroid Build Coastguard Worker inplace().transpose_(1, 2) 4650*da0073e9SAndroid Build Coastguard Worker inplace().squeeze_().t_() 4651*da0073e9SAndroid Build Coastguard Worker inplace().set_(x.storage()) 4652*da0073e9SAndroid Build Coastguard Worker inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride()) 4653*da0073e9SAndroid Build Coastguard Worker inplace().set_(x) 4654*da0073e9SAndroid Build Coastguard Worker inplace().set_() 4655*da0073e9SAndroid Build Coastguard Worker y_sparse._coalesced_(True) 4656*da0073e9SAndroid Build Coastguard Worker 4657*da0073e9SAndroid Build Coastguard Worker # shape modification 4658*da0073e9SAndroid Build Coastguard Worker x.as_strided(y.size(), y.stride()) 4659*da0073e9SAndroid Build Coastguard Worker x.expand((5, 2, 3)) 4660*da0073e9SAndroid Build Coastguard Worker x.expand_as(x) 4661*da0073e9SAndroid Build Coastguard Worker x.sum_to_size((1,)) 4662*da0073e9SAndroid Build Coastguard Worker torch.broadcast_tensors(x , x) 4663*da0073e9SAndroid Build Coastguard Worker x.reshape((1, 3, 2)) 4664*da0073e9SAndroid Build Coastguard Worker x.reshape_as(y) 4665*da0073e9SAndroid Build Coastguard Worker x.squeeze() 4666*da0073e9SAndroid Build Coastguard Worker x.squeeze(0) 4667*da0073e9SAndroid Build Coastguard Worker x.squeeze().t() 4668*da0073e9SAndroid Build Coastguard Worker x.transpose(1, 2) 4669*da0073e9SAndroid Build Coastguard Worker x.unsqueeze(2) 4670*da0073e9SAndroid Build Coastguard Worker x.view((1, 3, 2)) 4671*da0073e9SAndroid Build Coastguard Worker x.view_as(y) 4672*da0073e9SAndroid Build Coastguard Worker 4673*da0073e9SAndroid Build Coastguard Worker # chunk, split, etc. 4674*da0073e9SAndroid Build Coastguard Worker x.chunk(2, dim=1) 4675*da0073e9SAndroid Build Coastguard Worker x.split(1, dim=2) 4676*da0073e9SAndroid Build Coastguard Worker x.split_with_sizes([1, 2], dim=2) 4677*da0073e9SAndroid Build Coastguard Worker x.unfold(dimension=2, size=1, step=1) 4678*da0073e9SAndroid Build Coastguard Worker 4679*da0073e9SAndroid Build Coastguard Worker x.narrow(1, 1, 1) 4680*da0073e9SAndroid Build Coastguard Worker x.select(1, 1) 4681*da0073e9SAndroid Build Coastguard Worker torch.isnan(x) 4682*da0073e9SAndroid Build Coastguard Worker 4683*da0073e9SAndroid Build Coastguard Worker torch.empty((1, 3, 2), out=y) 4684*da0073e9SAndroid Build Coastguard Worker torch.empty_like(x) 4685*da0073e9SAndroid Build Coastguard Worker torch.empty_like(x, dtype=torch.int64) 4686*da0073e9SAndroid Build Coastguard Worker 4687*da0073e9SAndroid Build Coastguard Worker # to 4688*da0073e9SAndroid Build Coastguard Worker x.to(x) 4689*da0073e9SAndroid Build Coastguard Worker x.to(y) 4690*da0073e9SAndroid Build Coastguard Worker x.to(x, copy=True) 4691*da0073e9SAndroid Build Coastguard Worker 4692*da0073e9SAndroid Build Coastguard Worker def test_is_signed(self, device): 4693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.IntTensor(5).to(device).is_signed(), True) 4694*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ByteTensor(5).to(device).is_signed(), False) 4695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.CharTensor(5).to(device).is_signed(), True) 4696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.FloatTensor(5).to(device).is_signed(), True) 4697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.HalfTensor(10).to(device).is_signed(), True) 4698*da0073e9SAndroid Build Coastguard Worker 4699*da0073e9SAndroid Build Coastguard Worker def test_tensor_type(self): 4700*da0073e9SAndroid Build Coastguard Worker for t in torch._tensor_classes: 4701*da0073e9SAndroid Build Coastguard Worker if 'cuda' in t.__module__: 4702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_cuda, True) 4703*da0073e9SAndroid Build Coastguard Worker else: 4704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_cuda, False) 4705*da0073e9SAndroid Build Coastguard Worker if 'xpu' in t.__module__: 4706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_xpu, True) 4707*da0073e9SAndroid Build Coastguard Worker else: 4708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.is_xpu, False) 4709*da0073e9SAndroid Build Coastguard Worker 4710*da0073e9SAndroid Build Coastguard Worker # Note - reports a leak of 512 bytes on CUDA device 1 4711*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 4712*da0073e9SAndroid Build Coastguard Worker @skipCUDAMemoryLeakCheckIf(True) 4713*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4714*da0073e9SAndroid Build Coastguard Worker def test_tensor_set_errors_multigpu(self, devices): 4715*da0073e9SAndroid Build Coastguard Worker f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device=devices[0]) 4716*da0073e9SAndroid Build Coastguard Worker f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device=devices[1]) 4717*da0073e9SAndroid Build Coastguard Worker 4718*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage())) 4719*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 4720*da0073e9SAndroid Build Coastguard Worker lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride())) 4721*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1)) 4722*da0073e9SAndroid Build Coastguard Worker 4723*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_serialization 4724*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 4725*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) # Note: Tests works with one but prefers more devices 4726*da0073e9SAndroid Build Coastguard Worker def test_serialization(self, devices): 4727*da0073e9SAndroid Build Coastguard Worker def _test_serialization(filecontext_lambda): 4728*da0073e9SAndroid Build Coastguard Worker t0 = torch.cuda.FloatTensor(5).fill_(1) 4729*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(devices[-1]): 4730*da0073e9SAndroid Build Coastguard Worker tn = torch.cuda.FloatTensor(3).fill_(2) 4731*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(devices[0]) 4732*da0073e9SAndroid Build Coastguard Worker b = (t0, tn) 4733*da0073e9SAndroid Build Coastguard Worker with filecontext_lambda() as f: 4734*da0073e9SAndroid Build Coastguard Worker torch.save(b, f) 4735*da0073e9SAndroid Build Coastguard Worker f.seek(0) 4736*da0073e9SAndroid Build Coastguard Worker c = torch.load(f) 4737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, c, atol=0, rtol=0) 4738*da0073e9SAndroid Build Coastguard Worker u0, un = c 4739*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(u0.device), devices[0]) 4740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(un.device), devices[-1]) 4741*da0073e9SAndroid Build Coastguard Worker 4742*da0073e9SAndroid Build Coastguard Worker _test_serialization(tempfile.NamedTemporaryFile) 4743*da0073e9SAndroid Build Coastguard Worker _test_serialization(BytesIOContext) 4744*da0073e9SAndroid Build Coastguard Worker 4745*da0073e9SAndroid Build Coastguard Worker # FIXME: move memory format tests to their own test class/suite 4746*da0073e9SAndroid Build Coastguard Worker def test_memory_format_preserved_after_permute(self, device): 4747*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 8, 8, device=device) 4748*da0073e9SAndroid Build Coastguard Worker nhwc = x.contiguous(memory_format=torch.channels_last) 4749*da0073e9SAndroid Build Coastguard Worker y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2) 4750*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) 4751*da0073e9SAndroid Build Coastguard Worker 4752*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 8, 8, 8, device=device) 4753*da0073e9SAndroid Build Coastguard Worker ndhwc = x.contiguous(memory_format=torch.channels_last_3d) 4754*da0073e9SAndroid Build Coastguard Worker y = ndhwc.permute(0, 1, 4, 3, 2).permute(0, 1, 4, 3, 2) 4755*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous(memory_format=torch.channels_last_3d)) 4756*da0073e9SAndroid Build Coastguard Worker 4757*da0073e9SAndroid Build Coastguard Worker def test_memory_format_propagation_rules(self, device): 4758*da0073e9SAndroid Build Coastguard Worker 4759*da0073e9SAndroid Build Coastguard Worker contiguous = torch.rand(10, 3, 5, 5, device=device) 4760*da0073e9SAndroid Build Coastguard Worker cl = torch.rand(10, 3, 5, 5, device=device).contiguous(memory_format=torch.channels_last) 4761*da0073e9SAndroid Build Coastguard Worker ambiguous = torch.rand(10, 3, 1, 1, device=device).contiguous(memory_format=torch.channels_last) 4762*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ambiguous.is_contiguous(memory_format=torch.channels_last)) 4763*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ambiguous.is_contiguous(memory_format=torch.contiguous_format)) 4764*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(1, 1, 1, 1, device=device).contiguous(memory_format=torch.channels_last) 4765*da0073e9SAndroid Build Coastguard Worker 4766*da0073e9SAndroid Build Coastguard Worker def _test_propagation_rules(self, contiguous, cl, ambiguous, bias): 4767*da0073e9SAndroid Build Coastguard Worker options = ((ambiguous, contiguous, torch.contiguous_format), 4768*da0073e9SAndroid Build Coastguard Worker (ambiguous, cl, torch.channels_last), 4769*da0073e9SAndroid Build Coastguard Worker (contiguous, ambiguous, torch.contiguous_format), 4770*da0073e9SAndroid Build Coastguard Worker (contiguous, cl, torch.contiguous_format), 4771*da0073e9SAndroid Build Coastguard Worker (cl, ambiguous, torch.channels_last), 4772*da0073e9SAndroid Build Coastguard Worker (cl, contiguous, torch.channels_last), 4773*da0073e9SAndroid Build Coastguard Worker (bias, cl, torch.channels_last), 4774*da0073e9SAndroid Build Coastguard Worker (cl, bias, torch.channels_last),) 4775*da0073e9SAndroid Build Coastguard Worker 4776*da0073e9SAndroid Build Coastguard Worker for a, b, mf in options: 4777*da0073e9SAndroid Build Coastguard Worker result = a + b 4778*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_contiguous(memory_format=mf)) 4779*da0073e9SAndroid Build Coastguard Worker 4780*da0073e9SAndroid Build Coastguard Worker _test_propagation_rules(self, contiguous, cl, ambiguous, bias) 4781*da0073e9SAndroid Build Coastguard Worker 4782*da0073e9SAndroid Build Coastguard Worker cl = cl.to(memory_format=torch.channels_last) 4783*da0073e9SAndroid Build Coastguard Worker ambiguous = ambiguous.to(memory_format=torch.channels_last) 4784*da0073e9SAndroid Build Coastguard Worker bias = bias.to(memory_format=torch.channels_last) 4785*da0073e9SAndroid Build Coastguard Worker 4786*da0073e9SAndroid Build Coastguard Worker _test_propagation_rules(self, contiguous, cl, ambiguous, bias) 4787*da0073e9SAndroid Build Coastguard Worker 4788*da0073e9SAndroid Build Coastguard Worker # test cases when strides matter in ambiguous tensors 4789*da0073e9SAndroid Build Coastguard Worker for mf in (torch.channels_last, torch.contiguous_format): 4790*da0073e9SAndroid Build Coastguard Worker ambiguous = torch.rand(10, 3, 1, 1, device=device).to(memory_format=mf) 4791*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(3, 1, 1, device=device) 4792*da0073e9SAndroid Build Coastguard Worker result = ambiguous + bias 4793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ambiguous.stride(), result.stride()) 4794*da0073e9SAndroid Build Coastguard Worker result = bias + ambiguous 4795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ambiguous.stride(), result.stride()) 4796*da0073e9SAndroid Build Coastguard Worker result = ambiguous * 5 4797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ambiguous.stride(), result.stride()) 4798*da0073e9SAndroid Build Coastguard Worker 4799*da0073e9SAndroid Build Coastguard Worker @skipIfMps 4800*da0073e9SAndroid Build Coastguard Worker def test_memory_format_empty_like(self, device): 4801*da0073e9SAndroid Build Coastguard Worker def test_helper(x, memory_format): 4802*da0073e9SAndroid Build Coastguard Worker xc = x.contiguous(memory_format=memory_format) 4803*da0073e9SAndroid Build Coastguard Worker 4804*da0073e9SAndroid Build Coastguard Worker like = torch.empty_like(xc, memory_format=torch.preserve_format) 4805*da0073e9SAndroid Build Coastguard Worker self.assertFalse(like.is_contiguous()) 4806*da0073e9SAndroid Build Coastguard Worker self.assertTrue(like.is_contiguous(memory_format=memory_format)) 4807*da0073e9SAndroid Build Coastguard Worker 4808*da0073e9SAndroid Build Coastguard Worker like_x = torch.empty_like(x, memory_format=torch.preserve_format) 4809*da0073e9SAndroid Build Coastguard Worker self.assertTrue(like_x.is_contiguous()) 4810*da0073e9SAndroid Build Coastguard Worker self.assertFalse(like_x.is_contiguous(memory_format=memory_format)) 4811*da0073e9SAndroid Build Coastguard Worker 4812*da0073e9SAndroid Build Coastguard Worker like = torch.empty_like(x, memory_format=memory_format) 4813*da0073e9SAndroid Build Coastguard Worker self.assertFalse(like.is_contiguous()) 4814*da0073e9SAndroid Build Coastguard Worker self.assertTrue(like.is_contiguous(memory_format=memory_format)) 4815*da0073e9SAndroid Build Coastguard Worker 4816*da0073e9SAndroid Build Coastguard Worker like = torch.empty_like(xc, memory_format=torch.contiguous_format) 4817*da0073e9SAndroid Build Coastguard Worker self.assertTrue(like.is_contiguous()) 4818*da0073e9SAndroid Build Coastguard Worker self.assertFalse(like.is_contiguous(memory_format=memory_format)) 4819*da0073e9SAndroid Build Coastguard Worker 4820*da0073e9SAndroid Build Coastguard Worker like = torch.empty_like(xc) 4821*da0073e9SAndroid Build Coastguard Worker self.assertFalse(like.is_contiguous()) 4822*da0073e9SAndroid Build Coastguard Worker self.assertTrue(like.is_contiguous(memory_format=memory_format)) 4823*da0073e9SAndroid Build Coastguard Worker 4824*da0073e9SAndroid Build Coastguard Worker sparse = x.to_sparse() 4825*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 4826*da0073e9SAndroid Build Coastguard Worker z = torch.empty_like(sparse, memory_format=torch.preserve_format) 4827*da0073e9SAndroid Build Coastguard Worker 4828*da0073e9SAndroid Build Coastguard Worker test_helper(torch.randn(4, 3, 8, 8, device=device), torch.channels_last) 4829*da0073e9SAndroid Build Coastguard Worker test_helper(torch.randn(4, 3, 8, 8, 8, device=device), torch.channels_last_3d) 4830*da0073e9SAndroid Build Coastguard Worker 4831*da0073e9SAndroid Build Coastguard Worker def test_memory_format_consistency(self, device): 4832*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 3, 1, 1, device=device) 4833*da0073e9SAndroid Build Coastguard Worker x_rep = x.as_strided(x.size(), x.stride()) 4834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(), x_rep.size()) 4835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), x_rep.stride()) 4836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.is_contiguous(), x_rep.is_contiguous()) 4837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.is_contiguous(memory_format=torch.channels_last), x_rep.is_contiguous(memory_format=torch.channels_last)) 4838*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4839*da0073e9SAndroid Build Coastguard Worker x.is_contiguous(memory_format=torch.channels_last_3d), x_rep.is_contiguous(memory_format=torch.channels_last_3d)) 4840*da0073e9SAndroid Build Coastguard Worker 4841*da0073e9SAndroid Build Coastguard Worker # FIXME: make this a elementwise unary and elementwise binary OpInfo test 4842*da0073e9SAndroid Build Coastguard Worker def test_memory_format_operators(self, device): 4843*da0073e9SAndroid Build Coastguard Worker def _chunk_op(x, y): 4844*da0073e9SAndroid Build Coastguard Worker x1, x2 = x.chunk(2, dim=1) 4845*da0073e9SAndroid Build Coastguard Worker return x1 + x2 4846*da0073e9SAndroid Build Coastguard Worker 4847*da0073e9SAndroid Build Coastguard Worker def _unsqueeze_op_add(x, y): 4848*da0073e9SAndroid Build Coastguard Worker return x[0].unsqueeze(0) + 3 4849*da0073e9SAndroid Build Coastguard Worker 4850*da0073e9SAndroid Build Coastguard Worker def _unsqueeze_op_clone(x, y): 4851*da0073e9SAndroid Build Coastguard Worker return x[0].unsqueeze(0).clone() 4852*da0073e9SAndroid Build Coastguard Worker 4853*da0073e9SAndroid Build Coastguard Worker def _test_helper(x, y, bias, memory_format): 4854*da0073e9SAndroid Build Coastguard Worker return_contig_fns = [ 4855*da0073e9SAndroid Build Coastguard Worker lambda x, y: y + x, 4856*da0073e9SAndroid Build Coastguard Worker lambda x, y: y * x, 4857*da0073e9SAndroid Build Coastguard Worker lambda x, y: y.addcdiv(x, y, value=2), 4858*da0073e9SAndroid Build Coastguard Worker lambda x, y: y.addcmul(x, y, value=2), 4859*da0073e9SAndroid Build Coastguard Worker ] 4860*da0073e9SAndroid Build Coastguard Worker bias_fns = [ 4861*da0073e9SAndroid Build Coastguard Worker lambda x, b: x + b, 4862*da0073e9SAndroid Build Coastguard Worker lambda x, b: b + x, 4863*da0073e9SAndroid Build Coastguard Worker ] 4864*da0073e9SAndroid Build Coastguard Worker fns = [ 4865*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.clone(), 4866*da0073e9SAndroid Build Coastguard Worker lambda x, y: x + 3, 4867*da0073e9SAndroid Build Coastguard Worker lambda x, y: 3 * x, 4868*da0073e9SAndroid Build Coastguard Worker lambda x, y: x + y, 4869*da0073e9SAndroid Build Coastguard Worker lambda x, y: x * y, 4870*da0073e9SAndroid Build Coastguard Worker lambda x, y: abs(x), 4871*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.abs(), 4872*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.abs_(), 4873*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.acos(), 4874*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.acos_(), 4875*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.add(y, alpha=3), 4876*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.add_(y, alpha=3), 4877*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.addcdiv(y, y, value=2), 4878*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.addcdiv_(y, y, value=2), 4879*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.addcmul(y, y, value=2), 4880*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.addcmul_(y, y, value=2), 4881*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.acosh(), 4882*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.acosh_(), 4883*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.asinh(), 4884*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.asinh_(), 4885*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.atanh(), 4886*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.atanh_(), 4887*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.asin(), 4888*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.asin_(), 4889*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.atan(), 4890*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.atan2(y), 4891*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.atan2_(y), 4892*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.ceil(), 4893*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.ceil_(), 4894*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.clamp(-1, 1), 4895*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.cos(), 4896*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.cosh(), 4897*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.div(0.5), 4898*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.div_(0.5), 4899*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.div(y), 4900*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.div_(y), 4901*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.digamma(), 4902*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.digamma_(), 4903*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.erf(), 4904*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.erfc(), 4905*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.erfinv(), 4906*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.erfinv_(), 4907*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.exp(), 4908*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.expm1(), 4909*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.expm1_(), 4910*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.floor(), 4911*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.floor_(), 4912*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.fmod(2), 4913*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.frac(), 4914*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.hypot(y), 4915*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.hypot_(y), 4916*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.i0(), 4917*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.i0_(), 4918*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.lerp(y, 0.5), 4919*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log(), 4920*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log_(), 4921*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log10(), 4922*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log10_(), 4923*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log1p(), 4924*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log1p_(), 4925*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log2(), 4926*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.log2_(), 4927*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.mul(3), 4928*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.mul_(3), 4929*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.neg(), 4930*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.neg_(), 4931*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.pow(3), 4932*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.pow_(3), 4933*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.pow(0.0), 4934*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.pow(1.0), 4935*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.reciprocal(), 4936*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.remainder(2), 4937*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.round(), 4938*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.round_(), 4939*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.rsqrt(), 4940*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.rsqrt_(), 4941*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sigmoid(), 4942*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sigmoid_(), 4943*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.logit(), 4944*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.logit_(), 4945*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.logit(1e-6), 4946*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.logit_(1e-6), 4947*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sign(), 4948*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sign_(), 4949*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sgn(), 4950*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sgn_(), 4951*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sin(), 4952*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sin_(), 4953*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sinh(), 4954*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sinh_(), 4955*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sqrt(), 4956*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.sqrt_(), 4957*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.tan(), 4958*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.tanh(), 4959*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.trunc(), 4960*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.trunc_(), 4961*da0073e9SAndroid Build Coastguard Worker _chunk_op, 4962*da0073e9SAndroid Build Coastguard Worker _unsqueeze_op_add, 4963*da0073e9SAndroid Build Coastguard Worker _unsqueeze_op_clone, 4964*da0073e9SAndroid Build Coastguard Worker ] 4965*da0073e9SAndroid Build Coastguard Worker x_c = x.contiguous() 4966*da0073e9SAndroid Build Coastguard Worker y_c = y.contiguous() 4967*da0073e9SAndroid Build Coastguard Worker b_c = bias.contiguous() 4968*da0073e9SAndroid Build Coastguard Worker for fn in fns: 4969*da0073e9SAndroid Build Coastguard Worker is_inplace = '_(' in inspect.getsource(fn) 4970*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone() if is_inplace else x 4971*da0073e9SAndroid Build Coastguard Worker x_c_clone = x_c.clone() if is_inplace else x_c 4972*da0073e9SAndroid Build Coastguard Worker result_c = fn(x_c_clone, y_c) 4973*da0073e9SAndroid Build Coastguard Worker result = fn(x_clone, y) 4974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'") 4975*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4976*da0073e9SAndroid Build Coastguard Worker result.is_contiguous(memory_format=memory_format), 4977*da0073e9SAndroid Build Coastguard Worker f"result of the '{inspect.getsource(fn).strip()}' is not in '{memory_format}' format") 4978*da0073e9SAndroid Build Coastguard Worker 4979*da0073e9SAndroid Build Coastguard Worker for fn in bias_fns: 4980*da0073e9SAndroid Build Coastguard Worker result_c = fn(x_c, b_c) 4981*da0073e9SAndroid Build Coastguard Worker result = fn(x, bias) 4982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'") 4983*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4984*da0073e9SAndroid Build Coastguard Worker result.is_contiguous(memory_format=memory_format), 4985*da0073e9SAndroid Build Coastguard Worker f"result of the '{inspect.getsource(fn).strip()}' is not in '{memory_format}' format") 4986*da0073e9SAndroid Build Coastguard Worker 4987*da0073e9SAndroid Build Coastguard Worker for fn in return_contig_fns: 4988*da0073e9SAndroid Build Coastguard Worker result_c = fn(x_c, y_c) 4989*da0073e9SAndroid Build Coastguard Worker result = fn(x, y) 4990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_c, f"Failed for '{inspect.getsource(fn).strip()}'") 4991*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4992*da0073e9SAndroid Build Coastguard Worker result.is_contiguous(memory_format=torch.contiguous_format), 4993*da0073e9SAndroid Build Coastguard Worker f"result of the '{inspect.getsource(fn).strip()}' is not in '{torch.contiguous_format}' format") 4994*da0073e9SAndroid Build Coastguard Worker 4995*da0073e9SAndroid Build Coastguard Worker _test_helper( 4996*da0073e9SAndroid Build Coastguard Worker torch.randn((4, 3, 8, 8), device=device).contiguous(memory_format=torch.channels_last), 4997*da0073e9SAndroid Build Coastguard Worker abs(torch.randn((4, 3, 8, 8), device=device)) + 1, 4998*da0073e9SAndroid Build Coastguard Worker torch.randn((1, 3, 1, 1), device=device).contiguous(memory_format=torch.channels_last), 4999*da0073e9SAndroid Build Coastguard Worker torch.channels_last) 5000*da0073e9SAndroid Build Coastguard Worker _test_helper( 5001*da0073e9SAndroid Build Coastguard Worker torch.randn((4, 3, 8, 8, 8), device=device).contiguous(memory_format=torch.channels_last_3d), 5002*da0073e9SAndroid Build Coastguard Worker abs(torch.randn((4, 3, 8, 8, 8), device=device)) + 1, 5003*da0073e9SAndroid Build Coastguard Worker torch.randn((1, 3, 1, 1, 1), device=device).contiguous(memory_format=torch.channels_last_3d), 5004*da0073e9SAndroid Build Coastguard Worker torch.channels_last_3d) 5005*da0073e9SAndroid Build Coastguard Worker 5006*da0073e9SAndroid Build Coastguard Worker # FIXME: make this a elementwise unary and elementwise binary OpInfo test 5007*da0073e9SAndroid Build Coastguard Worker def test_strides_propagation(self, device): 5008*da0073e9SAndroid Build Coastguard Worker def _test_helper(x, op, unary=False): 5009*da0073e9SAndroid Build Coastguard Worker def compare_strides(s1, s2, div): 5010*da0073e9SAndroid Build Coastguard Worker sdiv = [s // div for s in s1] 5011*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sdiv, s2) 5012*da0073e9SAndroid Build Coastguard Worker 5013*da0073e9SAndroid Build Coastguard Worker dim = x.dim() 5014*da0073e9SAndroid Build Coastguard Worker # we produce memory dense outputs, so when input is strided on the last dimension 5015*da0073e9SAndroid Build Coastguard Worker # we need to divide by that dimension stride to compare input and result strides 5016*da0073e9SAndroid Build Coastguard Worker div = x.stride(-1) 5017*da0073e9SAndroid Build Coastguard Worker for p in permutations(range(dim)): 5018*da0073e9SAndroid Build Coastguard Worker xp = x.permute(p) 5019*da0073e9SAndroid Build Coastguard Worker if not unary: 5020*da0073e9SAndroid Build Coastguard Worker y = torch.randn(xp.size(-1), device=x.device, dtype=x.dtype) 5021*da0073e9SAndroid Build Coastguard Worker for inputs in ((xp, xp), (xp, y), (y, xp)): 5022*da0073e9SAndroid Build Coastguard Worker res = op(*inputs) 5023*da0073e9SAndroid Build Coastguard Worker compare_strides(xp.stride(), res.stride(), div) 5024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xp.size(), res.size()) 5025*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=xp.device, dtype=res.dtype) 5026*da0073e9SAndroid Build Coastguard Worker res = op(*inputs, out=out) 5027*da0073e9SAndroid Build Coastguard Worker compare_strides(xp.stride(), res.stride(), div) 5028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xp.size(), res.size()) 5029*da0073e9SAndroid Build Coastguard Worker else: 5030*da0073e9SAndroid Build Coastguard Worker res = op(xp) 5031*da0073e9SAndroid Build Coastguard Worker compare_strides(xp.stride(), res.stride(), div) 5032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xp.size(), res.size()) 5033*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device=xp.device, dtype=res.dtype) 5034*da0073e9SAndroid Build Coastguard Worker res = op(xp, out=out) 5035*da0073e9SAndroid Build Coastguard Worker compare_strides(xp.stride(), res.stride(), div) 5036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xp.size(), res.size()) 5037*da0073e9SAndroid Build Coastguard Worker 5038*da0073e9SAndroid Build Coastguard Worker # torch.eq by default calls TensorIterator with defined output, torch.add with undefined 5039*da0073e9SAndroid Build Coastguard Worker binary_ops = (torch.eq, torch.add) 5040*da0073e9SAndroid Build Coastguard Worker unary_ops = (torch.exp,) 5041*da0073e9SAndroid Build Coastguard Worker # memory dense, sliced and ambiguous sliced (ambiguous dense loses permutation information) 5042*da0073e9SAndroid Build Coastguard Worker xs = (torch.randn(2, 3, 4, device=device), torch.randn(2, 3, 8, device=device)[:, :, ::2], 5043*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 1, 4, 12, device=device)[:, :, :, ::2]) 5044*da0073e9SAndroid Build Coastguard Worker for op in binary_ops: 5045*da0073e9SAndroid Build Coastguard Worker for x in xs: 5046*da0073e9SAndroid Build Coastguard Worker _test_helper(x, op) 5047*da0073e9SAndroid Build Coastguard Worker for op in unary_ops: 5048*da0073e9SAndroid Build Coastguard Worker for x in xs: 5049*da0073e9SAndroid Build Coastguard Worker _test_helper(x, op, unary=True) 5050*da0073e9SAndroid Build Coastguard Worker 5051*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5052*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") 5053*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("NotImplementedError: PrimTorch does not support pinned memory") 5054*da0073e9SAndroid Build Coastguard Worker def test_pin_memory_from_constructor(self, device): 5055*da0073e9SAndroid Build Coastguard Worker def _get_like(t, **kwargs): 5056*da0073e9SAndroid Build Coastguard Worker return [ 5057*da0073e9SAndroid Build Coastguard Worker torch.rand_like(t, **kwargs), 5058*da0073e9SAndroid Build Coastguard Worker torch.randn_like(t, **kwargs), 5059*da0073e9SAndroid Build Coastguard Worker torch.empty_like(t, **kwargs), 5060*da0073e9SAndroid Build Coastguard Worker torch.full_like(t, 4, **kwargs), 5061*da0073e9SAndroid Build Coastguard Worker torch.zeros_like(t, **kwargs), 5062*da0073e9SAndroid Build Coastguard Worker torch.ones_like(t, **kwargs), 5063*da0073e9SAndroid Build Coastguard Worker ] 5064*da0073e9SAndroid Build Coastguard Worker 5065*da0073e9SAndroid Build Coastguard Worker def _get_tensors(**kwargs): 5066*da0073e9SAndroid Build Coastguard Worker return [ 5067*da0073e9SAndroid Build Coastguard Worker torch.tensor([10, 11], **kwargs), 5068*da0073e9SAndroid Build Coastguard Worker torch.randn(3, 5, **kwargs), 5069*da0073e9SAndroid Build Coastguard Worker torch.rand(3, **kwargs), 5070*da0073e9SAndroid Build Coastguard Worker # torch.randint(3, 5, **kwargs), // unsupported 5071*da0073e9SAndroid Build Coastguard Worker torch.zeros(3, **kwargs), 5072*da0073e9SAndroid Build Coastguard Worker torch.randperm(3, **kwargs), 5073*da0073e9SAndroid Build Coastguard Worker torch.empty(6, **kwargs), 5074*da0073e9SAndroid Build Coastguard Worker torch.ones(6, **kwargs), 5075*da0073e9SAndroid Build Coastguard Worker torch.eye(6, **kwargs), 5076*da0073e9SAndroid Build Coastguard Worker torch.arange(3, 5, **kwargs)] 5077*da0073e9SAndroid Build Coastguard Worker 5078*da0073e9SAndroid Build Coastguard Worker pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True) 5079*da0073e9SAndroid Build Coastguard Worker for x in pinned_tensors: 5080*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_pinned()) 5081*da0073e9SAndroid Build Coastguard Worker 5082*da0073e9SAndroid Build Coastguard Worker tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True)) 5083*da0073e9SAndroid Build Coastguard Worker for x in tensors: 5084*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_pinned()) 5085*da0073e9SAndroid Build Coastguard Worker 5086*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 5087*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5088*da0073e9SAndroid Build Coastguard Worker def test_storage_all_devices(self, devices): 5089*da0073e9SAndroid Build Coastguard Worker for device in devices: 5090*da0073e9SAndroid Build Coastguard Worker t = torch.tensor((), device=device) 5091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dtype, t.storage().dtype) 5092*da0073e9SAndroid Build Coastguard Worker 5093*da0073e9SAndroid Build Coastguard Worker # Note [lazy_clone_ tests with inductor enabled] 5094*da0073e9SAndroid Build Coastguard Worker # These `lazy_clone_` tests are written in a way that makes them pass in 5095*da0073e9SAndroid Build Coastguard Worker # both eager mode and compiled mode (`PYTORCH_TEST_WITH_INDUCTOR=1`). There 5096*da0073e9SAndroid Build Coastguard Worker # are cases where COW tensors can materialize at different times and in 5097*da0073e9SAndroid Build Coastguard Worker # different ways in compiled mode versus eager mode, and those cases need to 5098*da0073e9SAndroid Build Coastguard Worker # be avoided. There are two main wrinkles the be aware of. 5099*da0073e9SAndroid Build Coastguard Worker # 5100*da0073e9SAndroid Build Coastguard Worker # The first wrinkle is that these tests have to check the internal 5101*da0073e9SAndroid Build Coastguard Worker # properties of tensors to make sure they materialize in the expected way, 5102*da0073e9SAndroid Build Coastguard Worker # and those checks cause dynamo graph breaks. Depending on the situation, a 5103*da0073e9SAndroid Build Coastguard Worker # graph break in-between two compiled graphs that operate on the same COW 5104*da0073e9SAndroid Build Coastguard Worker # tensor can make the tensor materialize when it would not materialize in 5105*da0073e9SAndroid Build Coastguard Worker # eager mode, causing the checks to fail. The strategy for avoiding this is 5106*da0073e9SAndroid Build Coastguard Worker # to make all the operations on COW tensors get compiled into the same 5107*da0073e9SAndroid Build Coastguard Worker # graph, by not doing any checks between the operations, and just do all the 5108*da0073e9SAndroid Build Coastguard Worker # checks at the end of the test. If we really do want to perform checks 5109*da0073e9SAndroid Build Coastguard Worker # between two operations, `op1` and `op2`, the solution is to create two 5110*da0073e9SAndroid Build Coastguard Worker # different tests. One test performs just `op1` and then checks. The other 5111*da0073e9SAndroid Build Coastguard Worker # test performs `op1` followed immediately by `op2` and then checks. 5112*da0073e9SAndroid Build Coastguard Worker # 5113*da0073e9SAndroid Build Coastguard Worker # The second wrinkle is that in eager mode, if we perform writes on two COW 5114*da0073e9SAndroid Build Coastguard Worker # tensors where one is a lazy clone of the other, the first tensor to be 5115*da0073e9SAndroid Build Coastguard Worker # written will be materialized with a new data pointer, and the second 5116*da0073e9SAndroid Build Coastguard Worker # tensor will just reuse the original data pointer when it is materialized. 5117*da0073e9SAndroid Build Coastguard Worker # But in compiled mode, if these writes happen in the same graph, the order 5118*da0073e9SAndroid Build Coastguard Worker # in which the tensors materialize can be different than in eager mode. So 5119*da0073e9SAndroid Build Coastguard Worker # in this case the strategy is to purposefully cause a graph break to happen 5120*da0073e9SAndroid Build Coastguard Worker # in-between the two write operations, by adding checks between them, so 5121*da0073e9SAndroid Build Coastguard Worker # that they have to materialize in the expected order. 5122*da0073e9SAndroid Build Coastguard Worker @skipXLA 5123*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 5124*da0073e9SAndroid Build Coastguard Worker def test_lazy_clone(self, device, dtype): 5125*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) 5126*da0073e9SAndroid Build Coastguard Worker t_orig_storage_addr = torch._C._storage_address(t) 5127*da0073e9SAndroid Build Coastguard Worker orig_data_ptr = torch._C._data_address(t) 5128*da0073e9SAndroid Build Coastguard Worker clone = t._lazy_clone() 5129*da0073e9SAndroid Build Coastguard Worker 5130*da0073e9SAndroid Build Coastguard Worker # Lazy cloning a tensor should cause both it and its clone to become COW 5131*da0073e9SAndroid Build Coastguard Worker # tensors. They should have different storages, but the same data 5132*da0073e9SAndroid Build Coastguard Worker # pointer. 5133*da0073e9SAndroid Build Coastguard Worker 5134*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(clone)) 5135*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(t)) 5136*da0073e9SAndroid Build Coastguard Worker 5137*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr) 5138*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr) 5139*da0073e9SAndroid Build Coastguard Worker 5140*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(t) == orig_data_ptr) 5141*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(clone) == orig_data_ptr) 5142*da0073e9SAndroid Build Coastguard Worker 5143*da0073e9SAndroid Build Coastguard Worker # See Note [lazy_clone_ tests with inductor enabled] 5144*da0073e9SAndroid Build Coastguard Worker @skipXLA 5145*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 5146*da0073e9SAndroid Build Coastguard Worker def test_lazy_clone_view(self, device, dtype): 5147*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) 5148*da0073e9SAndroid Build Coastguard Worker t_orig_storage_addr = torch._C._storage_address(t) 5149*da0073e9SAndroid Build Coastguard Worker orig_data_ptr = torch._C._data_address(t) 5150*da0073e9SAndroid Build Coastguard Worker clone = t._lazy_clone() 5151*da0073e9SAndroid Build Coastguard Worker view = t.view([4]) 5152*da0073e9SAndroid Build Coastguard Worker 5153*da0073e9SAndroid Build Coastguard Worker # Viewing `t` should not cause a copy (materialize) to happen. All the 5154*da0073e9SAndroid Build Coastguard Worker # tensors should still be COW and have the same data pointer. `view` and 5155*da0073e9SAndroid Build Coastguard Worker # `t` should have the same storage, and `clone` should have a different 5156*da0073e9SAndroid Build Coastguard Worker # storage. 5157*da0073e9SAndroid Build Coastguard Worker 5158*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(t)) 5159*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(view)) 5160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(clone)) 5161*da0073e9SAndroid Build Coastguard Worker 5162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr) 5163*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr) 5164*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr) 5165*da0073e9SAndroid Build Coastguard Worker 5166*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(t) == orig_data_ptr) 5167*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(clone) == orig_data_ptr) 5168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(view) == orig_data_ptr) 5169*da0073e9SAndroid Build Coastguard Worker 5170*da0073e9SAndroid Build Coastguard Worker # See Note [lazy_clone_ tests with inductor enabled] 5171*da0073e9SAndroid Build Coastguard Worker @skipXLA 5172*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 5173*da0073e9SAndroid Build Coastguard Worker def test_lazy_clone_view_materialize(self, device, dtype): 5174*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) 5175*da0073e9SAndroid Build Coastguard Worker t_orig_storage_addr = torch._C._storage_address(t) 5176*da0073e9SAndroid Build Coastguard Worker orig_data_ptr = torch._C._data_address(t) 5177*da0073e9SAndroid Build Coastguard Worker clone = t._lazy_clone() 5178*da0073e9SAndroid Build Coastguard Worker view = t.view([4]) 5179*da0073e9SAndroid Build Coastguard Worker view += torch.ones(1, device=device, dtype=dtype) 5180*da0073e9SAndroid Build Coastguard Worker 5181*da0073e9SAndroid Build Coastguard Worker # Writing to `t` should cause the storage under `t` and `view` to be 5182*da0073e9SAndroid Build Coastguard Worker # copied (materialized), but should not affect `clone`. 5183*da0073e9SAndroid Build Coastguard Worker 5184*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_cow_tensor(t)) 5185*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_cow_tensor(view)) 5186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(clone)) 5187*da0073e9SAndroid Build Coastguard Worker 5188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr) 5189*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr) 5190*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr) 5191*da0073e9SAndroid Build Coastguard Worker 5192*da0073e9SAndroid Build Coastguard Worker t_new_data_addr = torch._C._data_address(t) 5193*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t_new_data_addr != orig_data_ptr) 5194*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(view) == t_new_data_addr) 5195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(clone) == orig_data_ptr) 5196*da0073e9SAndroid Build Coastguard Worker 5197*da0073e9SAndroid Build Coastguard Worker clone += torch.ones(1, device=device, dtype=dtype) 5198*da0073e9SAndroid Build Coastguard Worker 5199*da0073e9SAndroid Build Coastguard Worker # Writing to `clone` should materialize it, so it should no longer 5200*da0073e9SAndroid Build Coastguard Worker # be COW. However, since `clone`'s storage is the only COW storage 5201*da0073e9SAndroid Build Coastguard Worker # left that holds a reference to the original data pointer, this 5202*da0073e9SAndroid Build Coastguard Worker # materialization should not actually cause a copy--it should 5203*da0073e9SAndroid Build Coastguard Worker # just reuse the original data pointer. 5204*da0073e9SAndroid Build Coastguard Worker 5205*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_cow_tensor(t)) 5206*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_cow_tensor(view)) 5207*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_cow_tensor(clone)) 5208*da0073e9SAndroid Build Coastguard Worker 5209*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr) 5210*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr) 5211*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr) 5212*da0073e9SAndroid Build Coastguard Worker 5213*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(t) == t_new_data_addr) 5214*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(view) == t_new_data_addr) 5215*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._data_address(clone) == orig_data_ptr) 5216*da0073e9SAndroid Build Coastguard Worker 5217*da0073e9SAndroid Build Coastguard Worker @skipXLA 5218*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 5219*da0073e9SAndroid Build Coastguard Worker def test_lazy_clone_binary_op_no_materialize(self, device, dtype): 5220*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) 5221*da0073e9SAndroid Build Coastguard Worker clone = t._lazy_clone() 5222*da0073e9SAndroid Build Coastguard Worker res = t + clone 5223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(t)) 5224*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cow_tensor(clone)) 5225*da0073e9SAndroid Build Coastguard Worker 5226*da0073e9SAndroid Build Coastguard Worker # This tests that if a COW materialization is attempted inside an 5227*da0073e9SAndroid Build Coastguard Worker # `at::parallel_for` loop function, then an error is raised. This test is 5228*da0073e9SAndroid Build Coastguard Worker # implemented in Python rather than C++ because the C++ tests are built 5229*da0073e9SAndroid Build Coastguard Worker # without multithreading support in `at::parallel_for`. 5230*da0073e9SAndroid Build Coastguard Worker @skipXLA 5231*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Torchdynamo fails and we do not need to test it here anyway") 5232*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 5233*da0073e9SAndroid Build Coastguard Worker def test_parallel_cow_materialize_error(self, device, dtype): 5234*da0073e9SAndroid Build Coastguard Worker 5235*da0073e9SAndroid Build Coastguard Worker def run(num_threads, num_parallel, skip_first, should_error): 5236*da0073e9SAndroid Build Coastguard Worker orig_num_threads = torch.get_num_threads() 5237*da0073e9SAndroid Build Coastguard Worker 5238*da0073e9SAndroid Build Coastguard Worker try: 5239*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(num_threads) 5240*da0073e9SAndroid Build Coastguard Worker 5241*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)._lazy_clone() 5242*da0073e9SAndroid Build Coastguard Worker 5243*da0073e9SAndroid Build Coastguard Worker if should_error: 5244*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Materializing a storage'): 5245*da0073e9SAndroid Build Coastguard Worker torch._test_parallel_materialize( 5246*da0073e9SAndroid Build Coastguard Worker a, num_parallel, skip_first) 5247*da0073e9SAndroid Build Coastguard Worker else: 5248*da0073e9SAndroid Build Coastguard Worker torch._test_parallel_materialize(a, num_parallel, skip_first) 5249*da0073e9SAndroid Build Coastguard Worker 5250*da0073e9SAndroid Build Coastguard Worker # Error should not raise in any case if the tensor is not COW 5251*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) 5252*da0073e9SAndroid Build Coastguard Worker torch._test_parallel_materialize(b, num_parallel, skip_first) 5253*da0073e9SAndroid Build Coastguard Worker 5254*da0073e9SAndroid Build Coastguard Worker finally: 5255*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(orig_num_threads) 5256*da0073e9SAndroid Build Coastguard Worker 5257*da0073e9SAndroid Build Coastguard Worker run(1, 1, False, True) 5258*da0073e9SAndroid Build Coastguard Worker run(1, 1, True, False) 5259*da0073e9SAndroid Build Coastguard Worker run(1, 10, False, True) 5260*da0073e9SAndroid Build Coastguard Worker run(1, 10, True, True) 5261*da0073e9SAndroid Build Coastguard Worker run(10, 1, False, True) 5262*da0073e9SAndroid Build Coastguard Worker run(10, 1, True, False) 5263*da0073e9SAndroid Build Coastguard Worker run(10, 10, False, True) 5264*da0073e9SAndroid Build Coastguard Worker run(10, 10, True, True) 5265*da0073e9SAndroid Build Coastguard Worker run(10, 2, False, True) 5266*da0073e9SAndroid Build Coastguard Worker run(10, 2, True, True) 5267*da0073e9SAndroid Build Coastguard Worker 5268*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 5269*da0073e9SAndroid Build Coastguard Worker @skipIfMps 5270*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.double, torch.half) 5271*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5272*da0073e9SAndroid Build Coastguard Worker def test_multinomial(self, device, dtype): 5273*da0073e9SAndroid Build Coastguard Worker def make_prob_dist(shape, is_contiguous): 5274*da0073e9SAndroid Build Coastguard Worker if is_contiguous: 5275*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 5276*da0073e9SAndroid Build Coastguard Worker return torch.zeros(shape, device=device).uniform_().to(dtype=torch.half) 5277*da0073e9SAndroid Build Coastguard Worker return torch.zeros(shape, device=device, dtype=dtype).uniform_() 5278*da0073e9SAndroid Build Coastguard Worker elif len(shape) == 1: 5279*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 5280*da0073e9SAndroid Build Coastguard Worker return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=torch.half)[:, 2] 5281*da0073e9SAndroid Build Coastguard Worker return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2] 5282*da0073e9SAndroid Build Coastguard Worker else: 5283*da0073e9SAndroid Build Coastguard Worker # num dim = 2 5284*da0073e9SAndroid Build Coastguard Worker new_shape = [2, shape[1], 7, 1, shape[0], 1, 10] 5285*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half: 5286*da0073e9SAndroid Build Coastguard Worker prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=torch.half) 5287*da0073e9SAndroid Build Coastguard Worker else: 5288*da0073e9SAndroid Build Coastguard Worker prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_() 5289*da0073e9SAndroid Build Coastguard Worker prob_dist = prob_dist.transpose(1, 4) 5290*da0073e9SAndroid Build Coastguard Worker prob_dist = prob_dist[1, :, 5, 0, :, 0, 4] 5291*da0073e9SAndroid Build Coastguard Worker assert not prob_dist.is_contiguous() # sanity check 5292*da0073e9SAndroid Build Coastguard Worker return prob_dist 5293*da0073e9SAndroid Build Coastguard Worker 5294*da0073e9SAndroid Build Coastguard Worker for is_contiguous in (True, False): 5295*da0073e9SAndroid Build Coastguard Worker # with replacement 5296*da0073e9SAndroid Build Coastguard Worker n_row = 3 5297*da0073e9SAndroid Build Coastguard Worker for n_col in range(4, 5 + 1): 5298*da0073e9SAndroid Build Coastguard Worker prob_dist = make_prob_dist([n_row, n_col], is_contiguous) 5299*da0073e9SAndroid Build Coastguard Worker # indices that shouldn't be sampled (<0 means none) 5300*da0073e9SAndroid Build Coastguard Worker zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist() 5301*da0073e9SAndroid Build Coastguard Worker for i, j in enumerate(zero_prob_indices): 5302*da0073e9SAndroid Build Coastguard Worker if j >= 0: 5303*da0073e9SAndroid Build Coastguard Worker prob_dist[i, j] = 0 5304*da0073e9SAndroid Build Coastguard Worker n_sample = n_col * 3 5305*da0073e9SAndroid Build Coastguard Worker sample_indices = torch.multinomial(prob_dist, n_sample, True) 5306*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prob_dist.dim(), 2) 5307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample_indices.size(1), n_sample) 5308*da0073e9SAndroid Build Coastguard Worker for i in range(n_row): 5309*da0073e9SAndroid Build Coastguard Worker zero_prob_idx = zero_prob_indices[i] 5310*da0073e9SAndroid Build Coastguard Worker if zero_prob_idx < 0: 5311*da0073e9SAndroid Build Coastguard Worker continue 5312*da0073e9SAndroid Build Coastguard Worker for j in range(n_sample): 5313*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(sample_indices[i, j], zero_prob_idx, 5314*da0073e9SAndroid Build Coastguard Worker msg="sampled an index with zero probability") 5315*da0073e9SAndroid Build Coastguard Worker 5316*da0073e9SAndroid Build Coastguard Worker # without replacement 5317*da0073e9SAndroid Build Coastguard Worker n_row = 3 5318*da0073e9SAndroid Build Coastguard Worker for n_col in range(2, 10 + 1, 2): 5319*da0073e9SAndroid Build Coastguard Worker prob_dist = make_prob_dist([n_row, n_col], is_contiguous) 5320*da0073e9SAndroid Build Coastguard Worker # indices that shouldn't be sampled (<0 means none) 5321*da0073e9SAndroid Build Coastguard Worker zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist() 5322*da0073e9SAndroid Build Coastguard Worker for i, j in enumerate(zero_prob_indices): 5323*da0073e9SAndroid Build Coastguard Worker if j >= 0: 5324*da0073e9SAndroid Build Coastguard Worker prob_dist[i, j] = 0 5325*da0073e9SAndroid Build Coastguard Worker n_sample = max(1, n_col - 2) 5326*da0073e9SAndroid Build Coastguard Worker sample_indices = torch.multinomial(prob_dist, n_sample, False) 5327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prob_dist.dim(), 2) 5328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample_indices.size(1), n_sample) 5329*da0073e9SAndroid Build Coastguard Worker for i in range(n_row): 5330*da0073e9SAndroid Build Coastguard Worker row_samples = {} 5331*da0073e9SAndroid Build Coastguard Worker zero_prob_idx = zero_prob_indices[i] 5332*da0073e9SAndroid Build Coastguard Worker for j in range(n_sample): 5333*da0073e9SAndroid Build Coastguard Worker sample_idx = sample_indices[i, j] 5334*da0073e9SAndroid Build Coastguard Worker if zero_prob_idx >= 0: 5335*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(sample_idx, zero_prob_idx, 5336*da0073e9SAndroid Build Coastguard Worker msg="sampled an index with zero probability") 5337*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(sample_idx, row_samples, "sampled an index twice") 5338*da0073e9SAndroid Build Coastguard Worker row_samples[sample_idx] = True 5339*da0073e9SAndroid Build Coastguard Worker 5340*da0073e9SAndroid Build Coastguard Worker # vector 5341*da0073e9SAndroid Build Coastguard Worker n_col = 4 5342*da0073e9SAndroid Build Coastguard Worker prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1) 5343*da0073e9SAndroid Build Coastguard Worker zero_prob_idx = 1 # index that shouldn't be sampled 5344*da0073e9SAndroid Build Coastguard Worker prob_dist[zero_prob_idx] = 0 5345*da0073e9SAndroid Build Coastguard Worker n_sample = 20 5346*da0073e9SAndroid Build Coastguard Worker sample_indices = torch.multinomial(prob_dist, n_sample, True) 5347*da0073e9SAndroid Build Coastguard Worker for sample_index in sample_indices: 5348*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(sample_index, zero_prob_idx, msg="sampled an index with zero probability") 5349*da0073e9SAndroid Build Coastguard Worker s_dim = sample_indices.dim() 5350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample_indices.dim(), 1, msg="wrong number of dimensions") 5351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prob_dist.dim(), 1, msg="wrong number of prob_dist dimensions") 5352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample_indices.size(0), n_sample, msg="wrong number of samples") 5353*da0073e9SAndroid Build Coastguard Worker 5354*da0073e9SAndroid Build Coastguard Worker # CUDA misalignment issue (#46702) 5355*da0073e9SAndroid Build Coastguard Worker n_row, n_col = 2, 3 5356*da0073e9SAndroid Build Coastguard Worker prob_dist = make_prob_dist([n_row, n_col], True) 5357*da0073e9SAndroid Build Coastguard Worker n_sample = 1 5358*da0073e9SAndroid Build Coastguard Worker sample_indices = torch.multinomial(prob_dist, n_sample, True) 5359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample_indices.dim(), 2, msg="wrong number of dimensions") 5360*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample_indices.size(1), n_sample, msg="wrong number of samples") 5361*da0073e9SAndroid Build Coastguard Worker 5362*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 5363*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5364*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, torch.half) 5365*da0073e9SAndroid Build Coastguard Worker def test_multinomial_deterministic(self, device, dtype): 5366*da0073e9SAndroid Build Coastguard Worker gen = torch.Generator(device=device) 5367*da0073e9SAndroid Build Coastguard Worker 5368*da0073e9SAndroid Build Coastguard Worker trials = 5 5369*da0073e9SAndroid Build Coastguard Worker seed = 0 5370*da0073e9SAndroid Build Coastguard Worker prob_dist = torch.rand(10000, 1000, device=device, dtype=dtype) 5371*da0073e9SAndroid Build Coastguard Worker n_sample = 1 5372*da0073e9SAndroid Build Coastguard Worker 5373*da0073e9SAndroid Build Coastguard Worker for i in range(trials): 5374*da0073e9SAndroid Build Coastguard Worker gen.manual_seed(seed) 5375*da0073e9SAndroid Build Coastguard Worker samples_1 = torch.multinomial(prob_dist, n_sample, True, generator=gen) 5376*da0073e9SAndroid Build Coastguard Worker 5377*da0073e9SAndroid Build Coastguard Worker gen.manual_seed(seed) 5378*da0073e9SAndroid Build Coastguard Worker samples_2 = torch.multinomial(prob_dist, n_sample, True, generator=gen) 5379*da0073e9SAndroid Build Coastguard Worker 5380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(samples_1, samples_2) 5381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(samples_1.dim(), 2, msg="wrong number of dimensions") 5382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(samples_1.size(1), n_sample, msg="wrong number of samples") 5383*da0073e9SAndroid Build Coastguard Worker 5384*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 5385*da0073e9SAndroid Build Coastguard Worker @slowTest 5386*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 5387*da0073e9SAndroid Build Coastguard Worker def test_multinomial_rng_state_advance(self, device, dtype): 5388*da0073e9SAndroid Build Coastguard Worker corpus_size = 100000 5389*da0073e9SAndroid Build Coastguard Worker freqs = torch.ones(corpus_size, dtype=torch.float, device=device) 5390*da0073e9SAndroid Build Coastguard Worker n_sample = 100 5391*da0073e9SAndroid Build Coastguard Worker samples1 = torch.multinomial(freqs, n_sample, replacement=True) 5392*da0073e9SAndroid Build Coastguard Worker samples2 = torch.multinomial(freqs, n_sample, replacement=True) 5393*da0073e9SAndroid Build Coastguard Worker samples = torch.cat([samples1, samples2]) 5394*da0073e9SAndroid Build Coastguard Worker # expect no more than 1 repeating elements generated in 2 attempts 5395*da0073e9SAndroid Build Coastguard Worker # the probability of at least element being repeated is surprisingly large, 18% 5396*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(2 * n_sample - samples.unique().size(0), 2) 5397*da0073e9SAndroid Build Coastguard Worker samples1 = torch.multinomial(freqs, n_sample, replacement=False) 5398*da0073e9SAndroid Build Coastguard Worker samples2 = torch.multinomial(freqs, n_sample, replacement=False) 5399*da0073e9SAndroid Build Coastguard Worker samples = torch.cat([samples1, samples2]) 5400*da0073e9SAndroid Build Coastguard Worker # expect no more than 1 repeating elements generated in 2 attempts 5401*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1) 5402*da0073e9SAndroid Build Coastguard Worker 5403*da0073e9SAndroid Build Coastguard Worker def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn, 5404*da0073e9SAndroid Build Coastguard Worker memory_format, compare_data=True, default_is_preserve=False): 5405*da0073e9SAndroid Build Coastguard Worker 5406*da0073e9SAndroid Build Coastguard Worker assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d 5407*da0073e9SAndroid Build Coastguard Worker 5408*da0073e9SAndroid Build Coastguard Worker # xc is a channels last tensor 5409*da0073e9SAndroid Build Coastguard Worker xc = input_generator_fn(device) 5410*da0073e9SAndroid Build Coastguard Worker # xc is not memory dense, but looks like channels last 5411*da0073e9SAndroid Build Coastguard Worker # We don't preserve non-dense striding 5412*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 5413*da0073e9SAndroid Build Coastguard Worker if memory_format == torch.channels_last: 5414*da0073e9SAndroid Build Coastguard Worker xc = xc[..., ::2, ::2] 5415*da0073e9SAndroid Build Coastguard Worker else: 5416*da0073e9SAndroid Build Coastguard Worker xc = xc[..., ::2, ::2, ::2] 5417*da0073e9SAndroid Build Coastguard Worker 5418*da0073e9SAndroid Build Coastguard Worker clone = transformation_fn(xc, memory_format=torch.preserve_format) 5419*da0073e9SAndroid Build Coastguard Worker 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous()) 5422*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous(memory_format=memory_format)) 5423*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 5424*da0073e9SAndroid Build Coastguard Worker self.assertFalse(xc.is_contiguous()) 5425*da0073e9SAndroid Build Coastguard Worker self.assertFalse(xc.is_contiguous(memory_format=memory_format)) 5426*da0073e9SAndroid Build Coastguard Worker if compare_data: 5427*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xc, clone.to(xc)) 5428*da0073e9SAndroid Build Coastguard Worker 5429*da0073e9SAndroid Build Coastguard Worker xc = input_generator_fn(device) 5430*da0073e9SAndroid Build Coastguard Worker clone = transformation_fn(xc, memory_format=torch.contiguous_format) 5431*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous()) 5432*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous(memory_format=memory_format)) 5433*da0073e9SAndroid Build Coastguard Worker if compare_data: 5434*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xc, clone.to(xc)) 5435*da0073e9SAndroid Build Coastguard Worker 5436*da0073e9SAndroid Build Coastguard Worker xc = input_generator_fn(device) 5437*da0073e9SAndroid Build Coastguard Worker clone = transformation_fn(xc) 5438*da0073e9SAndroid Build Coastguard Worker 5439*da0073e9SAndroid Build Coastguard Worker if default_is_preserve: 5440*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous()) 5441*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous(memory_format=memory_format)) 5442*da0073e9SAndroid Build Coastguard Worker else: 5443*da0073e9SAndroid Build Coastguard Worker self.assertTrue(clone.is_contiguous()) 5444*da0073e9SAndroid Build Coastguard Worker self.assertFalse(clone.is_contiguous(memory_format=memory_format)) 5445*da0073e9SAndroid Build Coastguard Worker if compare_data: 5446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(xc, clone.to(xc)) 5447*da0073e9SAndroid Build Coastguard Worker 5448*da0073e9SAndroid Build Coastguard Worker # TODO copy _like constructors to stride permutation instead of just layout 5449*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 5450*da0073e9SAndroid Build Coastguard Worker x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device) 5451*da0073e9SAndroid Build Coastguard Worker for i in range(10): 5452*da0073e9SAndroid Build Coastguard Worker permutation = list(range(len(x.shape))) 5453*da0073e9SAndroid Build Coastguard Worker random.shuffle(permutation) 5454*da0073e9SAndroid Build Coastguard Worker x = x.permute(permutation) 5455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride()) 5456*da0073e9SAndroid Build Coastguard Worker 5457*da0073e9SAndroid Build Coastguard Worker def test_memory_format_to(self, device): 5458*da0073e9SAndroid Build Coastguard Worker def get_generator(memory_format, shape): 5459*da0073e9SAndroid Build Coastguard Worker def input_generator_fn(device): 5460*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) 5461*da0073e9SAndroid Build Coastguard Worker return input_generator_fn 5462*da0073e9SAndroid Build Coastguard Worker 5463*da0073e9SAndroid Build Coastguard Worker def transformation_fn(tensor, **kwargs): 5464*da0073e9SAndroid Build Coastguard Worker return tensor.to(dtype=torch.float64, **kwargs) 5465*da0073e9SAndroid Build Coastguard Worker 5466*da0073e9SAndroid Build Coastguard Worker formats_shapes = ( 5467*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (4, 3, 8, 8)), 5468*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (4, 3, 8, 8, 8))) 5469*da0073e9SAndroid Build Coastguard Worker 5470*da0073e9SAndroid Build Coastguard Worker for mf, shape in formats_shapes: 5471*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5472*da0073e9SAndroid Build Coastguard Worker device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True) 5473*da0073e9SAndroid Build Coastguard Worker 5474*da0073e9SAndroid Build Coastguard Worker def test_memory_format_type(self, device): 5475*da0073e9SAndroid Build Coastguard Worker def get_generator(memory_format, shape): 5476*da0073e9SAndroid Build Coastguard Worker def input_generator_fn(device): 5477*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) 5478*da0073e9SAndroid Build Coastguard Worker return input_generator_fn 5479*da0073e9SAndroid Build Coastguard Worker 5480*da0073e9SAndroid Build Coastguard Worker def transformation_fn(tensor, **kwargs): 5481*da0073e9SAndroid Build Coastguard Worker return tensor.to(torch.float64, **kwargs) 5482*da0073e9SAndroid Build Coastguard Worker 5483*da0073e9SAndroid Build Coastguard Worker formats_shapes = ( 5484*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (4, 3, 8, 8)), 5485*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (4, 3, 8, 8, 8))) 5486*da0073e9SAndroid Build Coastguard Worker 5487*da0073e9SAndroid Build Coastguard Worker for mf, shape in formats_shapes: 5488*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5489*da0073e9SAndroid Build Coastguard Worker device, get_generator(mf, shape), transformation_fn, mf, default_is_preserve=True) 5490*da0073e9SAndroid Build Coastguard Worker 5491*da0073e9SAndroid Build Coastguard Worker def test_memory_format_clone(self, device): 5492*da0073e9SAndroid Build Coastguard Worker def get_generator(memory_format, shape): 5493*da0073e9SAndroid Build Coastguard Worker def input_generator_fn(device): 5494*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) 5495*da0073e9SAndroid Build Coastguard Worker return input_generator_fn 5496*da0073e9SAndroid Build Coastguard Worker 5497*da0073e9SAndroid Build Coastguard Worker def transformation_fn(tensor, **kwargs): 5498*da0073e9SAndroid Build Coastguard Worker return tensor.clone(**kwargs) 5499*da0073e9SAndroid Build Coastguard Worker 5500*da0073e9SAndroid Build Coastguard Worker formats_shapes = ( 5501*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (4, 3, 8, 8)), 5502*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (4, 3, 8, 8, 8))) 5503*da0073e9SAndroid Build Coastguard Worker 5504*da0073e9SAndroid Build Coastguard Worker for mf, shape in formats_shapes: 5505*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5506*da0073e9SAndroid Build Coastguard Worker device, get_generator(mf, shape), transformation_fn, mf, True, default_is_preserve=True) 5507*da0073e9SAndroid Build Coastguard Worker 5508*da0073e9SAndroid Build Coastguard Worker def test_memory_format_factory_like_functions_preserve(self, device): 5509*da0073e9SAndroid Build Coastguard Worker def get_generator(memory_format, shape): 5510*da0073e9SAndroid Build Coastguard Worker def input_generator_fn(device): 5511*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) 5512*da0073e9SAndroid Build Coastguard Worker return input_generator_fn 5513*da0073e9SAndroid Build Coastguard Worker 5514*da0073e9SAndroid Build Coastguard Worker transformation_fns = [ 5515*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.zeros_like(t, **kwargs), 5516*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.ones_like(t, **kwargs), 5517*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.randint_like(t, 10, 100, **kwargs), 5518*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.randint_like(t, 100, **kwargs), 5519*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.randn_like(t, **kwargs), 5520*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.rand_like(t, **kwargs), 5521*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.full_like(t, 7, **kwargs), 5522*da0073e9SAndroid Build Coastguard Worker lambda t, **kwargs: torch.empty_like(t, **kwargs)] 5523*da0073e9SAndroid Build Coastguard Worker 5524*da0073e9SAndroid Build Coastguard Worker formats_shapes = ( 5525*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (4, 3, 8, 8)), 5526*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (4, 3, 8, 8, 8))) 5527*da0073e9SAndroid Build Coastguard Worker 5528*da0073e9SAndroid Build Coastguard Worker for mf, shape, in formats_shapes: 5529*da0073e9SAndroid Build Coastguard Worker for transformation_fn in transformation_fns: 5530*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5531*da0073e9SAndroid Build Coastguard Worker device, get_generator(mf, shape), transformation_fn, mf, compare_data=False, default_is_preserve=True) 5532*da0073e9SAndroid Build Coastguard Worker 5533*da0073e9SAndroid Build Coastguard Worker def test_memory_format_type_shortcuts(self, device): 5534*da0073e9SAndroid Build Coastguard Worker def get_generator(memory_format, shape, dtype): 5535*da0073e9SAndroid Build Coastguard Worker def input_generator_fn(device): 5536*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=dtype).clamp(0, 1) \ 5537*da0073e9SAndroid Build Coastguard Worker .round().contiguous(memory_format=memory_format) 5538*da0073e9SAndroid Build Coastguard Worker return input_generator_fn 5539*da0073e9SAndroid Build Coastguard Worker 5540*da0073e9SAndroid Build Coastguard Worker 5541*da0073e9SAndroid Build Coastguard Worker def get_fn(fn_name): 5542*da0073e9SAndroid Build Coastguard Worker def transformation_fn(tensor, **kwargs): 5543*da0073e9SAndroid Build Coastguard Worker fn = getattr(tensor, fn_name) 5544*da0073e9SAndroid Build Coastguard Worker return fn(**kwargs) 5545*da0073e9SAndroid Build Coastguard Worker return transformation_fn 5546*da0073e9SAndroid Build Coastguard Worker 5547*da0073e9SAndroid Build Coastguard Worker shortcuts = ['byte', 'char', 'double', 'bool', 'half', 'int', 'long', 'short'] 5548*da0073e9SAndroid Build Coastguard Worker if device == 'cpu': 5549*da0073e9SAndroid Build Coastguard Worker shortcuts += ['bfloat16'] 5550*da0073e9SAndroid Build Coastguard Worker 5551*da0073e9SAndroid Build Coastguard Worker formats_shapes = ( 5552*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (4, 3, 8, 8)), 5553*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (4, 3, 8, 8, 8))) 5554*da0073e9SAndroid Build Coastguard Worker 5555*da0073e9SAndroid Build Coastguard Worker for mf, shape in formats_shapes: 5556*da0073e9SAndroid Build Coastguard Worker for fn_name in shortcuts: 5557*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5558*da0073e9SAndroid Build Coastguard Worker device, get_generator(mf, shape, torch.float32), get_fn(fn_name), mf, default_is_preserve=True) 5559*da0073e9SAndroid Build Coastguard Worker 5560*da0073e9SAndroid Build Coastguard Worker # Test 'float' separately to avoid float->float no-op. 5561*da0073e9SAndroid Build Coastguard Worker for mf, shape in formats_shapes: 5562*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5563*da0073e9SAndroid Build Coastguard Worker device, get_generator(mf, shape, torch.float64), get_fn('float'), mf, default_is_preserve=True) 5564*da0073e9SAndroid Build Coastguard Worker 5565*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 5566*da0073e9SAndroid Build Coastguard Worker def test_memory_format_cpu_and_cuda_ops(self, device): 5567*da0073e9SAndroid Build Coastguard Worker def get_generator(memory_format, shape): 5568*da0073e9SAndroid Build Coastguard Worker def input_generator_fn(device): 5569*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=torch.float32).contiguous(memory_format=memory_format) 5570*da0073e9SAndroid Build Coastguard Worker return input_generator_fn 5571*da0073e9SAndroid Build Coastguard Worker 5572*da0073e9SAndroid Build Coastguard Worker def transformation_cpu_fn(tensor, **kwargs): 5573*da0073e9SAndroid Build Coastguard Worker return tensor.cpu(**kwargs) 5574*da0073e9SAndroid Build Coastguard Worker 5575*da0073e9SAndroid Build Coastguard Worker def transformation_cuda_fn(tensor, **kwargs): 5576*da0073e9SAndroid Build Coastguard Worker return tensor.cuda(**kwargs) 5577*da0073e9SAndroid Build Coastguard Worker 5578*da0073e9SAndroid Build Coastguard Worker formats_shapes = ( 5579*da0073e9SAndroid Build Coastguard Worker (torch.channels_last, (4, 3, 8, 8)), 5580*da0073e9SAndroid Build Coastguard Worker (torch.channels_last_3d, (4, 3, 8, 8, 8))) 5581*da0073e9SAndroid Build Coastguard Worker 5582*da0073e9SAndroid Build Coastguard Worker for mf, shape in formats_shapes: 5583*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5584*da0073e9SAndroid Build Coastguard Worker 'cuda', get_generator(mf, shape), transformation_cpu_fn, mf, default_is_preserve=True) 5585*da0073e9SAndroid Build Coastguard Worker self._test_memory_format_transformations( 5586*da0073e9SAndroid Build Coastguard Worker 'cpu', get_generator(mf, shape), transformation_cuda_fn, mf, default_is_preserve=True) 5587*da0073e9SAndroid Build Coastguard Worker 5588*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_serialization 5589*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5590*da0073e9SAndroid Build Coastguard Worker def test_pickle_gradscaler(self, device): 5591*da0073e9SAndroid Build Coastguard Worker # This test should pass in 3 cases for cuda: 5592*da0073e9SAndroid Build Coastguard Worker # 1. cuda is not available. 5593*da0073e9SAndroid Build Coastguard Worker # 2. cuda is available but device is not cuda. 5594*da0073e9SAndroid Build Coastguard Worker # 3. cuda is available and device is cuda. 5595*da0073e9SAndroid Build Coastguard Worker # In case 1, a and b disable themselves on construction and shouldn't try to pickle workhorse attributes. 5596*da0073e9SAndroid Build Coastguard Worker # In case 2, a and b are enabled. Workhorse attributes participate in pickling, but none are lazy-inited 5597*da0073e9SAndroid Build Coastguard Worker # to cuda Tensors, because I don't want to do cuda things if device is not cuda. 5598*da0073e9SAndroid Build Coastguard Worker # In case 3, a and b are enabled and we may also try lazy-initing _scale to a cuda tensor. 5599*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5600*da0073e9SAndroid Build Coastguard Worker try_lazy_inits = (True, False) 5601*da0073e9SAndroid Build Coastguard Worker GradScaler = partial(torch.GradScaler, device=device.type) 5602*da0073e9SAndroid Build Coastguard Worker for lazy_init_scale in try_lazy_inits: 5603*da0073e9SAndroid Build Coastguard Worker a = GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2) 5604*da0073e9SAndroid Build Coastguard Worker if device.type == "cuda": 5605*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not a.is_enabled() if torch.cuda.amp.common.amp_definitely_not_available() else a.is_enabled()) 5606*da0073e9SAndroid Build Coastguard Worker else: 5607*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.is_enabled()) 5608*da0073e9SAndroid Build Coastguard Worker if lazy_init_scale: 5609*da0073e9SAndroid Build Coastguard Worker # Dummy a.scale() call lazy-inits a._scale Tensor. 5610*da0073e9SAndroid Build Coastguard Worker a.scale(torch.tensor([4.0], dtype=torch.float32, device=device)) 5611*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a._scale.device.type == device.type) 5612*da0073e9SAndroid Build Coastguard Worker # The following three lines should work whether or not cuda is available. 5613*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(a) 5614*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 5615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.is_enabled(), a.is_enabled()) 5616*da0073e9SAndroid Build Coastguard Worker if a.is_enabled(): 5617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.get_scale(), 3.) 5618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.get_growth_factor(), 4.) 5619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.get_backoff_factor(), .5) 5620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.get_growth_interval(), 2) 5621*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b._init_growth_tracker, 0) 5622*da0073e9SAndroid Build Coastguard Worker # supplies a dummy key to test the defaultdict's default_factory 5623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b._per_optimizer_states["fdsa"], 5624*da0073e9SAndroid Build Coastguard Worker torch.amp.grad_scaler._refresh_per_optimizer_state()) 5625*da0073e9SAndroid Build Coastguard Worker if lazy_init_scale: 5626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0) 5627*da0073e9SAndroid Build Coastguard Worker 5628*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 5629*da0073e9SAndroid Build Coastguard Worker def _test_multinomial_empty(self, device, replacement, num_samples): 5630*da0073e9SAndroid Build Coastguard Worker probs = torch.ones(0, 3, device=device) 5631*da0073e9SAndroid Build Coastguard Worker expected = torch.empty(0, num_samples, dtype=torch.int64) 5632*da0073e9SAndroid Build Coastguard Worker out = torch.multinomial(probs, num_samples=num_samples, replacement=replacement) 5633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, expected) 5634*da0073e9SAndroid Build Coastguard Worker 5635*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 5636*da0073e9SAndroid Build Coastguard Worker def test_multinomial_empty_w_replacement(self, device): 5637*da0073e9SAndroid Build Coastguard Worker self._test_multinomial_empty(device, True, 1) 5638*da0073e9SAndroid Build Coastguard Worker self._test_multinomial_empty(device, True, 2) 5639*da0073e9SAndroid Build Coastguard Worker 5640*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test distributions 5641*da0073e9SAndroid Build Coastguard Worker def test_multinomial_empty_wo_replacement(self, device): 5642*da0073e9SAndroid Build Coastguard Worker self._test_multinomial_empty(device, False, 1) 5643*da0073e9SAndroid Build Coastguard Worker self._test_multinomial_empty(device, False, 2) 5644*da0073e9SAndroid Build Coastguard Worker 5645*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5646*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 5647*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_unscale(self, device, dtype): 5648*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5649*da0073e9SAndroid Build Coastguard Worker device0 = "cuda:0" if device.type == "cuda" else "cpu" 5650*da0073e9SAndroid Build Coastguard Worker inv_scale = torch.full((1,), 0.25, dtype=torch.float, device=device0) 5651*da0073e9SAndroid Build Coastguard Worker found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device0) 5652*da0073e9SAndroid Build Coastguard Worker 5653*da0073e9SAndroid Build Coastguard Worker size = 20 5654*da0073e9SAndroid Build Coastguard Worker g = torch.full((size, size), 4.0, dtype=dtype, device=device0) 5655*da0073e9SAndroid Build Coastguard Worker ginf = g.clone() 5656*da0073e9SAndroid Build Coastguard Worker ginf[2, 2] = float('inf') 5657*da0073e9SAndroid Build Coastguard Worker gnan = g.clone() 5658*da0073e9SAndroid Build Coastguard Worker gnan[2, 2] = float('nan') 5659*da0073e9SAndroid Build Coastguard Worker 5660*da0073e9SAndroid Build Coastguard Worker # Tries selected combinations of 5661*da0073e9SAndroid Build Coastguard Worker # - contiguous grads 5662*da0073e9SAndroid Build Coastguard Worker # - g.clone().t() which is not contiguous but still non overlapping and dense 5663*da0073e9SAndroid Build Coastguard Worker # - variants of g.clone()[:, :5] which are not non overlapping and dense 5664*da0073e9SAndroid Build Coastguard Worker # Non overlapping and dense grads route into a multi tensor apply kernel, 5665*da0073e9SAndroid Build Coastguard Worker # others use a fallback per-tensor kernel, so we should try both. 5666*da0073e9SAndroid Build Coastguard Worker cases = ( 5667*da0073e9SAndroid Build Coastguard Worker ([g.clone(), g.clone()], False), 5668*da0073e9SAndroid Build Coastguard Worker ([g.clone(), g.clone().t()], False), 5669*da0073e9SAndroid Build Coastguard Worker ([g.clone(), g.clone()[:, :5]], False), 5670*da0073e9SAndroid Build Coastguard Worker ([g.clone()[:, :5], g.clone()[:, :5]], False), 5671*da0073e9SAndroid Build Coastguard Worker ([g.clone(), ginf.clone()], True), 5672*da0073e9SAndroid Build Coastguard Worker ([g.clone(), gnan.clone()], True), 5673*da0073e9SAndroid Build Coastguard Worker ([g.clone(), ginf.clone()[:, :5]], True), 5674*da0073e9SAndroid Build Coastguard Worker ([g.clone(), gnan.clone()[:, :5]], True), 5675*da0073e9SAndroid Build Coastguard Worker ([ginf.clone(), g.clone()[:, :5]], True), 5676*da0073e9SAndroid Build Coastguard Worker ([ginf.clone()[:, :5], g.clone()[:, :5]], True), 5677*da0073e9SAndroid Build Coastguard Worker ) 5678*da0073e9SAndroid Build Coastguard Worker 5679*da0073e9SAndroid Build Coastguard Worker for grads, has_inf in cases: 5680*da0073e9SAndroid Build Coastguard Worker found_inf.zero_() 5681*da0073e9SAndroid Build Coastguard Worker torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) 5682*da0073e9SAndroid Build Coastguard Worker if has_inf: 5683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 1.0) 5684*da0073e9SAndroid Build Coastguard Worker else: 5685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 0.0) 5686*da0073e9SAndroid Build Coastguard Worker for grad in grads: 5687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7) 5688*da0073e9SAndroid Build Coastguard Worker 5689*da0073e9SAndroid Build Coastguard Worker # When passing lists with mismatched dtypes to a raw 5690*da0073e9SAndroid Build Coastguard Worker # _amp_foreach_non_finite_check_and_unscale_ call on CUDA, 5691*da0073e9SAndroid Build Coastguard Worker # it's expected to fall back to single-tensor TensorIterator kernel. 5692*da0073e9SAndroid Build Coastguard Worker grads = [g.clone(), g.to(dtype=torch.float16)] 5693*da0073e9SAndroid Build Coastguard Worker torch._amp_foreach_non_finite_check_and_unscale_(grads, found_inf, inv_scale) 5694*da0073e9SAndroid Build Coastguard Worker for grad in grads: 5695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7) 5696*da0073e9SAndroid Build Coastguard Worker 5697*da0073e9SAndroid Build Coastguard Worker # Passing lists with mismatched devices to a raw 5698*da0073e9SAndroid Build Coastguard Worker # _amp_foreach_non_finite_check_and_unscale_ call should raise errors. 5699*da0073e9SAndroid Build Coastguard Worker if device.type == "cuda" and TEST_MULTIGPU: 5700*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Expected all tensors to be on the same device"): 5701*da0073e9SAndroid Build Coastguard Worker torch._amp_foreach_non_finite_check_and_unscale_([g.clone(), g.to(device="cuda:1")], 5702*da0073e9SAndroid Build Coastguard Worker found_inf, 5703*da0073e9SAndroid Build Coastguard Worker inv_scale) 5704*da0073e9SAndroid Build Coastguard Worker 5705*da0073e9SAndroid Build Coastguard Worker # Creates a list of grads with mismatched dtypes and devices, to ensure 5706*da0073e9SAndroid Build Coastguard Worker # scaler._unscale_grads_ organizes grads by dtype and device before calling 5707*da0073e9SAndroid Build Coastguard Worker # _amp_foreach_non_finite_check_and_unscale_ on each set. 5708*da0073e9SAndroid Build Coastguard Worker # If inject_inf >= 0, writes an inf into one grad for _unscale_grads_ to find. 5709*da0073e9SAndroid Build Coastguard Worker def perfect_storm_grads(inject_inf): 5710*da0073e9SAndroid Build Coastguard Worker grads = [g.clone(), g.clone()[:, :5], g.to(dtype=torch.float16), g.to(dtype=torch.float16)] 5711*da0073e9SAndroid Build Coastguard Worker if device.type == "cuda" and TEST_MULTIGPU: 5712*da0073e9SAndroid Build Coastguard Worker grads += [g.to(device="cuda:1"), 5713*da0073e9SAndroid Build Coastguard Worker g.to(device="cuda:1")[:, :5], 5714*da0073e9SAndroid Build Coastguard Worker g.to(device="cuda:1", dtype=torch.float16), 5715*da0073e9SAndroid Build Coastguard Worker g.to(device="cuda:1", dtype=torch.float16)] 5716*da0073e9SAndroid Build Coastguard Worker if inject_inf >= 0: 5717*da0073e9SAndroid Build Coastguard Worker grads[inject_inf][2, 2] = float('inf') 5718*da0073e9SAndroid Build Coastguard Worker return grads 5719*da0073e9SAndroid Build Coastguard Worker 5720*da0073e9SAndroid Build Coastguard Worker GradScaler = partial(torch.GradScaler, device=device.type) 5721*da0073e9SAndroid Build Coastguard Worker scaler = GradScaler() 5722*da0073e9SAndroid Build Coastguard Worker dummy_params = [torch.empty_like(g) for g in perfect_storm_grads(-1)] 5723*da0073e9SAndroid Build Coastguard Worker dummy_opt = torch.optim.SGD(dummy_params, lr=1.) 5724*da0073e9SAndroid Build Coastguard Worker 5725*da0073e9SAndroid Build Coastguard Worker # Ensures the inf/nan checking can find an inf injected onto any grad in the perfect storm. 5726*da0073e9SAndroid Build Coastguard Worker for inject_inf in range(-1, len(dummy_params)): 5727*da0073e9SAndroid Build Coastguard Worker found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device0) 5728*da0073e9SAndroid Build Coastguard Worker grads = perfect_storm_grads(inject_inf) 5729*da0073e9SAndroid Build Coastguard Worker for i, p in enumerate(dummy_params): 5730*da0073e9SAndroid Build Coastguard Worker p.grad = grads[i] 5731*da0073e9SAndroid Build Coastguard Worker found_inf_per_device = scaler._unscale_grads_(dummy_opt, inv_scale, found_inf, True) 5732*da0073e9SAndroid Build Coastguard Worker if inject_inf < 0: 5733*da0073e9SAndroid Build Coastguard Worker # No inf was injected, ensures unscaling worked normally. 5734*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 0) 5735*da0073e9SAndroid Build Coastguard Worker for grad in grads: 5736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones_like(grad), rtol=1e-5, atol=1e-7) 5737*da0073e9SAndroid Build Coastguard Worker else: 5738*da0073e9SAndroid Build Coastguard Worker # inf was injected, ensures inf was found. 5739*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sum(v.item() for v in found_inf_per_device.values()) == 1) 5740*da0073e9SAndroid Build Coastguard Worker 5741*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5742*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 5743*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_update_scale(self, device, dtype): 5744*da0073e9SAndroid Build Coastguard Worker growth = 2.0 5745*da0073e9SAndroid Build Coastguard Worker backoff = 0.25 5746*da0073e9SAndroid Build Coastguard Worker growth_interval = 2 5747*da0073e9SAndroid Build Coastguard Worker scale = torch.full((1,), 4.0, dtype=dtype, device=device) 5748*da0073e9SAndroid Build Coastguard Worker growth_tracker = torch.full((1,), 0.0, dtype=torch.int32, device=device) 5749*da0073e9SAndroid Build Coastguard Worker found_inf = torch.full((1,), 0.0, dtype=torch.float, device=device) 5750*da0073e9SAndroid Build Coastguard Worker 5751*da0073e9SAndroid Build Coastguard Worker # Simulates 2 consecutive unskipped iterations 5752*da0073e9SAndroid Build Coastguard Worker torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval) 5753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(growth_tracker, 1) 5754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, 4.0) 5755*da0073e9SAndroid Build Coastguard Worker torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval) 5756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(growth_tracker, 0) 5757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, 8.0) 5758*da0073e9SAndroid Build Coastguard Worker 5759*da0073e9SAndroid Build Coastguard Worker # Simulates a skipped iteration 5760*da0073e9SAndroid Build Coastguard Worker found_inf.fill_(1.0) 5761*da0073e9SAndroid Build Coastguard Worker torch._amp_update_scale_(scale, growth_tracker, found_inf, growth, backoff, growth_interval) 5762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(growth_tracker, 0) 5763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, 2.0) 5764*da0073e9SAndroid Build Coastguard Worker 5765*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Failed running call_function for sparse_coo_tensor. See https://github.com/pytorch/pytorch/issues/118856") 5766*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5767*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float) 5768*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_unscale_sparse(self, device, dtype): 5769*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5770*da0073e9SAndroid Build Coastguard Worker scaler = torch.GradScaler(device=device.type) 5771*da0073e9SAndroid Build Coastguard Worker 5772*da0073e9SAndroid Build Coastguard Worker inv_scale = torch.full((1,), 0.25, dtype=dtype, device=device) 5773*da0073e9SAndroid Build Coastguard Worker found_inf = torch.empty((1,), dtype=dtype, device=device) 5774*da0073e9SAndroid Build Coastguard Worker cur = found_inf.device 5775*da0073e9SAndroid Build Coastguard Worker 5776*da0073e9SAndroid Build Coastguard Worker i = torch.tensor([[0, 1, 1], 5777*da0073e9SAndroid Build Coastguard Worker [2, 0, 2]], device=device, dtype=torch.int64) 5778*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([16., 32., 64.], device=device, dtype=torch.float) 5779*da0073e9SAndroid Build Coastguard Worker s = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype) 5780*da0073e9SAndroid Build Coastguard Worker 5781*da0073e9SAndroid Build Coastguard Worker p = s.clone() 5782*da0073e9SAndroid Build Coastguard Worker assert p.is_sparse 5783*da0073e9SAndroid Build Coastguard Worker opt = torch.optim.SGD([p], lr=1.) 5784*da0073e9SAndroid Build Coastguard Worker 5785*da0073e9SAndroid Build Coastguard Worker p.grad = s.clone() 5786*da0073e9SAndroid Build Coastguard Worker found_inf.zero_() 5787*da0073e9SAndroid Build Coastguard Worker found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur] 5788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 0.0) 5789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.grad.to_dense(), (s / 4).to_dense()) 5790*da0073e9SAndroid Build Coastguard Worker 5791*da0073e9SAndroid Build Coastguard Worker v = torch.FloatTensor([16., 32., float('inf')]) 5792*da0073e9SAndroid Build Coastguard Worker p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype) 5793*da0073e9SAndroid Build Coastguard Worker found_inf.zero_() 5794*da0073e9SAndroid Build Coastguard Worker found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur] 5795*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 1.0) 5796*da0073e9SAndroid Build Coastguard Worker 5797*da0073e9SAndroid Build Coastguard Worker v = torch.FloatTensor([16., 32., float('nan')]) 5798*da0073e9SAndroid Build Coastguard Worker p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=dtype) 5799*da0073e9SAndroid Build Coastguard Worker found_inf.zero_() 5800*da0073e9SAndroid Build Coastguard Worker found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, False)[cur] 5801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 1.0) 5802*da0073e9SAndroid Build Coastguard Worker 5803*da0073e9SAndroid Build Coastguard Worker p = s.clone().half() 5804*da0073e9SAndroid Build Coastguard Worker assert p.is_sparse 5805*da0073e9SAndroid Build Coastguard Worker opt = torch.optim.SGD([p], lr=1.) 5806*da0073e9SAndroid Build Coastguard Worker 5807*da0073e9SAndroid Build Coastguard Worker p.grad = s.clone().half() 5808*da0073e9SAndroid Build Coastguard Worker found_inf.zero_() 5809*da0073e9SAndroid Build Coastguard Worker found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur] 5810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 0.0) 5811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p.grad.to_dense(), (s.half() / 4).to_dense()) 5812*da0073e9SAndroid Build Coastguard Worker 5813*da0073e9SAndroid Build Coastguard Worker # Creates fp16 sparse tensor with duplicated indices (uncoalesced). The uncoalesced representation 5814*da0073e9SAndroid Build Coastguard Worker # does not overflow in fp16, but the coalesced representation would, because 64000 + 64000 > fp16 max. 5815*da0073e9SAndroid Build Coastguard Worker # _amp_non_finite_check_and_unscale_ should report an overflow here. 5816*da0073e9SAndroid Build Coastguard Worker i = torch.LongTensor([[0, 1, 0], 5817*da0073e9SAndroid Build Coastguard Worker [2, 0, 2]]) 5818*da0073e9SAndroid Build Coastguard Worker v = torch.FloatTensor([64000., 32., 64000.]) 5819*da0073e9SAndroid Build Coastguard Worker p.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device=device, dtype=torch.float16) 5820*da0073e9SAndroid Build Coastguard Worker found_inf.zero_() 5821*da0073e9SAndroid Build Coastguard Worker found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf, True)[cur] 5822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(found_inf, 1.0) 5823*da0073e9SAndroid Build Coastguard Worker 5824*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5825*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_state_dict(self, device): 5826*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5827*da0073e9SAndroid Build Coastguard Worker GradScaler = partial(torch.GradScaler, device=device.type) 5828*da0073e9SAndroid Build Coastguard Worker for lazy_init_scale in True, False: 5829*da0073e9SAndroid Build Coastguard Worker s0 = GradScaler(init_scale=3., growth_factor=4., backoff_factor=.5, growth_interval=2) 5830*da0073e9SAndroid Build Coastguard Worker s1 = GradScaler(init_scale=6., growth_factor=7., backoff_factor=.8, growth_interval=1) 5831*da0073e9SAndroid Build Coastguard Worker 5832*da0073e9SAndroid Build Coastguard Worker # sets a random value for load_state_dict to overwrite 5833*da0073e9SAndroid Build Coastguard Worker s1._init_growth_tracker = 7 5834*da0073e9SAndroid Build Coastguard Worker 5835*da0073e9SAndroid Build Coastguard Worker if lazy_init_scale: 5836*da0073e9SAndroid Build Coastguard Worker # Dummy scale() call to ensure the scale tensor is lazily initialized. 5837*da0073e9SAndroid Build Coastguard Worker s1.scale(torch.full((1,), 4.0, dtype=torch.float32, device=device)) 5838*da0073e9SAndroid Build Coastguard Worker if "cuda" == device.type: 5839*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1._scale, torch.cuda.FloatTensor)) 5840*da0073e9SAndroid Build Coastguard Worker else: 5841*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1._scale, torch.FloatTensor)) 5842*da0073e9SAndroid Build Coastguard Worker 5843*da0073e9SAndroid Build Coastguard Worker s1.load_state_dict(s0.state_dict()) 5844*da0073e9SAndroid Build Coastguard Worker 5845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.get_scale(), 3.) 5846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.get_growth_factor(), 4.) 5847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.get_backoff_factor(), .5) 5848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.get_growth_interval(), 2) 5849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1._init_growth_tracker, 0) 5850*da0073e9SAndroid Build Coastguard Worker 5851*da0073e9SAndroid Build Coastguard Worker # _run_scaling_case generalizes some single-optimizer test logic to avoid too much copy-pasting below. 5852*da0073e9SAndroid Build Coastguard Worker def _run_scaling_case(self, device, run, unskipped, skipped, atol=1e-7, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None): 5853*da0073e9SAndroid Build Coastguard Worker # Ensure scaling can be disabled without changing user control flow. 5854*da0073e9SAndroid Build Coastguard Worker for enabled in True, False: 5855*da0073e9SAndroid Build Coastguard Worker ( 5856*da0073e9SAndroid Build Coastguard Worker mod_control, mod_scaling, opt_control, opt_scaling, data, loss_fn, skip_iter, 5857*da0073e9SAndroid Build Coastguard Worker ) = _create_scaling_case(device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs) 5858*da0073e9SAndroid Build Coastguard Worker 5859*da0073e9SAndroid Build Coastguard Worker # For functionality, test with a modest initial scale, and an unrealistically-large growth factor 5860*da0073e9SAndroid Build Coastguard Worker # so any potential errors with the growth factor handling will be magnified. 5861*da0073e9SAndroid Build Coastguard Worker GradScaler = partial(torch.GradScaler, device=device) 5862*da0073e9SAndroid Build Coastguard Worker scaler = GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled, growth_interval=1) 5863*da0073e9SAndroid Build Coastguard Worker 5864*da0073e9SAndroid Build Coastguard Worker _ = run(device, data, mod_control, opt_control, scaler, loss_fn, skip_iter, False) 5865*da0073e9SAndroid Build Coastguard Worker ret = run(device, data, mod_scaling, opt_scaling, scaler, loss_fn, skip_iter, True) 5866*da0073e9SAndroid Build Coastguard Worker 5867*da0073e9SAndroid Build Coastguard Worker # Allows run() to optionally return a different scaler instance. 5868*da0073e9SAndroid Build Coastguard Worker scaler = ret if ret else scaler 5869*da0073e9SAndroid Build Coastguard Worker 5870*da0073e9SAndroid Build Coastguard Worker # If scaling was enabled, the scale factor should have been multiplied by the growth factor 5871*da0073e9SAndroid Build Coastguard Worker # len(data) - skipped times and the backoff factor "skipped" times. 5872*da0073e9SAndroid Build Coastguard Worker if enabled: 5873*da0073e9SAndroid Build Coastguard Worker net_growth = scaler.get_growth_factor()**unskipped if unskipped > 0 else 1.0 5874*da0073e9SAndroid Build Coastguard Worker net_backoff = scaler.get_backoff_factor()**skipped if skipped > 0 else 1.0 5875*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scaler.get_scale() == (128. * net_growth * net_backoff)) 5876*da0073e9SAndroid Build Coastguard Worker else: 5877*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scaler.get_scale() == 1.0) 5878*da0073e9SAndroid Build Coastguard Worker 5879*da0073e9SAndroid Build Coastguard Worker for c, s in zip(mod_control.parameters(), mod_scaling.parameters()): 5880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.grad, s.grad, atol=atol, rtol=1e-05) 5881*da0073e9SAndroid Build Coastguard Worker 5882*da0073e9SAndroid Build Coastguard Worker c_state, s_state = opt_control.state[c], opt_scaling.state[s] 5883*da0073e9SAndroid Build Coastguard Worker for k in c_state: 5884*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_state[k], s_state[k], atol=atol, rtol=1e-05, msg=k) 5885*da0073e9SAndroid Build Coastguard Worker 5886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, s, atol=atol, rtol=1e-05) 5887*da0073e9SAndroid Build Coastguard Worker 5888*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5889*da0073e9SAndroid Build Coastguard Worker @parametrize("foreach, fused", [(None, None), (True, None), (None, True)]) 5890*da0073e9SAndroid Build Coastguard Worker @optims( 5891*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]], 5892*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32] 5893*da0073e9SAndroid Build Coastguard Worker ) 5894*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_autocast(self, device, dtype, optim_info, foreach, fused): 5895*da0073e9SAndroid Build Coastguard Worker try_pickle = False 5896*da0073e9SAndroid Build Coastguard Worker 5897*da0073e9SAndroid Build Coastguard Worker def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): 5898*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 5899*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 5900*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device, dtype=torch.half, enabled=try_scaling_api): 5901*da0073e9SAndroid Build Coastguard Worker output = model(input) 5902*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(output, target) 5903*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 5904*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 5905*da0073e9SAndroid Build Coastguard Worker if i == skip_iter and scaler.is_enabled(): 5906*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 5907*da0073e9SAndroid Build Coastguard Worker model[1].weight.grad.fill_(float('inf')) 5908*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 5909*da0073e9SAndroid Build Coastguard Worker scaler.update() 5910*da0073e9SAndroid Build Coastguard Worker if try_pickle: 5911*da0073e9SAndroid Build Coastguard Worker scaler = pickle.loads(pickle.dumps(scaler)) 5912*da0073e9SAndroid Build Coastguard Worker else: 5913*da0073e9SAndroid Build Coastguard Worker loss.backward() 5914*da0073e9SAndroid Build Coastguard Worker if (not scaler.is_enabled()) or (i != skip_iter): 5915*da0073e9SAndroid Build Coastguard Worker optimizer.step() 5916*da0073e9SAndroid Build Coastguard Worker return scaler 5917*da0073e9SAndroid Build Coastguard Worker 5918*da0073e9SAndroid Build Coastguard Worker optimizer_ctor = optim_info.optim_cls 5919*da0073e9SAndroid Build Coastguard Worker 5920*da0073e9SAndroid Build Coastguard Worker # Compares no scaling + no autocasting against scaling + autocasting. 5921*da0073e9SAndroid Build Coastguard Worker # NOTE(mkozuki): With current way of testing, `torch.optim.Adam` is failing in spite of `foreach` and `fused`. 5922*da0073e9SAndroid Build Coastguard Worker # Giving some flexibility to this test might help. 5923*da0073e9SAndroid Build Coastguard Worker context = contextlib.nullcontext 5924*da0073e9SAndroid Build Coastguard Worker if optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW): 5925*da0073e9SAndroid Build Coastguard Worker from functools import partial 5926*da0073e9SAndroid Build Coastguard Worker context = partial(self.assertRaises, AssertionError) 5927*da0073e9SAndroid Build Coastguard Worker with context(): 5928*da0073e9SAndroid Build Coastguard Worker # sets atol=1e-3 because we're comparing pure fp32 arithmetic vs a mixture of fp16 and fp32 5929*da0073e9SAndroid Build Coastguard Worker self._run_scaling_case( 5930*da0073e9SAndroid Build Coastguard Worker device, run, unskipped=3, skipped=1, atol=1e-3, 5931*da0073e9SAndroid Build Coastguard Worker optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused}, 5932*da0073e9SAndroid Build Coastguard Worker ) 5933*da0073e9SAndroid Build Coastguard Worker # this will be picked up by try_pickle within run(): 5934*da0073e9SAndroid Build Coastguard Worker try_pickle = True 5935*da0073e9SAndroid Build Coastguard Worker self._run_scaling_case( 5936*da0073e9SAndroid Build Coastguard Worker device, run, unskipped=3, skipped=1, atol=1e-3, 5937*da0073e9SAndroid Build Coastguard Worker optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": foreach, "fused": fused}, 5938*da0073e9SAndroid Build Coastguard Worker ) 5939*da0073e9SAndroid Build Coastguard Worker 5940*da0073e9SAndroid Build Coastguard Worker # Make sure that the parameters become nonsense when scaled gradients are finite 5941*da0073e9SAndroid Build Coastguard Worker # but they get invalidated before `optimizer.step`, after `GradScaler.unscale_` 5942*da0073e9SAndroid Build Coastguard Worker 5943*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5944*da0073e9SAndroid Build Coastguard Worker @optims( 5945*da0073e9SAndroid Build Coastguard Worker [optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]], 5946*da0073e9SAndroid Build Coastguard Worker dtypes=[torch.float32] 5947*da0073e9SAndroid Build Coastguard Worker ) 5948*da0073e9SAndroid Build Coastguard Worker def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info): 5949*da0073e9SAndroid Build Coastguard Worker optimizer_ctor = optim_info.optim_cls 5950*da0073e9SAndroid Build Coastguard Worker all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( 5951*da0073e9SAndroid Build Coastguard Worker device, dtype, optim_info, skip=("differentiable",)) 5952*da0073e9SAndroid Build Coastguard Worker 5953*da0073e9SAndroid Build Coastguard Worker for optim_input in all_optim_inputs: 5954*da0073e9SAndroid Build Coastguard Worker model, _, optimizer, _, data, loss_fn, _ = _create_scaling_case( 5955*da0073e9SAndroid Build Coastguard Worker device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optim_input.kwargs, 5956*da0073e9SAndroid Build Coastguard Worker ) 5957*da0073e9SAndroid Build Coastguard Worker scaler = torch.GradScaler(device=device, init_scale=128.0) 5958*da0073e9SAndroid Build Coastguard Worker 5959*da0073e9SAndroid Build Coastguard Worker for input, target in data: 5960*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 5961*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device, dtype=torch.half): 5962*da0073e9SAndroid Build Coastguard Worker output = model(input) 5963*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(output, target) 5964*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 5965*da0073e9SAndroid Build Coastguard Worker scaler.unscale_(optimizer) 5966*da0073e9SAndroid Build Coastguard Worker 5967*da0073e9SAndroid Build Coastguard Worker # deliberately break grads 5968*da0073e9SAndroid Build Coastguard Worker for j, param in enumerate(model.parameters()): 5969*da0073e9SAndroid Build Coastguard Worker param.grad.copy_(torch.inf if j % 2 else torch.nan) 5970*da0073e9SAndroid Build Coastguard Worker 5971*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 5972*da0073e9SAndroid Build Coastguard Worker scaler.update() 5973*da0073e9SAndroid Build Coastguard Worker 5974*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters())) 5975*da0073e9SAndroid Build Coastguard Worker 5976*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5977*da0073e9SAndroid Build Coastguard Worker def test_grad_scale_will_not_overflow(self, device): 5978*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5979*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(5, 1).to(device) 5980*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adam(model.parameters()) 5981*da0073e9SAndroid Build Coastguard Worker scaler = torch.GradScaler(device=device.type, growth_interval=1, growth_factor=2**4, init_scale=1e38) 5982*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 5983*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 5).to(device) 5984*da0073e9SAndroid Build Coastguard Worker y = 1e-30 * torch.randn(1, 1).to(device) 5985*da0073e9SAndroid Build Coastguard Worker l = ((model(x) - y) ** 2).mean() 5986*da0073e9SAndroid Build Coastguard Worker scaler.scale(l).backward() 5987*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 5988*da0073e9SAndroid Build Coastguard Worker scaler.update() 5989*da0073e9SAndroid Build Coastguard Worker assert scaler._scale != float("inf") and scaler._scale != float("nan") 5990*da0073e9SAndroid Build Coastguard Worker 5991*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 5992*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_clipping(self, device): 5993*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 5994*da0073e9SAndroid Build Coastguard Worker 5995*da0073e9SAndroid Build Coastguard Worker def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): 5996*da0073e9SAndroid Build Coastguard Worker max_norm = 0.2 # A reasonable value that actually has an effect, based on printouts of grads 5997*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 5998*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 5999*da0073e9SAndroid Build Coastguard Worker output = model(input) 6000*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(output, target) 6001*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6002*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 6003*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm * scaler.get_scale()) 6004*da0073e9SAndroid Build Coastguard Worker if i == skip_iter and scaler.is_enabled(): 6005*da0073e9SAndroid Build Coastguard Worker model[1].weight.grad.data.fill_(float('inf')) 6006*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 6007*da0073e9SAndroid Build Coastguard Worker scaler.update() 6008*da0073e9SAndroid Build Coastguard Worker else: 6009*da0073e9SAndroid Build Coastguard Worker loss.backward() 6010*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 6011*da0073e9SAndroid Build Coastguard Worker if (not scaler.is_enabled()) or (i != skip_iter): 6012*da0073e9SAndroid Build Coastguard Worker optimizer.step() 6013*da0073e9SAndroid Build Coastguard Worker 6014*da0073e9SAndroid Build Coastguard Worker self._run_scaling_case(device.type, run, unskipped=3, skipped=1, atol=1e-5) 6015*da0073e9SAndroid Build Coastguard Worker 6016*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6017*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_clipping_separate_unscale(self, device): 6018*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 6019*da0073e9SAndroid Build Coastguard Worker 6020*da0073e9SAndroid Build Coastguard Worker def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): 6021*da0073e9SAndroid Build Coastguard Worker max_norm = 0.2 # A reasonable value that actually has an effect, based on printouts of grads 6022*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 6023*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 6024*da0073e9SAndroid Build Coastguard Worker output = model(input) 6025*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(output, target) 6026*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6027*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 6028*da0073e9SAndroid Build Coastguard Worker if i == skip_iter and scaler.is_enabled(): 6029*da0073e9SAndroid Build Coastguard Worker model[1].weight.grad.data.fill_(float('inf')) 6030*da0073e9SAndroid Build Coastguard Worker scaler.unscale_(optimizer) 6031*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, error_if_nonfinite=False) 6032*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 6033*da0073e9SAndroid Build Coastguard Worker scaler.update() 6034*da0073e9SAndroid Build Coastguard Worker else: 6035*da0073e9SAndroid Build Coastguard Worker loss.backward() 6036*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 6037*da0073e9SAndroid Build Coastguard Worker if (not scaler.is_enabled()) or (i != skip_iter): 6038*da0073e9SAndroid Build Coastguard Worker optimizer.step() 6039*da0073e9SAndroid Build Coastguard Worker 6040*da0073e9SAndroid Build Coastguard Worker self._run_scaling_case(device.type, run, unskipped=3, skipped=1) 6041*da0073e9SAndroid Build Coastguard Worker 6042*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6043*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, 'FIXME: fix this test for Windows') 6044*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_penalty(self, device): 6045*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 6046*da0073e9SAndroid Build Coastguard Worker 6047*da0073e9SAndroid Build Coastguard Worker def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): 6048*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 6049*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 6050*da0073e9SAndroid Build Coastguard Worker output = model(input) 6051*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(output, target) 6052*da0073e9SAndroid Build Coastguard Worker 6053*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6054*da0073e9SAndroid Build Coastguard Worker grad_params = torch.autograd.grad(scaler.scale(loss), 6055*da0073e9SAndroid Build Coastguard Worker model.parameters(), create_graph=True) 6056*da0073e9SAndroid Build Coastguard Worker inv_scale = 1. / scaler.get_scale() 6057*da0073e9SAndroid Build Coastguard Worker grad_params = [p * inv_scale for p in grad_params] 6058*da0073e9SAndroid Build Coastguard Worker else: 6059*da0073e9SAndroid Build Coastguard Worker grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True) 6060*da0073e9SAndroid Build Coastguard Worker 6061*da0073e9SAndroid Build Coastguard Worker grad_norm = 0 6062*da0073e9SAndroid Build Coastguard Worker for grad in grad_params: 6063*da0073e9SAndroid Build Coastguard Worker grad_norm += grad.pow(2).sum() 6064*da0073e9SAndroid Build Coastguard Worker grad_norm = grad_norm.sqrt() 6065*da0073e9SAndroid Build Coastguard Worker loss = loss + grad_norm 6066*da0073e9SAndroid Build Coastguard Worker 6067*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6068*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 6069*da0073e9SAndroid Build Coastguard Worker if i == skip_iter and scaler.is_enabled(): 6070*da0073e9SAndroid Build Coastguard Worker model[1].weight.grad.data.fill_(float('inf')) 6071*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 6072*da0073e9SAndroid Build Coastguard Worker scaler.update() 6073*da0073e9SAndroid Build Coastguard Worker else: 6074*da0073e9SAndroid Build Coastguard Worker loss.backward() 6075*da0073e9SAndroid Build Coastguard Worker if (not scaler.is_enabled()) or (i != skip_iter): 6076*da0073e9SAndroid Build Coastguard Worker optimizer.step() 6077*da0073e9SAndroid Build Coastguard Worker 6078*da0073e9SAndroid Build Coastguard Worker self._run_scaling_case(device.type, run, unskipped=3, skipped=1) 6079*da0073e9SAndroid Build Coastguard Worker 6080*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6081*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_accumulation(self, device): 6082*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 6083*da0073e9SAndroid Build Coastguard Worker 6084*da0073e9SAndroid Build Coastguard Worker def run(device, data, model, optimizer, scaler, loss_fn, skip_iter, try_scaling_api): 6085*da0073e9SAndroid Build Coastguard Worker iters_to_accumulate = 2 6086*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 6087*da0073e9SAndroid Build Coastguard Worker output = model(input) 6088*da0073e9SAndroid Build Coastguard Worker loss = loss_fn(output, target) 6089*da0073e9SAndroid Build Coastguard Worker loss = loss / iters_to_accumulate 6090*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6091*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 6092*da0073e9SAndroid Build Coastguard Worker else: 6093*da0073e9SAndroid Build Coastguard Worker loss.backward() 6094*da0073e9SAndroid Build Coastguard Worker if (i + 1) % iters_to_accumulate == 0: 6095*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6096*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer) 6097*da0073e9SAndroid Build Coastguard Worker scaler.update() 6098*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 6099*da0073e9SAndroid Build Coastguard Worker else: 6100*da0073e9SAndroid Build Coastguard Worker optimizer.step() 6101*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 6102*da0073e9SAndroid Build Coastguard Worker 6103*da0073e9SAndroid Build Coastguard Worker self._run_scaling_case(device.type, run, unskipped=2, skipped=0) 6104*da0073e9SAndroid Build Coastguard Worker 6105*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6106*da0073e9SAndroid Build Coastguard Worker def test_grad_scaling_multiple(self, device): 6107*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 6108*da0073e9SAndroid Build Coastguard Worker # Tests gradient scaling with 2 models and 2 optimizers that both receive gradients from 2 losses. 6109*da0073e9SAndroid Build Coastguard Worker # Some of the logic here cannot reuse the generic helper functions created for the 1-optimizer cases. 6110*da0073e9SAndroid Build Coastguard Worker for enabled in True, False: 6111*da0073e9SAndroid Build Coastguard Worker mod_control0, mod_scaling0, opt_control0, opt_scaling0, data, loss_fn, skip_iter = \ 6112*da0073e9SAndroid Build Coastguard Worker _create_scaling_case(device.type) 6113*da0073e9SAndroid Build Coastguard Worker mod_control1, mod_scaling1, opt_control1, opt_scaling1 = \ 6114*da0073e9SAndroid Build Coastguard Worker _create_scaling_models_optimizers(device.type) 6115*da0073e9SAndroid Build Coastguard Worker 6116*da0073e9SAndroid Build Coastguard Worker GradScaler = partial(torch.GradScaler, device=device.type) 6117*da0073e9SAndroid Build Coastguard Worker scaler = GradScaler(init_scale=128., growth_factor=2.0, enabled=enabled, growth_interval=1) 6118*da0073e9SAndroid Build Coastguard Worker 6119*da0073e9SAndroid Build Coastguard Worker def run(model0, model1, optimizer0, optimizer1, try_scaling_api): 6120*da0073e9SAndroid Build Coastguard Worker for i, (input, target) in enumerate(data): 6121*da0073e9SAndroid Build Coastguard Worker optimizer0.zero_grad() 6122*da0073e9SAndroid Build Coastguard Worker optimizer1.zero_grad() 6123*da0073e9SAndroid Build Coastguard Worker output0 = model0(input) 6124*da0073e9SAndroid Build Coastguard Worker output1 = model1(input) 6125*da0073e9SAndroid Build Coastguard Worker loss0 = loss_fn(0.3 * output0 + 0.7 * output1, target) 6126*da0073e9SAndroid Build Coastguard Worker loss1 = loss_fn(0.6 * output0 - 0.4 * output1, target) 6127*da0073e9SAndroid Build Coastguard Worker 6128*da0073e9SAndroid Build Coastguard Worker if try_scaling_api: 6129*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss0).backward(retain_graph=True) 6130*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss1).backward() 6131*da0073e9SAndroid Build Coastguard Worker if i == skip_iter and scaler.is_enabled(): 6132*da0073e9SAndroid Build Coastguard Worker model1[1].weight.grad.data.fill_(float('inf')) 6133*da0073e9SAndroid Build Coastguard Worker 6134*da0073e9SAndroid Build Coastguard Worker # As an additional stress test, separately unscale for one of the optimizers. 6135*da0073e9SAndroid Build Coastguard Worker scaler.unscale_(optimizer0) 6136*da0073e9SAndroid Build Coastguard Worker 6137*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer0) 6138*da0073e9SAndroid Build Coastguard Worker scaler.step(optimizer1) 6139*da0073e9SAndroid Build Coastguard Worker scaler.update() 6140*da0073e9SAndroid Build Coastguard Worker else: 6141*da0073e9SAndroid Build Coastguard Worker loss0.backward(retain_graph=True) 6142*da0073e9SAndroid Build Coastguard Worker loss1.backward() 6143*da0073e9SAndroid Build Coastguard Worker optimizer0.step() 6144*da0073e9SAndroid Build Coastguard Worker if (not scaler.is_enabled()) or (i != skip_iter): 6145*da0073e9SAndroid Build Coastguard Worker optimizer1.step() 6146*da0073e9SAndroid Build Coastguard Worker 6147*da0073e9SAndroid Build Coastguard Worker run(mod_control0, mod_control1, opt_control0, opt_control1, False) 6148*da0073e9SAndroid Build Coastguard Worker run(mod_scaling0, mod_scaling1, opt_scaling0, opt_scaling1, True) 6149*da0073e9SAndroid Build Coastguard Worker 6150*da0073e9SAndroid Build Coastguard Worker # The loss scale should have been multiplied by the growth factor 3 times and the backoff factor once. 6151*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scaler.get_scale() == (128. * scaler.get_growth_factor()**3 * 6152*da0073e9SAndroid Build Coastguard Worker scaler.get_backoff_factor()**1) if enabled else 1.0) 6153*da0073e9SAndroid Build Coastguard Worker 6154*da0073e9SAndroid Build Coastguard Worker for c, s in zip(chain(mod_control0.parameters(), mod_control1.parameters()), 6155*da0073e9SAndroid Build Coastguard Worker chain(mod_scaling0.parameters(), mod_scaling1.parameters())): 6156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, s, rtol=1e-5, atol=1e-7) 6157*da0073e9SAndroid Build Coastguard Worker 6158*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6159*da0073e9SAndroid Build Coastguard Worker def test_grad_scaler_pass_itself(self, device): 6160*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 6161*da0073e9SAndroid Build Coastguard Worker GradScaler = partial(torch.amp.GradScaler, device=device.type) 6162*da0073e9SAndroid Build Coastguard Worker 6163*da0073e9SAndroid Build Coastguard Worker class _PlaceHolderOptimizer(torch.optim.Optimizer): 6164*da0073e9SAndroid Build Coastguard Worker tester = self 6165*da0073e9SAndroid Build Coastguard Worker 6166*da0073e9SAndroid Build Coastguard Worker def __init__(self, params, defaults=None): 6167*da0073e9SAndroid Build Coastguard Worker if defaults is None: 6168*da0073e9SAndroid Build Coastguard Worker defaults = {} 6169*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 6170*da0073e9SAndroid Build Coastguard Worker self._step_supports_amp_scaling = True 6171*da0073e9SAndroid Build Coastguard Worker 6172*da0073e9SAndroid Build Coastguard Worker class Optimizer1(_PlaceHolderOptimizer): 6173*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None, *, grad_scaler=None): 6174*da0073e9SAndroid Build Coastguard Worker self.tester.assertTrue(isinstance(grad_scaler, torch.amp.GradScaler)) 6175*da0073e9SAndroid Build Coastguard Worker self.tester.assertFalse(hasattr(self, "grad_scale")) 6176*da0073e9SAndroid Build Coastguard Worker self.tester.assertFalse(hasattr(self, "found_inf")) 6177*da0073e9SAndroid Build Coastguard Worker 6178*da0073e9SAndroid Build Coastguard Worker class Optimizer2(_PlaceHolderOptimizer): 6179*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 6180*da0073e9SAndroid Build Coastguard Worker self.tester.assertTrue(isinstance(self.grad_scale, torch.Tensor)) 6181*da0073e9SAndroid Build Coastguard Worker self.tester.assertTrue(isinstance(self.found_inf, torch.Tensor)) 6182*da0073e9SAndroid Build Coastguard Worker 6183*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4).to(device) 6184*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Linear(4, 1).to(device) 6185*da0073e9SAndroid Build Coastguard Worker o1 = Optimizer1(m.parameters()) 6186*da0073e9SAndroid Build Coastguard Worker o2 = Optimizer2(m.parameters()) 6187*da0073e9SAndroid Build Coastguard Worker scaler = GradScaler(init_scale=2.0) 6188*da0073e9SAndroid Build Coastguard Worker 6189*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type=device.type, dtype=torch.half): 6190*da0073e9SAndroid Build Coastguard Worker y = m(x) 6191*da0073e9SAndroid Build Coastguard Worker loss = y.mean() 6192*da0073e9SAndroid Build Coastguard Worker scaler.scale(loss).backward() 6193*da0073e9SAndroid Build Coastguard Worker with self.assertWarns(FutureWarning): 6194*da0073e9SAndroid Build Coastguard Worker scaler.step(o1) 6195*da0073e9SAndroid Build Coastguard Worker scaler.step(o2) 6196*da0073e9SAndroid Build Coastguard Worker scaler.update() 6197*da0073e9SAndroid Build Coastguard Worker 6198*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6199*da0073e9SAndroid Build Coastguard Worker def test_grad_scaler_deprecated_warning(self, device): 6200*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 6201*da0073e9SAndroid Build Coastguard Worker GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler 6202*da0073e9SAndroid Build Coastguard Worker 6203*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 6204*da0073e9SAndroid Build Coastguard Worker FutureWarning, 6205*da0073e9SAndroid Build Coastguard Worker rf"`torch.{device.type}.amp.GradScaler\(args...\)` is deprecated.", 6206*da0073e9SAndroid Build Coastguard Worker ): 6207*da0073e9SAndroid Build Coastguard Worker _ = GradScaler(init_scale=2.0) 6208*da0073e9SAndroid Build Coastguard Worker 6209*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.float, torch.double, torch.half) 6210*da0073e9SAndroid Build Coastguard Worker @dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half) 6211*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 6212*da0073e9SAndroid Build Coastguard Worker def test_multinomial_cpu(self, device, dtype): 6213*da0073e9SAndroid Build Coastguard Worker def make_prob_dist(shape, is_contiguous): 6214*da0073e9SAndroid Build Coastguard Worker if is_contiguous: 6215*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half or dtype == torch.bfloat16: 6216*da0073e9SAndroid Build Coastguard Worker return torch.zeros(shape, device=device).uniform_().to(dtype=dtype) 6217*da0073e9SAndroid Build Coastguard Worker return torch.zeros(shape, device=device, dtype=dtype).uniform_() 6218*da0073e9SAndroid Build Coastguard Worker elif len(shape) == 1: 6219*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half or dtype == torch.bfloat16: 6220*da0073e9SAndroid Build Coastguard Worker return torch.zeros((shape + [5]), device=device).uniform_().to(dtype=dtype)[:, 2] 6221*da0073e9SAndroid Build Coastguard Worker return torch.zeros((shape + [5]), device=device, dtype=dtype).uniform_()[:, 2] 6222*da0073e9SAndroid Build Coastguard Worker else: 6223*da0073e9SAndroid Build Coastguard Worker # num dim = 2 6224*da0073e9SAndroid Build Coastguard Worker new_shape = [2, shape[1], 7, 1, shape[0], 1, 10] 6225*da0073e9SAndroid Build Coastguard Worker if dtype == torch.half or dtype == torch.bfloat16: 6226*da0073e9SAndroid Build Coastguard Worker prob_dist = torch.zeros(new_shape, device=device).uniform_().to(dtype=dtype) 6227*da0073e9SAndroid Build Coastguard Worker else: 6228*da0073e9SAndroid Build Coastguard Worker prob_dist = torch.zeros(new_shape, device=device, dtype=dtype).uniform_() 6229*da0073e9SAndroid Build Coastguard Worker prob_dist = prob_dist.transpose(1, 4) 6230*da0073e9SAndroid Build Coastguard Worker prob_dist = prob_dist[1, :, 5, 0, :, 0, 4] 6231*da0073e9SAndroid Build Coastguard Worker assert not prob_dist.is_contiguous() # sanity check 6232*da0073e9SAndroid Build Coastguard Worker return prob_dist 6233*da0073e9SAndroid Build Coastguard Worker 6234*da0073e9SAndroid Build Coastguard Worker # FIXME: move to elementwise ternary test suite 6235*da0073e9SAndroid Build Coastguard Worker # As the test fails with Runtime Error not raised on XLA 6236*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6237*da0073e9SAndroid Build Coastguard Worker def test_where_scalar_handcrafted_values(self, device): 6238*da0073e9SAndroid Build Coastguard Worker # Tests ScalarxScalar, ScalarxTensor and TensorxScalar 6239*da0073e9SAndroid Build Coastguard Worker # variant of `where` against NumPy version with 6240*da0073e9SAndroid Build Coastguard Worker # handcrafted values. 6241*da0073e9SAndroid Build Coastguard Worker condition_shape = (5, 5) 6242*da0073e9SAndroid Build Coastguard Worker dtypes = ( 6243*da0073e9SAndroid Build Coastguard Worker torch.bool, torch.uint8, torch.int8, torch.int16, torch.int64, 6244*da0073e9SAndroid Build Coastguard Worker torch.float16, torch.float32, torch.float64, 6245*da0073e9SAndroid Build Coastguard Worker torch.complex64, torch.complex128, 6246*da0073e9SAndroid Build Coastguard Worker ) 6247*da0073e9SAndroid Build Coastguard Worker shapes = ((), (5,), (1, 5),) 6248*da0073e9SAndroid Build Coastguard Worker 6249*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 6250*da0073e9SAndroid Build Coastguard Worker tensors = (torch.empty(shape, dtype=dtype, device=device).fill_(17) 6251*da0073e9SAndroid Build Coastguard Worker for shape, dtype in product(shapes, dtypes)) 6252*da0073e9SAndroid Build Coastguard Worker 6253*da0073e9SAndroid Build Coastguard Worker # Use different values for `x` and `y` 6254*da0073e9SAndroid Build Coastguard Worker # as they are the output values which are compared. 6255*da0073e9SAndroid Build Coastguard Worker x_vals = (True, 3, 7.0, 1 + 0.5j) 6256*da0073e9SAndroid Build Coastguard Worker y_vals = itertools.chain((False, 4, 8.0, 2 + 0.5j), tensors) 6257*da0073e9SAndroid Build Coastguard Worker for x in x_vals: 6258*da0073e9SAndroid Build Coastguard Worker for y in y_vals: 6259*da0073e9SAndroid Build Coastguard Worker condition = torch.empty(*condition_shape, dtype=torch.bool, device=device).bernoulli_() 6260*da0073e9SAndroid Build Coastguard Worker common_dtype = torch.result_type(x, y) 6261*da0073e9SAndroid Build Coastguard Worker 6262*da0073e9SAndroid Build Coastguard Worker def check_equal(condition, x, y): 6263*da0073e9SAndroid Build Coastguard Worker condition_np = condition.cpu().numpy() 6264*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x 6265*da0073e9SAndroid Build Coastguard Worker y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y 6266*da0073e9SAndroid Build Coastguard Worker 6267*da0073e9SAndroid Build Coastguard Worker # NumPy aggressively promotes to double, hence cast to output to correct dtype 6268*da0073e9SAndroid Build Coastguard Worker expected = torch.from_numpy(np.where(condition_np, x_np, y_np)).to(common_dtype) 6269*da0073e9SAndroid Build Coastguard Worker result = torch.where(condition, x, y) 6270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, result) 6271*da0073e9SAndroid Build Coastguard Worker 6272*da0073e9SAndroid Build Coastguard Worker check_equal(condition, x, y) 6273*da0073e9SAndroid Build Coastguard Worker check_equal(condition, y, x) 6274*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cuda": 6275*da0073e9SAndroid Build Coastguard Worker check_equal(condition, torch.tensor(x), y) 6276*da0073e9SAndroid Build Coastguard Worker check_equal(condition, y, torch.tensor(x)) 6277*da0073e9SAndroid Build Coastguard Worker if not isinstance(y, torch.Tensor): 6278*da0073e9SAndroid Build Coastguard Worker check_equal(condition, torch.tensor(y), torch.tensor(x)) 6279*da0073e9SAndroid Build Coastguard Worker if isinstance(y, torch.Tensor) and y.ndim > 0: 6280*da0073e9SAndroid Build Coastguard Worker check_equal(torch.tensor(True), x, y) 6281*da0073e9SAndroid Build Coastguard Worker check_equal(torch.tensor(True), y, x) 6282*da0073e9SAndroid Build Coastguard Worker 6283*da0073e9SAndroid Build Coastguard Worker 6284*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("FIXME") 6285*da0073e9SAndroid Build Coastguard Worker def test_hook_remove(self, device): 6286*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/58354 6287*da0073e9SAndroid Build Coastguard Worker def _test_helper(remove_hook): 6288*da0073e9SAndroid Build Coastguard Worker def install_hook(tensor): 6289*da0073e9SAndroid Build Coastguard Worker handle = None 6290*da0073e9SAndroid Build Coastguard Worker 6291*da0073e9SAndroid Build Coastguard Worker def hook(tensor): 6292*da0073e9SAndroid Build Coastguard Worker if remove_hook: 6293*da0073e9SAndroid Build Coastguard Worker handle.remove() 6294*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(tensor) 6295*da0073e9SAndroid Build Coastguard Worker handle = tensor.register_hook(hook) 6296*da0073e9SAndroid Build Coastguard Worker 6297*da0073e9SAndroid Build Coastguard Worker t = torch.ones((1, 5), device=device, requires_grad=True) 6298*da0073e9SAndroid Build Coastguard Worker install_hook(t) 6299*da0073e9SAndroid Build Coastguard Worker 6300*da0073e9SAndroid Build Coastguard Worker # First call to backward 6301*da0073e9SAndroid Build Coastguard Worker t.mean().backward() 6302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, torch.zeros_like(t)) 6303*da0073e9SAndroid Build Coastguard Worker 6304*da0073e9SAndroid Build Coastguard Worker # Second call to backward 6305*da0073e9SAndroid Build Coastguard Worker t.mean().backward() 6306*da0073e9SAndroid Build Coastguard Worker if remove_hook: 6307*da0073e9SAndroid Build Coastguard Worker # After removing the hook, make sure the usual gradient is returned 6308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, 0.2 * torch.ones_like(t)) 6309*da0073e9SAndroid Build Coastguard Worker else: 6310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, torch.zeros_like(t)) 6311*da0073e9SAndroid Build Coastguard Worker 6312*da0073e9SAndroid Build Coastguard Worker _test_helper(remove_hook=True) 6313*da0073e9SAndroid Build Coastguard Worker _test_helper(remove_hook=False) 6314*da0073e9SAndroid Build Coastguard Worker 6315*da0073e9SAndroid Build Coastguard Worker # FIXME: get PyTorch/XLA to run test_testing 6316*da0073e9SAndroid Build Coastguard Worker # This test should ideally be in test_testing.py, 6317*da0073e9SAndroid Build Coastguard Worker # but since pytorch/xla runs tests from test_torch.py, we have it here. 6318*da0073e9SAndroid Build Coastguard Worker @skipXLA 6319*da0073e9SAndroid Build Coastguard Worker def test_skip_xla(self, device): 6320*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'xla': 6321*da0073e9SAndroid Build Coastguard Worker # Should not reach here! 6322*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False) 6323*da0073e9SAndroid Build Coastguard Worker 6324*da0073e9SAndroid Build Coastguard Worker # FIXME: get PyTorch/XLA to run test_testing 6325*da0073e9SAndroid Build Coastguard Worker # This test should ideally be in test_testing.py, 6326*da0073e9SAndroid Build Coastguard Worker # but since pytorch/xla runs tests from test_torch.py, we have it here. 6327*da0073e9SAndroid Build Coastguard Worker @expectedFailureXLA 6328*da0073e9SAndroid Build Coastguard Worker def test_expected_failure_xla(self, device): 6329*da0073e9SAndroid Build Coastguard Worker if self.device_type == 'xla': 6330*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False) 6331*da0073e9SAndroid Build Coastguard Worker 6332*da0073e9SAndroid Build Coastguard Worker # FIXME: get PyTorch/XLA to run test_testing 6333*da0073e9SAndroid Build Coastguard Worker # This test should ideally be in test_testing.py, 6334*da0073e9SAndroid Build Coastguard Worker # but since pytorch/xla runs tests from test_torch.py, we have it here. 6335*da0073e9SAndroid Build Coastguard Worker def test_assertRaisesRegex_ignore_msg_non_native_device(self, device): 6336*da0073e9SAndroid Build Coastguard Worker # Verify that self.assertRaisesRegex only checks the Error and ignores 6337*da0073e9SAndroid Build Coastguard Worker # message for non-native devices. 6338*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10, 3), device=device) 6339*da0073e9SAndroid Build Coastguard Worker t = torch.empty(10, dtype=torch.int64, device=device).random_(0, 3) 6340*da0073e9SAndroid Build Coastguard Worker invalid_weight = torch.randn(4, device=device) 6341*da0073e9SAndroid Build Coastguard Worker msg = "weight tensor should be defined either for all 3 classes or no classes" 6342*da0073e9SAndroid Build Coastguard Worker 6343*da0073e9SAndroid Build Coastguard Worker # XLA raises RuntimeError with a different message. 6344*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 6345*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.nll_loss(x, t, weight=invalid_weight) 6346*da0073e9SAndroid Build Coastguard Worker 6347*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32)) 6348*da0073e9SAndroid Build Coastguard Worker def test_copy_(self, device, dtype): 6349*da0073e9SAndroid Build Coastguard Worker def can_cast(src_dtype, dst_dtype): 6350*da0073e9SAndroid Build Coastguard Worker # torch.can_cast(torch.int16, torch.uint8) returns True 6351*da0073e9SAndroid Build Coastguard Worker # which isn't actually safe-cast. 6352*da0073e9SAndroid Build Coastguard Worker # This function returns False in this case. 6353*da0073e9SAndroid Build Coastguard Worker def is_unsigned_int(dtype): 6354*da0073e9SAndroid Build Coastguard Worker return dtype is torch.uint8 6355*da0073e9SAndroid Build Coastguard Worker 6356*da0073e9SAndroid Build Coastguard Worker if is_unsigned_int(dst_dtype): 6357*da0073e9SAndroid Build Coastguard Worker return is_unsigned_int(src_dtype) 6358*da0073e9SAndroid Build Coastguard Worker return torch.can_cast(src_dtype, dst_dtype) 6359*da0073e9SAndroid Build Coastguard Worker 6360*da0073e9SAndroid Build Coastguard Worker def make_tensor_wrapper(shape, dtype): 6361*da0073e9SAndroid Build Coastguard Worker if dtype is not torch.complex32: 6362*da0073e9SAndroid Build Coastguard Worker # Make tensor does not support generating 6363*da0073e9SAndroid Build Coastguard Worker # complex32 tensor 6364*da0073e9SAndroid Build Coastguard Worker return make_tensor(shape, device=device, dtype=dtype) 6365*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, device=device, dtype=dtype) 6366*da0073e9SAndroid Build Coastguard Worker 6367*da0073e9SAndroid Build Coastguard Worker t = make_tensor_wrapper((50,), dtype) 6368*da0073e9SAndroid Build Coastguard Worker src_dtypes = all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.complex32) 6369*da0073e9SAndroid Build Coastguard Worker for src_dtype in src_dtypes: 6370*da0073e9SAndroid Build Coastguard Worker src = make_tensor_wrapper((50,), dtype=src_dtype) 6371*da0073e9SAndroid Build Coastguard Worker t.copy_(src) 6372*da0073e9SAndroid Build Coastguard Worker dst = make_tensor_wrapper((50, ), dtype=src_dtype) 6373*da0073e9SAndroid Build Coastguard Worker if can_cast(src_dtype, dtype): 6374*da0073e9SAndroid Build Coastguard Worker rtol = None 6375*da0073e9SAndroid Build Coastguard Worker atol = None 6376*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.half, torch.complex32): 6377*da0073e9SAndroid Build Coastguard Worker rtol = 1e-3 6378*da0073e9SAndroid Build Coastguard Worker atol = 1e-3 6379*da0073e9SAndroid Build Coastguard Worker if dtype in (torch.bfloat16,): 6380*da0073e9SAndroid Build Coastguard Worker rtol = 1e-2 6381*da0073e9SAndroid Build Coastguard Worker atol = 1e-2 6382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src, dst.copy_(t), rtol=rtol, atol=atol) 6383*da0073e9SAndroid Build Coastguard Worker 6384*da0073e9SAndroid Build Coastguard Worker @dtypes(*all_types_and_complex_and( 6385*da0073e9SAndroid Build Coastguard Worker torch.bool, torch.half, torch.bfloat16, torch.complex32, 6386*da0073e9SAndroid Build Coastguard Worker torch.uint16, torch.uint32, torch.uint64)) 6387*da0073e9SAndroid Build Coastguard Worker def test_item(self, device, dtype): 6388*da0073e9SAndroid Build Coastguard Worker if torch.device(device).type == 'xla' and dtype in [torch.uint16, torch.uint32, torch.uint64]: 6389*da0073e9SAndroid Build Coastguard Worker self.skipTest('uint16,32,64 not implemented on XLA') 6390*da0073e9SAndroid Build Coastguard Worker t = torch.ones((), device=device, dtype=dtype) 6391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, t.item()) 6392*da0073e9SAndroid Build Coastguard Worker 6393*da0073e9SAndroid Build Coastguard Worker @onlyNativeDeviceTypes 6394*da0073e9SAndroid Build Coastguard Worker def test_masked_scatter_inplace_noncontiguous(self, device): 6395*da0073e9SAndroid Build Coastguard Worker t = torch.zeros(5, 2, dtype=torch.long, device=device) 6396*da0073e9SAndroid Build Coastguard Worker t_non_contig = t.transpose(0, 1) 6397*da0073e9SAndroid Build Coastguard Worker t_contig = t_non_contig.contiguous() 6398*da0073e9SAndroid Build Coastguard Worker 6399*da0073e9SAndroid Build Coastguard Worker assert t_contig.is_contiguous() 6400*da0073e9SAndroid Build Coastguard Worker assert not t_non_contig.is_contiguous() 6401*da0073e9SAndroid Build Coastguard Worker 6402*da0073e9SAndroid Build Coastguard Worker mask = torch.tensor([[False, True], [False, True], [False, False], [True, True], [True, True]], device=device) 6403*da0073e9SAndroid Build Coastguard Worker mask_non_contig = mask.transpose(0, 1) 6404*da0073e9SAndroid Build Coastguard Worker mask_contig = mask_non_contig.contiguous() 6405*da0073e9SAndroid Build Coastguard Worker 6406*da0073e9SAndroid Build Coastguard Worker assert mask_contig.is_contiguous() 6407*da0073e9SAndroid Build Coastguard Worker assert not mask_non_contig.is_contiguous() 6408*da0073e9SAndroid Build Coastguard Worker 6409*da0073e9SAndroid Build Coastguard Worker # source is always converted to contiguous by the op. 6410*da0073e9SAndroid Build Coastguard Worker source = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 9]], device=device) 6411*da0073e9SAndroid Build Coastguard Worker 6412*da0073e9SAndroid Build Coastguard Worker # t: contig, mask: contig 6413*da0073e9SAndroid Build Coastguard Worker expected = t_contig.masked_scatter_(mask_contig, source) 6414*da0073e9SAndroid Build Coastguard Worker 6415*da0073e9SAndroid Build Coastguard Worker # t: non-contig, mask: non-contig 6416*da0073e9SAndroid Build Coastguard Worker actual = t_non_contig.masked_scatter_(mask_non_contig, source) 6417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 6418*da0073e9SAndroid Build Coastguard Worker 6419*da0073e9SAndroid Build Coastguard Worker # t: contig, mask: non-contig 6420*da0073e9SAndroid Build Coastguard Worker actual = t_contig.masked_scatter_(mask_non_contig, source) 6421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 6422*da0073e9SAndroid Build Coastguard Worker 6423*da0073e9SAndroid Build Coastguard Worker # t: non-contig, mask: contig 6424*da0073e9SAndroid Build Coastguard Worker actual = t_non_contig.masked_scatter_(mask_contig, source) 6425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 6426*da0073e9SAndroid Build Coastguard Worker 6427*da0073e9SAndroid Build Coastguard Worker 6428*da0073e9SAndroid Build Coastguard Worker# Tests that compare a device's computation with the (gold-standard) CPU's. 6429*da0073e9SAndroid Build Coastguard Workerclass TestDevicePrecision(TestCase): 6430*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 6431*da0073e9SAndroid Build Coastguard Worker 6432*da0073e9SAndroid Build Coastguard Worker # FIXME: move to indexing test suite 6433*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 6434*da0073e9SAndroid Build Coastguard Worker def test_index_add_bfloat16(self, device): 6435*da0073e9SAndroid Build Coastguard Worker inp_tensor = torch.randn(5, 3, device='cpu').bfloat16() 6436*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.bfloat16, device='cpu') 6437*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([0, 4, 2], device='cpu') 6438*da0073e9SAndroid Build Coastguard Worker out_cpu = inp_tensor.index_add(0, index, t) 6439*da0073e9SAndroid Build Coastguard Worker 6440*da0073e9SAndroid Build Coastguard Worker inp_tensor = inp_tensor.to(device=device) 6441*da0073e9SAndroid Build Coastguard Worker t = t.to(device=device) 6442*da0073e9SAndroid Build Coastguard Worker index = index.to(device=device) 6443*da0073e9SAndroid Build Coastguard Worker out_gpu = inp_tensor.index_add(0, index, t) 6444*da0073e9SAndroid Build Coastguard Worker 6445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_cpu, out_gpu, atol=1e-2, rtol=0) 6446*da0073e9SAndroid Build Coastguard Worker 6447*da0073e9SAndroid Build Coastguard Worker # FIXME: move to serialization test suite 6448*da0073e9SAndroid Build Coastguard Worker def test_device_serialization(self, device): 6449*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4, device=device) 6450*da0073e9SAndroid Build Coastguard Worker 6451*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 6452*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 6453*da0073e9SAndroid Build Coastguard Worker f.seek(0) 6454*da0073e9SAndroid Build Coastguard Worker x_copy = torch.load(f) 6455*da0073e9SAndroid Build Coastguard Worker 6456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_copy, x) 6457*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(x_copy), type(x)) 6458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_copy.device, x.device) 6459*da0073e9SAndroid Build Coastguard Worker 6460*da0073e9SAndroid Build Coastguard Worker # FIXME: move to serialization test suite 6461*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 6462*da0073e9SAndroid Build Coastguard Worker def test_multidevice_serialization(self, devices): 6463*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(4, 4, device=devices[0]), 6464*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 4, device=devices[1])] 6465*da0073e9SAndroid Build Coastguard Worker 6466*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as f: 6467*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 6468*da0073e9SAndroid Build Coastguard Worker f.seek(0) 6469*da0073e9SAndroid Build Coastguard Worker x_copy = torch.load(f) 6470*da0073e9SAndroid Build Coastguard Worker 6471*da0073e9SAndroid Build Coastguard Worker for original, cp in zip(x, x_copy): 6472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cp, original) 6473*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(cp), type(original)) 6474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cp.device, original.device) 6475*da0073e9SAndroid Build Coastguard Worker 6476*da0073e9SAndroid Build Coastguard Worker # FIXME: move to data movement test suite 6477*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 6478*da0073e9SAndroid Build Coastguard Worker def test_copy_noncontig(self, devices): 6479*da0073e9SAndroid Build Coastguard Worker def do_test(d0, d1): 6480*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5, 6.5], device=d0) 6481*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0, 0, 0, 0, 0, 0], device=d1) 6482*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(x.dtype, y.dtype) 6483*da0073e9SAndroid Build Coastguard Worker 6484*da0073e9SAndroid Build Coastguard Worker y[::2].copy_(x[::2]) 6485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, [1, 0, 3, 0, 5, 0]) 6486*da0073e9SAndroid Build Coastguard Worker 6487*da0073e9SAndroid Build Coastguard Worker do_test('cpu', devices[0]) 6488*da0073e9SAndroid Build Coastguard Worker do_test(devices[0], 'cpu') 6489*da0073e9SAndroid Build Coastguard Worker 6490*da0073e9SAndroid Build Coastguard Worker if len(devices) > 1: 6491*da0073e9SAndroid Build Coastguard Worker do_test(devices[0], devices[1]) 6492*da0073e9SAndroid Build Coastguard Worker 6493*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 6494*da0073e9SAndroid Build Coastguard Worker def test_type_conversions_same_device(self, devices): 6495*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device=devices[1]) 6496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.int().device, torch.device(devices[1])) 6497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.type(torch.int).device, torch.device(devices[1])) 6498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.to(torch.int).device, torch.device(devices[1])) 6499*da0073e9SAndroid Build Coastguard Worker 6500*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA(torch.half, torch.float, torch.double, 6501*da0073e9SAndroid Build Coastguard Worker torch.int8, torch.short, torch.int, torch.long, 6502*da0073e9SAndroid Build Coastguard Worker torch.uint8) 6503*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double, 6504*da0073e9SAndroid Build Coastguard Worker torch.int8, torch.short, torch.int, torch.long, 6505*da0073e9SAndroid Build Coastguard Worker torch.uint8) 6506*da0073e9SAndroid Build Coastguard Worker def test_from_sequence(self, device, dtype): 6507*da0073e9SAndroid Build Coastguard Worker seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)] 6508*da0073e9SAndroid Build Coastguard Worker reference = torch.arange(0, 20).resize_(5, 4) 6509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False) 6510*da0073e9SAndroid Build Coastguard Worker 6511*da0073e9SAndroid Build Coastguard Worker # FIXME: moved to indexing test suite 6512*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 6513*da0073e9SAndroid Build Coastguard Worker def test_advancedindex_mixed_cpu_devices(self, devices) -> None: 6514*da0073e9SAndroid Build Coastguard Worker def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None: 6515*da0073e9SAndroid Build Coastguard Worker # test getitem 6516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, ia, None, ib, 0].cpu(), 6517*da0073e9SAndroid Build Coastguard Worker x.cpu()[:, ia.cpu(), None, ib.cpu(), 0]) 6518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[ia], x.cpu()[ia.cpu()]) 6519*da0073e9SAndroid Build Coastguard Worker # test setitem 6520*da0073e9SAndroid Build Coastguard Worker x_clone1 = x.clone() 6521*da0073e9SAndroid Build Coastguard Worker x_clone2 = x.clone() 6522*da0073e9SAndroid Build Coastguard Worker first_shape = x[:, ia, None, ib, 0].shape 6523*da0073e9SAndroid Build Coastguard Worker second_shape = x[ia].shape 6524*da0073e9SAndroid Build Coastguard Worker x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1) 6525*da0073e9SAndroid Build Coastguard Worker x_clone2[ia] = torch.randn(second_shape).to(x_clone2) 6526*da0073e9SAndroid Build Coastguard Worker 6527*da0073e9SAndroid Build Coastguard Worker cpu = torch.device('cpu') 6528*da0073e9SAndroid Build Coastguard Worker for device in devices: 6529*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 4, 4, 3) 6530*da0073e9SAndroid Build Coastguard Worker ia = torch.tensor([0, 2, 1]) 6531*da0073e9SAndroid Build Coastguard Worker ib = torch.tensor([0, 2, 1]) 6532*da0073e9SAndroid Build Coastguard Worker 6533*da0073e9SAndroid Build Coastguard Worker # Index device tensor with cpu tensor 6534*da0073e9SAndroid Build Coastguard Worker x = x.to(device) 6535*da0073e9SAndroid Build Coastguard Worker ia = ia.to(cpu) 6536*da0073e9SAndroid Build Coastguard Worker ib = ib.to(cpu) 6537*da0073e9SAndroid Build Coastguard Worker test(x, ia, ib) 6538*da0073e9SAndroid Build Coastguard Worker 6539*da0073e9SAndroid Build Coastguard Worker # Index device tensor with mixed cpu, device tensors 6540*da0073e9SAndroid Build Coastguard Worker x = x.to(device) 6541*da0073e9SAndroid Build Coastguard Worker ia = ia.to(cpu) 6542*da0073e9SAndroid Build Coastguard Worker ib = ib.to(device) 6543*da0073e9SAndroid Build Coastguard Worker test(x, ia, ib) 6544*da0073e9SAndroid Build Coastguard Worker 6545*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 6546*da0073e9SAndroid Build Coastguard Worker def test_advancedindex_mixed_devices_error(self, devices) -> None: 6547*da0073e9SAndroid Build Coastguard Worker def test(x: torch.Tensor, ia: torch.Tensor, ib: torch.Tensor) -> None: 6548*da0073e9SAndroid Build Coastguard Worker # test getitem 6549*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"): 6550*da0073e9SAndroid Build Coastguard Worker value = x[:, ia, None, ib, 0] 6551*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, fr"indices should be either .* \({x.device}\)"): 6552*da0073e9SAndroid Build Coastguard Worker value = x[ib] 6553*da0073e9SAndroid Build Coastguard Worker 6554*da0073e9SAndroid Build Coastguard Worker cpu = torch.device('cpu') 6555*da0073e9SAndroid Build Coastguard Worker for device in devices: 6556*da0073e9SAndroid Build Coastguard Worker # Index cpu tensor with device tensor 6557*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 4, 4, 3) 6558*da0073e9SAndroid Build Coastguard Worker ia = torch.tensor([0, 2, 1]).to(device) 6559*da0073e9SAndroid Build Coastguard Worker ib = torch.tensor([0, 2, 1]).to(device) 6560*da0073e9SAndroid Build Coastguard Worker test(x, ia, ib) 6561*da0073e9SAndroid Build Coastguard Worker 6562*da0073e9SAndroid Build Coastguard Worker # Index cpu tensor with mixed cpu, device tensors 6563*da0073e9SAndroid Build Coastguard Worker x = x.to(cpu) 6564*da0073e9SAndroid Build Coastguard Worker ia = ia.to(cpu) 6565*da0073e9SAndroid Build Coastguard Worker ib = ib.to(device) 6566*da0073e9SAndroid Build Coastguard Worker test(x, ia, ib) 6567*da0073e9SAndroid Build Coastguard Worker 6568*da0073e9SAndroid Build Coastguard Worker if len(devices) > 1: 6569*da0073e9SAndroid Build Coastguard Worker other_device = devices[0] if device == devices[1] else devices[1] 6570*da0073e9SAndroid Build Coastguard Worker 6571*da0073e9SAndroid Build Coastguard Worker # Index device tensor with mixed cpu, device tensors on different devices 6572*da0073e9SAndroid Build Coastguard Worker x = x.to(device) 6573*da0073e9SAndroid Build Coastguard Worker ia = ia.to(cpu) 6574*da0073e9SAndroid Build Coastguard Worker ib = ib.to(other_device) 6575*da0073e9SAndroid Build Coastguard Worker test(x, ia, ib) 6576*da0073e9SAndroid Build Coastguard Worker 6577*da0073e9SAndroid Build Coastguard Worker # FIXME: move to data movement test suite 6578*da0073e9SAndroid Build Coastguard Worker def test_copy_broadcast(self, device) -> None: 6579*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 5) 6580*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, device=device) 6581*da0073e9SAndroid Build Coastguard Worker x.copy_(y) 6582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[3], y) 6583*da0073e9SAndroid Build Coastguard Worker 6584*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 5, device=device) 6585*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5) 6586*da0073e9SAndroid Build Coastguard Worker x.copy_(y) 6587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[3], y) 6588*da0073e9SAndroid Build Coastguard Worker 6589*da0073e9SAndroid Build Coastguard Worker # FIXME: move to an elementwise ternary test suite 6590*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.int64, torch.float32, torch.float64) 6591*da0073e9SAndroid Build Coastguard Worker def test_clamp(self, device, dtype): 6592*da0073e9SAndroid Build Coastguard Worker test_args = [ 6593*da0073e9SAndroid Build Coastguard Worker *product( 6594*da0073e9SAndroid Build Coastguard Worker [(100, 50), (10, 64), (97,)], # shape 6595*da0073e9SAndroid Build Coastguard Worker (True, False), # non-contiguous 6596*da0073e9SAndroid Build Coastguard Worker ) 6597*da0073e9SAndroid Build Coastguard Worker ] 6598*da0073e9SAndroid Build Coastguard Worker 6599*da0073e9SAndroid Build Coastguard Worker for shape, noncontig in test_args: 6600*da0073e9SAndroid Build Coastguard Worker x = make_tensor(shape, device=device, dtype=dtype, 6601*da0073e9SAndroid Build Coastguard Worker noncontiguous=noncontig) 6602*da0073e9SAndroid Build Coastguard Worker ub = make_tensor(shape, device=device, dtype=dtype, 6603*da0073e9SAndroid Build Coastguard Worker noncontiguous=noncontig) 6604*da0073e9SAndroid Build Coastguard Worker lb = make_tensor(shape, device=device, dtype=dtype, 6605*da0073e9SAndroid Build Coastguard Worker noncontiguous=noncontig) 6606*da0073e9SAndroid Build Coastguard Worker 6607*da0073e9SAndroid Build Coastguard Worker expect = x.max(lb).min(ub) 6608*da0073e9SAndroid Build Coastguard Worker actual = x.clamp(lb, ub) 6609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 6610*da0073e9SAndroid Build Coastguard Worker 6611*da0073e9SAndroid Build Coastguard Worker expect = np.clip(x.cpu().numpy(), lb.cpu().numpy(), ub.cpu().numpy()) 6612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 6613*da0073e9SAndroid Build Coastguard Worker 6614*da0073e9SAndroid Build Coastguard Worker expect = x.max(lb) 6615*da0073e9SAndroid Build Coastguard Worker actual = x.clamp(min=lb) 6616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 6617*da0073e9SAndroid Build Coastguard Worker 6618*da0073e9SAndroid Build Coastguard Worker expect = x.min(ub) 6619*da0073e9SAndroid Build Coastguard Worker actual = x.clamp(max=ub) 6620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 6621*da0073e9SAndroid Build Coastguard Worker 6622*da0073e9SAndroid Build Coastguard Worker # Test broadcasting min & max 6623*da0073e9SAndroid Build Coastguard Worker expect = x.max(lb[0]).min(ub[..., :1]) 6624*da0073e9SAndroid Build Coastguard Worker actual = x.clamp(lb[0], ub[..., :1]) 6625*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 6626*da0073e9SAndroid Build Coastguard Worker 6627*da0073e9SAndroid Build Coastguard Worker # Test broadcasting x 6628*da0073e9SAndroid Build Coastguard Worker expect = x[..., :1].max(lb).min(ub) 6629*da0073e9SAndroid Build Coastguard Worker actual = x[..., :1].clamp(lb, ub) 6630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, actual) 6631*da0073e9SAndroid Build Coastguard Worker 6632*da0073e9SAndroid Build Coastguard Worker def test_cuda_device_idx(self, device): 6633*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, device=device) 6634*da0073e9SAndroid Build Coastguard Worker y = torch._efficientzerotensor(3, device=device) 6635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.device, y.device) 6636*da0073e9SAndroid Build Coastguard Worker 6637*da0073e9SAndroid Build Coastguard Worker# we implemented custom deallocation for subclasses, so it behooves 6638*da0073e9SAndroid Build Coastguard Worker# us to make sure all of these bits work. We'll use __del__ to 6639*da0073e9SAndroid Build Coastguard Worker# track if objects die or not 6640*da0073e9SAndroid Build Coastguard Workerclass Tracker: 6641*da0073e9SAndroid Build Coastguard Worker def __init__(self, marker): 6642*da0073e9SAndroid Build Coastguard Worker self.marker = marker 6643*da0073e9SAndroid Build Coastguard Worker 6644*da0073e9SAndroid Build Coastguard Worker @staticmethod 6645*da0073e9SAndroid Build Coastguard Worker def make(): 6646*da0073e9SAndroid Build Coastguard Worker marker = [False] 6647*da0073e9SAndroid Build Coastguard Worker return marker, Tracker(marker) 6648*da0073e9SAndroid Build Coastguard Worker 6649*da0073e9SAndroid Build Coastguard Worker def __del__(self): 6650*da0073e9SAndroid Build Coastguard Worker self.marker[0] = True 6651*da0073e9SAndroid Build Coastguard Worker 6652*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 6653*da0073e9SAndroid Build Coastguard Workerdef disable_gc(): 6654*da0073e9SAndroid Build Coastguard Worker if gc.isenabled(): 6655*da0073e9SAndroid Build Coastguard Worker try: 6656*da0073e9SAndroid Build Coastguard Worker gc.disable() 6657*da0073e9SAndroid Build Coastguard Worker yield 6658*da0073e9SAndroid Build Coastguard Worker finally: 6659*da0073e9SAndroid Build Coastguard Worker gc.enable() 6660*da0073e9SAndroid Build Coastguard Worker else: 6661*da0073e9SAndroid Build Coastguard Worker yield 6662*da0073e9SAndroid Build Coastguard Worker 6663*da0073e9SAndroid Build Coastguard Workerclass TestTorch(TestCase): 6664*da0073e9SAndroid Build Coastguard Worker exact_dtype = True 6665*da0073e9SAndroid Build Coastguard Worker 6666*da0073e9SAndroid Build Coastguard Worker def test_dir(self): 6667*da0073e9SAndroid Build Coastguard Worker dir(torch) 6668*da0073e9SAndroid Build Coastguard Worker 6669*da0073e9SAndroid Build Coastguard Worker def test_wildcard_import(self): 6670*da0073e9SAndroid Build Coastguard Worker exec('from torch import *') 6671*da0073e9SAndroid Build Coastguard Worker 6672*da0073e9SAndroid Build Coastguard Worker def test_newaxis_numpy_comparison(self): 6673*da0073e9SAndroid Build Coastguard Worker def run_test(tensor, *idx): 6674*da0073e9SAndroid Build Coastguard Worker npt = tensor.numpy() 6675*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor[idx], npt[idx]) 6676*da0073e9SAndroid Build Coastguard Worker 6677*da0073e9SAndroid Build Coastguard Worker # 1D Tensor Tests 6678*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 10) 6679*da0073e9SAndroid Build Coastguard Worker cases = [ 6680*da0073e9SAndroid Build Coastguard Worker [None], 6681*da0073e9SAndroid Build Coastguard Worker [None, None], 6682*da0073e9SAndroid Build Coastguard Worker [Ellipsis, None], 6683*da0073e9SAndroid Build Coastguard Worker [None, Ellipsis], 6684*da0073e9SAndroid Build Coastguard Worker [2, None], 6685*da0073e9SAndroid Build Coastguard Worker [None, 2], 6686*da0073e9SAndroid Build Coastguard Worker [Ellipsis, None, 2], 6687*da0073e9SAndroid Build Coastguard Worker [Ellipsis, 2, None], 6688*da0073e9SAndroid Build Coastguard Worker [2, Ellipsis, None], 6689*da0073e9SAndroid Build Coastguard Worker [2, None, Ellipsis], 6690*da0073e9SAndroid Build Coastguard Worker [None, 2, Ellipsis], 6691*da0073e9SAndroid Build Coastguard Worker [None, Ellipsis, 2], 6692*da0073e9SAndroid Build Coastguard Worker ] 6693*da0073e9SAndroid Build Coastguard Worker 6694*da0073e9SAndroid Build Coastguard Worker for case in cases: 6695*da0073e9SAndroid Build Coastguard Worker run_test(x, *case) 6696*da0073e9SAndroid Build Coastguard Worker 6697*da0073e9SAndroid Build Coastguard Worker # 2D Tensor Tests 6698*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 12).view(3, 4) 6699*da0073e9SAndroid Build Coastguard Worker cases = [ 6700*da0073e9SAndroid Build Coastguard Worker [None], 6701*da0073e9SAndroid Build Coastguard Worker [None, None], 6702*da0073e9SAndroid Build Coastguard Worker [None, None, None], 6703*da0073e9SAndroid Build Coastguard Worker [Ellipsis, None], 6704*da0073e9SAndroid Build Coastguard Worker [Ellipsis, None, None], 6705*da0073e9SAndroid Build Coastguard Worker [None, Ellipsis], 6706*da0073e9SAndroid Build Coastguard Worker [None, Ellipsis, None], 6707*da0073e9SAndroid Build Coastguard Worker [None, None, Ellipsis], 6708*da0073e9SAndroid Build Coastguard Worker [2, None], 6709*da0073e9SAndroid Build Coastguard Worker [2, None, Ellipsis], 6710*da0073e9SAndroid Build Coastguard Worker [2, Ellipsis, None], 6711*da0073e9SAndroid Build Coastguard Worker [None, 2, Ellipsis], 6712*da0073e9SAndroid Build Coastguard Worker [Ellipsis, 2, None], 6713*da0073e9SAndroid Build Coastguard Worker [Ellipsis, None, 2], 6714*da0073e9SAndroid Build Coastguard Worker [None, Ellipsis, 2], 6715*da0073e9SAndroid Build Coastguard Worker [1, 2, None], 6716*da0073e9SAndroid Build Coastguard Worker [1, 2, Ellipsis, None], 6717*da0073e9SAndroid Build Coastguard Worker [1, Ellipsis, 2, None], 6718*da0073e9SAndroid Build Coastguard Worker [Ellipsis, 1, None, 2], 6719*da0073e9SAndroid Build Coastguard Worker [Ellipsis, 1, 2, None], 6720*da0073e9SAndroid Build Coastguard Worker [1, None, 2, Ellipsis], 6721*da0073e9SAndroid Build Coastguard Worker [None, 1, Ellipsis, 2], 6722*da0073e9SAndroid Build Coastguard Worker [None, 1, 2, Ellipsis], 6723*da0073e9SAndroid Build Coastguard Worker ] 6724*da0073e9SAndroid Build Coastguard Worker 6725*da0073e9SAndroid Build Coastguard Worker for case in cases: 6726*da0073e9SAndroid Build Coastguard Worker run_test(x, *case) 6727*da0073e9SAndroid Build Coastguard Worker 6728*da0073e9SAndroid Build Coastguard Worker def _consecutive(self, size, start=1): 6729*da0073e9SAndroid Build Coastguard Worker sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0) 6730*da0073e9SAndroid Build Coastguard Worker sequence.add_(start - 1) 6731*da0073e9SAndroid Build Coastguard Worker return sequence.resize_(*size) 6732*da0073e9SAndroid Build Coastguard Worker 6733*da0073e9SAndroid Build Coastguard Worker def test_newindex(self): 6734*da0073e9SAndroid Build Coastguard Worker reference = self._consecutive((3, 3, 3)) 6735*da0073e9SAndroid Build Coastguard Worker # This relies on __index__() being correct - but we have separate tests for that 6736*da0073e9SAndroid Build Coastguard Worker 6737*da0073e9SAndroid Build Coastguard Worker def checkPartialAssign(index): 6738*da0073e9SAndroid Build Coastguard Worker reference = torch.zeros(3, 3, 3) 6739*da0073e9SAndroid Build Coastguard Worker reference[index] = self._consecutive((3, 3, 3))[index] 6740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], atol=0, rtol=0) 6741*da0073e9SAndroid Build Coastguard Worker reference[index] = 0 6742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reference, torch.zeros(3, 3, 3), atol=0, rtol=0) 6743*da0073e9SAndroid Build Coastguard Worker 6744*da0073e9SAndroid Build Coastguard Worker checkPartialAssign(0) 6745*da0073e9SAndroid Build Coastguard Worker checkPartialAssign(1) 6746*da0073e9SAndroid Build Coastguard Worker checkPartialAssign(2) 6747*da0073e9SAndroid Build Coastguard Worker checkPartialAssign((0, 1)) 6748*da0073e9SAndroid Build Coastguard Worker checkPartialAssign((1, 2)) 6749*da0073e9SAndroid Build Coastguard Worker checkPartialAssign((0, 2)) 6750*da0073e9SAndroid Build Coastguard Worker checkPartialAssign(torch.LongTensor((0, 2))) 6751*da0073e9SAndroid Build Coastguard Worker 6752*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6753*da0073e9SAndroid Build Coastguard Worker reference[1, 1, 1, 1] = 1 6754*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6755*da0073e9SAndroid Build Coastguard Worker reference[1, 1, 1, (1, 1)] = 1 6756*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6757*da0073e9SAndroid Build Coastguard Worker reference[3, 3, 3, 3, 3, 3, 3, 3] = 1 6758*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6759*da0073e9SAndroid Build Coastguard Worker reference[0.0] = 1 6760*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 6761*da0073e9SAndroid Build Coastguard Worker reference[0.0:2.0] = 1 6762*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6763*da0073e9SAndroid Build Coastguard Worker reference[0.0, 0.0:2.0] = 1 6764*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6765*da0073e9SAndroid Build Coastguard Worker reference[0.0, :, 0.0:2.0] = 1 6766*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6767*da0073e9SAndroid Build Coastguard Worker reference[0.0, ..., 0.0:2.0] = 1 6768*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(IndexError): 6769*da0073e9SAndroid Build Coastguard Worker reference[0.0, :, 0.0] = 1 6770*da0073e9SAndroid Build Coastguard Worker 6771*da0073e9SAndroid Build Coastguard Worker # Test `torch._check*` functions 6772*da0073e9SAndroid Build Coastguard Worker def test_check(self): 6773*da0073e9SAndroid Build Coastguard Worker test_cases = [ 6774*da0073e9SAndroid Build Coastguard Worker # check function, expected error 6775*da0073e9SAndroid Build Coastguard Worker (torch._check, RuntimeError), 6776*da0073e9SAndroid Build Coastguard Worker (torch._check_index, IndexError), 6777*da0073e9SAndroid Build Coastguard Worker (torch._check_value, ValueError), 6778*da0073e9SAndroid Build Coastguard Worker (torch._check_type, TypeError), 6779*da0073e9SAndroid Build Coastguard Worker (torch._check_not_implemented, NotImplementedError), 6780*da0073e9SAndroid Build Coastguard Worker ] 6781*da0073e9SAndroid Build Coastguard Worker 6782*da0073e9SAndroid Build Coastguard Worker for check_fn, expected_error in test_cases: 6783*da0073e9SAndroid Build Coastguard Worker # cond=True should not raise an error 6784*da0073e9SAndroid Build Coastguard Worker check_fn(True) 6785*da0073e9SAndroid Build Coastguard Worker 6786*da0073e9SAndroid Build Coastguard Worker # Test default failure message for cond=False 6787*da0073e9SAndroid Build Coastguard Worker default_message = 'Expected cond to be True' 6788*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, default_message): 6789*da0073e9SAndroid Build Coastguard Worker check_fn(False) 6790*da0073e9SAndroid Build Coastguard Worker 6791*da0073e9SAndroid Build Coastguard Worker # Test a simple failure message 6792*da0073e9SAndroid Build Coastguard Worker message = 'message' 6793*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, message): 6794*da0073e9SAndroid Build Coastguard Worker check_fn(False, lambda: message) 6795*da0073e9SAndroid Build Coastguard Worker 6796*da0073e9SAndroid Build Coastguard Worker # Test message with tensor 6797*da0073e9SAndroid Build Coastguard Worker def message(): 6798*da0073e9SAndroid Build Coastguard Worker return torch.arange(4) 6799*da0073e9SAndroid Build Coastguard Worker 6800*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, re.escape(str(message()))): 6801*da0073e9SAndroid Build Coastguard Worker check_fn(False, message) 6802*da0073e9SAndroid Build Coastguard Worker 6803*da0073e9SAndroid Build Coastguard Worker # Test format string message 6804*da0073e9SAndroid Build Coastguard Worker def message(): 6805*da0073e9SAndroid Build Coastguard Worker return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}" 6806*da0073e9SAndroid Build Coastguard Worker 6807*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(expected_error, re.escape(str(message()))): 6808*da0073e9SAndroid Build Coastguard Worker check_fn(False, message) 6809*da0073e9SAndroid Build Coastguard Worker 6810*da0073e9SAndroid Build Coastguard Worker # Test incorrect `cond` arg type 6811*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 'cond must be a bool'): 6812*da0073e9SAndroid Build Coastguard Worker check_fn('wrong type') 6813*da0073e9SAndroid Build Coastguard Worker 6814*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, 'cond must be a bool'): 6815*da0073e9SAndroid Build Coastguard Worker check_fn(torch.tensor(True)) 6816*da0073e9SAndroid Build Coastguard Worker 6817*da0073e9SAndroid Build Coastguard Worker # FIXME: move to indexing test suite 6818*da0073e9SAndroid Build Coastguard Worker def test_index_add(self): 6819*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 6820*da0073e9SAndroid Build Coastguard Worker for dest_contig, src_contig, index_contig in product([True, False], repeat=3): 6821*da0073e9SAndroid Build Coastguard Worker for other_sizes in ((), (4, 5)): 6822*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.int, torch.long]: 6823*da0073e9SAndroid Build Coastguard Worker num_copy, num_dest = 3, 3 6824*da0073e9SAndroid Build Coastguard Worker dest = torch.randn(num_dest, *other_sizes, device=device) 6825*da0073e9SAndroid Build Coastguard Worker if not dest_contig: 6826*da0073e9SAndroid Build Coastguard Worker dest = make_tensor(dest.shape, device=device, dtype=dest.dtype, noncontiguous=True) 6827*da0073e9SAndroid Build Coastguard Worker src = torch.randn(num_copy, *other_sizes, device=device) 6828*da0073e9SAndroid Build Coastguard Worker if not src_contig: 6829*da0073e9SAndroid Build Coastguard Worker src = noncontiguous_like(src) 6830*da0073e9SAndroid Build Coastguard Worker idx = torch.randperm(num_dest, dtype=dtype, device=device).narrow(0, 0, num_copy) 6831*da0073e9SAndroid Build Coastguard Worker if not index_contig: 6832*da0073e9SAndroid Build Coastguard Worker idx = noncontiguous_like(idx) 6833*da0073e9SAndroid Build Coastguard Worker # index_add_ without alpha argument 6834*da0073e9SAndroid Build Coastguard Worker dest2 = dest.clone() 6835*da0073e9SAndroid Build Coastguard Worker dest.index_add_(0, idx, src) 6836*da0073e9SAndroid Build Coastguard Worker for i in range(idx.size(0)): 6837*da0073e9SAndroid Build Coastguard Worker dest2[idx[i]] += src[i] 6838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest, dest2) 6839*da0073e9SAndroid Build Coastguard Worker # index_add_ with alpha argument 6840*da0073e9SAndroid Build Coastguard Worker dest2 = dest.clone() 6841*da0073e9SAndroid Build Coastguard Worker dest.index_add_(0, idx, src, alpha=2) 6842*da0073e9SAndroid Build Coastguard Worker for i in range(idx.size(0)): 6843*da0073e9SAndroid Build Coastguard Worker dest2[idx[i]] += src[i] * 2 6844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dest, dest2) 6845*da0073e9SAndroid Build Coastguard Worker 6846*da0073e9SAndroid Build Coastguard Worker # FIXME: resolve comment below and move this to indexing test suite 6847*da0073e9SAndroid Build Coastguard Worker # add coverage for issue with atomic add that appeared only for 6848*da0073e9SAndroid Build Coastguard Worker # specific dtypes on cuda: 6849*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/29153 6850*da0073e9SAndroid Build Coastguard Worker def test_index_add_all_dtypes(self): 6851*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 6852*da0073e9SAndroid Build Coastguard Worker for dtype in get_all_math_dtypes(device): 6853*da0073e9SAndroid Build Coastguard Worker for idx_dtype in [torch.int, torch.long]: 6854*da0073e9SAndroid Build Coastguard Worker size = [5, 5] 6855*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 6856*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(size, dtype=dtype, device=device) 6857*da0073e9SAndroid Build Coastguard Worker elif dtype.is_signed: 6858*da0073e9SAndroid Build Coastguard Worker tensor = torch.randint(-5, 15, size, dtype=dtype, device=device) 6859*da0073e9SAndroid Build Coastguard Worker else: 6860*da0073e9SAndroid Build Coastguard Worker tensor = torch.randint(0, 10, size, dtype=dtype, device=device) 6861*da0073e9SAndroid Build Coastguard Worker 6862*da0073e9SAndroid Build Coastguard Worker # index_add calls atomicAdd on cuda. 6863*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros(size, dtype=dtype, device=device) 6864*da0073e9SAndroid Build Coastguard Worker 6865*da0073e9SAndroid Build Coastguard Worker added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor) 6866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(added, tensor) 6867*da0073e9SAndroid Build Coastguard Worker 6868*da0073e9SAndroid Build Coastguard Worker added = zeros.index_add(0, torch.arange(0, size[0], dtype=idx_dtype, device=device), tensor, alpha=-1) 6869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(added, -tensor) 6870*da0073e9SAndroid Build Coastguard Worker 6871*da0073e9SAndroid Build Coastguard Worker @unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False) 6872*da0073e9SAndroid Build Coastguard Worker @set_default_dtype(torch.double) 6873*da0073e9SAndroid Build Coastguard Worker def test_index_add_correctness(self): 6874*da0073e9SAndroid Build Coastguard Worker # Check whether index_add can get correct result when 6875*da0073e9SAndroid Build Coastguard Worker # alpha is 1, and dtype of index is torch.long, 6876*da0073e9SAndroid Build Coastguard Worker # i.e., using scatter_add 6877*da0073e9SAndroid Build Coastguard Worker def helper(dim, dtype, device, size_result, size_source): 6878*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(size_result, dtype=dtype, device=device) 6879*da0073e9SAndroid Build Coastguard Worker index = torch.randint(0, size_result[dim], (size_source[dim],), 6880*da0073e9SAndroid Build Coastguard Worker dtype=torch.long, device=device) 6881*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point or dtype.is_complex: 6882*da0073e9SAndroid Build Coastguard Worker source = torch.rand(size_source, dtype=dtype, device=device) 6883*da0073e9SAndroid Build Coastguard Worker elif dtype.is_signed: 6884*da0073e9SAndroid Build Coastguard Worker source = torch.randint(-2, 5, size_source, dtype=dtype, device=device) 6885*da0073e9SAndroid Build Coastguard Worker else: 6886*da0073e9SAndroid Build Coastguard Worker source = torch.randint(0, 5, size_source, dtype=dtype, device=device) 6887*da0073e9SAndroid Build Coastguard Worker 6888*da0073e9SAndroid Build Coastguard Worker ref_out = tensor.index_add(dim, index, source, alpha=2.) / 2. 6889*da0073e9SAndroid Build Coastguard Worker ref_out = ref_out.to(dtype=dtype) 6890*da0073e9SAndroid Build Coastguard Worker out = tensor.index_add(dim, index, source) 6891*da0073e9SAndroid Build Coastguard Worker if device == 'cuda': 6892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out, atol=1e-2, rtol=1e-2) 6893*da0073e9SAndroid Build Coastguard Worker else: 6894*da0073e9SAndroid Build Coastguard Worker # scatter_add uses fp32 as accumulate type, while index_add doesn't. 6895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ref_out.to(dtype=dtype), atol=1e-2, rtol=1e-2) 6896*da0073e9SAndroid Build Coastguard Worker 6897*da0073e9SAndroid Build Coastguard Worker for dim in [-1, -2, -3]: 6898*da0073e9SAndroid Build Coastguard Worker for dtype in all_types_and_complex_and(torch.half, torch.bfloat16): 6899*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 6900*da0073e9SAndroid Build Coastguard Worker for size in [(2, 512, 256), (5, 256, 256)]: 6901*da0073e9SAndroid Build Coastguard Worker helper(dim, dtype, device, size, size) 6902*da0073e9SAndroid Build Coastguard Worker 6903*da0073e9SAndroid Build Coastguard Worker # Check bound 6904*da0073e9SAndroid Build Coastguard Worker result = torch.zeros(1, 512, 256, dtype=dtype) 6905*da0073e9SAndroid Build Coastguard Worker source = torch.ones(1, 512, 256, dtype=dtype) 6906*da0073e9SAndroid Build Coastguard Worker index = torch.ones(257).to(dtype=torch.long) 6907*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source)) 6908*da0073e9SAndroid Build Coastguard Worker index = (torch.ones(256) * 257).to(dtype=torch.long) 6909*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: result.index_add_(dim, index, source)) 6910*da0073e9SAndroid Build Coastguard Worker 6911*da0073e9SAndroid Build Coastguard Worker def test_index_add_cornercase(self): 6912*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 6913*da0073e9SAndroid Build Coastguard Worker dest = torch.randn((), device=device) 6914*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([0], device=device) 6915*da0073e9SAndroid Build Coastguard Worker source = torch.randn(1, 1, 1, device=device) 6916*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6917*da0073e9SAndroid Build Coastguard Worker RuntimeError, 6918*da0073e9SAndroid Build Coastguard Worker r"source tensor shape must match self tensor shape, excluding the specified dimension", 6919*da0073e9SAndroid Build Coastguard Worker ): 6920*da0073e9SAndroid Build Coastguard Worker dest.index_add(0, index, source) 6921*da0073e9SAndroid Build Coastguard Worker 6922*da0073e9SAndroid Build Coastguard Worker def test_linspace_logspace(self): 6923*da0073e9SAndroid Build Coastguard Worker # Ensure the output does not require grad regardless of inputs requiring gard or not. 6924*da0073e9SAndroid Build Coastguard Worker # The output of factory functions should not be part of any computational graph. 6925*da0073e9SAndroid Build Coastguard Worker start = 0.0 6926*da0073e9SAndroid Build Coastguard Worker end = 3.0 6927*da0073e9SAndroid Build Coastguard Worker 6928*da0073e9SAndroid Build Coastguard Worker for step in [0, 1, 2]: 6929*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6930*da0073e9SAndroid Build Coastguard Worker torch.linspace( 6931*da0073e9SAndroid Build Coastguard Worker torch.tensor(start, requires_grad=True), 6932*da0073e9SAndroid Build Coastguard Worker torch.tensor(end, requires_grad=True), step 6933*da0073e9SAndroid Build Coastguard Worker ).requires_grad 6934*da0073e9SAndroid Build Coastguard Worker ) 6935*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.linspace(torch.tensor(start, requires_grad=True), end, step).requires_grad) 6936*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.linspace(start, torch.tensor(end, requires_grad=True), step).requires_grad) 6937*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6938*da0073e9SAndroid Build Coastguard Worker torch.logspace( 6939*da0073e9SAndroid Build Coastguard Worker torch.tensor(start, requires_grad=True), 6940*da0073e9SAndroid Build Coastguard Worker torch.tensor(end, requires_grad=True), step 6941*da0073e9SAndroid Build Coastguard Worker ).requires_grad 6942*da0073e9SAndroid Build Coastguard Worker ) 6943*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.logspace(torch.tensor(start, requires_grad=True), end, step).requires_grad) 6944*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.logspace(start, torch.tensor(end, requires_grad=True), step).requires_grad) 6945*da0073e9SAndroid Build Coastguard Worker 6946*da0073e9SAndroid Build Coastguard Worker # FIXME: move to shape ops test suite 6947*da0073e9SAndroid Build Coastguard Worker def test_unflatten(self): 6948*da0073e9SAndroid Build Coastguard Worker # test args: tensor, int, sizes 6949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([]).unflatten(0, (0, 1)), torch.empty(0, 1)) 6950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1]).unflatten(0, (1, 1)), torch.tensor([[1]])) 6951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (2, 2)), torch.tensor([[1, 2], [3, 4]])) 6952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, [2, 2]), torch.tensor([[1, 2], [3, 4]])) 6953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, torch.Size([2, 2])), torch.tensor([[1, 2], [3, 4]])) 6954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(2, 10).unflatten(1, (5, 2)), torch.ones(2, 5, 2)) 6955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1, 2, 3, 4]).unflatten(0, (-1, 2)), 6956*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 2], [3, 4]])) 6957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(2, 10).unflatten(1, (5, -1)), 6958*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 5, 2)) 6959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(2, 10).unflatten(1, (-1,)), 6960*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 10)) 6961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(2, 3 * 4 * 5 * 6).unflatten(1, (3, 4, -1, 6)), 6962*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 3, 4, 5, 6)) 6963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ones(2, 0, 2).unflatten(1, (3, -1, 4, 5)), 6964*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 3, 0, 4, 5, 2)) 6965*da0073e9SAndroid Build Coastguard Worker 6966*da0073e9SAndroid Build Coastguard Worker # test invalid args: tensor, str, sizes 6967*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"unflatten\(\): argument 'dim' \(position 1\) must be int, not str"): 6968*da0073e9SAndroid Build Coastguard Worker torch.tensor([1]).unflatten('A', (1, 1)) 6969*da0073e9SAndroid Build Coastguard Worker 6970*da0073e9SAndroid Build Coastguard Worker # test invalid args: tensor, str, namedshape 6971*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Name 'A' not found in Tensor\[None\]."): 6972*da0073e9SAndroid Build Coastguard Worker torch.ones(4).unflatten('A', (('A', 2), ('B', 2))) 6973*da0073e9SAndroid Build Coastguard Worker 6974*da0073e9SAndroid Build Coastguard Worker # test other invalid arguments 6975*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"sizes must be non-empty"): 6976*da0073e9SAndroid Build Coastguard Worker torch.tensor([1]).unflatten(0, []) 6977*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Provided sizes \[2, 2\] don't multiply up to the size of dim 0 \(1\)"): 6978*da0073e9SAndroid Build Coastguard Worker torch.tensor([1]).unflatten(0, [2, 2]) 6979*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, r"Dimension specified as 0 but tensor has no dimensions"): 6980*da0073e9SAndroid Build Coastguard Worker torch.tensor(1).unflatten(0, [0]) 6981*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"only one dimension can be inferred"): 6982*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 10).unflatten(1, (-1, -1)) 6983*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 6984*da0073e9SAndroid Build Coastguard Worker r"Provided sizes \[-1, 4\] don't multiply up to the size of dim 1 \(10\)"): 6985*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 10).unflatten(1, (-1, 4)) 6986*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 6987*da0073e9SAndroid Build Coastguard Worker r"the unspecified dimension size -1 can be any value and is ambiguous"): 6988*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 0).unflatten(1, (2, -1, 0)) 6989*da0073e9SAndroid Build Coastguard Worker 6990*da0073e9SAndroid Build Coastguard Worker # Test that warnings generated from C++ are translated to the correct type 6991*da0073e9SAndroid Build Coastguard Worker def test_warn_types(self): 6992*da0073e9SAndroid Build Coastguard Worker test_cases = [ 6993*da0073e9SAndroid Build Coastguard Worker # function, warning type, message 6994*da0073e9SAndroid Build Coastguard Worker (torch._C._warn, UserWarning, r"Test message for TORCH_WARN"), 6995*da0073e9SAndroid Build Coastguard Worker (torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"), 6996*da0073e9SAndroid Build Coastguard Worker ] 6997*da0073e9SAndroid Build Coastguard Worker 6998*da0073e9SAndroid Build Coastguard Worker for fn, warning_type, message in test_cases: 6999*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 7000*da0073e9SAndroid Build Coastguard Worker warnings.resetwarnings() 7001*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings('always', category=warning_type) 7002*da0073e9SAndroid Build Coastguard Worker fn() 7003*da0073e9SAndroid Build Coastguard Worker 7004*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1, msg=f'{warning_type} not raised') 7005*da0073e9SAndroid Build Coastguard Worker warning = w[0].message 7006*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised') 7007*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search( 7008*da0073e9SAndroid Build Coastguard Worker message, 7009*da0073e9SAndroid Build Coastguard Worker str(warning))) 7010*da0073e9SAndroid Build Coastguard Worker 7011*da0073e9SAndroid Build Coastguard Worker def test_structseq_repr(self): 7012*da0073e9SAndroid Build Coastguard Worker a = torch.arange(250).reshape(5, 5, 10) 7013*da0073e9SAndroid Build Coastguard Worker expected = """ 7014*da0073e9SAndroid Build Coastguard Worker torch.return_types.max( 7015*da0073e9SAndroid Build Coastguard Worker values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], 7016*da0073e9SAndroid Build Coastguard Worker [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], 7017*da0073e9SAndroid Build Coastguard Worker [140, 141, 142, 143, 144, 145, 146, 147, 148, 149], 7018*da0073e9SAndroid Build Coastguard Worker [190, 191, 192, 193, 194, 195, 196, 197, 198, 199], 7019*da0073e9SAndroid Build Coastguard Worker [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]), 7020*da0073e9SAndroid Build Coastguard Worker indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7021*da0073e9SAndroid Build Coastguard Worker [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7022*da0073e9SAndroid Build Coastguard Worker [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7023*da0073e9SAndroid Build Coastguard Worker [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], 7024*da0073e9SAndroid Build Coastguard Worker [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))""" 7025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip()) 7026*da0073e9SAndroid Build Coastguard Worker 7027*da0073e9SAndroid Build Coastguard Worker def test_is_same_size(self): 7028*da0073e9SAndroid Build Coastguard Worker t1 = torch.empty(3, 4, 9, 10) 7029*da0073e9SAndroid Build Coastguard Worker t2 = torch.empty(3, 4) 7030*da0073e9SAndroid Build Coastguard Worker t3 = torch.empty(1, 9, 3, 3) 7031*da0073e9SAndroid Build Coastguard Worker t4 = torch.empty(3, 4, 9, 10) 7032*da0073e9SAndroid Build Coastguard Worker 7033*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_same_size(t2)) 7034*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.is_same_size(t3)) 7035*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t1.is_same_size(t4)) 7036*da0073e9SAndroid Build Coastguard Worker 7037*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)]) 7038*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(2, 4), torch.ones(2, 4)]) 7039*da0073e9SAndroid Build Coastguard Worker nt3 = torch.nested.nested_tensor([torch.ones(2, 4, 5), torch.ones(2, 6, 5)]) 7040*da0073e9SAndroid Build Coastguard Worker nt4 = torch.nested.nested_tensor([torch.ones(2, 4), torch.ones(3, 4), torch.ones(5, 4)]) 7041*da0073e9SAndroid Build Coastguard Worker 7042*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt1.is_same_size(nt2)) 7043*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt1.is_same_size(nt3)) 7044*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt1.is_same_size(nt4)) 7045*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."): 7046*da0073e9SAndroid Build Coastguard Worker t1.is_same_size(nt1) 7047*da0073e9SAndroid Build Coastguard Worker 7048*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected both self and other to be nested tensors."): 7049*da0073e9SAndroid Build Coastguard Worker nt1.is_same_size(t1) 7050*da0073e9SAndroid Build Coastguard Worker 7051*da0073e9SAndroid Build Coastguard Worker def test_tensor_set(self): 7052*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([]) 7053*da0073e9SAndroid Build Coastguard Worker t2 = torch.empty(3, 4, 9, 10).uniform_() 7054*da0073e9SAndroid Build Coastguard Worker t1.set_(t2) 7055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) 7056*da0073e9SAndroid Build Coastguard Worker size = torch.Size([9, 3, 4, 10]) 7057*da0073e9SAndroid Build Coastguard Worker t1.set_(t2.storage(), 0, size) 7058*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.size(), size) 7059*da0073e9SAndroid Build Coastguard Worker t1.set_(t2.storage(), 0, tuple(size)) 7060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.size(), size) 7061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.stride(), (120, 40, 10, 1)) 7062*da0073e9SAndroid Build Coastguard Worker stride = (10, 360, 90, 1) 7063*da0073e9SAndroid Build Coastguard Worker t1.set_(t2.storage(), 0, size, stride) 7064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.stride(), stride) 7065*da0073e9SAndroid Build Coastguard Worker t1.set_(t2.storage(), 0, size=size, stride=stride) 7066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.size(), size) 7067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.stride(), stride) 7068*da0073e9SAndroid Build Coastguard Worker 7069*da0073e9SAndroid Build Coastguard Worker # test argument names 7070*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([]) 7071*da0073e9SAndroid Build Coastguard Worker # 1. case when source is tensor 7072*da0073e9SAndroid Build Coastguard Worker t1.set_(source=t2) 7073*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) 7074*da0073e9SAndroid Build Coastguard Worker # 2. case when source is storage 7075*da0073e9SAndroid Build Coastguard Worker t1.set_(source=t2.storage()) 7076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) 7077*da0073e9SAndroid Build Coastguard Worker # 3. case when source is storage, and other args also specified 7078*da0073e9SAndroid Build Coastguard Worker t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) 7079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.size(), size) 7080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.stride(), stride) 7081*da0073e9SAndroid Build Coastguard Worker 7082*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([True, True], dtype=torch.bool) 7083*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([False, False], dtype=torch.bool) 7084*da0073e9SAndroid Build Coastguard Worker t1.set_(t2) 7085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) 7086*da0073e9SAndroid Build Coastguard Worker 7087*da0073e9SAndroid Build Coastguard Worker def test_tensor_set_errors(self): 7088*da0073e9SAndroid Build Coastguard Worker f_cpu = torch.randn((2, 3), dtype=torch.float32) 7089*da0073e9SAndroid Build Coastguard Worker d_cpu = torch.randn((2, 3), dtype=torch.float64) 7090*da0073e9SAndroid Build Coastguard Worker 7091*da0073e9SAndroid Build Coastguard Worker # change dtype 7092*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage())) 7093*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 7094*da0073e9SAndroid Build Coastguard Worker lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride())) 7095*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu)) 7096*da0073e9SAndroid Build Coastguard Worker 7097*da0073e9SAndroid Build Coastguard Worker # change device 7098*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7099*da0073e9SAndroid Build Coastguard Worker f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda') 7100*da0073e9SAndroid Build Coastguard Worker 7101*da0073e9SAndroid Build Coastguard Worker # cpu -> cuda 7102*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage())) 7103*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 7104*da0073e9SAndroid Build Coastguard Worker lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride())) 7105*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda)) 7106*da0073e9SAndroid Build Coastguard Worker 7107*da0073e9SAndroid Build Coastguard Worker # cuda -> cpu 7108*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage())) 7109*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, 7110*da0073e9SAndroid Build Coastguard Worker lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride())) 7111*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu)) 7112*da0073e9SAndroid Build Coastguard Worker 7113*da0073e9SAndroid Build Coastguard Worker # FIXME: move this test test_testing.py (along with allclose testing) 7114*da0073e9SAndroid Build Coastguard Worker # NOTE: test_equal will be deprecated in favor of torch.testing.assert_close 7115*da0073e9SAndroid Build Coastguard Worker # once torch.testing is out of beta 7116*da0073e9SAndroid Build Coastguard Worker def test_equal(self): 7117*da0073e9SAndroid Build Coastguard Worker devices = [torch.cpu, torch.cuda] 7118*da0073e9SAndroid Build Coastguard Worker for device in ["cpu", "cuda"]: 7119*da0073e9SAndroid Build Coastguard Worker if device == "cuda" and not torch.cuda.is_available(): 7120*da0073e9SAndroid Build Coastguard Worker continue 7121*da0073e9SAndroid Build Coastguard Worker 7122*da0073e9SAndroid Build Coastguard Worker # Contiguous, 1D 7123*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor((3., 4., 9., 10.), device=device) 7124*da0073e9SAndroid Build Coastguard Worker t2 = t1.contiguous() 7125*da0073e9SAndroid Build Coastguard Worker t3 = torch.tensor((1., 9., 3., 10.), device=device) 7126*da0073e9SAndroid Build Coastguard Worker t4 = torch.tensor((3., 4., 9.), device=device) 7127*da0073e9SAndroid Build Coastguard Worker t5 = torch.tensor([], device=device) 7128*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t1.equal(t2)) 7129*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.equal(t3)) 7130*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.equal(t4)) 7131*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t1.equal(t5)) 7132*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(t1, t2)) 7133*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(t1, t3)) 7134*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(t1, t4)) 7135*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(t1, t5)) 7136*da0073e9SAndroid Build Coastguard Worker 7137*da0073e9SAndroid Build Coastguard Worker # Non contiguous, 2D 7138*da0073e9SAndroid Build Coastguard Worker s = torch.tensor(((1, 2, 3, 4), (5, 6, 7, 8)), device=device) 7139*da0073e9SAndroid Build Coastguard Worker s1 = s[:, 1:3] 7140*da0073e9SAndroid Build Coastguard Worker s2 = s1.clone() 7141*da0073e9SAndroid Build Coastguard Worker s3 = torch.tensor(((2, 3), (6, 7)), device=device) 7142*da0073e9SAndroid Build Coastguard Worker s4 = torch.tensor(((0, 0), (0, 0)), device=device) 7143*da0073e9SAndroid Build Coastguard Worker 7144*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1.is_contiguous()) 7145*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.equal(s2)) 7146*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.equal(s3)) 7147*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s1.equal(s4)) 7148*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(s1, s2)) 7149*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(s1, s3)) 7150*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(s1, s4)) 7151*da0073e9SAndroid Build Coastguard Worker 7152*da0073e9SAndroid Build Coastguard Worker # Different dtypes 7153*da0073e9SAndroid Build Coastguard Worker x = torch.tensor((1, 2, 3), dtype=torch.float, device=device) 7154*da0073e9SAndroid Build Coastguard Worker y = torch.tensor((1, 2, 3), dtype=torch.int, device=device) 7155*da0073e9SAndroid Build Coastguard Worker z = torch.tensor((1, -1), dtype=torch.int, device=device) 7156*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(x, y)) 7157*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(z, x)) 7158*da0073e9SAndroid Build Coastguard Worker 7159*da0073e9SAndroid Build Coastguard Worker # Fast path test: tensor flags, like neg and conj 7160*da0073e9SAndroid Build Coastguard Worker neg_0 = torch.tensor((1, 2, 3), dtype=torch.float, device=device) 7161*da0073e9SAndroid Build Coastguard Worker neg_1 = neg_0._neg_view() 7162*da0073e9SAndroid Build Coastguard Worker self.assertTrue(neg_1.is_neg()) 7163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(neg_0.data_ptr(), neg_1.data_ptr()) 7164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(neg_0.storage_offset(), neg_1.storage_offset()) 7165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(neg_0.stride(), neg_1.stride()) 7166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(neg_0.size(), neg_1.size()) 7167*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(neg_0, neg_1)) 7168*da0073e9SAndroid Build Coastguard Worker # FIXME: Disable the following check due to the inductor failure 7169*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/100340 and 7170*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/98175 7171*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 7172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(neg_0, neg_1._neg_view())) 7173*da0073e9SAndroid Build Coastguard Worker 7174*da0073e9SAndroid Build Coastguard Worker conj_0 = torch.tensor([1.0 + 2.0j, 2.0 + 1.0j], device=device) 7175*da0073e9SAndroid Build Coastguard Worker conj_1 = conj_0.conj() 7176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(conj_1.is_conj()) 7177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conj_0.data_ptr(), conj_1.data_ptr()) 7178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conj_0.storage_offset(), conj_1.storage_offset()) 7179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conj_0.stride(), conj_1.stride()) 7180*da0073e9SAndroid Build Coastguard Worker self.assertEqual(conj_0.size(), conj_1.size()) 7181*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(conj_0, conj_1)) 7182*da0073e9SAndroid Build Coastguard Worker # FIXME: Disable the following check due to the inductor failure 7183*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/100340 and 7184*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/98175 7185*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 7186*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(conj_0, conj_1.conj())) 7187*da0073e9SAndroid Build Coastguard Worker 7188*da0073e9SAndroid Build Coastguard Worker # Fast path test: two tensors share the same storage, but different dtype 7189*da0073e9SAndroid Build Coastguard Worker s_0 = torch.rand((2, 3), dtype=torch.float, device=device) 7190*da0073e9SAndroid Build Coastguard Worker s_1 = s_0.view(dtype=torch.int32) 7191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_0.data_ptr(), s_1.data_ptr()) 7192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_0.storage_offset(), s_1.storage_offset()) 7193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_0.stride(), s_1.stride()) 7194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_0.size(), s_1.size()) 7195*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(s_0, s_1)) 7196*da0073e9SAndroid Build Coastguard Worker 7197*da0073e9SAndroid Build Coastguard Worker # Fast path test: two tensors share the same storage, but different strides 7198*da0073e9SAndroid Build Coastguard Worker t_0 = torch.rand((2, 3), dtype=torch.float, device=device) 7199*da0073e9SAndroid Build Coastguard Worker t_1 = t_0.t() 7200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_0.data_ptr(), t_1.data_ptr()) 7201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_0.storage_offset(), t_1.storage_offset()) 7202*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t_0.stride(), t_1.stride()) 7203*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t_0.size(), t_1.size()) 7204*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(t_0, t_1)) 7205*da0073e9SAndroid Build Coastguard Worker 7206*da0073e9SAndroid Build Coastguard Worker # Fast path: tensor containing `nan` is not equal to self 7207*da0073e9SAndroid Build Coastguard Worker for dtype in floating_and_complex_types(): 7208*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1., float('nan')], dtype=dtype) 7209*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.equal(t, t)) 7210*da0073e9SAndroid Build Coastguard Worker 7211*da0073e9SAndroid Build Coastguard Worker def test_element_size(self): 7212*da0073e9SAndroid Build Coastguard Worker byte = torch.ByteStorage().element_size() 7213*da0073e9SAndroid Build Coastguard Worker char = torch.CharStorage().element_size() 7214*da0073e9SAndroid Build Coastguard Worker short = torch.ShortStorage().element_size() 7215*da0073e9SAndroid Build Coastguard Worker int = torch.IntStorage().element_size() 7216*da0073e9SAndroid Build Coastguard Worker long = torch.LongStorage().element_size() 7217*da0073e9SAndroid Build Coastguard Worker float = torch.FloatStorage().element_size() 7218*da0073e9SAndroid Build Coastguard Worker double = torch.DoubleStorage().element_size() 7219*da0073e9SAndroid Build Coastguard Worker bool = torch.BoolStorage().element_size() 7220*da0073e9SAndroid Build Coastguard Worker bfloat16 = torch.BFloat16Storage().element_size() 7221*da0073e9SAndroid Build Coastguard Worker complexfloat = torch.ComplexFloatStorage().element_size() 7222*da0073e9SAndroid Build Coastguard Worker complexdouble = torch.ComplexDoubleStorage().element_size() 7223*da0073e9SAndroid Build Coastguard Worker 7224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byte, torch.ByteTensor().element_size()) 7225*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byte, torch.ByteTensor().itemsize) 7226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(char, torch.CharTensor().element_size()) 7227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(char, torch.CharTensor().itemsize) 7228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(short, torch.ShortTensor().element_size()) 7229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(short, torch.ShortTensor().itemsize) 7230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(int, torch.IntTensor().element_size()) 7231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(int, torch.IntTensor().itemsize) 7232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(long, torch.LongTensor().element_size()) 7233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(long, torch.LongTensor().itemsize) 7234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float, torch.FloatTensor().element_size()) 7235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float, torch.FloatTensor().itemsize) 7236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double, torch.DoubleTensor().element_size()) 7237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double, torch.DoubleTensor().itemsize) 7238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bool, torch.BoolTensor().element_size()) 7239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bool, torch.BoolTensor().itemsize) 7240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).element_size()) 7241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bfloat16, torch.tensor([], dtype=torch.bfloat16).itemsize) 7242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).element_size()) 7243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexfloat, torch.tensor([], dtype=torch.complex64).itemsize) 7244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).element_size()) 7245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexdouble, torch.tensor([], dtype=torch.complex128).itemsize) 7246*da0073e9SAndroid Build Coastguard Worker 7247*da0073e9SAndroid Build Coastguard Worker self.assertGreater(byte, 0) 7248*da0073e9SAndroid Build Coastguard Worker self.assertGreater(char, 0) 7249*da0073e9SAndroid Build Coastguard Worker self.assertGreater(short, 0) 7250*da0073e9SAndroid Build Coastguard Worker self.assertGreater(int, 0) 7251*da0073e9SAndroid Build Coastguard Worker self.assertGreater(long, 0) 7252*da0073e9SAndroid Build Coastguard Worker self.assertGreater(float, 0) 7253*da0073e9SAndroid Build Coastguard Worker self.assertGreater(double, 0) 7254*da0073e9SAndroid Build Coastguard Worker self.assertGreater(bool, 0) 7255*da0073e9SAndroid Build Coastguard Worker self.assertGreater(bfloat16, 0) 7256*da0073e9SAndroid Build Coastguard Worker self.assertGreater(complexfloat, 0) 7257*da0073e9SAndroid Build Coastguard Worker self.assertGreater(complexdouble, 0) 7258*da0073e9SAndroid Build Coastguard Worker 7259*da0073e9SAndroid Build Coastguard Worker # These tests are portable, not necessarily strict for your system. 7260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byte, 1) 7261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(char, 1) 7262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bool, 1) 7263*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(short, 2) 7264*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(int, 2) 7265*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(int, short) 7266*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(long, 4) 7267*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(long, int) 7268*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(double, float) 7269*da0073e9SAndroid Build Coastguard Worker 7270*da0073e9SAndroid Build Coastguard Worker def test_permute(self): 7271*da0073e9SAndroid Build Coastguard Worker orig = [1, 2, 3, 4, 5, 6, 7] 7272*da0073e9SAndroid Build Coastguard Worker perm = torch.randperm(7).tolist() 7273*da0073e9SAndroid Build Coastguard Worker x = torch.empty(*orig).fill_(0) 7274*da0073e9SAndroid Build Coastguard Worker new = [i - 1 for i in x.permute(*perm).size()] 7275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(perm, new) 7276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(), orig) 7277*da0073e9SAndroid Build Coastguard Worker 7278*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 7279*da0073e9SAndroid Build Coastguard Worker def test_reversed(self): 7280*da0073e9SAndroid Build Coastguard Worker val = torch.arange(0, 10) 7281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reversed(val), torch.arange(9, -1, -1)) 7282*da0073e9SAndroid Build Coastguard Worker 7283*da0073e9SAndroid Build Coastguard Worker val = torch.arange(1, 10).view(3, 3) 7284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reversed(val), torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]])) 7285*da0073e9SAndroid Build Coastguard Worker 7286*da0073e9SAndroid Build Coastguard Worker val = torch.tensor(42) 7287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(reversed(val), torch.tensor(42)) 7288*da0073e9SAndroid Build Coastguard Worker 7289*da0073e9SAndroid Build Coastguard Worker def test_contains(self): 7290*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 10) 7291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4 in x, True) 7292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(12 in x, False) 7293*da0073e9SAndroid Build Coastguard Worker 7294*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1, 10).view(3, 3) 7295*da0073e9SAndroid Build Coastguard Worker val = torch.arange(1, 4) 7296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(val in x, True) 7297*da0073e9SAndroid Build Coastguard Worker val += 10 7298*da0073e9SAndroid Build Coastguard Worker self.assertEqual(val in x, False) 7299*da0073e9SAndroid Build Coastguard Worker 7300*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 7301*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7302*da0073e9SAndroid Build Coastguard Worker f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {str}.", 7303*da0073e9SAndroid Build Coastguard Worker lambda: "foo" in x) 7304*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 7305*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7306*da0073e9SAndroid Build Coastguard Worker f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type([1, 2])}.", 7307*da0073e9SAndroid Build Coastguard Worker lambda: [1, 2] in x) 7308*da0073e9SAndroid Build Coastguard Worker 7309*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 7310*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_parameter(self): 7311*da0073e9SAndroid Build Coastguard Worker from copy import deepcopy 7312*da0073e9SAndroid Build Coastguard Worker l = torch.nn.Linear(10, 1) 7313*da0073e9SAndroid Build Coastguard Worker s = l.state_dict(keep_vars=True) 7314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nn.Parameter, type(s['weight'])) 7315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nn.Parameter, type(s['bias'])) 7316*da0073e9SAndroid Build Coastguard Worker 7317*da0073e9SAndroid Build Coastguard Worker s2 = deepcopy(s) 7318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nn.Parameter, type(s2['weight'])) 7319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nn.Parameter, type(s2['bias'])) 7320*da0073e9SAndroid Build Coastguard Worker 7321*da0073e9SAndroid Build Coastguard Worker def test_pickle(self): 7322*da0073e9SAndroid Build Coastguard Worker import pickle 7323*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5) 7324*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(a) 7325*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 7326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 7327*da0073e9SAndroid Build Coastguard Worker 7328*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 7329*da0073e9SAndroid Build Coastguard Worker def test_pickle_parameter(self): 7330*da0073e9SAndroid Build Coastguard Worker import pickle 7331*da0073e9SAndroid Build Coastguard Worker a = torch.nn.Parameter(torch.randn(5, 5)) 7332*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(a) 7333*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 7334*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(b, torch.nn.Parameter)) 7335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, b.requires_grad) 7336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 7337*da0073e9SAndroid Build Coastguard Worker 7338*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 7339*da0073e9SAndroid Build Coastguard Worker def test_pickle_parameter_no_requires_grad(self): 7340*da0073e9SAndroid Build Coastguard Worker import pickle 7341*da0073e9SAndroid Build Coastguard Worker a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False) 7342*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(a) 7343*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 7344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(b, torch.nn.Parameter)) 7345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, b.requires_grad) 7346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 7347*da0073e9SAndroid Build Coastguard Worker 7348*da0073e9SAndroid Build Coastguard Worker def test_pickle_dtype(self): 7349*da0073e9SAndroid Build Coastguard Worker t = torch.float32 7350*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(t) 7351*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 7352*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(b, torch.dtype)) 7353*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(b), id(t)) 7354*da0073e9SAndroid Build Coastguard Worker 7355*da0073e9SAndroid Build Coastguard Worker def test_pickle_size(self): 7356*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10).size() 7357*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(a) 7358*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 7359*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(b, torch.Size)) 7360*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 7361*da0073e9SAndroid Build Coastguard Worker 7362*da0073e9SAndroid Build Coastguard Worker def test_pickle_function(self): 7363*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/37703 7364*da0073e9SAndroid Build Coastguard Worker a = torch.tanh 7365*da0073e9SAndroid Build Coastguard Worker serialized = pickle.dumps(a) 7366*da0073e9SAndroid Build Coastguard Worker b = pickle.loads(serialized) 7367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 7368*da0073e9SAndroid Build Coastguard Worker 7369*da0073e9SAndroid Build Coastguard Worker def test_generator_cpu(self): 7370*da0073e9SAndroid Build Coastguard Worker # test default generators are equal 7371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.default_generator, torch.default_generator) 7372*da0073e9SAndroid Build Coastguard Worker 7373*da0073e9SAndroid Build Coastguard Worker # tests Generator API 7374*da0073e9SAndroid Build Coastguard Worker # manual_seed, seed, initial_seed, get_state, set_state 7375*da0073e9SAndroid Build Coastguard Worker g1 = torch.Generator() 7376*da0073e9SAndroid Build Coastguard Worker g2 = torch.Generator() 7377*da0073e9SAndroid Build Coastguard Worker g1.manual_seed(12345) 7378*da0073e9SAndroid Build Coastguard Worker g2.manual_seed(12345) 7379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g1.initial_seed(), g2.initial_seed()) 7380*da0073e9SAndroid Build Coastguard Worker 7381*da0073e9SAndroid Build Coastguard Worker g1.seed() 7382*da0073e9SAndroid Build Coastguard Worker g2.seed() 7383*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(g1.initial_seed(), g2.initial_seed()) 7384*da0073e9SAndroid Build Coastguard Worker 7385*da0073e9SAndroid Build Coastguard Worker g1 = torch.Generator() 7386*da0073e9SAndroid Build Coastguard Worker g2_state = g2.get_state() 7387*da0073e9SAndroid Build Coastguard Worker g2_randn = torch.randn(1, generator=g2) 7388*da0073e9SAndroid Build Coastguard Worker g1.set_state(g2_state) 7389*da0073e9SAndroid Build Coastguard Worker g1_randn = torch.randn(1, generator=g1) 7390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g1_randn, g2_randn) 7391*da0073e9SAndroid Build Coastguard Worker 7392*da0073e9SAndroid Build Coastguard Worker default_state = torch.default_generator.get_state() 7393*da0073e9SAndroid Build Coastguard Worker q = torch.empty(100) 7394*da0073e9SAndroid Build Coastguard Worker g1_normal = q.normal_() 7395*da0073e9SAndroid Build Coastguard Worker g2 = torch.Generator() 7396*da0073e9SAndroid Build Coastguard Worker g2.set_state(default_state) 7397*da0073e9SAndroid Build Coastguard Worker g2_normal = q.normal_(generator=g2) 7398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g1_normal, g2_normal) 7399*da0073e9SAndroid Build Coastguard Worker 7400*da0073e9SAndroid Build Coastguard Worker def test_invalid_generator_raises(self): 7401*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.Generator('opengl')) 7402*da0073e9SAndroid Build Coastguard Worker 7403*da0073e9SAndroid Build Coastguard Worker def test_pickle_generator(self) -> None: 7404*da0073e9SAndroid Build Coastguard Worker devices = ['cpu'] 7405*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7406*da0073e9SAndroid Build Coastguard Worker devices += ['cuda'] 7407*da0073e9SAndroid Build Coastguard Worker 7408*da0073e9SAndroid Build Coastguard Worker for device in devices: 7409*da0073e9SAndroid Build Coastguard Worker with self.subTest(device=device): 7410*da0073e9SAndroid Build Coastguard Worker generator = torch.Generator(device=device).manual_seed(12345) 7411*da0073e9SAndroid Build Coastguard Worker if device != "cpu": 7412*da0073e9SAndroid Build Coastguard Worker generator.set_offset(100) 7413*da0073e9SAndroid Build Coastguard Worker torch.randn((100, 100), generator=generator, device=device) # progress the RNG state 7414*da0073e9SAndroid Build Coastguard Worker 7415*da0073e9SAndroid Build Coastguard Worker reserialized: torch.Generator = pickle.loads(pickle.dumps(generator)) 7416*da0073e9SAndroid Build Coastguard Worker 7417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(generator.device, reserialized.device) 7418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(generator.initial_seed(), reserialized.initial_seed()) 7419*da0073e9SAndroid Build Coastguard Worker if device != "cpu": 7420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(generator.get_offset(), reserialized.get_offset()) 7421*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(generator.get_state(), reserialized.get_state()) 7422*da0073e9SAndroid Build Coastguard Worker 7423*da0073e9SAndroid Build Coastguard Worker def _sobol_reference_samples(self, scramble: bool) -> torch.Tensor: 7424*da0073e9SAndroid Build Coastguard Worker if not scramble: 7425*da0073e9SAndroid Build Coastguard Worker # theoretical values from Joe Kuo 2010 7426*da0073e9SAndroid Build Coastguard Worker return torch.tensor( 7427*da0073e9SAndroid Build Coastguard Worker [ 7428*da0073e9SAndroid Build Coastguard Worker [0., 0.], 7429*da0073e9SAndroid Build Coastguard Worker [0.5, 0.5], 7430*da0073e9SAndroid Build Coastguard Worker [0.75, 0.25], 7431*da0073e9SAndroid Build Coastguard Worker [0.25, 0.75], 7432*da0073e9SAndroid Build Coastguard Worker [0.375, 0.375], 7433*da0073e9SAndroid Build Coastguard Worker [0.875, 0.875], 7434*da0073e9SAndroid Build Coastguard Worker [0.625, 0.125], 7435*da0073e9SAndroid Build Coastguard Worker [0.125, 0.625], 7436*da0073e9SAndroid Build Coastguard Worker ], 7437*da0073e9SAndroid Build Coastguard Worker ) 7438*da0073e9SAndroid Build Coastguard Worker else: 7439*da0073e9SAndroid Build Coastguard Worker # theoretical values unknown: convergence properties checked 7440*da0073e9SAndroid Build Coastguard Worker return torch.tensor( 7441*da0073e9SAndroid Build Coastguard Worker [ 7442*da0073e9SAndroid Build Coastguard Worker [0.50860737, 0.29320504], 7443*da0073e9SAndroid Build Coastguard Worker [0.07116939, 0.89594537], 7444*da0073e9SAndroid Build Coastguard Worker [0.49354145, 0.11524881], 7445*da0073e9SAndroid Build Coastguard Worker [0.93097717, 0.70244044], 7446*da0073e9SAndroid Build Coastguard Worker [0.87266153, 0.23887917], 7447*da0073e9SAndroid Build Coastguard Worker [0.31021884, 0.57600391], 7448*da0073e9SAndroid Build Coastguard Worker [0.13687253, 0.42054182], 7449*da0073e9SAndroid Build Coastguard Worker [0.69931293, 0.77336788], 7450*da0073e9SAndroid Build Coastguard Worker ], 7451*da0073e9SAndroid Build Coastguard Worker ) 7452*da0073e9SAndroid Build Coastguard Worker 7453*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_bounds(self, scramble: bool = False): 7454*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(100, scramble=scramble, seed=123456) 7455*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(512) 7456*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(sample >= 0)) 7457*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(sample <= 1)) 7458*da0073e9SAndroid Build Coastguard Worker 7459*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_bounds_scrambled(self): 7460*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_bounds(scramble=True) 7461*da0073e9SAndroid Build Coastguard Worker 7462*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_draw(self, scramble: bool = False): 7463*da0073e9SAndroid Build Coastguard Worker ref_sample = self._sobol_reference_samples(scramble=scramble) 7464*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) 7465*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(n=len(ref_sample)) 7466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample, ref_sample) 7467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(engine.num_generated, len(ref_sample)) 7468*da0073e9SAndroid Build Coastguard Worker 7469*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_draw_scrambled(self): 7470*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_draw(scramble=True) 7471*da0073e9SAndroid Build Coastguard Worker 7472*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_first_point(self): 7473*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.float, torch.double): 7474*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=False) 7475*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(1, dtype=dtype) 7476*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(sample == 0)) 7477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample.dtype, dtype) 7478*da0073e9SAndroid Build Coastguard Worker for dtype in (torch.float, torch.double): 7479*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456) 7480*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(1, dtype=dtype) 7481*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.all(sample != 0)) 7482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sample.dtype, dtype) 7483*da0073e9SAndroid Build Coastguard Worker 7484*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_continuing(self, scramble: bool = False): 7485*da0073e9SAndroid Build Coastguard Worker ref_sample = self._sobol_reference_samples(scramble=scramble) 7486*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) 7487*da0073e9SAndroid Build Coastguard Worker n_half = len(ref_sample) // 2 7488*da0073e9SAndroid Build Coastguard Worker _ = engine.draw(n=n_half) 7489*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(n=n_half) 7490*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(sample, ref_sample[n_half:]) 7491*da0073e9SAndroid Build Coastguard Worker 7492*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_continuing_scrambled(self): 7493*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_continuing(scramble=True) 7494*da0073e9SAndroid Build Coastguard Worker 7495*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_reset(self, scramble: bool = False): 7496*da0073e9SAndroid Build Coastguard Worker ref_sample = self._sobol_reference_samples(scramble=scramble) 7497*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) 7498*da0073e9SAndroid Build Coastguard Worker _ = engine.draw(n=len(ref_sample) // 2) 7499*da0073e9SAndroid Build Coastguard Worker engine.reset() 7500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(engine.num_generated, 0) 7501*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(n=len(ref_sample)) 7502*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(sample, ref_sample) 7503*da0073e9SAndroid Build Coastguard Worker 7504*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_reset_scrambled(self): 7505*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_reset(scramble=True) 7506*da0073e9SAndroid Build Coastguard Worker 7507*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_fast_forward(self, scramble: bool = False): 7508*da0073e9SAndroid Build Coastguard Worker ref_sample = self._sobol_reference_samples(scramble=scramble) 7509*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) 7510*da0073e9SAndroid Build Coastguard Worker engine.fast_forward(4) 7511*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(n=4) 7512*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(sample, ref_sample[4:]) 7513*da0073e9SAndroid Build Coastguard Worker # alternate fast forwarding with sampling 7514*da0073e9SAndroid Build Coastguard Worker engine.reset() 7515*da0073e9SAndroid Build Coastguard Worker even_draws = [] 7516*da0073e9SAndroid Build Coastguard Worker for i in range(8): 7517*da0073e9SAndroid Build Coastguard Worker if i % 2 == 0: 7518*da0073e9SAndroid Build Coastguard Worker even_draws.append(engine.draw()) 7519*da0073e9SAndroid Build Coastguard Worker else: 7520*da0073e9SAndroid Build Coastguard Worker engine.fast_forward(1) 7521*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 7522*da0073e9SAndroid Build Coastguard Worker ref_sample[[i for i in range(8) if i % 2 == 0]], 7523*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(np.concatenate(even_draws)), 7524*da0073e9SAndroid Build Coastguard Worker ) 7525*da0073e9SAndroid Build Coastguard Worker 7526*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_fast_forward_scrambled(self): 7527*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_fast_forward(scramble=True) 7528*da0073e9SAndroid Build Coastguard Worker 7529*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_default_dtype(self): 7530*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456) 7531*da0073e9SAndroid Build Coastguard Worker # Check that default dtype is correctly handled 7532*da0073e9SAndroid Build Coastguard Worker self.assertEqual(engine.draw(n=5).dtype, torch.float32) 7533*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float64): 7534*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456) 7535*da0073e9SAndroid Build Coastguard Worker # Check that default dtype is correctly handled (when set to float64) 7536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(engine.draw(n=5).dtype, torch.float64) 7537*da0073e9SAndroid Build Coastguard Worker # Check that explicitly passed dtype is adhered to 7538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32) 7539*da0073e9SAndroid Build Coastguard Worker # Reinitialize the engine and check that first draw dtype is correctly handled 7540*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123456) 7541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(engine.draw(n=5, dtype=torch.float32).dtype, torch.float32) 7542*da0073e9SAndroid Build Coastguard Worker 7543*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("np.float64 restored as float32 after graph break.") 7544*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_distribution(self, scramble=False): 7545*da0073e9SAndroid Build Coastguard Worker d = 50 7546*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(d, scramble=scramble, seed=123456) 7547*da0073e9SAndroid Build Coastguard Worker sample = engine.draw(1024) 7548*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 7549*da0073e9SAndroid Build Coastguard Worker torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2 7550*da0073e9SAndroid Build Coastguard Worker ) 7551*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 7552*da0073e9SAndroid Build Coastguard Worker np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2 7553*da0073e9SAndroid Build Coastguard Worker ) 7554*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 7555*da0073e9SAndroid Build Coastguard Worker np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2 7556*da0073e9SAndroid Build Coastguard Worker ) 7557*da0073e9SAndroid Build Coastguard Worker 7558*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("np.float64 restored as float32 after graph break.") 7559*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_distribution_scrambled(self): 7560*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_distribution(scramble=True) 7561*da0073e9SAndroid Build Coastguard Worker 7562*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_draw_base2(self, scramble=False): 7563*da0073e9SAndroid Build Coastguard Worker ref_sample = self._sobol_reference_samples(scramble=scramble) 7564*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456) 7565*da0073e9SAndroid Build Coastguard Worker sample = engine.draw_base2(2) 7566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_sample[:4], sample) 7567*da0073e9SAndroid Build Coastguard Worker # resampling still having N=2**n 7568*da0073e9SAndroid Build Coastguard Worker sample = engine.draw_base2(2) 7569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_sample[4:8], sample) 7570*da0073e9SAndroid Build Coastguard Worker 7571*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_draw_base2_scrambled(self): 7572*da0073e9SAndroid Build Coastguard Worker self.test_sobolengine_draw_base2(scramble=True) 7573*da0073e9SAndroid Build Coastguard Worker 7574*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_raise(self): 7575*da0073e9SAndroid Build Coastguard Worker maxdim = torch.quasirandom.SobolEngine.MAXDIM 7576*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 7577*da0073e9SAndroid Build Coastguard Worker torch.quasirandom.SobolEngine(maxdim + 1) 7578*da0073e9SAndroid Build Coastguard Worker 7579*da0073e9SAndroid Build Coastguard Worker def test_sobolengine_high_dim(self): 7580*da0073e9SAndroid Build Coastguard Worker engine = torch.quasirandom.SobolEngine(1111, scramble=False, seed=123456) 7581*da0073e9SAndroid Build Coastguard Worker samples1 = engine.draw() 7582*da0073e9SAndroid Build Coastguard Worker vals1, counts1 = torch.unique(samples1, return_counts=True) 7583*da0073e9SAndroid Build Coastguard Worker samples2 = engine.draw() 7584*da0073e9SAndroid Build Coastguard Worker vals2, counts2 = torch.unique(samples2, return_counts=True) 7585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vals1.item(), 0.0) 7586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts1.item(), 1111) 7587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vals2.item(), 0.5) 7588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts1.item(), 1111) 7589*da0073e9SAndroid Build Coastguard Worker 7590*da0073e9SAndroid Build Coastguard Worker def test_parsing_int64(self): 7591*da0073e9SAndroid Build Coastguard Worker # accepts integer arguments 7592*da0073e9SAndroid Build Coastguard Worker x = torch.cumsum(torch.ones(5, 5), 0) 7593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0))) 7594*da0073e9SAndroid Build Coastguard Worker # doesn't accept floating point variables 7595*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.))) 7596*da0073e9SAndroid Build Coastguard Worker 7597*da0073e9SAndroid Build Coastguard Worker def test_parsing_double(self): 7598*da0073e9SAndroid Build Coastguard Worker # accepts floating point and integer arguments 7599*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 7600*da0073e9SAndroid Build Coastguard Worker torch.isclose(x, x, 1, 1) 7601*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isclose(x, x, 1, 1).all()) 7602*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isclose(x, x, 1.5, 1.).all()) 7603*da0073e9SAndroid Build Coastguard Worker # accepts floating point and integer tensors 7604*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all()) 7605*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all()) 7606*da0073e9SAndroid Build Coastguard Worker # doesn't accept variables with requires_grad 7607*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, 7608*da0073e9SAndroid Build Coastguard Worker lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all()) 7609*da0073e9SAndroid Build Coastguard Worker 7610*da0073e9SAndroid Build Coastguard Worker def test_parsing_intlist(self): 7611*da0073e9SAndroid Build Coastguard Worker # parse with integer variables 7612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape) 7613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape) 7614*da0073e9SAndroid Build Coastguard Worker # parse with numpy integers 7615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape) 7616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape) 7617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape) 7618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape) 7619*da0073e9SAndroid Build Coastguard Worker 7620*da0073e9SAndroid Build Coastguard Worker # fail parse with float variables 7621*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4)))) 7622*da0073e9SAndroid Build Coastguard Worker # fail parse with numpy floats 7623*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones((3., torch.tensor(4)))) 7624*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4)))) 7625*da0073e9SAndroid Build Coastguard Worker 7626*da0073e9SAndroid Build Coastguard Worker # fail parse with > 1 element variables 7627*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3))) 7628*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3))) 7629*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3))) 7630*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3))) 7631*da0073e9SAndroid Build Coastguard Worker 7632*da0073e9SAndroid Build Coastguard Worker # fail parse with additional positional args after intlist arg 7633*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, 7634*da0073e9SAndroid Build Coastguard Worker "received an invalid combination of arguments", 7635*da0073e9SAndroid Build Coastguard Worker lambda: torch.LongTensor((6, 0), 1, 1, 0)) 7636*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, 7637*da0073e9SAndroid Build Coastguard Worker "missing 1 required positional arguments", 7638*da0073e9SAndroid Build Coastguard Worker lambda: torch.tensor().new_zeros((5, 5), 0)) 7639*da0073e9SAndroid Build Coastguard Worker 7640*da0073e9SAndroid Build Coastguard Worker def test_from_buffer(self): 7641*da0073e9SAndroid Build Coastguard Worker a = bytearray([1, 2, 3, 4]) 7642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) 7643*da0073e9SAndroid Build Coastguard Worker shorts = torch.ShortStorage.from_buffer(a, 'big') 7644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shorts.size(), 2) 7645*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shorts.tolist(), [258, 772]) 7646*da0073e9SAndroid Build Coastguard Worker ints = torch.IntStorage.from_buffer(a, 'little') 7647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ints.size(), 1) 7648*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ints[0], 67305985) 7649*da0073e9SAndroid Build Coastguard Worker f = bytearray([0x40, 0x10, 0x00, 0x00]) 7650*da0073e9SAndroid Build Coastguard Worker floats = torch.FloatStorage.from_buffer(f, 'big') 7651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(floats.size(), 1) 7652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(floats[0], 2.25) 7653*da0073e9SAndroid Build Coastguard Worker 7654*da0073e9SAndroid Build Coastguard Worker f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) 7655*da0073e9SAndroid Build Coastguard Worker bools = torch.BoolStorage.from_buffer(f, 'big') 7656*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bools.size(), 8) 7657*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True]) 7658*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bools.type(), 'torch.BoolStorage') 7659*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(bools, torch.BoolStorage)) 7660*da0073e9SAndroid Build Coastguard Worker 7661*da0073e9SAndroid Build Coastguard Worker f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9') 7662*da0073e9SAndroid Build Coastguard Worker bools = torch.BoolStorage.from_buffer(f, 'big') 7663*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bools.size(), 19) 7664*da0073e9SAndroid Build Coastguard Worker 7665*da0073e9SAndroid Build Coastguard Worker f = bytearray(b'\0x4A') 7666*da0073e9SAndroid Build Coastguard Worker bools = torch.BoolStorage.from_buffer(f, 'big') 7667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bools.size(), 4) 7668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bools.tolist(), [False, True, True, True]) 7669*da0073e9SAndroid Build Coastguard Worker bytes = torch.ByteStorage.from_buffer(a) 7670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bytes.nbytes(), 4) 7671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bytes.tolist(), [1, 2, 3, 4]) 7672*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(bytes, torch.ByteStorage)) 7673*da0073e9SAndroid Build Coastguard Worker 7674*da0073e9SAndroid Build Coastguard Worker def test_storage_error(self): 7675*da0073e9SAndroid Build Coastguard Worker quantized_storages = [ 7676*da0073e9SAndroid Build Coastguard Worker torch.QInt32Storage, 7677*da0073e9SAndroid Build Coastguard Worker torch.QInt8Storage, 7678*da0073e9SAndroid Build Coastguard Worker torch.QUInt2x4Storage, 7679*da0073e9SAndroid Build Coastguard Worker torch.QUInt4x2Storage, 7680*da0073e9SAndroid Build Coastguard Worker torch.QUInt8Storage, 7681*da0073e9SAndroid Build Coastguard Worker ] 7682*da0073e9SAndroid Build Coastguard Worker 7683*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Only child classes of _LegacyStorage can be instantiated"): 7684*da0073e9SAndroid Build Coastguard Worker torch.storage._LegacyStorage() 7685*da0073e9SAndroid Build Coastguard Worker 7686*da0073e9SAndroid Build Coastguard Worker for storage_class in torch._storage_classes: 7687*da0073e9SAndroid Build Coastguard Worker if storage_class in [torch.UntypedStorage, torch.TypedStorage]: 7688*da0073e9SAndroid Build Coastguard Worker continue 7689*da0073e9SAndroid Build Coastguard Worker 7690*da0073e9SAndroid Build Coastguard Worker device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu' 7691*da0073e9SAndroid Build Coastguard Worker dtype = storage_class.dtype 7692*da0073e9SAndroid Build Coastguard Worker 7693*da0073e9SAndroid Build Coastguard Worker if device == 'cuda' and not torch.cuda.is_available(): 7694*da0073e9SAndroid Build Coastguard Worker continue 7695*da0073e9SAndroid Build Coastguard Worker 7696*da0073e9SAndroid Build Coastguard Worker # Legacy <type>Storage constructor errors 7697*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"'device' cannot be specified"): 7698*da0073e9SAndroid Build Coastguard Worker storage_class(device='cpu') 7699*da0073e9SAndroid Build Coastguard Worker 7700*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"'dtype' cannot be specified"): 7701*da0073e9SAndroid Build Coastguard Worker storage_class(dtype=torch.float) 7702*da0073e9SAndroid Build Coastguard Worker 7703*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"got an unexpected keyword"): 7704*da0073e9SAndroid Build Coastguard Worker storage_class(sdlkjf=torch.float) 7705*da0073e9SAndroid Build Coastguard Worker 7706*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"): 7707*da0073e9SAndroid Build Coastguard Worker storage_class(0, 0) 7708*da0073e9SAndroid Build Coastguard Worker 7709*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"invalid data type"): 7710*da0073e9SAndroid Build Coastguard Worker storage_class('string') 7711*da0073e9SAndroid Build Coastguard Worker 7712*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Argument type not recognized"): 7713*da0073e9SAndroid Build Coastguard Worker storage_class(torch.tensor([])) 7714*da0073e9SAndroid Build Coastguard Worker 7715*da0073e9SAndroid Build Coastguard Worker s = storage_class() 7716*da0073e9SAndroid Build Coastguard Worker 7717*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"No positional arguments"): 7718*da0073e9SAndroid Build Coastguard Worker storage_class(0, wrap_storage=s.untyped()) 7719*da0073e9SAndroid Build Coastguard Worker 7720*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"must be UntypedStorage"): 7721*da0073e9SAndroid Build Coastguard Worker storage_class(wrap_storage=s) 7722*da0073e9SAndroid Build Coastguard Worker 7723*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7724*da0073e9SAndroid Build Coastguard Worker if storage_class in quantized_storages: 7725*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"): 7726*da0073e9SAndroid Build Coastguard Worker s.cuda() 7727*da0073e9SAndroid Build Coastguard Worker 7728*da0073e9SAndroid Build Coastguard Worker else: 7729*da0073e9SAndroid Build Coastguard Worker 7730*da0073e9SAndroid Build Coastguard Worker if s.is_cuda: 7731*da0073e9SAndroid Build Coastguard Worker s_other_device = s.cpu() 7732*da0073e9SAndroid Build Coastguard Worker else: 7733*da0073e9SAndroid Build Coastguard Worker s_other_device = s.cuda() 7734*da0073e9SAndroid Build Coastguard Worker 7735*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"): 7736*da0073e9SAndroid Build Coastguard Worker storage_class(wrap_storage=s_other_device.untyped()) 7737*da0073e9SAndroid Build Coastguard Worker 7738*da0073e9SAndroid Build Coastguard Worker # TypedStorage constructor errors 7739*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"No positional arguments"): 7740*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(0, wrap_storage=s.untyped(), dtype=dtype) 7741*da0073e9SAndroid Build Coastguard Worker 7742*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"): 7743*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(wrap_storage=s.untyped()) 7744*da0073e9SAndroid Build Coastguard Worker 7745*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"): 7746*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(wrap_storage=s.untyped(), dtype=0) 7747*da0073e9SAndroid Build Coastguard Worker 7748*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"): 7749*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(wrap_storage=s.untyped(), dtype=dtype, device=device) 7750*da0073e9SAndroid Build Coastguard Worker 7751*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be UntypedStorage"): 7752*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(wrap_storage=s, dtype=dtype) 7753*da0073e9SAndroid Build Coastguard Worker 7754*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"): 7755*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(dtype=dtype, device='xla') 7756*da0073e9SAndroid Build Coastguard Worker 7757*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7758*da0073e9SAndroid Build Coastguard Worker if storage_class in quantized_storages: 7759*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"): 7760*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(dtype=dtype, device='cuda') 7761*da0073e9SAndroid Build Coastguard Worker 7762*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, r"Argument type not recognized"): 7763*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(torch.tensor([]), dtype=dtype, device=device) 7764*da0073e9SAndroid Build Coastguard Worker 7765*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"): 7766*da0073e9SAndroid Build Coastguard Worker torch.TypedStorage(0, 0, dtype=dtype, device=device) 7767*da0073e9SAndroid Build Coastguard Worker 7768*da0073e9SAndroid Build Coastguard Worker if isinstance(s, torch.TypedStorage): 7769*da0073e9SAndroid Build Coastguard Worker s_other = torch.TypedStorage([1, 2, 3, 4], device=device, dtype=dtype) 7770*da0073e9SAndroid Build Coastguard Worker 7771*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'cannot set item'): 7772*da0073e9SAndroid Build Coastguard Worker s.fill_(s_other) 7773*da0073e9SAndroid Build Coastguard Worker 7774*da0073e9SAndroid Build Coastguard Worker def test_storage_error_no_attribute(self): 7775*da0073e9SAndroid Build Coastguard Worker storage_classes = [ 7776*da0073e9SAndroid Build Coastguard Worker torch.cuda.ByteStorage, 7777*da0073e9SAndroid Build Coastguard Worker torch.cuda.FloatStorage, 7778*da0073e9SAndroid Build Coastguard Worker ] 7779*da0073e9SAndroid Build Coastguard Worker for storage_class in storage_classes: 7780*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'): 7781*da0073e9SAndroid Build Coastguard Worker storage_class.from_buffer() 7782*da0073e9SAndroid Build Coastguard Worker 7783*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'): 7784*da0073e9SAndroid Build Coastguard Worker storage_class._new_with_weak_ptr() 7785*da0073e9SAndroid Build Coastguard Worker 7786*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'): 7787*da0073e9SAndroid Build Coastguard Worker storage_class._new_shared_filename(0, 0, 0) 7788*da0073e9SAndroid Build Coastguard Worker 7789*da0073e9SAndroid Build Coastguard Worker def test_storage_casts(self): 7790*da0073e9SAndroid Build Coastguard Worker storage = torch.IntStorage([-1, 0, 1, 2, 3, 4]) 7791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage.size(), 6) 7792*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4]) 7793*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage.type(), 'torch.IntStorage') 7794*da0073e9SAndroid Build Coastguard Worker self.assertIs(storage.dtype, torch.int32) 7795*da0073e9SAndroid Build Coastguard Worker 7796*da0073e9SAndroid Build Coastguard Worker floatStorage = storage.float() 7797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(floatStorage.size(), 6) 7798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4]) 7799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(floatStorage.type(), 'torch.FloatStorage') 7800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7801*da0073e9SAndroid Build Coastguard Worker self.assertIs(floatStorage.dtype, torch.float32) 7802*da0073e9SAndroid Build Coastguard Worker 7803*da0073e9SAndroid Build Coastguard Worker halfStorage = storage.half() 7804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfStorage.size(), 6) 7805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4]) 7806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfStorage.type(), 'torch.HalfStorage') 7807*da0073e9SAndroid Build Coastguard Worker self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7808*da0073e9SAndroid Build Coastguard Worker self.assertIs(halfStorage.dtype, torch.float16) 7809*da0073e9SAndroid Build Coastguard Worker 7810*da0073e9SAndroid Build Coastguard Worker bfloat16Storage = storage.bfloat16() 7811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bfloat16Storage.size(), 6) 7812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4]) 7813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage') 7814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7815*da0073e9SAndroid Build Coastguard Worker self.assertIs(bfloat16Storage.dtype, torch.bfloat16) 7816*da0073e9SAndroid Build Coastguard Worker 7817*da0073e9SAndroid Build Coastguard Worker longStorage = storage.long() 7818*da0073e9SAndroid Build Coastguard Worker self.assertEqual(longStorage.size(), 6) 7819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4]) 7820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(longStorage.type(), 'torch.LongStorage') 7821*da0073e9SAndroid Build Coastguard Worker self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7822*da0073e9SAndroid Build Coastguard Worker self.assertIs(longStorage.dtype, torch.int64) 7823*da0073e9SAndroid Build Coastguard Worker 7824*da0073e9SAndroid Build Coastguard Worker shortStorage = storage.short() 7825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shortStorage.size(), 6) 7826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4]) 7827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shortStorage.type(), 'torch.ShortStorage') 7828*da0073e9SAndroid Build Coastguard Worker self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7829*da0073e9SAndroid Build Coastguard Worker self.assertIs(shortStorage.dtype, torch.int16) 7830*da0073e9SAndroid Build Coastguard Worker 7831*da0073e9SAndroid Build Coastguard Worker doubleStorage = storage.double() 7832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(doubleStorage.size(), 6) 7833*da0073e9SAndroid Build Coastguard Worker self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) 7834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage') 7835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7836*da0073e9SAndroid Build Coastguard Worker self.assertIs(doubleStorage.dtype, torch.float64) 7837*da0073e9SAndroid Build Coastguard Worker 7838*da0073e9SAndroid Build Coastguard Worker charStorage = storage.char() 7839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(charStorage.size(), 6) 7840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) 7841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(charStorage.type(), 'torch.CharStorage') 7842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) 7843*da0073e9SAndroid Build Coastguard Worker self.assertIs(charStorage.dtype, torch.int8) 7844*da0073e9SAndroid Build Coastguard Worker 7845*da0073e9SAndroid Build Coastguard Worker byteStorage = storage.byte() 7846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byteStorage.size(), 6) 7847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4]) 7848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byteStorage.type(), 'torch.ByteStorage') 7849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4]) 7850*da0073e9SAndroid Build Coastguard Worker self.assertIs(byteStorage.dtype, torch.uint8) 7851*da0073e9SAndroid Build Coastguard Worker 7852*da0073e9SAndroid Build Coastguard Worker boolStorage = storage.bool() 7853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(boolStorage.size(), 6) 7854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True]) 7855*da0073e9SAndroid Build Coastguard Worker self.assertEqual(boolStorage.type(), 'torch.BoolStorage') 7856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1]) 7857*da0073e9SAndroid Build Coastguard Worker self.assertIs(boolStorage.dtype, torch.bool) 7858*da0073e9SAndroid Build Coastguard Worker 7859*da0073e9SAndroid Build Coastguard Worker complexfloat_storage = torch.ComplexFloatStorage([-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j]) 7860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexfloat_storage.size(), 6) 7861*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexfloat_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j]) 7862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexfloat_storage.type(), 'torch.ComplexFloatStorage') 7863*da0073e9SAndroid Build Coastguard Worker self.assertIs(complexfloat_storage.dtype, torch.complex64) 7864*da0073e9SAndroid Build Coastguard Worker 7865*da0073e9SAndroid Build Coastguard Worker complexdouble_storage = complexfloat_storage.complex_double() 7866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexdouble_storage.size(), 6) 7867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexdouble_storage.tolist(), [-1, 0, 1 + 2j, 2.5j, 3.5, 4 - 2j]) 7868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complexdouble_storage.type(), 'torch.ComplexDoubleStorage') 7869*da0073e9SAndroid Build Coastguard Worker self.assertIs(complexdouble_storage.dtype, torch.complex128) 7870*da0073e9SAndroid Build Coastguard Worker 7871*da0073e9SAndroid Build Coastguard Worker def test_storage_byteswap(self): 7872*da0073e9SAndroid Build Coastguard Worker input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 7873*da0073e9SAndroid Build Coastguard Worker swapped_8bytes = [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8] 7874*da0073e9SAndroid Build Coastguard Worker swapped_4bytes = [3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12] 7875*da0073e9SAndroid Build Coastguard Worker swapped_2bytes = [1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14] 7876*da0073e9SAndroid Build Coastguard Worker swapped_1byte = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 7877*da0073e9SAndroid Build Coastguard Worker 7878*da0073e9SAndroid Build Coastguard Worker storage = torch.storage.TypedStorage(input, dtype=torch.uint8)._untyped_storage 7879*da0073e9SAndroid Build Coastguard Worker 7880*da0073e9SAndroid Build Coastguard Worker storage_f64 = storage.__copy__() 7881*da0073e9SAndroid Build Coastguard Worker storage_f64.byteswap(torch.float64) 7882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_f64.tolist(), swapped_8bytes) 7883*da0073e9SAndroid Build Coastguard Worker 7884*da0073e9SAndroid Build Coastguard Worker storage_f32 = storage.__copy__() 7885*da0073e9SAndroid Build Coastguard Worker storage_f32.byteswap(torch.float32) 7886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_f32.tolist(), swapped_4bytes) 7887*da0073e9SAndroid Build Coastguard Worker 7888*da0073e9SAndroid Build Coastguard Worker storage_f16 = storage.__copy__() 7889*da0073e9SAndroid Build Coastguard Worker storage_f16.byteswap(torch.float16) 7890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_f16.tolist(), swapped_2bytes) 7891*da0073e9SAndroid Build Coastguard Worker 7892*da0073e9SAndroid Build Coastguard Worker storage_bf16 = storage.__copy__() 7893*da0073e9SAndroid Build Coastguard Worker storage_bf16.byteswap(torch.bfloat16) 7894*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_bf16.tolist(), swapped_2bytes) 7895*da0073e9SAndroid Build Coastguard Worker 7896*da0073e9SAndroid Build Coastguard Worker storage_i64 = storage.__copy__() 7897*da0073e9SAndroid Build Coastguard Worker storage_i64.byteswap(torch.int64) 7898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_i64.tolist(), swapped_8bytes) 7899*da0073e9SAndroid Build Coastguard Worker 7900*da0073e9SAndroid Build Coastguard Worker storage_i32 = storage.__copy__() 7901*da0073e9SAndroid Build Coastguard Worker storage_i32.byteswap(torch.int32) 7902*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_i32.tolist(), swapped_4bytes) 7903*da0073e9SAndroid Build Coastguard Worker 7904*da0073e9SAndroid Build Coastguard Worker storage_i16 = storage.__copy__() 7905*da0073e9SAndroid Build Coastguard Worker storage_i16.byteswap(torch.int16) 7906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_i16.tolist(), swapped_2bytes) 7907*da0073e9SAndroid Build Coastguard Worker 7908*da0073e9SAndroid Build Coastguard Worker storage_i8 = storage.__copy__() 7909*da0073e9SAndroid Build Coastguard Worker storage_i8.byteswap(torch.int8) 7910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_i8.tolist(), swapped_1byte) 7911*da0073e9SAndroid Build Coastguard Worker 7912*da0073e9SAndroid Build Coastguard Worker storage_ui8 = storage.__copy__() 7913*da0073e9SAndroid Build Coastguard Worker storage_ui8.byteswap(torch.uint8) 7914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_ui8.tolist(), swapped_1byte) 7915*da0073e9SAndroid Build Coastguard Worker 7916*da0073e9SAndroid Build Coastguard Worker storage_bool = storage.__copy__() 7917*da0073e9SAndroid Build Coastguard Worker storage_bool.byteswap(torch.bool) 7918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_bool.tolist(), swapped_1byte) 7919*da0073e9SAndroid Build Coastguard Worker 7920*da0073e9SAndroid Build Coastguard Worker storage_c128 = storage.__copy__() 7921*da0073e9SAndroid Build Coastguard Worker storage_c128.byteswap(torch.complex128) 7922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_c128.tolist(), swapped_8bytes) 7923*da0073e9SAndroid Build Coastguard Worker 7924*da0073e9SAndroid Build Coastguard Worker storage_c64 = storage.__copy__() 7925*da0073e9SAndroid Build Coastguard Worker storage_c64.byteswap(torch.complex64) 7926*da0073e9SAndroid Build Coastguard Worker self.assertEqual(storage_c64.tolist(), swapped_4bytes) 7927*da0073e9SAndroid Build Coastguard Worker 7928*da0073e9SAndroid Build Coastguard Worker # Test that internal versions of functions related to TypedStorage do not 7929*da0073e9SAndroid Build Coastguard Worker # produce a deprecation warning 7930*da0073e9SAndroid Build Coastguard Worker def test_typed_storage_internal_no_warning(self): 7931*da0073e9SAndroid Build Coastguard Worker s0 = torch.FloatStorage(10) 7932*da0073e9SAndroid Build Coastguard Worker s0_untyped = s0.untyped() 7933*da0073e9SAndroid Build Coastguard Worker t0 = torch.randn(10) 7934*da0073e9SAndroid Build Coastguard Worker 7935*da0073e9SAndroid Build Coastguard Worker funcs = [ 7936*da0073e9SAndroid Build Coastguard Worker lambda: torch.FloatStorage(_internal=True), 7937*da0073e9SAndroid Build Coastguard Worker lambda: torch.TypedStorage( 7938*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, 7939*da0073e9SAndroid Build Coastguard Worker device='cpu', 7940*da0073e9SAndroid Build Coastguard Worker _internal=True), 7941*da0073e9SAndroid Build Coastguard Worker lambda: torch.TypedStorage( 7942*da0073e9SAndroid Build Coastguard Worker wrap_storage=s0_untyped, 7943*da0073e9SAndroid Build Coastguard Worker dtype=s0.dtype, 7944*da0073e9SAndroid Build Coastguard Worker _internal=True), 7945*da0073e9SAndroid Build Coastguard Worker lambda: torch.FloatStorage._dtype, 7946*da0073e9SAndroid Build Coastguard Worker lambda: s0._resize_(20), 7947*da0073e9SAndroid Build Coastguard Worker lambda: s0._size(), 7948*da0073e9SAndroid Build Coastguard Worker lambda: s0._untyped_storage, 7949*da0073e9SAndroid Build Coastguard Worker lambda: s0._is_shared(), 7950*da0073e9SAndroid Build Coastguard Worker lambda: s0._share_memory_(), 7951*da0073e9SAndroid Build Coastguard Worker lambda: s0._pickle_storage_type(), 7952*da0073e9SAndroid Build Coastguard Worker lambda: s0._setitem(slice(0, s0._size()), 1), 7953*da0073e9SAndroid Build Coastguard Worker lambda: s0._element_size(), 7954*da0073e9SAndroid Build Coastguard Worker lambda: s0._deepcopy({}), 7955*da0073e9SAndroid Build Coastguard Worker lambda: s0._data_ptr(), 7956*da0073e9SAndroid Build Coastguard Worker lambda: s0._nbytes(), 7957*da0073e9SAndroid Build Coastguard Worker lambda: t0._typed_storage(), 7958*da0073e9SAndroid Build Coastguard Worker ] 7959*da0073e9SAndroid Build Coastguard Worker 7960*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7961*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.FloatStorage(10) 7962*da0073e9SAndroid Build Coastguard Worker s1_untyped = s1.untyped() 7963*da0073e9SAndroid Build Coastguard Worker t1 = torch.randn(10, device='cuda') 7964*da0073e9SAndroid Build Coastguard Worker 7965*da0073e9SAndroid Build Coastguard Worker funcs += [ 7966*da0073e9SAndroid Build Coastguard Worker lambda: torch.cuda.FloatStorage(_internal=True), 7967*da0073e9SAndroid Build Coastguard Worker lambda: torch.TypedStorage( 7968*da0073e9SAndroid Build Coastguard Worker dtype=torch.float, 7969*da0073e9SAndroid Build Coastguard Worker device='cuda', 7970*da0073e9SAndroid Build Coastguard Worker _internal=True), 7971*da0073e9SAndroid Build Coastguard Worker lambda: torch.TypedStorage( 7972*da0073e9SAndroid Build Coastguard Worker wrap_storage=s1_untyped, 7973*da0073e9SAndroid Build Coastguard Worker dtype=s1.dtype, 7974*da0073e9SAndroid Build Coastguard Worker _internal=True), 7975*da0073e9SAndroid Build Coastguard Worker lambda: torch.cuda.FloatStorage._dtype, 7976*da0073e9SAndroid Build Coastguard Worker lambda: s1._resize_(20), 7977*da0073e9SAndroid Build Coastguard Worker lambda: s1._size(), 7978*da0073e9SAndroid Build Coastguard Worker lambda: s1._untyped_storage, 7979*da0073e9SAndroid Build Coastguard Worker lambda: s1._is_shared(), 7980*da0073e9SAndroid Build Coastguard Worker lambda: s1._share_memory_(), 7981*da0073e9SAndroid Build Coastguard Worker lambda: s1._pickle_storage_type(), 7982*da0073e9SAndroid Build Coastguard Worker lambda: s1._setitem(slice(0, s1._size()), 1), 7983*da0073e9SAndroid Build Coastguard Worker lambda: s1._element_size(), 7984*da0073e9SAndroid Build Coastguard Worker lambda: s1._deepcopy({}), 7985*da0073e9SAndroid Build Coastguard Worker lambda: s1._data_ptr(), 7986*da0073e9SAndroid Build Coastguard Worker lambda: s1._nbytes(), 7987*da0073e9SAndroid Build Coastguard Worker lambda: t1._typed_storage(), 7988*da0073e9SAndroid Build Coastguard Worker ] 7989*da0073e9SAndroid Build Coastguard Worker 7990*da0073e9SAndroid Build Coastguard Worker # Check that each of the TypedStorage internal function calls do not 7991*da0073e9SAndroid Build Coastguard Worker # produce a deprecation warning 7992*da0073e9SAndroid Build Coastguard Worker for f in funcs: 7993*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(): 7994*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings('error', "TypedStorage is deprecated") 7995*da0073e9SAndroid Build Coastguard Worker f() 7996*da0073e9SAndroid Build Coastguard Worker 7997*da0073e9SAndroid Build Coastguard Worker # Test that public functions related to TypedStorage produce a deprecation 7998*da0073e9SAndroid Build Coastguard Worker # warning 7999*da0073e9SAndroid Build Coastguard Worker @skipIfTorchInductor("FIXME") 8000*da0073e9SAndroid Build Coastguard Worker def test_typed_storage_deprecation_warning(self): 8001*da0073e9SAndroid Build Coastguard Worker s0 = torch.FloatStorage(10) 8002*da0073e9SAndroid Build Coastguard Worker funcs = [ 8003*da0073e9SAndroid Build Coastguard Worker lambda: torch.FloatStorage(), 8004*da0073e9SAndroid Build Coastguard Worker lambda: torch.FloatStorage.dtype, 8005*da0073e9SAndroid Build Coastguard Worker lambda: s0.fill_(0), 8006*da0073e9SAndroid Build Coastguard Worker lambda: s0.is_cuda, 8007*da0073e9SAndroid Build Coastguard Worker lambda: s0.untyped(), 8008*da0073e9SAndroid Build Coastguard Worker lambda: len(s0), 8009*da0073e9SAndroid Build Coastguard Worker lambda: s0[0], 8010*da0073e9SAndroid Build Coastguard Worker ] 8011*da0073e9SAndroid Build Coastguard Worker 8012*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 8013*da0073e9SAndroid Build Coastguard Worker s1 = torch.cuda.FloatStorage(10) 8014*da0073e9SAndroid Build Coastguard Worker funcs += [ 8015*da0073e9SAndroid Build Coastguard Worker lambda: torch.cuda.FloatStorage(), 8016*da0073e9SAndroid Build Coastguard Worker lambda: torch.cuda.FloatStorage.dtype, 8017*da0073e9SAndroid Build Coastguard Worker lambda: s1.fill_(0), 8018*da0073e9SAndroid Build Coastguard Worker lambda: s1.is_cuda, 8019*da0073e9SAndroid Build Coastguard Worker lambda: s1.untyped(), 8020*da0073e9SAndroid Build Coastguard Worker lambda: len(s1), 8021*da0073e9SAndroid Build Coastguard Worker lambda: s1[0], 8022*da0073e9SAndroid Build Coastguard Worker ] 8023*da0073e9SAndroid Build Coastguard Worker 8024*da0073e9SAndroid Build Coastguard Worker # Check that each of the TypedStorage function calls produce a warning 8025*da0073e9SAndroid Build Coastguard Worker # if warnings are reset between each 8026*da0073e9SAndroid Build Coastguard Worker for f in funcs: 8027*da0073e9SAndroid Build Coastguard Worker with AlwaysWarnTypedStorageRemoval(True): 8028*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 8029*da0073e9SAndroid Build Coastguard Worker warnings.resetwarnings() 8030*da0073e9SAndroid Build Coastguard Worker f() 8031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1, msg=str([str(a) for a in w])) 8032*da0073e9SAndroid Build Coastguard Worker warning = w[0].message 8033*da0073e9SAndroid Build Coastguard Worker self.assertTrue(warning, DeprecationWarning) 8034*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search( 8035*da0073e9SAndroid Build Coastguard Worker '^TypedStorage is deprecated', 8036*da0073e9SAndroid Build Coastguard Worker str(warning))) 8037*da0073e9SAndroid Build Coastguard Worker 8038*da0073e9SAndroid Build Coastguard Worker # Test that only the first warning is raised by default 8039*da0073e9SAndroid Build Coastguard Worker torch.storage._reset_warn_typed_storage_removal() 8040*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 8041*da0073e9SAndroid Build Coastguard Worker warnings.resetwarnings() 8042*da0073e9SAndroid Build Coastguard Worker torch.FloatStorage() 8043*da0073e9SAndroid Build Coastguard Worker torch.randn(10).storage() 8044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1, msg=str([str(a) for a in w])) 8045*da0073e9SAndroid Build Coastguard Worker warning = w[0].message 8046*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search( 8047*da0073e9SAndroid Build Coastguard Worker '^TypedStorage is deprecated', 8048*da0073e9SAndroid Build Coastguard Worker str(warning))) 8049*da0073e9SAndroid Build Coastguard Worker # Check the line of code from the warning's stack 8050*da0073e9SAndroid Build Coastguard Worker with open(w[0].filename, encoding="utf-8") as f: 8051*da0073e9SAndroid Build Coastguard Worker code_line = f.readlines()[w[0].lineno - 1] 8052*da0073e9SAndroid Build Coastguard Worker self.assertTrue(re.search(re.escape('torch.FloatStorage()'), code_line)) 8053*da0073e9SAndroid Build Coastguard Worker 8054*da0073e9SAndroid Build Coastguard Worker # Check that warnings are not emitted if it happened in the past 8055*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 8056*da0073e9SAndroid Build Coastguard Worker warnings.resetwarnings() 8057*da0073e9SAndroid Build Coastguard Worker torch.FloatStorage() 8058*da0073e9SAndroid Build Coastguard Worker torch.randn(10).storage() 8059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0, msg=str([str(a) for a in w])) 8060*da0073e9SAndroid Build Coastguard Worker 8061*da0073e9SAndroid Build Coastguard Worker def test_from_file(self): 8062*da0073e9SAndroid Build Coastguard Worker def assert_with_filename(filename): 8063*da0073e9SAndroid Build Coastguard Worker size = 10000 8064*da0073e9SAndroid Build Coastguard Worker s1 = torch.FloatStorage.from_file(filename, True, size) 8065*da0073e9SAndroid Build Coastguard Worker t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) 8066*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1.data_ptr(), torch.FloatTensor(s1).data_ptr()) 8067*da0073e9SAndroid Build Coastguard Worker 8068*da0073e9SAndroid Build Coastguard Worker # check mapping 8069*da0073e9SAndroid Build Coastguard Worker s2 = torch.FloatStorage.from_file(filename, True, size) 8070*da0073e9SAndroid Build Coastguard Worker t2 = torch.FloatTensor(s2) 8071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, atol=0, rtol=0) 8072*da0073e9SAndroid Build Coastguard Worker 8073*da0073e9SAndroid Build Coastguard Worker # check changes to t1 from t2 8074*da0073e9SAndroid Build Coastguard Worker rnum = random.uniform(-1, 1) 8075*da0073e9SAndroid Build Coastguard Worker t1.fill_(rnum) 8076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, atol=0, rtol=0) 8077*da0073e9SAndroid Build Coastguard Worker 8078*da0073e9SAndroid Build Coastguard Worker # check changes to t2 from t1 8079*da0073e9SAndroid Build Coastguard Worker rnum = random.uniform(-1, 1) 8080*da0073e9SAndroid Build Coastguard Worker t2.fill_(rnum) 8081*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, atol=0, rtol=0) 8082*da0073e9SAndroid Build Coastguard Worker 8083*da0073e9SAndroid Build Coastguard Worker # release the tensors 8084*da0073e9SAndroid Build Coastguard Worker del s1, t1, s2, t2 8085*da0073e9SAndroid Build Coastguard Worker 8086*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 8087*da0073e9SAndroid Build Coastguard Worker assert_with_filename(fname) 8088*da0073e9SAndroid Build Coastguard Worker 8089*da0073e9SAndroid Build Coastguard Worker if IS_FILESYSTEM_UTF8_ENCODING: 8090*da0073e9SAndroid Build Coastguard Worker with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname, TemporaryFileName(dir=dname) as fname: 8091*da0073e9SAndroid Build Coastguard Worker assert_with_filename(fname) 8092*da0073e9SAndroid Build Coastguard Worker 8093*da0073e9SAndroid Build Coastguard Worker def test_torch_from_file(self): 8094*da0073e9SAndroid Build Coastguard Worker def assert_with_filename(filename): 8095*da0073e9SAndroid Build Coastguard Worker size = 10000 8096*da0073e9SAndroid Build Coastguard Worker s1 = torch.from_file(filename, True, size, dtype=torch.float) 8097*da0073e9SAndroid Build Coastguard Worker t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) 8098*da0073e9SAndroid Build Coastguard Worker 8099*da0073e9SAndroid Build Coastguard Worker # check mapping 8100*da0073e9SAndroid Build Coastguard Worker s2 = torch.from_file(filename, True, size, dtype=torch.float) 8101*da0073e9SAndroid Build Coastguard Worker t2 = torch.FloatTensor(s2) 8102*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, atol=0, rtol=0) 8103*da0073e9SAndroid Build Coastguard Worker 8104*da0073e9SAndroid Build Coastguard Worker # check changes to t1 from t2 8105*da0073e9SAndroid Build Coastguard Worker rnum = random.uniform(-1, 1) 8106*da0073e9SAndroid Build Coastguard Worker t1.fill_(rnum) 8107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, atol=0, rtol=0) 8108*da0073e9SAndroid Build Coastguard Worker 8109*da0073e9SAndroid Build Coastguard Worker # check changes to t2 from t1 8110*da0073e9SAndroid Build Coastguard Worker rnum = random.uniform(-1, 1) 8111*da0073e9SAndroid Build Coastguard Worker t2.fill_(rnum) 8112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, atol=0, rtol=0) 8113*da0073e9SAndroid Build Coastguard Worker 8114*da0073e9SAndroid Build Coastguard Worker # release the tensors 8115*da0073e9SAndroid Build Coastguard Worker del s1, t1, s2, t2 8116*da0073e9SAndroid Build Coastguard Worker 8117*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 8118*da0073e9SAndroid Build Coastguard Worker assert_with_filename(fname) 8119*da0073e9SAndroid Build Coastguard Worker 8120*da0073e9SAndroid Build Coastguard Worker if IS_FILESYSTEM_UTF8_ENCODING: 8121*da0073e9SAndroid Build Coastguard Worker with TemporaryDirectoryName(suffix='\u4e2d\u6587') as dname, TemporaryFileName(dir=dname) as fname: 8122*da0073e9SAndroid Build Coastguard Worker assert_with_filename(fname) 8123*da0073e9SAndroid Build Coastguard Worker 8124*da0073e9SAndroid Build Coastguard Worker def test_print(self): 8125*da0073e9SAndroid Build Coastguard Worker default_type = torch.tensor([]).type() 8126*da0073e9SAndroid Build Coastguard Worker for t in torch._tensor_classes: 8127*da0073e9SAndroid Build Coastguard Worker if t == torch.HalfTensor: 8128*da0073e9SAndroid Build Coastguard Worker continue # HalfTensor does not support fill 8129*da0073e9SAndroid Build Coastguard Worker if t.is_sparse: 8130*da0073e9SAndroid Build Coastguard Worker continue 8131*da0073e9SAndroid Build Coastguard Worker if t.is_cuda and not torch.cuda.is_available(): 8132*da0073e9SAndroid Build Coastguard Worker continue 8133*da0073e9SAndroid Build Coastguard Worker obj = t(100, 100).fill_(1) 8134*da0073e9SAndroid Build Coastguard Worker obj.__repr__() 8135*da0073e9SAndroid Build Coastguard Worker str(obj) 8136*da0073e9SAndroid Build Coastguard Worker # test half tensor 8137*da0073e9SAndroid Build Coastguard Worker obj = torch.rand(100, 100, device='cpu').half() 8138*da0073e9SAndroid Build Coastguard Worker obj.__repr__() 8139*da0073e9SAndroid Build Coastguard Worker str(obj) 8140*da0073e9SAndroid Build Coastguard Worker for t in torch._storage_classes: 8141*da0073e9SAndroid Build Coastguard Worker if t == torch.BFloat16Storage: 8142*da0073e9SAndroid Build Coastguard Worker continue # Fix once fill is enabled for bfloat16 8143*da0073e9SAndroid Build Coastguard Worker if t.is_cuda and not torch.cuda.is_available(): 8144*da0073e9SAndroid Build Coastguard Worker continue 8145*da0073e9SAndroid Build Coastguard Worker if t == torch.BoolStorage or t == torch.cuda.BoolStorage: 8146*da0073e9SAndroid Build Coastguard Worker obj = t(100).fill_(True) 8147*da0073e9SAndroid Build Coastguard Worker else: 8148*da0073e9SAndroid Build Coastguard Worker obj = t(100).fill_(1) 8149*da0073e9SAndroid Build Coastguard Worker obj.__repr__() 8150*da0073e9SAndroid Build Coastguard Worker str(obj) 8151*da0073e9SAndroid Build Coastguard Worker 8152*da0073e9SAndroid Build Coastguard Worker # test complex tensor 8153*da0073e9SAndroid Build Coastguard Worker # complex tensor print uses two formatters, one for real values 8154*da0073e9SAndroid Build Coastguard Worker # and the other for imag values. this is consistent with numpy 8155*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([2.3 + 4j, 7 + 6j]) 8156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8157*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([2.3000+4.j, 7.0000+6.j])''') 8158*da0073e9SAndroid Build Coastguard Worker 8159*da0073e9SAndroid Build Coastguard Worker # test complex half tensor 8160*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.25 + 4j, -7. + 6j], dtype=torch.chalf) 8161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8162*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([ 1.2500+4.j, -7.0000+6.j], dtype=torch.complex32)''') 8163*da0073e9SAndroid Build Coastguard Worker 8164*da0073e9SAndroid Build Coastguard Worker # test scientific notation for complex tensors 8165*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1e28 + 2j , -1e-28j]) 8166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8167*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1.0000e+28+2.0000e+00j, -0.0000e+00-1.0000e-28j])''') 8168*da0073e9SAndroid Build Coastguard Worker 8169*da0073e9SAndroid Build Coastguard Worker # test big integer 8170*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2341234123412341) 8171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8172*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor(2341234123412341)''') 8173*da0073e9SAndroid Build Coastguard Worker 8174*da0073e9SAndroid Build Coastguard Worker # test scientific notation 8175*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1e28, 1e-28]) 8176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8177*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''') 8178*da0073e9SAndroid Build Coastguard Worker 8179*da0073e9SAndroid Build Coastguard Worker # test scientific notation using set_printoptions 8180*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1e2, 1e-2]) 8181*da0073e9SAndroid Build Coastguard Worker torch.set_printoptions(sci_mode=True) 8182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8183*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''') 8184*da0073e9SAndroid Build Coastguard Worker torch.set_printoptions(sci_mode=False) 8185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8186*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''') 8187*da0073e9SAndroid Build Coastguard Worker torch.set_printoptions(sci_mode=None) # reset to the default value 8188*da0073e9SAndroid Build Coastguard Worker 8189*da0073e9SAndroid Build Coastguard Worker # test no leading space if all elements positive 8190*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 8191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8192*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1, 2])''') 8193*da0073e9SAndroid Build Coastguard Worker 8194*da0073e9SAndroid Build Coastguard Worker # test for leading space if there are negative elements 8195*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, -2]) 8196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8197*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([ 1, -2])''') 8198*da0073e9SAndroid Build Coastguard Worker 8199*da0073e9SAndroid Build Coastguard Worker # test inf and nan 8200*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1]) 8201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8202*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([4.0000, inf, 1.5000, -inf, 0.0000, nan, 1.0000])''') 8203*da0073e9SAndroid Build Coastguard Worker 8204*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([4, inf, complex(1.5, inf), complex(-inf, 4), 0, complex(nan, inf), complex(3, nan)]) 8205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.__repr__(), str(y)) 8206*da0073e9SAndroid Build Coastguard Worker expected_str = '''\ 8207*da0073e9SAndroid Build Coastguard Workertensor([4.0000+0.j, inf+0.j, 1.5000+infj, -inf+4.j, 0.0000+0.j, nan+infj, 8208*da0073e9SAndroid Build Coastguard Worker 3.0000+nanj])''' 8209*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(y), expected_str) 8210*da0073e9SAndroid Build Coastguard Worker 8211*da0073e9SAndroid Build Coastguard Worker # test dtype 8212*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 8213*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64) 8214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8215*da0073e9SAndroid Build Coastguard Worker expected_str = '''\ 8216*da0073e9SAndroid Build Coastguard Workertensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, 8217*da0073e9SAndroid Build Coastguard Worker inf], dtype=torch.float64)''' 8218*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), expected_str) 8219*da0073e9SAndroid Build Coastguard Worker 8220*da0073e9SAndroid Build Coastguard Worker # test changing default dtype 8221*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float64): 8222*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8223*da0073e9SAndroid Build Coastguard Worker expected_str = '''\ 8224*da0073e9SAndroid Build Coastguard Workertensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, 8225*da0073e9SAndroid Build Coastguard Worker inf])''' 8226*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), expected_str) 8227*da0073e9SAndroid Build Coastguard Worker 8228*da0073e9SAndroid Build Coastguard Worker # test summary 8229*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(10000) 8230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8231*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([0., 0., 0., ..., 0., 0., 0.])''') 8232*da0073e9SAndroid Build Coastguard Worker 8233*da0073e9SAndroid Build Coastguard Worker # test internal summary function 8234*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 20, 5, 30) 8235*da0073e9SAndroid Build Coastguard Worker summary = torch._tensor_str.get_summarized_data(x) 8236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(summary.shape, (1, 6, 5, 6)) 8237*da0073e9SAndroid Build Coastguard Worker first_and_last = [0, 1, 2, -3, -2, -1] 8238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(summary, x[:, first_and_last][..., first_and_last]) 8239*da0073e9SAndroid Build Coastguard Worker 8240*da0073e9SAndroid Build Coastguard Worker # test device 8241*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 8242*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([123], device='cuda:0') 8243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8244*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''') 8245*da0073e9SAndroid Build Coastguard Worker 8246*da0073e9SAndroid Build Coastguard Worker # test changing default to cuda 8247*da0073e9SAndroid Build Coastguard Worker torch.set_default_tensor_type(torch.cuda.FloatTensor) 8248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8249*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([123])''') 8250*da0073e9SAndroid Build Coastguard Worker 8251*da0073e9SAndroid Build Coastguard Worker # test printing a tensor on a different gpu than current one. 8252*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() >= 2: 8253*da0073e9SAndroid Build Coastguard Worker with torch.cuda.device(1): 8254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8255*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''') 8256*da0073e9SAndroid Build Coastguard Worker 8257*da0073e9SAndroid Build Coastguard Worker # test printing cpu tensor when default device is cuda 8258*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([123], device='cpu') 8259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.__repr__(), str(y)) 8260*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(y), '''tensor([123], device='cpu')''') 8261*da0073e9SAndroid Build Coastguard Worker torch.set_default_tensor_type(default_type) 8262*da0073e9SAndroid Build Coastguard Worker 8263*da0073e9SAndroid Build Coastguard Worker 8264*da0073e9SAndroid Build Coastguard Worker # test integral floats and requires_grad 8265*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([123.], requires_grad=True) 8266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8267*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''') 8268*da0073e9SAndroid Build Coastguard Worker 8269*da0073e9SAndroid Build Coastguard Worker # test non-contiguous print 8270*da0073e9SAndroid Build Coastguard Worker # sliced tensor should have > PRINT_OPTS.threshold elements 8271*da0073e9SAndroid Build Coastguard Worker x = torch.ones(100, 2, 2, 10) 8272*da0073e9SAndroid Build Coastguard Worker y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1)) 8273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(y), y.__repr__()) 8274*da0073e9SAndroid Build Coastguard Worker expected_str = '''\ 8275*da0073e9SAndroid Build Coastguard Workertensor([[[1., 1., 1., ..., 1., 1., 1.], 8276*da0073e9SAndroid Build Coastguard Worker [1., 1., 1., ..., 1., 1., 1.]], 8277*da0073e9SAndroid Build Coastguard Worker 8278*da0073e9SAndroid Build Coastguard Worker [[1., 1., 1., ..., 1., 1., 1.], 8279*da0073e9SAndroid Build Coastguard Worker [1., 1., 1., ..., 1., 1., 1.]], 8280*da0073e9SAndroid Build Coastguard Worker 8281*da0073e9SAndroid Build Coastguard Worker [[1., 1., 1., ..., 1., 1., 1.], 8282*da0073e9SAndroid Build Coastguard Worker [1., 1., 1., ..., 1., 1., 1.]], 8283*da0073e9SAndroid Build Coastguard Worker 8284*da0073e9SAndroid Build Coastguard Worker ..., 8285*da0073e9SAndroid Build Coastguard Worker 8286*da0073e9SAndroid Build Coastguard Worker [[1., 1., 1., ..., 1., 1., 1.], 8287*da0073e9SAndroid Build Coastguard Worker [1., 1., 1., ..., 1., 1., 1.]], 8288*da0073e9SAndroid Build Coastguard Worker 8289*da0073e9SAndroid Build Coastguard Worker [[1., 1., 1., ..., 1., 1., 1.], 8290*da0073e9SAndroid Build Coastguard Worker [1., 1., 1., ..., 1., 1., 1.]], 8291*da0073e9SAndroid Build Coastguard Worker 8292*da0073e9SAndroid Build Coastguard Worker [[1., 1., 1., ..., 1., 1., 1.], 8293*da0073e9SAndroid Build Coastguard Worker [1., 1., 1., ..., 1., 1., 1.]]])\ 8294*da0073e9SAndroid Build Coastguard Worker''' 8295*da0073e9SAndroid Build Coastguard Worker 8296*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(y), expected_str) 8297*da0073e9SAndroid Build Coastguard Worker 8298*da0073e9SAndroid Build Coastguard Worker x = torch.ones(100, 2, 2, 10) * (1 + 1j) 8299*da0073e9SAndroid Build Coastguard Worker y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1)) 8300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(y), y.__repr__()) 8301*da0073e9SAndroid Build Coastguard Worker expected_str = '''\ 8302*da0073e9SAndroid Build Coastguard Workertensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], 8303*da0073e9SAndroid Build Coastguard Worker [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]], 8304*da0073e9SAndroid Build Coastguard Worker 8305*da0073e9SAndroid Build Coastguard Worker [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], 8306*da0073e9SAndroid Build Coastguard Worker [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]], 8307*da0073e9SAndroid Build Coastguard Worker 8308*da0073e9SAndroid Build Coastguard Worker [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], 8309*da0073e9SAndroid Build Coastguard Worker [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]], 8310*da0073e9SAndroid Build Coastguard Worker 8311*da0073e9SAndroid Build Coastguard Worker ..., 8312*da0073e9SAndroid Build Coastguard Worker 8313*da0073e9SAndroid Build Coastguard Worker [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], 8314*da0073e9SAndroid Build Coastguard Worker [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]], 8315*da0073e9SAndroid Build Coastguard Worker 8316*da0073e9SAndroid Build Coastguard Worker [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], 8317*da0073e9SAndroid Build Coastguard Worker [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]], 8318*da0073e9SAndroid Build Coastguard Worker 8319*da0073e9SAndroid Build Coastguard Worker [[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], 8320*da0073e9SAndroid Build Coastguard Worker [1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j]]])\ 8321*da0073e9SAndroid Build Coastguard Worker''' 8322*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(y), expected_str) 8323*da0073e9SAndroid Build Coastguard Worker 8324*da0073e9SAndroid Build Coastguard Worker # test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style 8325*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.00002) 8326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8327*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''') 8328*da0073e9SAndroid Build Coastguard Worker 8329*da0073e9SAndroid Build Coastguard Worker # test print boolean tensor 8330*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([True]) 8331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8332*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([True])''') 8333*da0073e9SAndroid Build Coastguard Worker 8334*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(True) 8335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8336*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor(True)''') 8337*da0073e9SAndroid Build Coastguard Worker 8338*da0073e9SAndroid Build Coastguard Worker # [Numpy] test print float in sci_mode when min < 0.0001. 8339*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.00002]) 8340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8341*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''') 8342*da0073e9SAndroid Build Coastguard Worker 8343*da0073e9SAndroid Build Coastguard Worker # [Numpy] test print complex in sci_mode when real_min < 0.0001 and (or) imag_min < 0.0001. 8344*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.00002]) * (1 + 1j) 8345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8346*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([2.0000e-05+2.0000e-05j])''') 8347*da0073e9SAndroid Build Coastguard Worker 8348*da0073e9SAndroid Build Coastguard Worker # [Numpy] test print float in sci_mode when max > 1e8. 8349*da0073e9SAndroid Build Coastguard Worker # TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific 8350*da0073e9SAndroid Build Coastguard Worker # to do automatic trimming and padding. 8351*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([123456789.]) 8352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8353*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''') 8354*da0073e9SAndroid Build Coastguard Worker 8355*da0073e9SAndroid Build Coastguard Worker # [Numpy] test print float in sci_mode when max / min > 1000. 8356*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.01, 11]) 8357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8358*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''') 8359*da0073e9SAndroid Build Coastguard Worker 8360*da0073e9SAndroid Build Coastguard Worker # [Numpy] test print int max / min > 1000, no sci_mode 8361*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 1010]) 8362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8363*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([ 1, 1010])''') 8364*da0073e9SAndroid Build Coastguard Worker 8365*da0073e9SAndroid Build Coastguard Worker # [Numpy] test print int > 1e8, no sci_mode 8366*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1000000000]) # 1e9 8367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8368*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1000000000])''') 8369*da0073e9SAndroid Build Coastguard Worker 8370*da0073e9SAndroid Build Coastguard Worker # [Numpy] test printing float in int_mode 8371*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1., 1000.]) 8372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8373*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([ 1., 1000.])''') 8374*da0073e9SAndroid Build Coastguard Worker 8375*da0073e9SAndroid Build Coastguard Worker # [Numpy] test printing float in int_mode in sci format when max / min > 1000. 8376*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1., 1010.]) 8377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.__repr__(), str(x)) 8378*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''') 8379*da0073e9SAndroid Build Coastguard Worker 8380*da0073e9SAndroid Build Coastguard Worker def test_sizeof(self) -> None: 8381*da0073e9SAndroid Build Coastguard Worker sizeof_empty = torch.randn(0).storage().__sizeof__() 8382*da0073e9SAndroid Build Coastguard Worker sizeof_10 = torch.randn(10).storage().__sizeof__() 8383*da0073e9SAndroid Build Coastguard Worker sizeof_100 = torch.randn(100).storage().__sizeof__() 8384*da0073e9SAndroid Build Coastguard Worker self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) 8385*da0073e9SAndroid Build Coastguard Worker self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) 8386*da0073e9SAndroid Build Coastguard Worker 8387*da0073e9SAndroid Build Coastguard Worker sizeof_empty = torch.randn(0).to(torch.uint8).storage().__sizeof__() 8388*da0073e9SAndroid Build Coastguard Worker sizeof_10 = torch.randn(10).to(torch.uint8).storage().__sizeof__() 8389*da0073e9SAndroid Build Coastguard Worker sizeof_100 = torch.randn(100).to(torch.uint8).storage().__sizeof__() 8390*da0073e9SAndroid Build Coastguard Worker self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) 8391*da0073e9SAndroid Build Coastguard Worker self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) 8392*da0073e9SAndroid Build Coastguard Worker 8393*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 8394*da0073e9SAndroid Build Coastguard Worker def test_resizable(self) -> None: 8395*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5) 8396*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.storage().resizable()) 8397*da0073e9SAndroid Build Coastguard Worker x.numpy() 8398*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.storage().resizable()) 8399*da0073e9SAndroid Build Coastguard Worker 8400*da0073e9SAndroid Build Coastguard Worker def test_iter(self) -> None: 8401*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 8402*da0073e9SAndroid Build Coastguard Worker for i, sub in enumerate(x): 8403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sub, x[i]) # noqa: PLR1736 8404*da0073e9SAndroid Build Coastguard Worker 8405*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([]) 8406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(x), []) 8407*da0073e9SAndroid Build Coastguard Worker 8408*da0073e9SAndroid Build Coastguard Worker def test_new(self) -> None: 8409*da0073e9SAndroid Build Coastguard Worker x = torch.autograd.Variable(torch.tensor([])) 8410*da0073e9SAndroid Build Coastguard Worker y = torch.autograd.Variable(torch.randn(4, 4)) 8411*da0073e9SAndroid Build Coastguard Worker z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) 8412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new().shape, [0]) 8413*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(), x) 8414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(1, 2).shape, [1, 2]) 8415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4]) 8416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([3, 4]).shape, [2]) 8417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([3, 4]).tolist(), [3, 4]) 8418*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new((3, 4)).tolist(), [3, 4]) 8419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4]) 8420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4]) 8421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4]) 8422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(size=(3, 4)).shape, [3, 4]) 8423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(()).shape, [0]) 8424*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr()) 8425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.new(y).data_ptr(), y.data_ptr()) 8426*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(x.new(y), y) 8427*da0073e9SAndroid Build Coastguard Worker 8428*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: x.new(z)) 8429*da0073e9SAndroid Build Coastguard Worker # TypeError would be better 8430*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.new(z.storage())) 8431*da0073e9SAndroid Build Coastguard Worker 8432*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") 8433*da0073e9SAndroid Build Coastguard Worker def test_pin_memory(self): 8434*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 5) 8435*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.is_pinned()) 8436*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 8437*da0073e9SAndroid Build Coastguard Worker pinned = x.pin_memory() 8438*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pinned.is_pinned()) 8439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pinned, x) 8440*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(pinned.data_ptr(), x.data_ptr()) 8441*da0073e9SAndroid Build Coastguard Worker # test that pin_memory on already pinned tensor has no effect 8442*da0073e9SAndroid Build Coastguard Worker self.assertIs(pinned, pinned.pin_memory()) 8443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) 8444*da0073e9SAndroid Build Coastguard Worker 8445*da0073e9SAndroid Build Coastguard Worker def test_error_msg_type_translation(self): 8446*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8447*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8448*da0073e9SAndroid Build Coastguard Worker # message includes both Double and Long 8449*da0073e9SAndroid Build Coastguard Worker '(?=.*Double)(?=.*Long)'): 8450*da0073e9SAndroid Build Coastguard Worker 8451*da0073e9SAndroid Build Coastguard Worker # Calls model with a LongTensor input but DoubleTensor weights 8452*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(1, 1, 1, 6, dtype=torch.long) 8453*da0073e9SAndroid Build Coastguard Worker weight = torch.nn.Parameter(torch.zeros(1, 1, 1, 3, dtype=torch.double)) 8454*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False) 8455*da0073e9SAndroid Build Coastguard Worker model.weight = weight 8456*da0073e9SAndroid Build Coastguard Worker out = model(input) 8457*da0073e9SAndroid Build Coastguard Worker 8458*da0073e9SAndroid Build Coastguard Worker def test_apply(self): 8459*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1, 6) 8460*da0073e9SAndroid Build Coastguard Worker res = x.clone().apply_(lambda k: k + k) 8461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, x * 2) 8462*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: x.apply_(lambda k: "str")) 8463*da0073e9SAndroid Build Coastguard Worker 8464*da0073e9SAndroid Build Coastguard Worker def test_map(self): 8465*da0073e9SAndroid Build Coastguard Worker x = torch.autograd.Variable(torch.randn(3, 3)) 8466*da0073e9SAndroid Build Coastguard Worker y = torch.autograd.Variable(torch.randn(3)) 8467*da0073e9SAndroid Build Coastguard Worker res = x.clone() 8468*da0073e9SAndroid Build Coastguard Worker res.map_(y, lambda a, b: a + b) 8469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, x + y) 8470*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, "not callable", lambda: res.map_(y, "str")) 8471*da0073e9SAndroid Build Coastguard Worker 8472*da0073e9SAndroid Build Coastguard Worker def test_map2(self): 8473*da0073e9SAndroid Build Coastguard Worker x = torch.autograd.Variable(torch.randn(3, 3)) 8474*da0073e9SAndroid Build Coastguard Worker y = torch.autograd.Variable(torch.randn(3)) 8475*da0073e9SAndroid Build Coastguard Worker z = torch.autograd.Variable(torch.randn(1, 3)) 8476*da0073e9SAndroid Build Coastguard Worker res = x.clone() 8477*da0073e9SAndroid Build Coastguard Worker res.map2_(y, z, lambda a, b, c: a + b * c) 8478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, x + y * z) 8479*da0073e9SAndroid Build Coastguard Worker z.requires_grad = True 8480*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 8481*da0073e9SAndroid Build Coastguard Worker RuntimeError, "requires grad", 8482*da0073e9SAndroid Build Coastguard Worker lambda: res.map2_(y, z, lambda a, b, c: a + b * c)) 8483*da0073e9SAndroid Build Coastguard Worker 8484*da0073e9SAndroid Build Coastguard Worker def test_Size(self): 8485*da0073e9SAndroid Build Coastguard Worker x = torch.Size([1, 2, 3]) 8486*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x, tuple) 8487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[0], 1) 8488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[1], 2) 8489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[2], 3) 8490*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(x), 3) 8491*da0073e9SAndroid Build Coastguard Worker self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3))) 8492*da0073e9SAndroid Build Coastguard Worker 8493*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x * 2, torch.Size) 8494*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x[:-1], torch.Size) 8495*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x + x, torch.Size) 8496*da0073e9SAndroid Build Coastguard Worker 8497*da0073e9SAndroid Build Coastguard Worker def test_Size_scalar(self): 8498*da0073e9SAndroid Build Coastguard Worker three = torch.tensor(3) 8499*da0073e9SAndroid Build Coastguard Worker two = torch.tensor(2) 8500*da0073e9SAndroid Build Coastguard Worker x = torch.Size([0, 1, two, three, 4]) 8501*da0073e9SAndroid Build Coastguard Worker for i in range(1, 5): 8502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i], i) 8503*da0073e9SAndroid Build Coastguard Worker 8504*da0073e9SAndroid Build Coastguard Worker def test_Size_iter(self): 8505*da0073e9SAndroid Build Coastguard Worker for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]: 8506*da0073e9SAndroid Build Coastguard Worker x = torch.Size(sizes) 8507*da0073e9SAndroid Build Coastguard Worker for i in range(0, 5): 8508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[i], i + 1) 8509*da0073e9SAndroid Build Coastguard Worker 8510*da0073e9SAndroid Build Coastguard Worker def test_t_not_2d_error(self): 8511*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t()) 8512*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_()) 8513*da0073e9SAndroid Build Coastguard Worker 8514*da0073e9SAndroid Build Coastguard Worker # skip this test for now as it affects all tests 8515*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "flush_denormal not supported") 8516*da0073e9SAndroid Build Coastguard Worker def test_set_flush_denormal(self): 8517*da0073e9SAndroid Build Coastguard Worker tiny_float = 1e-42 8518*da0073e9SAndroid Build Coastguard Worker tiny_double = 1e-320 8519*da0073e9SAndroid Build Coastguard Worker float_tensor = torch.FloatTensor([1.0, tiny_float]) 8520*da0073e9SAndroid Build Coastguard Worker double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double]) 8521*da0073e9SAndroid Build Coastguard Worker 8522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0) 8523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_tensor[1], tiny_float, atol=tiny_float / 16, rtol=0) 8524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0) 8525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0) 8526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_tensor[2], tiny_double, atol=0.0, rtol=0) 8527*da0073e9SAndroid Build Coastguard Worker 8528*da0073e9SAndroid Build Coastguard Worker torch.set_flush_denormal(True) 8529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_tensor[0], 1.0, atol=0.0, rtol=0) 8530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_tensor[1], 0.0, atol=0.0, rtol=0) # tiny_float to zero 8531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_tensor[0], 1.0, atol=0.0, rtol=0) 8532*da0073e9SAndroid Build Coastguard Worker # tiny_float is not converted to zero in double type 8533*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_tensor[1], tiny_float, atol=0.0, rtol=0) 8534*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_tensor[2], 0.0, atol=0.0, rtol=0) # tiny_double to zero 8535*da0073e9SAndroid Build Coastguard Worker torch.set_flush_denormal(False) 8536*da0073e9SAndroid Build Coastguard Worker 8537*da0073e9SAndroid Build Coastguard Worker def test_show_config(self): 8538*da0073e9SAndroid Build Coastguard Worker # We can't usefully test the output; just make sure this doesn't crash 8539*da0073e9SAndroid Build Coastguard Worker torch.__config__.show() 8540*da0073e9SAndroid Build Coastguard Worker 8541*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE, "CXX_FLAGS is only for OSS build.") 8542*da0073e9SAndroid Build Coastguard Worker def test_cxx_flags(self): 8543*da0073e9SAndroid Build Coastguard Worker torch.__config__._cxx_flags() 8544*da0073e9SAndroid Build Coastguard Worker 8545*da0073e9SAndroid Build Coastguard Worker def test_parallel_info(self): 8546*da0073e9SAndroid Build Coastguard Worker torch.__config__.parallel_info() 8547*da0073e9SAndroid Build Coastguard Worker 8548*da0073e9SAndroid Build Coastguard Worker def test_get_cpu_capability(self): 8549*da0073e9SAndroid Build Coastguard Worker # This method is primarily exposed for torchvision's resize 8550*da0073e9SAndroid Build Coastguard Worker torch.backends.cpu.get_cpu_capability() 8551*da0073e9SAndroid Build Coastguard Worker 8552*da0073e9SAndroid Build Coastguard Worker # We have to ensure that method is torchscriptable as torchvision's resize 8553*da0073e9SAndroid Build Coastguard Worker # should be torchscriptable 8554*da0073e9SAndroid Build Coastguard Worker torch.jit.script(torch.backends.cpu.get_cpu_capability) 8555*da0073e9SAndroid Build Coastguard Worker 8556*da0073e9SAndroid Build Coastguard Worker @slowTest 8557*da0073e9SAndroid Build Coastguard Worker def test_slow_test(self): 8558*da0073e9SAndroid Build Coastguard Worker # Just a smoketest to make sure our slowTest decorator works. 8559*da0073e9SAndroid Build Coastguard Worker pass 8560*da0073e9SAndroid Build Coastguard Worker 8561*da0073e9SAndroid Build Coastguard Worker def test_is_nonzero(self): 8562*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"): 8563*da0073e9SAndroid Build Coastguard Worker torch.tensor([]).is_nonzero() 8564*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"): 8565*da0073e9SAndroid Build Coastguard Worker torch.tensor([0, 0]).is_nonzero() 8566*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor(0).is_nonzero()) 8567*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor(1).is_nonzero()) 8568*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor([0]).is_nonzero()) 8569*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor([1]).is_nonzero()) 8570*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor([[0]]).is_nonzero()) 8571*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor([[1]]).is_nonzero()) 8572*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor(0.1).is_nonzero()) 8573*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor(-0.1).is_nonzero()) 8574*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor(0.0).is_nonzero()) 8575*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor(True).is_nonzero()) 8576*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor(False).is_nonzero()) 8577*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.tensor(0 + 0j).is_nonzero()) 8578*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.tensor(0 + 0.1j).is_nonzero()) 8579*da0073e9SAndroid Build Coastguard Worker 8580*da0073e9SAndroid Build Coastguard Worker def test_assert_async(self): 8581*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"): 8582*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor([])) 8583*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"): 8584*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor([0, 0])) 8585*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"): 8586*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0)) 8587*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(1)) 8588*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0.1)) 8589*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(-0.1)) 8590*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"): 8591*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0.0)) 8592*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(True)) 8593*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"): 8594*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(False)) 8595*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0 + 0.1j)) 8596*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"): 8597*da0073e9SAndroid Build Coastguard Worker torch._assert_async(torch.tensor(0 + 0j)) 8598*da0073e9SAndroid Build Coastguard Worker 8599*da0073e9SAndroid Build Coastguard Worker # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA 8600*da0073e9SAndroid Build Coastguard Worker # is available, we get a different error. 8601*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error") 8602*da0073e9SAndroid Build Coastguard Worker def test_cuda_not_built(self): 8603*da0073e9SAndroid Build Coastguard Worker msg = "Torch not compiled with CUDA enabled" 8604*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device()) 8605*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda")) 8606*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda()) 8607*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, msg, lambda: torch.cuda.FloatTensor()) 8608*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor)) 8609*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda")) 8610*da0073e9SAndroid Build Coastguard Worker 8611*da0073e9SAndroid Build Coastguard Worker def test_has_internal_overlap(self): 8612*da0073e9SAndroid Build Coastguard Worker OVERLAP_NO = 0 8613*da0073e9SAndroid Build Coastguard Worker OVERLAP_YES = 1 8614*da0073e9SAndroid Build Coastguard Worker OVERLAP_TOO_HARD = 2 8615*da0073e9SAndroid Build Coastguard Worker 8616*da0073e9SAndroid Build Coastguard Worker # Check for contiguous tensors 8617*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3) 8618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO) 8619*da0073e9SAndroid Build Coastguard Worker 8620*da0073e9SAndroid Build Coastguard Worker # Checks for zero strides 8621*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1, 3) 8622*da0073e9SAndroid Build Coastguard Worker b_expanded = b.expand(4, 3) 8623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES) 8624*da0073e9SAndroid Build Coastguard Worker 8625*da0073e9SAndroid Build Coastguard Worker # Check for zero strided, size 1 axis, in non-contiguous storage (gh-33812) 8626*da0073e9SAndroid Build Coastguard Worker c = torch.randn(10).as_strided([2, 1, 5], [1, 0, 2]) 8627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_NO) 8628*da0073e9SAndroid Build Coastguard Worker c = torch.randn(2, 1, 10)[::2].as_strided((2, 1, 5), (10, 0, 2)) 8629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._debug_has_internal_overlap(c), OVERLAP_TOO_HARD) 8630*da0073e9SAndroid Build Coastguard Worker 8631*da0073e9SAndroid Build Coastguard Worker def test_allow_tensor_metadata_change(self): 8632*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 3) 8633*da0073e9SAndroid Build Coastguard Worker # Metadata changes are allowed on view tensors that are created from detach(). 8634*da0073e9SAndroid Build Coastguard Worker 8635*da0073e9SAndroid Build Coastguard Worker def test_memory_format(self): 8636*da0073e9SAndroid Build Coastguard Worker def test_helper(x, memory_format): 8637*da0073e9SAndroid Build Coastguard Worker y = x.contiguous(memory_format=memory_format) 8638*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.is_contiguous()) 8639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous(memory_format=memory_format)) 8640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x) 8641*da0073e9SAndroid Build Coastguard Worker 8642*da0073e9SAndroid Build Coastguard Worker test_helper(torch.randn(4, 3, 8, 8), torch.channels_last) 8643*da0073e9SAndroid Build Coastguard Worker test_helper(torch.randn(4, 3, 8, 8, 8), torch.channels_last_3d) 8644*da0073e9SAndroid Build Coastguard Worker 8645*da0073e9SAndroid Build Coastguard Worker def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self): 8646*da0073e9SAndroid Build Coastguard Worker def test_helper(x, memory_format): 8647*da0073e9SAndroid Build Coastguard Worker alias = x.contiguous(memory_format=memory_format) 8648*da0073e9SAndroid Build Coastguard Worker alias.fill_(7) 8649*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, alias) 8650*da0073e9SAndroid Build Coastguard Worker 8651*da0073e9SAndroid Build Coastguard Worker test_helper(torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2), torch.channels_last) 8652*da0073e9SAndroid Build Coastguard Worker test_helper(torch.randn(4, 8, 8, 8, 3).permute(0, 4, 1, 2, 3), torch.channels_last_3d) 8653*da0073e9SAndroid Build Coastguard Worker 8654*da0073e9SAndroid Build Coastguard Worker def test_memory_format_empty(self): 8655*da0073e9SAndroid Build Coastguard Worker def test_helper(dim1, dim2, memory_format): 8656*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 8657*da0073e9SAndroid Build Coastguard Worker x = torch.empty(dim1, memory_format=memory_format) 8658*da0073e9SAndroid Build Coastguard Worker x = torch.empty(dim2, memory_format=memory_format) 8659*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.is_contiguous(memory_format=memory_format)) 8660*da0073e9SAndroid Build Coastguard Worker 8661*da0073e9SAndroid Build Coastguard Worker test_helper((3, 3), (3, 3, 3, 3), torch.channels_last) 8662*da0073e9SAndroid Build Coastguard Worker test_helper((3, 3, 3), (3, 3, 3, 3, 3), torch.channels_last_3d) 8663*da0073e9SAndroid Build Coastguard Worker 8664*da0073e9SAndroid Build Coastguard Worker def test_dim_order(self): 8665*da0073e9SAndroid Build Coastguard Worker shape = (2, 3, 5, 7) 8666*da0073e9SAndroid Build Coastguard Worker 8667*da0073e9SAndroid Build Coastguard Worker t = torch.empty(shape) 8668*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual(t.dim_order(), (0, 1, 2, 3), seq_type=tuple) 8669*da0073e9SAndroid Build Coastguard Worker # transpose doesn't really change the underlying physical memory 8670*da0073e9SAndroid Build Coastguard Worker # so expecting dim_order change to reflect that (like strides) 8671*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual(t.transpose(0, 1).dim_order(), (1, 0, 2, 3)) 8672*da0073e9SAndroid Build Coastguard Worker 8673*da0073e9SAndroid Build Coastguard Worker t = torch.empty(shape, memory_format=torch.channels_last) 8674*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 1)) 8675*da0073e9SAndroid Build Coastguard Worker 8676*da0073e9SAndroid Build Coastguard Worker t = torch.empty((2, 3, 5, 7, 8), memory_format=torch.channels_last_3d) 8677*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual(t.dim_order(), (0, 2, 3, 4, 1)) 8678*da0073e9SAndroid Build Coastguard Worker 8679*da0073e9SAndroid Build Coastguard Worker for dim_order in itertools.permutations(range(4)): 8680*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual( 8681*da0073e9SAndroid Build Coastguard Worker dim_order, torch.empty_permuted(shape, dim_order).dim_order() 8682*da0073e9SAndroid Build Coastguard Worker ) 8683*da0073e9SAndroid Build Coastguard Worker 8684*da0073e9SAndroid Build Coastguard Worker for shape in [(2, 2, 2, 2), (2, 1, 2, 2), (2, 2, 1, 2), (2, 2, 2, 1), (2, 2, 1, 1), (2, 1, 1, 2)]: 8685*da0073e9SAndroid Build Coastguard Worker for memory_format in (torch.contiguous_format, torch.channels_last): 8686*da0073e9SAndroid Build Coastguard Worker t = torch.empty(shape).to(memory_format=memory_format) 8687*da0073e9SAndroid Build Coastguard Worker if memory_format == torch.contiguous_format: 8688*da0073e9SAndroid Build Coastguard Worker dim_order_target = list(range(len(shape))) 8689*da0073e9SAndroid Build Coastguard Worker elif memory_format == torch.channels_last: 8690*da0073e9SAndroid Build Coastguard Worker dim_order_target = [0, *list(range(2, len(shape))), 1] 8691*da0073e9SAndroid Build Coastguard Worker 8692*da0073e9SAndroid Build Coastguard Worker self.assertSequenceEqual(dim_order_target, t.dim_order()) 8693*da0073e9SAndroid Build Coastguard Worker 8694*da0073e9SAndroid Build Coastguard Worker def test_subclass_tensors(self): 8695*da0073e9SAndroid Build Coastguard Worker # raise an error when trying to subclass FloatTensor 8696*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"): 8697*da0073e9SAndroid Build Coastguard Worker class Foo1(torch.FloatTensor): 8698*da0073e9SAndroid Build Coastguard Worker pass 8699*da0073e9SAndroid Build Coastguard Worker 8700*da0073e9SAndroid Build Coastguard Worker # but allow subclassing Tensor: 8701*da0073e9SAndroid Build Coastguard Worker class Foo2(torch.Tensor): 8702*da0073e9SAndroid Build Coastguard Worker def foo(self): 8703*da0073e9SAndroid Build Coastguard Worker return 5 8704*da0073e9SAndroid Build Coastguard Worker f = Foo2() 8705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f.foo(), 5) 8706*da0073e9SAndroid Build Coastguard Worker 8707*da0073e9SAndroid Build Coastguard Worker def test_ndim(self): 8708*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 3) 8709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, a.ndim) 8710*da0073e9SAndroid Build Coastguard Worker b = torch.randn(()) 8711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, b.ndim) 8712*da0073e9SAndroid Build Coastguard Worker c = torch.randn(1, 0) 8713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, c.ndim) 8714*da0073e9SAndroid Build Coastguard Worker 8715*da0073e9SAndroid Build Coastguard Worker def test_nbytes(self): 8716*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 3, dtype=torch.float64) 8717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.numel() * a.element_size(), a.nbytes) 8718*da0073e9SAndroid Build Coastguard Worker b = torch.randn(()) 8719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.numel() * b.element_size(), b.nbytes) 8720*da0073e9SAndroid Build Coastguard Worker c = torch.randn(1, 0) 8721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.numel() * c.element_size(), c.nbytes) 8722*da0073e9SAndroid Build Coastguard Worker 8723*da0073e9SAndroid Build Coastguard Worker def test_fill_diagonal(self): 8724*da0073e9SAndroid Build Coastguard Worker a1 = torch.randn(7, 3) 8725*da0073e9SAndroid Build Coastguard Worker a2 = a1.clone() 8726*da0073e9SAndroid Build Coastguard Worker v = 1 8727*da0073e9SAndroid Build Coastguard Worker for i in range(3): 8728*da0073e9SAndroid Build Coastguard Worker a2[i][i] = v 8729*da0073e9SAndroid Build Coastguard Worker a1.fill_diagonal_(v) 8730*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a1, a2) 8731*da0073e9SAndroid Build Coastguard Worker 8732*da0073e9SAndroid Build Coastguard Worker b1 = torch.randn(7, 3) 8733*da0073e9SAndroid Build Coastguard Worker b2 = b1.clone() 8734*da0073e9SAndroid Build Coastguard Worker for i in range(3): 8735*da0073e9SAndroid Build Coastguard Worker b2[i][i] = v 8736*da0073e9SAndroid Build Coastguard Worker b2[i + 4][i] = v 8737*da0073e9SAndroid Build Coastguard Worker b1.fill_diagonal_(v, wrap=True) 8738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b1, b2) 8739*da0073e9SAndroid Build Coastguard Worker 8740*da0073e9SAndroid Build Coastguard Worker c1 = torch.rand(3, 3, 3) 8741*da0073e9SAndroid Build Coastguard Worker c2 = c1.clone() 8742*da0073e9SAndroid Build Coastguard Worker for i in range(3): 8743*da0073e9SAndroid Build Coastguard Worker c2[i][i][i] = v 8744*da0073e9SAndroid Build Coastguard Worker c1.fill_diagonal_(v) 8745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c1, c2) 8746*da0073e9SAndroid Build Coastguard Worker 8747*da0073e9SAndroid Build Coastguard Worker # non-contiguous tensor 8748*da0073e9SAndroid Build Coastguard Worker d1 = torch.rand(3, 3, 3)[:, 1, ...] 8749*da0073e9SAndroid Build Coastguard Worker d2 = d1.clone() 8750*da0073e9SAndroid Build Coastguard Worker for i in range(3): 8751*da0073e9SAndroid Build Coastguard Worker d2[i][i] = v 8752*da0073e9SAndroid Build Coastguard Worker d1.fill_diagonal_(v) 8753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d1, d2) 8754*da0073e9SAndroid Build Coastguard Worker 8755*da0073e9SAndroid Build Coastguard Worker e1 = torch.rand(7, 3, 3)[:, 1, ...] 8756*da0073e9SAndroid Build Coastguard Worker e2 = e1.clone() 8757*da0073e9SAndroid Build Coastguard Worker for i in range(3): 8758*da0073e9SAndroid Build Coastguard Worker e2[i][i] = v 8759*da0073e9SAndroid Build Coastguard Worker e2[i + 4][i] = v 8760*da0073e9SAndroid Build Coastguard Worker e1.fill_diagonal_(v, wrap=True) 8761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e1, e2) 8762*da0073e9SAndroid Build Coastguard Worker 8763*da0073e9SAndroid Build Coastguard Worker def test_setting_real_imag_to_a_number(self): 8764*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, dtype=torch.cfloat) 8765*da0073e9SAndroid Build Coastguard Worker x.real = 0 8766*da0073e9SAndroid Build Coastguard Worker x.imag = 0 8767*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros(4) 8768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.real, zeros) 8769*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.imag, zeros) 8770*da0073e9SAndroid Build Coastguard Worker 8771*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_cpu_inference(self): 8772*da0073e9SAndroid Build Coastguard Worker # input nchw in (2,1,1,1), (2,2,2,2) 8773*da0073e9SAndroid Build Coastguard Worker inputs = [ 8774*da0073e9SAndroid Build Coastguard Worker torch.tensor([[[[-0.5000]]], [[[0.5000]]]]), 8775*da0073e9SAndroid Build Coastguard Worker torch.tensor([ 8776*da0073e9SAndroid Build Coastguard Worker [ 8777*da0073e9SAndroid Build Coastguard Worker [[-0.5000, 0.5000], [-1.0000, 1.0000]], 8778*da0073e9SAndroid Build Coastguard Worker [[-0.2500, -0.5000], [0.2500, 0.5000]] 8779*da0073e9SAndroid Build Coastguard Worker ], 8780*da0073e9SAndroid Build Coastguard Worker [ 8781*da0073e9SAndroid Build Coastguard Worker [[0.1000, 1.0000], [1.0000, 0.1000]], 8782*da0073e9SAndroid Build Coastguard Worker [[1.0000, 0.5000], [1.5000, -1.5000]] 8783*da0073e9SAndroid Build Coastguard Worker ]])] 8784*da0073e9SAndroid Build Coastguard Worker # output nchw in (2,1,1,1), (2,2,2,2) 8785*da0073e9SAndroid Build Coastguard Worker outputs = [ 8786*da0073e9SAndroid Build Coastguard Worker torch.tensor([ 8787*da0073e9SAndroid Build Coastguard Worker [[[-0.499997496604919433593750000]]], 8788*da0073e9SAndroid Build Coastguard Worker [[[0.499997496604919433593750000]]]]), 8789*da0073e9SAndroid Build Coastguard Worker torch.tensor([ 8790*da0073e9SAndroid Build Coastguard Worker [[[-0.499997496604919433593750000, 0.499997496604919433593750000], 8791*da0073e9SAndroid Build Coastguard Worker [-0.999994993209838867187500000, 0.999994993209838867187500000]], 8792*da0073e9SAndroid Build Coastguard Worker [[-0.249998748302459716796875000, -0.499997496604919433593750000], 8793*da0073e9SAndroid Build Coastguard Worker [0.249998748302459716796875000, 0.499997496604919433593750000]]], 8794*da0073e9SAndroid Build Coastguard Worker [[[0.099999502301216125488281250, 0.999994993209838867187500000], 8795*da0073e9SAndroid Build Coastguard Worker [0.999994993209838867187500000, 0.099999502301216125488281250]], 8796*da0073e9SAndroid Build Coastguard Worker [[0.999994993209838867187500000, 0.499997496604919433593750000], 8797*da0073e9SAndroid Build Coastguard Worker [1.499992489814758300781250000, -1.499992489814758300781250000]]]])] 8798*da0073e9SAndroid Build Coastguard Worker 8799*da0073e9SAndroid Build Coastguard Worker 8800*da0073e9SAndroid Build Coastguard Worker for i in range(len(inputs)): 8801*da0073e9SAndroid Build Coastguard Worker for affine in [False, True]: 8802*da0073e9SAndroid Build Coastguard Worker m = torch.nn.BatchNorm2d(inputs[i].size()[1], 1e-05, 0.1, affine=affine) 8803*da0073e9SAndroid Build Coastguard Worker m.eval() 8804*da0073e9SAndroid Build Coastguard Worker # contiguous case 8805*da0073e9SAndroid Build Coastguard Worker input1 = inputs[i].contiguous() 8806*da0073e9SAndroid Build Coastguard Worker output1 = m(input1) 8807*da0073e9SAndroid Build Coastguard Worker # non-contiguous case 8808*da0073e9SAndroid Build Coastguard Worker input2 = input1.permute(0, 1, 3, 2) 8809*da0073e9SAndroid Build Coastguard Worker output2 = m(input2).permute(0, 1, 3, 2) 8810*da0073e9SAndroid Build Coastguard Worker # channels last case 8811*da0073e9SAndroid Build Coastguard Worker input3 = input1.contiguous(memory_format=torch.channels_last) 8812*da0073e9SAndroid Build Coastguard Worker output3 = m(input3) 8813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output3, outputs[i]) 8814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output3, output1) 8815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output3, output2) 8816*da0073e9SAndroid Build Coastguard Worker 8817*da0073e9SAndroid Build Coastguard Worker # FIXME: move these meta tests to their own test suite/class or 8818*da0073e9SAndroid Build Coastguard Worker # distribute them among the appropriate test suites for their ops 8819*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Fails after Triton update, see https://github.com/pytorch/pytorch/issues/94687") 8820*da0073e9SAndroid Build Coastguard Worker def test_empty_meta(self): 8821*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2 ** 20, 2 ** 20, device='meta') 8822*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2 ** 20, device='meta') 8823*da0073e9SAndroid Build Coastguard Worker z = x + y 8824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(), (2 ** 20, 2 ** 20)) 8825*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z[0][0].item()) 8826*da0073e9SAndroid Build Coastguard Worker 8827*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Fails after Triton update, see https://github.com/pytorch/pytorch/issues/94687") 8828*da0073e9SAndroid Build Coastguard Worker def test_format_scalar_meta(self): 8829*da0073e9SAndroid Build Coastguard Worker x = torch.empty((), device='meta') 8830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(format(x), repr(x)) 8831*da0073e9SAndroid Build Coastguard Worker 8832*da0073e9SAndroid Build Coastguard Worker def test_upsample_nearest1d_meta(self): 8833*da0073e9SAndroid Build Coastguard Worker # TODO: this test should be triggered by test_nn.py but right 8834*da0073e9SAndroid Build Coastguard Worker # now meta is not enabled (and even if it was, we are probably 8835*da0073e9SAndroid Build Coastguard Worker # missing too many meta functions to get through the test unmolested) 8836*da0073e9SAndroid Build Coastguard Worker 8837*da0073e9SAndroid Build Coastguard Worker # NB: Can't make the exponent too big, or it will overflow 8838*da0073e9SAndroid Build Coastguard Worker # signed 64-bit integer 8839*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2 * 10 ** 8, 3, 2 * 10 ** 8, device='meta') 8840*da0073e9SAndroid Build Coastguard Worker z = torch.nn.functional.interpolate(x, scale_factor=2) 8841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) 8842*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z[0][0][0].item()) 8843*da0073e9SAndroid Build Coastguard Worker 8844*da0073e9SAndroid Build Coastguard Worker # TODO: the out tests cannot be triggered by test_nn.py because 8845*da0073e9SAndroid Build Coastguard Worker # we don't actually do out= arguments for nn functions, so there 8846*da0073e9SAndroid Build Coastguard Worker # is no public API by which to get the out version 8847*da0073e9SAndroid Build Coastguard Worker 8848*da0073e9SAndroid Build Coastguard Worker # interpolate doesn't seem to support out= 8849*da0073e9SAndroid Build Coastguard Worker # (not sure why passing None here doesn't work? How strange...) 8850*da0073e9SAndroid Build Coastguard Worker z = torch.empty(0, device='meta') 8851*da0073e9SAndroid Build Coastguard Worker torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z) 8852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8)) 8853*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z[0][0][0].item()) 8854*da0073e9SAndroid Build Coastguard Worker 8855*da0073e9SAndroid Build Coastguard Worker def test_upsample_nearest2d_meta(self): 8856*da0073e9SAndroid Build Coastguard Worker # TODO: the out tests cannot be triggered by test_nn.py because 8857*da0073e9SAndroid Build Coastguard Worker # we don't actually do out= arguments for nn functions, so there 8858*da0073e9SAndroid Build Coastguard Worker # is no public API by which to get the out version 8859*da0073e9SAndroid Build Coastguard Worker 8860*da0073e9SAndroid Build Coastguard Worker # Make sure we don't clobber strides of out tensor. NB: this 8861*da0073e9SAndroid Build Coastguard Worker # test must be done on 2d/3d, because 1d doesn't have any meaningful 8862*da0073e9SAndroid Build Coastguard Worker # layout support 8863*da0073e9SAndroid Build Coastguard Worker x = torch.empty(4, 3, 8, 8, device='meta') 8864*da0073e9SAndroid Build Coastguard Worker out = torch.empty(4, 3, 16, 16, device='meta', memory_format=torch.channels_last) 8865*da0073e9SAndroid Build Coastguard Worker torch._C._nn.upsample_nearest2d(x, (16, 16), out=out) 8866*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 8867*da0073e9SAndroid Build Coastguard Worker 8868*da0073e9SAndroid Build Coastguard Worker x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last) 8869*da0073e9SAndroid Build Coastguard Worker out = torch.empty(4, 3, 16, 16, device='meta') 8870*da0073e9SAndroid Build Coastguard Worker torch._C._nn.upsample_nearest2d(x, (16, 16), out=out) 8871*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous()) 8872*da0073e9SAndroid Build Coastguard Worker 8873*da0073e9SAndroid Build Coastguard Worker # But if resize occurs, do clobber 8874*da0073e9SAndroid Build Coastguard Worker x = torch.empty(4, 3, 8, 8, device='meta', memory_format=torch.channels_last) 8875*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, device='meta') 8876*da0073e9SAndroid Build Coastguard Worker torch._C._nn.upsample_nearest2d(x, (16, 16), out=out) 8877*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) 8878*da0073e9SAndroid Build Coastguard Worker 8879*da0073e9SAndroid Build Coastguard Worker # Complain if out dtype mismatch 8880*da0073e9SAndroid Build Coastguard Worker x = torch.empty(4, 3, 8, 8, device='meta', dtype=torch.float) 8881*da0073e9SAndroid Build Coastguard Worker out = torch.empty(4, 3, 16, 16, device='meta', dtype=torch.double) 8882*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 8883*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out), 8884*da0073e9SAndroid Build Coastguard Worker """Expected out tensor to have dtype torch.float32 but got torch.float64 instead""" 8885*da0073e9SAndroid Build Coastguard Worker ) 8886*da0073e9SAndroid Build Coastguard Worker 8887*da0073e9SAndroid Build Coastguard Worker # Complain if out device mismatch 8888*da0073e9SAndroid Build Coastguard Worker x = torch.empty(0, 3, 8, 8, device='meta') 8889*da0073e9SAndroid Build Coastguard Worker out = torch.empty(0, 3, 16, 16, device='cpu') 8890*da0073e9SAndroid Build Coastguard Worker # FIXME: compiling should properly error with a device mismatch. 8891*da0073e9SAndroid Build Coastguard Worker if not TEST_WITH_TORCHINDUCTOR: 8892*da0073e9SAndroid Build Coastguard Worker self.assertExpectedRaisesInline( 8893*da0073e9SAndroid Build Coastguard Worker RuntimeError, lambda: torch._C._nn.upsample_nearest2d(x, (16, 16), out=out), 8894*da0073e9SAndroid Build Coastguard Worker """Attempting to copy from device meta to device cpu, but cross-device copies are not allowed!""" 8895*da0073e9SAndroid Build Coastguard Worker ) 8896*da0073e9SAndroid Build Coastguard Worker 8897*da0073e9SAndroid Build Coastguard Worker def test_add_meta_scalar(self): 8898*da0073e9SAndroid Build Coastguard Worker # From https://github.com/pytorch/pytorch/issues/53815 8899*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, device='meta') 8900*da0073e9SAndroid Build Coastguard Worker y = x + 2 8901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.size(), x.size()) 8902*da0073e9SAndroid Build Coastguard Worker 8903*da0073e9SAndroid Build Coastguard Worker def test_normal_shape(self): 8904*da0073e9SAndroid Build Coastguard Worker warned = False 8905*da0073e9SAndroid Build Coastguard Worker for device in get_all_device_types(): 8906*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(1, device=device) 8907*da0073e9SAndroid Build Coastguard Worker tensor4 = torch.rand(4, device=device) 8908*da0073e9SAndroid Build Coastguard Worker tensor120 = torch.rand(120, device=device) 8909*da0073e9SAndroid Build Coastguard Worker tensor2145 = torch.rand(2, 1, 4, 5, device=device) 8910*da0073e9SAndroid Build Coastguard Worker tensor2345 = torch.rand(2, 3, 4, 5, device=device) 8911*da0073e9SAndroid Build Coastguard Worker tensor2345_non_contiguous = torch.rand(2, 4, 3, 5, device=device).permute(0, 2, 1, 3) 8912*da0073e9SAndroid Build Coastguard Worker tensor2345_channels_last = tensor2345.contiguous(memory_format=torch.channels_last) 8913*da0073e9SAndroid Build Coastguard Worker output2345 = torch.zeros(2, 3, 4, 5, device=device) 8914*da0073e9SAndroid Build Coastguard Worker output345 = torch.zeros(3, 4, 5, device=device) 8915*da0073e9SAndroid Build Coastguard Worker 8916*da0073e9SAndroid Build Coastguard Worker # inputs have same size 8917*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, tensor2345).size(), (2, 3, 4, 5)) 8918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345).size(), (2, 3, 4, 5)) 8919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, tensor2345_channels_last).size(), (2, 3, 4, 5)) 8920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345_non_contiguous, tensor2345_channels_last).size(), (2, 3, 4, 5)) 8921*da0073e9SAndroid Build Coastguard Worker 8922*da0073e9SAndroid Build Coastguard Worker # scalar case 8923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, 2).size(), (2, 3, 4, 5)) 8924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(2, tensor2345).size(), (2, 3, 4, 5)) 8925*da0073e9SAndroid Build Coastguard Worker 8926*da0073e9SAndroid Build Coastguard Worker # inputs are expandable tensors 8927*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, tensor1).size(), (2, 3, 4, 5)) 8928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2145, tensor2345).size(), (2, 3, 4, 5)) 8929*da0073e9SAndroid Build Coastguard Worker 8930*da0073e9SAndroid Build Coastguard Worker # inputs are non-expandable tensors, but they have same number of elements 8931*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8932*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8933*da0073e9SAndroid Build Coastguard Worker r"The size of tensor a \(120\) must match the size of " 8934*da0073e9SAndroid Build Coastguard Worker r"tensor b \(5\) at non-singleton dimension 3"): 8935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,)) 8936*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8937*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8938*da0073e9SAndroid Build Coastguard Worker r"The size of tensor a \(5\) must match the size of " 8939*da0073e9SAndroid Build Coastguard Worker r"tensor b \(120\) at non-singleton dimension 3"): 8940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, tensor120).size(), (2, 3, 4, 5)) 8941*da0073e9SAndroid Build Coastguard Worker 8942*da0073e9SAndroid Build Coastguard Worker # inputs are non-expandable tensors and they don't have same number of elements 8943*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8944*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8945*da0073e9SAndroid Build Coastguard Worker r"The size of tensor a \(5\) must match the size of " 8946*da0073e9SAndroid Build Coastguard Worker r"tensor b \(4\) at non-singleton dimension 3"): 8947*da0073e9SAndroid Build Coastguard Worker torch.normal(tensor2345, tensor4) 8948*da0073e9SAndroid Build Coastguard Worker 8949*da0073e9SAndroid Build Coastguard Worker # output and inputs are size compatible 8950*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, tensor2345, out=output2345).size(), (2, 3, 4, 5)) 8951*da0073e9SAndroid Build Coastguard Worker 8952*da0073e9SAndroid Build Coastguard Worker # output and inputs are not size compatible 8953*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 8954*da0073e9SAndroid Build Coastguard Worker UserWarning, 8955*da0073e9SAndroid Build Coastguard Worker "This behavior is deprecated, and in a future PyTorch " 8956*da0073e9SAndroid Build Coastguard Worker "release outputs will not be resized unless they have " 8957*da0073e9SAndroid Build Coastguard Worker "zero elements"): 8958*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.normal(tensor2345, tensor2145, out=output345).size(), (2, 3, 4, 5)) 8959*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8960*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8961*da0073e9SAndroid Build Coastguard Worker r"The size of tensor a \(5\) must match the size of " 8962*da0073e9SAndroid Build Coastguard Worker r"tensor b \(120\) at non-singleton dimension 3"): 8963*da0073e9SAndroid Build Coastguard Worker # inputs are not expandable, output size is not the same as mean 8964*da0073e9SAndroid Build Coastguard Worker torch.normal(tensor2345, tensor120, out=output345) 8965*da0073e9SAndroid Build Coastguard Worker 8966*da0073e9SAndroid Build Coastguard Worker def test_tensoriterator_output_setup(self): 8967*da0073e9SAndroid Build Coastguard Worker # Test whether the output's memory layout is correct 8968*da0073e9SAndroid Build Coastguard Worker def test_memory_layout(x, y, scale, zero_point, out): 8969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.dim(), 4) 8970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(), y.size()) 8971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.size(), out.size()) 8972*da0073e9SAndroid Build Coastguard Worker 8973*da0073e9SAndroid Build Coastguard Worker shape = x.size() 8974*da0073e9SAndroid Build Coastguard Worker for n in range(shape[0]): 8975*da0073e9SAndroid Build Coastguard Worker for c in range(shape[1]): 8976*da0073e9SAndroid Build Coastguard Worker for h in range(shape[2]): 8977*da0073e9SAndroid Build Coastguard Worker for w in range(shape[3]): 8978*da0073e9SAndroid Build Coastguard Worker if scale is not None and zero_point is not None: 8979*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 8980*da0073e9SAndroid Build Coastguard Worker out[n][c][h][w], 8981*da0073e9SAndroid Build Coastguard Worker torch.ops.quantized.add(x[n][c][h][w], y[n][c][h][w], scale, zero_point)) 8982*da0073e9SAndroid Build Coastguard Worker else: 8983*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[n][c][h][w], x[n][c][h][w] + y[n][c][h][w]) 8984*da0073e9SAndroid Build Coastguard Worker 8985*da0073e9SAndroid Build Coastguard Worker xraw = torch.rand(2, 3, 4, 4) 8986*da0073e9SAndroid Build Coastguard Worker yraw = torch.rand(2, 3, 4, 4) 8987*da0073e9SAndroid Build Coastguard Worker qxraw = torch.quantize_per_tensor(xraw, 0.1, 5, torch.quint8) 8988*da0073e9SAndroid Build Coastguard Worker qyraw = torch.quantize_per_tensor(yraw, 0.1, 5, torch.quint8) 8989*da0073e9SAndroid Build Coastguard Worker 8990*da0073e9SAndroid Build Coastguard Worker # contiguous case fast setup 8991*da0073e9SAndroid Build Coastguard Worker test_memory_layout(xraw, yraw, None, None, xraw + yraw) 8992*da0073e9SAndroid Build Coastguard Worker test_memory_layout(qxraw, qyraw, 0.1, 5, torch.ops.quantized.add(qxraw, qyraw, 0.1, 5)) 8993*da0073e9SAndroid Build Coastguard Worker 8994*da0073e9SAndroid Build Coastguard Worker # channels last case fast setup 8995*da0073e9SAndroid Build Coastguard Worker x = xraw.contiguous(memory_format=torch.channels_last) 8996*da0073e9SAndroid Build Coastguard Worker y = yraw.contiguous(memory_format=torch.channels_last) 8997*da0073e9SAndroid Build Coastguard Worker test_memory_layout(x, y, None, None, x + y) 8998*da0073e9SAndroid Build Coastguard Worker qx = qxraw.contiguous(memory_format=torch.channels_last) 8999*da0073e9SAndroid Build Coastguard Worker qy = qyraw.contiguous(memory_format=torch.channels_last) 9000*da0073e9SAndroid Build Coastguard Worker test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5)) 9001*da0073e9SAndroid Build Coastguard Worker 9002*da0073e9SAndroid Build Coastguard Worker # non contiguous case fast setup (dense, non-overlapping, same shape and strides) 9003*da0073e9SAndroid Build Coastguard Worker x = xraw.permute(0, 2, 3, 1) 9004*da0073e9SAndroid Build Coastguard Worker y = yraw.permute(0, 2, 3, 1) 9005*da0073e9SAndroid Build Coastguard Worker test_memory_layout(x, y, None, None, x + y) 9006*da0073e9SAndroid Build Coastguard Worker qx = qxraw.permute(0, 2, 3, 1) 9007*da0073e9SAndroid Build Coastguard Worker qy = qyraw.permute(0, 2, 3, 1) 9008*da0073e9SAndroid Build Coastguard Worker test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5)) 9009*da0073e9SAndroid Build Coastguard Worker 9010*da0073e9SAndroid Build Coastguard Worker # non contiguous case fast setup (dense, non-overlapping) 9011*da0073e9SAndroid Build Coastguard Worker # input tensors have same shape and strides 9012*da0073e9SAndroid Build Coastguard Worker # output tensor have same shape as input tensors but different stride 9013*da0073e9SAndroid Build Coastguard Worker # output tensor should preserve its strides in this case 9014*da0073e9SAndroid Build Coastguard Worker x = xraw.permute(0, 2, 3, 1) 9015*da0073e9SAndroid Build Coastguard Worker y = yraw.permute(0, 2, 3, 1) 9016*da0073e9SAndroid Build Coastguard Worker out = torch.empty_like(xraw) 9017*da0073e9SAndroid Build Coastguard Worker out = out.permute(0, 3, 2, 1) 9018*da0073e9SAndroid Build Coastguard Worker expected_stride = out.stride() 9019*da0073e9SAndroid Build Coastguard Worker test_memory_layout(x, y, None, None, torch.add(x, y, out=out)) 9020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_stride, out.stride()) 9021*da0073e9SAndroid Build Coastguard Worker 9022*da0073e9SAndroid Build Coastguard Worker # non contiguous case non fast setup 9023*da0073e9SAndroid Build Coastguard Worker x = xraw.permute(0, 2, 3, 1) 9024*da0073e9SAndroid Build Coastguard Worker y = yraw.permute(0, 3, 2, 1) 9025*da0073e9SAndroid Build Coastguard Worker test_memory_layout(x, y, None, None, x + y) 9026*da0073e9SAndroid Build Coastguard Worker qx = qxraw.permute(0, 2, 3, 1) 9027*da0073e9SAndroid Build Coastguard Worker qy = qyraw.permute(0, 3, 2, 1) 9028*da0073e9SAndroid Build Coastguard Worker test_memory_layout(qx, qy, 0.1, 5, torch.ops.quantized.add(qx, qy, 0.1, 5)) 9029*da0073e9SAndroid Build Coastguard Worker 9030*da0073e9SAndroid Build Coastguard Worker # Tests to make sure we still handle .data properly until it is removed 9031*da0073e9SAndroid Build Coastguard Worker def test_dot_data_use(self): 9032*da0073e9SAndroid Build Coastguard Worker # .data allows to change the Tensors types inplace, check that we still 9033*da0073e9SAndroid Build Coastguard Worker # raise a nice error. 9034*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9035*da0073e9SAndroid Build Coastguard Worker RuntimeError, 9036*da0073e9SAndroid Build Coastguard Worker # message includes both Double and ComplexFloat 9037*da0073e9SAndroid Build Coastguard Worker '(?=.*Double)(?=.*ComplexFloat)'): 9038*da0073e9SAndroid Build Coastguard Worker 9039*da0073e9SAndroid Build Coastguard Worker # Calls model with a LongTensor input but DoubleTensor weights 9040*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 1, 1, 6, dtype=torch.double) 9041*da0073e9SAndroid Build Coastguard Worker weight = torch.zeros(1, 1, 1, 3, dtype=torch.complex64) 9042*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False) 9043*da0073e9SAndroid Build Coastguard Worker model.weight.data = weight 9044*da0073e9SAndroid Build Coastguard Worker out = model(input) 9045*da0073e9SAndroid Build Coastguard Worker 9046*da0073e9SAndroid Build Coastguard Worker def test_empty_storage_view(self): 9047*da0073e9SAndroid Build Coastguard Worker # we should be able to "modify" slices of a 0-element 9048*da0073e9SAndroid Build Coastguard Worker # array without an error being raised due to 9049*da0073e9SAndroid Build Coastguard Worker # trying to resize its storage 9050*da0073e9SAndroid Build Coastguard Worker t = torch.from_numpy(np.empty((0, 4))) 9051*da0073e9SAndroid Build Coastguard Worker t[:, 1::2] *= 1 9052*da0073e9SAndroid Build Coastguard Worker 9053*da0073e9SAndroid Build Coastguard Worker def test_has_storage(self): 9054*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(torch.tensor([]).storage()) 9055*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(torch.empty(0).storage()) 9056*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(torch.tensor([]).clone().storage()) 9057*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(torch.tensor([0, 0, 0]).nonzero().storage()) 9058*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(torch.tensor([]).new().storage()) 9059*da0073e9SAndroid Build Coastguard Worker 9060*da0073e9SAndroid Build Coastguard Worker # FIXME: Extend this test and put in a TensorProperties test class 9061*da0073e9SAndroid Build Coastguard Worker def test_numel(self): 9062*da0073e9SAndroid Build Coastguard Worker b = torch.ByteTensor(3, 100, 100) 9063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.nelement(), 3 * 100 * 100) 9064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.numel(), 3 * 100 * 100) 9065*da0073e9SAndroid Build Coastguard Worker 9066*da0073e9SAndroid Build Coastguard Worker # Verifies that (deep)copies of dtypes are the same objects 9067*da0073e9SAndroid Build Coastguard Worker def test_copy_dtypes(self): 9068*da0073e9SAndroid Build Coastguard Worker for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 9069*da0073e9SAndroid Build Coastguard Worker copied_dtype = copy.deepcopy(dtype) 9070*da0073e9SAndroid Build Coastguard Worker self.assertIs(dtype, copied_dtype) 9071*da0073e9SAndroid Build Coastguard Worker 9072*da0073e9SAndroid Build Coastguard Worker def test_dtype_is_signed(self): 9073*da0073e9SAndroid Build Coastguard Worker for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.half): 9074*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype))) 9075*da0073e9SAndroid Build Coastguard Worker 9076*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.quint8.is_signed) 9077*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint8.is_signed) 9078*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'not supported for quantized', lambda: torch.qint32.is_signed) 9079*da0073e9SAndroid Build Coastguard Worker 9080*da0073e9SAndroid Build Coastguard Worker # FIXME: Put the following random tests into their own test class or test suite 9081*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098") 9082*da0073e9SAndroid Build Coastguard Worker def test_RNGState(self): 9083*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 9084*da0073e9SAndroid Build Coastguard Worker stateCloned = state.clone() 9085*da0073e9SAndroid Build Coastguard Worker before = torch.rand(1000) 9086*da0073e9SAndroid Build Coastguard Worker 9087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(state.ne(stateCloned).long().sum(), 0, atol=0, rtol=0) 9088*da0073e9SAndroid Build Coastguard Worker 9089*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 9090*da0073e9SAndroid Build Coastguard Worker after = torch.rand(1000) 9091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before, after, atol=0, rtol=0) 9092*da0073e9SAndroid Build Coastguard Worker 9093*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098") 9094*da0073e9SAndroid Build Coastguard Worker def test_RNGStateAliasing(self): 9095*da0073e9SAndroid Build Coastguard Worker # Fork the random number stream at this point 9096*da0073e9SAndroid Build Coastguard Worker gen = torch.Generator() 9097*da0073e9SAndroid Build Coastguard Worker gen.set_state(torch.get_rng_state()) 9098*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gen.get_state(), torch.get_rng_state()) 9099*da0073e9SAndroid Build Coastguard Worker 9100*da0073e9SAndroid Build Coastguard Worker target_value = torch.rand(1000) 9101*da0073e9SAndroid Build Coastguard Worker # Dramatically alter the internal state of the main generator 9102*da0073e9SAndroid Build Coastguard Worker _ = torch.rand(100000) 9103*da0073e9SAndroid Build Coastguard Worker forked_value = torch.rand(1000, generator=gen) 9104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(target_value, forked_value, atol=0, rtol=0, msg="RNG has not forked correctly.") 9105*da0073e9SAndroid Build Coastguard Worker 9106*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098") 9107*da0073e9SAndroid Build Coastguard Worker def test_RNG_after_pickle(self): 9108*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(100) 9109*da0073e9SAndroid Build Coastguard Worker before = torch.rand(10) 9110*da0073e9SAndroid Build Coastguard Worker 9111*da0073e9SAndroid Build Coastguard Worker torch.random.manual_seed(100) 9112*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 9113*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor([1, 2, 3]) 9114*da0073e9SAndroid Build Coastguard Worker ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor) 9115*da0073e9SAndroid Build Coastguard Worker after = torch.rand(10) 9116*da0073e9SAndroid Build Coastguard Worker 9117*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before, after, atol=0, rtol=0) 9118*da0073e9SAndroid Build Coastguard Worker 9119*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098") 9120*da0073e9SAndroid Build Coastguard Worker def test_boxMullerState(self): 9121*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 9122*da0073e9SAndroid Build Coastguard Worker odd_number = 101 9123*da0073e9SAndroid Build Coastguard Worker seeded = torch.randn(odd_number) 9124*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 9125*da0073e9SAndroid Build Coastguard Worker midstream = torch.randn(odd_number) 9126*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 9127*da0073e9SAndroid Build Coastguard Worker repeat_midstream = torch.randn(odd_number) 9128*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(123) 9129*da0073e9SAndroid Build Coastguard Worker reseeded = torch.randn(odd_number) 9130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(midstream, repeat_midstream, atol=0, rtol=0, 9131*da0073e9SAndroid Build Coastguard Worker msg='get_rng_state/set_rng_state not generating same sequence of normally distributed numbers') 9132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seeded, reseeded, atol=0, rtol=0, 9133*da0073e9SAndroid Build Coastguard Worker msg='repeated calls to manual_seed not generating same sequence of normally distributed numbers') 9134*da0073e9SAndroid Build Coastguard Worker 9135*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("requires https://github.com/pytorch/torchdynamo/pull/1098") 9136*da0073e9SAndroid Build Coastguard Worker def test_manual_seed(self): 9137*da0073e9SAndroid Build Coastguard Worker rng_state = torch.get_rng_state() 9138*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2) 9139*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100) 9140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.initial_seed(), 2) 9141*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2) 9142*da0073e9SAndroid Build Coastguard Worker y = torch.randn(100) 9143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 9144*da0073e9SAndroid Build Coastguard Worker 9145*da0073e9SAndroid Build Coastguard Worker max_int64 = 0x7fff_ffff_ffff_ffff 9146*da0073e9SAndroid Build Coastguard Worker min_int64 = -max_int64 - 1 9147*da0073e9SAndroid Build Coastguard Worker max_uint64 = 0xffff_ffff_ffff_ffff 9148*da0073e9SAndroid Build Coastguard Worker # Check all boundary cases of valid seed value inputs 9149*da0073e9SAndroid Build Coastguard Worker test_cases = [ 9150*da0073e9SAndroid Build Coastguard Worker # (seed, expected_initial_seed) 9151*da0073e9SAndroid Build Coastguard Worker # Positive seeds should be unchanged 9152*da0073e9SAndroid Build Coastguard Worker (max_int64, max_int64), 9153*da0073e9SAndroid Build Coastguard Worker (max_int64 + 1, max_int64 + 1), 9154*da0073e9SAndroid Build Coastguard Worker (max_uint64, max_uint64), 9155*da0073e9SAndroid Build Coastguard Worker (0, 0), 9156*da0073e9SAndroid Build Coastguard Worker # Negative seeds wrap around starting from the largest seed value 9157*da0073e9SAndroid Build Coastguard Worker (-1, max_uint64), 9158*da0073e9SAndroid Build Coastguard Worker (min_int64, max_int64 + 1) 9159*da0073e9SAndroid Build Coastguard Worker ] 9160*da0073e9SAndroid Build Coastguard Worker for seed, expected_initial_seed in test_cases: 9161*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 9162*da0073e9SAndroid Build Coastguard Worker actual_initial_seed = torch.initial_seed() 9163*da0073e9SAndroid Build Coastguard Worker msg = (f"expected initial_seed() = {expected_initial_seed:x} " 9164*da0073e9SAndroid Build Coastguard Worker f"after calling manual_seed({seed:x}), but got {actual_initial_seed:x} instead") 9165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_initial_seed, actual_initial_seed, msg=msg) 9166*da0073e9SAndroid Build Coastguard Worker for invalid_seed in [min_int64 - 1, max_uint64 + 1]: 9167*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Overflow when unpacking long'): 9168*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(invalid_seed) 9169*da0073e9SAndroid Build Coastguard Worker 9170*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(rng_state) 9171*da0073e9SAndroid Build Coastguard Worker 9172*da0073e9SAndroid Build Coastguard Worker # FIXME: Describe this test and port to the generic device framework in a more 9173*da0073e9SAndroid Build Coastguard Worker # appropriate test suite for the copy operation 9174*da0073e9SAndroid Build Coastguard Worker def test_copy_transpose(self): 9175*da0073e9SAndroid Build Coastguard Worker x = torch.arange(100 * 100, dtype=torch.float).reshape(100, 100).t() 9176*da0073e9SAndroid Build Coastguard Worker y = torch.empty(100, 100, dtype=torch.float) 9177*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 9178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 0], range(100)) 9179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 40], range(4000, 4100)) 9180*da0073e9SAndroid Build Coastguard Worker 9181*da0073e9SAndroid Build Coastguard Worker y = torch.empty(100, 100, dtype=torch.double) 9182*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 9183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 0], range(100)) 9184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 40], range(4000, 4100)) 9185*da0073e9SAndroid Build Coastguard Worker 9186*da0073e9SAndroid Build Coastguard Worker # Validates regression reported in https://github.com/pytorch/pytorch/issues/45269 9187*da0073e9SAndroid Build Coastguard Worker x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.cfloat).t() 9188*da0073e9SAndroid Build Coastguard Worker y = torch.empty(100, 100, dtype=torch.cfloat) 9189*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 9190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 0], range(100)) 9191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 40], range(4000, 4100)) 9192*da0073e9SAndroid Build Coastguard Worker 9193*da0073e9SAndroid Build Coastguard Worker x = torch.arange(100 * 100).reshape(100, 100).to(dtype=torch.complex32).t() 9194*da0073e9SAndroid Build Coastguard Worker y = torch.empty(100, 100, dtype=torch.complex32) 9195*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 9196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 0], range(100)) 9197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y[:, 40], range(4000, 4100)) 9198*da0073e9SAndroid Build Coastguard Worker 9199*da0073e9SAndroid Build Coastguard Worker # FIXME: Port to a more appropriate test suite 9200*da0073e9SAndroid Build Coastguard Worker def test_copy_broadcast(self): 9201*da0073e9SAndroid Build Coastguard Worker torch.zeros(5, 6).copy_(torch.zeros(6)) 9202*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30))) 9203*da0073e9SAndroid Build Coastguard Worker 9204*da0073e9SAndroid Build Coastguard Worker # FIXME: Port to a more appropriate test suite 9205*da0073e9SAndroid Build Coastguard Worker # Fails with inductor (and aot_eager) because functionalization replaces copy_ with copy, 9206*da0073e9SAndroid Build Coastguard Worker # which doesn't properly error on bad inputs. 9207*da0073e9SAndroid Build Coastguard Worker def test_copy_many_to_one(self): 9208*da0073e9SAndroid Build Coastguard Worker # Testing in-place copy where it attempt to write from many memory 9209*da0073e9SAndroid Build Coastguard Worker # storage to a single storage would cause RuntimeError to be thrown 9210*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) 9211*da0073e9SAndroid Build Coastguard Worker 9212*da0073e9SAndroid Build Coastguard Worker def test_copy_float16(self): 9213*da0073e9SAndroid Build Coastguard Worker # Check that fbgemm code no longer reads memory out of bounds, see 9214*da0073e9SAndroid Build Coastguard Worker # copy_impl and fbgemm::Float16ToFloat_ref. 9215*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/88543 9216*da0073e9SAndroid Build Coastguard Worker 9217*da0073e9SAndroid Build Coastguard Worker # Types to test different code paths in copy_impl. 9218*da0073e9SAndroid Build Coastguard Worker dtypes = ( 9219*da0073e9SAndroid Build Coastguard Worker # out_dtype, src_dtype 9220*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.float16), # fbgemm 9221*da0073e9SAndroid Build Coastguard Worker (torch.float16, torch.float32), # fbgemm 9222*da0073e9SAndroid Build Coastguard Worker (torch.float32, torch.float32), # TensorIterator 9223*da0073e9SAndroid Build Coastguard Worker ) 9224*da0073e9SAndroid Build Coastguard Worker 9225*da0073e9SAndroid Build Coastguard Worker cases = ( 9226*da0073e9SAndroid Build Coastguard Worker # out_shape, src_shape, is_ok 9227*da0073e9SAndroid Build Coastguard Worker # These cases used to crash with fbgemm, make sure these also raise 9228*da0073e9SAndroid Build Coastguard Worker # exceptions with TensorIterator. 9229*da0073e9SAndroid Build Coastguard Worker ((1, 2, 3), (0, 2, 3), False), # same strides, not allowed by TI 9230*da0073e9SAndroid Build Coastguard Worker ((1, 5, 6), (4, 5, 6), False), # same strides, not allowed by TI 9231*da0073e9SAndroid Build Coastguard Worker (1, (0, 2, 3), False), # different strides 9232*da0073e9SAndroid Build Coastguard Worker ((4, 5, 6), (0, 2, 3), False), # different strides 9233*da0073e9SAndroid Build Coastguard Worker ((4, 5, 6), (1, 2, 3), False), # different strides 9234*da0073e9SAndroid Build Coastguard Worker ((4, 5, 6), (6, 5, 4), False), # same numel 9235*da0073e9SAndroid Build Coastguard Worker 9236*da0073e9SAndroid Build Coastguard Worker # These cases should pass with fbgemm and TensorIterator. 9237*da0073e9SAndroid Build Coastguard Worker ((4, 5, 6), (1, 5, 6), True), # same strides 9238*da0073e9SAndroid Build Coastguard Worker ((4, 5, 6), (4, 5, 6), True), # same strides 9239*da0073e9SAndroid Build Coastguard Worker ((0, 2, 3), 1, True), # different strides, allowed by TI 9240*da0073e9SAndroid Build Coastguard Worker ((4, 5, 6), (4, 5, 1), True), # different strides, allowed by TI 9241*da0073e9SAndroid Build Coastguard Worker ) 9242*da0073e9SAndroid Build Coastguard Worker 9243*da0073e9SAndroid Build Coastguard Worker for (out_shape, src_shape, is_ok), (out_dtype, src_dtype) in itertools.product(cases, dtypes): 9244*da0073e9SAndroid Build Coastguard Worker out = torch.zeros(out_shape, dtype=out_dtype, device=torch.device('cpu')) 9245*da0073e9SAndroid Build Coastguard Worker src = torch.ones(src_shape, dtype=src_dtype, device=torch.device('cpu')) 9246*da0073e9SAndroid Build Coastguard Worker if is_ok: 9247*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 9248*da0073e9SAndroid Build Coastguard Worker out_cuda = out.cuda() 9249*da0073e9SAndroid Build Coastguard Worker src_cuda = src.cuda() 9250*da0073e9SAndroid Build Coastguard Worker res = out.copy_(src) 9251*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 9252*da0073e9SAndroid Build Coastguard Worker res_cuda = out_cuda.copy_(src_cuda) 9253*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, res_cuda) 9254*da0073e9SAndroid Build Coastguard Worker else: 9255*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: out.copy_(src)) 9256*da0073e9SAndroid Build Coastguard Worker 9257*da0073e9SAndroid Build Coastguard Worker # FIXME: Port to a more appropriate test suite 9258*da0073e9SAndroid Build Coastguard Worker def _test_to_with_layout(self, layout): 9259*da0073e9SAndroid Build Coastguard Worker def test_copy_behavior(t, non_blocking=False): 9260*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(t, non_blocking=non_blocking)) 9261*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) 9262*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) 9263*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) 9264*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) 9265*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)) 9266*da0073e9SAndroid Build Coastguard Worker 9267*da0073e9SAndroid Build Coastguard Worker devices = [t.device] 9268*da0073e9SAndroid Build Coastguard Worker if t.device.type == 'cuda': 9269*da0073e9SAndroid Build Coastguard Worker if t.device.index == -1: 9270*da0073e9SAndroid Build Coastguard Worker devices.append(f'cuda:{torch.cuda.current_device()}') 9271*da0073e9SAndroid Build Coastguard Worker elif t.device.index == torch.cuda.current_device(): 9272*da0073e9SAndroid Build Coastguard Worker devices.append('cuda') 9273*da0073e9SAndroid Build Coastguard Worker for device in devices: 9274*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(device, non_blocking=non_blocking)) 9275*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) 9276*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) 9277*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) 9278*da0073e9SAndroid Build Coastguard Worker 9279*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(5) 9280*da0073e9SAndroid Build Coastguard Worker if layout == torch.sparse_csr: 9281*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr() 9282*da0073e9SAndroid Build Coastguard Worker test_copy_behavior(a) 9283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, a.to('cpu').device) 9284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device) 9285*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype) 9286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, a.to(torch.float32).device) 9287*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype) 9288*da0073e9SAndroid Build Coastguard Worker 9289*da0073e9SAndroid Build Coastguard Worker def test_data_ptr(getter): 9290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getter(a), getter(a.to('cpu'))) 9291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False))) 9292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getter(a), getter(a.to('cpu', copy=False))) 9293*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True))) 9294*da0073e9SAndroid Build Coastguard Worker if layout == torch.sparse_csr: 9295*da0073e9SAndroid Build Coastguard Worker # TODO: compressed sparse tensors currently don't support data_ptr. 9296*da0073e9SAndroid Build Coastguard Worker # Exercising failure will allow us to widen coverage of this test once it does. 9297*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"): 9298*da0073e9SAndroid Build Coastguard Worker a.data_ptr() 9299*da0073e9SAndroid Build Coastguard Worker # While compressed sparse tensors don't have a concept of data_ptr 9300*da0073e9SAndroid Build Coastguard Worker # the underlying tensors do. The implementation of to appropriately forwards 9301*da0073e9SAndroid Build Coastguard Worker # the call to the components, which is what we're test here. 9302*da0073e9SAndroid Build Coastguard Worker test_data_ptr(lambda a: a.values().data_ptr()) 9303*da0073e9SAndroid Build Coastguard Worker test_data_ptr(lambda a: a.crow_indices().data_ptr()) 9304*da0073e9SAndroid Build Coastguard Worker test_data_ptr(lambda a: a.col_indices().data_ptr()) 9305*da0073e9SAndroid Build Coastguard Worker else: 9306*da0073e9SAndroid Build Coastguard Worker test_data_ptr(lambda a: a.data_ptr()) 9307*da0073e9SAndroid Build Coastguard Worker 9308*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 9309*da0073e9SAndroid Build Coastguard Worker for non_blocking in [True, False]: 9310*da0073e9SAndroid Build Coastguard Worker for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: 9311*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(5., device=cuda) 9312*da0073e9SAndroid Build Coastguard Worker test_copy_behavior(b, non_blocking) 9313*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device) 9314*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device) 9315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device) 9316*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype) 9317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device) 9318*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype) 9319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, b.to(dtype=torch.int32).device) 9320*da0073e9SAndroid Build Coastguard Worker 9321*da0073e9SAndroid Build Coastguard Worker def test_to(self): 9322*da0073e9SAndroid Build Coastguard Worker self._test_to_with_layout(torch.strided) 9323*da0073e9SAndroid Build Coastguard Worker is_cuda10_2_or_higher = ( 9324*da0073e9SAndroid Build Coastguard Worker (torch.version.cuda is not None) 9325*da0073e9SAndroid Build Coastguard Worker and ([int(x) for x in torch.version.cuda.split(".")] >= [10, 2])) 9326*da0073e9SAndroid Build Coastguard Worker if is_cuda10_2_or_higher: # in cuda10_1 sparse_csr is beta 9327*da0073e9SAndroid Build Coastguard Worker self._test_to_with_layout(torch.sparse_csr) 9328*da0073e9SAndroid Build Coastguard Worker 9329*da0073e9SAndroid Build Coastguard Worker # FIXME: describe this test 9330*da0073e9SAndroid Build Coastguard Worker def test_as_subclass(self): 9331*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 9332*da0073e9SAndroid Build Coastguard Worker member_var = object() 9333*da0073e9SAndroid Build Coastguard Worker 9334*da0073e9SAndroid Build Coastguard Worker t0 = torch.tensor(0) 9335*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([1, 2]) 9336*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([[3, 4], [5, 6]]) 9337*da0073e9SAndroid Build Coastguard Worker 9338*da0073e9SAndroid Build Coastguard Worker s0 = t0.as_subclass(SubTensor) 9339*da0073e9SAndroid Build Coastguard Worker s1 = t1.as_subclass(SubTensor) 9340*da0073e9SAndroid Build Coastguard Worker s2 = t2.as_subclass(SubTensor) 9341*da0073e9SAndroid Build Coastguard Worker 9342*da0073e9SAndroid Build Coastguard Worker # Check that the correct type is returned. 9343*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(s0) is SubTensor) 9344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(s1) is SubTensor) 9345*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(s2) is SubTensor) 9346*da0073e9SAndroid Build Coastguard Worker 9347*da0073e9SAndroid Build Coastguard Worker # Check that the data is equal. 9348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t0, s0) 9349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, s1) 9350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, s2) 9351*da0073e9SAndroid Build Coastguard Worker 9352*da0073e9SAndroid Build Coastguard Worker t0[()] = 1 9353*da0073e9SAndroid Build Coastguard Worker t1[1] = 3 9354*da0073e9SAndroid Build Coastguard Worker t2[1, 1] = 7 9355*da0073e9SAndroid Build Coastguard Worker 9356*da0073e9SAndroid Build Coastguard Worker # Check that the data is equal even after modification. 9357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t0, s0) 9358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, s1) 9359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2, s2) 9360*da0073e9SAndroid Build Coastguard Worker 9361*da0073e9SAndroid Build Coastguard Worker # Check that member variables are passed through. 9362*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0.member_var is SubTensor.member_var) 9363*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s1.member_var is SubTensor.member_var) 9364*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s2.member_var is SubTensor.member_var) 9365*da0073e9SAndroid Build Coastguard Worker 9366*da0073e9SAndroid Build Coastguard Worker # Test that autograd is propagated. 9367*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5, dtype=torch.float32, requires_grad=True) 9368*da0073e9SAndroid Build Coastguard Worker 9369*da0073e9SAndroid Build Coastguard Worker # Run a calculation on the tensor. 9370*da0073e9SAndroid Build Coastguard Worker exp_t = torch.exp(t) 9371*da0073e9SAndroid Build Coastguard Worker 9372*da0073e9SAndroid Build Coastguard Worker # Cast exp_t to a subclass. 9373*da0073e9SAndroid Build Coastguard Worker exp_s = exp_t.as_subclass(SubTensor) 9374*da0073e9SAndroid Build Coastguard Worker 9375*da0073e9SAndroid Build Coastguard Worker # Make sure that t.grad was initially None 9376*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.grad is None) 9377*da0073e9SAndroid Build Coastguard Worker 9378*da0073e9SAndroid Build Coastguard Worker # Run the autograd calculation. 9379*da0073e9SAndroid Build Coastguard Worker exp_s.backward() 9380*da0073e9SAndroid Build Coastguard Worker 9381*da0073e9SAndroid Build Coastguard Worker # Make sure autograd was propagated to the original tensor 9382*da0073e9SAndroid Build Coastguard Worker # declared with requires_grad. 9383*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t.grad is not None) 9384*da0073e9SAndroid Build Coastguard Worker 9385*da0073e9SAndroid Build Coastguard Worker # Make sure invalid subclasses raise nice errors 9386*da0073e9SAndroid Build Coastguard Worker class BadSubTensor: 9387*da0073e9SAndroid Build Coastguard Worker member_var = object() 9388*da0073e9SAndroid Build Coastguard Worker 9389*da0073e9SAndroid Build Coastguard Worker err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor" 9390*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 9391*da0073e9SAndroid Build Coastguard Worker s0 = t0.as_subclass(BadSubTensor) 9392*da0073e9SAndroid Build Coastguard Worker 9393*da0073e9SAndroid Build Coastguard Worker # FIXME: Port to a test suite that better fits slicing 9394*da0073e9SAndroid Build Coastguard Worker def test_slice(self): 9395*da0073e9SAndroid Build Coastguard Worker empty = torch.empty(0, 4) 9396*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0., 16).view(4, 4) 9397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:], x) 9398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:4], x) 9399*da0073e9SAndroid Build Coastguard Worker # start and stop are clamped to the size of dim 9400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:5], x) 9401*da0073e9SAndroid Build Coastguard Worker # if start >= stop then the result is empty 9402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[2:1], empty) 9403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[2:2], empty) 9404*da0073e9SAndroid Build Coastguard Worker # out of bounds is also empty 9405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[10:12], empty) 9406*da0073e9SAndroid Build Coastguard Worker # additional correctness checks 9407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:1].tolist(), [[0, 1, 2, 3]]) 9408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:-3].tolist(), [[0, 1, 2, 3]]) 9409*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[:, -2:3].tolist(), [[2], [6], [10], [14]]) 9410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x[0:-1:2].tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]]) 9411*da0073e9SAndroid Build Coastguard Worker 9412*da0073e9SAndroid Build Coastguard Worker def test_split_with_sizes_copy_out(self): 9413*da0073e9SAndroid Build Coastguard Worker device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 9414*da0073e9SAndroid Build Coastguard Worker shape = (30, 40, 50) 9415*da0073e9SAndroid Build Coastguard Worker x = torch.rand(*shape, device=device) 9416*da0073e9SAndroid Build Coastguard Worker cases = [ 9417*da0073e9SAndroid Build Coastguard Worker (0, [3, 7, 8, 12]), 9418*da0073e9SAndroid Build Coastguard Worker (1, [3, 7, 10, 20]), 9419*da0073e9SAndroid Build Coastguard Worker (-2, [3, 7, 10, 20]), 9420*da0073e9SAndroid Build Coastguard Worker (2, [3, 7, 10, 12, 18]), 9421*da0073e9SAndroid Build Coastguard Worker (-1, [3, 7, 10, 12, 18]), 9422*da0073e9SAndroid Build Coastguard Worker (2, [3, 7, 10, 0, 30]), 9423*da0073e9SAndroid Build Coastguard Worker ] 9424*da0073e9SAndroid Build Coastguard Worker for dim, split_sizes in cases: 9425*da0073e9SAndroid Build Coastguard Worker views = x.split_with_sizes(split_sizes, dim=dim) 9426*da0073e9SAndroid Build Coastguard Worker expects = [v.clone() for v in views] 9427*da0073e9SAndroid Build Coastguard Worker out = [torch.zeros_like(v) for v in views] 9428*da0073e9SAndroid Build Coastguard Worker for expect, t in zip(expects, out): 9429*da0073e9SAndroid Build Coastguard Worker if expect.numel() != 0: 9430*da0073e9SAndroid Build Coastguard Worker self.assertFalse(expect.eq(t).all().item()) 9431*da0073e9SAndroid Build Coastguard Worker 9432*da0073e9SAndroid Build Coastguard Worker torch.split_with_sizes_copy(x, split_sizes, dim=dim, out=out) 9433*da0073e9SAndroid Build Coastguard Worker for expect, t in zip(expects, out): 9434*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect.eq(t).all().item()) 9435*da0073e9SAndroid Build Coastguard Worker 9436*da0073e9SAndroid Build Coastguard Worker if not torch.cuda.is_available(): 9437*da0073e9SAndroid Build Coastguard Worker continue 9438*da0073e9SAndroid Build Coastguard Worker 9439*da0073e9SAndroid Build Coastguard Worker # Test with cuda graph 9440*da0073e9SAndroid Build Coastguard Worker out = [torch.zeros_like(v) for v in views] 9441*da0073e9SAndroid Build Coastguard Worker for expect, t in zip(expects, out): 9442*da0073e9SAndroid Build Coastguard Worker if expect.numel() != 0: 9443*da0073e9SAndroid Build Coastguard Worker self.assertFalse(expect.eq(t).all().item()) 9444*da0073e9SAndroid Build Coastguard Worker 9445*da0073e9SAndroid Build Coastguard Worker g = torch.cuda.CUDAGraph() 9446*da0073e9SAndroid Build Coastguard Worker with torch.cuda.graph(g): 9447*da0073e9SAndroid Build Coastguard Worker torch.split_with_sizes_copy(x, split_sizes, dim=dim, out=out) 9448*da0073e9SAndroid Build Coastguard Worker 9449*da0073e9SAndroid Build Coastguard Worker g.replay() 9450*da0073e9SAndroid Build Coastguard Worker for expect, t in zip(expects, out): 9451*da0073e9SAndroid Build Coastguard Worker self.assertTrue(expect.eq(t).all().item()) 9452*da0073e9SAndroid Build Coastguard Worker 9453*da0073e9SAndroid Build Coastguard Worker def test_type(self): 9454*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3).double() 9455*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32) 9456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32) 9457*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype()) 9458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.type(torch.int32).dtype, torch.int32) 9459*da0073e9SAndroid Build Coastguard Worker 9460*da0073e9SAndroid Build Coastguard Worker # FIXME: port to a quantization test suite 9461*da0073e9SAndroid Build Coastguard Worker def test_qengine(self): 9462*da0073e9SAndroid Build Coastguard Worker qengines = torch.backends.quantized.supported_engines 9463*da0073e9SAndroid Build Coastguard Worker original_qe = torch.backends.quantized.engine 9464*da0073e9SAndroid Build Coastguard Worker for qe in qengines: 9465*da0073e9SAndroid Build Coastguard Worker torch.backends.quantized.engine = qe 9466*da0073e9SAndroid Build Coastguard Worker assert torch.backends.quantized.engine == qe, 'qengine not set successfully' 9467*da0073e9SAndroid Build Coastguard Worker torch.backends.quantized.engine = original_qe 9468*da0073e9SAndroid Build Coastguard Worker 9469*da0073e9SAndroid Build Coastguard Worker def test_terminate_handler_on_crash(self): 9470*da0073e9SAndroid Build Coastguard Worker cmd = [sys.executable, '-c', "import os; os.environ[\"TORCH_CUSTOM_TERMINATE\"] ='1'; \ 9471*da0073e9SAndroid Build Coastguard Worker import torch; import torch._C; torch._C._abort()"] 9472*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(subprocess.CalledProcessError) as cm: 9473*da0073e9SAndroid Build Coastguard Worker subprocess.check_output(cmd, shell=False) 9474*da0073e9SAndroid Build Coastguard Worker e = cm.exception 9475*da0073e9SAndroid Build Coastguard Worker output = e.stdout.decode("utf-8") 9476*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(e.returncode, 0) 9477*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(output, None) 9478*da0073e9SAndroid Build Coastguard Worker self.assertIn('Unhandled exception caught in c10/util/AbortHandler.h', output) 9479*da0073e9SAndroid Build Coastguard Worker 9480*da0073e9SAndroid Build Coastguard Worker # FIXME: port to a distributed test suite -- also... how could this be OOMing on Windows CUDA? 9481*da0073e9SAndroid Build Coastguard Worker @slowTest 9482*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ 9483*da0073e9SAndroid Build Coastguard Worker don't support multiprocessing with spawn start method") 9484*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows') 9485*da0073e9SAndroid Build Coastguard Worker def test_multinomial_invalid_probs(self): 9486*da0073e9SAndroid Build Coastguard Worker def _spawn_method(self, method, arg): 9487*da0073e9SAndroid Build Coastguard Worker try: 9488*da0073e9SAndroid Build Coastguard Worker mp.set_start_method('spawn') 9489*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 9490*da0073e9SAndroid Build Coastguard Worker pass 9491*da0073e9SAndroid Build Coastguard Worker with mp.Pool(1) as pool: 9492*da0073e9SAndroid Build Coastguard Worker out = pool.map(method, [arg]) 9493*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out[0]) 9494*da0073e9SAndroid Build Coastguard Worker 9495*da0073e9SAndroid Build Coastguard Worker def _test_multinomial_invalid_probs(probs): 9496*da0073e9SAndroid Build Coastguard Worker try: 9497*da0073e9SAndroid Build Coastguard Worker # n_sample = 1 is a special case, test n_sample=2 which is more general 9498*da0073e9SAndroid Build Coastguard Worker torch.multinomial(probs.to('cpu'), 2) 9499*da0073e9SAndroid Build Coastguard Worker return False # Should not be reached 9500*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 9501*da0073e9SAndroid Build Coastguard Worker return 'probability tensor contains either `inf`, `nan` or element < 0' in str(e) 9502*da0073e9SAndroid Build Coastguard Worker 9503*da0073e9SAndroid Build Coastguard Worker _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -1., 1.])) 9504*da0073e9SAndroid Build Coastguard Worker _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., inf, 1.])) 9505*da0073e9SAndroid Build Coastguard Worker _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., -inf, 1.])) 9506*da0073e9SAndroid Build Coastguard Worker _spawn_method(_test_multinomial_invalid_probs, torch.tensor([1., 1., nan])) 9507*da0073e9SAndroid Build Coastguard Worker 9508*da0073e9SAndroid Build Coastguard Worker # FIXME: port to more appropriate test suite 9509*da0073e9SAndroid Build Coastguard Worker def test_to_with_tensor(self): 9510*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(5) 9511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, a.to(a).device) 9512*da0073e9SAndroid Build Coastguard Worker 9513*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 9514*da0073e9SAndroid Build Coastguard Worker for non_blocking in [True, False]: 9515*da0073e9SAndroid Build Coastguard Worker for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: 9516*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(5., device=cuda) 9517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device) 9518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device) 9519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device) 9520*da0073e9SAndroid Build Coastguard Worker 9521*da0073e9SAndroid Build Coastguard Worker def test_device(self): 9522*da0073e9SAndroid Build Coastguard Worker cpu = torch.device('cpu') 9523*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cpu', str(cpu)) 9524*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cpu', cpu.type) 9525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, cpu.index) 9526*da0073e9SAndroid Build Coastguard Worker 9527*da0073e9SAndroid Build Coastguard Worker cpu0 = torch.device('cpu:0') 9528*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cpu:0', str(cpu0)) 9529*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cpu', cpu0.type) 9530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, cpu0.index) 9531*da0073e9SAndroid Build Coastguard Worker 9532*da0073e9SAndroid Build Coastguard Worker cpu0 = torch.device('cpu', 0) 9533*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cpu:0', str(cpu0)) 9534*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cpu', cpu0.type) 9535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, cpu0.index) 9536*da0073e9SAndroid Build Coastguard Worker 9537*da0073e9SAndroid Build Coastguard Worker cuda = torch.device('cuda') 9538*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda', str(cuda)) 9539*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda', cuda.type) 9540*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, cuda.index) 9541*da0073e9SAndroid Build Coastguard Worker 9542*da0073e9SAndroid Build Coastguard Worker cuda1 = torch.device('cuda:1') 9543*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda:1', str(cuda1)) 9544*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda', cuda1.type) 9545*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, cuda1.index) 9546*da0073e9SAndroid Build Coastguard Worker 9547*da0073e9SAndroid Build Coastguard Worker cuda1 = torch.device('cuda', 1) 9548*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda:1', str(cuda1)) 9549*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda', cuda1.type) 9550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, cuda1.index) 9551*da0073e9SAndroid Build Coastguard Worker 9552*da0073e9SAndroid Build Coastguard Worker cuda90 = torch.device('cuda', 90) 9553*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda:90', str(cuda90)) 9554*da0073e9SAndroid Build Coastguard Worker self.assertEqual('cuda', cuda90.type) 9555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(90, cuda90.index) 9556*da0073e9SAndroid Build Coastguard Worker 9557*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1')) 9558*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1')) 9559*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 ')) 9560*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda: 2')) 9561*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 2')) 9562*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.')) 9563*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2?')) 9564*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:?2')) 9565*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:')) 9566*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2.232')) 9567*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2 cuda:3')) 9568*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2+cuda:3')) 9569*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('cuda:2cuda:3')) 9570*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device(-1)) 9571*da0073e9SAndroid Build Coastguard Worker 9572*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('other')) 9573*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.device('other:0')) 9574*da0073e9SAndroid Build Coastguard Worker 9575*da0073e9SAndroid Build Coastguard Worker device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'} 9576*da0073e9SAndroid Build Coastguard Worker device_hash_set = set() 9577*da0073e9SAndroid Build Coastguard Worker device_hash_set.update(hash(torch.device(device)) for device in device_set) 9578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(device_set), len(device_hash_set)) 9579*da0073e9SAndroid Build Coastguard Worker 9580*da0073e9SAndroid Build Coastguard Worker def get_expected_device_repr(device): 9581*da0073e9SAndroid Build Coastguard Worker if device.index is not None: 9582*da0073e9SAndroid Build Coastguard Worker return f"device(type='{device.type}', index={device.index})" 9583*da0073e9SAndroid Build Coastguard Worker 9584*da0073e9SAndroid Build Coastguard Worker return f"device(type='{device.type}')" 9585*da0073e9SAndroid Build Coastguard Worker 9586*da0073e9SAndroid Build Coastguard Worker for device in device_set: 9587*da0073e9SAndroid Build Coastguard Worker dev = torch.device(device) 9588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(dev), get_expected_device_repr(dev)) 9589*da0073e9SAndroid Build Coastguard Worker 9590*da0073e9SAndroid Build Coastguard Worker # Tests that the use_deterministic_flag can be set as expected 9591*da0073e9SAndroid Build Coastguard Worker @wrapDeterministicFlagAPITest 9592*da0073e9SAndroid Build Coastguard Worker def test_deterministic_flag(self): 9593*da0073e9SAndroid Build Coastguard Worker for deterministic, warn_only in product([True, False], [True, False]): 9594*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(deterministic, warn_only=warn_only) 9595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled()) 9596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled()) 9597*da0073e9SAndroid Build Coastguard Worker 9598*da0073e9SAndroid Build Coastguard Worker if deterministic: 9599*da0073e9SAndroid Build Coastguard Worker if warn_only: 9600*da0073e9SAndroid Build Coastguard Worker debug_mode = 1 9601*da0073e9SAndroid Build Coastguard Worker else: 9602*da0073e9SAndroid Build Coastguard Worker debug_mode = 2 9603*da0073e9SAndroid Build Coastguard Worker else: 9604*da0073e9SAndroid Build Coastguard Worker debug_mode = 0 9605*da0073e9SAndroid Build Coastguard Worker 9606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(debug_mode, torch.get_deterministic_debug_mode()) 9607*da0073e9SAndroid Build Coastguard Worker 9608*da0073e9SAndroid Build Coastguard Worker for debug_mode in [0, 1, 2]: 9609*da0073e9SAndroid Build Coastguard Worker torch.set_deterministic_debug_mode(debug_mode) 9610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(debug_mode, torch.get_deterministic_debug_mode()) 9611*da0073e9SAndroid Build Coastguard Worker deterministic = debug_mode in [1, 2] 9612*da0073e9SAndroid Build Coastguard Worker warn_only = debug_mode == 1 9613*da0073e9SAndroid Build Coastguard Worker 9614*da0073e9SAndroid Build Coastguard Worker self.assertEqual(deterministic, torch.are_deterministic_algorithms_enabled()) 9615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(warn_only, torch.is_deterministic_algorithms_warn_only_enabled()) 9616*da0073e9SAndroid Build Coastguard Worker 9617*da0073e9SAndroid Build Coastguard Worker for debug_mode, debug_mode_str in [(0, 'default'), (1, 'warn'), (2, 'error')]: 9618*da0073e9SAndroid Build Coastguard Worker torch.set_deterministic_debug_mode(debug_mode_str) 9619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(debug_mode, torch.get_deterministic_debug_mode()) 9620*da0073e9SAndroid Build Coastguard Worker 9621*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9622*da0073e9SAndroid Build Coastguard Worker TypeError, 9623*da0073e9SAndroid Build Coastguard Worker r"_set_deterministic_algorithms\(\): argument 'mode' \(position 1\) must be bool, not int"): 9624*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(1) 9625*da0073e9SAndroid Build Coastguard Worker 9626*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9627*da0073e9SAndroid Build Coastguard Worker TypeError, 9628*da0073e9SAndroid Build Coastguard Worker r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"): 9629*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(False, warn_only=1) 9630*da0073e9SAndroid Build Coastguard Worker 9631*da0073e9SAndroid Build Coastguard Worker # Tests that torch.utils.deterministic.fill_uninitialized_memory can be set as expected 9632*da0073e9SAndroid Build Coastguard Worker def test_deterministic_fill_uninitialized_memory(self): 9633*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True, fill_uninitialized_memory=False): 9634*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory) 9635*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory()) 9636*da0073e9SAndroid Build Coastguard Worker 9637*da0073e9SAndroid Build Coastguard Worker with DeterministicGuard(True, fill_uninitialized_memory=True): 9638*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory) 9639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory()) 9640*da0073e9SAndroid Build Coastguard Worker 9641*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory) 9642*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory()) 9643*da0073e9SAndroid Build Coastguard Worker 9644*da0073e9SAndroid Build Coastguard Worker torch.utils.deterministic.fill_uninitialized_memory = False 9645*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory) 9646*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory()) 9647*da0073e9SAndroid Build Coastguard Worker 9648*da0073e9SAndroid Build Coastguard Worker torch.utils.deterministic.fill_uninitialized_memory = True 9649*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory) 9650*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory()) 9651*da0073e9SAndroid Build Coastguard Worker 9652*da0073e9SAndroid Build Coastguard Worker torch._C._set_deterministic_fill_uninitialized_memory(False) 9653*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory) 9654*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory()) 9655*da0073e9SAndroid Build Coastguard Worker 9656*da0073e9SAndroid Build Coastguard Worker torch._C._set_deterministic_fill_uninitialized_memory(True) 9657*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory) 9658*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory()) 9659*da0073e9SAndroid Build Coastguard Worker 9660*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"expected a bool, but got int"): 9661*da0073e9SAndroid Build Coastguard Worker torch.utils.deterministic.fill_uninitialized_memory = 1 9662*da0073e9SAndroid Build Coastguard Worker 9663*da0073e9SAndroid Build Coastguard Worker def test_type_conversion_via_dtype_name(self): 9664*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1]) 9665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.byte().dtype, torch.uint8) 9666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.bool().dtype, torch.bool) 9667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.char().dtype, torch.int8) 9668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.double().dtype, torch.float64) 9669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.float().dtype, torch.float32) 9670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.half().dtype, torch.float16) 9671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.int().dtype, torch.int32) 9672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.bfloat16().dtype, torch.bfloat16) 9673*da0073e9SAndroid Build Coastguard Worker cfloat = x.cfloat() 9674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cfloat.dtype, torch.complex64) 9675*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cfloat.real, x.float()) 9676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cfloat.imag, torch.zeros_like(cfloat.imag)) 9677*da0073e9SAndroid Build Coastguard Worker cdouble = x.cdouble() 9678*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cdouble.dtype, torch.complex128) 9679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cdouble.real, x.double()) 9680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cdouble.imag, torch.zeros_like(cdouble.imag)) 9681*da0073e9SAndroid Build Coastguard Worker chalf = x.chalf() 9682*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chalf.dtype, torch.complex32) 9683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chalf.real, x.half()) 9684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(chalf.imag, torch.zeros_like(chalf.imag)) 9685*da0073e9SAndroid Build Coastguard Worker 9686*da0073e9SAndroid Build Coastguard Worker def test_type_alias(self): 9687*da0073e9SAndroid Build Coastguard Worker type_alias_map = {torch.float64: torch.double, 9688*da0073e9SAndroid Build Coastguard Worker torch.float32: torch.float, 9689*da0073e9SAndroid Build Coastguard Worker torch.int32: torch.int, 9690*da0073e9SAndroid Build Coastguard Worker torch.int64: torch.long, 9691*da0073e9SAndroid Build Coastguard Worker torch.int16: torch.short, 9692*da0073e9SAndroid Build Coastguard Worker torch.float16: torch.half, 9693*da0073e9SAndroid Build Coastguard Worker torch.complex32: torch.chalf, 9694*da0073e9SAndroid Build Coastguard Worker torch.complex64: torch.cfloat} 9695*da0073e9SAndroid Build Coastguard Worker for dtype, alias in type_alias_map.items(): 9696*da0073e9SAndroid Build Coastguard Worker self.assertIs(alias, dtype) 9697*da0073e9SAndroid Build Coastguard Worker 9698*da0073e9SAndroid Build Coastguard Worker def test_doc_template(self) -> None: 9699*da0073e9SAndroid Build Coastguard Worker """ 9700*da0073e9SAndroid Build Coastguard Worker Test that all public API doc strings use the same standard template for 9701*da0073e9SAndroid Build Coastguard Worker all common arguments such as tensor or dim 9702*da0073e9SAndroid Build Coastguard Worker """ 9703*da0073e9SAndroid Build Coastguard Worker from torch._torch_docs import __file__ as doc_file 9704*da0073e9SAndroid Build Coastguard Worker from torch._torch_docs import multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args 9705*da0073e9SAndroid Build Coastguard Worker 9706*da0073e9SAndroid Build Coastguard Worker with open(doc_file, encoding="utf-8") as f: 9707*da0073e9SAndroid Build Coastguard Worker doc_strs = f.read() 9708*da0073e9SAndroid Build Coastguard Worker 9709*da0073e9SAndroid Build Coastguard Worker matches = re.findall( 9710*da0073e9SAndroid Build Coastguard Worker r'add_docstr\(([^,]+?),[^"\']*?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')(?:\.|,?[^,\)]*?\))', 9711*da0073e9SAndroid Build Coastguard Worker doc_strs, 9712*da0073e9SAndroid Build Coastguard Worker re.MULTILINE | re.DOTALL, 9713*da0073e9SAndroid Build Coastguard Worker ) 9714*da0073e9SAndroid Build Coastguard Worker self.assertTrue(matches) 9715*da0073e9SAndroid Build Coastguard Worker 9716*da0073e9SAndroid Build Coastguard Worker for m in matches: 9717*da0073e9SAndroid Build Coastguard Worker func = m[0].strip() 9718*da0073e9SAndroid Build Coastguard Worker desc = m[1].strip() 9719*da0073e9SAndroid Build Coastguard Worker 9720*da0073e9SAndroid Build Coastguard Worker for common_args in [multi_dim_common, single_dim_common, factory_common_args, factory_like_common_args]: 9721*da0073e9SAndroid Build Coastguard Worker for k, v in common_args.items(): 9722*da0073e9SAndroid Build Coastguard Worker self.assertNotIn(v, desc, f'The argument description "{v}" in {func} can be ' 9723*da0073e9SAndroid Build Coastguard Worker f'replaced by {{{k}}}') 9724*da0073e9SAndroid Build Coastguard Worker 9725*da0073e9SAndroid Build Coastguard Worker def test_doc(self): 9726*da0073e9SAndroid Build Coastguard Worker checked_types = (types.MethodType, types.FunctionType, 9727*da0073e9SAndroid Build Coastguard Worker types.BuiltinFunctionType, types.BuiltinMethodType) 9728*da0073e9SAndroid Build Coastguard Worker 9729*da0073e9SAndroid Build Coastguard Worker def _test_namespace(ns, *skips): 9730*da0073e9SAndroid Build Coastguard Worker if isinstance(ns, object): 9731*da0073e9SAndroid Build Coastguard Worker ns_name = ns.__class__.__name__ 9732*da0073e9SAndroid Build Coastguard Worker else: 9733*da0073e9SAndroid Build Coastguard Worker ns_name = ns.__name__ 9734*da0073e9SAndroid Build Coastguard Worker skip_regexes = [] 9735*da0073e9SAndroid Build Coastguard Worker for r in skips: 9736*da0073e9SAndroid Build Coastguard Worker if isinstance(r, str): 9737*da0073e9SAndroid Build Coastguard Worker skip_regexes.append(re.compile(f'^{re.escape(r)}$')) 9738*da0073e9SAndroid Build Coastguard Worker else: 9739*da0073e9SAndroid Build Coastguard Worker skip_regexes.append(r) 9740*da0073e9SAndroid Build Coastguard Worker 9741*da0073e9SAndroid Build Coastguard Worker for name in dir(ns): 9742*da0073e9SAndroid Build Coastguard Worker if name.startswith('_'): 9743*da0073e9SAndroid Build Coastguard Worker continue 9744*da0073e9SAndroid Build Coastguard Worker if name in ['real', 'imag']: 9745*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, dtype=torch.cfloat) 9746*da0073e9SAndroid Build Coastguard Worker var = getattr(y, name) 9747*da0073e9SAndroid Build Coastguard Worker elif name in ["H", "mT", "mH"]: 9748*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, 1) 9749*da0073e9SAndroid Build Coastguard Worker var = getattr(y, name) 9750*da0073e9SAndroid Build Coastguard Worker else: 9751*da0073e9SAndroid Build Coastguard Worker var = getattr(ns, name) 9752*da0073e9SAndroid Build Coastguard Worker if not isinstance(var, checked_types): 9753*da0073e9SAndroid Build Coastguard Worker continue 9754*da0073e9SAndroid Build Coastguard Worker doc = var.__doc__ 9755*da0073e9SAndroid Build Coastguard Worker has_doc = doc is not None and len(doc.strip()) > 0 9756*da0073e9SAndroid Build Coastguard Worker full_name = ns_name + '.' + name 9757*da0073e9SAndroid Build Coastguard Worker if any(r.match(name) for r in skip_regexes): 9758*da0073e9SAndroid Build Coastguard Worker self.assertFalse(has_doc, 9759*da0073e9SAndroid Build Coastguard Worker f'New docs have been added for {full_name}, please remove ' 9760*da0073e9SAndroid Build Coastguard Worker 'it from the skipped list in TestTorch.test_doc') 9761*da0073e9SAndroid Build Coastguard Worker else: 9762*da0073e9SAndroid Build Coastguard Worker self.assertTrue(has_doc, f'{full_name} is missing documentation') 9763*da0073e9SAndroid Build Coastguard Worker 9764*da0073e9SAndroid Build Coastguard Worker # FIXME: All of the following should be marked as expected failures 9765*da0073e9SAndroid Build Coastguard Worker # so that it is easier to tell when missing has been added. 9766*da0073e9SAndroid Build Coastguard Worker # FIXME: fix all the skipped ones below! 9767*da0073e9SAndroid Build Coastguard Worker test_namespace(torch.randn(1), # noqa: F821 9768*da0073e9SAndroid Build Coastguard Worker 'as_strided_', 9769*da0073e9SAndroid Build Coastguard Worker re.compile('^clamp_(min|max)_?$'), 9770*da0073e9SAndroid Build Coastguard Worker 'is_distributed', 9771*da0073e9SAndroid Build Coastguard Worker 'is_nonzero', 9772*da0073e9SAndroid Build Coastguard Worker 'is_same_size', 9773*da0073e9SAndroid Build Coastguard Worker 'log_softmax', 9774*da0073e9SAndroid Build Coastguard Worker 'map2_', 9775*da0073e9SAndroid Build Coastguard Worker 'new', 9776*da0073e9SAndroid Build Coastguard Worker 'reinforce', 9777*da0073e9SAndroid Build Coastguard Worker 'relu', 9778*da0073e9SAndroid Build Coastguard Worker 'relu_', 9779*da0073e9SAndroid Build Coastguard Worker 'prelu', 9780*da0073e9SAndroid Build Coastguard Worker 'resize', 9781*da0073e9SAndroid Build Coastguard Worker 'resize_as', 9782*da0073e9SAndroid Build Coastguard Worker 'softmax', 9783*da0073e9SAndroid Build Coastguard Worker 'split_with_sizes', 9784*da0073e9SAndroid Build Coastguard Worker 'unsafe_split_with_sizes', 9785*da0073e9SAndroid Build Coastguard Worker '_autocast_to_fp16', 9786*da0073e9SAndroid Build Coastguard Worker '_autocast_to_fp32', 9787*da0073e9SAndroid Build Coastguard Worker ) 9788*da0073e9SAndroid Build Coastguard Worker 9789*da0073e9SAndroid Build Coastguard Worker test_namespace(torch.nn) # noqa: F821 9790*da0073e9SAndroid Build Coastguard Worker test_namespace(torch.nn.functional, 'assert_int_or_pair') # noqa: F821 9791*da0073e9SAndroid Build Coastguard Worker # TODO: add torch.* tests when we have proper namespacing on ATen functions 9792*da0073e9SAndroid Build Coastguard Worker # test_namespace(torch) 9793*da0073e9SAndroid Build Coastguard Worker 9794*da0073e9SAndroid Build Coastguard Worker # FIXME: deprecate torch.Tensor constructor 9795*da0073e9SAndroid Build Coastguard Worker def test_tensor_ctor_scalar(self): 9796*da0073e9SAndroid Build Coastguard Worker x = torch.Tensor(torch.tensor(1.0)) 9797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor(1.0)) 9798*da0073e9SAndroid Build Coastguard Worker 9799*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_gradient(self): 9800*da0073e9SAndroid Build Coastguard Worker from copy import deepcopy 9801*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(10) 9802*da0073e9SAndroid Build Coastguard Worker a.grad = torch.ones(10) 9803*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, deepcopy(a).grad) 9804*da0073e9SAndroid Build Coastguard Worker s = torch.zeros(10).to_sparse() 9805*da0073e9SAndroid Build Coastguard Worker s.grad = torch.ones(10).to_sparse() 9806*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.grad, deepcopy(s).grad) 9807*da0073e9SAndroid Build Coastguard Worker 9808*da0073e9SAndroid Build Coastguard Worker # ensure sharing is not broken 9809*da0073e9SAndroid Build Coastguard Worker c = deepcopy([a, a.grad]) 9810*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c[0].grad is c[1]) 9811*da0073e9SAndroid Build Coastguard Worker 9812*da0073e9SAndroid Build Coastguard Worker def test_tensor_base_init(self): 9813*da0073e9SAndroid Build Coastguard Worker # Direct construction not OK 9814*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._C.TensorBase()) 9815*da0073e9SAndroid Build Coastguard Worker 9816*da0073e9SAndroid Build Coastguard Worker # Subclassing it directly no OK 9817*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot subclass"): 9818*da0073e9SAndroid Build Coastguard Worker class Tfail(torch._C.TensorBase): 9819*da0073e9SAndroid Build Coastguard Worker pass 9820*da0073e9SAndroid Build Coastguard Worker 9821*da0073e9SAndroid Build Coastguard Worker # Doing so with Tensor is ok though 9822*da0073e9SAndroid Build Coastguard Worker class T(torch.Tensor): 9823*da0073e9SAndroid Build Coastguard Worker pass 9824*da0073e9SAndroid Build Coastguard Worker 9825*da0073e9SAndroid Build Coastguard Worker T() 9826*da0073e9SAndroid Build Coastguard Worker 9827*da0073e9SAndroid Build Coastguard Worker def test_storage_base_init(self): 9828*da0073e9SAndroid Build Coastguard Worker # Direct construction not OK 9829*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch._C.StorageBase()) 9830*da0073e9SAndroid Build Coastguard Worker 9831*da0073e9SAndroid Build Coastguard Worker # But construction of subclass is OK 9832*da0073e9SAndroid Build Coastguard Worker class T(torch._C.StorageBase): 9833*da0073e9SAndroid Build Coastguard Worker pass 9834*da0073e9SAndroid Build Coastguard Worker 9835*da0073e9SAndroid Build Coastguard Worker T() 9836*da0073e9SAndroid Build Coastguard Worker 9837*da0073e9SAndroid Build Coastguard Worker def test_tensor_base_new(self): 9838*da0073e9SAndroid Build Coastguard Worker 9839*da0073e9SAndroid Build Coastguard Worker # OK to call super().__new__, see 9840*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/57421 9841*da0073e9SAndroid Build Coastguard Worker class TestTensor(torch.Tensor): 9842*da0073e9SAndroid Build Coastguard Worker @staticmethod 9843*da0073e9SAndroid Build Coastguard Worker def __new__(cls, x, *args, **kwargs): 9844*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, x, *args, **kwargs) 9845*da0073e9SAndroid Build Coastguard Worker 9846*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5) 9847*da0073e9SAndroid Build Coastguard Worker test_tensor = TestTensor(x) 9848*da0073e9SAndroid Build Coastguard Worker 9849*da0073e9SAndroid Build Coastguard Worker def test_storage_base_new(self): 9850*da0073e9SAndroid Build Coastguard Worker 9851*da0073e9SAndroid Build Coastguard Worker # OK to call super().__new__, see 9852*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/57421 9853*da0073e9SAndroid Build Coastguard Worker class TestStorage(torch._C.StorageBase): 9854*da0073e9SAndroid Build Coastguard Worker @staticmethod 9855*da0073e9SAndroid Build Coastguard Worker def __new__(cls, x, *args, **kwargs): 9856*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, x, *args, **kwargs) 9857*da0073e9SAndroid Build Coastguard Worker 9858*da0073e9SAndroid Build Coastguard Worker x = torch.UntypedStorage(5) 9859*da0073e9SAndroid Build Coastguard Worker test_storage = TestStorage(x) 9860*da0073e9SAndroid Build Coastguard Worker 9861*da0073e9SAndroid Build Coastguard Worker def test_pyobj_preserved(self): 9862*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2) 9863*da0073e9SAndroid Build Coastguard Worker x.foo = 2 # put something on __dict__ 9864*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2) 9865*da0073e9SAndroid Build Coastguard Worker y.grad = x 9866*da0073e9SAndroid Build Coastguard Worker del x # x is dead in Python 9867*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.foo, 2) 9868*da0073e9SAndroid Build Coastguard Worker z = y.grad # it's live 9869*da0073e9SAndroid Build Coastguard Worker del z # it's dead again 9870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad.foo, 2) 9871*da0073e9SAndroid Build Coastguard Worker 9872*da0073e9SAndroid Build Coastguard Worker def test_subclass_preserved(self): 9873*da0073e9SAndroid Build Coastguard Worker class MyTensor(torch.Tensor): 9874*da0073e9SAndroid Build Coastguard Worker pass 9875*da0073e9SAndroid Build Coastguard Worker 9876*da0073e9SAndroid Build Coastguard Worker x = MyTensor(torch.empty(2)) 9877*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2) 9878*da0073e9SAndroid Build Coastguard Worker y.grad = x 9879*da0073e9SAndroid Build Coastguard Worker del x # x is dead in Python 9880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(y.grad), MyTensor) 9881*da0073e9SAndroid Build Coastguard Worker z = y.grad # it's live 9882*da0073e9SAndroid Build Coastguard Worker del z # it's dead again 9883*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(y.grad), MyTensor) 9884*da0073e9SAndroid Build Coastguard Worker 9885*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 9886*da0073e9SAndroid Build Coastguard Worker def test_storage_dealloc(self): 9887*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 9888*da0073e9SAndroid Build Coastguard Worker s0 = torch.UntypedStorage(10) 9889*da0073e9SAndroid Build Coastguard Worker s1 = s0 9890*da0073e9SAndroid Build Coastguard Worker s0._tracker = t 9891*da0073e9SAndroid Build Coastguard Worker del t 9892*da0073e9SAndroid Build Coastguard Worker 9893*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9894*da0073e9SAndroid Build Coastguard Worker del s0 9895*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9896*da0073e9SAndroid Build Coastguard Worker del s1 9897*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 9898*da0073e9SAndroid Build Coastguard Worker 9899*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 9900*da0073e9SAndroid Build Coastguard Worker def test_storage_from_tensor_dealloc(self): 9901*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 9902*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 9903*da0073e9SAndroid Build Coastguard Worker s0 = a.untyped_storage() 9904*da0073e9SAndroid Build Coastguard Worker s0._tracker = t 9905*da0073e9SAndroid Build Coastguard Worker del t 9906*da0073e9SAndroid Build Coastguard Worker 9907*da0073e9SAndroid Build Coastguard Worker s1 = a.untyped_storage() 9908*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0 is s1) 9909*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(s1, '_tracker')) 9910*da0073e9SAndroid Build Coastguard Worker 9911*da0073e9SAndroid Build Coastguard Worker del a 9912*da0073e9SAndroid Build Coastguard Worker 9913*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9914*da0073e9SAndroid Build Coastguard Worker del s0 9915*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9916*da0073e9SAndroid Build Coastguard Worker del s1 9917*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 9918*da0073e9SAndroid Build Coastguard Worker 9919*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 9920*da0073e9SAndroid Build Coastguard Worker def test_storage_from_tensor_dealloc_zombie(self): 9921*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 9922*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 9923*da0073e9SAndroid Build Coastguard Worker s0 = a.untyped_storage() 9924*da0073e9SAndroid Build Coastguard Worker s0._tracker = t 9925*da0073e9SAndroid Build Coastguard Worker del t 9926*da0073e9SAndroid Build Coastguard Worker 9927*da0073e9SAndroid Build Coastguard Worker s1 = a.untyped_storage() 9928*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0 is s1) 9929*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(s1, '_tracker')) 9930*da0073e9SAndroid Build Coastguard Worker 9931*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9932*da0073e9SAndroid Build Coastguard Worker del s0 9933*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9934*da0073e9SAndroid Build Coastguard Worker del s1 9935*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9936*da0073e9SAndroid Build Coastguard Worker del a 9937*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 9938*da0073e9SAndroid Build Coastguard Worker 9939*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 9940*da0073e9SAndroid Build Coastguard Worker def test_storage_from_tensor_dealloc_resurrected(self): 9941*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 9942*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 9943*da0073e9SAndroid Build Coastguard Worker s0 = a.untyped_storage() 9944*da0073e9SAndroid Build Coastguard Worker s0._tracker = t 9945*da0073e9SAndroid Build Coastguard Worker del t 9946*da0073e9SAndroid Build Coastguard Worker 9947*da0073e9SAndroid Build Coastguard Worker s1 = a.untyped_storage() 9948*da0073e9SAndroid Build Coastguard Worker self.assertTrue(s0 is s1) 9949*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(s1, '_tracker')) 9950*da0073e9SAndroid Build Coastguard Worker 9951*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9952*da0073e9SAndroid Build Coastguard Worker del s0 9953*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9954*da0073e9SAndroid Build Coastguard Worker del s1 9955*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9956*da0073e9SAndroid Build Coastguard Worker 9957*da0073e9SAndroid Build Coastguard Worker s0 = a.untyped_storage() 9958*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s0, torch.UntypedStorage)) 9959*da0073e9SAndroid Build Coastguard Worker 9960*da0073e9SAndroid Build Coastguard Worker del a 9961*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9962*da0073e9SAndroid Build Coastguard Worker del s0 9963*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 9964*da0073e9SAndroid Build Coastguard Worker 9965*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 9966*da0073e9SAndroid Build Coastguard Worker def test_storage_dealloc_resurrected(self): 9967*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 9968*da0073e9SAndroid Build Coastguard Worker s = torch.UntypedStorage(10) 9969*da0073e9SAndroid Build Coastguard Worker s._tracker = t 9970*da0073e9SAndroid Build Coastguard Worker del t 9971*da0073e9SAndroid Build Coastguard Worker 9972*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(s) 9973*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9974*da0073e9SAndroid Build Coastguard Worker del s 9975*da0073e9SAndroid Build Coastguard Worker 9976*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9977*da0073e9SAndroid Build Coastguard Worker 9978*da0073e9SAndroid Build Coastguard Worker s = a.untyped_storage() 9979*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s, torch.UntypedStorage)) 9980*da0073e9SAndroid Build Coastguard Worker 9981*da0073e9SAndroid Build Coastguard Worker del a 9982*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 9983*da0073e9SAndroid Build Coastguard Worker del s 9984*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 9985*da0073e9SAndroid Build Coastguard Worker 9986*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 9987*da0073e9SAndroid Build Coastguard Worker def test_storage_dealloc_subclass_zombie(self): 9988*da0073e9SAndroid Build Coastguard Worker class MyStorage(torch.UntypedStorage): 9989*da0073e9SAndroid Build Coastguard Worker finalized_count = 0 9990*da0073e9SAndroid Build Coastguard Worker 9991*da0073e9SAndroid Build Coastguard Worker def __del__(self): 9992*da0073e9SAndroid Build Coastguard Worker MyStorage.finalized_count += 1 9993*da0073e9SAndroid Build Coastguard Worker 9994*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 9995*da0073e9SAndroid Build Coastguard Worker s = MyStorage(10) 9996*da0073e9SAndroid Build Coastguard Worker s._tracker = t 9997*da0073e9SAndroid Build Coastguard Worker del t 9998*da0073e9SAndroid Build Coastguard Worker 9999*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(s) 10000*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10001*da0073e9SAndroid Build Coastguard Worker del s 10002*da0073e9SAndroid Build Coastguard Worker 10003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MyStorage.finalized_count, 0) 10004*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10005*da0073e9SAndroid Build Coastguard Worker 10006*da0073e9SAndroid Build Coastguard Worker del a 10007*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MyStorage.finalized_count, 1) 10008*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10009*da0073e9SAndroid Build Coastguard Worker 10010*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Tracker hook does not work in TorchDynamo") 10011*da0073e9SAndroid Build Coastguard Worker def test_storage_dealloc_subclass_resurrected(self): 10012*da0073e9SAndroid Build Coastguard Worker class MyStorage(torch.UntypedStorage): 10013*da0073e9SAndroid Build Coastguard Worker finalized_count = 0 10014*da0073e9SAndroid Build Coastguard Worker 10015*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10016*da0073e9SAndroid Build Coastguard Worker MyStorage.finalized_count += 1 10017*da0073e9SAndroid Build Coastguard Worker 10018*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 10019*da0073e9SAndroid Build Coastguard Worker s = MyStorage(10) 10020*da0073e9SAndroid Build Coastguard Worker s._tracker = t 10021*da0073e9SAndroid Build Coastguard Worker del t 10022*da0073e9SAndroid Build Coastguard Worker 10023*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(s) 10024*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10025*da0073e9SAndroid Build Coastguard Worker del s 10026*da0073e9SAndroid Build Coastguard Worker 10027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MyStorage.finalized_count, 0) 10028*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10029*da0073e9SAndroid Build Coastguard Worker 10030*da0073e9SAndroid Build Coastguard Worker s = a.untyped_storage() 10031*da0073e9SAndroid Build Coastguard Worker del a 10032*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MyStorage.finalized_count, 0) 10034*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s, MyStorage)) 10035*da0073e9SAndroid Build Coastguard Worker del s 10036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(MyStorage.finalized_count, 1) 10037*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10038*da0073e9SAndroid Build Coastguard Worker 10039*da0073e9SAndroid Build Coastguard Worker def test_tensor_slot_dealloc(self): 10040*da0073e9SAndroid Build Coastguard Worker 10041*da0073e9SAndroid Build Coastguard Worker class SlotTensor1(torch.Tensor): 10042*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot1'] 10043*da0073e9SAndroid Build Coastguard Worker 10044*da0073e9SAndroid Build Coastguard Worker class SlotTensor2(SlotTensor1): 10045*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot2'] 10046*da0073e9SAndroid Build Coastguard Worker 10047*da0073e9SAndroid Build Coastguard Worker m1, t1 = Tracker.make() 10048*da0073e9SAndroid Build Coastguard Worker m2, t2 = Tracker.make() 10049*da0073e9SAndroid Build Coastguard Worker slot_tensor = SlotTensor2(torch.empty(2)) 10050*da0073e9SAndroid Build Coastguard Worker slot_tensor.slot1 = t1 10051*da0073e9SAndroid Build Coastguard Worker slot_tensor.slot2 = t2 10052*da0073e9SAndroid Build Coastguard Worker del t1 10053*da0073e9SAndroid Build Coastguard Worker del t2 10054*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10055*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10056*da0073e9SAndroid Build Coastguard Worker del slot_tensor 10057*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10058*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10059*da0073e9SAndroid Build Coastguard Worker 10060*da0073e9SAndroid Build Coastguard Worker def test_storage_slot_dealloc(self): 10061*da0073e9SAndroid Build Coastguard Worker 10062*da0073e9SAndroid Build Coastguard Worker class SlotStorage1(torch._C.StorageBase): 10063*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot1'] 10064*da0073e9SAndroid Build Coastguard Worker 10065*da0073e9SAndroid Build Coastguard Worker class SlotStorage2(SlotStorage1): 10066*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot2'] 10067*da0073e9SAndroid Build Coastguard Worker 10068*da0073e9SAndroid Build Coastguard Worker m1, t1 = Tracker.make() 10069*da0073e9SAndroid Build Coastguard Worker m2, t2 = Tracker.make() 10070*da0073e9SAndroid Build Coastguard Worker slot_storage = SlotStorage2(torch.UntypedStorage(2)) 10071*da0073e9SAndroid Build Coastguard Worker slot_storage.slot1 = t1 10072*da0073e9SAndroid Build Coastguard Worker slot_storage.slot2 = t2 10073*da0073e9SAndroid Build Coastguard Worker del t1 10074*da0073e9SAndroid Build Coastguard Worker del t2 10075*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10076*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10077*da0073e9SAndroid Build Coastguard Worker del slot_storage 10078*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10079*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10080*da0073e9SAndroid Build Coastguard Worker 10081*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 10082*da0073e9SAndroid Build Coastguard Worker def test_tensor_dict_dealloc(self): 10083*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 10084*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2) 10085*da0073e9SAndroid Build Coastguard Worker x.arf = t 10086*da0073e9SAndroid Build Coastguard Worker del t 10087*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10088*da0073e9SAndroid Build Coastguard Worker del x 10089*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10090*da0073e9SAndroid Build Coastguard Worker 10091*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 10092*da0073e9SAndroid Build Coastguard Worker def test_storage_dict_dealloc(self): 10093*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 10094*da0073e9SAndroid Build Coastguard Worker x = torch.UntypedStorage(2) 10095*da0073e9SAndroid Build Coastguard Worker x.arf = t 10096*da0073e9SAndroid Build Coastguard Worker del t 10097*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10098*da0073e9SAndroid Build Coastguard Worker del x 10099*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10100*da0073e9SAndroid Build Coastguard Worker 10101*da0073e9SAndroid Build Coastguard Worker def test_tensor_finalizer_dealloc(self): 10102*da0073e9SAndroid Build Coastguard Worker m = [False] 10103*da0073e9SAndroid Build Coastguard Worker 10104*da0073e9SAndroid Build Coastguard Worker class FinalizerTensor(torch.Tensor): 10105*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10106*da0073e9SAndroid Build Coastguard Worker m[0] = True 10107*da0073e9SAndroid Build Coastguard Worker 10108*da0073e9SAndroid Build Coastguard Worker fin_tensor = FinalizerTensor(torch.empty(2)) 10109*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10110*da0073e9SAndroid Build Coastguard Worker del fin_tensor 10111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10112*da0073e9SAndroid Build Coastguard Worker 10113*da0073e9SAndroid Build Coastguard Worker def test_storage_finalizer_dealloc(self): 10114*da0073e9SAndroid Build Coastguard Worker m = [False] 10115*da0073e9SAndroid Build Coastguard Worker 10116*da0073e9SAndroid Build Coastguard Worker class FinalizerStorage(torch._C.StorageBase): 10117*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10118*da0073e9SAndroid Build Coastguard Worker m[0] = True 10119*da0073e9SAndroid Build Coastguard Worker 10120*da0073e9SAndroid Build Coastguard Worker fin_storage = FinalizerStorage(torch.UntypedStorage(2)) 10121*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10122*da0073e9SAndroid Build Coastguard Worker del fin_storage 10123*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10124*da0073e9SAndroid Build Coastguard Worker 10125*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") 10126*da0073e9SAndroid Build Coastguard Worker def test_tensor_weakref_dealloc(self): 10127*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2) 10128*da0073e9SAndroid Build Coastguard Worker m = [False] 10129*da0073e9SAndroid Build Coastguard Worker 10130*da0073e9SAndroid Build Coastguard Worker def cb(r): 10131*da0073e9SAndroid Build Coastguard Worker m[0] = True 10132*da0073e9SAndroid Build Coastguard Worker 10133*da0073e9SAndroid Build Coastguard Worker wref = weakref.ref(x, cb) 10134*da0073e9SAndroid Build Coastguard Worker del x 10135*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wref(), None) 10137*da0073e9SAndroid Build Coastguard Worker 10138*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") 10139*da0073e9SAndroid Build Coastguard Worker def test_storage_weakref_dealloc(self): 10140*da0073e9SAndroid Build Coastguard Worker 10141*da0073e9SAndroid Build Coastguard Worker x = torch.UntypedStorage(2) 10142*da0073e9SAndroid Build Coastguard Worker m = [False] 10143*da0073e9SAndroid Build Coastguard Worker 10144*da0073e9SAndroid Build Coastguard Worker def cb(r): 10145*da0073e9SAndroid Build Coastguard Worker m[0] = True 10146*da0073e9SAndroid Build Coastguard Worker 10147*da0073e9SAndroid Build Coastguard Worker wref = weakref.ref(x, cb) 10148*da0073e9SAndroid Build Coastguard Worker del x 10149*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10150*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wref(), None) 10151*da0073e9SAndroid Build Coastguard Worker 10152*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 10153*da0073e9SAndroid Build Coastguard Worker def test_tensor_cycle_via_dict(self): 10154*da0073e9SAndroid Build Coastguard Worker m1, t1 = Tracker.make() 10155*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2) 10156*da0073e9SAndroid Build Coastguard Worker x._tracker = t1 10157*da0073e9SAndroid Build Coastguard Worker del t1 10158*da0073e9SAndroid Build Coastguard Worker 10159*da0073e9SAndroid Build Coastguard Worker m2, t2 = Tracker.make() 10160*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2) 10161*da0073e9SAndroid Build Coastguard Worker y._tracker = t2 10162*da0073e9SAndroid Build Coastguard Worker del t2 10163*da0073e9SAndroid Build Coastguard Worker 10164*da0073e9SAndroid Build Coastguard Worker x._loop = y 10165*da0073e9SAndroid Build Coastguard Worker y._loop = x 10166*da0073e9SAndroid Build Coastguard Worker 10167*da0073e9SAndroid Build Coastguard Worker # C++ reference should keep the cycle live! 10168*da0073e9SAndroid Build Coastguard Worker # This exercise THPVariable_subtype_traverse 10169*da0073e9SAndroid Build Coastguard Worker # NB: Because z.grad is a reference done entirely in C++, cycles 10170*da0073e9SAndroid Build Coastguard Worker # involving it directly are NOT broken by Python GC; you've 10171*da0073e9SAndroid Build Coastguard Worker # set up a good old C++ reference cycle which we cannot safely 10172*da0073e9SAndroid Build Coastguard Worker # break (because C++ references are allowed to be accessed 10173*da0073e9SAndroid Build Coastguard Worker # multithreaded-ly) (TODO: except maybe if you can prove that 10174*da0073e9SAndroid Build Coastguard Worker # only Python has access to the C++ object, in which case you can 10175*da0073e9SAndroid Build Coastguard Worker # also prove that no multithreaded access occurs) 10176*da0073e9SAndroid Build Coastguard Worker z = torch.empty(2) 10177*da0073e9SAndroid Build Coastguard Worker z.grad = x 10178*da0073e9SAndroid Build Coastguard Worker 10179*da0073e9SAndroid Build Coastguard Worker del x 10180*da0073e9SAndroid Build Coastguard Worker del y 10181*da0073e9SAndroid Build Coastguard Worker 10182*da0073e9SAndroid Build Coastguard Worker gc.collect() 10183*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10184*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10185*da0073e9SAndroid Build Coastguard Worker 10186*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 10187*da0073e9SAndroid Build Coastguard Worker del z 10188*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10189*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10190*da0073e9SAndroid Build Coastguard Worker 10191*da0073e9SAndroid Build Coastguard Worker gc.collect() 10192*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10193*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10194*da0073e9SAndroid Build Coastguard Worker 10195*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 10196*da0073e9SAndroid Build Coastguard Worker def test_storage_cycle_via_dict(self): 10197*da0073e9SAndroid Build Coastguard Worker m1, t1 = Tracker.make() 10198*da0073e9SAndroid Build Coastguard Worker x = torch.UntypedStorage(2) 10199*da0073e9SAndroid Build Coastguard Worker x._tracker = t1 10200*da0073e9SAndroid Build Coastguard Worker del t1 10201*da0073e9SAndroid Build Coastguard Worker 10202*da0073e9SAndroid Build Coastguard Worker m2, t2 = Tracker.make() 10203*da0073e9SAndroid Build Coastguard Worker y = torch.UntypedStorage(2) 10204*da0073e9SAndroid Build Coastguard Worker y._tracker = t2 10205*da0073e9SAndroid Build Coastguard Worker del t2 10206*da0073e9SAndroid Build Coastguard Worker 10207*da0073e9SAndroid Build Coastguard Worker x._loop = y 10208*da0073e9SAndroid Build Coastguard Worker y._loop = x 10209*da0073e9SAndroid Build Coastguard Worker 10210*da0073e9SAndroid Build Coastguard Worker # C++ reference should keep the cycle live! 10211*da0073e9SAndroid Build Coastguard Worker # This exercise THPVariable_subtype_traverse 10212*da0073e9SAndroid Build Coastguard Worker # NB: Because z.grad is a reference done entirely in C++, cycles 10213*da0073e9SAndroid Build Coastguard Worker # involving it directly are NOT broken by Python GC; you've 10214*da0073e9SAndroid Build Coastguard Worker # set up a good old C++ reference cycle which we cannot safely 10215*da0073e9SAndroid Build Coastguard Worker # break (because C++ references are allowed to be accessed 10216*da0073e9SAndroid Build Coastguard Worker # multithreaded-ly) (TODO: except maybe if you can prove that 10217*da0073e9SAndroid Build Coastguard Worker # only Python has access to the C++ object, in which case you can 10218*da0073e9SAndroid Build Coastguard Worker # also prove that no multithreaded access occurs) 10219*da0073e9SAndroid Build Coastguard Worker z = torch.UntypedStorage(2) 10220*da0073e9SAndroid Build Coastguard Worker z.grad = x 10221*da0073e9SAndroid Build Coastguard Worker 10222*da0073e9SAndroid Build Coastguard Worker del x 10223*da0073e9SAndroid Build Coastguard Worker del y 10224*da0073e9SAndroid Build Coastguard Worker 10225*da0073e9SAndroid Build Coastguard Worker gc.collect() 10226*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10227*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10228*da0073e9SAndroid Build Coastguard Worker 10229*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 10230*da0073e9SAndroid Build Coastguard Worker del z 10231*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10232*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10233*da0073e9SAndroid Build Coastguard Worker 10234*da0073e9SAndroid Build Coastguard Worker gc.collect() 10235*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10236*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10237*da0073e9SAndroid Build Coastguard Worker 10238*da0073e9SAndroid Build Coastguard Worker def test_tensor_cycle_via_slots(self): 10239*da0073e9SAndroid Build Coastguard Worker m1 = [False] 10240*da0073e9SAndroid Build Coastguard Worker m2 = [False] 10241*da0073e9SAndroid Build Coastguard Worker 10242*da0073e9SAndroid Build Coastguard Worker class SlotTensor1(torch.Tensor): 10243*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot1'] 10244*da0073e9SAndroid Build Coastguard Worker 10245*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10246*da0073e9SAndroid Build Coastguard Worker m1[0] = True 10247*da0073e9SAndroid Build Coastguard Worker 10248*da0073e9SAndroid Build Coastguard Worker class SlotTensor2(SlotTensor1): 10249*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot2'] 10250*da0073e9SAndroid Build Coastguard Worker 10251*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10252*da0073e9SAndroid Build Coastguard Worker m2[0] = True 10253*da0073e9SAndroid Build Coastguard Worker 10254*da0073e9SAndroid Build Coastguard Worker x = SlotTensor1(torch.empty(2)) 10255*da0073e9SAndroid Build Coastguard Worker y = SlotTensor2(torch.empty(2)) 10256*da0073e9SAndroid Build Coastguard Worker 10257*da0073e9SAndroid Build Coastguard Worker x.slot1 = y 10258*da0073e9SAndroid Build Coastguard Worker y.slot2 = x 10259*da0073e9SAndroid Build Coastguard Worker 10260*da0073e9SAndroid Build Coastguard Worker del x 10261*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 10262*da0073e9SAndroid Build Coastguard Worker del y 10263*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10264*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10265*da0073e9SAndroid Build Coastguard Worker 10266*da0073e9SAndroid Build Coastguard Worker gc.collect() 10267*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10268*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10269*da0073e9SAndroid Build Coastguard Worker 10270*da0073e9SAndroid Build Coastguard Worker def test_storage_cycle_via_slots(self): 10271*da0073e9SAndroid Build Coastguard Worker m1 = [False] 10272*da0073e9SAndroid Build Coastguard Worker m2 = [False] 10273*da0073e9SAndroid Build Coastguard Worker 10274*da0073e9SAndroid Build Coastguard Worker class SlotStorage1(torch._C.StorageBase): 10275*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot1'] 10276*da0073e9SAndroid Build Coastguard Worker 10277*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10278*da0073e9SAndroid Build Coastguard Worker m1[0] = True 10279*da0073e9SAndroid Build Coastguard Worker 10280*da0073e9SAndroid Build Coastguard Worker class SlotStorage2(SlotStorage1): 10281*da0073e9SAndroid Build Coastguard Worker __slots__ = ['slot2'] 10282*da0073e9SAndroid Build Coastguard Worker 10283*da0073e9SAndroid Build Coastguard Worker def __del__(self): 10284*da0073e9SAndroid Build Coastguard Worker m2[0] = True 10285*da0073e9SAndroid Build Coastguard Worker 10286*da0073e9SAndroid Build Coastguard Worker x = SlotStorage1(torch.UntypedStorage(2)) 10287*da0073e9SAndroid Build Coastguard Worker y = SlotStorage2(torch.UntypedStorage(2)) 10288*da0073e9SAndroid Build Coastguard Worker 10289*da0073e9SAndroid Build Coastguard Worker x.slot1 = y 10290*da0073e9SAndroid Build Coastguard Worker y.slot2 = x 10291*da0073e9SAndroid Build Coastguard Worker 10292*da0073e9SAndroid Build Coastguard Worker del x 10293*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 10294*da0073e9SAndroid Build Coastguard Worker del y 10295*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10296*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10297*da0073e9SAndroid Build Coastguard Worker 10298*da0073e9SAndroid Build Coastguard Worker gc.collect() 10299*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10300*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10301*da0073e9SAndroid Build Coastguard Worker 10302*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 10303*da0073e9SAndroid Build Coastguard Worker def test_storage_preserve_nonhermetic_in_hermetic_context(self): 10304*da0073e9SAndroid Build Coastguard Worker from torch.library import Library, impl 10305*da0073e9SAndroid Build Coastguard Worker global _my_storage 10306*da0073e9SAndroid Build Coastguard Worker 10307*da0073e9SAndroid Build Coastguard Worker my_lib = Library("my_lib", "DEF") # noqa: TOR901 10308*da0073e9SAndroid Build Coastguard Worker my_lib.define('my_func() -> None') 10309*da0073e9SAndroid Build Coastguard Worker 10310*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.]) 10311*da0073e9SAndroid Build Coastguard Worker _my_storage = a.untyped_storage() 10312*da0073e9SAndroid Build Coastguard Worker 10313*da0073e9SAndroid Build Coastguard Worker m, t = Tracker.make() 10314*da0073e9SAndroid Build Coastguard Worker _my_storage._tracker = t 10315*da0073e9SAndroid Build Coastguard Worker del t 10316*da0073e9SAndroid Build Coastguard Worker 10317*da0073e9SAndroid Build Coastguard Worker @impl(my_lib, 'my_func', '') 10318*da0073e9SAndroid Build Coastguard Worker def my_func(): 10319*da0073e9SAndroid Build Coastguard Worker global _my_storage 10320*da0073e9SAndroid Build Coastguard Worker del _my_storage 10321*da0073e9SAndroid Build Coastguard Worker 10322*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10323*da0073e9SAndroid Build Coastguard Worker torch.ops.my_lib.my_func() 10324*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m[0]) 10325*da0073e9SAndroid Build Coastguard Worker 10326*da0073e9SAndroid Build Coastguard Worker s = a.untyped_storage() 10327*da0073e9SAndroid Build Coastguard Worker del a 10328*da0073e9SAndroid Build Coastguard Worker del s 10329*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m[0]) 10330*da0073e9SAndroid Build Coastguard Worker 10331*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_autograd? 10332*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo does not work well with hooks") 10333*da0073e9SAndroid Build Coastguard Worker def test_backward_hooks_traverse(self): 10334*da0073e9SAndroid Build Coastguard Worker m1, t1 = Tracker.make() 10335*da0073e9SAndroid Build Coastguard Worker m2, t2 = Tracker.make() 10336*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2, requires_grad=True) 10337*da0073e9SAndroid Build Coastguard Worker x._tracker = t1 10338*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2, requires_grad=True) 10339*da0073e9SAndroid Build Coastguard Worker y._tracker = t2 10340*da0073e9SAndroid Build Coastguard Worker del t1 10341*da0073e9SAndroid Build Coastguard Worker del t2 10342*da0073e9SAndroid Build Coastguard Worker 10343*da0073e9SAndroid Build Coastguard Worker # this hits a special setter, it's not just a __dict__ entry 10344*da0073e9SAndroid Build Coastguard Worker x._backward_hooks = y 10345*da0073e9SAndroid Build Coastguard Worker y._backward_hooks = x 10346*da0073e9SAndroid Build Coastguard Worker 10347*da0073e9SAndroid Build Coastguard Worker del x 10348*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 10349*da0073e9SAndroid Build Coastguard Worker del y 10350*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m1[0]) 10351*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2[0]) 10352*da0073e9SAndroid Build Coastguard Worker 10353*da0073e9SAndroid Build Coastguard Worker gc.collect() 10354*da0073e9SAndroid Build Coastguard Worker 10355*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m1[0]) 10356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2[0]) 10357*da0073e9SAndroid Build Coastguard Worker 10358*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") 10359*da0073e9SAndroid Build Coastguard Worker def test_tensor_dead_weak_ref(self): 10360*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2) 10361*da0073e9SAndroid Build Coastguard Worker w_x = weakref.ref(x) 10362*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2) 10363*da0073e9SAndroid Build Coastguard Worker y.grad = x 10364*da0073e9SAndroid Build Coastguard Worker del x 10365*da0073e9SAndroid Build Coastguard Worker 10366*da0073e9SAndroid Build Coastguard Worker x = w_x() 10367*da0073e9SAndroid Build Coastguard Worker # Ideally, x would keep the tensor live. But CPython doesn't 10368*da0073e9SAndroid Build Coastguard Worker # provide enough hooks to do this. So it will go dead and x 10369*da0073e9SAndroid Build Coastguard Worker # will transmute into an undefined tensor. Not great, but the 10370*da0073e9SAndroid Build Coastguard Worker # best we can do. 10371*da0073e9SAndroid Build Coastguard Worker del y 10372*da0073e9SAndroid Build Coastguard Worker 10373*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: x.sigmoid()) 10374*da0073e9SAndroid Build Coastguard Worker 10375*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") 10376*da0073e9SAndroid Build Coastguard Worker def test_storage_dead_weak_ref(self): 10377*da0073e9SAndroid Build Coastguard Worker x = torch.UntypedStorage(2) 10378*da0073e9SAndroid Build Coastguard Worker w_x = weakref.ref(x) 10379*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(x) 10380*da0073e9SAndroid Build Coastguard Worker del x 10381*da0073e9SAndroid Build Coastguard Worker 10382*da0073e9SAndroid Build Coastguard Worker x = w_x() 10383*da0073e9SAndroid Build Coastguard Worker # Ideally, x would keep the storage live. But CPython doesn't 10384*da0073e9SAndroid Build Coastguard Worker # provide enough hooks to do this. So it will go dead and x 10385*da0073e9SAndroid Build Coastguard Worker # will transmute into storage with null StorageImpl. Not great, but the 10386*da0073e9SAndroid Build Coastguard Worker # best we can do. 10387*da0073e9SAndroid Build Coastguard Worker del y 10388*da0073e9SAndroid Build Coastguard Worker 10389*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0]) 10390*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float()) 10391*da0073e9SAndroid Build Coastguard Worker 10392*da0073e9SAndroid Build Coastguard Worker def test_tensor_resurrected_weak_ref(self): 10393*da0073e9SAndroid Build Coastguard Worker x = torch.empty(2) 10394*da0073e9SAndroid Build Coastguard Worker w_x = weakref.ref(x) 10395*da0073e9SAndroid Build Coastguard Worker y = torch.empty(2) 10396*da0073e9SAndroid Build Coastguard Worker y.grad = x 10397*da0073e9SAndroid Build Coastguard Worker del x 10398*da0073e9SAndroid Build Coastguard Worker 10399*da0073e9SAndroid Build Coastguard Worker x = w_x() 10400*da0073e9SAndroid Build Coastguard Worker # Use this to manually fix weak references after dereferencing them 10401*da0073e9SAndroid Build Coastguard Worker x._fix_weakref() 10402*da0073e9SAndroid Build Coastguard Worker del y 10403*da0073e9SAndroid Build Coastguard Worker x.sigmoid() 10404*da0073e9SAndroid Build Coastguard Worker 10405*da0073e9SAndroid Build Coastguard Worker def test_storage_resurrected_weak_ref(self): 10406*da0073e9SAndroid Build Coastguard Worker x = torch.UntypedStorage(2) 10407*da0073e9SAndroid Build Coastguard Worker w_x = weakref.ref(x) 10408*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(x) 10409*da0073e9SAndroid Build Coastguard Worker del x 10410*da0073e9SAndroid Build Coastguard Worker 10411*da0073e9SAndroid Build Coastguard Worker x = w_x() 10412*da0073e9SAndroid Build Coastguard Worker # Use this to manually fix weak reference after dereferencing them 10413*da0073e9SAndroid Build Coastguard Worker x._fix_weakref() 10414*da0073e9SAndroid Build Coastguard Worker del y 10415*da0073e9SAndroid Build Coastguard Worker x.float() 10416*da0073e9SAndroid Build Coastguard Worker 10417*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") 10418*da0073e9SAndroid Build Coastguard Worker def test_tensor_fix_weakref_no_leak(self): 10419*da0073e9SAndroid Build Coastguard Worker import weakref 10420*da0073e9SAndroid Build Coastguard Worker 10421*da0073e9SAndroid Build Coastguard Worker called = False 10422*da0073e9SAndroid Build Coastguard Worker 10423*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1) 10424*da0073e9SAndroid Build Coastguard Worker 10425*da0073e9SAndroid Build Coastguard Worker def callback(w): 10426*da0073e9SAndroid Build Coastguard Worker nonlocal called 10427*da0073e9SAndroid Build Coastguard Worker called = True 10428*da0073e9SAndroid Build Coastguard Worker wa = weakref.ref(a, callback) 10429*da0073e9SAndroid Build Coastguard Worker a._fix_weakref() 10430*da0073e9SAndroid Build Coastguard Worker del a 10431*da0073e9SAndroid Build Coastguard Worker 10432*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 10433*da0073e9SAndroid Build Coastguard Worker 10434*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") 10435*da0073e9SAndroid Build Coastguard Worker def test_storage_fix_weakref_no_leak(self): 10436*da0073e9SAndroid Build Coastguard Worker import weakref 10437*da0073e9SAndroid Build Coastguard Worker 10438*da0073e9SAndroid Build Coastguard Worker called = False 10439*da0073e9SAndroid Build Coastguard Worker 10440*da0073e9SAndroid Build Coastguard Worker a = torch.UntypedStorage(1) 10441*da0073e9SAndroid Build Coastguard Worker 10442*da0073e9SAndroid Build Coastguard Worker def callback(w): 10443*da0073e9SAndroid Build Coastguard Worker nonlocal called 10444*da0073e9SAndroid Build Coastguard Worker called = True 10445*da0073e9SAndroid Build Coastguard Worker wa = weakref.ref(a, callback) 10446*da0073e9SAndroid Build Coastguard Worker a._fix_weakref() 10447*da0073e9SAndroid Build Coastguard Worker del a 10448*da0073e9SAndroid Build Coastguard Worker 10449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 10450*da0073e9SAndroid Build Coastguard Worker 10451*da0073e9SAndroid Build Coastguard Worker # FIXME: move to test_linalg 10452*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 10453*da0073e9SAndroid Build Coastguard Worker def test_bmm_multithreaded(self): 10454*da0073e9SAndroid Build Coastguard Worker device = 'cpu' 10455*da0073e9SAndroid Build Coastguard Worker num_threads = torch.get_num_threads() 10456*da0073e9SAndroid Build Coastguard Worker 10457*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(4) 10458*da0073e9SAndroid Build Coastguard Worker batch_sizes = [1, 10] 10459*da0073e9SAndroid Build Coastguard Worker M, N, O = 23, 8, 12 10460*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 10461*da0073e9SAndroid Build Coastguard Worker numpy_dtype = dtype 10462*da0073e9SAndroid Build Coastguard Worker 10463*da0073e9SAndroid Build Coastguard Worker def invert_perm(p): 10464*da0073e9SAndroid Build Coastguard Worker d = {x: i for i, x in enumerate(p)} 10465*da0073e9SAndroid Build Coastguard Worker return (d[0], d[1], d[2]) 10466*da0073e9SAndroid Build Coastguard Worker 10467*da0073e9SAndroid Build Coastguard Worker def generate_inputs(num_batches): 10468*da0073e9SAndroid Build Coastguard Worker # transposed tensors 10469*da0073e9SAndroid Build Coastguard Worker for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2): 10470*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) 10471*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) 10472*da0073e9SAndroid Build Coastguard Worker b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) 10473*da0073e9SAndroid Build Coastguard Worker b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) 10474*da0073e9SAndroid Build Coastguard Worker yield b1, b2 10475*da0073e9SAndroid Build Coastguard Worker # broadcasting tensors 10476*da0073e9SAndroid Build Coastguard Worker for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): 10477*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) 10478*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) 10479*da0073e9SAndroid Build Coastguard Worker b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) 10480*da0073e9SAndroid Build Coastguard Worker b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) 10481*da0073e9SAndroid Build Coastguard Worker yield b1, b2 10482*da0073e9SAndroid Build Coastguard Worker # zero-sized tensors 10483*da0073e9SAndroid Build Coastguard Worker for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): 10484*da0073e9SAndroid Build Coastguard Worker shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) 10485*da0073e9SAndroid Build Coastguard Worker shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) 10486*da0073e9SAndroid Build Coastguard Worker b1 = torch.randn(shape1, dtype=dtype, device=device) 10487*da0073e9SAndroid Build Coastguard Worker b2 = torch.randn(shape2, dtype=dtype, device=device) 10488*da0073e9SAndroid Build Coastguard Worker yield b1, b2 10489*da0073e9SAndroid Build Coastguard Worker 10490*da0073e9SAndroid Build Coastguard Worker try: 10491*da0073e9SAndroid Build Coastguard Worker for num_batches in batch_sizes: 10492*da0073e9SAndroid Build Coastguard Worker for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))): 10493*da0073e9SAndroid Build Coastguard Worker res1 = torch.bmm(b1, b2) 10494*da0073e9SAndroid Build Coastguard Worker res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \ 10495*da0073e9SAndroid Build Coastguard Worker .permute(perm3).contiguous().permute(invert_perm(perm3)) 10496*da0073e9SAndroid Build Coastguard Worker torch.bmm(b1, b2, out=res2) 10497*da0073e9SAndroid Build Coastguard Worker expect = torch.from_numpy( 10498*da0073e9SAndroid Build Coastguard Worker b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype) 10499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res1) 10500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect, res2) 10501*da0073e9SAndroid Build Coastguard Worker finally: 10502*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(num_threads) 10503*da0073e9SAndroid Build Coastguard Worker 10504*da0073e9SAndroid Build Coastguard Worker def test_conj_neg_tolist(self): 10505*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, dtype=torch.cfloat) 10506*da0073e9SAndroid Build Coastguard Worker y1 = x.conj() 10507*da0073e9SAndroid Build Coastguard Worker y1_expect = x.conj_physical() 10508*da0073e9SAndroid Build Coastguard Worker y2 = y1.imag 10509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y1_expect.tolist()) 10510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y2, y1_expect.imag.tolist()) 10511*da0073e9SAndroid Build Coastguard Worker 10512*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(torch.backends.cuda.is_built(), "Skipped for cuda-enabled build") 10513*da0073e9SAndroid Build Coastguard Worker def test_no_cuda_monkeypatch(self): 10514*da0073e9SAndroid Build Coastguard Worker # Note that this is not in test_cuda.py as this whole file is skipped when cuda 10515*da0073e9SAndroid Build Coastguard Worker # is not available. 10516*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Stream"): 10517*da0073e9SAndroid Build Coastguard Worker torch.cuda.Stream() 10518*da0073e9SAndroid Build Coastguard Worker 10519*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class Event"): 10520*da0073e9SAndroid Build Coastguard Worker torch.cuda.Event() 10521*da0073e9SAndroid Build Coastguard Worker 10522*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Tried to instantiate dummy base class CUDAGraph"): 10523*da0073e9SAndroid Build Coastguard Worker torch.cuda.graphs.CUDAGraph() 10524*da0073e9SAndroid Build Coastguard Worker 10525*da0073e9SAndroid Build Coastguard Worker def test_tensor_where_scalar(self): 10526*da0073e9SAndroid Build Coastguard Worker 10527*da0073e9SAndroid Build Coastguard Worker a = torch.arange(4.0) 10528*da0073e9SAndroid Build Coastguard Worker not_zero = 0.001 10529*da0073e9SAndroid Build Coastguard Worker 10530*da0073e9SAndroid Build Coastguard Worker # b is generated through torch.where function with not_zero being a scalar parameter 10531*da0073e9SAndroid Build Coastguard Worker b = torch.where(a != 0, a, not_zero) 10532*da0073e9SAndroid Build Coastguard Worker # c is generated through Tensor.where method with not_zero being a scalar parameter 10533*da0073e9SAndroid Build Coastguard Worker c = a.where(a != 0, not_zero) 10534*da0073e9SAndroid Build Coastguard Worker 10535*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, c) 10536*da0073e9SAndroid Build Coastguard Worker 10537*da0073e9SAndroid Build Coastguard Worker def test_data_ptr_of_empty_tensor_with_storage(self): 10538*da0073e9SAndroid Build Coastguard Worker t = torch.empty((2, 2)) 10539*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t.data_ptr(), 0) 10540*da0073e9SAndroid Build Coastguard Worker t.resize_((0, 2)) 10541*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.data_ptr(), 0) 10542*da0073e9SAndroid Build Coastguard Worker 10543*da0073e9SAndroid Build Coastguard Worker def test_data_ptr_of_empty_view_with_storage(self): 10544*da0073e9SAndroid Build Coastguard Worker t = torch.empty((2, 2)) 10545*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t.data_ptr(), 0) 10546*da0073e9SAndroid Build Coastguard Worker t2 = t[0:0].view(0, 1) 10547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2.data_ptr(), 0) 10548*da0073e9SAndroid Build Coastguard Worker 10549*da0073e9SAndroid Build Coastguard Worker def test_size_stride(self) -> None: 10550*da0073e9SAndroid Build Coastguard Worker t = torch.rand(2, 3, dtype=torch.float32) 10551*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.size(0), 2) 10552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.size(dim=None), torch.Size([2, 3])) 10553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.stride(dim=None), torch.Size([3, 1])) 10554*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.t().stride(), torch.Size([1, 3])) 10555*da0073e9SAndroid Build Coastguard Worker 10556*da0073e9SAndroid Build Coastguard Worker def test_invalid_arg_error_handling(self) -> None: 10557*da0073e9SAndroid Build Coastguard Worker """ Tests that errors from old TH functions are propagated back """ 10558*da0073e9SAndroid Build Coastguard Worker for invalid_val in [-1, 2**65]: 10559*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.set_num_threads(invalid_val)) 10560*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: torch.set_num_interop_threads(invalid_val)) 10561*da0073e9SAndroid Build Coastguard Worker 10562*da0073e9SAndroid Build Coastguard Worker def _get_tensor_prop(self, t): 10563*da0073e9SAndroid Build Coastguard Worker preserved = ( 10564*da0073e9SAndroid Build Coastguard Worker id(t), 10565*da0073e9SAndroid Build Coastguard Worker # Refcount values get modified by Dynamo resume frames 10566*da0073e9SAndroid Build Coastguard Worker 0 if TEST_WITH_TORCHDYNAMO else sys.getrefcount(t), 10567*da0073e9SAndroid Build Coastguard Worker ) 10568*da0073e9SAndroid Build Coastguard Worker slotnames = copyreg._slotnames(t.__class__) 10569*da0073e9SAndroid Build Coastguard Worker moved = ( 10570*da0073e9SAndroid Build Coastguard Worker slotnames, 10571*da0073e9SAndroid Build Coastguard Worker id(t.__dict__), 10572*da0073e9SAndroid Build Coastguard Worker tuple(t.__dict__.keys()), 10573*da0073e9SAndroid Build Coastguard Worker [getattr(t, name, None) for name in slotnames] 10574*da0073e9SAndroid Build Coastguard Worker ) 10575*da0073e9SAndroid Build Coastguard Worker return preserved, moved 10576*da0073e9SAndroid Build Coastguard Worker 10577*da0073e9SAndroid Build Coastguard Worker def _checked_swap(self, t1, t2): 10578*da0073e9SAndroid Build Coastguard Worker t1_pres, t1_moved = self._get_tensor_prop(t1) 10579*da0073e9SAndroid Build Coastguard Worker t2_pres, t2_moved = self._get_tensor_prop(t2) 10580*da0073e9SAndroid Build Coastguard Worker 10581*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t1, t2) 10582*da0073e9SAndroid Build Coastguard Worker 10583*da0073e9SAndroid Build Coastguard Worker new_t1_pres, new_t1_moved = self._get_tensor_prop(t1) 10584*da0073e9SAndroid Build Coastguard Worker new_t2_pres, new_t2_moved = self._get_tensor_prop(t2) 10585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1_pres, new_t1_pres) 10586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2_pres, new_t2_pres) 10587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1_moved, new_t2_moved) 10588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2_moved, new_t1_moved) 10589*da0073e9SAndroid Build Coastguard Worker 10590*da0073e9SAndroid Build Coastguard Worker # tests that PyObject slots on TensorImpl are correctly swapped by 10591*da0073e9SAndroid Build Coastguard Worker # checking that when the function applied on a swapped tensor is 10592*da0073e9SAndroid Build Coastguard Worker # returns doesn't change the TensorImpl, the returned value (which is 10593*da0073e9SAndroid Build Coastguard Worker # given by returning the reference to the PyObject in the TensorImpl's 10594*da0073e9SAndroid Build Coastguard Worker # PyObjectSlot) is still correct 10595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(t1.fill_(0.5)), id(t1)) 10596*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id(t2.fill_(0.5)), id(t2)) 10597*da0073e9SAndroid Build Coastguard Worker 10598*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo adds weakrefs") 10599*da0073e9SAndroid Build Coastguard Worker def test_swap_basic(self): 10600*da0073e9SAndroid Build Coastguard Worker ts = [ 10601*da0073e9SAndroid Build Coastguard Worker torch.rand(2), 10602*da0073e9SAndroid Build Coastguard Worker torch.rand(3, 3), 10603*da0073e9SAndroid Build Coastguard Worker torch.empty(3, dtype=torch.int), 10604*da0073e9SAndroid Build Coastguard Worker TwoTensor(torch.rand(4), torch.rand(4)) 10605*da0073e9SAndroid Build Coastguard Worker ] 10606*da0073e9SAndroid Build Coastguard Worker 10607*da0073e9SAndroid Build Coastguard Worker for t1, t2 in itertools.combinations(ts, 2): 10608*da0073e9SAndroid Build Coastguard Worker t1 = t1.clone() 10609*da0073e9SAndroid Build Coastguard Worker t2 = t2.clone() 10610*da0073e9SAndroid Build Coastguard Worker t2.foo = "bar" 10611*da0073e9SAndroid Build Coastguard Worker holder = [] 10612*da0073e9SAndroid Build Coastguard Worker holder.append(t1) 10613*da0073e9SAndroid Build Coastguard Worker 10614*da0073e9SAndroid Build Coastguard Worker self._checked_swap(t1, t2) 10615*da0073e9SAndroid Build Coastguard Worker 10616*da0073e9SAndroid Build Coastguard Worker self.assertIs(holder[0], t1) 10617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.foo, "bar") 10618*da0073e9SAndroid Build Coastguard Worker 10619*da0073e9SAndroid Build Coastguard Worker if t1.is_floating_point(): 10620*da0073e9SAndroid Build Coastguard Worker t3 = t1.clone().detach().requires_grad_(True) 10621*da0073e9SAndroid Build Coastguard Worker out = t3 * 2 10622*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t3, t2) 10623*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "AccumulateGrad node that was poisoned by swap_tensors"): 10624*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 10625*da0073e9SAndroid Build Coastguard Worker 10626*da0073e9SAndroid Build Coastguard Worker wr = weakref.ref(t1) 10627*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "has weakref"): 10628*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t1, t2) 10629*da0073e9SAndroid Build Coastguard Worker 10630*da0073e9SAndroid Build Coastguard Worker 10631*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo adds weakrefs") 10632*da0073e9SAndroid Build Coastguard Worker def test_swap_fail_slots(self): 10633*da0073e9SAndroid Build Coastguard Worker class MyTwoTensor(TwoTensor): 10634*da0073e9SAndroid Build Coastguard Worker __slots__ = ("a", "b") 10635*da0073e9SAndroid Build Coastguard Worker 10636*da0073e9SAndroid Build Coastguard Worker class MyTwoTensor2(TwoTensor): 10637*da0073e9SAndroid Build Coastguard Worker __slots__ = ("b", "a") 10638*da0073e9SAndroid Build Coastguard Worker 10639*da0073e9SAndroid Build Coastguard Worker class MyTwoTensor3(TwoTensor): 10640*da0073e9SAndroid Build Coastguard Worker __slots__ = ("a", "b", "c", "d") 10641*da0073e9SAndroid Build Coastguard Worker 10642*da0073e9SAndroid Build Coastguard Worker class MyTwoTensor4(TwoTensor): 10643*da0073e9SAndroid Build Coastguard Worker __slots__ = ("a", "c") 10644*da0073e9SAndroid Build Coastguard Worker 10645*da0073e9SAndroid Build Coastguard Worker 10646*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(4) 10647*da0073e9SAndroid Build Coastguard Worker t2 = TwoTensor(torch.rand(4), torch.rand(4)) 10648*da0073e9SAndroid Build Coastguard Worker t3 = MyTwoTensor(torch.rand(4), torch.rand(4)) 10649*da0073e9SAndroid Build Coastguard Worker t4 = MyTwoTensor(torch.rand(4), torch.rand(4)) 10650*da0073e9SAndroid Build Coastguard Worker t5 = MyTwoTensor2(torch.rand(4), torch.rand(4)) 10651*da0073e9SAndroid Build Coastguard Worker t6 = MyTwoTensor3(torch.rand(4), torch.rand(4)) 10652*da0073e9SAndroid Build Coastguard Worker t7 = MyTwoTensor3(torch.rand(4), torch.rand(4)) 10653*da0073e9SAndroid Build Coastguard Worker t8 = MyTwoTensor4(torch.rand(4), torch.rand(4)) 10654*da0073e9SAndroid Build Coastguard Worker 10655*da0073e9SAndroid Build Coastguard Worker self._checked_swap(t1, t2) 10656*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"): 10657*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t1, t3) 10658*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"): 10659*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t2, t3) 10660*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"): 10661*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t2, t8) 10662*da0073e9SAndroid Build Coastguard Worker self._checked_swap(t3, t4) 10663*da0073e9SAndroid Build Coastguard Worker self._checked_swap(t3, t5) 10664*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot swap t1 and t2 if they have different slots"): 10665*da0073e9SAndroid Build Coastguard Worker torch.utils.swap_tensors(t3, t6) 10666*da0073e9SAndroid Build Coastguard Worker t3.c = "foo" 10667*da0073e9SAndroid Build Coastguard Worker t4.d = "bar" 10668*da0073e9SAndroid Build Coastguard Worker self._checked_swap(t3, t4) 10669*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t4.c, "foo") 10670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t3.d, "bar") 10671*da0073e9SAndroid Build Coastguard Worker t6.c = "cat" 10672*da0073e9SAndroid Build Coastguard Worker t7.d = "dog" 10673*da0073e9SAndroid Build Coastguard Worker self._checked_swap(t6, t7) 10674*da0073e9SAndroid Build Coastguard Worker 10675*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(torch.cuda.is_available(), "Test specific for CPU") 10676*da0073e9SAndroid Build Coastguard Worker def test_bf16_supported_on_cpu(self): 10677*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.cuda.is_bf16_supported()) 10678*da0073e9SAndroid Build Coastguard Worker 10679*da0073e9SAndroid Build Coastguard Worker 10680*da0073e9SAndroid Build Coastguard Worker# The following block extends TestTorch with negative dim wrapping tests 10681*da0073e9SAndroid Build Coastguard Worker# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests 10682*da0073e9SAndroid Build Coastguard Worker# Functions to test negative dimension wrapping 10683*da0073e9SAndroid Build Coastguard WorkerMETHOD = 1 10684*da0073e9SAndroid Build Coastguard WorkerINPLACE_METHOD = 2 10685*da0073e9SAndroid Build Coastguard WorkerFUNCTIONAL = 4 10686*da0073e9SAndroid Build Coastguard WorkerDIM_ARG: None = None 10687*da0073e9SAndroid Build Coastguard Worker 10688*da0073e9SAndroid Build Coastguard Workerdef make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0): 10689*da0073e9SAndroid Build Coastguard Worker def neg_dim_test(self): 10690*da0073e9SAndroid Build Coastguard Worker if isinstance(tensor_arg, list): 10691*da0073e9SAndroid Build Coastguard Worker assert METHOD not in types and INPLACE_METHOD not in types 10692*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(arg) for arg in tensor_arg] 10693*da0073e9SAndroid Build Coastguard Worker ndim = len(tensor_arg[-1]) 10694*da0073e9SAndroid Build Coastguard Worker else: 10695*da0073e9SAndroid Build Coastguard Worker x = torch.randn(*tensor_arg) 10696*da0073e9SAndroid Build Coastguard Worker ndim = len(tensor_arg) 10697*da0073e9SAndroid Build Coastguard Worker ndim += extra_dim 10698*da0073e9SAndroid Build Coastguard Worker 10699*da0073e9SAndroid Build Coastguard Worker n_dim_to_test = sum(e is DIM_ARG for e in arg_constr()) 10700*da0073e9SAndroid Build Coastguard Worker 10701*da0073e9SAndroid Build Coastguard Worker for dims_val in combinations(range(ndim), n_dim_to_test): 10702*da0073e9SAndroid Build Coastguard Worker arg = arg_constr() 10703*da0073e9SAndroid Build Coastguard Worker arg_neg = copy.deepcopy(arg) 10704*da0073e9SAndroid Build Coastguard Worker idx = 0 10705*da0073e9SAndroid Build Coastguard Worker for i, v in enumerate(arg): 10706*da0073e9SAndroid Build Coastguard Worker if v is DIM_ARG: 10707*da0073e9SAndroid Build Coastguard Worker arg[i] = dims_val[idx] 10708*da0073e9SAndroid Build Coastguard Worker arg_neg[i] = dims_val[idx] - ndim 10709*da0073e9SAndroid Build Coastguard Worker idx += 1 10710*da0073e9SAndroid Build Coastguard Worker 10711*da0073e9SAndroid Build Coastguard Worker if METHOD in types: 10712*da0073e9SAndroid Build Coastguard Worker a = getattr(x, name)(*arg) 10713*da0073e9SAndroid Build Coastguard Worker b = getattr(x, name)(*arg_neg) 10714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 10715*da0073e9SAndroid Build Coastguard Worker 10716*da0073e9SAndroid Build Coastguard Worker if INPLACE_METHOD in types: 10717*da0073e9SAndroid Build Coastguard Worker a = x.clone() 10718*da0073e9SAndroid Build Coastguard Worker getattr(a, name + '_')(*arg) 10719*da0073e9SAndroid Build Coastguard Worker b = x.clone() 10720*da0073e9SAndroid Build Coastguard Worker getattr(b, name + '_')(*arg_neg) 10721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 10722*da0073e9SAndroid Build Coastguard Worker 10723*da0073e9SAndroid Build Coastguard Worker if FUNCTIONAL in types: 10724*da0073e9SAndroid Build Coastguard Worker a = getattr(torch, name)(x, *arg) 10725*da0073e9SAndroid Build Coastguard Worker b = getattr(torch, name)(x, *arg_neg) 10726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 10727*da0073e9SAndroid Build Coastguard Worker 10728*da0073e9SAndroid Build Coastguard Worker return neg_dim_test 10729*da0073e9SAndroid Build Coastguard Worker 10730*da0073e9SAndroid Build Coastguard Workerdef idx_tensor(size, max_val): 10731*da0073e9SAndroid Build Coastguard Worker return torch.LongTensor(*size).random_(0, max_val - 1) 10732*da0073e9SAndroid Build Coastguard Worker 10733*da0073e9SAndroid Build Coastguard Workerdef add_neg_dim_tests(): 10734*da0073e9SAndroid Build Coastguard Worker neg_dim_tests = [ 10735*da0073e9SAndroid Build Coastguard Worker ('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]), 10736*da0073e9SAndroid Build Coastguard Worker ('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]), 10737*da0073e9SAndroid Build Coastguard Worker ('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]), 10738*da0073e9SAndroid Build Coastguard Worker ('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]), 10739*da0073e9SAndroid Build Coastguard Worker ('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), 10740*da0073e9SAndroid Build Coastguard Worker ('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]), 10741*da0073e9SAndroid Build Coastguard Worker ('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]), 10742*da0073e9SAndroid Build Coastguard Worker ('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), 10743*da0073e9SAndroid Build Coastguard Worker ('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]), 10744*da0073e9SAndroid Build Coastguard Worker ('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]), 10745*da0073e9SAndroid Build Coastguard Worker ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1), 10746*da0073e9SAndroid Build Coastguard Worker ('logcumsumexp', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10747*da0073e9SAndroid Build Coastguard Worker ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10748*da0073e9SAndroid Build Coastguard Worker ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10749*da0073e9SAndroid Build Coastguard Worker ('cummax', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10750*da0073e9SAndroid Build Coastguard Worker ('cummin', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10751*da0073e9SAndroid Build Coastguard Worker ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10752*da0073e9SAndroid Build Coastguard Worker ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10753*da0073e9SAndroid Build Coastguard Worker ('nanmedian', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10754*da0073e9SAndroid Build Coastguard Worker ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10755*da0073e9SAndroid Build Coastguard Worker ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]), 10756*da0073e9SAndroid Build Coastguard Worker ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10757*da0073e9SAndroid Build Coastguard Worker ('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10758*da0073e9SAndroid Build Coastguard Worker ('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10759*da0073e9SAndroid Build Coastguard Worker ('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10760*da0073e9SAndroid Build Coastguard Worker ('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]), 10761*da0073e9SAndroid Build Coastguard Worker ('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10762*da0073e9SAndroid Build Coastguard Worker ('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10763*da0073e9SAndroid Build Coastguard Worker ('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), 10764*da0073e9SAndroid Build Coastguard Worker ('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), 10765*da0073e9SAndroid Build Coastguard Worker ('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]), 10766*da0073e9SAndroid Build Coastguard Worker ('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]), 10767*da0073e9SAndroid Build Coastguard Worker ('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]), 10768*da0073e9SAndroid Build Coastguard Worker ('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]), 10769*da0073e9SAndroid Build Coastguard Worker ('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]), 10770*da0073e9SAndroid Build Coastguard Worker ('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]), 10771*da0073e9SAndroid Build Coastguard Worker ('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]), 10772*da0073e9SAndroid Build Coastguard Worker ] 10773*da0073e9SAndroid Build Coastguard Worker 10774*da0073e9SAndroid Build Coastguard Worker for decl in neg_dim_tests: 10775*da0073e9SAndroid Build Coastguard Worker if len(decl) == 4: 10776*da0073e9SAndroid Build Coastguard Worker name, tensor_arg, arg_constr, types = decl 10777*da0073e9SAndroid Build Coastguard Worker extra_dim = 0 10778*da0073e9SAndroid Build Coastguard Worker elif len(decl) == 5: 10779*da0073e9SAndroid Build Coastguard Worker name, tensor_arg, arg_constr, types, extra_dim = decl 10780*da0073e9SAndroid Build Coastguard Worker 10781*da0073e9SAndroid Build Coastguard Worker test_name = 'test_' + name + '_neg_dim' 10782*da0073e9SAndroid Build Coastguard Worker 10783*da0073e9SAndroid Build Coastguard Worker assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name 10784*da0073e9SAndroid Build Coastguard Worker setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim)) 10785*da0073e9SAndroid Build Coastguard Worker 10786*da0073e9SAndroid Build Coastguard Worker# TODO: these empy classes are temporarily instantiated for XLA compatibility 10787*da0073e9SAndroid Build Coastguard Worker# once XLA updates their test suite it should be removed 10788*da0073e9SAndroid Build Coastguard Workerclass TestViewOps(TestCase): 10789*da0073e9SAndroid Build Coastguard Worker pass 10790*da0073e9SAndroid Build Coastguard Worker 10791*da0073e9SAndroid Build Coastguard Workerclass TestTensorDeviceOps(TestCase): 10792*da0073e9SAndroid Build Coastguard Worker pass 10793*da0073e9SAndroid Build Coastguard Worker 10794*da0073e9SAndroid Build Coastguard Worker# Generates tests 10795*da0073e9SAndroid Build Coastguard Worker# Note: test generation must be done at file scope, not within main, or 10796*da0073e9SAndroid Build Coastguard Worker# pytest will fail. 10797*da0073e9SAndroid Build Coastguard Workeradd_neg_dim_tests() 10798*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestViewOps, globals()) 10799*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestVitalSignsCuda, globals()) 10800*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTensorDeviceOps, globals()) 10801*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestTorchDeviceType, globals()) 10802*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') 10803*da0073e9SAndroid Build Coastguard Worker 10804*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 10805*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 10806*da0073e9SAndroid Build Coastguard Worker run_tests() 10807